Skip to main content

fresh/services/completion/
buffer_words.rs

1//! Buffer-word completion provider with proximity scoring.
2//!
3//! Collects unique words from the buffer scan window and ranks them by:
4//! 1. **Proximity** to the cursor (line-distance approximation).
5//! 2. **Viewport bias** — words visible on screen are boosted.
6//! 3. **Frequency** — words that appear more often get a small bonus.
7//!
8//! Smart-case matching, language-aware word boundaries, and multi-buffer
9//! support are all provided through the shared `CompletionContext`.
10//!
11//! # Huge-file safety
12//!
13//! Only the pre-sliced `buffer_window` is scanned. For large files the
14//! completion service limits this to 32 KB around the cursor.
15
16use std::collections::HashMap;
17
18use unicode_segmentation::UnicodeSegmentation;
19
20use super::provider::{
21    case_mismatch_penalty, is_word_grapheme_for_lang, smart_case_matches, CompletionCandidate,
22    CompletionContext, CompletionProvider, CompletionSourceId, ProviderResult,
23};
24
25/// Maximum number of candidates returned.
26const MAX_CANDIDATES: usize = 40;
27
28/// Minimum word length in grapheme clusters to be a candidate.
29const MIN_WORD_LEN_GRAPHEMES: usize = 2;
30
31pub struct BufferWordProvider;
32
33impl BufferWordProvider {
34    pub fn new() -> Self {
35        Self
36    }
37}
38
39impl Default for BufferWordProvider {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45/// Entry tracking a word's occurrences within a scan window.
46struct WordStats {
47    /// The word text (original casing of first occurrence).
48    text: String,
49    /// Number of occurrences.
50    count: u32,
51    /// Byte offset of the occurrence closest to the cursor.
52    nearest_offset: usize,
53    /// Absolute byte-distance of the nearest occurrence to the cursor.
54    nearest_dist: usize,
55    /// Whether at least one occurrence falls within the viewport.
56    in_viewport: bool,
57    /// Length in grapheme clusters (for min-length filtering).
58    grapheme_len: usize,
59}
60
61/// Collect word statistics from a text window.
62fn collect_word_stats(
63    text: &str,
64    extra: &str,
65    cursor_in_window: usize,
66    viewport_start_in_window: usize,
67    viewport_end_in_window: usize,
68) -> HashMap<String, WordStats> {
69    let mut stats: HashMap<String, WordStats> = HashMap::new();
70
71    let mut current_word = String::new();
72    let mut word_start: usize = 0;
73    let mut word_grapheme_count: usize = 0;
74    let mut byte_pos: usize = 0;
75
76    for grapheme in text.graphemes(true) {
77        if is_word_grapheme_for_lang(grapheme, extra) {
78            if current_word.is_empty() {
79                word_start = byte_pos;
80                word_grapheme_count = 0;
81            }
82            current_word.push_str(grapheme);
83            word_grapheme_count += 1;
84        } else if !current_word.is_empty() {
85            record_word(
86                &mut stats,
87                std::mem::take(&mut current_word),
88                word_grapheme_count,
89                word_start,
90                cursor_in_window,
91                viewport_start_in_window,
92                viewport_end_in_window,
93            );
94            word_grapheme_count = 0;
95        }
96        byte_pos += grapheme.len();
97    }
98    if !current_word.is_empty() {
99        record_word(
100            &mut stats,
101            current_word,
102            word_grapheme_count,
103            word_start,
104            cursor_in_window,
105            viewport_start_in_window,
106            viewport_end_in_window,
107        );
108    }
109
110    stats
111}
112
113fn record_word(
114    stats: &mut HashMap<String, WordStats>,
115    word: String,
116    grapheme_len: usize,
117    byte_offset: usize,
118    cursor_in_window: usize,
119    viewport_start: usize,
120    viewport_end: usize,
121) {
122    let dist = byte_offset.abs_diff(cursor_in_window);
123    let in_vp = byte_offset >= viewport_start && byte_offset < viewport_end;
124    let key = word.to_lowercase();
125
126    stats
127        .entry(key)
128        .and_modify(|s| {
129            s.count += 1;
130            if dist < s.nearest_dist {
131                s.nearest_dist = dist;
132                s.nearest_offset = byte_offset;
133            }
134            s.in_viewport |= in_vp;
135        })
136        .or_insert(WordStats {
137            text: word,
138            count: 1,
139            nearest_offset: byte_offset,
140            nearest_dist: dist,
141            in_viewport: in_vp,
142            grapheme_len,
143        });
144}
145
146impl CompletionProvider for BufferWordProvider {
147    fn id(&self) -> CompletionSourceId {
148        CompletionSourceId("buffer_words".into())
149    }
150
151    fn display_name(&self) -> &str {
152        "Buffer Words"
153    }
154
155    fn is_enabled(&self, ctx: &CompletionContext) -> bool {
156        !ctx.prefix.is_empty()
157    }
158
159    fn provide(&self, ctx: &CompletionContext, buffer_window: &[u8]) -> ProviderResult {
160        let text = String::from_utf8_lossy(buffer_window);
161        let extra = &ctx.word_chars_extra;
162
163        let cursor_in_window = ctx.cursor_byte.saturating_sub(ctx.scan_range.start);
164        let vp_start = ctx.viewport_top_byte.saturating_sub(ctx.scan_range.start);
165        let vp_end = ctx
166            .viewport_bottom_byte
167            .saturating_sub(ctx.scan_range.start)
168            .min(buffer_window.len());
169
170        let mut all_stats = collect_word_stats(&text, extra, cursor_in_window, vp_start, vp_end);
171
172        // Merge stats from other open buffers (lower priority).
173        for (i, other) in ctx.other_buffers.iter().enumerate() {
174            let other_text = String::from_utf8_lossy(&other.bytes);
175            let other_stats = collect_word_stats(&other_text, extra, 0, 0, 0);
176            let cross_buffer_dist_offset = 300_000 * (i + 1);
177            for (key, os) in other_stats {
178                all_stats.entry(key).or_insert(WordStats {
179                    text: os.text,
180                    count: os.count,
181                    nearest_offset: os.nearest_offset,
182                    nearest_dist: os.nearest_dist + cross_buffer_dist_offset,
183                    in_viewport: false,
184                    grapheme_len: os.grapheme_len,
185                });
186            }
187        }
188
189        let mut scored: Vec<(i64, &WordStats)> = all_stats
190            .values()
191            .filter(|s| {
192                s.grapheme_len >= MIN_WORD_LEN_GRAPHEMES
193                    && smart_case_matches(&s.text, &ctx.prefix, ctx.prefix_has_uppercase)
194                    && s.text.to_lowercase() != ctx.prefix.to_lowercase()
195            })
196            .map(|s| {
197                // Base: proximity score (closer = higher).
198                let mut score: i64 = 500_000i64.saturating_sub(s.nearest_dist as i64);
199                // Viewport boost: +100k if any occurrence is visible.
200                if s.in_viewport {
201                    score += 100_000;
202                }
203                // Frequency bonus: +5k per extra occurrence (capped).
204                score += (s.count.min(10) as i64 - 1) * 5_000;
205                // Smart-case penalty.
206                score += case_mismatch_penalty(&s.text, &ctx.prefix, ctx.prefix_has_uppercase);
207                (score, s)
208            })
209            .collect();
210
211        scored.sort_by(|a, b| b.0.cmp(&a.0));
212
213        let candidates = scored
214            .into_iter()
215            .take(MAX_CANDIDATES)
216            .map(|(score, s)| CompletionCandidate::word(s.text.clone(), score))
217            .collect();
218
219        ProviderResult::Ready(candidates)
220    }
221
222    fn priority(&self) -> u32 {
223        20
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::super::provider::OtherBufferSlice;
230    use super::*;
231
232    fn make_ctx(prefix: &str, cursor: usize, buf_len: usize) -> CompletionContext {
233        CompletionContext {
234            prefix: prefix.into(),
235            cursor_byte: cursor,
236            word_start_byte: cursor.saturating_sub(prefix.len()),
237            buffer_len: buf_len,
238            is_large_file: false,
239            scan_range: 0..buf_len,
240            viewport_top_byte: 0,
241            viewport_bottom_byte: buf_len,
242            language_id: None,
243            word_chars_extra: String::new(),
244            prefix_has_uppercase: prefix.chars().any(|c| c.is_uppercase()),
245            other_buffers: Vec::new(),
246        }
247    }
248
249    #[test]
250    fn proximity_beats_frequency() {
251        let text = b"far_match far_match far_match close_match";
252        let provider = BufferWordProvider::new();
253        let ctx = CompletionContext {
254            prefix: "far".into(),
255            cursor_byte: 38,
256            word_start_byte: 35,
257            buffer_len: text.len(),
258            is_large_file: false,
259            scan_range: 0..text.len(),
260            viewport_top_byte: 0,
261            viewport_bottom_byte: text.len(),
262            language_id: None,
263            word_chars_extra: String::new(),
264            prefix_has_uppercase: false,
265            other_buffers: Vec::new(),
266        };
267        let result = provider.provide(&ctx, text);
268        match result {
269            ProviderResult::Ready(candidates) => {
270                assert!(!candidates.is_empty());
271                assert_eq!(candidates[0].label, "far_match");
272            }
273            _ => panic!("expected Ready"),
274        }
275    }
276
277    #[test]
278    fn viewport_boost() {
279        let text = b"alpha_one xxxxxxxxx alpha_two";
280        let provider = BufferWordProvider::new();
281        let ctx = CompletionContext {
282            prefix: "alpha".into(),
283            cursor_byte: 15,
284            word_start_byte: 10,
285            buffer_len: text.len(),
286            is_large_file: false,
287            scan_range: 0..text.len(),
288            viewport_top_byte: 20,
289            viewport_bottom_byte: text.len(),
290            language_id: None,
291            word_chars_extra: String::new(),
292            prefix_has_uppercase: false,
293            other_buffers: Vec::new(),
294        };
295        let result = provider.provide(&ctx, text);
296        match result {
297            ProviderResult::Ready(candidates) => {
298                assert_eq!(candidates.len(), 2);
299                assert_eq!(candidates[0].label, "alpha_two");
300            }
301            _ => panic!("expected Ready"),
302        }
303    }
304
305    #[test]
306    fn min_length_filter() {
307        let text = b"a b cc dd hello";
308        let provider = BufferWordProvider::new();
309        let ctx = make_ctx("h", 15, text.len());
310        let result = provider.provide(&ctx, text);
311        match result {
312            ProviderResult::Ready(candidates) => {
313                assert_eq!(candidates.len(), 1);
314                assert_eq!(candidates[0].label, "hello");
315            }
316            _ => panic!("expected Ready"),
317        }
318    }
319
320    #[test]
321    fn unicode_words() {
322        let text = "naïve_var naïve_fn naïf".as_bytes();
323        let provider = BufferWordProvider::new();
324        let ctx = make_ctx("naïve", 0, text.len());
325        let result = provider.provide(&ctx, text);
326        match result {
327            ProviderResult::Ready(candidates) => {
328                let labels: Vec<&str> = candidates.iter().map(|c| c.label.as_str()).collect();
329                assert!(labels.contains(&"naïve_var"));
330                assert!(labels.contains(&"naïve_fn"));
331                assert!(!labels.contains(&"naïf"));
332            }
333            _ => panic!("expected Ready"),
334        }
335    }
336
337    #[test]
338    fn smart_case_penalizes_mismatch() {
339        let text = b"http_request HttpServer HTTP_CONST";
340        let provider = BufferWordProvider::new();
341        let ctx = make_ctx("http", 0, text.len());
342        let result = provider.provide(&ctx, text);
343        match result {
344            ProviderResult::Ready(candidates) => {
345                assert_eq!(candidates.len(), 3);
346                // Exact-case "http_request" should rank above "HttpServer"
347                let req = candidates
348                    .iter()
349                    .find(|c| c.label == "http_request")
350                    .unwrap();
351                let srv = candidates.iter().find(|c| c.label == "HttpServer").unwrap();
352                assert!(req.score > srv.score);
353            }
354            _ => panic!("expected Ready"),
355        }
356    }
357
358    #[test]
359    fn multi_buffer_words() {
360        let text = b"local_var another";
361        let provider = BufferWordProvider::new();
362        let mut ctx = make_ctx("lo", 0, text.len());
363        ctx.other_buffers = vec![OtherBufferSlice {
364            buffer_id: 2,
365            bytes: b"long_name logging".to_vec(),
366            label: "other.rs".into(),
367        }];
368        let result = provider.provide(&ctx, text);
369        match result {
370            ProviderResult::Ready(candidates) => {
371                let labels: Vec<&str> = candidates.iter().map(|c| c.label.as_str()).collect();
372                assert!(labels.contains(&"local_var"));
373                assert!(labels.contains(&"long_name"));
374                assert!(labels.contains(&"logging"));
375                // Active buffer should outscore cross-buffer
376                let local = candidates.iter().find(|c| c.label == "local_var").unwrap();
377                let long = candidates.iter().find(|c| c.label == "long_name").unwrap();
378                assert!(local.score > long.score);
379            }
380            _ => panic!("expected Ready"),
381        }
382    }
383}