Skip to main content

hookwise/cascade/
token_sim.rs

1use std::sync::RwLock;
2
3use async_trait::async_trait;
4use chrono::Utc;
5
6use crate::cascade::{CascadeInput, CascadeTier};
7use crate::decision::{CacheKey, Decision, DecisionMetadata, DecisionRecord, DecisionTier};
8use crate::error::Result;
9
10/// A token set entry for Jaccard comparison.
11#[derive(Debug, Clone)]
12pub struct TokenEntry {
13    pub tokens: Vec<String>,
14    pub cache_key: CacheKey,
15    pub record: DecisionRecord,
16}
17
18/// Tier 2a: Token-level Jaccard similarity.
19pub struct TokenJaccard {
20    entries: RwLock<Vec<TokenEntry>>,
21    threshold: f64,
22    min_tokens: usize,
23}
24
25impl TokenJaccard {
26    pub fn new(threshold: f64, min_tokens: usize) -> Self {
27        Self {
28            entries: RwLock::new(Vec::new()),
29            threshold,
30            min_tokens,
31        }
32    }
33
34    /// Load entries from cached decisions.
35    pub fn load_from(&self, records: &[DecisionRecord]) {
36        let mut entries = self.entries.write().unwrap_or_else(|e| e.into_inner());
37        for record in records {
38            let tokens = Self::tokenize(&record.key.sanitized_input);
39            entries.push(TokenEntry {
40                tokens,
41                cache_key: record.key.clone(),
42                record: record.clone(),
43            });
44        }
45    }
46
47    /// Add a single entry.
48    pub fn insert(&self, record: &DecisionRecord) {
49        let tokens = Self::tokenize(&record.key.sanitized_input);
50        let mut entries = self.entries.write().unwrap_or_else(|e| e.into_inner());
51        entries.push(TokenEntry {
52            tokens,
53            cache_key: record.key.clone(),
54            record: record.clone(),
55        });
56    }
57
58    /// Tokenize an input string: split on whitespace + punctuation, lowercase,
59    /// deduplicate, sort.
60    pub fn tokenize(input: &str) -> Vec<String> {
61        let mut tokens: Vec<String> = input
62            .split(|c: char| c.is_whitespace() || c.is_ascii_punctuation())
63            .filter(|s| !s.is_empty())
64            .map(|s| s.to_lowercase())
65            .collect();
66        tokens.sort();
67        tokens.dedup();
68        tokens
69    }
70
71    /// Compute Jaccard coefficient between two sorted token slices.
72    pub fn jaccard_coefficient(a: &[String], b: &[String]) -> f64 {
73        if a.is_empty() && b.is_empty() {
74            return 1.0;
75        }
76        let intersection = Self::sorted_intersection_count(a, b);
77        let union = a.len() + b.len() - intersection;
78        if union == 0 {
79            return 0.0;
80        }
81        intersection as f64 / union as f64
82    }
83
84    /// Count intersection of two sorted slices using merge-join.
85    fn sorted_intersection_count(a: &[String], b: &[String]) -> usize {
86        let mut count = 0;
87        let (mut i, mut j) = (0, 0);
88        while i < a.len() && j < b.len() {
89            match a[i].cmp(&b[j]) {
90                std::cmp::Ordering::Less => i += 1,
91                std::cmp::Ordering::Greater => j += 1,
92                std::cmp::Ordering::Equal => {
93                    count += 1;
94                    i += 1;
95                    j += 1;
96                }
97            }
98        }
99        count
100    }
101
102    /// Remove all entries for a specific role.
103    pub fn invalidate_role(&self, role: &str) {
104        let mut entries = self.entries.write().unwrap_or_else(|e| e.into_inner());
105        entries.retain(|e| e.cache_key.role != role);
106    }
107
108    /// Remove all entries.
109    pub fn invalidate_all(&self) {
110        let mut entries = self.entries.write().unwrap_or_else(|e| e.into_inner());
111        entries.clear();
112    }
113}
114
115#[async_trait]
116impl CascadeTier for TokenJaccard {
117    async fn evaluate(&self, input: &CascadeInput) -> Result<Option<DecisionRecord>> {
118        let query_tokens = Self::tokenize(&input.sanitized_input);
119
120        // Skip if too few tokens
121        if query_tokens.len() < self.min_tokens {
122            return Ok(None);
123        }
124
125        let role_name = input
126            .session
127            .role
128            .as_ref()
129            .map(|r| r.name.as_str())
130            .unwrap_or("*");
131
132        let entries = self.entries.read().unwrap_or_else(|e| e.into_inner());
133
134        let mut best_match: Option<(f64, &TokenEntry)> = None;
135
136        for entry in entries.iter() {
137            // Only match same role or wildcard entries
138            if entry.cache_key.role != role_name && entry.cache_key.role != "*" {
139                continue;
140            }
141            // Only match same tool
142            if entry.cache_key.tool != input.tool_name {
143                continue;
144            }
145
146            let score = Self::jaccard_coefficient(&query_tokens, &entry.tokens);
147
148            if score >= self.threshold && best_match.as_ref().is_none_or(|(best, _)| score > *best)
149            {
150                best_match = Some((score, entry));
151            }
152        }
153
154        match best_match {
155            Some((score, entry)) => {
156                // Similarity behavior:
157                // - allow -> auto-approve
158                // - deny -> fall through (similarity never auto-denies)
159                // - ask -> return ask (escalate)
160                match entry.record.decision {
161                    Decision::Deny => Ok(None), // Never auto-deny from similarity
162                    Decision::Allow | Decision::Ask => {
163                        Ok(Some(DecisionRecord {
164                            key: CacheKey {
165                                sanitized_input: input.sanitized_input.clone(),
166                                tool: input.tool_name.clone(),
167                                role: role_name.to_string(),
168                            },
169                            decision: entry.record.decision,
170                            metadata: DecisionMetadata {
171                                tier: DecisionTier::TokenJaccard,
172                                confidence: score,
173                                reason: format!(
174                                    "token Jaccard similarity {:.3} >= {:.3} with cached {}",
175                                    score, self.threshold, entry.record.decision
176                                ),
177                                matched_key: Some(entry.cache_key.clone()),
178                                similarity_score: Some(score),
179                            },
180                            timestamp: Utc::now(),
181                            scope: entry.record.scope,
182                            file_path: input.file_path.clone(),
183                            session_id: String::new(), // Filled by CascadeRunner
184                        }))
185                    }
186                }
187            }
188            None => Ok(None), // No match above threshold
189        }
190    }
191
192    fn tier(&self) -> DecisionTier {
193        DecisionTier::TokenJaccard
194    }
195
196    fn name(&self) -> &str {
197        "token-jaccard"
198    }
199}