Skip to main content

rig_resources/
baseline.rs

1//! Environmental baselines and the `baseline.compare` tool.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use parking_lot::RwLock;
8use serde::{Deserialize, Serialize};
9use serde_json::{Value, json};
10use thiserror::Error;
11
12use rig_compose::{KernelError, Tool, ToolSchema};
13
14#[derive(Debug, Error)]
15pub enum BaselineError {
16    #[error("baseline `{entity}/{metric}` not found")]
17    NotFound { entity: String, metric: String },
18}
19
20/// Statistical envelope for one (entity, metric) pair.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct EntityBaseline {
23    pub entity: String,
24    pub metric: String,
25    pub mean: f64,
26    pub std_dev: f64,
27    pub samples: u64,
28}
29
30impl EntityBaseline {
31    pub fn from_stats(
32        entity: impl Into<String>,
33        metric: impl Into<String>,
34        stats: &OnlineStats,
35    ) -> Self {
36        Self {
37            entity: entity.into(),
38            metric: metric.into(),
39            mean: stats.mean(),
40            std_dev: stats.std_dev(),
41            samples: stats.count(),
42        }
43    }
44
45    pub fn within(&self, value: f64, k: f64) -> bool {
46        let bound = (k * self.std_dev).max(f64::EPSILON);
47        (value - self.mean).abs() <= bound
48    }
49}
50
51/// Online mean/variance accumulator for building [`EntityBaseline`] values.
52///
53/// Uses Welford's algorithm, so callers can update an environmental baseline
54/// one observation at a time without storing raw samples.
55#[derive(Debug, Clone, Default, Serialize, Deserialize)]
56pub struct OnlineStats {
57    count: u64,
58    mean: f64,
59    m2: f64,
60}
61
62impl OnlineStats {
63    pub fn new() -> Self {
64        Self::default()
65    }
66
67    pub fn push(&mut self, value: f64) {
68        self.count = self.count.saturating_add(1);
69        let delta = value - self.mean;
70        self.mean += delta / self.count as f64;
71        let delta2 = value - self.mean;
72        self.m2 += delta * delta2;
73    }
74
75    pub fn count(&self) -> u64 {
76        self.count
77    }
78
79    pub fn is_empty(&self) -> bool {
80        self.count == 0
81    }
82
83    pub fn mean(&self) -> f64 {
84        self.mean
85    }
86
87    /// Sample variance. Returns `0.0` until at least two samples exist.
88    pub fn variance(&self) -> f64 {
89        if self.count < 2 {
90            0.0
91        } else {
92            self.m2 / (self.count - 1) as f64
93        }
94    }
95
96    pub fn std_dev(&self) -> f64 {
97        self.variance().sqrt()
98    }
99
100    pub fn to_baseline(
101        &self,
102        entity: impl Into<String>,
103        metric: impl Into<String>,
104    ) -> EntityBaseline {
105        EntityBaseline::from_stats(entity, metric, self)
106    }
107}
108
109#[async_trait]
110pub trait BaselineStore: Send + Sync {
111    async fn put(&self, baseline: EntityBaseline) -> Result<(), BaselineError>;
112    async fn get(&self, entity: &str, metric: &str) -> Result<EntityBaseline, BaselineError>;
113    async fn contains(&self, entity: &str, metric: &str) -> bool;
114}
115
116#[derive(Clone, Default)]
117pub struct InMemoryBaselineStore {
118    inner: Arc<RwLock<HashMap<(String, String), EntityBaseline>>>,
119}
120
121impl InMemoryBaselineStore {
122    pub fn new() -> Self {
123        Self::default()
124    }
125    pub fn arc() -> Arc<Self> {
126        Arc::new(Self::new())
127    }
128    pub fn len(&self) -> usize {
129        self.inner.read().len()
130    }
131    pub fn is_empty(&self) -> bool {
132        self.inner.read().is_empty()
133    }
134}
135
136#[async_trait]
137impl BaselineStore for InMemoryBaselineStore {
138    async fn put(&self, baseline: EntityBaseline) -> Result<(), BaselineError> {
139        self.inner
140            .write()
141            .insert((baseline.entity.clone(), baseline.metric.clone()), baseline);
142        Ok(())
143    }
144    async fn get(&self, entity: &str, metric: &str) -> Result<EntityBaseline, BaselineError> {
145        self.inner
146            .read()
147            .get(&(entity.to_string(), metric.to_string()))
148            .cloned()
149            .ok_or_else(|| BaselineError::NotFound {
150                entity: entity.to_string(),
151                metric: metric.to_string(),
152            })
153    }
154    async fn contains(&self, entity: &str, metric: &str) -> bool {
155        self.inner
156            .read()
157            .contains_key(&(entity.to_string(), metric.to_string()))
158    }
159}
160
161/// `baseline.compare` — kernel tool.
162pub struct BaselineCompareTool {
163    store: Arc<dyn BaselineStore>,
164}
165
166impl BaselineCompareTool {
167    pub const NAME: &'static str = "baseline.compare";
168
169    pub fn new(store: Arc<dyn BaselineStore>) -> Self {
170        Self { store }
171    }
172
173    pub fn arc(store: Arc<dyn BaselineStore>) -> Arc<dyn Tool> {
174        Arc::new(Self::new(store))
175    }
176}
177
178#[async_trait]
179impl Tool for BaselineCompareTool {
180    fn schema(&self) -> ToolSchema {
181        ToolSchema {
182            name: Self::NAME.into(),
183            description:
184                "Compare an observed value to the entity's baseline (mean +/- k*sigma). Returns availability and within-bound flags."
185                    .into(),
186            args_schema: json!({
187                "type": "object",
188                "required": ["entity", "metric", "value"],
189                "properties": {
190                    "entity": {"type": "string"},
191                    "metric": {"type": "string"},
192                    "value": {"type": "number"},
193                    "k": {"type": "number", "default": 2.0}
194                }
195            }),
196            result_schema: json!({"type": "object"}),
197        }
198    }
199
200    fn name(&self) -> rig_compose::tool::ToolName {
201        Self::NAME.to_string()
202    }
203
204    async fn invoke(&self, args: Value) -> Result<Value, KernelError> {
205        #[derive(serde::Deserialize)]
206        struct Args {
207            entity: String,
208            metric: String,
209            value: f64,
210            #[serde(default = "default_k")]
211            k: f64,
212        }
213        fn default_k() -> f64 {
214            2.0
215        }
216        let parsed: Args = serde_json::from_value(args)?;
217        match self.store.get(&parsed.entity, &parsed.metric).await {
218            Ok(baseline) => Ok(json!({
219                "available": true,
220                "within": baseline.within(parsed.value, parsed.k),
221                "mean": baseline.mean,
222                "std_dev": baseline.std_dev,
223                "k": parsed.k,
224            })),
225            Err(_) => Ok(json!({
226                "available": false,
227                "within": false,
228                "k": parsed.k,
229            })),
230        }
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    fn baseline(entity: &str, metric: &str, mean: f64, sd: f64) -> EntityBaseline {
239        EntityBaseline {
240            entity: entity.into(),
241            metric: metric.into(),
242            mean,
243            std_dev: sd,
244            samples: 100,
245        }
246    }
247
248    #[tokio::test]
249    async fn within_bounds_check() {
250        let b = baseline("e", "fanout", 10.0, 2.0);
251        assert!(b.within(11.0, 2.0));
252        assert!(!b.within(20.0, 2.0));
253    }
254
255    #[test]
256    fn online_stats_builds_entity_baseline() {
257        let mut stats = OnlineStats::new();
258        for value in [2.0_f64, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0] {
259            stats.push(value);
260        }
261        let baseline = stats.to_baseline("host", "bytes");
262        assert_eq!(baseline.samples, 8);
263        assert!((baseline.mean - 5.0).abs() < 1e-12);
264        assert!((baseline.std_dev - 4.571_428_571_428_f64.sqrt()).abs() < 1e-12);
265    }
266
267    #[tokio::test]
268    async fn store_put_then_get() {
269        let store = InMemoryBaselineStore::new();
270        store.put(baseline("e", "m", 5.0, 1.0)).await.unwrap();
271        let got = store.get("e", "m").await.unwrap();
272        assert_eq!(got.samples, 100);
273        assert!(store.contains("e", "m").await);
274    }
275
276    #[tokio::test]
277    async fn tool_reports_available_and_within() {
278        let store: Arc<dyn BaselineStore> = Arc::new(InMemoryBaselineStore::new());
279        store.put(baseline("e", "m", 100.0, 5.0)).await.unwrap();
280        let tool = BaselineCompareTool::new(store);
281        let out = tool
282            .invoke(json!({"entity": "e", "metric": "m", "value": 102.0, "k": 2.0}))
283            .await
284            .unwrap();
285        assert_eq!(out["available"], true);
286        assert_eq!(out["within"], true);
287    }
288}