1use std::sync::Arc;
36
37use serde::{Deserialize, Serialize};
38use serde_json::Value;
39use thiserror::Error;
40
41use chio_kernel::{Guard, GuardContext, KernelError, Verdict};
42
43pub const DEFAULT_SIMILARITY_THRESHOLD: f64 = 0.85;
45pub const DEFAULT_AMBIGUITY_BAND: f64 = 0.10;
47pub const DEFAULT_TOP_K: usize = 5;
49
50#[derive(Clone, Copy, Debug, Default, Deserialize, Serialize, PartialEq, Eq)]
52#[serde(rename_all = "snake_case")]
53pub enum AmbiguousPolicy {
54 #[default]
56 Allow,
57 Deny,
59}
60
61#[derive(Debug, Error)]
63pub enum SpiderSenseError {
64 #[error("pattern database parse error: {0}")]
66 Parse(String),
67 #[error("pattern database is invalid: {0}")]
69 Invalid(String),
70 #[error("invalid configuration: {0}")]
72 Config(String),
73 #[error("failed to read pattern database: {0}")]
75 Io(String),
76}
77
78#[derive(Clone, Debug, Deserialize, Serialize)]
80#[serde(deny_unknown_fields)]
81pub struct SpiderSenseConfig {
82 #[serde(default = "default_threshold")]
85 pub similarity_threshold: f64,
86 #[serde(default = "default_band")]
88 pub ambiguity_band: f64,
89 #[serde(default = "default_top_k")]
91 pub top_k: usize,
92 #[serde(default)]
94 pub ambiguous_policy: AmbiguousPolicy,
95}
96
97fn default_threshold() -> f64 {
98 DEFAULT_SIMILARITY_THRESHOLD
99}
100fn default_band() -> f64 {
101 DEFAULT_AMBIGUITY_BAND
102}
103fn default_top_k() -> usize {
104 DEFAULT_TOP_K
105}
106
107impl Default for SpiderSenseConfig {
108 fn default() -> Self {
109 Self {
110 similarity_threshold: DEFAULT_SIMILARITY_THRESHOLD,
111 ambiguity_band: DEFAULT_AMBIGUITY_BAND,
112 top_k: DEFAULT_TOP_K,
113 ambiguous_policy: AmbiguousPolicy::Allow,
114 }
115 }
116}
117
118#[derive(Clone, Debug, Deserialize, Serialize)]
120pub struct PatternEntry {
121 pub id: String,
123 pub category: String,
125 pub stage: String,
127 pub label: String,
129 pub embedding: Vec<f32>,
131}
132
133#[derive(Clone, Debug)]
135pub struct PatternDb {
136 entries: Arc<Vec<PatternEntry>>,
137 dim: usize,
138}
139
140impl PatternDb {
141 pub fn from_json(json: &str) -> Result<Self, SpiderSenseError> {
147 let entries: Vec<PatternEntry> =
148 serde_json::from_str(json).map_err(|e| SpiderSenseError::Parse(e.to_string()))?;
149 Self::from_entries(entries)
150 }
151
152 pub fn from_entries(entries: Vec<PatternEntry>) -> Result<Self, SpiderSenseError> {
154 if entries.is_empty() {
155 return Err(SpiderSenseError::Invalid(
156 "pattern database must contain at least one entry".into(),
157 ));
158 }
159 let dim = entries[0].embedding.len();
160 if dim == 0 {
161 return Err(SpiderSenseError::Invalid(
162 "pattern embeddings must be non-empty".into(),
163 ));
164 }
165 for (i, entry) in entries.iter().enumerate() {
166 if entry.embedding.len() != dim {
167 return Err(SpiderSenseError::Invalid(format!(
168 "dimension mismatch at index {i}: expected {dim}, got {}",
169 entry.embedding.len()
170 )));
171 }
172 if let Some(j) = entry.embedding.iter().position(|v| !v.is_finite()) {
173 return Err(SpiderSenseError::Invalid(format!(
174 "entry {i} has non-finite embedding value at dimension {j}"
175 )));
176 }
177 }
178 Ok(Self {
179 entries: Arc::new(entries),
180 dim,
181 })
182 }
183
184 pub fn dim(&self) -> usize {
186 self.dim
187 }
188
189 pub fn len(&self) -> usize {
191 self.entries.len()
192 }
193
194 pub fn is_empty(&self) -> bool {
196 self.entries.is_empty()
197 }
198}
199
200pub struct SpiderSenseGuard {
202 db: PatternDb,
203 upper: f64,
204 lower: f64,
205 top_k: usize,
206 ambiguous_policy: AmbiguousPolicy,
207}
208
209impl SpiderSenseGuard {
210 pub fn new(db: PatternDb, config: SpiderSenseConfig) -> Result<Self, SpiderSenseError> {
212 if !config.similarity_threshold.is_finite()
213 || !(0.0..=1.0).contains(&config.similarity_threshold)
214 {
215 return Err(SpiderSenseError::Config(format!(
216 "similarity_threshold must be finite in [0.0, 1.0], got {}",
217 config.similarity_threshold
218 )));
219 }
220 if !config.ambiguity_band.is_finite() || !(0.0..=1.0).contains(&config.ambiguity_band) {
221 return Err(SpiderSenseError::Config(format!(
222 "ambiguity_band must be finite in [0.0, 1.0], got {}",
223 config.ambiguity_band
224 )));
225 }
226 let upper = config.similarity_threshold + config.ambiguity_band;
227 let lower = config.similarity_threshold - config.ambiguity_band;
228 if !(0.0..=1.0).contains(&upper) || !(0.0..=1.0).contains(&lower) {
229 return Err(SpiderSenseError::Config(format!(
230 "threshold ± band must stay inside [0.0, 1.0]; got lower={lower:.3}, upper={upper:.3}"
231 )));
232 }
233 if config.top_k == 0 {
234 return Err(SpiderSenseError::Config("top_k must be ≥ 1".into()));
235 }
236 Ok(Self {
237 db,
238 upper,
239 lower,
240 top_k: config.top_k,
241 ambiguous_policy: config.ambiguous_policy,
242 })
243 }
244
245 pub fn from_json(json: &str) -> Result<Self, SpiderSenseError> {
247 let db = PatternDb::from_json(json)?;
248 Self::new(db, SpiderSenseConfig::default())
249 }
250
251 pub fn from_json_file(path: &str) -> Result<Self, SpiderSenseError> {
253 let data = std::fs::read_to_string(path)
254 .map_err(|e| SpiderSenseError::Io(format!("{path}: {e}")))?;
255 Self::from_json(&data)
256 }
257
258 pub fn score(&self, embedding: &[f32]) -> f64 {
262 if embedding.len() != self.db.dim {
263 return 0.0;
264 }
265 if embedding.iter().any(|v| !v.is_finite()) {
266 return 0.0;
267 }
268 let mut best = 0.0_f64;
269 let mut seen = 0usize;
270 for entry in self.db.entries.iter() {
271 let score = cosine_similarity(embedding, &entry.embedding);
272 if score > best {
273 best = score;
274 }
275 seen += 1;
276 if seen >= self.top_k {
277 }
283 }
284 best
285 }
286
287 fn verdict_for(&self, score: f64) -> Verdict {
289 if !score.is_finite() {
290 return Verdict::Deny;
291 }
292 if score >= self.upper {
293 Verdict::Deny
294 } else if score <= self.lower {
295 Verdict::Allow
296 } else {
297 match self.ambiguous_policy {
298 AmbiguousPolicy::Allow => Verdict::Allow,
299 AmbiguousPolicy::Deny => Verdict::Deny,
300 }
301 }
302 }
303
304 pub fn pattern_count(&self) -> usize {
306 self.db.len()
307 }
308
309 pub fn dim(&self) -> usize {
311 self.db.dim()
312 }
313}
314
315impl Guard for SpiderSenseGuard {
316 fn name(&self) -> &str {
317 "spider-sense"
318 }
319
320 fn evaluate(&self, ctx: &GuardContext) -> Result<Verdict, KernelError> {
321 let embedding = match extract_embedding(&ctx.request.arguments) {
322 Some(e) => e,
323 None => return Ok(Verdict::Allow),
324 };
325 if embedding.len() != self.db.dim {
326 return Ok(Verdict::Deny);
328 }
329 if embedding.iter().any(|v| !v.is_finite()) {
330 return Ok(Verdict::Deny);
331 }
332 let score = self.score(&embedding);
333 Ok(self.verdict_for(score))
334 }
335}
336
337pub fn extract_embedding(arguments: &Value) -> Option<Vec<f32>> {
347 if let Some(vec) = arguments
348 .get("embedding")
349 .or_else(|| arguments.get("vector"))
350 .and_then(array_as_f32_vec)
351 {
352 return Some(vec);
353 }
354 if let Some(array) = arguments.get("embeddings").and_then(|v| v.as_array()) {
355 let vectors: Vec<Vec<f32>> = array.iter().filter_map(array_as_f32_vec).collect();
356 if vectors.is_empty() {
357 return None;
358 }
359 let dim = vectors[0].len();
360 if dim == 0 || vectors.iter().any(|v| v.len() != dim) {
361 return None;
362 }
363 let mut sum = vec![0.0_f64; dim];
364 for v in &vectors {
365 for (i, x) in v.iter().enumerate() {
366 sum[i] += f64::from(*x);
367 }
368 }
369 let n = vectors.len() as f64;
370 return Some(sum.into_iter().map(|s| (s / n) as f32).collect());
371 }
372 None
373}
374
375fn array_as_f32_vec(value: &Value) -> Option<Vec<f32>> {
376 let array = value.as_array()?;
377 if array.is_empty() {
378 return None;
379 }
380 let mut out = Vec::with_capacity(array.len());
381 for v in array {
382 let n = v.as_f64()?;
383 if !n.is_finite() {
384 return None;
385 }
386 out.push(n as f32);
387 }
388 Some(out)
389}
390
391pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
401 if a.len() != b.len() || a.is_empty() {
402 return 0.0;
403 }
404 let mut dot: f64 = 0.0;
405 let mut na: f64 = 0.0;
406 let mut nb: f64 = 0.0;
407 for (x, y) in a.iter().zip(b.iter()) {
408 let xd = f64::from(*x);
409 let yd = f64::from(*y);
410 if !xd.is_finite() || !yd.is_finite() {
411 return 0.0;
412 }
413 dot += xd * yd;
414 na += xd * xd;
415 nb += yd * yd;
416 }
417 let denom = na.sqrt() * nb.sqrt();
418 if !denom.is_normal() {
419 return 0.0;
420 }
421 let r = dot / denom;
422 if r.is_finite() {
423 r
424 } else {
425 0.0
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432
433 fn sample_db() -> PatternDb {
434 PatternDb::from_json(
435 r#"[
436 {"id":"a","category":"x","stage":"perception","label":"l","embedding":[1.0,0.0,0.0]},
437 {"id":"b","category":"y","stage":"action","label":"l","embedding":[0.0,1.0,0.0]}
438 ]"#,
439 )
440 .expect("sample DB parses")
441 }
442
443 #[test]
444 fn cosine_basics() {
445 assert!((cosine_similarity(&[1.0, 0.0], &[1.0, 0.0]) - 1.0).abs() < 1e-9);
446 assert!(cosine_similarity(&[1.0, 0.0], &[0.0, 1.0]).abs() < 1e-9);
447 assert_eq!(cosine_similarity(&[0.0, 0.0], &[1.0, 2.0]), 0.0);
448 assert_eq!(cosine_similarity(&[f32::NAN, 0.0], &[1.0, 0.0]), 0.0);
449 assert_eq!(cosine_similarity(&[f32::INFINITY, 0.0], &[1.0, 0.0]), 0.0);
450 assert_eq!(cosine_similarity(&[1.0], &[1.0, 0.0]), 0.0);
451 }
452
453 #[test]
454 fn pattern_db_rejects_empty() {
455 assert!(matches!(
456 PatternDb::from_json("[]"),
457 Err(SpiderSenseError::Invalid(_))
458 ));
459 }
460
461 #[test]
462 fn pattern_db_rejects_dim_mismatch() {
463 let json = r#"[
464 {"id":"a","category":"x","stage":"s","label":"l","embedding":[1.0,0.0]},
465 {"id":"b","category":"y","stage":"s","label":"l","embedding":[1.0]}
466 ]"#;
467 assert!(matches!(
468 PatternDb::from_json(json),
469 Err(SpiderSenseError::Invalid(_))
470 ));
471 }
472
473 #[test]
474 fn guard_denies_identical_vector() {
475 let guard =
476 SpiderSenseGuard::new(sample_db(), SpiderSenseConfig::default()).expect("build");
477 let score = guard.score(&[1.0, 0.0, 0.0]);
478 assert!((score - 1.0).abs() < 1e-9);
479 assert!(matches!(guard.verdict_for(score), Verdict::Deny));
480 }
481
482 #[test]
483 fn guard_allows_orthogonal_vector() {
484 let guard =
485 SpiderSenseGuard::new(sample_db(), SpiderSenseConfig::default()).expect("build");
486 let score = guard.score(&[0.0, 0.0, 1.0]);
487 assert!(score.abs() < 1e-9);
488 assert!(matches!(guard.verdict_for(score), Verdict::Allow));
489 }
490
491 #[test]
492 fn guard_dim_mismatch_denies() {
493 let guard =
494 SpiderSenseGuard::new(sample_db(), SpiderSenseConfig::default()).expect("build");
495 let score = guard.score(&[1.0, 0.0]);
496 assert_eq!(score, 0.0);
497 assert!(matches!(guard.verdict_for(score), Verdict::Allow));
498 }
499
500 #[test]
501 fn guard_nan_score_denies() {
502 let guard =
503 SpiderSenseGuard::new(sample_db(), SpiderSenseConfig::default()).expect("build");
504 assert!(matches!(guard.verdict_for(f64::NAN), Verdict::Deny));
505 }
506
507 #[test]
508 fn ambiguous_respects_policy() {
509 let db = sample_db();
510 let config = SpiderSenseConfig {
511 similarity_threshold: 0.5,
512 ambiguity_band: 0.1,
513 top_k: 5,
514 ambiguous_policy: AmbiguousPolicy::Deny,
515 };
516 let guard = SpiderSenseGuard::new(db, config).unwrap();
517 assert!(matches!(guard.verdict_for(0.5), Verdict::Deny));
519 }
520
521 #[test]
522 fn extract_embedding_from_args() {
523 let args = serde_json::json!({"embedding": [0.1, 0.2, 0.3]});
524 let e = extract_embedding(&args).unwrap();
525 assert_eq!(e.len(), 3);
526 }
527
528 #[test]
529 fn extract_embedding_averages_list() {
530 let args = serde_json::json!({"embeddings": [[1.0, 0.0], [0.0, 1.0]]});
531 let e = extract_embedding(&args).unwrap();
532 assert_eq!(e.len(), 2);
533 assert!((e[0] - 0.5).abs() < 1e-6);
534 assert!((e[1] - 0.5).abs() < 1e-6);
535 }
536
537 #[test]
538 fn extract_embedding_none_when_absent() {
539 assert!(extract_embedding(&serde_json::json!({"foo": "bar"})).is_none());
540 }
541
542 #[test]
543 fn reject_bad_config() {
544 let db = sample_db();
545 let bad = SpiderSenseConfig {
546 similarity_threshold: 1.5,
547 ..SpiderSenseConfig::default()
548 };
549 assert!(SpiderSenseGuard::new(db, bad).is_err());
550 }
551}