fresh/services/completion/
buffer_words.rs1use std::collections::HashMap;
17
18use unicode_segmentation::UnicodeSegmentation;
19
20use super::provider::{
21 case_mismatch_penalty, is_word_grapheme_for_lang, smart_case_matches, CompletionCandidate,
22 CompletionContext, CompletionProvider, CompletionSourceId, ProviderResult,
23};
24
25const MAX_CANDIDATES: usize = 40;
27
28const MIN_WORD_LEN_GRAPHEMES: usize = 2;
30
31pub struct BufferWordProvider;
32
33impl BufferWordProvider {
34 pub fn new() -> Self {
35 Self
36 }
37}
38
39impl Default for BufferWordProvider {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45struct WordStats {
47 text: String,
49 count: u32,
51 nearest_offset: usize,
53 nearest_dist: usize,
55 in_viewport: bool,
57 grapheme_len: usize,
59}
60
61fn collect_word_stats(
63 text: &str,
64 extra: &str,
65 cursor_in_window: usize,
66 viewport_start_in_window: usize,
67 viewport_end_in_window: usize,
68) -> HashMap<String, WordStats> {
69 let mut stats: HashMap<String, WordStats> = HashMap::new();
70
71 let mut current_word = String::new();
72 let mut word_start: usize = 0;
73 let mut word_grapheme_count: usize = 0;
74 let mut byte_pos: usize = 0;
75
76 for grapheme in text.graphemes(true) {
77 if is_word_grapheme_for_lang(grapheme, extra) {
78 if current_word.is_empty() {
79 word_start = byte_pos;
80 word_grapheme_count = 0;
81 }
82 current_word.push_str(grapheme);
83 word_grapheme_count += 1;
84 } else if !current_word.is_empty() {
85 record_word(
86 &mut stats,
87 std::mem::take(&mut current_word),
88 word_grapheme_count,
89 word_start,
90 cursor_in_window,
91 viewport_start_in_window,
92 viewport_end_in_window,
93 );
94 word_grapheme_count = 0;
95 }
96 byte_pos += grapheme.len();
97 }
98 if !current_word.is_empty() {
99 record_word(
100 &mut stats,
101 current_word,
102 word_grapheme_count,
103 word_start,
104 cursor_in_window,
105 viewport_start_in_window,
106 viewport_end_in_window,
107 );
108 }
109
110 stats
111}
112
113fn record_word(
114 stats: &mut HashMap<String, WordStats>,
115 word: String,
116 grapheme_len: usize,
117 byte_offset: usize,
118 cursor_in_window: usize,
119 viewport_start: usize,
120 viewport_end: usize,
121) {
122 let dist = byte_offset.abs_diff(cursor_in_window);
123 let in_vp = byte_offset >= viewport_start && byte_offset < viewport_end;
124 let key = word.to_lowercase();
125
126 stats
127 .entry(key)
128 .and_modify(|s| {
129 s.count += 1;
130 if dist < s.nearest_dist {
131 s.nearest_dist = dist;
132 s.nearest_offset = byte_offset;
133 }
134 s.in_viewport |= in_vp;
135 })
136 .or_insert(WordStats {
137 text: word,
138 count: 1,
139 nearest_offset: byte_offset,
140 nearest_dist: dist,
141 in_viewport: in_vp,
142 grapheme_len,
143 });
144}
145
146impl CompletionProvider for BufferWordProvider {
147 fn id(&self) -> CompletionSourceId {
148 CompletionSourceId("buffer_words".into())
149 }
150
151 fn display_name(&self) -> &str {
152 "Buffer Words"
153 }
154
155 fn is_enabled(&self, ctx: &CompletionContext) -> bool {
156 !ctx.prefix.is_empty()
157 }
158
159 fn provide(&self, ctx: &CompletionContext, buffer_window: &[u8]) -> ProviderResult {
160 let text = String::from_utf8_lossy(buffer_window);
161 let extra = &ctx.word_chars_extra;
162
163 let cursor_in_window = ctx.cursor_byte.saturating_sub(ctx.scan_range.start);
164 let vp_start = ctx.viewport_top_byte.saturating_sub(ctx.scan_range.start);
165 let vp_end = ctx
166 .viewport_bottom_byte
167 .saturating_sub(ctx.scan_range.start)
168 .min(buffer_window.len());
169
170 let mut all_stats = collect_word_stats(&text, extra, cursor_in_window, vp_start, vp_end);
171
172 for (i, other) in ctx.other_buffers.iter().enumerate() {
174 let other_text = String::from_utf8_lossy(&other.bytes);
175 let other_stats = collect_word_stats(&other_text, extra, 0, 0, 0);
176 let cross_buffer_dist_offset = 300_000 * (i + 1);
177 for (key, os) in other_stats {
178 all_stats.entry(key).or_insert(WordStats {
179 text: os.text,
180 count: os.count,
181 nearest_offset: os.nearest_offset,
182 nearest_dist: os.nearest_dist + cross_buffer_dist_offset,
183 in_viewport: false,
184 grapheme_len: os.grapheme_len,
185 });
186 }
187 }
188
189 let mut scored: Vec<(i64, &WordStats)> = all_stats
190 .values()
191 .filter(|s| {
192 s.grapheme_len >= MIN_WORD_LEN_GRAPHEMES
193 && smart_case_matches(&s.text, &ctx.prefix, ctx.prefix_has_uppercase)
194 && s.text.to_lowercase() != ctx.prefix.to_lowercase()
195 })
196 .map(|s| {
197 let mut score: i64 = 500_000i64.saturating_sub(s.nearest_dist as i64);
199 if s.in_viewport {
201 score += 100_000;
202 }
203 score += (s.count.min(10) as i64 - 1) * 5_000;
205 score += case_mismatch_penalty(&s.text, &ctx.prefix, ctx.prefix_has_uppercase);
207 (score, s)
208 })
209 .collect();
210
211 scored.sort_by(|a, b| b.0.cmp(&a.0));
212
213 let candidates = scored
214 .into_iter()
215 .take(MAX_CANDIDATES)
216 .map(|(score, s)| CompletionCandidate::word(s.text.clone(), score))
217 .collect();
218
219 ProviderResult::Ready(candidates)
220 }
221
222 fn priority(&self) -> u32 {
223 20
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::super::provider::OtherBufferSlice;
230 use super::*;
231
232 fn make_ctx(prefix: &str, cursor: usize, buf_len: usize) -> CompletionContext {
233 CompletionContext {
234 prefix: prefix.into(),
235 cursor_byte: cursor,
236 word_start_byte: cursor.saturating_sub(prefix.len()),
237 buffer_len: buf_len,
238 is_large_file: false,
239 scan_range: 0..buf_len,
240 viewport_top_byte: 0,
241 viewport_bottom_byte: buf_len,
242 language_id: None,
243 word_chars_extra: String::new(),
244 prefix_has_uppercase: prefix.chars().any(|c| c.is_uppercase()),
245 other_buffers: Vec::new(),
246 }
247 }
248
249 #[test]
250 fn proximity_beats_frequency() {
251 let text = b"far_match far_match far_match close_match";
252 let provider = BufferWordProvider::new();
253 let ctx = CompletionContext {
254 prefix: "far".into(),
255 cursor_byte: 38,
256 word_start_byte: 35,
257 buffer_len: text.len(),
258 is_large_file: false,
259 scan_range: 0..text.len(),
260 viewport_top_byte: 0,
261 viewport_bottom_byte: text.len(),
262 language_id: None,
263 word_chars_extra: String::new(),
264 prefix_has_uppercase: false,
265 other_buffers: Vec::new(),
266 };
267 let result = provider.provide(&ctx, text);
268 match result {
269 ProviderResult::Ready(candidates) => {
270 assert!(!candidates.is_empty());
271 assert_eq!(candidates[0].label, "far_match");
272 }
273 _ => panic!("expected Ready"),
274 }
275 }
276
277 #[test]
278 fn viewport_boost() {
279 let text = b"alpha_one xxxxxxxxx alpha_two";
280 let provider = BufferWordProvider::new();
281 let ctx = CompletionContext {
282 prefix: "alpha".into(),
283 cursor_byte: 15,
284 word_start_byte: 10,
285 buffer_len: text.len(),
286 is_large_file: false,
287 scan_range: 0..text.len(),
288 viewport_top_byte: 20,
289 viewport_bottom_byte: text.len(),
290 language_id: None,
291 word_chars_extra: String::new(),
292 prefix_has_uppercase: false,
293 other_buffers: Vec::new(),
294 };
295 let result = provider.provide(&ctx, text);
296 match result {
297 ProviderResult::Ready(candidates) => {
298 assert_eq!(candidates.len(), 2);
299 assert_eq!(candidates[0].label, "alpha_two");
300 }
301 _ => panic!("expected Ready"),
302 }
303 }
304
305 #[test]
306 fn min_length_filter() {
307 let text = b"a b cc dd hello";
308 let provider = BufferWordProvider::new();
309 let ctx = make_ctx("h", 15, text.len());
310 let result = provider.provide(&ctx, text);
311 match result {
312 ProviderResult::Ready(candidates) => {
313 assert_eq!(candidates.len(), 1);
314 assert_eq!(candidates[0].label, "hello");
315 }
316 _ => panic!("expected Ready"),
317 }
318 }
319
320 #[test]
321 fn unicode_words() {
322 let text = "naïve_var naïve_fn naïf".as_bytes();
323 let provider = BufferWordProvider::new();
324 let ctx = make_ctx("naïve", 0, text.len());
325 let result = provider.provide(&ctx, text);
326 match result {
327 ProviderResult::Ready(candidates) => {
328 let labels: Vec<&str> = candidates.iter().map(|c| c.label.as_str()).collect();
329 assert!(labels.contains(&"naïve_var"));
330 assert!(labels.contains(&"naïve_fn"));
331 assert!(!labels.contains(&"naïf"));
332 }
333 _ => panic!("expected Ready"),
334 }
335 }
336
337 #[test]
338 fn smart_case_penalizes_mismatch() {
339 let text = b"http_request HttpServer HTTP_CONST";
340 let provider = BufferWordProvider::new();
341 let ctx = make_ctx("http", 0, text.len());
342 let result = provider.provide(&ctx, text);
343 match result {
344 ProviderResult::Ready(candidates) => {
345 assert_eq!(candidates.len(), 3);
346 let req = candidates
348 .iter()
349 .find(|c| c.label == "http_request")
350 .unwrap();
351 let srv = candidates.iter().find(|c| c.label == "HttpServer").unwrap();
352 assert!(req.score > srv.score);
353 }
354 _ => panic!("expected Ready"),
355 }
356 }
357
358 #[test]
359 fn multi_buffer_words() {
360 let text = b"local_var another";
361 let provider = BufferWordProvider::new();
362 let mut ctx = make_ctx("lo", 0, text.len());
363 ctx.other_buffers = vec![OtherBufferSlice {
364 buffer_id: 2,
365 bytes: b"long_name logging".to_vec(),
366 label: "other.rs".into(),
367 }];
368 let result = provider.provide(&ctx, text);
369 match result {
370 ProviderResult::Ready(candidates) => {
371 let labels: Vec<&str> = candidates.iter().map(|c| c.label.as_str()).collect();
372 assert!(labels.contains(&"local_var"));
373 assert!(labels.contains(&"long_name"));
374 assert!(labels.contains(&"logging"));
375 let local = candidates.iter().find(|c| c.label == "local_var").unwrap();
377 let long = candidates.iter().find(|c| c.label == "long_name").unwrap();
378 assert!(local.score > long.score);
379 }
380 _ => panic!("expected Ready"),
381 }
382 }
383}