1use std::collections::HashMap;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use parking_lot::RwLock;
8use serde::{Deserialize, Serialize};
9use serde_json::{Value, json};
10use thiserror::Error;
11
12use rig_compose::{KernelError, Tool, ToolSchema};
13
14use crate::trace::ResourceTraceEnvelope;
15
16const TRACE_RESOURCE: &str = "baseline";
17const TRACE_OPERATION: &str = "compare";
18const TRACE_KIND: &str = "baseline_compare";
19
20pub const TRACE_REASON_NOT_FOUND: &str = "baseline_not_found";
23pub const TRACE_REASON_WITHIN_BOUNDS: &str = "within_bounds";
25pub const TRACE_REASON_EXCEEDS_BOUNDS: &str = "exceeds_bounds";
27
28#[derive(Debug, Error)]
30pub enum BaselineError {
31 #[error("baseline `{entity}/{metric}` not found")]
33 NotFound {
34 entity: String,
36 metric: String,
38 },
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct EntityBaseline {
44 pub entity: String,
46 pub metric: String,
48 pub mean: f64,
50 pub std_dev: f64,
52 pub samples: u64,
54}
55
56impl EntityBaseline {
57 pub fn from_stats(
59 entity: impl Into<String>,
60 metric: impl Into<String>,
61 stats: &OnlineStats,
62 ) -> Self {
63 Self {
64 entity: entity.into(),
65 metric: metric.into(),
66 mean: stats.mean(),
67 std_dev: stats.std_dev(),
68 samples: stats.count(),
69 }
70 }
71
72 pub fn within(&self, value: f64, k: f64) -> bool {
74 let bound = (k * self.std_dev).max(f64::EPSILON);
75 (value - self.mean).abs() <= bound
76 }
77}
78
79#[derive(Debug, Clone, Default, Serialize, Deserialize)]
84pub struct OnlineStats {
85 count: u64,
86 mean: f64,
87 m2: f64,
88}
89
90impl OnlineStats {
91 pub fn new() -> Self {
93 Self::default()
94 }
95
96 pub fn push(&mut self, value: f64) {
98 self.count = self.count.saturating_add(1);
99 let delta = value - self.mean;
100 self.mean += delta / self.count as f64;
101 let delta2 = value - self.mean;
102 self.m2 += delta * delta2;
103 }
104
105 pub fn count(&self) -> u64 {
107 self.count
108 }
109
110 pub fn is_empty(&self) -> bool {
112 self.count == 0
113 }
114
115 pub fn mean(&self) -> f64 {
117 self.mean
118 }
119
120 pub fn variance(&self) -> f64 {
122 if self.count < 2 {
123 0.0
124 } else {
125 self.m2 / (self.count - 1) as f64
126 }
127 }
128
129 pub fn std_dev(&self) -> f64 {
131 self.variance().sqrt()
132 }
133
134 pub fn to_baseline(
136 &self,
137 entity: impl Into<String>,
138 metric: impl Into<String>,
139 ) -> EntityBaseline {
140 EntityBaseline::from_stats(entity, metric, self)
141 }
142}
143
144#[async_trait]
146pub trait BaselineStore: Send + Sync {
147 async fn put(&self, baseline: EntityBaseline) -> Result<(), BaselineError>;
149 async fn get(&self, entity: &str, metric: &str) -> Result<EntityBaseline, BaselineError>;
151 async fn contains(&self, entity: &str, metric: &str) -> bool;
153}
154
155#[derive(Clone, Default)]
157pub struct InMemoryBaselineStore {
158 inner: Arc<RwLock<HashMap<(String, String), EntityBaseline>>>,
159}
160
161impl InMemoryBaselineStore {
162 pub fn new() -> Self {
164 Self::default()
165 }
166 pub fn arc() -> Arc<Self> {
168 Arc::new(Self::new())
169 }
170 pub fn len(&self) -> usize {
172 self.inner.read().len()
173 }
174 pub fn is_empty(&self) -> bool {
176 self.inner.read().is_empty()
177 }
178}
179
180#[async_trait]
181impl BaselineStore for InMemoryBaselineStore {
182 async fn put(&self, baseline: EntityBaseline) -> Result<(), BaselineError> {
183 self.inner
184 .write()
185 .insert((baseline.entity.clone(), baseline.metric.clone()), baseline);
186 Ok(())
187 }
188 async fn get(&self, entity: &str, metric: &str) -> Result<EntityBaseline, BaselineError> {
189 self.inner
190 .read()
191 .get(&(entity.to_string(), metric.to_string()))
192 .cloned()
193 .ok_or_else(|| BaselineError::NotFound {
194 entity: entity.to_string(),
195 metric: metric.to_string(),
196 })
197 }
198 async fn contains(&self, entity: &str, metric: &str) -> bool {
199 self.inner
200 .read()
201 .contains_key(&(entity.to_string(), metric.to_string()))
202 }
203}
204
205pub struct BaselineCompareTool {
207 store: Arc<dyn BaselineStore>,
208}
209
210impl BaselineCompareTool {
211 pub const NAME: &'static str = "baseline.compare";
213
214 pub fn new(store: Arc<dyn BaselineStore>) -> Self {
216 Self { store }
217 }
218
219 pub fn arc(store: Arc<dyn BaselineStore>) -> Arc<dyn Tool> {
221 Arc::new(Self::new(store))
222 }
223}
224
225#[async_trait]
226impl Tool for BaselineCompareTool {
227 fn schema(&self) -> ToolSchema {
228 ToolSchema {
229 name: Self::NAME.into(),
230 description:
231 "Compare an observed value to the entity's baseline (mean +/- k*sigma). Returns availability and within-bound flags."
232 .into(),
233 args_schema: json!({
234 "type": "object",
235 "required": ["entity", "metric", "value"],
236 "properties": {
237 "entity": {"type": "string"},
238 "metric": {"type": "string"},
239 "value": {"type": "number"},
240 "k": {"type": "number", "default": 2.0}
241 }
242 }),
243 result_schema: json!({"type": "object"}),
244 }
245 }
246
247 fn name(&self) -> rig_compose::tool::ToolName {
248 Self::NAME.to_string()
249 }
250
251 async fn invoke(&self, args: Value) -> Result<Value, KernelError> {
252 #[derive(serde::Deserialize)]
253 struct Args {
254 entity: String,
255 metric: String,
256 value: f64,
257 #[serde(default = "default_k")]
258 k: f64,
259 }
260 fn default_k() -> f64 {
261 2.0
262 }
263 let parsed: Args = serde_json::from_value(args)?;
264 match self.store.get(&parsed.entity, &parsed.metric).await {
265 Ok(baseline) => Ok(json!({
266 "available": true,
267 "within": baseline.within(parsed.value, parsed.k),
268 "mean": baseline.mean,
269 "std_dev": baseline.std_dev,
270 "k": parsed.k,
271 })),
272 Err(_) => Ok(json!({
273 "available": false,
274 "within": false,
275 "k": parsed.k,
276 })),
277 }
278 }
279}
280
281#[must_use]
312pub fn baseline_compare_trace_envelope(
313 entity: &str,
314 metric: &str,
315 observed: f64,
316 k: f64,
317 baseline: Option<&EntityBaseline>,
318) -> ResourceTraceEnvelope {
319 let input = json!({
320 "entity": entity,
321 "metric": metric,
322 "observed_value": observed,
323 "k": k,
324 });
325
326 let mut envelope = ResourceTraceEnvelope::new(TRACE_RESOURCE, TRACE_OPERATION, TRACE_KIND)
327 .with_input_summary(input);
328
329 match baseline {
330 None => {
331 envelope = envelope
332 .with_output_summary(json!({
333 "available": false,
334 "within": false,
335 }))
336 .with_reason(TRACE_REASON_NOT_FOUND);
337 }
338 Some(baseline) => {
339 let within = baseline.within(observed, k);
340 let bound = (k * baseline.std_dev).max(f64::EPSILON);
341 let deviation = (observed - baseline.mean).abs();
342 envelope = envelope
343 .with_output_summary(json!({
344 "available": true,
345 "within": within,
346 "mean": baseline.mean,
347 "std_dev": baseline.std_dev,
348 "bound": bound,
349 "deviation": deviation,
350 }))
351 .with_reason(if within {
352 TRACE_REASON_WITHIN_BOUNDS
353 } else {
354 TRACE_REASON_EXCEEDS_BOUNDS
355 });
356
357 let mut metadata = json!({
358 "samples": baseline.samples,
359 });
360 if baseline.std_dev > f64::EPSILON
361 && let Some(map) = metadata.as_object_mut()
362 && let Some(z) =
363 serde_json::Number::from_f64((observed - baseline.mean) / baseline.std_dev)
364 {
365 map.insert("z_score".into(), Value::Number(z));
366 }
367 envelope = envelope.with_metadata(metadata);
368 }
369 }
370
371 envelope
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 fn baseline(entity: &str, metric: &str, mean: f64, sd: f64) -> EntityBaseline {
379 EntityBaseline {
380 entity: entity.into(),
381 metric: metric.into(),
382 mean,
383 std_dev: sd,
384 samples: 100,
385 }
386 }
387
388 #[tokio::test]
389 async fn within_bounds_check() {
390 let b = baseline("e", "fanout", 10.0, 2.0);
391 assert!(b.within(11.0, 2.0));
392 assert!(!b.within(20.0, 2.0));
393 }
394
395 #[test]
396 fn online_stats_builds_entity_baseline() {
397 let mut stats = OnlineStats::new();
398 for value in [2.0_f64, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0] {
399 stats.push(value);
400 }
401 let baseline = stats.to_baseline("host", "bytes");
402 assert_eq!(baseline.samples, 8);
403 assert!((baseline.mean - 5.0).abs() < 1e-12);
404 assert!((baseline.std_dev - 4.571_428_571_428_f64.sqrt()).abs() < 1e-12);
405 }
406
407 #[tokio::test]
408 async fn store_put_then_get() {
409 let store = InMemoryBaselineStore::new();
410 store.put(baseline("e", "m", 5.0, 1.0)).await.unwrap();
411 let got = store.get("e", "m").await.unwrap();
412 assert_eq!(got.samples, 100);
413 assert!(store.contains("e", "m").await);
414 }
415
416 #[tokio::test]
417 async fn tool_reports_available_and_within() {
418 let store: Arc<dyn BaselineStore> = Arc::new(InMemoryBaselineStore::new());
419 store.put(baseline("e", "m", 100.0, 5.0)).await.unwrap();
420 let tool = BaselineCompareTool::new(store);
421 let out = tool
422 .invoke(json!({"entity": "e", "metric": "m", "value": 102.0, "k": 2.0}))
423 .await
424 .unwrap();
425 assert_eq!(out["available"], true);
426 assert_eq!(out["within"], true);
427 }
428
429 #[test]
430 fn trace_envelope_within_bounds_includes_metadata() {
431 let b = baseline("host-1", "fanout", 10.0, 2.0);
432 let envelope = baseline_compare_trace_envelope("host-1", "fanout", 11.0, 2.0, Some(&b));
433
434 assert_eq!(envelope.version, ResourceTraceEnvelope::VERSION);
435 assert_eq!(envelope.resource, "baseline");
436 assert_eq!(envelope.operation, "compare");
437 assert_eq!(envelope.trace_kind, "baseline_compare");
438 assert_eq!(envelope.input_summary["entity"], "host-1");
439 assert_eq!(envelope.input_summary["metric"], "fanout");
440 let observed = envelope.input_summary["observed_value"].as_f64().unwrap();
441 assert!((observed - 11.0).abs() < 1e-9);
442 assert_eq!(envelope.output_summary["available"], true);
443 assert_eq!(envelope.output_summary["within"], true);
444 let mean = envelope.output_summary["mean"].as_f64().unwrap();
445 assert!((mean - 10.0).abs() < 1e-9);
446 let bound = envelope.output_summary["bound"].as_f64().unwrap();
447 assert!((bound - 4.0).abs() < 1e-9);
448 assert_eq!(envelope.reason.as_deref(), Some(TRACE_REASON_WITHIN_BOUNDS));
449 assert_eq!(envelope.metadata["samples"], 100);
450 let z = envelope.metadata["z_score"].as_f64().unwrap();
451 assert!((z - 0.5).abs() < 1e-9);
452 }
453
454 #[test]
455 fn trace_envelope_exceeds_bounds_sets_reason() {
456 let b = baseline("host-1", "fanout", 10.0, 2.0);
457 let envelope = baseline_compare_trace_envelope("host-1", "fanout", 20.0, 2.0, Some(&b));
458 assert_eq!(envelope.output_summary["within"], false);
459 assert_eq!(
460 envelope.reason.as_deref(),
461 Some(TRACE_REASON_EXCEEDS_BOUNDS)
462 );
463 let deviation = envelope.output_summary["deviation"].as_f64().unwrap();
464 assert!((deviation - 10.0).abs() < 1e-9);
465 }
466
467 #[test]
468 fn trace_envelope_not_found_omits_baseline_fields() {
469 let envelope = baseline_compare_trace_envelope("ghost", "metric", 7.0, 2.0, None);
470 assert_eq!(envelope.output_summary["available"], false);
471 assert_eq!(envelope.output_summary["within"], false);
472 assert!(envelope.output_summary.get("mean").is_none());
473 assert_eq!(envelope.reason.as_deref(), Some(TRACE_REASON_NOT_FOUND));
474 assert!(envelope.metadata.is_null());
475 }
476}