1use std::collections::HashMap;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use parking_lot::RwLock;
8use serde::{Deserialize, Serialize};
9use serde_json::{Value, json};
10use thiserror::Error;
11
12use rig_compose::{KernelError, Tool, ToolSchema};
13
14use crate::trace::ResourceTraceEnvelope;
15
16const TRACE_RESOURCE: &str = "baseline";
17const TRACE_OPERATION: &str = "compare";
18const TRACE_KIND: &str = "baseline_compare";
19
20pub const TRACE_REASON_NOT_FOUND: &str = "baseline_not_found";
23pub const TRACE_REASON_WITHIN_BOUNDS: &str = "within_bounds";
25pub const TRACE_REASON_EXCEEDS_BOUNDS: &str = "exceeds_bounds";
27
28#[derive(Debug, Error)]
29pub enum BaselineError {
30 #[error("baseline `{entity}/{metric}` not found")]
31 NotFound { entity: String, metric: String },
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct EntityBaseline {
37 pub entity: String,
38 pub metric: String,
39 pub mean: f64,
40 pub std_dev: f64,
41 pub samples: u64,
42}
43
44impl EntityBaseline {
45 pub fn from_stats(
46 entity: impl Into<String>,
47 metric: impl Into<String>,
48 stats: &OnlineStats,
49 ) -> Self {
50 Self {
51 entity: entity.into(),
52 metric: metric.into(),
53 mean: stats.mean(),
54 std_dev: stats.std_dev(),
55 samples: stats.count(),
56 }
57 }
58
59 pub fn within(&self, value: f64, k: f64) -> bool {
60 let bound = (k * self.std_dev).max(f64::EPSILON);
61 (value - self.mean).abs() <= bound
62 }
63}
64
65#[derive(Debug, Clone, Default, Serialize, Deserialize)]
70pub struct OnlineStats {
71 count: u64,
72 mean: f64,
73 m2: f64,
74}
75
76impl OnlineStats {
77 pub fn new() -> Self {
78 Self::default()
79 }
80
81 pub fn push(&mut self, value: f64) {
82 self.count = self.count.saturating_add(1);
83 let delta = value - self.mean;
84 self.mean += delta / self.count as f64;
85 let delta2 = value - self.mean;
86 self.m2 += delta * delta2;
87 }
88
89 pub fn count(&self) -> u64 {
90 self.count
91 }
92
93 pub fn is_empty(&self) -> bool {
94 self.count == 0
95 }
96
97 pub fn mean(&self) -> f64 {
98 self.mean
99 }
100
101 pub fn variance(&self) -> f64 {
103 if self.count < 2 {
104 0.0
105 } else {
106 self.m2 / (self.count - 1) as f64
107 }
108 }
109
110 pub fn std_dev(&self) -> f64 {
111 self.variance().sqrt()
112 }
113
114 pub fn to_baseline(
115 &self,
116 entity: impl Into<String>,
117 metric: impl Into<String>,
118 ) -> EntityBaseline {
119 EntityBaseline::from_stats(entity, metric, self)
120 }
121}
122
123#[async_trait]
124pub trait BaselineStore: Send + Sync {
125 async fn put(&self, baseline: EntityBaseline) -> Result<(), BaselineError>;
126 async fn get(&self, entity: &str, metric: &str) -> Result<EntityBaseline, BaselineError>;
127 async fn contains(&self, entity: &str, metric: &str) -> bool;
128}
129
130#[derive(Clone, Default)]
131pub struct InMemoryBaselineStore {
132 inner: Arc<RwLock<HashMap<(String, String), EntityBaseline>>>,
133}
134
135impl InMemoryBaselineStore {
136 pub fn new() -> Self {
137 Self::default()
138 }
139 pub fn arc() -> Arc<Self> {
140 Arc::new(Self::new())
141 }
142 pub fn len(&self) -> usize {
143 self.inner.read().len()
144 }
145 pub fn is_empty(&self) -> bool {
146 self.inner.read().is_empty()
147 }
148}
149
150#[async_trait]
151impl BaselineStore for InMemoryBaselineStore {
152 async fn put(&self, baseline: EntityBaseline) -> Result<(), BaselineError> {
153 self.inner
154 .write()
155 .insert((baseline.entity.clone(), baseline.metric.clone()), baseline);
156 Ok(())
157 }
158 async fn get(&self, entity: &str, metric: &str) -> Result<EntityBaseline, BaselineError> {
159 self.inner
160 .read()
161 .get(&(entity.to_string(), metric.to_string()))
162 .cloned()
163 .ok_or_else(|| BaselineError::NotFound {
164 entity: entity.to_string(),
165 metric: metric.to_string(),
166 })
167 }
168 async fn contains(&self, entity: &str, metric: &str) -> bool {
169 self.inner
170 .read()
171 .contains_key(&(entity.to_string(), metric.to_string()))
172 }
173}
174
175pub struct BaselineCompareTool {
177 store: Arc<dyn BaselineStore>,
178}
179
180impl BaselineCompareTool {
181 pub const NAME: &'static str = "baseline.compare";
182
183 pub fn new(store: Arc<dyn BaselineStore>) -> Self {
184 Self { store }
185 }
186
187 pub fn arc(store: Arc<dyn BaselineStore>) -> Arc<dyn Tool> {
188 Arc::new(Self::new(store))
189 }
190}
191
192#[async_trait]
193impl Tool for BaselineCompareTool {
194 fn schema(&self) -> ToolSchema {
195 ToolSchema {
196 name: Self::NAME.into(),
197 description:
198 "Compare an observed value to the entity's baseline (mean +/- k*sigma). Returns availability and within-bound flags."
199 .into(),
200 args_schema: json!({
201 "type": "object",
202 "required": ["entity", "metric", "value"],
203 "properties": {
204 "entity": {"type": "string"},
205 "metric": {"type": "string"},
206 "value": {"type": "number"},
207 "k": {"type": "number", "default": 2.0}
208 }
209 }),
210 result_schema: json!({"type": "object"}),
211 }
212 }
213
214 fn name(&self) -> rig_compose::tool::ToolName {
215 Self::NAME.to_string()
216 }
217
218 async fn invoke(&self, args: Value) -> Result<Value, KernelError> {
219 #[derive(serde::Deserialize)]
220 struct Args {
221 entity: String,
222 metric: String,
223 value: f64,
224 #[serde(default = "default_k")]
225 k: f64,
226 }
227 fn default_k() -> f64 {
228 2.0
229 }
230 let parsed: Args = serde_json::from_value(args)?;
231 match self.store.get(&parsed.entity, &parsed.metric).await {
232 Ok(baseline) => Ok(json!({
233 "available": true,
234 "within": baseline.within(parsed.value, parsed.k),
235 "mean": baseline.mean,
236 "std_dev": baseline.std_dev,
237 "k": parsed.k,
238 })),
239 Err(_) => Ok(json!({
240 "available": false,
241 "within": false,
242 "k": parsed.k,
243 })),
244 }
245 }
246}
247
248#[must_use]
279pub fn baseline_compare_trace_envelope(
280 entity: &str,
281 metric: &str,
282 observed: f64,
283 k: f64,
284 baseline: Option<&EntityBaseline>,
285) -> ResourceTraceEnvelope {
286 let input = json!({
287 "entity": entity,
288 "metric": metric,
289 "observed_value": observed,
290 "k": k,
291 });
292
293 let mut envelope = ResourceTraceEnvelope::new(TRACE_RESOURCE, TRACE_OPERATION, TRACE_KIND)
294 .with_input_summary(input);
295
296 match baseline {
297 None => {
298 envelope = envelope
299 .with_output_summary(json!({
300 "available": false,
301 "within": false,
302 }))
303 .with_reason(TRACE_REASON_NOT_FOUND);
304 }
305 Some(baseline) => {
306 let within = baseline.within(observed, k);
307 let bound = (k * baseline.std_dev).max(f64::EPSILON);
308 let deviation = (observed - baseline.mean).abs();
309 envelope = envelope
310 .with_output_summary(json!({
311 "available": true,
312 "within": within,
313 "mean": baseline.mean,
314 "std_dev": baseline.std_dev,
315 "bound": bound,
316 "deviation": deviation,
317 }))
318 .with_reason(if within {
319 TRACE_REASON_WITHIN_BOUNDS
320 } else {
321 TRACE_REASON_EXCEEDS_BOUNDS
322 });
323
324 let mut metadata = json!({
325 "samples": baseline.samples,
326 });
327 if baseline.std_dev > f64::EPSILON
328 && let Some(map) = metadata.as_object_mut()
329 && let Some(z) =
330 serde_json::Number::from_f64((observed - baseline.mean) / baseline.std_dev)
331 {
332 map.insert("z_score".into(), Value::Number(z));
333 }
334 envelope = envelope.with_metadata(metadata);
335 }
336 }
337
338 envelope
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344
345 fn baseline(entity: &str, metric: &str, mean: f64, sd: f64) -> EntityBaseline {
346 EntityBaseline {
347 entity: entity.into(),
348 metric: metric.into(),
349 mean,
350 std_dev: sd,
351 samples: 100,
352 }
353 }
354
355 #[tokio::test]
356 async fn within_bounds_check() {
357 let b = baseline("e", "fanout", 10.0, 2.0);
358 assert!(b.within(11.0, 2.0));
359 assert!(!b.within(20.0, 2.0));
360 }
361
362 #[test]
363 fn online_stats_builds_entity_baseline() {
364 let mut stats = OnlineStats::new();
365 for value in [2.0_f64, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0] {
366 stats.push(value);
367 }
368 let baseline = stats.to_baseline("host", "bytes");
369 assert_eq!(baseline.samples, 8);
370 assert!((baseline.mean - 5.0).abs() < 1e-12);
371 assert!((baseline.std_dev - 4.571_428_571_428_f64.sqrt()).abs() < 1e-12);
372 }
373
374 #[tokio::test]
375 async fn store_put_then_get() {
376 let store = InMemoryBaselineStore::new();
377 store.put(baseline("e", "m", 5.0, 1.0)).await.unwrap();
378 let got = store.get("e", "m").await.unwrap();
379 assert_eq!(got.samples, 100);
380 assert!(store.contains("e", "m").await);
381 }
382
383 #[tokio::test]
384 async fn tool_reports_available_and_within() {
385 let store: Arc<dyn BaselineStore> = Arc::new(InMemoryBaselineStore::new());
386 store.put(baseline("e", "m", 100.0, 5.0)).await.unwrap();
387 let tool = BaselineCompareTool::new(store);
388 let out = tool
389 .invoke(json!({"entity": "e", "metric": "m", "value": 102.0, "k": 2.0}))
390 .await
391 .unwrap();
392 assert_eq!(out["available"], true);
393 assert_eq!(out["within"], true);
394 }
395
396 #[test]
397 fn trace_envelope_within_bounds_includes_metadata() {
398 let b = baseline("host-1", "fanout", 10.0, 2.0);
399 let envelope = baseline_compare_trace_envelope("host-1", "fanout", 11.0, 2.0, Some(&b));
400
401 assert_eq!(envelope.version, ResourceTraceEnvelope::VERSION);
402 assert_eq!(envelope.resource, "baseline");
403 assert_eq!(envelope.operation, "compare");
404 assert_eq!(envelope.trace_kind, "baseline_compare");
405 assert_eq!(envelope.input_summary["entity"], "host-1");
406 assert_eq!(envelope.input_summary["metric"], "fanout");
407 let observed = envelope.input_summary["observed_value"].as_f64().unwrap();
408 assert!((observed - 11.0).abs() < 1e-9);
409 assert_eq!(envelope.output_summary["available"], true);
410 assert_eq!(envelope.output_summary["within"], true);
411 let mean = envelope.output_summary["mean"].as_f64().unwrap();
412 assert!((mean - 10.0).abs() < 1e-9);
413 let bound = envelope.output_summary["bound"].as_f64().unwrap();
414 assert!((bound - 4.0).abs() < 1e-9);
415 assert_eq!(envelope.reason.as_deref(), Some(TRACE_REASON_WITHIN_BOUNDS));
416 assert_eq!(envelope.metadata["samples"], 100);
417 let z = envelope.metadata["z_score"].as_f64().unwrap();
418 assert!((z - 0.5).abs() < 1e-9);
419 }
420
421 #[test]
422 fn trace_envelope_exceeds_bounds_sets_reason() {
423 let b = baseline("host-1", "fanout", 10.0, 2.0);
424 let envelope = baseline_compare_trace_envelope("host-1", "fanout", 20.0, 2.0, Some(&b));
425 assert_eq!(envelope.output_summary["within"], false);
426 assert_eq!(
427 envelope.reason.as_deref(),
428 Some(TRACE_REASON_EXCEEDS_BOUNDS)
429 );
430 let deviation = envelope.output_summary["deviation"].as_f64().unwrap();
431 assert!((deviation - 10.0).abs() < 1e-9);
432 }
433
434 #[test]
435 fn trace_envelope_not_found_omits_baseline_fields() {
436 let envelope = baseline_compare_trace_envelope("ghost", "metric", 7.0, 2.0, None);
437 assert_eq!(envelope.output_summary["available"], false);
438 assert_eq!(envelope.output_summary["within"], false);
439 assert!(envelope.output_summary.get("mean").is_none());
440 assert_eq!(envelope.reason.as_deref(), Some(TRACE_REASON_NOT_FOUND));
441 assert!(envelope.metadata.is_null());
442 }
443}