use koda_core::providers::ImageData;
use std::path::{Path, PathBuf};
#[derive(Debug, Clone)]
pub struct PasteBlock {
pub content: String,
pub char_count: usize,
}
#[derive(Debug)]
pub struct ProcessedInput {
pub prompt: String,
pub context_files: Vec<FileContext>,
pub images: Vec<ImageData>,
pub paste_blocks: Vec<PasteBlock>,
}
#[derive(Debug)]
pub struct FileContext {
pub path: String,
pub content: String,
}
const IMAGE_EXTENSIONS: &[&str] = &["png", "jpg", "jpeg", "gif", "webp", "bmp"];
fn is_image_file(path: &str) -> bool {
let lower = path.to_lowercase();
IMAGE_EXTENSIONS.iter().any(|ext| lower.ends_with(ext))
}
fn mime_type_for(path: &str) -> &'static str {
let lower = path.to_lowercase();
if lower.ends_with(".png") {
"image/png"
} else if lower.ends_with(".jpg") || lower.ends_with(".jpeg") {
"image/jpeg"
} else if lower.ends_with(".gif") {
"image/gif"
} else if lower.ends_with(".webp") {
"image/webp"
} else if lower.ends_with(".bmp") {
"image/bmp"
} else {
"application/octet-stream"
}
}
fn strip_quotes(s: &str) -> &str {
if s.len() >= 2
&& ((s.starts_with('"') && s.ends_with('"')) || (s.starts_with('\'') && s.ends_with('\'')))
{
&s[1..s.len() - 1]
} else {
s
}
}
fn looks_like_file_path(token: &str) -> bool {
let cleaned = strip_quotes(token);
cleaned.starts_with('/')
|| cleaned.starts_with("~/")
|| cleaned.starts_with("./")
|| cleaned.starts_with("..")
|| (cleaned.len() >= 3
&& cleaned.as_bytes()[0].is_ascii_alphabetic()
&& cleaned.as_bytes()[1] == b':'
&& (cleaned.as_bytes()[2] == b'\\' || cleaned.as_bytes()[2] == b'/'))
}
fn try_load_image(path: &Path, display_path: &str) -> Option<ImageData> {
match std::fs::read(path) {
Ok(bytes) => {
use base64::Engine;
let b64 = base64::engine::general_purpose::STANDARD.encode(&bytes);
let media_type = mime_type_for(display_path).to_string();
Some(ImageData {
media_type,
base64: b64,
})
}
Err(_) => {
eprintln!(" \x1b[33m\u{26a0} Could not read image: {display_path}\x1b[0m");
None
}
}
}
fn resolve_bare_path(token: &str) -> Option<PathBuf> {
let cleaned = strip_quotes(token);
if let Some(rest) = cleaned.strip_prefix("~/") {
let home = std::env::var("HOME")
.or_else(|_| std::env::var("USERPROFILE"))
.ok()?;
Some(PathBuf::from(home).join(rest))
} else {
let p = PathBuf::from(cleaned);
if p.is_absolute() {
Some(p)
} else {
std::env::current_dir().ok().map(|cwd| cwd.join(cleaned))
}
}
}
pub fn process_input(input: &str, project_root: &Path) -> ProcessedInput {
let mut prompt_parts = Vec::new();
let mut context_files = Vec::new();
let mut images = Vec::new();
for token in input.split_whitespace() {
if let Some(raw_path) = token.strip_prefix('@') {
if raw_path.is_empty() {
prompt_parts.push(token.to_string());
continue;
}
let raw_path = strip_quotes(raw_path);
let full_path = match koda_core::tools::safe_resolve_path(project_root, raw_path) {
Ok(p) => p,
Err(_) => {
tracing::warn!("@file path escapes project root: {raw_path}");
prompt_parts.push(token.to_string());
continue;
}
};
if is_image_file(raw_path) {
if let Some(img) = try_load_image(&full_path, raw_path) {
images.push(img);
} else {
prompt_parts.push(token.to_string());
}
continue;
}
match std::fs::read_to_string(&full_path) {
Ok(content) => {
context_files.push(FileContext {
path: raw_path.to_string(),
content,
});
}
Err(_) => {
eprintln!(" \x1b[33m\u{26a0} Could not read: {raw_path}\x1b[0m");
prompt_parts.push(token.to_string());
}
}
continue;
}
let unquoted = strip_quotes(token);
if looks_like_file_path(token)
&& is_image_file(unquoted)
&& let Some(resolved) = resolve_bare_path(token)
&& resolved.exists()
{
let display = resolved.display().to_string();
if let Some(img) = try_load_image(&resolved, &display) {
images.push(img);
continue;
}
}
prompt_parts.push(token.to_string());
}
let prompt = prompt_parts.join(" ");
let prompt = if prompt.trim().is_empty() && (!context_files.is_empty() || !images.is_empty()) {
if !images.is_empty() && context_files.is_empty() {
"Describe and analyze this image.".to_string()
} else {
"Describe and explain the attached files.".to_string()
}
} else {
prompt
};
ProcessedInput {
prompt,
context_files,
images,
paste_blocks: Vec::new(),
}
}
pub fn format_context_files(files: &[FileContext]) -> Option<String> {
if files.is_empty() {
return None;
}
let mut parts = Vec::new();
for f in files {
parts.push(format!(
"<file path=\"{}\">{}</file>",
f.path,
if f.content.len() > 40_000 {
let mut end = 40_000;
while !f.content.is_char_boundary(end) {
end -= 1;
}
format!(
"{}\n\n[truncated — {} bytes total]",
&f.content[..end],
f.content.len()
)
} else {
f.content.clone()
}
));
}
Some(parts.join("\n\n"))
}
pub const PASTE_BLOCK_THRESHOLD: usize = 200;
const PASTE_BLOCK_MAX_CHARS: usize = 40_000;
pub fn format_paste_blocks(blocks: &[PasteBlock]) -> Option<String> {
if blocks.is_empty() {
return None;
}
let parts: Vec<String> = blocks
.iter()
.map(|b| {
let content = if b.content.len() > PASTE_BLOCK_MAX_CHARS {
let mut end = PASTE_BLOCK_MAX_CHARS;
while !b.content.is_char_boundary(end) {
end -= 1;
}
format!(
"{}\n\n[truncated — {} chars total]",
&b.content[..end],
b.char_count
)
} else {
b.content.clone()
};
format!(
"<reference type=\"pasted\" chars=\"{}\">{}</reference>",
b.char_count, content
)
})
.collect();
Some(parts.join("\n\n"))
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
#[test]
fn test_process_input_with_file_ref() {
let dir = TempDir::new().unwrap();
fs::write(dir.path().join("test.rs"), "fn test() {}").unwrap();
let result = process_input("explain @test.rs", dir.path());
assert_eq!(result.prompt, "explain");
assert_eq!(result.context_files.len(), 1);
assert_eq!(result.context_files[0].path, "test.rs");
assert_eq!(result.context_files[0].content, "fn test() {}");
}
#[test]
fn test_process_input_no_refs() {
let dir = TempDir::new().unwrap();
let result = process_input("just a normal question", dir.path());
assert_eq!(result.prompt, "just a normal question");
assert!(result.context_files.is_empty());
}
#[test]
fn test_process_input_only_ref() {
let dir = TempDir::new().unwrap();
fs::write(dir.path().join("code.py"), "print('hi')").unwrap();
let result = process_input("@code.py", dir.path());
assert_eq!(result.prompt, "Describe and explain the attached files.");
assert_eq!(result.context_files.len(), 1);
}
#[test]
fn test_process_input_missing_file() {
let dir = TempDir::new().unwrap();
let result = process_input("explain @nonexistent.rs", dir.path());
assert!(result.prompt.contains("@nonexistent.rs"));
assert!(result.context_files.is_empty());
}
#[test]
fn test_format_context_files_empty() {
assert!(format_context_files(&[]).is_none());
}
#[test]
fn test_format_context_files() {
let files = vec![FileContext {
path: "main.rs".into(),
content: "fn main() {}".into(),
}];
let result = format_context_files(&files).unwrap();
assert!(result.contains("<file path=\"main.rs\">"));
assert!(result.contains("fn main() {}"));
assert!(result.contains("</file>"));
}
#[test]
fn test_is_image_file() {
assert!(is_image_file("photo.png"));
assert!(is_image_file("photo.PNG"));
assert!(is_image_file("photo.jpg"));
assert!(is_image_file("photo.jpeg"));
assert!(is_image_file("photo.gif"));
assert!(is_image_file("photo.webp"));
assert!(is_image_file("photo.bmp"));
assert!(!is_image_file("code.rs"));
assert!(!is_image_file("data.json"));
assert!(!is_image_file("readme.md"));
}
#[test]
fn test_mime_type_for() {
assert_eq!(mime_type_for("x.png"), "image/png");
assert_eq!(mime_type_for("x.jpg"), "image/jpeg");
assert_eq!(mime_type_for("x.jpeg"), "image/jpeg");
assert_eq!(mime_type_for("x.gif"), "image/gif");
assert_eq!(mime_type_for("x.webp"), "image/webp");
assert_eq!(mime_type_for("x.bmp"), "image/bmp");
}
#[test]
fn test_process_input_image_ref() {
let dir = TempDir::new().unwrap();
let png_bytes: [u8; 8] = [0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A];
fs::write(dir.path().join("screenshot.png"), png_bytes).unwrap();
let result = process_input("what is this @screenshot.png", dir.path());
assert_eq!(result.prompt, "what is this");
assert!(result.context_files.is_empty());
assert_eq!(result.images.len(), 1);
assert_eq!(result.images[0].media_type, "image/png");
assert!(!result.images[0].base64.is_empty());
}
#[test]
fn test_process_input_image_only_default_prompt() {
let dir = TempDir::new().unwrap();
let png_bytes: [u8; 8] = [0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A];
fs::write(dir.path().join("ui.png"), png_bytes).unwrap();
let result = process_input("@ui.png", dir.path());
assert_eq!(result.prompt, "Describe and analyze this image.");
assert_eq!(result.images.len(), 1);
}
#[test]
fn test_process_input_mixed_image_and_file() {
let dir = TempDir::new().unwrap();
fs::write(dir.path().join("code.rs"), "fn main() {}").unwrap();
let png_bytes: [u8; 8] = [0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A];
fs::write(dir.path().join("error.png"), png_bytes).unwrap();
let result = process_input("fix this @code.rs @error.png", dir.path());
assert_eq!(result.prompt, "fix this");
assert_eq!(result.context_files.len(), 1);
assert_eq!(result.images.len(), 1);
}
#[test]
fn test_strip_quotes() {
assert_eq!(strip_quotes("'/path/to/file.png'"), "/path/to/file.png");
assert_eq!(strip_quotes("\"/path/to/file.png\""), "/path/to/file.png");
assert_eq!(strip_quotes("/no/quotes.png"), "/no/quotes.png");
assert_eq!(strip_quotes("'mismatched"), "'mismatched");
assert_eq!(strip_quotes("'"), "'");
assert_eq!(strip_quotes("\""), "\"");
}
#[test]
fn test_looks_like_file_path() {
assert!(looks_like_file_path("/absolute/path.png"));
assert!(looks_like_file_path("~/Desktop/img.jpg"));
assert!(looks_like_file_path("./relative/img.png"));
assert!(looks_like_file_path("../parent/img.png"));
assert!(looks_like_file_path("'/quoted/path.png'"));
assert!(looks_like_file_path("C:\\Users\\test\\img.png"));
assert!(looks_like_file_path("D:/tmp/img.png"));
assert!(!looks_like_file_path("just-a-word"));
assert!(!looks_like_file_path("relative.png"));
}
#[test]
fn test_drag_and_drop_absolute_path() {
let dir = TempDir::new().unwrap();
let png_bytes: [u8; 8] = [0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A];
let img_path = dir.path().join("screenshot.png");
fs::write(&img_path, png_bytes).unwrap();
let input = format!("what is this {}", img_path.display());
let result = process_input(&input, dir.path());
assert_eq!(result.prompt, "what is this");
assert_eq!(result.images.len(), 1);
assert_eq!(result.images[0].media_type, "image/png");
}
#[test]
fn test_drag_and_drop_quoted_path() {
let dir = TempDir::new().unwrap();
let png_bytes: [u8; 8] = [0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A];
let img_path = dir.path().join("screenshot.png");
fs::write(&img_path, png_bytes).unwrap();
let input = format!("explain '{}'", img_path.display());
let result = process_input(&input, dir.path());
assert_eq!(result.prompt, "explain");
assert_eq!(result.images.len(), 1);
}
#[test]
fn test_drag_and_drop_nonexistent_stays_in_prompt() {
let dir = TempDir::new().unwrap();
let input = "/tmp/nonexistent_image_12345.png what is this";
let result = process_input(input, dir.path());
assert!(result.prompt.contains("/tmp/nonexistent_image_12345.png"));
assert!(result.images.is_empty());
}
#[test]
fn test_non_image_absolute_path_stays_in_prompt() {
let dir = TempDir::new().unwrap();
fs::write(dir.path().join("data.json"), "{}").unwrap();
let input = format!("read {}", dir.path().join("data.json").display());
let result = process_input(&input, dir.path());
assert!(result.prompt.contains("data.json"));
assert!(result.images.is_empty());
}
#[test]
fn test_resolve_bare_path_absolute() {
#[cfg(unix)]
{
let resolved = resolve_bare_path("/tmp/test.png");
assert_eq!(resolved, Some(PathBuf::from("/tmp/test.png")));
}
#[cfg(windows)]
{
let resolved = resolve_bare_path("C:\\tmp\\test.png");
assert_eq!(resolved, Some(PathBuf::from("C:\\tmp\\test.png")));
}
}
#[test]
fn test_resolve_bare_path_home() {
if std::env::var("HOME").is_ok() {
let resolved = resolve_bare_path("~/test.png");
assert!(resolved.is_some());
let path = resolved.unwrap();
assert!(!path.to_string_lossy().contains('~'));
assert!(path.to_string_lossy().ends_with("test.png"));
}
}
#[test]
fn test_resolve_bare_path_quoted() {
#[cfg(unix)]
{
let resolved = resolve_bare_path("'/tmp/test.png'");
assert_eq!(resolved, Some(PathBuf::from("/tmp/test.png")));
}
#[cfg(windows)]
{
let resolved = resolve_bare_path("'C:\\tmp\\test.png'");
assert_eq!(resolved, Some(PathBuf::from("C:\\tmp\\test.png")));
}
}
#[test]
fn test_resolve_bare_path_relative() {
let resolved = resolve_bare_path("./test.png");
assert!(resolved.is_some());
assert!(resolved.unwrap().is_absolute());
}
#[test]
fn test_at_file_traversal_blocked() {
let dir = tempfile::tempdir().unwrap();
std::fs::write(dir.path().join("safe.rs"), "fn main() {}").unwrap();
let result = process_input("read @../../etc/passwd", dir.path());
assert!(
result.context_files.is_empty(),
"traversal should not load files outside project root"
);
assert!(result.prompt.contains("@../../etc/passwd"));
}
#[test]
fn test_format_paste_blocks_empty() {
assert!(format_paste_blocks(&[]).is_none());
}
#[test]
fn test_format_paste_blocks_single() {
let blocks = vec![PasteBlock {
content: "hello world".into(),
char_count: 11,
}];
let result = format_paste_blocks(&blocks).unwrap();
assert!(result.contains("<reference type=\"pasted\" chars=\"11\">"));
assert!(result.contains("hello world"));
assert!(result.contains("</reference>"));
}
#[test]
fn test_format_paste_blocks_multiple() {
let blocks = vec![
PasteBlock {
content: "block one".into(),
char_count: 9,
},
PasteBlock {
content: "block two".into(),
char_count: 9,
},
];
let result = format_paste_blocks(&blocks).unwrap();
assert!(result.contains("block one"));
assert!(result.contains("block two"));
assert!(result.contains("</reference>\n\n<reference"));
}
#[test]
fn test_format_paste_blocks_truncation() {
let long_content = "a".repeat(50_000);
let blocks = vec![PasteBlock {
content: long_content,
char_count: 50_000,
}];
let result = format_paste_blocks(&blocks).unwrap();
assert!(result.contains("[truncated — 50000 chars total]"));
assert!(result.len() < 45_000);
}
}