Skip to main content

wisp/components/app/
attachments.rs

1use super::PromptAttachment;
2use agent_client_protocol as acp;
3use base64::Engine;
4use base64::engine::general_purpose::STANDARD as BASE64;
5use std::path::Path;
6use tokio::io::AsyncReadExt;
7use url::Url;
8
9const MAX_EMBED_TEXT_BYTES: usize = 1024 * 1024;
10const MAX_MEDIA_BYTES: usize = 10 * 1024 * 1024;
11
12const IMAGE_MIME_TYPES: &[&str] = &["image/png", "image/jpeg", "image/gif", "image/webp"];
13const AUDIO_MIME_TYPES: &[&str] = &["audio/wav", "audio/mpeg", "audio/mp3", "audio/ogg"];
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum AttachmentKind {
17    Text,
18    Image,
19    Audio,
20    Unsupported,
21}
22
23pub fn classify_attachment(path: &Path) -> AttachmentKind {
24    let mime = mime_guess::from_path(path).first_or_octet_stream().to_string();
25
26    if IMAGE_MIME_TYPES.contains(&mime.as_str()) {
27        AttachmentKind::Image
28    } else if AUDIO_MIME_TYPES.contains(&mime.as_str()) {
29        AttachmentKind::Audio
30    } else if mime.starts_with("text/") {
31        AttachmentKind::Text
32    } else {
33        AttachmentKind::Unsupported
34    }
35}
36
37#[derive(Debug, Default)]
38pub struct AttachmentBuildOutcome {
39    pub blocks: Vec<acp::ContentBlock>,
40    pub transcript_placeholders: Vec<String>,
41    pub warnings: Vec<String>,
42}
43
44pub async fn build_attachment_blocks(attachments: &[PromptAttachment]) -> AttachmentBuildOutcome {
45    let mut outcome = AttachmentBuildOutcome::default();
46
47    for attachment in attachments {
48        match try_build_attachment_block(&attachment.path, &attachment.display_name).await {
49            Ok(result) => {
50                outcome.blocks.push(result.block);
51                if let Some(placeholder) = result.transcript_placeholder {
52                    outcome.transcript_placeholders.push(placeholder);
53                }
54                if let Some(warning) = result.warning {
55                    outcome.warnings.push(warning);
56                }
57            }
58            Err(warning) => outcome.warnings.push(warning),
59        }
60    }
61
62    outcome
63}
64
65struct AttachmentBlockResult {
66    block: acp::ContentBlock,
67    transcript_placeholder: Option<String>,
68    warning: Option<String>,
69}
70
71async fn try_build_attachment_block(path: &Path, display_name: &str) -> Result<AttachmentBlockResult, String> {
72    let kind = classify_attachment(path);
73    let mime_type = mime_guess::from_path(path).first_or_octet_stream().to_string();
74
75    match kind {
76        AttachmentKind::Image | AttachmentKind::Audio => {
77            let bytes = read_media_bytes(path, display_name).await?;
78            let data = BASE64.encode(&bytes);
79            let (block, placeholder) = match kind {
80                AttachmentKind::Image => (
81                    acp::ContentBlock::Image(acp::ImageContent::new(data, &mime_type)),
82                    format!("[image attachment: {display_name}]"),
83                ),
84                _ => (
85                    acp::ContentBlock::Audio(acp::AudioContent::new(data, &mime_type)),
86                    format!("[audio attachment: {display_name}]"),
87                ),
88            };
89            Ok(AttachmentBlockResult { block, transcript_placeholder: Some(placeholder), warning: None })
90        }
91        _ => build_text_resource_block(path, display_name, &mime_type).await,
92    }
93}
94
95async fn read_media_bytes(path: &Path, display_name: &str) -> Result<Vec<u8>, String> {
96    let metadata = tokio::fs::metadata(path).await.map_err(|e| format!("Failed to read {display_name}: {e}"))?;
97
98    if metadata.len() > MAX_MEDIA_BYTES as u64 {
99        return Err(format!(
100            "Skipped {display_name}: file too large ({} bytes, max {})",
101            metadata.len(),
102            MAX_MEDIA_BYTES
103        ));
104    }
105
106    tokio::fs::read(path).await.map_err(|e| format!("Failed to read {display_name}: {e}"))
107}
108
109async fn build_text_resource_block(
110    path: &Path,
111    display_name: &str,
112    mime_type: &str,
113) -> Result<AttachmentBlockResult, String> {
114    let file = tokio::fs::File::open(path).await.map_err(|error| format!("Failed to read {display_name}: {error}"))?;
115
116    let mut bytes = Vec::new();
117    file.take((MAX_EMBED_TEXT_BYTES + 1) as u64)
118        .read_to_end(&mut bytes)
119        .await
120        .map_err(|error| format!("Failed to read {display_name}: {error}"))?;
121
122    let truncated = bytes.len() > MAX_EMBED_TEXT_BYTES;
123    if truncated {
124        bytes.truncate(MAX_EMBED_TEXT_BYTES);
125    }
126    let text_bytes = bytes.as_slice();
127
128    let text = match std::str::from_utf8(text_bytes) {
129        Ok(text) => text.to_string(),
130        Err(error) if truncated && error.valid_up_to() > 0 => {
131            let valid_bytes = &text_bytes[..error.valid_up_to()];
132            std::str::from_utf8(valid_bytes).expect("valid_up_to must point at a utf8 boundary").to_string()
133        }
134        Err(_) => return Err(format!("Skipped binary or non-UTF8 file: {display_name}")),
135    };
136
137    let file_uri = build_attachment_file_uri(path, display_name).await?;
138    let warning = truncated.then(|| format!("Truncated {display_name} to {MAX_EMBED_TEXT_BYTES} bytes"));
139
140    let block =
141        acp::ContentBlock::Resource(acp::EmbeddedResource::new(acp::EmbeddedResourceResource::TextResourceContents(
142            acp::TextResourceContents::new(text, file_uri).mime_type(mime_type),
143        )));
144
145    Ok(AttachmentBlockResult { block, transcript_placeholder: None, warning })
146}
147
148async fn build_attachment_file_uri(path: &Path, display_name: &str) -> Result<String, String> {
149    let canonical_path = tokio::fs::canonicalize(path).await.ok();
150    let uri_path = canonical_path.as_deref().unwrap_or(path);
151    Url::from_file_path(uri_path)
152        .map_err(|()| format!("Failed to build file URI for {display_name}"))
153        .map(|url| url.to_string())
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use tempfile::TempDir;
160
161    #[tokio::test]
162    async fn build_attachment_blocks_truncates_large_file_with_warning() {
163        let tmp = TempDir::new().unwrap();
164        let path = tmp.path().join("large.txt");
165        let display_name = "large.txt".to_string();
166        std::fs::write(&path, "x".repeat(MAX_EMBED_TEXT_BYTES + 64)).unwrap();
167
168        let attachments = vec![PromptAttachment { path, display_name: display_name.clone() }];
169        let outcome = build_attachment_blocks(&attachments).await;
170
171        assert_eq!(outcome.blocks.len(), 1);
172        assert_eq!(outcome.warnings.len(), 1);
173        assert!(outcome.warnings[0].contains(&format!("Truncated {display_name} to {MAX_EMBED_TEXT_BYTES} bytes")));
174    }
175
176    #[tokio::test]
177    async fn build_attachment_blocks_skips_non_utf8_files() {
178        let tmp = TempDir::new().unwrap();
179        let path = tmp.path().join("binary.bin");
180        let display_name = "binary.bin".to_string();
181        std::fs::write(&path, [0xff, 0xfe, 0xfd]).unwrap();
182
183        let attachments = vec![PromptAttachment { path, display_name: display_name.clone() }];
184        let outcome = build_attachment_blocks(&attachments).await;
185
186        assert!(outcome.blocks.is_empty());
187        assert_eq!(outcome.warnings.len(), 1);
188        assert!(outcome.warnings[0].contains(&format!("Skipped binary or non-UTF8 file: {display_name}")));
189    }
190
191    #[tokio::test]
192    async fn build_attachment_file_uri_falls_back_when_canonicalize_fails() {
193        let tmp = TempDir::new().unwrap();
194        let path = tmp.path().join("missing.txt");
195        let expected = Url::from_file_path(&path).unwrap().to_string();
196
197        let uri = build_attachment_file_uri(&path, "missing.txt")
198            .await
199            .expect("URI should be built from original absolute path");
200
201        assert_eq!(uri, expected);
202    }
203
204    #[tokio::test]
205    async fn png_file_produces_image_content_block() {
206        let tmp = TempDir::new().unwrap();
207        let path = tmp.path().join("test.png");
208        std::fs::write(&path, b"fake png data").unwrap();
209
210        let attachments = vec![PromptAttachment { path, display_name: "test.png".to_string() }];
211        let outcome = build_attachment_blocks(&attachments).await;
212
213        assert_eq!(outcome.blocks.len(), 1);
214        assert!(outcome.warnings.is_empty());
215        assert_eq!(outcome.transcript_placeholders, vec!["[image attachment: test.png]"]);
216        assert!(matches!(outcome.blocks[0], acp::ContentBlock::Image(_)));
217    }
218
219    #[tokio::test]
220    async fn wav_file_produces_audio_content_block() {
221        let tmp = TempDir::new().unwrap();
222        let path = tmp.path().join("test.wav");
223        std::fs::write(&path, b"fake wav data").unwrap();
224
225        let attachments = vec![PromptAttachment { path, display_name: "test.wav".to_string() }];
226        let outcome = build_attachment_blocks(&attachments).await;
227
228        assert_eq!(outcome.blocks.len(), 1);
229        assert!(outcome.warnings.is_empty());
230        assert_eq!(outcome.transcript_placeholders, vec!["[audio attachment: test.wav]"]);
231        assert!(matches!(outcome.blocks[0], acp::ContentBlock::Audio(_)));
232    }
233
234    #[test]
235    fn classify_attachment_detects_images() {
236        assert_eq!(classify_attachment(Path::new("photo.png")), AttachmentKind::Image);
237        assert_eq!(classify_attachment(Path::new("photo.jpg")), AttachmentKind::Image);
238        assert_eq!(classify_attachment(Path::new("photo.gif")), AttachmentKind::Image);
239        assert_eq!(classify_attachment(Path::new("photo.webp")), AttachmentKind::Image);
240    }
241
242    #[test]
243    fn classify_attachment_detects_audio() {
244        assert_eq!(classify_attachment(Path::new("note.wav")), AttachmentKind::Audio);
245        assert_eq!(classify_attachment(Path::new("note.mp3")), AttachmentKind::Audio);
246        assert_eq!(classify_attachment(Path::new("note.ogg")), AttachmentKind::Audio);
247    }
248
249    #[test]
250    fn classify_attachment_detects_text() {
251        assert_eq!(classify_attachment(Path::new("readme.txt")), AttachmentKind::Text);
252    }
253
254    #[test]
255    fn classify_attachment_unknown_extension_is_unsupported() {
256        assert_eq!(classify_attachment(Path::new("data.xyz")), AttachmentKind::Unsupported);
257    }
258}