1use std::collections::HashMap;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use parking_lot::RwLock;
8use serde::{Deserialize, Serialize};
9use serde_json::{Value, json};
10use thiserror::Error;
11
12use rig_compose::{KernelError, Tool, ToolSchema};
13
14#[derive(Debug, Error)]
15pub enum BaselineError {
16 #[error("baseline `{entity}/{metric}` not found")]
17 NotFound { entity: String, metric: String },
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct EntityBaseline {
23 pub entity: String,
24 pub metric: String,
25 pub mean: f64,
26 pub std_dev: f64,
27 pub samples: u64,
28}
29
30impl EntityBaseline {
31 pub fn from_stats(
32 entity: impl Into<String>,
33 metric: impl Into<String>,
34 stats: &OnlineStats,
35 ) -> Self {
36 Self {
37 entity: entity.into(),
38 metric: metric.into(),
39 mean: stats.mean(),
40 std_dev: stats.std_dev(),
41 samples: stats.count(),
42 }
43 }
44
45 pub fn within(&self, value: f64, k: f64) -> bool {
46 let bound = (k * self.std_dev).max(f64::EPSILON);
47 (value - self.mean).abs() <= bound
48 }
49}
50
51#[derive(Debug, Clone, Default, Serialize, Deserialize)]
56pub struct OnlineStats {
57 count: u64,
58 mean: f64,
59 m2: f64,
60}
61
62impl OnlineStats {
63 pub fn new() -> Self {
64 Self::default()
65 }
66
67 pub fn push(&mut self, value: f64) {
68 self.count = self.count.saturating_add(1);
69 let delta = value - self.mean;
70 self.mean += delta / self.count as f64;
71 let delta2 = value - self.mean;
72 self.m2 += delta * delta2;
73 }
74
75 pub fn count(&self) -> u64 {
76 self.count
77 }
78
79 pub fn is_empty(&self) -> bool {
80 self.count == 0
81 }
82
83 pub fn mean(&self) -> f64 {
84 self.mean
85 }
86
87 pub fn variance(&self) -> f64 {
89 if self.count < 2 {
90 0.0
91 } else {
92 self.m2 / (self.count - 1) as f64
93 }
94 }
95
96 pub fn std_dev(&self) -> f64 {
97 self.variance().sqrt()
98 }
99
100 pub fn to_baseline(
101 &self,
102 entity: impl Into<String>,
103 metric: impl Into<String>,
104 ) -> EntityBaseline {
105 EntityBaseline::from_stats(entity, metric, self)
106 }
107}
108
109#[async_trait]
110pub trait BaselineStore: Send + Sync {
111 async fn put(&self, baseline: EntityBaseline) -> Result<(), BaselineError>;
112 async fn get(&self, entity: &str, metric: &str) -> Result<EntityBaseline, BaselineError>;
113 async fn contains(&self, entity: &str, metric: &str) -> bool;
114}
115
116#[derive(Clone, Default)]
117pub struct InMemoryBaselineStore {
118 inner: Arc<RwLock<HashMap<(String, String), EntityBaseline>>>,
119}
120
121impl InMemoryBaselineStore {
122 pub fn new() -> Self {
123 Self::default()
124 }
125 pub fn arc() -> Arc<Self> {
126 Arc::new(Self::new())
127 }
128 pub fn len(&self) -> usize {
129 self.inner.read().len()
130 }
131 pub fn is_empty(&self) -> bool {
132 self.inner.read().is_empty()
133 }
134}
135
136#[async_trait]
137impl BaselineStore for InMemoryBaselineStore {
138 async fn put(&self, baseline: EntityBaseline) -> Result<(), BaselineError> {
139 self.inner
140 .write()
141 .insert((baseline.entity.clone(), baseline.metric.clone()), baseline);
142 Ok(())
143 }
144 async fn get(&self, entity: &str, metric: &str) -> Result<EntityBaseline, BaselineError> {
145 self.inner
146 .read()
147 .get(&(entity.to_string(), metric.to_string()))
148 .cloned()
149 .ok_or_else(|| BaselineError::NotFound {
150 entity: entity.to_string(),
151 metric: metric.to_string(),
152 })
153 }
154 async fn contains(&self, entity: &str, metric: &str) -> bool {
155 self.inner
156 .read()
157 .contains_key(&(entity.to_string(), metric.to_string()))
158 }
159}
160
161pub struct BaselineCompareTool {
163 store: Arc<dyn BaselineStore>,
164}
165
166impl BaselineCompareTool {
167 pub const NAME: &'static str = "baseline.compare";
168
169 pub fn new(store: Arc<dyn BaselineStore>) -> Self {
170 Self { store }
171 }
172
173 pub fn arc(store: Arc<dyn BaselineStore>) -> Arc<dyn Tool> {
174 Arc::new(Self::new(store))
175 }
176}
177
178#[async_trait]
179impl Tool for BaselineCompareTool {
180 fn schema(&self) -> ToolSchema {
181 ToolSchema {
182 name: Self::NAME.into(),
183 description:
184 "Compare an observed value to the entity's baseline (mean +/- k*sigma). Returns availability and within-bound flags."
185 .into(),
186 args_schema: json!({
187 "type": "object",
188 "required": ["entity", "metric", "value"],
189 "properties": {
190 "entity": {"type": "string"},
191 "metric": {"type": "string"},
192 "value": {"type": "number"},
193 "k": {"type": "number", "default": 2.0}
194 }
195 }),
196 result_schema: json!({"type": "object"}),
197 }
198 }
199
200 fn name(&self) -> rig_compose::tool::ToolName {
201 Self::NAME.to_string()
202 }
203
204 async fn invoke(&self, args: Value) -> Result<Value, KernelError> {
205 #[derive(serde::Deserialize)]
206 struct Args {
207 entity: String,
208 metric: String,
209 value: f64,
210 #[serde(default = "default_k")]
211 k: f64,
212 }
213 fn default_k() -> f64 {
214 2.0
215 }
216 let parsed: Args = serde_json::from_value(args)?;
217 match self.store.get(&parsed.entity, &parsed.metric).await {
218 Ok(baseline) => Ok(json!({
219 "available": true,
220 "within": baseline.within(parsed.value, parsed.k),
221 "mean": baseline.mean,
222 "std_dev": baseline.std_dev,
223 "k": parsed.k,
224 })),
225 Err(_) => Ok(json!({
226 "available": false,
227 "within": false,
228 "k": parsed.k,
229 })),
230 }
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 fn baseline(entity: &str, metric: &str, mean: f64, sd: f64) -> EntityBaseline {
239 EntityBaseline {
240 entity: entity.into(),
241 metric: metric.into(),
242 mean,
243 std_dev: sd,
244 samples: 100,
245 }
246 }
247
248 #[tokio::test]
249 async fn within_bounds_check() {
250 let b = baseline("e", "fanout", 10.0, 2.0);
251 assert!(b.within(11.0, 2.0));
252 assert!(!b.within(20.0, 2.0));
253 }
254
255 #[test]
256 fn online_stats_builds_entity_baseline() {
257 let mut stats = OnlineStats::new();
258 for value in [2.0_f64, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0] {
259 stats.push(value);
260 }
261 let baseline = stats.to_baseline("host", "bytes");
262 assert_eq!(baseline.samples, 8);
263 assert!((baseline.mean - 5.0).abs() < 1e-12);
264 assert!((baseline.std_dev - 4.571_428_571_428_f64.sqrt()).abs() < 1e-12);
265 }
266
267 #[tokio::test]
268 async fn store_put_then_get() {
269 let store = InMemoryBaselineStore::new();
270 store.put(baseline("e", "m", 5.0, 1.0)).await.unwrap();
271 let got = store.get("e", "m").await.unwrap();
272 assert_eq!(got.samples, 100);
273 assert!(store.contains("e", "m").await);
274 }
275
276 #[tokio::test]
277 async fn tool_reports_available_and_within() {
278 let store: Arc<dyn BaselineStore> = Arc::new(InMemoryBaselineStore::new());
279 store.put(baseline("e", "m", 100.0, 5.0)).await.unwrap();
280 let tool = BaselineCompareTool::new(store);
281 let out = tool
282 .invoke(json!({"entity": "e", "metric": "m", "value": 102.0, "k": 2.0}))
283 .await
284 .unwrap();
285 assert_eq!(out["available"], true);
286 assert_eq!(out["within"], true);
287 }
288}