Skip to main content

cloakrs_core/
context.rs

1//! Context-aware confidence scoring.
2
3use crate::Span;
4use serde::{Deserialize, Serialize};
5
6/// Configuration for context-aware confidence scoring.
7///
8/// # Examples
9///
10/// ```
11/// use cloakrs_core::ContextConfig;
12///
13/// let config = ContextConfig::default();
14/// assert_eq!(config.words_before, 5);
15/// ```
16#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
17pub struct ContextConfig {
18    /// Number of words to inspect before the match.
19    pub words_before: usize,
20    /// Number of words to inspect after the match.
21    pub words_after: usize,
22    /// Boost for one positive context hit.
23    pub single_hit_boost: f64,
24    /// Boost for two or more positive context hits.
25    pub multi_hit_boost: f64,
26    /// Penalty when code-like context is detected.
27    pub code_context_penalty: f64,
28}
29
30impl Default for ContextConfig {
31    fn default() -> Self {
32        Self {
33            words_before: 5,
34            words_after: 3,
35            single_hit_boost: 0.10,
36            multi_hit_boost: 0.15,
37            code_context_penalty: 0.20,
38        }
39    }
40}
41
42/// Words surrounding a candidate match.
43///
44/// # Examples
45///
46/// ```
47/// use cloakrs_core::{surrounding_context, ContextConfig, Span};
48///
49/// let window = surrounding_context(
50///     "my email is jane@example.com today",
51///     Span::new(12, 28),
52///     &ContextConfig::default(),
53/// );
54/// assert!(window.before.contains(&"email".to_string()));
55/// ```
56#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
57pub struct ContextWindow {
58    /// Lowercased words before the match.
59    pub before: Vec<String>,
60    /// Lowercased words after the match.
61    pub after: Vec<String>,
62}
63
64/// Result of context scoring.
65///
66/// # Examples
67///
68/// ```
69/// use cloakrs_core::{context_score, ContextConfig, Span};
70///
71/// let score = context_score(
72///     "email jane@example.com",
73///     Span::new(6, 22),
74///     &["email"],
75///     &ContextConfig::default(),
76/// );
77/// assert!(score.adjustment > 0.0);
78/// ```
79#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
80pub struct ContextScore {
81    /// Adjustment to add to a recognizer's base confidence.
82    pub adjustment: f64,
83    /// Number of matching positive context terms.
84    pub positive_hits: usize,
85    /// Whether code-like context was detected.
86    pub code_like: bool,
87}
88
89/// Extracts lowercased words before and after a candidate span.
90#[must_use]
91pub fn surrounding_context(text: &str, span: Span, config: &ContextConfig) -> ContextWindow {
92    let start = span.start.min(text.len());
93    let end = span.end.min(text.len());
94    let before_text = text.get(..start).unwrap_or_default();
95    let after_text = text.get(end..).unwrap_or_default();
96
97    let mut before = words(before_text);
98    let before_start = before.len().saturating_sub(config.words_before);
99    before = before.split_off(before_start);
100
101    let after = words(after_text)
102        .into_iter()
103        .take(config.words_after)
104        .collect();
105
106    ContextWindow { before, after }
107}
108
109/// Computes a bounded context confidence adjustment.
110#[must_use]
111pub fn context_score(
112    text: &str,
113    span: Span,
114    positive_terms: &[&str],
115    config: &ContextConfig,
116) -> ContextScore {
117    let window = surrounding_context(text, span, config);
118    let haystack = window
119        .before
120        .iter()
121        .chain(window.after.iter())
122        .cloned()
123        .collect::<Vec<_>>()
124        .join(" ");
125
126    let positive_hits = positive_terms
127        .iter()
128        .filter(|term| haystack.contains(&term.to_ascii_lowercase()))
129        .count();
130
131    let mut adjustment = match positive_hits {
132        0 => 0.0,
133        1 => config.single_hit_boost,
134        _ => config.multi_hit_boost,
135    };
136
137    let code_like = is_code_like_context(&window);
138    if code_like {
139        adjustment -= config.code_context_penalty;
140    }
141
142    ContextScore {
143        adjustment,
144        positive_hits,
145        code_like,
146    }
147}
148
149fn is_code_like_context(window: &ContextWindow) -> bool {
150    const CODE_TERMS: &[&str] = &["=", "var", "let", "const", "fn", "function"];
151    window
152        .before
153        .iter()
154        .rev()
155        .take(3)
156        .any(|word| CODE_TERMS.contains(&word.as_str()))
157}
158
159fn words(text: &str) -> Vec<String> {
160    text.split(|c: char| !(c.is_alphanumeric() || matches!(c, '_' | '-' | ':' | '=')))
161        .filter(|word| !word.is_empty())
162        .map(|word| word.to_ascii_lowercase())
163        .collect()
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169
170    #[test]
171    fn test_surrounding_context_collects_before_and_after_words() {
172        let window = surrounding_context(
173            "my email is jane@example.com today please",
174            Span::new(12, 28),
175            &ContextConfig::default(),
176        );
177        assert!(window.before.contains(&"email".to_string()));
178        assert!(window.after.contains(&"today".to_string()));
179    }
180
181    #[test]
182    fn test_context_score_single_hit_boosts() {
183        let score = context_score(
184            "email jane@example.com",
185            Span::new(6, 22),
186            &["email"],
187            &ContextConfig::default(),
188        );
189        assert_eq!(score.positive_hits, 1);
190        assert_eq!(score.adjustment, 0.10);
191    }
192
193    #[test]
194    fn test_context_score_multiple_hits_uses_larger_boost() {
195        let score = context_score(
196            "contact email jane@example.com",
197            Span::new(14, 30),
198            &["contact", "email"],
199            &ContextConfig::default(),
200        );
201        assert_eq!(score.positive_hits, 2);
202        assert_eq!(score.adjustment, 0.15);
203    }
204
205    #[test]
206    fn test_context_score_code_context_penalizes() {
207        let score = context_score(
208            "let email = jane@example.com",
209            Span::new(12, 28),
210            &["email"],
211            &ContextConfig::default(),
212        );
213        assert!(score.code_like);
214        assert!(score.adjustment < 0.0);
215    }
216
217    #[test]
218    fn test_context_score_no_terms_returns_zero() {
219        let score = context_score(
220            "value jane@example.com",
221            Span::new(6, 22),
222            &["email"],
223            &ContextConfig::default(),
224        );
225        assert_eq!(score.adjustment, 0.0);
226    }
227}