Skip to main content

rig_resources/
baseline.rs

1//! Environmental baselines and the `baseline.compare` tool.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use parking_lot::RwLock;
8use serde::{Deserialize, Serialize};
9use serde_json::{Value, json};
10use thiserror::Error;
11
12use rig_compose::{KernelError, Tool, ToolSchema};
13
14#[derive(Debug, Error)]
15pub enum BaselineError {
16    #[error("baseline `{entity}/{metric}` not found")]
17    NotFound { entity: String, metric: String },
18}
19
20/// Statistical envelope for one (entity, metric) pair.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct EntityBaseline {
23    pub entity: String,
24    pub metric: String,
25    pub mean: f64,
26    pub std_dev: f64,
27    pub samples: u64,
28}
29
30impl EntityBaseline {
31    pub fn within(&self, value: f64, k: f64) -> bool {
32        let bound = (k * self.std_dev).max(f64::EPSILON);
33        (value - self.mean).abs() <= bound
34    }
35}
36
37#[async_trait]
38pub trait BaselineStore: Send + Sync {
39    async fn put(&self, baseline: EntityBaseline) -> Result<(), BaselineError>;
40    async fn get(&self, entity: &str, metric: &str) -> Result<EntityBaseline, BaselineError>;
41    async fn contains(&self, entity: &str, metric: &str) -> bool;
42}
43
44#[derive(Clone, Default)]
45pub struct InMemoryBaselineStore {
46    inner: Arc<RwLock<HashMap<(String, String), EntityBaseline>>>,
47}
48
49impl InMemoryBaselineStore {
50    pub fn new() -> Self {
51        Self::default()
52    }
53    pub fn arc() -> Arc<Self> {
54        Arc::new(Self::new())
55    }
56    pub fn len(&self) -> usize {
57        self.inner.read().len()
58    }
59    pub fn is_empty(&self) -> bool {
60        self.inner.read().is_empty()
61    }
62}
63
64#[async_trait]
65impl BaselineStore for InMemoryBaselineStore {
66    async fn put(&self, baseline: EntityBaseline) -> Result<(), BaselineError> {
67        self.inner
68            .write()
69            .insert((baseline.entity.clone(), baseline.metric.clone()), baseline);
70        Ok(())
71    }
72    async fn get(&self, entity: &str, metric: &str) -> Result<EntityBaseline, BaselineError> {
73        self.inner
74            .read()
75            .get(&(entity.to_string(), metric.to_string()))
76            .cloned()
77            .ok_or_else(|| BaselineError::NotFound {
78                entity: entity.to_string(),
79                metric: metric.to_string(),
80            })
81    }
82    async fn contains(&self, entity: &str, metric: &str) -> bool {
83        self.inner
84            .read()
85            .contains_key(&(entity.to_string(), metric.to_string()))
86    }
87}
88
89/// `baseline.compare` — kernel tool.
90pub struct BaselineCompareTool {
91    store: Arc<dyn BaselineStore>,
92}
93
94impl BaselineCompareTool {
95    pub const NAME: &'static str = "baseline.compare";
96
97    pub fn new(store: Arc<dyn BaselineStore>) -> Self {
98        Self { store }
99    }
100
101    pub fn arc(store: Arc<dyn BaselineStore>) -> Arc<dyn Tool> {
102        Arc::new(Self::new(store))
103    }
104}
105
106#[async_trait]
107impl Tool for BaselineCompareTool {
108    fn schema(&self) -> ToolSchema {
109        ToolSchema {
110            name: Self::NAME.into(),
111            description:
112                "Compare an observed value to the entity's baseline (mean +/- k*sigma). Returns availability and within-bound flags."
113                    .into(),
114            args_schema: json!({
115                "type": "object",
116                "required": ["entity", "metric", "value"],
117                "properties": {
118                    "entity": {"type": "string"},
119                    "metric": {"type": "string"},
120                    "value": {"type": "number"},
121                    "k": {"type": "number", "default": 2.0}
122                }
123            }),
124            result_schema: json!({"type": "object"}),
125        }
126    }
127
128    fn name(&self) -> rig_compose::tool::ToolName {
129        Self::NAME.to_string()
130    }
131
132    async fn invoke(&self, args: Value) -> Result<Value, KernelError> {
133        #[derive(serde::Deserialize)]
134        struct Args {
135            entity: String,
136            metric: String,
137            value: f64,
138            #[serde(default = "default_k")]
139            k: f64,
140        }
141        fn default_k() -> f64 {
142            2.0
143        }
144        let parsed: Args = serde_json::from_value(args)?;
145        match self.store.get(&parsed.entity, &parsed.metric).await {
146            Ok(baseline) => Ok(json!({
147                "available": true,
148                "within": baseline.within(parsed.value, parsed.k),
149                "mean": baseline.mean,
150                "std_dev": baseline.std_dev,
151                "k": parsed.k,
152            })),
153            Err(_) => Ok(json!({
154                "available": false,
155                "within": false,
156                "k": parsed.k,
157            })),
158        }
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165
166    fn baseline(entity: &str, metric: &str, mean: f64, sd: f64) -> EntityBaseline {
167        EntityBaseline {
168            entity: entity.into(),
169            metric: metric.into(),
170            mean,
171            std_dev: sd,
172            samples: 100,
173        }
174    }
175
176    #[tokio::test]
177    async fn within_bounds_check() {
178        let b = baseline("e", "fanout", 10.0, 2.0);
179        assert!(b.within(11.0, 2.0));
180        assert!(!b.within(20.0, 2.0));
181    }
182
183    #[tokio::test]
184    async fn store_put_then_get() {
185        let store = InMemoryBaselineStore::new();
186        store.put(baseline("e", "m", 5.0, 1.0)).await.unwrap();
187        let got = store.get("e", "m").await.unwrap();
188        assert_eq!(got.samples, 100);
189        assert!(store.contains("e", "m").await);
190    }
191
192    #[tokio::test]
193    async fn tool_reports_available_and_within() {
194        let store: Arc<dyn BaselineStore> = Arc::new(InMemoryBaselineStore::new());
195        store.put(baseline("e", "m", 100.0, 5.0)).await.unwrap();
196        let tool = BaselineCompareTool::new(store);
197        let out = tool
198            .invoke(json!({"entity": "e", "metric": "m", "value": 102.0, "k": 2.0}))
199            .await
200            .unwrap();
201        assert_eq!(out["available"], true);
202        assert_eq!(out["within"], true);
203    }
204}