Skip to main content

rig_resources/
baseline.rs

1//! Environmental baselines and the `baseline.compare` tool.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use parking_lot::RwLock;
8use serde::{Deserialize, Serialize};
9use serde_json::{Value, json};
10use thiserror::Error;
11
12use rig_compose::{KernelError, Tool, ToolSchema};
13
14use crate::trace::ResourceTraceEnvelope;
15
16const TRACE_RESOURCE: &str = "baseline";
17const TRACE_OPERATION: &str = "compare";
18const TRACE_KIND: &str = "baseline_compare";
19
20/// Reason emitted when no baseline existed for the requested
21/// `(entity, metric)` pair.
22pub const TRACE_REASON_NOT_FOUND: &str = "baseline_not_found";
23/// Reason emitted when the observation fell inside the `mean ± k·σ` bound.
24pub const TRACE_REASON_WITHIN_BOUNDS: &str = "within_bounds";
25/// Reason emitted when the observation fell outside the `mean ± k·σ` bound.
26pub const TRACE_REASON_EXCEEDS_BOUNDS: &str = "exceeds_bounds";
27
28/// Errors returned by baseline stores.
29#[derive(Debug, Error)]
30pub enum BaselineError {
31    /// No baseline exists for an entity/metric pair.
32    #[error("baseline `{entity}/{metric}` not found")]
33    NotFound {
34        /// Entity identifier used for lookup.
35        entity: String,
36        /// Metric identifier used for lookup.
37        metric: String,
38    },
39}
40
41/// Statistical envelope for one (entity, metric) pair.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct EntityBaseline {
44    /// Entity identifier (host, user, service, etc.).
45    pub entity: String,
46    /// Metric name represented by this baseline.
47    pub metric: String,
48    /// Observed mean.
49    pub mean: f64,
50    /// Sample standard deviation.
51    pub std_dev: f64,
52    /// Number of observations used to build the baseline.
53    pub samples: u64,
54}
55
56impl EntityBaseline {
57    /// Build a baseline envelope from online statistics.
58    pub fn from_stats(
59        entity: impl Into<String>,
60        metric: impl Into<String>,
61        stats: &OnlineStats,
62    ) -> Self {
63        Self {
64            entity: entity.into(),
65            metric: metric.into(),
66            mean: stats.mean(),
67            std_dev: stats.std_dev(),
68            samples: stats.count(),
69        }
70    }
71
72    /// Return `true` when `value` falls within `mean ± k * std_dev`.
73    pub fn within(&self, value: f64, k: f64) -> bool {
74        let bound = (k * self.std_dev).max(f64::EPSILON);
75        (value - self.mean).abs() <= bound
76    }
77}
78
79/// Online mean/variance accumulator for building [`EntityBaseline`] values.
80///
81/// Uses Welford's algorithm, so callers can update an environmental baseline
82/// one observation at a time without storing raw samples.
83#[derive(Debug, Clone, Default, Serialize, Deserialize)]
84pub struct OnlineStats {
85    count: u64,
86    mean: f64,
87    m2: f64,
88}
89
90impl OnlineStats {
91    /// Create an empty accumulator.
92    pub fn new() -> Self {
93        Self::default()
94    }
95
96    /// Add one sample to the accumulator.
97    pub fn push(&mut self, value: f64) {
98        self.count = self.count.saturating_add(1);
99        let delta = value - self.mean;
100        self.mean += delta / self.count as f64;
101        let delta2 = value - self.mean;
102        self.m2 += delta * delta2;
103    }
104
105    /// Number of samples observed.
106    pub fn count(&self) -> u64 {
107        self.count
108    }
109
110    /// Whether no samples have been observed.
111    pub fn is_empty(&self) -> bool {
112        self.count == 0
113    }
114
115    /// Current mean, or `0.0` before the first sample.
116    pub fn mean(&self) -> f64 {
117        self.mean
118    }
119
120    /// Sample variance. Returns `0.0` until at least two samples exist.
121    pub fn variance(&self) -> f64 {
122        if self.count < 2 {
123            0.0
124        } else {
125            self.m2 / (self.count - 1) as f64
126        }
127    }
128
129    /// Sample standard deviation.
130    pub fn std_dev(&self) -> f64 {
131        self.variance().sqrt()
132    }
133
134    /// Convert the accumulated stats into an [`EntityBaseline`].
135    pub fn to_baseline(
136        &self,
137        entity: impl Into<String>,
138        metric: impl Into<String>,
139    ) -> EntityBaseline {
140        EntityBaseline::from_stats(entity, metric, self)
141    }
142}
143
144/// Storage contract for entity/metric baselines.
145#[async_trait]
146pub trait BaselineStore: Send + Sync {
147    /// Insert or replace a baseline.
148    async fn put(&self, baseline: EntityBaseline) -> Result<(), BaselineError>;
149    /// Fetch one baseline by entity and metric.
150    async fn get(&self, entity: &str, metric: &str) -> Result<EntityBaseline, BaselineError>;
151    /// Return `true` when a baseline exists for entity and metric.
152    async fn contains(&self, entity: &str, metric: &str) -> bool;
153}
154
155/// In-memory baseline store for tests, examples, and single-process agents.
156#[derive(Clone, Default)]
157pub struct InMemoryBaselineStore {
158    inner: Arc<RwLock<HashMap<(String, String), EntityBaseline>>>,
159}
160
161impl InMemoryBaselineStore {
162    /// Create an empty store.
163    pub fn new() -> Self {
164        Self::default()
165    }
166    /// Create an empty store wrapped in [`Arc`].
167    pub fn arc() -> Arc<Self> {
168        Arc::new(Self::new())
169    }
170    /// Number of baselines stored.
171    pub fn len(&self) -> usize {
172        self.inner.read().len()
173    }
174    /// Whether the store contains no baselines.
175    pub fn is_empty(&self) -> bool {
176        self.inner.read().is_empty()
177    }
178}
179
180#[async_trait]
181impl BaselineStore for InMemoryBaselineStore {
182    async fn put(&self, baseline: EntityBaseline) -> Result<(), BaselineError> {
183        self.inner
184            .write()
185            .insert((baseline.entity.clone(), baseline.metric.clone()), baseline);
186        Ok(())
187    }
188    async fn get(&self, entity: &str, metric: &str) -> Result<EntityBaseline, BaselineError> {
189        self.inner
190            .read()
191            .get(&(entity.to_string(), metric.to_string()))
192            .cloned()
193            .ok_or_else(|| BaselineError::NotFound {
194                entity: entity.to_string(),
195                metric: metric.to_string(),
196            })
197    }
198    async fn contains(&self, entity: &str, metric: &str) -> bool {
199        self.inner
200            .read()
201            .contains_key(&(entity.to_string(), metric.to_string()))
202    }
203}
204
205/// `baseline.compare` — kernel tool.
206pub struct BaselineCompareTool {
207    store: Arc<dyn BaselineStore>,
208}
209
210impl BaselineCompareTool {
211    /// Canonical tool name registered with `rig-compose`.
212    pub const NAME: &'static str = "baseline.compare";
213
214    /// Build a tool backed by `store`.
215    pub fn new(store: Arc<dyn BaselineStore>) -> Self {
216        Self { store }
217    }
218
219    /// Build a trait-object handle suitable for direct registry insertion.
220    pub fn arc(store: Arc<dyn BaselineStore>) -> Arc<dyn Tool> {
221        Arc::new(Self::new(store))
222    }
223}
224
225#[async_trait]
226impl Tool for BaselineCompareTool {
227    fn schema(&self) -> ToolSchema {
228        ToolSchema {
229            name: Self::NAME.into(),
230            description:
231                "Compare an observed value to the entity's baseline (mean +/- k*sigma). Returns availability and within-bound flags."
232                    .into(),
233            args_schema: json!({
234                "type": "object",
235                "required": ["entity", "metric", "value"],
236                "properties": {
237                    "entity": {"type": "string"},
238                    "metric": {"type": "string"},
239                    "value": {"type": "number"},
240                    "k": {"type": "number", "default": 2.0}
241                }
242            }),
243            result_schema: json!({"type": "object"}),
244        }
245    }
246
247    fn name(&self) -> rig_compose::tool::ToolName {
248        Self::NAME.to_string()
249    }
250
251    async fn invoke(&self, args: Value) -> Result<Value, KernelError> {
252        #[derive(serde::Deserialize)]
253        struct Args {
254            entity: String,
255            metric: String,
256            value: f64,
257            #[serde(default = "default_k")]
258            k: f64,
259        }
260        fn default_k() -> f64 {
261            2.0
262        }
263        let parsed: Args = serde_json::from_value(args)?;
264        match self.store.get(&parsed.entity, &parsed.metric).await {
265            Ok(baseline) => Ok(json!({
266                "available": true,
267                "within": baseline.within(parsed.value, parsed.k),
268                "mean": baseline.mean,
269                "std_dev": baseline.std_dev,
270                "k": parsed.k,
271            })),
272            Err(_) => Ok(json!({
273                "available": false,
274                "within": false,
275                "k": parsed.k,
276            })),
277        }
278    }
279}
280
281/// Build a [`ResourceTraceEnvelope`] describing a single `baseline.compare`
282/// evaluation.
283///
284/// Pass `baseline` as `Some(&EntityBaseline)` when the store had a record
285/// for the `(entity, metric)` pair, or `None` to record a not-available
286/// comparison. The envelope mirrors the structure of
287/// [`crate::security_finding_trace_envelope`] and
288/// [`crate::memory_lookup_trace_envelope`] so audit and observability
289/// pipelines can route all three with one shape.
290///
291/// Reason codes:
292/// * `None` → [`TRACE_REASON_NOT_FOUND`]
293/// * `Some(_)` and inside `mean ± k·σ` → [`TRACE_REASON_WITHIN_BOUNDS`]
294/// * `Some(_)` and outside the bound → [`TRACE_REASON_EXCEEDS_BOUNDS`]
295///
296/// ```no_run
297/// use rig_resources::{EntityBaseline, baseline_compare_trace_envelope};
298///
299/// let baseline = EntityBaseline {
300///     entity: "host-1".into(),
301///     metric: "fanout".into(),
302///     mean: 10.0,
303///     std_dev: 2.0,
304///     samples: 100,
305/// };
306/// let envelope =
307///     baseline_compare_trace_envelope("host-1", "fanout", 11.0, 2.0, Some(&baseline));
308/// assert_eq!(envelope.resource, "baseline");
309/// assert_eq!(envelope.output_summary["within"], true);
310/// ```
311#[must_use]
312pub fn baseline_compare_trace_envelope(
313    entity: &str,
314    metric: &str,
315    observed: f64,
316    k: f64,
317    baseline: Option<&EntityBaseline>,
318) -> ResourceTraceEnvelope {
319    let input = json!({
320        "entity": entity,
321        "metric": metric,
322        "observed_value": observed,
323        "k": k,
324    });
325
326    let mut envelope = ResourceTraceEnvelope::new(TRACE_RESOURCE, TRACE_OPERATION, TRACE_KIND)
327        .with_input_summary(input);
328
329    match baseline {
330        None => {
331            envelope = envelope
332                .with_output_summary(json!({
333                    "available": false,
334                    "within": false,
335                }))
336                .with_reason(TRACE_REASON_NOT_FOUND);
337        }
338        Some(baseline) => {
339            let within = baseline.within(observed, k);
340            let bound = (k * baseline.std_dev).max(f64::EPSILON);
341            let deviation = (observed - baseline.mean).abs();
342            envelope = envelope
343                .with_output_summary(json!({
344                    "available": true,
345                    "within": within,
346                    "mean": baseline.mean,
347                    "std_dev": baseline.std_dev,
348                    "bound": bound,
349                    "deviation": deviation,
350                }))
351                .with_reason(if within {
352                    TRACE_REASON_WITHIN_BOUNDS
353                } else {
354                    TRACE_REASON_EXCEEDS_BOUNDS
355                });
356
357            let mut metadata = json!({
358                "samples": baseline.samples,
359            });
360            if baseline.std_dev > f64::EPSILON
361                && let Some(map) = metadata.as_object_mut()
362                && let Some(z) =
363                    serde_json::Number::from_f64((observed - baseline.mean) / baseline.std_dev)
364            {
365                map.insert("z_score".into(), Value::Number(z));
366            }
367            envelope = envelope.with_metadata(metadata);
368        }
369    }
370
371    envelope
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    fn baseline(entity: &str, metric: &str, mean: f64, sd: f64) -> EntityBaseline {
379        EntityBaseline {
380            entity: entity.into(),
381            metric: metric.into(),
382            mean,
383            std_dev: sd,
384            samples: 100,
385        }
386    }
387
388    #[tokio::test]
389    async fn within_bounds_check() {
390        let b = baseline("e", "fanout", 10.0, 2.0);
391        assert!(b.within(11.0, 2.0));
392        assert!(!b.within(20.0, 2.0));
393    }
394
395    #[test]
396    fn online_stats_builds_entity_baseline() {
397        let mut stats = OnlineStats::new();
398        for value in [2.0_f64, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0] {
399            stats.push(value);
400        }
401        let baseline = stats.to_baseline("host", "bytes");
402        assert_eq!(baseline.samples, 8);
403        assert!((baseline.mean - 5.0).abs() < 1e-12);
404        assert!((baseline.std_dev - 4.571_428_571_428_f64.sqrt()).abs() < 1e-12);
405    }
406
407    #[tokio::test]
408    async fn store_put_then_get() {
409        let store = InMemoryBaselineStore::new();
410        store.put(baseline("e", "m", 5.0, 1.0)).await.unwrap();
411        let got = store.get("e", "m").await.unwrap();
412        assert_eq!(got.samples, 100);
413        assert!(store.contains("e", "m").await);
414    }
415
416    #[tokio::test]
417    async fn tool_reports_available_and_within() {
418        let store: Arc<dyn BaselineStore> = Arc::new(InMemoryBaselineStore::new());
419        store.put(baseline("e", "m", 100.0, 5.0)).await.unwrap();
420        let tool = BaselineCompareTool::new(store);
421        let out = tool
422            .invoke(json!({"entity": "e", "metric": "m", "value": 102.0, "k": 2.0}))
423            .await
424            .unwrap();
425        assert_eq!(out["available"], true);
426        assert_eq!(out["within"], true);
427    }
428
429    #[test]
430    fn trace_envelope_within_bounds_includes_metadata() {
431        let b = baseline("host-1", "fanout", 10.0, 2.0);
432        let envelope = baseline_compare_trace_envelope("host-1", "fanout", 11.0, 2.0, Some(&b));
433
434        assert_eq!(envelope.version, ResourceTraceEnvelope::VERSION);
435        assert_eq!(envelope.resource, "baseline");
436        assert_eq!(envelope.operation, "compare");
437        assert_eq!(envelope.trace_kind, "baseline_compare");
438        assert_eq!(envelope.input_summary["entity"], "host-1");
439        assert_eq!(envelope.input_summary["metric"], "fanout");
440        let observed = envelope.input_summary["observed_value"].as_f64().unwrap();
441        assert!((observed - 11.0).abs() < 1e-9);
442        assert_eq!(envelope.output_summary["available"], true);
443        assert_eq!(envelope.output_summary["within"], true);
444        let mean = envelope.output_summary["mean"].as_f64().unwrap();
445        assert!((mean - 10.0).abs() < 1e-9);
446        let bound = envelope.output_summary["bound"].as_f64().unwrap();
447        assert!((bound - 4.0).abs() < 1e-9);
448        assert_eq!(envelope.reason.as_deref(), Some(TRACE_REASON_WITHIN_BOUNDS));
449        assert_eq!(envelope.metadata["samples"], 100);
450        let z = envelope.metadata["z_score"].as_f64().unwrap();
451        assert!((z - 0.5).abs() < 1e-9);
452    }
453
454    #[test]
455    fn trace_envelope_exceeds_bounds_sets_reason() {
456        let b = baseline("host-1", "fanout", 10.0, 2.0);
457        let envelope = baseline_compare_trace_envelope("host-1", "fanout", 20.0, 2.0, Some(&b));
458        assert_eq!(envelope.output_summary["within"], false);
459        assert_eq!(
460            envelope.reason.as_deref(),
461            Some(TRACE_REASON_EXCEEDS_BOUNDS)
462        );
463        let deviation = envelope.output_summary["deviation"].as_f64().unwrap();
464        assert!((deviation - 10.0).abs() < 1e-9);
465    }
466
467    #[test]
468    fn trace_envelope_not_found_omits_baseline_fields() {
469        let envelope = baseline_compare_trace_envelope("ghost", "metric", 7.0, 2.0, None);
470        assert_eq!(envelope.output_summary["available"], false);
471        assert_eq!(envelope.output_summary["within"], false);
472        assert!(envelope.output_summary.get("mean").is_none());
473        assert_eq!(envelope.reason.as_deref(), Some(TRACE_REASON_NOT_FOUND));
474        assert!(envelope.metadata.is_null());
475    }
476}