hookwise/cascade/
token_sim.rs1use std::sync::RwLock;
2
3use async_trait::async_trait;
4use chrono::Utc;
5
6use crate::cascade::{CascadeInput, CascadeTier};
7use crate::decision::{CacheKey, Decision, DecisionMetadata, DecisionRecord, DecisionTier};
8use crate::error::Result;
9
10#[derive(Debug, Clone)]
12pub struct TokenEntry {
13 pub tokens: Vec<String>,
14 pub cache_key: CacheKey,
15 pub record: DecisionRecord,
16}
17
18pub struct TokenJaccard {
20 entries: RwLock<Vec<TokenEntry>>,
21 threshold: f64,
22 min_tokens: usize,
23}
24
25impl TokenJaccard {
26 pub fn new(threshold: f64, min_tokens: usize) -> Self {
27 Self {
28 entries: RwLock::new(Vec::new()),
29 threshold,
30 min_tokens,
31 }
32 }
33
34 pub fn load_from(&self, records: &[DecisionRecord]) {
36 let mut entries = self.entries.write().unwrap_or_else(|e| e.into_inner());
37 for record in records {
38 let tokens = Self::tokenize(&record.key.sanitized_input);
39 entries.push(TokenEntry {
40 tokens,
41 cache_key: record.key.clone(),
42 record: record.clone(),
43 });
44 }
45 }
46
47 pub fn insert(&self, record: &DecisionRecord) {
49 let tokens = Self::tokenize(&record.key.sanitized_input);
50 let mut entries = self.entries.write().unwrap_or_else(|e| e.into_inner());
51 entries.push(TokenEntry {
52 tokens,
53 cache_key: record.key.clone(),
54 record: record.clone(),
55 });
56 }
57
58 pub fn tokenize(input: &str) -> Vec<String> {
61 let mut tokens: Vec<String> = input
62 .split(|c: char| c.is_whitespace() || c.is_ascii_punctuation())
63 .filter(|s| !s.is_empty())
64 .map(|s| s.to_lowercase())
65 .collect();
66 tokens.sort();
67 tokens.dedup();
68 tokens
69 }
70
71 pub fn jaccard_coefficient(a: &[String], b: &[String]) -> f64 {
73 if a.is_empty() && b.is_empty() {
74 return 1.0;
75 }
76 let intersection = Self::sorted_intersection_count(a, b);
77 let union = a.len() + b.len() - intersection;
78 if union == 0 {
79 return 0.0;
80 }
81 intersection as f64 / union as f64
82 }
83
84 fn sorted_intersection_count(a: &[String], b: &[String]) -> usize {
86 let mut count = 0;
87 let (mut i, mut j) = (0, 0);
88 while i < a.len() && j < b.len() {
89 match a[i].cmp(&b[j]) {
90 std::cmp::Ordering::Less => i += 1,
91 std::cmp::Ordering::Greater => j += 1,
92 std::cmp::Ordering::Equal => {
93 count += 1;
94 i += 1;
95 j += 1;
96 }
97 }
98 }
99 count
100 }
101
102 pub fn invalidate_role(&self, role: &str) {
104 let mut entries = self.entries.write().unwrap_or_else(|e| e.into_inner());
105 entries.retain(|e| e.cache_key.role != role);
106 }
107
108 pub fn invalidate_all(&self) {
110 let mut entries = self.entries.write().unwrap_or_else(|e| e.into_inner());
111 entries.clear();
112 }
113}
114
115#[async_trait]
116impl CascadeTier for TokenJaccard {
117 async fn evaluate(&self, input: &CascadeInput) -> Result<Option<DecisionRecord>> {
118 let query_tokens = Self::tokenize(&input.sanitized_input);
119
120 if query_tokens.len() < self.min_tokens {
122 return Ok(None);
123 }
124
125 let role_name = input
126 .session
127 .role
128 .as_ref()
129 .map(|r| r.name.as_str())
130 .unwrap_or("*");
131
132 let entries = self.entries.read().unwrap_or_else(|e| e.into_inner());
133
134 let mut best_match: Option<(f64, &TokenEntry)> = None;
135
136 for entry in entries.iter() {
137 if entry.cache_key.role != role_name && entry.cache_key.role != "*" {
139 continue;
140 }
141 if entry.cache_key.tool != input.tool_name {
143 continue;
144 }
145
146 let score = Self::jaccard_coefficient(&query_tokens, &entry.tokens);
147
148 if score >= self.threshold && best_match.as_ref().is_none_or(|(best, _)| score > *best)
149 {
150 best_match = Some((score, entry));
151 }
152 }
153
154 match best_match {
155 Some((score, entry)) => {
156 match entry.record.decision {
161 Decision::Deny => Ok(None), Decision::Allow | Decision::Ask => {
163 Ok(Some(DecisionRecord {
164 key: CacheKey {
165 sanitized_input: input.sanitized_input.clone(),
166 tool: input.tool_name.clone(),
167 role: role_name.to_string(),
168 },
169 decision: entry.record.decision,
170 metadata: DecisionMetadata {
171 tier: DecisionTier::TokenJaccard,
172 confidence: score,
173 reason: format!(
174 "token Jaccard similarity {:.3} >= {:.3} with cached {}",
175 score, self.threshold, entry.record.decision
176 ),
177 matched_key: Some(entry.cache_key.clone()),
178 similarity_score: Some(score),
179 },
180 timestamp: Utc::now(),
181 scope: entry.record.scope,
182 file_path: input.file_path.clone(),
183 session_id: String::new(), }))
185 }
186 }
187 }
188 None => Ok(None), }
190 }
191
192 fn tier(&self) -> DecisionTier {
193 DecisionTier::TokenJaccard
194 }
195
196 fn name(&self) -> &str {
197 "token-jaccard"
198 }
199}