Skip to main content

koda_cli/
input.rs

1//! Input processing — @file references and image loading.
2//!
3//! Processes user input for `@path` references, loading file contents
4//! as additional context and images for multi-modal prompts.
5
6use koda_core::providers::ImageData;
7use std::path::{Path, PathBuf};
8
9// ── @file pre-processor ────────────────────────────────────────
10
11/// Content pasted via clipboard (bracketed paste).
12#[derive(Debug, Clone)]
13pub struct PasteBlock {
14    /// The raw pasted text.
15    pub content: String,
16    /// Character count.
17    pub char_count: usize,
18}
19
20/// Result of processing user input for `@path` references.
21#[derive(Debug)]
22pub struct ProcessedInput {
23    /// The cleaned prompt text (with @references stripped).
24    pub prompt: String,
25    /// File contents to inject as additional context.
26    pub context_files: Vec<FileContext>,
27    /// Base64-encoded images from @image references.
28    pub images: Vec<ImageData>,
29    /// Pasted content blocks (from bracketed paste).
30    pub paste_blocks: Vec<PasteBlock>,
31}
32
33/// A file's contents loaded from an `@path` reference.
34#[derive(Debug)]
35pub struct FileContext {
36    pub path: String,
37    pub content: String,
38}
39
40/// Image file extensions we recognize for multi-modal input.
41const IMAGE_EXTENSIONS: &[&str] = &["png", "jpg", "jpeg", "gif", "webp", "bmp"];
42
43/// Detect if a file path refers to an image by extension.
44fn is_image_file(path: &str) -> bool {
45    let lower = path.to_lowercase();
46    IMAGE_EXTENSIONS.iter().any(|ext| lower.ends_with(ext))
47}
48
49/// Determine MIME type from file extension.
50fn mime_type_for(path: &str) -> &'static str {
51    let lower = path.to_lowercase();
52    if lower.ends_with(".png") {
53        "image/png"
54    } else if lower.ends_with(".jpg") || lower.ends_with(".jpeg") {
55        "image/jpeg"
56    } else if lower.ends_with(".gif") {
57        "image/gif"
58    } else if lower.ends_with(".webp") {
59        "image/webp"
60    } else if lower.ends_with(".bmp") {
61        "image/bmp"
62    } else {
63        "application/octet-stream"
64    }
65}
66
67/// Strip surrounding quotes from a token (terminals often quote dragged paths).
68fn strip_quotes(s: &str) -> &str {
69    if s.len() >= 2
70        && ((s.starts_with('"') && s.ends_with('"')) || (s.starts_with('\'') && s.ends_with('\'')))
71    {
72        &s[1..s.len() - 1]
73    } else {
74        s
75    }
76}
77
78/// Check if a token looks like a bare file path (absolute, ~/, or ./ prefixed).
79fn looks_like_file_path(token: &str) -> bool {
80    let cleaned = strip_quotes(token);
81    cleaned.starts_with('/')
82        || cleaned.starts_with("~/")
83        || cleaned.starts_with("./")
84        || cleaned.starts_with("..")
85        // Windows absolute paths: C:\ or D:\
86        || (cleaned.len() >= 3
87            && cleaned.as_bytes()[0].is_ascii_alphabetic()
88            && cleaned.as_bytes()[1] == b':'
89            && (cleaned.as_bytes()[2] == b'\\' || cleaned.as_bytes()[2] == b'/'))
90}
91
92/// Try to load an image file, returning the ImageData if successful.
93fn try_load_image(path: &Path, display_path: &str) -> Option<ImageData> {
94    match std::fs::read(path) {
95        Ok(bytes) => {
96            use base64::Engine;
97            let b64 = base64::engine::general_purpose::STANDARD.encode(&bytes);
98            let media_type = mime_type_for(display_path).to_string();
99            Some(ImageData {
100                media_type,
101                base64: b64,
102            })
103        }
104        Err(_) => {
105            eprintln!("  \x1b[33m\u{26a0} Could not read image: {display_path}\x1b[0m");
106            None
107        }
108    }
109}
110
111/// Resolve a bare path token to an absolute path, expanding ~ if needed.
112fn resolve_bare_path(token: &str) -> Option<PathBuf> {
113    let cleaned = strip_quotes(token);
114    if let Some(rest) = cleaned.strip_prefix("~/") {
115        let home = std::env::var("HOME")
116            .or_else(|_| std::env::var("USERPROFILE"))
117            .ok()?;
118        Some(PathBuf::from(home).join(rest))
119    } else {
120        let p = PathBuf::from(cleaned);
121        if p.is_absolute() {
122            Some(p)
123        } else {
124            // Relative paths like ./foo or ../foo — resolve from cwd
125            std::env::current_dir().ok().map(|cwd| cwd.join(cleaned))
126        }
127    }
128}
129
130/// Scan input for `@path` tokens and bare image paths (drag-and-drop),
131/// read the files, and return cleaned prompt plus file contents and images.
132pub fn process_input(input: &str, project_root: &Path) -> ProcessedInput {
133    let mut prompt_parts = Vec::new();
134    let mut context_files = Vec::new();
135    let mut images = Vec::new();
136
137    for token in input.split_whitespace() {
138        // ── @path references (explicit) ───────────────────────
139        if let Some(raw_path) = token.strip_prefix('@') {
140            if raw_path.is_empty() {
141                prompt_parts.push(token.to_string());
142                continue;
143            }
144
145            let raw_path = strip_quotes(raw_path);
146
147            // Security: reject paths that escape the project root
148            let full_path = match koda_core::tools::safe_resolve_path(project_root, raw_path) {
149                Ok(p) => p,
150                Err(_) => {
151                    tracing::warn!("@file path escapes project root: {raw_path}");
152                    prompt_parts.push(token.to_string());
153                    continue;
154                }
155            };
156
157            // Image files → base64 encode for multi-modal
158            if is_image_file(raw_path) {
159                if let Some(img) = try_load_image(&full_path, raw_path) {
160                    images.push(img);
161                } else {
162                    prompt_parts.push(token.to_string());
163                }
164                continue;
165            }
166
167            // Text files → read as string context
168            match std::fs::read_to_string(&full_path) {
169                Ok(content) => {
170                    context_files.push(FileContext {
171                        path: raw_path.to_string(),
172                        content,
173                    });
174                }
175                Err(_) => {
176                    eprintln!("  \x1b[33m\u{26a0} Could not read: {raw_path}\x1b[0m");
177                    prompt_parts.push(token.to_string());
178                }
179            }
180            continue;
181        }
182
183        // ── Bare image paths (drag-and-drop) ──────────────────
184        // Detect absolute/relative paths to image files pasted directly
185        let unquoted = strip_quotes(token);
186        if looks_like_file_path(token)
187            && is_image_file(unquoted)
188            && let Some(resolved) = resolve_bare_path(token)
189            && resolved.exists()
190        {
191            let display = resolved.display().to_string();
192            if let Some(img) = try_load_image(&resolved, &display) {
193                images.push(img);
194                continue;
195            }
196        }
197
198        prompt_parts.push(token.to_string());
199    }
200
201    let prompt = prompt_parts.join(" ");
202
203    // If only @refs were provided with no other text, add a default prompt
204    let prompt = if prompt.trim().is_empty() && (!context_files.is_empty() || !images.is_empty()) {
205        if !images.is_empty() && context_files.is_empty() {
206            "Describe and analyze this image.".to_string()
207        } else {
208            "Describe and explain the attached files.".to_string()
209        }
210    } else {
211        prompt
212    };
213
214    ProcessedInput {
215        prompt,
216        context_files,
217        images,
218        paste_blocks: Vec::new(),
219    }
220}
221
222/// Format file contexts into a string suitable for injection into the user
223/// message sent to the LLM.
224pub fn format_context_files(files: &[FileContext]) -> Option<String> {
225    if files.is_empty() {
226        return None;
227    }
228
229    let mut parts = Vec::new();
230    for f in files {
231        parts.push(format!(
232            "<file path=\"{}\">{}</file>",
233            f.path,
234            // Cap at ~40k chars (~10k tokens) per file
235            if f.content.len() > 40_000 {
236                // Snap to char boundary to avoid panic on multi-byte chars
237                let mut end = 40_000;
238                while !f.content.is_char_boundary(end) {
239                    end -= 1;
240                }
241                format!(
242                    "{}\n\n[truncated — {} bytes total]",
243                    &f.content[..end],
244                    f.content.len()
245                )
246            } else {
247                f.content.clone()
248            }
249        ));
250    }
251
252    Some(parts.join("\n\n"))
253}
254
255/// Pastes shorter than this go inline in the textarea; longer ones become PasteBlocks.
256pub const PASTE_BLOCK_THRESHOLD: usize = 200;
257
258/// Max chars per paste block (~40k chars, matching file truncation policy).
259const PASTE_BLOCK_MAX_CHARS: usize = 40_000;
260
261/// Format paste blocks into semantically tagged XML for the LLM.
262///
263/// Each block is wrapped in `<reference type="pasted" chars="N">...</reference>`
264/// so the model can distinguish pasted reference material from direct instructions.
265pub fn format_paste_blocks(blocks: &[PasteBlock]) -> Option<String> {
266    if blocks.is_empty() {
267        return None;
268    }
269
270    let parts: Vec<String> = blocks
271        .iter()
272        .map(|b| {
273            let content = if b.content.len() > PASTE_BLOCK_MAX_CHARS {
274                let mut end = PASTE_BLOCK_MAX_CHARS;
275                while !b.content.is_char_boundary(end) {
276                    end -= 1;
277                }
278                format!(
279                    "{}\n\n[truncated — {} chars total]",
280                    &b.content[..end],
281                    b.char_count
282                )
283            } else {
284                b.content.clone()
285            };
286            format!(
287                "<reference type=\"pasted\" chars=\"{}\">{}</reference>",
288                b.char_count, content
289            )
290        })
291        .collect();
292
293    Some(parts.join("\n\n"))
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299    use std::fs;
300    use tempfile::TempDir;
301
302    #[test]
303    fn test_process_input_with_file_ref() {
304        let dir = TempDir::new().unwrap();
305        fs::write(dir.path().join("test.rs"), "fn test() {}").unwrap();
306
307        let result = process_input("explain @test.rs", dir.path());
308        assert_eq!(result.prompt, "explain");
309        assert_eq!(result.context_files.len(), 1);
310        assert_eq!(result.context_files[0].path, "test.rs");
311        assert_eq!(result.context_files[0].content, "fn test() {}");
312    }
313
314    #[test]
315    fn test_process_input_no_refs() {
316        let dir = TempDir::new().unwrap();
317        let result = process_input("just a normal question", dir.path());
318        assert_eq!(result.prompt, "just a normal question");
319        assert!(result.context_files.is_empty());
320    }
321
322    #[test]
323    fn test_process_input_only_ref() {
324        let dir = TempDir::new().unwrap();
325        fs::write(dir.path().join("code.py"), "print('hi')").unwrap();
326
327        let result = process_input("@code.py", dir.path());
328        assert_eq!(result.prompt, "Describe and explain the attached files.");
329        assert_eq!(result.context_files.len(), 1);
330    }
331
332    #[test]
333    fn test_process_input_missing_file() {
334        let dir = TempDir::new().unwrap();
335        let result = process_input("explain @nonexistent.rs", dir.path());
336        // Missing file stays in prompt as-is
337        assert!(result.prompt.contains("@nonexistent.rs"));
338        assert!(result.context_files.is_empty());
339    }
340
341    #[test]
342    fn test_format_context_files_empty() {
343        assert!(format_context_files(&[]).is_none());
344    }
345
346    #[test]
347    fn test_format_context_files() {
348        let files = vec![FileContext {
349            path: "main.rs".into(),
350            content: "fn main() {}".into(),
351        }];
352        let result = format_context_files(&files).unwrap();
353        assert!(result.contains("<file path=\"main.rs\">"));
354        assert!(result.contains("fn main() {}"));
355        assert!(result.contains("</file>"));
356    }
357
358    #[test]
359    fn test_is_image_file() {
360        assert!(is_image_file("photo.png"));
361        assert!(is_image_file("photo.PNG"));
362        assert!(is_image_file("photo.jpg"));
363        assert!(is_image_file("photo.jpeg"));
364        assert!(is_image_file("photo.gif"));
365        assert!(is_image_file("photo.webp"));
366        assert!(is_image_file("photo.bmp"));
367        assert!(!is_image_file("code.rs"));
368        assert!(!is_image_file("data.json"));
369        assert!(!is_image_file("readme.md"));
370    }
371
372    #[test]
373    fn test_mime_type_for() {
374        assert_eq!(mime_type_for("x.png"), "image/png");
375        assert_eq!(mime_type_for("x.jpg"), "image/jpeg");
376        assert_eq!(mime_type_for("x.jpeg"), "image/jpeg");
377        assert_eq!(mime_type_for("x.gif"), "image/gif");
378        assert_eq!(mime_type_for("x.webp"), "image/webp");
379        assert_eq!(mime_type_for("x.bmp"), "image/bmp");
380    }
381
382    #[test]
383    fn test_process_input_image_ref() {
384        let dir = TempDir::new().unwrap();
385        // Create a tiny 1x1 PNG (valid minimal)
386        let png_bytes: [u8; 8] = [0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A];
387        fs::write(dir.path().join("screenshot.png"), png_bytes).unwrap();
388
389        let result = process_input("what is this @screenshot.png", dir.path());
390        assert_eq!(result.prompt, "what is this");
391        assert!(result.context_files.is_empty());
392        assert_eq!(result.images.len(), 1);
393        assert_eq!(result.images[0].media_type, "image/png");
394        assert!(!result.images[0].base64.is_empty());
395    }
396
397    #[test]
398    fn test_process_input_image_only_default_prompt() {
399        let dir = TempDir::new().unwrap();
400        let png_bytes: [u8; 8] = [0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A];
401        fs::write(dir.path().join("ui.png"), png_bytes).unwrap();
402
403        let result = process_input("@ui.png", dir.path());
404        assert_eq!(result.prompt, "Describe and analyze this image.");
405        assert_eq!(result.images.len(), 1);
406    }
407
408    #[test]
409    fn test_process_input_mixed_image_and_file() {
410        let dir = TempDir::new().unwrap();
411        fs::write(dir.path().join("code.rs"), "fn main() {}").unwrap();
412        let png_bytes: [u8; 8] = [0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A];
413        fs::write(dir.path().join("error.png"), png_bytes).unwrap();
414
415        let result = process_input("fix this @code.rs @error.png", dir.path());
416        assert_eq!(result.prompt, "fix this");
417        assert_eq!(result.context_files.len(), 1);
418        assert_eq!(result.images.len(), 1);
419    }
420
421    #[test]
422    fn test_strip_quotes() {
423        assert_eq!(strip_quotes("'/path/to/file.png'"), "/path/to/file.png");
424        assert_eq!(strip_quotes("\"/path/to/file.png\""), "/path/to/file.png");
425        assert_eq!(strip_quotes("/no/quotes.png"), "/no/quotes.png");
426        assert_eq!(strip_quotes("'mismatched"), "'mismatched");
427        assert_eq!(strip_quotes("'"), "'");
428        assert_eq!(strip_quotes("\""), "\"");
429    }
430
431    #[test]
432    fn test_looks_like_file_path() {
433        assert!(looks_like_file_path("/absolute/path.png"));
434        assert!(looks_like_file_path("~/Desktop/img.jpg"));
435        assert!(looks_like_file_path("./relative/img.png"));
436        assert!(looks_like_file_path("../parent/img.png"));
437        assert!(looks_like_file_path("'/quoted/path.png'"));
438        // Windows paths
439        assert!(looks_like_file_path("C:\\Users\\test\\img.png"));
440        assert!(looks_like_file_path("D:/tmp/img.png"));
441        assert!(!looks_like_file_path("just-a-word"));
442        assert!(!looks_like_file_path("relative.png"));
443    }
444
445    #[test]
446    fn test_drag_and_drop_absolute_path() {
447        let dir = TempDir::new().unwrap();
448        let png_bytes: [u8; 8] = [0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A];
449        let img_path = dir.path().join("screenshot.png");
450        fs::write(&img_path, png_bytes).unwrap();
451
452        let input = format!("what is this {}", img_path.display());
453        let result = process_input(&input, dir.path());
454        assert_eq!(result.prompt, "what is this");
455        assert_eq!(result.images.len(), 1);
456        assert_eq!(result.images[0].media_type, "image/png");
457    }
458
459    #[test]
460    fn test_drag_and_drop_quoted_path() {
461        let dir = TempDir::new().unwrap();
462        let png_bytes: [u8; 8] = [0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A];
463        let img_path = dir.path().join("screenshot.png");
464        fs::write(&img_path, png_bytes).unwrap();
465
466        // Single-quoted (some terminals do this)
467        let input = format!("explain '{}'", img_path.display());
468        let result = process_input(&input, dir.path());
469        assert_eq!(result.prompt, "explain");
470        assert_eq!(result.images.len(), 1);
471    }
472
473    #[test]
474    fn test_drag_and_drop_nonexistent_stays_in_prompt() {
475        let dir = TempDir::new().unwrap();
476        let input = "/tmp/nonexistent_image_12345.png what is this";
477        let result = process_input(input, dir.path());
478        // Non-existent file stays as text in prompt
479        assert!(result.prompt.contains("/tmp/nonexistent_image_12345.png"));
480        assert!(result.images.is_empty());
481    }
482
483    #[test]
484    fn test_non_image_absolute_path_stays_in_prompt() {
485        let dir = TempDir::new().unwrap();
486        fs::write(dir.path().join("data.json"), "{}").unwrap();
487        let input = format!("read {}", dir.path().join("data.json").display());
488        let result = process_input(&input, dir.path());
489        // Non-image absolute paths are NOT auto-consumed (only images)
490        assert!(result.prompt.contains("data.json"));
491        assert!(result.images.is_empty());
492    }
493
494    #[test]
495    fn test_resolve_bare_path_absolute() {
496        #[cfg(unix)]
497        {
498            let resolved = resolve_bare_path("/tmp/test.png");
499            assert_eq!(resolved, Some(PathBuf::from("/tmp/test.png")));
500        }
501        #[cfg(windows)]
502        {
503            let resolved = resolve_bare_path("C:\\tmp\\test.png");
504            assert_eq!(resolved, Some(PathBuf::from("C:\\tmp\\test.png")));
505        }
506    }
507
508    #[test]
509    fn test_resolve_bare_path_home() {
510        // Only works if HOME is set, which it always is in tests
511        if std::env::var("HOME").is_ok() {
512            let resolved = resolve_bare_path("~/test.png");
513            assert!(resolved.is_some());
514            let path = resolved.unwrap();
515            assert!(!path.to_string_lossy().contains('~'));
516            assert!(path.to_string_lossy().ends_with("test.png"));
517        }
518    }
519
520    #[test]
521    fn test_resolve_bare_path_quoted() {
522        #[cfg(unix)]
523        {
524            let resolved = resolve_bare_path("'/tmp/test.png'");
525            assert_eq!(resolved, Some(PathBuf::from("/tmp/test.png")));
526        }
527        #[cfg(windows)]
528        {
529            let resolved = resolve_bare_path("'C:\\tmp\\test.png'");
530            assert_eq!(resolved, Some(PathBuf::from("C:\\tmp\\test.png")));
531        }
532    }
533
534    #[test]
535    fn test_resolve_bare_path_relative() {
536        let resolved = resolve_bare_path("./test.png");
537        assert!(resolved.is_some());
538        // Should be resolved to an absolute path via cwd
539        assert!(resolved.unwrap().is_absolute());
540    }
541
542    #[test]
543    fn test_at_file_traversal_blocked() {
544        let dir = tempfile::tempdir().unwrap();
545        std::fs::write(dir.path().join("safe.rs"), "fn main() {}").unwrap();
546
547        let result = process_input("read @../../etc/passwd", dir.path());
548        // Traversal path should be rejected — no context files loaded
549        assert!(
550            result.context_files.is_empty(),
551            "traversal should not load files outside project root"
552        );
553        // The @ref should remain in the prompt as-is
554        assert!(result.prompt.contains("@../../etc/passwd"));
555    }
556
557    #[test]
558    fn test_format_paste_blocks_empty() {
559        assert!(format_paste_blocks(&[]).is_none());
560    }
561
562    #[test]
563    fn test_format_paste_blocks_single() {
564        let blocks = vec![PasteBlock {
565            content: "hello world".into(),
566            char_count: 11,
567        }];
568        let result = format_paste_blocks(&blocks).unwrap();
569        assert!(result.contains("<reference type=\"pasted\" chars=\"11\">"));
570        assert!(result.contains("hello world"));
571        assert!(result.contains("</reference>"));
572    }
573
574    #[test]
575    fn test_format_paste_blocks_multiple() {
576        let blocks = vec![
577            PasteBlock {
578                content: "block one".into(),
579                char_count: 9,
580            },
581            PasteBlock {
582                content: "block two".into(),
583                char_count: 9,
584            },
585        ];
586        let result = format_paste_blocks(&blocks).unwrap();
587        assert!(result.contains("block one"));
588        assert!(result.contains("block two"));
589        // Joined with double newline
590        assert!(result.contains("</reference>\n\n<reference"));
591    }
592
593    #[test]
594    fn test_format_paste_blocks_truncation() {
595        let long_content = "a".repeat(50_000);
596        let blocks = vec![PasteBlock {
597            content: long_content,
598            char_count: 50_000,
599        }];
600        let result = format_paste_blocks(&blocks).unwrap();
601        assert!(result.contains("[truncated — 50000 chars total]"));
602        // Should be capped, not the full 50k
603        assert!(result.len() < 45_000);
604    }
605}