1use crate::Span;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
17pub struct ContextConfig {
18 pub words_before: usize,
20 pub words_after: usize,
22 pub single_hit_boost: f64,
24 pub multi_hit_boost: f64,
26 pub code_context_penalty: f64,
28}
29
30impl Default for ContextConfig {
31 fn default() -> Self {
32 Self {
33 words_before: 5,
34 words_after: 3,
35 single_hit_boost: 0.10,
36 multi_hit_boost: 0.15,
37 code_context_penalty: 0.20,
38 }
39 }
40}
41
42#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
57pub struct ContextWindow {
58 pub before: Vec<String>,
60 pub after: Vec<String>,
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
80pub struct ContextScore {
81 pub adjustment: f64,
83 pub positive_hits: usize,
85 pub code_like: bool,
87}
88
89#[must_use]
91pub fn surrounding_context(text: &str, span: Span, config: &ContextConfig) -> ContextWindow {
92 let start = span.start.min(text.len());
93 let end = span.end.min(text.len());
94 let before_text = text.get(..start).unwrap_or_default();
95 let after_text = text.get(end..).unwrap_or_default();
96
97 let mut before = words(before_text);
98 let before_start = before.len().saturating_sub(config.words_before);
99 before = before.split_off(before_start);
100
101 let after = words(after_text)
102 .into_iter()
103 .take(config.words_after)
104 .collect();
105
106 ContextWindow { before, after }
107}
108
109#[must_use]
111pub fn context_score(
112 text: &str,
113 span: Span,
114 positive_terms: &[&str],
115 config: &ContextConfig,
116) -> ContextScore {
117 let window = surrounding_context(text, span, config);
118 let haystack = window
119 .before
120 .iter()
121 .chain(window.after.iter())
122 .cloned()
123 .collect::<Vec<_>>()
124 .join(" ");
125
126 let positive_hits = positive_terms
127 .iter()
128 .filter(|term| haystack.contains(&term.to_ascii_lowercase()))
129 .count();
130
131 let mut adjustment = match positive_hits {
132 0 => 0.0,
133 1 => config.single_hit_boost,
134 _ => config.multi_hit_boost,
135 };
136
137 let code_like = is_code_like_context(&window);
138 if code_like {
139 adjustment -= config.code_context_penalty;
140 }
141
142 ContextScore {
143 adjustment,
144 positive_hits,
145 code_like,
146 }
147}
148
149fn is_code_like_context(window: &ContextWindow) -> bool {
150 const CODE_TERMS: &[&str] = &["=", "var", "let", "const", "fn", "function"];
151 window
152 .before
153 .iter()
154 .rev()
155 .take(3)
156 .any(|word| CODE_TERMS.contains(&word.as_str()))
157}
158
159fn words(text: &str) -> Vec<String> {
160 text.split(|c: char| !(c.is_alphanumeric() || matches!(c, '_' | '-' | ':' | '=')))
161 .filter(|word| !word.is_empty())
162 .map(|word| word.to_ascii_lowercase())
163 .collect()
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169
170 #[test]
171 fn test_surrounding_context_collects_before_and_after_words() {
172 let window = surrounding_context(
173 "my email is jane@example.com today please",
174 Span::new(12, 28),
175 &ContextConfig::default(),
176 );
177 assert!(window.before.contains(&"email".to_string()));
178 assert!(window.after.contains(&"today".to_string()));
179 }
180
181 #[test]
182 fn test_context_score_single_hit_boosts() {
183 let score = context_score(
184 "email jane@example.com",
185 Span::new(6, 22),
186 &["email"],
187 &ContextConfig::default(),
188 );
189 assert_eq!(score.positive_hits, 1);
190 assert_eq!(score.adjustment, 0.10);
191 }
192
193 #[test]
194 fn test_context_score_multiple_hits_uses_larger_boost() {
195 let score = context_score(
196 "contact email jane@example.com",
197 Span::new(14, 30),
198 &["contact", "email"],
199 &ContextConfig::default(),
200 );
201 assert_eq!(score.positive_hits, 2);
202 assert_eq!(score.adjustment, 0.15);
203 }
204
205 #[test]
206 fn test_context_score_code_context_penalizes() {
207 let score = context_score(
208 "let email = jane@example.com",
209 Span::new(12, 28),
210 &["email"],
211 &ContextConfig::default(),
212 );
213 assert!(score.code_like);
214 assert!(score.adjustment < 0.0);
215 }
216
217 #[test]
218 fn test_context_score_no_terms_returns_zero() {
219 let score = context_score(
220 "value jane@example.com",
221 Span::new(6, 22),
222 &["email"],
223 &ContextConfig::default(),
224 );
225 assert_eq!(score.adjustment, 0.0);
226 }
227}