1use std::collections::HashMap;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use parking_lot::RwLock;
8use serde::{Deserialize, Serialize};
9use serde_json::{Value, json};
10use thiserror::Error;
11
12use rig_compose::{KernelError, Tool, ToolSchema};
13
14#[derive(Debug, Error)]
15pub enum BaselineError {
16 #[error("baseline `{entity}/{metric}` not found")]
17 NotFound { entity: String, metric: String },
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct EntityBaseline {
23 pub entity: String,
24 pub metric: String,
25 pub mean: f64,
26 pub std_dev: f64,
27 pub samples: u64,
28}
29
30impl EntityBaseline {
31 pub fn within(&self, value: f64, k: f64) -> bool {
32 let bound = (k * self.std_dev).max(f64::EPSILON);
33 (value - self.mean).abs() <= bound
34 }
35}
36
37#[async_trait]
38pub trait BaselineStore: Send + Sync {
39 async fn put(&self, baseline: EntityBaseline) -> Result<(), BaselineError>;
40 async fn get(&self, entity: &str, metric: &str) -> Result<EntityBaseline, BaselineError>;
41 async fn contains(&self, entity: &str, metric: &str) -> bool;
42}
43
44#[derive(Clone, Default)]
45pub struct InMemoryBaselineStore {
46 inner: Arc<RwLock<HashMap<(String, String), EntityBaseline>>>,
47}
48
49impl InMemoryBaselineStore {
50 pub fn new() -> Self {
51 Self::default()
52 }
53 pub fn arc() -> Arc<Self> {
54 Arc::new(Self::new())
55 }
56 pub fn len(&self) -> usize {
57 self.inner.read().len()
58 }
59 pub fn is_empty(&self) -> bool {
60 self.inner.read().is_empty()
61 }
62}
63
64#[async_trait]
65impl BaselineStore for InMemoryBaselineStore {
66 async fn put(&self, baseline: EntityBaseline) -> Result<(), BaselineError> {
67 self.inner
68 .write()
69 .insert((baseline.entity.clone(), baseline.metric.clone()), baseline);
70 Ok(())
71 }
72 async fn get(&self, entity: &str, metric: &str) -> Result<EntityBaseline, BaselineError> {
73 self.inner
74 .read()
75 .get(&(entity.to_string(), metric.to_string()))
76 .cloned()
77 .ok_or_else(|| BaselineError::NotFound {
78 entity: entity.to_string(),
79 metric: metric.to_string(),
80 })
81 }
82 async fn contains(&self, entity: &str, metric: &str) -> bool {
83 self.inner
84 .read()
85 .contains_key(&(entity.to_string(), metric.to_string()))
86 }
87}
88
89pub struct BaselineCompareTool {
91 store: Arc<dyn BaselineStore>,
92}
93
94impl BaselineCompareTool {
95 pub const NAME: &'static str = "baseline.compare";
96
97 pub fn new(store: Arc<dyn BaselineStore>) -> Self {
98 Self { store }
99 }
100
101 pub fn arc(store: Arc<dyn BaselineStore>) -> Arc<dyn Tool> {
102 Arc::new(Self::new(store))
103 }
104}
105
106#[async_trait]
107impl Tool for BaselineCompareTool {
108 fn schema(&self) -> ToolSchema {
109 ToolSchema {
110 name: Self::NAME.into(),
111 description:
112 "Compare an observed value to the entity's baseline (mean +/- k*sigma). Returns availability and within-bound flags."
113 .into(),
114 args_schema: json!({
115 "type": "object",
116 "required": ["entity", "metric", "value"],
117 "properties": {
118 "entity": {"type": "string"},
119 "metric": {"type": "string"},
120 "value": {"type": "number"},
121 "k": {"type": "number", "default": 2.0}
122 }
123 }),
124 result_schema: json!({"type": "object"}),
125 }
126 }
127
128 fn name(&self) -> rig_compose::tool::ToolName {
129 Self::NAME.to_string()
130 }
131
132 async fn invoke(&self, args: Value) -> Result<Value, KernelError> {
133 #[derive(serde::Deserialize)]
134 struct Args {
135 entity: String,
136 metric: String,
137 value: f64,
138 #[serde(default = "default_k")]
139 k: f64,
140 }
141 fn default_k() -> f64 {
142 2.0
143 }
144 let parsed: Args = serde_json::from_value(args)?;
145 match self.store.get(&parsed.entity, &parsed.metric).await {
146 Ok(baseline) => Ok(json!({
147 "available": true,
148 "within": baseline.within(parsed.value, parsed.k),
149 "mean": baseline.mean,
150 "std_dev": baseline.std_dev,
151 "k": parsed.k,
152 })),
153 Err(_) => Ok(json!({
154 "available": false,
155 "within": false,
156 "k": parsed.k,
157 })),
158 }
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165
166 fn baseline(entity: &str, metric: &str, mean: f64, sd: f64) -> EntityBaseline {
167 EntityBaseline {
168 entity: entity.into(),
169 metric: metric.into(),
170 mean,
171 std_dev: sd,
172 samples: 100,
173 }
174 }
175
176 #[tokio::test]
177 async fn within_bounds_check() {
178 let b = baseline("e", "fanout", 10.0, 2.0);
179 assert!(b.within(11.0, 2.0));
180 assert!(!b.within(20.0, 2.0));
181 }
182
183 #[tokio::test]
184 async fn store_put_then_get() {
185 let store = InMemoryBaselineStore::new();
186 store.put(baseline("e", "m", 5.0, 1.0)).await.unwrap();
187 let got = store.get("e", "m").await.unwrap();
188 assert_eq!(got.samples, 100);
189 assert!(store.contains("e", "m").await);
190 }
191
192 #[tokio::test]
193 async fn tool_reports_available_and_within() {
194 let store: Arc<dyn BaselineStore> = Arc::new(InMemoryBaselineStore::new());
195 store.put(baseline("e", "m", 100.0, 5.0)).await.unwrap();
196 let tool = BaselineCompareTool::new(store);
197 let out = tool
198 .invoke(json!({"entity": "e", "metric": "m", "value": 102.0, "k": 2.0}))
199 .await
200 .unwrap();
201 assert_eq!(out["available"], true);
202 assert_eq!(out["within"], true);
203 }
204}