Skip to main content

rig_resources/
baseline.rs

1//! Environmental baselines and the `baseline.compare` tool.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use parking_lot::RwLock;
8use serde::{Deserialize, Serialize};
9use serde_json::{Value, json};
10use thiserror::Error;
11
12use rig_compose::{KernelError, Tool, ToolSchema};
13
14use crate::trace::ResourceTraceEnvelope;
15
16const TRACE_RESOURCE: &str = "baseline";
17const TRACE_OPERATION: &str = "compare";
18const TRACE_KIND: &str = "baseline_compare";
19
20/// Reason emitted when no baseline existed for the requested
21/// `(entity, metric)` pair.
22pub const TRACE_REASON_NOT_FOUND: &str = "baseline_not_found";
23/// Reason emitted when the observation fell inside the `mean ± k·σ` bound.
24pub const TRACE_REASON_WITHIN_BOUNDS: &str = "within_bounds";
25/// Reason emitted when the observation fell outside the `mean ± k·σ` bound.
26pub const TRACE_REASON_EXCEEDS_BOUNDS: &str = "exceeds_bounds";
27
28#[derive(Debug, Error)]
29pub enum BaselineError {
30    #[error("baseline `{entity}/{metric}` not found")]
31    NotFound { entity: String, metric: String },
32}
33
34/// Statistical envelope for one (entity, metric) pair.
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct EntityBaseline {
37    pub entity: String,
38    pub metric: String,
39    pub mean: f64,
40    pub std_dev: f64,
41    pub samples: u64,
42}
43
44impl EntityBaseline {
45    pub fn from_stats(
46        entity: impl Into<String>,
47        metric: impl Into<String>,
48        stats: &OnlineStats,
49    ) -> Self {
50        Self {
51            entity: entity.into(),
52            metric: metric.into(),
53            mean: stats.mean(),
54            std_dev: stats.std_dev(),
55            samples: stats.count(),
56        }
57    }
58
59    pub fn within(&self, value: f64, k: f64) -> bool {
60        let bound = (k * self.std_dev).max(f64::EPSILON);
61        (value - self.mean).abs() <= bound
62    }
63}
64
65/// Online mean/variance accumulator for building [`EntityBaseline`] values.
66///
67/// Uses Welford's algorithm, so callers can update an environmental baseline
68/// one observation at a time without storing raw samples.
69#[derive(Debug, Clone, Default, Serialize, Deserialize)]
70pub struct OnlineStats {
71    count: u64,
72    mean: f64,
73    m2: f64,
74}
75
76impl OnlineStats {
77    pub fn new() -> Self {
78        Self::default()
79    }
80
81    pub fn push(&mut self, value: f64) {
82        self.count = self.count.saturating_add(1);
83        let delta = value - self.mean;
84        self.mean += delta / self.count as f64;
85        let delta2 = value - self.mean;
86        self.m2 += delta * delta2;
87    }
88
89    pub fn count(&self) -> u64 {
90        self.count
91    }
92
93    pub fn is_empty(&self) -> bool {
94        self.count == 0
95    }
96
97    pub fn mean(&self) -> f64 {
98        self.mean
99    }
100
101    /// Sample variance. Returns `0.0` until at least two samples exist.
102    pub fn variance(&self) -> f64 {
103        if self.count < 2 {
104            0.0
105        } else {
106            self.m2 / (self.count - 1) as f64
107        }
108    }
109
110    pub fn std_dev(&self) -> f64 {
111        self.variance().sqrt()
112    }
113
114    pub fn to_baseline(
115        &self,
116        entity: impl Into<String>,
117        metric: impl Into<String>,
118    ) -> EntityBaseline {
119        EntityBaseline::from_stats(entity, metric, self)
120    }
121}
122
123#[async_trait]
124pub trait BaselineStore: Send + Sync {
125    async fn put(&self, baseline: EntityBaseline) -> Result<(), BaselineError>;
126    async fn get(&self, entity: &str, metric: &str) -> Result<EntityBaseline, BaselineError>;
127    async fn contains(&self, entity: &str, metric: &str) -> bool;
128}
129
130#[derive(Clone, Default)]
131pub struct InMemoryBaselineStore {
132    inner: Arc<RwLock<HashMap<(String, String), EntityBaseline>>>,
133}
134
135impl InMemoryBaselineStore {
136    pub fn new() -> Self {
137        Self::default()
138    }
139    pub fn arc() -> Arc<Self> {
140        Arc::new(Self::new())
141    }
142    pub fn len(&self) -> usize {
143        self.inner.read().len()
144    }
145    pub fn is_empty(&self) -> bool {
146        self.inner.read().is_empty()
147    }
148}
149
150#[async_trait]
151impl BaselineStore for InMemoryBaselineStore {
152    async fn put(&self, baseline: EntityBaseline) -> Result<(), BaselineError> {
153        self.inner
154            .write()
155            .insert((baseline.entity.clone(), baseline.metric.clone()), baseline);
156        Ok(())
157    }
158    async fn get(&self, entity: &str, metric: &str) -> Result<EntityBaseline, BaselineError> {
159        self.inner
160            .read()
161            .get(&(entity.to_string(), metric.to_string()))
162            .cloned()
163            .ok_or_else(|| BaselineError::NotFound {
164                entity: entity.to_string(),
165                metric: metric.to_string(),
166            })
167    }
168    async fn contains(&self, entity: &str, metric: &str) -> bool {
169        self.inner
170            .read()
171            .contains_key(&(entity.to_string(), metric.to_string()))
172    }
173}
174
175/// `baseline.compare` — kernel tool.
176pub struct BaselineCompareTool {
177    store: Arc<dyn BaselineStore>,
178}
179
180impl BaselineCompareTool {
181    pub const NAME: &'static str = "baseline.compare";
182
183    pub fn new(store: Arc<dyn BaselineStore>) -> Self {
184        Self { store }
185    }
186
187    pub fn arc(store: Arc<dyn BaselineStore>) -> Arc<dyn Tool> {
188        Arc::new(Self::new(store))
189    }
190}
191
192#[async_trait]
193impl Tool for BaselineCompareTool {
194    fn schema(&self) -> ToolSchema {
195        ToolSchema {
196            name: Self::NAME.into(),
197            description:
198                "Compare an observed value to the entity's baseline (mean +/- k*sigma). Returns availability and within-bound flags."
199                    .into(),
200            args_schema: json!({
201                "type": "object",
202                "required": ["entity", "metric", "value"],
203                "properties": {
204                    "entity": {"type": "string"},
205                    "metric": {"type": "string"},
206                    "value": {"type": "number"},
207                    "k": {"type": "number", "default": 2.0}
208                }
209            }),
210            result_schema: json!({"type": "object"}),
211        }
212    }
213
214    fn name(&self) -> rig_compose::tool::ToolName {
215        Self::NAME.to_string()
216    }
217
218    async fn invoke(&self, args: Value) -> Result<Value, KernelError> {
219        #[derive(serde::Deserialize)]
220        struct Args {
221            entity: String,
222            metric: String,
223            value: f64,
224            #[serde(default = "default_k")]
225            k: f64,
226        }
227        fn default_k() -> f64 {
228            2.0
229        }
230        let parsed: Args = serde_json::from_value(args)?;
231        match self.store.get(&parsed.entity, &parsed.metric).await {
232            Ok(baseline) => Ok(json!({
233                "available": true,
234                "within": baseline.within(parsed.value, parsed.k),
235                "mean": baseline.mean,
236                "std_dev": baseline.std_dev,
237                "k": parsed.k,
238            })),
239            Err(_) => Ok(json!({
240                "available": false,
241                "within": false,
242                "k": parsed.k,
243            })),
244        }
245    }
246}
247
248/// Build a [`ResourceTraceEnvelope`] describing a single `baseline.compare`
249/// evaluation.
250///
251/// Pass `baseline` as `Some(&EntityBaseline)` when the store had a record
252/// for the `(entity, metric)` pair, or `None` to record a not-available
253/// comparison. The envelope mirrors the structure of
254/// [`crate::security_finding_trace_envelope`] and
255/// [`crate::memory_lookup_trace_envelope`] so audit and observability
256/// pipelines can route all three with one shape.
257///
258/// Reason codes:
259/// * `None` → [`TRACE_REASON_NOT_FOUND`]
260/// * `Some(_)` and inside `mean ± k·σ` → [`TRACE_REASON_WITHIN_BOUNDS`]
261/// * `Some(_)` and outside the bound → [`TRACE_REASON_EXCEEDS_BOUNDS`]
262///
263/// ```no_run
264/// use rig_resources::{EntityBaseline, baseline_compare_trace_envelope};
265///
266/// let baseline = EntityBaseline {
267///     entity: "host-1".into(),
268///     metric: "fanout".into(),
269///     mean: 10.0,
270///     std_dev: 2.0,
271///     samples: 100,
272/// };
273/// let envelope =
274///     baseline_compare_trace_envelope("host-1", "fanout", 11.0, 2.0, Some(&baseline));
275/// assert_eq!(envelope.resource, "baseline");
276/// assert_eq!(envelope.output_summary["within"], true);
277/// ```
278#[must_use]
279pub fn baseline_compare_trace_envelope(
280    entity: &str,
281    metric: &str,
282    observed: f64,
283    k: f64,
284    baseline: Option<&EntityBaseline>,
285) -> ResourceTraceEnvelope {
286    let input = json!({
287        "entity": entity,
288        "metric": metric,
289        "observed_value": observed,
290        "k": k,
291    });
292
293    let mut envelope = ResourceTraceEnvelope::new(TRACE_RESOURCE, TRACE_OPERATION, TRACE_KIND)
294        .with_input_summary(input);
295
296    match baseline {
297        None => {
298            envelope = envelope
299                .with_output_summary(json!({
300                    "available": false,
301                    "within": false,
302                }))
303                .with_reason(TRACE_REASON_NOT_FOUND);
304        }
305        Some(baseline) => {
306            let within = baseline.within(observed, k);
307            let bound = (k * baseline.std_dev).max(f64::EPSILON);
308            let deviation = (observed - baseline.mean).abs();
309            envelope = envelope
310                .with_output_summary(json!({
311                    "available": true,
312                    "within": within,
313                    "mean": baseline.mean,
314                    "std_dev": baseline.std_dev,
315                    "bound": bound,
316                    "deviation": deviation,
317                }))
318                .with_reason(if within {
319                    TRACE_REASON_WITHIN_BOUNDS
320                } else {
321                    TRACE_REASON_EXCEEDS_BOUNDS
322                });
323
324            let mut metadata = json!({
325                "samples": baseline.samples,
326            });
327            if baseline.std_dev > f64::EPSILON
328                && let Some(map) = metadata.as_object_mut()
329                && let Some(z) =
330                    serde_json::Number::from_f64((observed - baseline.mean) / baseline.std_dev)
331            {
332                map.insert("z_score".into(), Value::Number(z));
333            }
334            envelope = envelope.with_metadata(metadata);
335        }
336    }
337
338    envelope
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344
345    fn baseline(entity: &str, metric: &str, mean: f64, sd: f64) -> EntityBaseline {
346        EntityBaseline {
347            entity: entity.into(),
348            metric: metric.into(),
349            mean,
350            std_dev: sd,
351            samples: 100,
352        }
353    }
354
355    #[tokio::test]
356    async fn within_bounds_check() {
357        let b = baseline("e", "fanout", 10.0, 2.0);
358        assert!(b.within(11.0, 2.0));
359        assert!(!b.within(20.0, 2.0));
360    }
361
362    #[test]
363    fn online_stats_builds_entity_baseline() {
364        let mut stats = OnlineStats::new();
365        for value in [2.0_f64, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0] {
366            stats.push(value);
367        }
368        let baseline = stats.to_baseline("host", "bytes");
369        assert_eq!(baseline.samples, 8);
370        assert!((baseline.mean - 5.0).abs() < 1e-12);
371        assert!((baseline.std_dev - 4.571_428_571_428_f64.sqrt()).abs() < 1e-12);
372    }
373
374    #[tokio::test]
375    async fn store_put_then_get() {
376        let store = InMemoryBaselineStore::new();
377        store.put(baseline("e", "m", 5.0, 1.0)).await.unwrap();
378        let got = store.get("e", "m").await.unwrap();
379        assert_eq!(got.samples, 100);
380        assert!(store.contains("e", "m").await);
381    }
382
383    #[tokio::test]
384    async fn tool_reports_available_and_within() {
385        let store: Arc<dyn BaselineStore> = Arc::new(InMemoryBaselineStore::new());
386        store.put(baseline("e", "m", 100.0, 5.0)).await.unwrap();
387        let tool = BaselineCompareTool::new(store);
388        let out = tool
389            .invoke(json!({"entity": "e", "metric": "m", "value": 102.0, "k": 2.0}))
390            .await
391            .unwrap();
392        assert_eq!(out["available"], true);
393        assert_eq!(out["within"], true);
394    }
395
396    #[test]
397    fn trace_envelope_within_bounds_includes_metadata() {
398        let b = baseline("host-1", "fanout", 10.0, 2.0);
399        let envelope = baseline_compare_trace_envelope("host-1", "fanout", 11.0, 2.0, Some(&b));
400
401        assert_eq!(envelope.version, ResourceTraceEnvelope::VERSION);
402        assert_eq!(envelope.resource, "baseline");
403        assert_eq!(envelope.operation, "compare");
404        assert_eq!(envelope.trace_kind, "baseline_compare");
405        assert_eq!(envelope.input_summary["entity"], "host-1");
406        assert_eq!(envelope.input_summary["metric"], "fanout");
407        let observed = envelope.input_summary["observed_value"].as_f64().unwrap();
408        assert!((observed - 11.0).abs() < 1e-9);
409        assert_eq!(envelope.output_summary["available"], true);
410        assert_eq!(envelope.output_summary["within"], true);
411        let mean = envelope.output_summary["mean"].as_f64().unwrap();
412        assert!((mean - 10.0).abs() < 1e-9);
413        let bound = envelope.output_summary["bound"].as_f64().unwrap();
414        assert!((bound - 4.0).abs() < 1e-9);
415        assert_eq!(envelope.reason.as_deref(), Some(TRACE_REASON_WITHIN_BOUNDS));
416        assert_eq!(envelope.metadata["samples"], 100);
417        let z = envelope.metadata["z_score"].as_f64().unwrap();
418        assert!((z - 0.5).abs() < 1e-9);
419    }
420
421    #[test]
422    fn trace_envelope_exceeds_bounds_sets_reason() {
423        let b = baseline("host-1", "fanout", 10.0, 2.0);
424        let envelope = baseline_compare_trace_envelope("host-1", "fanout", 20.0, 2.0, Some(&b));
425        assert_eq!(envelope.output_summary["within"], false);
426        assert_eq!(
427            envelope.reason.as_deref(),
428            Some(TRACE_REASON_EXCEEDS_BOUNDS)
429        );
430        let deviation = envelope.output_summary["deviation"].as_f64().unwrap();
431        assert!((deviation - 10.0).abs() < 1e-9);
432    }
433
434    #[test]
435    fn trace_envelope_not_found_omits_baseline_fields() {
436        let envelope = baseline_compare_trace_envelope("ghost", "metric", 7.0, 2.0, None);
437        assert_eq!(envelope.output_summary["available"], false);
438        assert_eq!(envelope.output_summary["within"], false);
439        assert!(envelope.output_summary.get("mean").is_none());
440        assert_eq!(envelope.reason.as_deref(), Some(TRACE_REASON_NOT_FOUND));
441        assert!(envelope.metadata.is_null());
442    }
443}