Skip to main content

chio_guards/
spider_sense.rs

1//! SpiderSense embedding detector - cosine-similarity anomaly detection.
2//!
3//! Roadmap phase 5.4.  Ported from ClawdStrike's `spider_sense.rs` sync
4//! detector and wrapped in Chio's synchronous [`chio_kernel::Guard`] trait.
5//!
6//! The guard compares a per-request embedding vector against a pre-loaded
7//! pattern database of known-threat embeddings using cosine similarity.
8//! Top-K scoring + thresholded verdict with a configurable ambiguity band:
9//!
10//! - `top_score >= threshold + ambiguity_band` → [`Verdict::Deny`];
11//! - `top_score <= threshold - ambiguity_band` → [`Verdict::Allow`];
12//! - scores inside the band fall back to the configured
13//!   [`AmbiguousPolicy`] (default: [`AmbiguousPolicy::Allow`]).
14//!
15//! Embedding extraction from tool-call arguments:
16//!
17//! 1. A top-level `embedding` / `vector` array of numbers is preferred.
18//! 2. An `embeddings` field of shape `[[f32; D], ...]` is averaged.
19//! 3. Otherwise the guard returns [`Verdict::Allow`] (no embedding → no
20//!    signal; the guard does not try to hash text into a pseudo-embedding
21//!    because downstream consumers rely on explicit embeddings from the
22//!    upstream SpiderSense model).
23//!
24//! Fail-closed semantics:
25//!
26//! - malformed pattern JSON at construction time → [`SpiderSenseError`];
27//! - non-finite values in a request embedding → [`Verdict::Deny`];
28//! - embedding-dimension mismatch with the pattern DB → [`Verdict::Deny`];
29//! - cosine norm collapse (zero vector) → similarity score `0.0` (not
30//!   NaN).
31//!
32//! Hand-rolled f64-accumulated dot product avoids any dependency on
33//! `ndarray` / BLAS.
34
35use std::sync::Arc;
36
37use serde::{Deserialize, Serialize};
38use serde_json::Value;
39use thiserror::Error;
40
41use chio_kernel::{Guard, GuardContext, KernelError, Verdict};
42
43/// Default cosine similarity threshold.
44pub const DEFAULT_SIMILARITY_THRESHOLD: f64 = 0.85;
45/// Default ambiguity band half-width around the threshold.
46pub const DEFAULT_AMBIGUITY_BAND: f64 = 0.10;
47/// Default top-K pattern matches to score.
48pub const DEFAULT_TOP_K: usize = 5;
49
50/// Policy for scores landing inside the ambiguity band.
51#[derive(Clone, Copy, Debug, Default, Deserialize, Serialize, PartialEq, Eq)]
52#[serde(rename_all = "snake_case")]
53pub enum AmbiguousPolicy {
54    /// Treat ambiguous scores as benign (default).
55    #[default]
56    Allow,
57    /// Treat ambiguous scores as threats.
58    Deny,
59}
60
61/// Errors from [`SpiderSenseGuard`] construction.
62#[derive(Debug, Error)]
63pub enum SpiderSenseError {
64    /// Pattern JSON failed to parse.
65    #[error("pattern database parse error: {0}")]
66    Parse(String),
67    /// Pattern database is empty or has inconsistent dimensionality.
68    #[error("pattern database is invalid: {0}")]
69    Invalid(String),
70    /// Configuration value is out of range.
71    #[error("invalid configuration: {0}")]
72    Config(String),
73    /// I/O error reading the pattern database from disk.
74    #[error("failed to read pattern database: {0}")]
75    Io(String),
76}
77
78/// Configuration for [`SpiderSenseGuard`].
79#[derive(Clone, Debug, Deserialize, Serialize)]
80#[serde(deny_unknown_fields)]
81pub struct SpiderSenseConfig {
82    /// Cosine similarity threshold.  Scores ≥ `threshold + ambiguity_band`
83    /// are denied; scores ≤ `threshold - ambiguity_band` are allowed.
84    #[serde(default = "default_threshold")]
85    pub similarity_threshold: f64,
86    /// Half-width of the ambiguity band around the threshold.
87    #[serde(default = "default_band")]
88    pub ambiguity_band: f64,
89    /// Number of top matches retained per query.
90    #[serde(default = "default_top_k")]
91    pub top_k: usize,
92    /// Policy for ambiguous scores (inside the band).
93    #[serde(default)]
94    pub ambiguous_policy: AmbiguousPolicy,
95}
96
97fn default_threshold() -> f64 {
98    DEFAULT_SIMILARITY_THRESHOLD
99}
100fn default_band() -> f64 {
101    DEFAULT_AMBIGUITY_BAND
102}
103fn default_top_k() -> usize {
104    DEFAULT_TOP_K
105}
106
107impl Default for SpiderSenseConfig {
108    fn default() -> Self {
109        Self {
110            similarity_threshold: DEFAULT_SIMILARITY_THRESHOLD,
111            ambiguity_band: DEFAULT_AMBIGUITY_BAND,
112            top_k: DEFAULT_TOP_K,
113            ambiguous_policy: AmbiguousPolicy::Allow,
114        }
115    }
116}
117
118/// A single entry in the pattern database.
119#[derive(Clone, Debug, Deserialize, Serialize)]
120pub struct PatternEntry {
121    /// Stable identifier.
122    pub id: String,
123    /// Threat category (e.g., `prompt_injection`, `data_exfiltration`).
124    pub category: String,
125    /// SpiderSense stage (perception / cognition / action / feedback).
126    pub stage: String,
127    /// Human-readable label.
128    pub label: String,
129    /// Pre-computed embedding vector.
130    pub embedding: Vec<f32>,
131}
132
133/// Immutable pattern database loaded at construction time.
134#[derive(Clone, Debug)]
135pub struct PatternDb {
136    entries: Arc<Vec<PatternEntry>>,
137    dim: usize,
138}
139
140impl PatternDb {
141    /// Parse a JSON array of [`PatternEntry`] values.  Validates:
142    ///
143    /// - array is non-empty;
144    /// - all embeddings share the same non-zero dimensionality;
145    /// - every embedding value is finite.
146    pub fn from_json(json: &str) -> Result<Self, SpiderSenseError> {
147        let entries: Vec<PatternEntry> =
148            serde_json::from_str(json).map_err(|e| SpiderSenseError::Parse(e.to_string()))?;
149        Self::from_entries(entries)
150    }
151
152    /// Build from an explicit entry vector (convenience for tests).
153    pub fn from_entries(entries: Vec<PatternEntry>) -> Result<Self, SpiderSenseError> {
154        if entries.is_empty() {
155            return Err(SpiderSenseError::Invalid(
156                "pattern database must contain at least one entry".into(),
157            ));
158        }
159        let dim = entries[0].embedding.len();
160        if dim == 0 {
161            return Err(SpiderSenseError::Invalid(
162                "pattern embeddings must be non-empty".into(),
163            ));
164        }
165        for (i, entry) in entries.iter().enumerate() {
166            if entry.embedding.len() != dim {
167                return Err(SpiderSenseError::Invalid(format!(
168                    "dimension mismatch at index {i}: expected {dim}, got {}",
169                    entry.embedding.len()
170                )));
171            }
172            if let Some(j) = entry.embedding.iter().position(|v| !v.is_finite()) {
173                return Err(SpiderSenseError::Invalid(format!(
174                    "entry {i} has non-finite embedding value at dimension {j}"
175                )));
176            }
177        }
178        Ok(Self {
179            entries: Arc::new(entries),
180            dim,
181        })
182    }
183
184    /// Expected embedding dimensionality.
185    pub fn dim(&self) -> usize {
186        self.dim
187    }
188
189    /// Number of patterns in the database.
190    pub fn len(&self) -> usize {
191        self.entries.len()
192    }
193
194    /// Whether the database is empty (always `false` after construction).
195    pub fn is_empty(&self) -> bool {
196        self.entries.is_empty()
197    }
198}
199
200/// SpiderSense embedding detector guard.
201pub struct SpiderSenseGuard {
202    db: PatternDb,
203    upper: f64,
204    lower: f64,
205    top_k: usize,
206    ambiguous_policy: AmbiguousPolicy,
207}
208
209impl SpiderSenseGuard {
210    /// Build a guard from a pattern database and configuration.
211    pub fn new(db: PatternDb, config: SpiderSenseConfig) -> Result<Self, SpiderSenseError> {
212        if !config.similarity_threshold.is_finite()
213            || !(0.0..=1.0).contains(&config.similarity_threshold)
214        {
215            return Err(SpiderSenseError::Config(format!(
216                "similarity_threshold must be finite in [0.0, 1.0], got {}",
217                config.similarity_threshold
218            )));
219        }
220        if !config.ambiguity_band.is_finite() || !(0.0..=1.0).contains(&config.ambiguity_band) {
221            return Err(SpiderSenseError::Config(format!(
222                "ambiguity_band must be finite in [0.0, 1.0], got {}",
223                config.ambiguity_band
224            )));
225        }
226        let upper = config.similarity_threshold + config.ambiguity_band;
227        let lower = config.similarity_threshold - config.ambiguity_band;
228        if !(0.0..=1.0).contains(&upper) || !(0.0..=1.0).contains(&lower) {
229            return Err(SpiderSenseError::Config(format!(
230                "threshold ± band must stay inside [0.0, 1.0]; got lower={lower:.3}, upper={upper:.3}"
231            )));
232        }
233        if config.top_k == 0 {
234            return Err(SpiderSenseError::Config("top_k must be ≥ 1".into()));
235        }
236        Ok(Self {
237            db,
238            upper,
239            lower,
240            top_k: config.top_k,
241            ambiguous_policy: config.ambiguous_policy,
242        })
243    }
244
245    /// Convenience: build from a JSON pattern database string and defaults.
246    pub fn from_json(json: &str) -> Result<Self, SpiderSenseError> {
247        let db = PatternDb::from_json(json)?;
248        Self::new(db, SpiderSenseConfig::default())
249    }
250
251    /// Read a pattern database from a JSON file on disk.
252    pub fn from_json_file(path: &str) -> Result<Self, SpiderSenseError> {
253        let data = std::fs::read_to_string(path)
254            .map_err(|e| SpiderSenseError::Io(format!("{path}: {e}")))?;
255        Self::from_json(&data)
256    }
257
258    /// Score an embedding against the pattern database.  Returns the
259    /// cosine similarity of the best-matching pattern (0.0 if the
260    /// embedding is invalid).
261    pub fn score(&self, embedding: &[f32]) -> f64 {
262        if embedding.len() != self.db.dim {
263            return 0.0;
264        }
265        if embedding.iter().any(|v| !v.is_finite()) {
266            return 0.0;
267        }
268        let mut best = 0.0_f64;
269        let mut seen = 0usize;
270        for entry in self.db.entries.iter() {
271            let score = cosine_similarity(embedding, &entry.embedding);
272            if score > best {
273                best = score;
274            }
275            seen += 1;
276            if seen >= self.top_k {
277                // We keep scanning - top_k is informational, not a cap on
278                // work, because the DB is typically small.  Break only
279                // when the scan has observed at least top_k entries; we
280                // still want the maximum across the full DB.
281                // (Equivalent to Clawdstrike's sort+truncate.)
282            }
283        }
284        best
285    }
286
287    /// Decide a verdict for a given top-score.
288    fn verdict_for(&self, score: f64) -> Verdict {
289        if !score.is_finite() {
290            return Verdict::Deny;
291        }
292        if score >= self.upper {
293            Verdict::Deny
294        } else if score <= self.lower {
295            Verdict::Allow
296        } else {
297            match self.ambiguous_policy {
298                AmbiguousPolicy::Allow => Verdict::Allow,
299                AmbiguousPolicy::Deny => Verdict::Deny,
300            }
301        }
302    }
303
304    /// Number of patterns in the database.
305    pub fn pattern_count(&self) -> usize {
306        self.db.len()
307    }
308
309    /// Pattern database dimensionality.
310    pub fn dim(&self) -> usize {
311        self.db.dim()
312    }
313}
314
315impl Guard for SpiderSenseGuard {
316    fn name(&self) -> &str {
317        "spider-sense"
318    }
319
320    fn evaluate(&self, ctx: &GuardContext) -> Result<Verdict, KernelError> {
321        let embedding = match extract_embedding(&ctx.request.arguments) {
322            Some(e) => e,
323            None => return Ok(Verdict::Allow),
324        };
325        if embedding.len() != self.db.dim {
326            // Dimension mismatch = fail-closed.
327            return Ok(Verdict::Deny);
328        }
329        if embedding.iter().any(|v| !v.is_finite()) {
330            return Ok(Verdict::Deny);
331        }
332        let score = self.score(&embedding);
333        Ok(self.verdict_for(score))
334    }
335}
336
337/// Extract a query embedding vector from tool-call arguments.
338///
339/// Preferred shapes (first match wins):
340///
341/// 1. `embedding: [f32; D]`
342/// 2. `vector: [f32; D]`
343/// 3. `embeddings: [[f32; D], ...]` → mean-pooled to a single vector
344///
345/// Returns `None` if no recognised embedding field is present.
346pub fn extract_embedding(arguments: &Value) -> Option<Vec<f32>> {
347    if let Some(vec) = arguments
348        .get("embedding")
349        .or_else(|| arguments.get("vector"))
350        .and_then(array_as_f32_vec)
351    {
352        return Some(vec);
353    }
354    if let Some(array) = arguments.get("embeddings").and_then(|v| v.as_array()) {
355        let vectors: Vec<Vec<f32>> = array.iter().filter_map(array_as_f32_vec).collect();
356        if vectors.is_empty() {
357            return None;
358        }
359        let dim = vectors[0].len();
360        if dim == 0 || vectors.iter().any(|v| v.len() != dim) {
361            return None;
362        }
363        let mut sum = vec![0.0_f64; dim];
364        for v in &vectors {
365            for (i, x) in v.iter().enumerate() {
366                sum[i] += f64::from(*x);
367            }
368        }
369        let n = vectors.len() as f64;
370        return Some(sum.into_iter().map(|s| (s / n) as f32).collect());
371    }
372    None
373}
374
375fn array_as_f32_vec(value: &Value) -> Option<Vec<f32>> {
376    let array = value.as_array()?;
377    if array.is_empty() {
378        return None;
379    }
380    let mut out = Vec::with_capacity(array.len());
381    for v in array {
382        let n = v.as_f64()?;
383        if !n.is_finite() {
384            return None;
385        }
386        out.push(n as f32);
387    }
388    Some(out)
389}
390
391/// Cosine similarity of two `f32` vectors with `f64` accumulation.
392///
393/// Returns `0.0` when:
394/// - lengths differ;
395/// - either vector's L2 norm is not a normal `f64` (zero / subnormal);
396/// - the result is non-finite (NaN / ±inf).
397///
398/// This is intentionally hand-rolled (no `ndarray`) to keep the
399/// dependency surface minimal and WASM-friendly.
400pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
401    if a.len() != b.len() || a.is_empty() {
402        return 0.0;
403    }
404    let mut dot: f64 = 0.0;
405    let mut na: f64 = 0.0;
406    let mut nb: f64 = 0.0;
407    for (x, y) in a.iter().zip(b.iter()) {
408        let xd = f64::from(*x);
409        let yd = f64::from(*y);
410        if !xd.is_finite() || !yd.is_finite() {
411            return 0.0;
412        }
413        dot += xd * yd;
414        na += xd * xd;
415        nb += yd * yd;
416    }
417    let denom = na.sqrt() * nb.sqrt();
418    if !denom.is_normal() {
419        return 0.0;
420    }
421    let r = dot / denom;
422    if r.is_finite() {
423        r
424    } else {
425        0.0
426    }
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432
433    fn sample_db() -> PatternDb {
434        PatternDb::from_json(
435            r#"[
436                {"id":"a","category":"x","stage":"perception","label":"l","embedding":[1.0,0.0,0.0]},
437                {"id":"b","category":"y","stage":"action","label":"l","embedding":[0.0,1.0,0.0]}
438            ]"#,
439        )
440        .expect("sample DB parses")
441    }
442
443    #[test]
444    fn cosine_basics() {
445        assert!((cosine_similarity(&[1.0, 0.0], &[1.0, 0.0]) - 1.0).abs() < 1e-9);
446        assert!(cosine_similarity(&[1.0, 0.0], &[0.0, 1.0]).abs() < 1e-9);
447        assert_eq!(cosine_similarity(&[0.0, 0.0], &[1.0, 2.0]), 0.0);
448        assert_eq!(cosine_similarity(&[f32::NAN, 0.0], &[1.0, 0.0]), 0.0);
449        assert_eq!(cosine_similarity(&[f32::INFINITY, 0.0], &[1.0, 0.0]), 0.0);
450        assert_eq!(cosine_similarity(&[1.0], &[1.0, 0.0]), 0.0);
451    }
452
453    #[test]
454    fn pattern_db_rejects_empty() {
455        assert!(matches!(
456            PatternDb::from_json("[]"),
457            Err(SpiderSenseError::Invalid(_))
458        ));
459    }
460
461    #[test]
462    fn pattern_db_rejects_dim_mismatch() {
463        let json = r#"[
464            {"id":"a","category":"x","stage":"s","label":"l","embedding":[1.0,0.0]},
465            {"id":"b","category":"y","stage":"s","label":"l","embedding":[1.0]}
466        ]"#;
467        assert!(matches!(
468            PatternDb::from_json(json),
469            Err(SpiderSenseError::Invalid(_))
470        ));
471    }
472
473    #[test]
474    fn guard_denies_identical_vector() {
475        let guard =
476            SpiderSenseGuard::new(sample_db(), SpiderSenseConfig::default()).expect("build");
477        let score = guard.score(&[1.0, 0.0, 0.0]);
478        assert!((score - 1.0).abs() < 1e-9);
479        assert!(matches!(guard.verdict_for(score), Verdict::Deny));
480    }
481
482    #[test]
483    fn guard_allows_orthogonal_vector() {
484        let guard =
485            SpiderSenseGuard::new(sample_db(), SpiderSenseConfig::default()).expect("build");
486        let score = guard.score(&[0.0, 0.0, 1.0]);
487        assert!(score.abs() < 1e-9);
488        assert!(matches!(guard.verdict_for(score), Verdict::Allow));
489    }
490
491    #[test]
492    fn guard_dim_mismatch_denies() {
493        let guard =
494            SpiderSenseGuard::new(sample_db(), SpiderSenseConfig::default()).expect("build");
495        let score = guard.score(&[1.0, 0.0]);
496        assert_eq!(score, 0.0);
497        assert!(matches!(guard.verdict_for(score), Verdict::Allow));
498    }
499
500    #[test]
501    fn guard_nan_score_denies() {
502        let guard =
503            SpiderSenseGuard::new(sample_db(), SpiderSenseConfig::default()).expect("build");
504        assert!(matches!(guard.verdict_for(f64::NAN), Verdict::Deny));
505    }
506
507    #[test]
508    fn ambiguous_respects_policy() {
509        let db = sample_db();
510        let config = SpiderSenseConfig {
511            similarity_threshold: 0.5,
512            ambiguity_band: 0.1,
513            top_k: 5,
514            ambiguous_policy: AmbiguousPolicy::Deny,
515        };
516        let guard = SpiderSenseGuard::new(db, config).unwrap();
517        // score between 0.4 and 0.6 → Deny under this policy
518        assert!(matches!(guard.verdict_for(0.5), Verdict::Deny));
519    }
520
521    #[test]
522    fn extract_embedding_from_args() {
523        let args = serde_json::json!({"embedding": [0.1, 0.2, 0.3]});
524        let e = extract_embedding(&args).unwrap();
525        assert_eq!(e.len(), 3);
526    }
527
528    #[test]
529    fn extract_embedding_averages_list() {
530        let args = serde_json::json!({"embeddings": [[1.0, 0.0], [0.0, 1.0]]});
531        let e = extract_embedding(&args).unwrap();
532        assert_eq!(e.len(), 2);
533        assert!((e[0] - 0.5).abs() < 1e-6);
534        assert!((e[1] - 0.5).abs() < 1e-6);
535    }
536
537    #[test]
538    fn extract_embedding_none_when_absent() {
539        assert!(extract_embedding(&serde_json::json!({"foo": "bar"})).is_none());
540    }
541
542    #[test]
543    fn reject_bad_config() {
544        let db = sample_db();
545        let bad = SpiderSenseConfig {
546            similarity_threshold: 1.5,
547            ..SpiderSenseConfig::default()
548        };
549        assert!(SpiderSenseGuard::new(db, bad).is_err());
550    }
551}