Skip to main content

koda_cli/
completer.rs

1//! Tab completion for TUI input.
2//!
3//! Handles two completion modes:
4//! - **Slash commands**: `/d` → `/diff`, `/diff commit`, `/diff review`
5//! - **@file paths**: `explain @src/m` → `explain @src/main.rs`
6
7use std::path::{Path, PathBuf};
8
9/// All known slash commands with (command, description, arg_hint).
10/// `arg_hint` is `Some("<placeholder>")` for commands that take an argument,
11/// `None` for self-contained commands and picker-openers.
12/// Single source of truth — used by completer and auto-dropdown.
13pub const SLASH_COMMANDS: &[(&str, &str, Option<&str>)] = &[
14    ("/agent", "Switch to a sub-agent", Some("<name>")),
15    (
16        "/compact",
17        "Summarize conversation to reclaim context",
18        None,
19    ),
20    ("/diff", "Show git diff (review, commit)", None),
21    ("/exit", "Quit the session", None),
22    ("/expand", "Show full output of last tool call", None),
23    ("/help", "Show commands and shortcuts", None),
24    ("/key", "Manage API keys", None),
25    ("/memory", "View/save project & global memory", None),
26    ("/model", "Pick a model (aliases + local)", None),
27    ("/provider", "Browse all models from a provider", None),
28    (
29        "/purge",
30        "Delete archived history (e.g. /purge 90d)",
31        Some("<days>"),
32    ),
33    ("/sessions", "List/resume/delete sessions", None),
34    ("/skills", "List available skills (search with query)", None),
35    ("/undo", "Undo last turn's file changes", None),
36    ("/verbose", "Toggle full tool output", None),
37];
38
39/// Unified Tab-completion for slash commands and @file paths.
40pub struct InputCompleter {
41    /// Current completion matches.
42    matches: Vec<String>,
43    /// Index into `matches` for cycling.
44    idx: usize,
45    /// The token being completed (to detect changes).
46    token: String,
47    /// Project root for @file path resolution.
48    project_root: PathBuf,
49    /// Cached model names for `/model` completion.
50    model_names: Vec<String>,
51}
52
53impl InputCompleter {
54    pub fn new(project_root: PathBuf) -> Self {
55        Self {
56            matches: Vec::new(),
57            idx: 0,
58            token: String::new(),
59            project_root,
60            model_names: Vec::new(),
61        }
62    }
63
64    /// Update the cached model names (call after provider switch or model list fetch).
65    pub fn set_model_names(&mut self, names: Vec<String>) {
66        self.model_names = names;
67    }
68
69    /// Attempt to complete the current input text.
70    ///
71    /// Returns `Some(replacement_text)` with the full input line replaced,
72    /// or `None` if no completion is available.
73    /// Repeated calls cycle through matches.
74    pub fn complete(&mut self, current_text: &str) -> Option<String> {
75        let trimmed = current_text.trim_end();
76
77        // Slash command completion: input starts with /
78        if trimmed.starts_with('/') {
79            // /model <partial> → complete model names
80            if let Some(partial) = trimmed.strip_prefix("/model ") {
81                return self.complete_model(partial);
82            }
83            return self.complete_slash(trimmed);
84        }
85
86        // @file completion: find the last @token in the input
87        if let Some(at_pos) = find_last_at_token(trimmed) {
88            let partial = &trimmed[at_pos + 1..]; // after @
89            let prefix = &trimmed[..at_pos]; // everything before @
90            return self.complete_file(prefix, partial);
91        }
92
93        self.reset();
94        None
95    }
96
97    /// Reset completion state (call on non-Tab keystrokes).
98    pub fn reset(&mut self) {
99        self.matches.clear();
100        self.idx = 0;
101        self.token.clear();
102    }
103
104    // ── Slash command completion ─────────────────────────────
105
106    fn complete_slash(&mut self, trimmed: &str) -> Option<String> {
107        // Rebuild matches if the token changed
108        if trimmed != self.token && !self.matches.iter().any(|m| m == trimmed) {
109            self.token = trimmed.to_string();
110            self.matches = SLASH_COMMANDS
111                .iter()
112                .filter(|(cmd, _, _)| cmd.starts_with(trimmed) && *cmd != trimmed)
113                .map(|(cmd, _, _)| cmd.to_string())
114                .collect();
115            self.idx = 0;
116        }
117
118        if self.matches.is_empty() {
119            return None;
120        }
121
122        let result = self.matches[self.idx].clone();
123        self.idx = (self.idx + 1) % self.matches.len();
124        Some(result)
125    }
126
127    // ── /model name completion ──────────────────────────────
128
129    fn complete_model(&mut self, partial: &str) -> Option<String> {
130        let token_key = format!("/model {partial}");
131
132        if token_key != self.token {
133            self.token = token_key;
134            // Complete against alias names + cached provider model names
135            let alias_names = koda_core::model_alias::alias_names();
136            self.matches = alias_names
137                .iter()
138                .map(|s| s.to_string())
139                .chain(self.model_names.iter().cloned())
140                .filter(|name| name.contains(partial) && name.as_str() != partial)
141                .map(|name| format!("/model {name}"))
142                .collect();
143            self.idx = 0;
144        }
145
146        if self.matches.is_empty() {
147            return None;
148        }
149
150        let result = self.matches[self.idx].clone();
151        self.idx = (self.idx + 1) % self.matches.len();
152        Some(result)
153    }
154
155    // ── @file path completion ────────────────────────────────
156
157    fn complete_file(&mut self, prefix: &str, partial: &str) -> Option<String> {
158        // Check if the partial is already one of our matches (user is cycling)
159        let is_cycling = !self.matches.is_empty() && self.matches.iter().any(|m| m == partial);
160
161        if !is_cycling {
162            self.token = format!("@{partial}");
163            self.matches = list_path_matches(&self.project_root, partial);
164            self.idx = 0;
165        }
166
167        if self.matches.is_empty() {
168            return None;
169        }
170
171        let path = &self.matches[self.idx];
172        self.idx = (self.idx + 1) % self.matches.len();
173
174        // Rebuild full input: prefix + @completed_path
175        Some(format!("{prefix}@{path}"))
176    }
177}
178
179// ── Helpers ─────────────────────────────────────────────────
180
181/// Find the byte position of the last `@` that starts a file reference.
182///
183/// An `@` counts as a file reference if it's preceded by whitespace
184/// or is at the start of the input (not an email address).
185pub fn find_last_at_token(text: &str) -> Option<usize> {
186    for (i, c) in text.char_indices().rev() {
187        if c == '@' && (i == 0 || matches!(text.as_bytes()[i - 1], b' ' | b'\n')) {
188            return Some(i);
189        }
190    }
191    None
192}
193
194/// List filesystem paths matching a partial path relative to project_root.
195/// Public wrapper for the `@` auto-dropdown in `tui_app.rs`.
196pub fn list_path_matches_public(project_root: &Path, partial: &str) -> Vec<String> {
197    list_path_matches(project_root, partial)
198}
199
200/// List filesystem paths matching a partial path relative to project_root.
201///
202/// Uses fuzzy subsequence matching: `@mrs` matches `main.rs`, `@ctml` matches `Cargo.toml`.
203/// Prefix matches rank higher than fuzzy matches.
204/// Directories get a trailing `/` to encourage further completion.
205fn list_path_matches(project_root: &Path, partial: &str) -> Vec<String> {
206    let (dir_part, file_prefix) = match partial.rfind('/') {
207        Some(pos) => (&partial[..=pos], &partial[pos + 1..]),
208        None => ("", partial),
209    };
210
211    let search_dir = if dir_part.is_empty() {
212        project_root.to_path_buf()
213    } else {
214        // Security: reject paths with traversal components
215        if dir_part.contains("..") {
216            return Vec::new();
217        }
218        project_root.join(dir_part)
219    };
220
221    let entries = match std::fs::read_dir(&search_dir) {
222        Ok(entries) => entries,
223        Err(_) => return Vec::new(),
224    };
225
226    let lower_prefix = file_prefix.to_lowercase();
227
228    let mut scored: Vec<(i32, String)> = entries
229        .filter_map(|e| e.ok())
230        .filter_map(|entry| {
231            let name = entry.file_name().to_string_lossy().to_string();
232
233            // Skip hidden files and common noise
234            if name.starts_with('.') {
235                return None;
236            }
237
238            let is_dir = entry.file_type().map(|t| t.is_dir()).unwrap_or(false);
239
240            // Skip build artifacts / deps
241            if is_dir
242                && matches!(
243                    name.as_str(),
244                    "target" | "node_modules" | "__pycache__" | ".git"
245                )
246            {
247                return None;
248            }
249
250            // query is lowered; target keeps original case for camelCase detection
251            let score = fuzzy_score(&lower_prefix, &name)?;
252
253            let path = if is_dir {
254                format!("{dir_part}{name}/")
255            } else {
256                format!("{dir_part}{name}")
257            };
258            Some((score, path))
259        })
260        .collect();
261
262    // Sort by score (higher = better match), then alphabetically
263    scored.sort_by(|a, b| b.0.cmp(&a.0).then_with(|| a.1.cmp(&b.1)));
264    scored.into_iter().map(|(_, path)| path).collect()
265}
266
267/// Fuzzy subsequence scoring.
268///
269/// Returns `Some(score)` if all chars of `query` appear in `target` in order.
270/// Higher score = better match.
271///
272/// Scoring (nucleo-inspired, matching CC's `native-ts/file-index`):
273/// - Base: +1 per matched char
274/// - Prefix / first char at pos 0: +100
275/// - Consecutive chars: +10
276/// - After separator (`_`, `-`, `.`, `/`): +5
277/// - camelCase transition (lower→upper): +6
278/// - Gap penalty: −3 (start) + −1 per additional gap char
279///
280/// `query` must be lowercased. `target` is **original case** so camelCase
281/// transitions can be detected; character comparison is case-insensitive.
282fn fuzzy_score(query: &str, target: &str) -> Option<i32> {
283    if query.is_empty() {
284        return Some(0);
285    }
286
287    let query_chars: Vec<char> = query.chars().collect();
288    let target_chars: Vec<char> = target.chars().collect();
289
290    let mut qi = 0;
291    let mut score: i32 = 0;
292    let mut prev_match_pos: Option<usize> = None;
293
294    for (ti, &tc) in target_chars.iter().enumerate() {
295        if qi < query_chars.len() && tc.to_ascii_lowercase() == query_chars[qi] {
296            score += 1;
297
298            // Bonus: prefix match
299            if qi == 0 && ti == 0 {
300                score += 100;
301            }
302
303            // Bonus: consecutive match
304            if ti > 0 && prev_match_pos == Some(ti - 1) {
305                score += 10;
306            }
307
308            // Bonus: after separator
309            if ti > 0 && matches!(target_chars[ti - 1], '_' | '-' | '.' | '/') {
310                score += 5;
311            }
312
313            // Bonus: camelCase transition (previous char lowercase, current uppercase)
314            if ti > 0 && target_chars[ti - 1].is_ascii_lowercase() && tc.is_ascii_uppercase() {
315                score += 6;
316            }
317
318            // Penalty: gap between consecutive matches
319            if let Some(prev) = prev_match_pos {
320                let gap = ti - prev - 1;
321                if gap > 0 {
322                    score -= 3 + gap as i32; // start + extension
323                }
324            }
325
326            prev_match_pos = Some(ti);
327            qi += 1;
328        }
329    }
330
331    if qi == query_chars.len() {
332        Some(score)
333    } else {
334        None // Not all query chars matched
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341    use std::fs;
342    use tempfile::tempdir;
343
344    // ── Slash command tests ─────────────────────────────────
345
346    #[test]
347    fn test_complete_slash_d() {
348        let tmp = tempdir().unwrap();
349        let mut c = InputCompleter::new(tmp.path().to_path_buf());
350        let first = c.complete("/d");
351        assert!(first.is_some());
352        assert!(first.unwrap().starts_with("/d"));
353    }
354
355    #[test]
356    fn test_complete_cycles() {
357        let tmp = tempdir().unwrap();
358        let mut c = InputCompleter::new(tmp.path().to_path_buf());
359        let a = c.complete("/d");
360        let b = c.complete("/d");
361        assert!(a.is_some());
362        assert!(b.is_some());
363    }
364
365    #[test]
366    fn test_no_match() {
367        let tmp = tempdir().unwrap();
368        let mut c = InputCompleter::new(tmp.path().to_path_buf());
369        assert!(c.complete("/zzz").is_none());
370    }
371
372    #[test]
373    fn test_non_slash_no_at_returns_none() {
374        let tmp = tempdir().unwrap();
375        let mut c = InputCompleter::new(tmp.path().to_path_buf());
376        assert!(c.complete("hello").is_none());
377    }
378
379    #[test]
380    fn test_exact_match_no_complete() {
381        let tmp = tempdir().unwrap();
382        let mut c = InputCompleter::new(tmp.path().to_path_buf());
383        assert!(c.complete("/exit").is_none());
384    }
385
386    // ── @file completion tests ───────────────────────────────
387
388    #[test]
389    fn test_at_file_completes() {
390        let tmp = tempdir().unwrap();
391        fs::write(tmp.path().join("main.rs"), "fn main() {}").unwrap();
392        fs::write(tmp.path().join("mod.rs"), "").unwrap();
393
394        let mut c = InputCompleter::new(tmp.path().to_path_buf());
395        let result = c.complete("explain @m");
396        assert!(result.is_some());
397        let text = result.unwrap();
398        assert!(text.starts_with("explain @m"), "got: {text}");
399        assert!(
400            text.contains("main.rs") || text.contains("mod.rs"),
401            "got: {text}"
402        );
403    }
404
405    #[test]
406    fn test_at_file_in_subdir() {
407        let tmp = tempdir().unwrap();
408        fs::create_dir_all(tmp.path().join("src")).unwrap();
409        fs::write(tmp.path().join("src/lib.rs"), "").unwrap();
410        fs::write(tmp.path().join("src/main.rs"), "").unwrap();
411
412        let mut c = InputCompleter::new(tmp.path().to_path_buf());
413        let result = c.complete("@src/l");
414        assert_eq!(result, Some("@src/lib.rs".to_string()));
415    }
416
417    #[test]
418    fn test_at_file_dir_gets_trailing_slash() {
419        let tmp = tempdir().unwrap();
420        fs::create_dir_all(tmp.path().join("src")).unwrap();
421
422        let mut c = InputCompleter::new(tmp.path().to_path_buf());
423        let result = c.complete("@s");
424        assert_eq!(result, Some("@src/".to_string()));
425    }
426
427    #[test]
428    fn test_at_file_cycles() {
429        let tmp = tempdir().unwrap();
430        fs::write(tmp.path().join("alpha.rs"), "").unwrap();
431        fs::write(tmp.path().join("beta.rs"), "").unwrap();
432
433        let mut c = InputCompleter::new(tmp.path().to_path_buf());
434        // First Tab: input is "@" → returns first match
435        let a = c.complete("@").unwrap();
436        // Second Tab: input is now the completed text (e.g., "@alpha.rs")
437        let b = c.complete(&a).unwrap();
438        assert_ne!(a, b, "should cycle through different files");
439        // Third Tab: should cycle back
440        let c_result = c.complete(&b).unwrap();
441        assert_eq!(c_result, a, "should cycle back to first");
442        assert_eq!(c_result, a, "should cycle back to first");
443    }
444
445    #[test]
446    fn test_at_file_skips_hidden() {
447        let tmp = tempdir().unwrap();
448        fs::write(tmp.path().join(".hidden"), "").unwrap();
449        fs::write(tmp.path().join("visible.rs"), "").unwrap();
450
451        let mut c = InputCompleter::new(tmp.path().to_path_buf());
452        let result = c.complete("@");
453        assert_eq!(result, Some("@visible.rs".to_string()));
454    }
455
456    #[test]
457    fn test_at_file_case_insensitive() {
458        let tmp = tempdir().unwrap();
459        fs::write(tmp.path().join("Makefile"), "").unwrap();
460        fs::write(tmp.path().join("README.md"), "").unwrap();
461
462        let mut c = InputCompleter::new(tmp.path().to_path_buf());
463        let result = c.complete("@make");
464        assert_eq!(result, Some("@Makefile".to_string()));
465
466        c.reset();
467        let result = c.complete("@read");
468        assert_eq!(result, Some("@README.md".to_string()));
469    }
470
471    #[test]
472    fn test_at_file_preserves_prefix_text() {
473        let tmp = tempdir().unwrap();
474        fs::write(tmp.path().join("config.toml"), "").unwrap();
475
476        let mut c = InputCompleter::new(tmp.path().to_path_buf());
477        let result = c.complete("review this @c");
478        assert_eq!(result, Some("review this @config.toml".to_string()));
479    }
480
481    // ── /model completion tests ──────────────────────────────
482
483    #[test]
484    fn test_model_complete() {
485        let tmp = tempdir().unwrap();
486        let mut c = InputCompleter::new(tmp.path().to_path_buf());
487        c.set_model_names(vec![
488            "gpt-4o".into(),
489            "gpt-4o-mini".into(),
490            "gpt-3.5-turbo".into(),
491        ]);
492        let result = c.complete("/model gpt-4");
493        assert!(result.is_some());
494        let text = result.unwrap();
495        assert!(text.starts_with("/model gpt-4"), "got: {text}");
496    }
497
498    #[test]
499    fn test_model_complete_cycles() {
500        let tmp = tempdir().unwrap();
501        let mut c = InputCompleter::new(tmp.path().to_path_buf());
502        c.set_model_names(vec!["gpt-4o".into(), "gpt-4o-mini".into()]);
503        let a = c.complete("/model gpt");
504        let b = c.complete("/model gpt");
505        assert!(a.is_some());
506        assert!(b.is_some());
507        assert_ne!(a, b, "should cycle through models");
508    }
509
510    #[test]
511    fn test_model_no_names_returns_none() {
512        let tmp = tempdir().unwrap();
513        let mut c = InputCompleter::new(tmp.path().to_path_buf());
514        // No provider model names set; "gpt" matches no aliases (we only have gemini/claude)
515        assert!(c.complete("/model gpt").is_none());
516    }
517
518    #[test]
519    fn test_model_no_match_returns_none() {
520        let tmp = tempdir().unwrap();
521        let mut c = InputCompleter::new(tmp.path().to_path_buf());
522        // "zzz" matches no aliases or model names
523        assert!(c.complete("/model zzz").is_none());
524    }
525
526    #[test]
527    fn test_model_substring_match() {
528        let tmp = tempdir().unwrap();
529        let mut c = InputCompleter::new(tmp.path().to_path_buf());
530        c.set_model_names(vec!["claude-3-sonnet".into(), "claude-3-opus".into()]);
531        let result = c.complete("/model opus");
532        // "opus" matches both the alias "claude-opus" and "claude-3-opus" from model_names
533        assert!(result.is_some());
534        let text = result.unwrap();
535        assert!(text.contains("opus"), "got: {text}");
536    }
537
538    // ── Helper tests ────────────────────────────────────────
539
540    #[test]
541    fn test_find_last_at_token() {
542        assert_eq!(find_last_at_token("@file"), Some(0));
543        assert_eq!(find_last_at_token("explain @file"), Some(8));
544        assert_eq!(find_last_at_token("email@domain"), None); // no space before @
545        assert_eq!(find_last_at_token("a @b @c"), Some(5)); // last @
546        assert_eq!(find_last_at_token("no at here"), None);
547        // @ after newline (multi-line input via Alt+Enter)
548        assert_eq!(find_last_at_token("line1\n@file"), Some(6));
549        assert_eq!(find_last_at_token("a\nb\n@c"), Some(4));
550    }
551
552    #[test]
553    fn test_at_file_after_newline() {
554        let tmp = tempdir().unwrap();
555        fs::write(tmp.path().join("config.toml"), "").unwrap();
556
557        let mut c = InputCompleter::new(tmp.path().to_path_buf());
558        // Simulate multi-line input: first line + newline + @partial
559        let result = c.complete("explain this\n@c");
560        assert_eq!(result, Some("explain this\n@config.toml".to_string()));
561    }
562
563    #[test]
564    fn test_at_file_traversal_blocked() {
565        let tmp = tempdir().unwrap();
566        fs::write(tmp.path().join("safe.rs"), "").unwrap();
567
568        let mut c = InputCompleter::new(tmp.path().to_path_buf());
569        // Attempt path traversal — should return no matches
570        let result = c.complete("@../../etc/");
571        assert!(result.is_none(), "traversal should be blocked");
572    }
573
574    // ── Fuzzy matching tests ────────────────────────────────
575
576    #[test]
577    fn test_fuzzy_score_basic() {
578        // Exact prefix → high score
579        assert!(fuzzy_score("main", "main.rs").unwrap() > 100);
580        // Subsequence match
581        assert!(fuzzy_score("mrs", "main.rs").is_some());
582        // No match
583        assert!(fuzzy_score("xyz", "main.rs").is_none());
584    }
585
586    #[test]
587    fn test_fuzzy_score_prefix_wins() {
588        let prefix = fuzzy_score("ma", "main.rs").unwrap();
589        let fuzzy = fuzzy_score("ma", "format.rs").unwrap();
590        assert!(prefix > fuzzy, "prefix {prefix} should beat fuzzy {fuzzy}");
591    }
592
593    #[test]
594    fn test_fuzzy_at_file() {
595        let tmp = tempdir().unwrap();
596        fs::write(tmp.path().join("main.rs"), "").unwrap();
597        fs::write(tmp.path().join("Cargo.toml"), "").unwrap();
598        fs::write(tmp.path().join("config.rs"), "").unwrap();
599
600        let mut c = InputCompleter::new(tmp.path().to_path_buf());
601        // "mrs" → should fuzzy-match main.rs (m...r.s)
602        let result = c.complete("@mrs");
603        assert_eq!(result, Some("@main.rs".to_string()));
604    }
605
606    #[test]
607    fn test_fuzzy_cargo_toml() {
608        let tmp = tempdir().unwrap();
609        fs::write(tmp.path().join("Cargo.toml"), "").unwrap();
610        fs::write(tmp.path().join("config.rs"), "").unwrap();
611
612        let mut c = InputCompleter::new(tmp.path().to_path_buf());
613        // "ctml" → fuzzy-match Cargo.toml (c...t..m.l)
614        let result = c.complete("@ctml");
615        assert_eq!(result, Some("@Cargo.toml".to_string()));
616    }
617
618    #[test]
619    fn test_fuzzy_prefix_ranked_first() {
620        let tmp = tempdir().unwrap();
621        fs::write(tmp.path().join("main.rs"), "").unwrap();
622        fs::write(tmp.path().join("format.rs"), "").unwrap();
623
624        let mut c = InputCompleter::new(tmp.path().to_path_buf());
625        // "m" → main.rs should come before format.rs (prefix match wins)
626        let result = c.complete("@m");
627        assert_eq!(result, Some("@main.rs".to_string()));
628    }
629
630    // ── Gap penalty tests ──────────────────────────────────
631
632    #[test]
633    fn test_gap_penalty_tight_beats_scattered() {
634        // "mrs": main.rs has gap=1 (m-a-i-n-.-r-s), scattered has large gaps
635        let tight = fuzzy_score("mrs", "main.rs").unwrap();
636        let scattered = fuzzy_score("mrs", "my_really_long_script.rs").unwrap();
637        assert!(
638            tight > scattered,
639            "tight {tight} should beat scattered {scattered}"
640        );
641    }
642
643    #[test]
644    fn test_gap_penalty_consecutive_no_penalty() {
645        // Consecutive chars should get bonus, not penalty
646        let consec = fuzzy_score("mai", "main.rs").unwrap();
647        let gapped = fuzzy_score("mai", "m_a_i.rs").unwrap();
648        assert!(
649            consec > gapped,
650            "consecutive {consec} should beat gapped {gapped}"
651        );
652    }
653
654    // ── camelCase bonus tests ──────────────────────────────
655
656    #[test]
657    fn test_camel_case_bonus() {
658        // "dm" at camelCase boundary (D→M) should score higher
659        let camel = fuzzy_score("dm", "DropdownMenu").unwrap();
660        let flat = fuzzy_score("dm", "random_dm_file").unwrap();
661        assert!(camel > flat, "camelCase {camel} should beat flat {flat}");
662    }
663}