use std::time::Duration;
use osproxy_core::{Clock, EndpointKind, IndexName, PartitionId, PrincipalId};
use serde_json::Value;
use crate::directive::{DiagLevel, DiagnosticsDirective, DirectiveMatch, DirectiveSet};
pub fn decode_directive_set(body: &[u8], clock: &dyn Clock) -> Result<DirectiveSet, &'static str> {
let v: Value = serde_json::from_slice(body).map_err(|_| "invalid_json")?;
reject_unknown_keys(&v, &["directives"])?;
let items = v
.get("directives")
.and_then(Value::as_array)
.ok_or("missing_directives")?;
let mut directives = Vec::with_capacity(items.len());
for item in items {
directives.push(decode_one(item, clock)?);
}
Ok(DirectiveSet::from_directives(directives))
}
const DIRECTIVE_KEYS: &[&str] = &[
"id",
"level",
"ttl_secs",
"tenant",
"index",
"principal",
"endpoint",
"sample_per_mille",
"ring_buffer",
"capture",
];
fn reject_unknown_keys(v: &Value, allowed: &[&str]) -> Result<(), &'static str> {
let obj = v.as_object().ok_or("not_an_object")?;
if obj.keys().all(|k| allowed.contains(&k.as_str())) {
Ok(())
} else {
Err("unknown_field")
}
}
fn decode_one(v: &Value, clock: &dyn Clock) -> Result<DiagnosticsDirective, &'static str> {
reject_unknown_keys(v, DIRECTIVE_KEYS)?;
let id = v
.get("id")
.and_then(Value::as_str)
.ok_or("missing_id")?
.to_owned();
let level = DiagLevel::from_name(
v.get("level")
.and_then(Value::as_str)
.ok_or("missing_level")?,
)
.ok_or("unknown_level")?;
let ttl_secs = v
.get("ttl_secs")
.and_then(Value::as_u64)
.ok_or("missing_ttl_secs")?;
if ttl_secs == 0 {
return Err("zero_ttl");
}
let expires_at = clock.now().saturating_add(Duration::from_secs(ttl_secs));
let sample_per_mille = match v.get("sample_per_mille") {
None => 1000,
Some(n) => match n.as_u64() {
Some(n) if n <= 1000 => u16::try_from(n).unwrap_or(1000),
_ => return Err("bad_sample_rate"),
},
};
let mut match_ = DirectiveMatch::all();
if let Some(t) = v.get("tenant").and_then(Value::as_str) {
match_ = match_.for_tenant(PartitionId::from(t));
}
if let Some(i) = v.get("index").and_then(Value::as_str) {
match_ = match_.for_index(IndexName::from(i));
}
if let Some(p) = v.get("principal").and_then(Value::as_str) {
match_ = match_.for_principal(PrincipalId::from(p));
}
if let Some(e) = v.get("endpoint") {
let name = e.as_str().ok_or("bad_endpoint")?;
match_ = match_.for_endpoint(EndpointKind::from_name(name).ok_or("unknown_endpoint")?);
}
Ok(DiagnosticsDirective {
id,
match_,
level,
sample_per_mille,
expires_at,
ring_buffer: v
.get("ring_buffer")
.and_then(Value::as_bool)
.unwrap_or(false),
capture: v.get("capture").and_then(Value::as_bool).unwrap_or(false),
})
}
#[cfg(test)]
#[path = "decode_tests.rs"]
mod tests;