Skip to main content

wisp/components/app/
attachments.rs

1use super::PromptAttachment;
2use agent_client_protocol as acp;
3use base64::Engine;
4use base64::engine::general_purpose::STANDARD as BASE64;
5use std::path::Path;
6use tokio::io::AsyncReadExt;
7use url::Url;
8
9const MAX_EMBED_TEXT_BYTES: usize = 1024 * 1024;
10const MAX_MEDIA_BYTES: usize = 10 * 1024 * 1024;
11
12const IMAGE_MIME_TYPES: &[&str] = &["image/png", "image/jpeg", "image/gif", "image/webp"];
13const AUDIO_MIME_TYPES: &[&str] = &["audio/wav", "audio/mpeg", "audio/mp3", "audio/ogg"];
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum AttachmentKind {
17    Text,
18    Image,
19    Audio,
20    Unsupported,
21}
22
23pub fn classify_attachment(path: &Path) -> AttachmentKind {
24    let mime = mime_guess::from_path(path)
25        .first_or_octet_stream()
26        .to_string();
27
28    if IMAGE_MIME_TYPES.contains(&mime.as_str()) {
29        AttachmentKind::Image
30    } else if AUDIO_MIME_TYPES.contains(&mime.as_str()) {
31        AttachmentKind::Audio
32    } else if mime.starts_with("text/") {
33        AttachmentKind::Text
34    } else {
35        AttachmentKind::Unsupported
36    }
37}
38
39#[derive(Debug, Default)]
40pub struct AttachmentBuildOutcome {
41    pub blocks: Vec<acp::ContentBlock>,
42    pub transcript_placeholders: Vec<String>,
43    pub warnings: Vec<String>,
44}
45
46pub async fn build_attachment_blocks(attachments: &[PromptAttachment]) -> AttachmentBuildOutcome {
47    let mut outcome = AttachmentBuildOutcome::default();
48
49    for attachment in attachments {
50        match try_build_attachment_block(&attachment.path, &attachment.display_name).await {
51            Ok(result) => {
52                outcome.blocks.push(result.block);
53                if let Some(placeholder) = result.transcript_placeholder {
54                    outcome.transcript_placeholders.push(placeholder);
55                }
56                if let Some(warning) = result.warning {
57                    outcome.warnings.push(warning);
58                }
59            }
60            Err(warning) => outcome.warnings.push(warning),
61        }
62    }
63
64    outcome
65}
66
67struct AttachmentBlockResult {
68    block: acp::ContentBlock,
69    transcript_placeholder: Option<String>,
70    warning: Option<String>,
71}
72
73async fn try_build_attachment_block(
74    path: &Path,
75    display_name: &str,
76) -> Result<AttachmentBlockResult, String> {
77    let kind = classify_attachment(path);
78    let mime_type = mime_guess::from_path(path)
79        .first_or_octet_stream()
80        .to_string();
81
82    match kind {
83        AttachmentKind::Image | AttachmentKind::Audio => {
84            let bytes = read_media_bytes(path, display_name).await?;
85            let data = BASE64.encode(&bytes);
86            let (block, placeholder) = match kind {
87                AttachmentKind::Image => (
88                    acp::ContentBlock::Image(acp::ImageContent::new(data, &mime_type)),
89                    format!("[image attachment: {display_name}]"),
90                ),
91                _ => (
92                    acp::ContentBlock::Audio(acp::AudioContent::new(data, &mime_type)),
93                    format!("[audio attachment: {display_name}]"),
94                ),
95            };
96            Ok(AttachmentBlockResult {
97                block,
98                transcript_placeholder: Some(placeholder),
99                warning: None,
100            })
101        }
102        _ => build_text_resource_block(path, display_name, &mime_type).await,
103    }
104}
105
106async fn read_media_bytes(path: &Path, display_name: &str) -> Result<Vec<u8>, String> {
107    let metadata = tokio::fs::metadata(path)
108        .await
109        .map_err(|e| format!("Failed to read {display_name}: {e}"))?;
110
111    if metadata.len() > MAX_MEDIA_BYTES as u64 {
112        return Err(format!(
113            "Skipped {display_name}: file too large ({} bytes, max {})",
114            metadata.len(),
115            MAX_MEDIA_BYTES
116        ));
117    }
118
119    tokio::fs::read(path)
120        .await
121        .map_err(|e| format!("Failed to read {display_name}: {e}"))
122}
123
124async fn build_text_resource_block(
125    path: &Path,
126    display_name: &str,
127    mime_type: &str,
128) -> Result<AttachmentBlockResult, String> {
129    let file = tokio::fs::File::open(path)
130        .await
131        .map_err(|error| format!("Failed to read {display_name}: {error}"))?;
132
133    let mut bytes = Vec::new();
134    file.take((MAX_EMBED_TEXT_BYTES + 1) as u64)
135        .read_to_end(&mut bytes)
136        .await
137        .map_err(|error| format!("Failed to read {display_name}: {error}"))?;
138
139    let truncated = bytes.len() > MAX_EMBED_TEXT_BYTES;
140    if truncated {
141        bytes.truncate(MAX_EMBED_TEXT_BYTES);
142    }
143    let text_bytes = bytes.as_slice();
144
145    let text = match std::str::from_utf8(text_bytes) {
146        Ok(text) => text.to_string(),
147        Err(error) if truncated && error.valid_up_to() > 0 => {
148            let valid_bytes = &text_bytes[..error.valid_up_to()];
149            std::str::from_utf8(valid_bytes)
150                .expect("valid_up_to must point at a utf8 boundary")
151                .to_string()
152        }
153        Err(_) => return Err(format!("Skipped binary or non-UTF8 file: {display_name}")),
154    };
155
156    let file_uri = build_attachment_file_uri(path, display_name).await?;
157    let warning =
158        truncated.then(|| format!("Truncated {display_name} to {MAX_EMBED_TEXT_BYTES} bytes"));
159
160    let block = acp::ContentBlock::Resource(acp::EmbeddedResource::new(
161        acp::EmbeddedResourceResource::TextResourceContents(
162            acp::TextResourceContents::new(text, file_uri).mime_type(mime_type),
163        ),
164    ));
165
166    Ok(AttachmentBlockResult {
167        block,
168        transcript_placeholder: None,
169        warning,
170    })
171}
172
173async fn build_attachment_file_uri(path: &Path, display_name: &str) -> Result<String, String> {
174    let canonical_path = tokio::fs::canonicalize(path).await.ok();
175    let uri_path = canonical_path.as_deref().unwrap_or(path);
176    Url::from_file_path(uri_path)
177        .map_err(|()| format!("Failed to build file URI for {display_name}"))
178        .map(|url| url.to_string())
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184    use tempfile::TempDir;
185
186    #[tokio::test]
187    async fn build_attachment_blocks_truncates_large_file_with_warning() {
188        let tmp = TempDir::new().unwrap();
189        let path = tmp.path().join("large.txt");
190        let display_name = "large.txt".to_string();
191        std::fs::write(&path, "x".repeat(MAX_EMBED_TEXT_BYTES + 64)).unwrap();
192
193        let attachments = vec![PromptAttachment {
194            path,
195            display_name: display_name.clone(),
196        }];
197        let outcome = build_attachment_blocks(&attachments).await;
198
199        assert_eq!(outcome.blocks.len(), 1);
200        assert_eq!(outcome.warnings.len(), 1);
201        assert!(outcome.warnings[0].contains(&format!(
202            "Truncated {display_name} to {MAX_EMBED_TEXT_BYTES} bytes"
203        )));
204    }
205
206    #[tokio::test]
207    async fn build_attachment_blocks_skips_non_utf8_files() {
208        let tmp = TempDir::new().unwrap();
209        let path = tmp.path().join("binary.bin");
210        let display_name = "binary.bin".to_string();
211        std::fs::write(&path, [0xff, 0xfe, 0xfd]).unwrap();
212
213        let attachments = vec![PromptAttachment {
214            path,
215            display_name: display_name.clone(),
216        }];
217        let outcome = build_attachment_blocks(&attachments).await;
218
219        assert!(outcome.blocks.is_empty());
220        assert_eq!(outcome.warnings.len(), 1);
221        assert!(
222            outcome.warnings[0]
223                .contains(&format!("Skipped binary or non-UTF8 file: {display_name}"))
224        );
225    }
226
227    #[tokio::test]
228    async fn build_attachment_file_uri_falls_back_when_canonicalize_fails() {
229        let tmp = TempDir::new().unwrap();
230        let path = tmp.path().join("missing.txt");
231        let expected = Url::from_file_path(&path).unwrap().to_string();
232
233        let uri = build_attachment_file_uri(&path, "missing.txt")
234            .await
235            .expect("URI should be built from original absolute path");
236
237        assert_eq!(uri, expected);
238    }
239
240    #[tokio::test]
241    async fn png_file_produces_image_content_block() {
242        let tmp = TempDir::new().unwrap();
243        let path = tmp.path().join("test.png");
244        std::fs::write(&path, b"fake png data").unwrap();
245
246        let attachments = vec![PromptAttachment {
247            path,
248            display_name: "test.png".to_string(),
249        }];
250        let outcome = build_attachment_blocks(&attachments).await;
251
252        assert_eq!(outcome.blocks.len(), 1);
253        assert!(outcome.warnings.is_empty());
254        assert_eq!(
255            outcome.transcript_placeholders,
256            vec!["[image attachment: test.png]"]
257        );
258        assert!(matches!(outcome.blocks[0], acp::ContentBlock::Image(_)));
259    }
260
261    #[tokio::test]
262    async fn wav_file_produces_audio_content_block() {
263        let tmp = TempDir::new().unwrap();
264        let path = tmp.path().join("test.wav");
265        std::fs::write(&path, b"fake wav data").unwrap();
266
267        let attachments = vec![PromptAttachment {
268            path,
269            display_name: "test.wav".to_string(),
270        }];
271        let outcome = build_attachment_blocks(&attachments).await;
272
273        assert_eq!(outcome.blocks.len(), 1);
274        assert!(outcome.warnings.is_empty());
275        assert_eq!(
276            outcome.transcript_placeholders,
277            vec!["[audio attachment: test.wav]"]
278        );
279        assert!(matches!(outcome.blocks[0], acp::ContentBlock::Audio(_)));
280    }
281
282    #[test]
283    fn classify_attachment_detects_images() {
284        assert_eq!(
285            classify_attachment(Path::new("photo.png")),
286            AttachmentKind::Image
287        );
288        assert_eq!(
289            classify_attachment(Path::new("photo.jpg")),
290            AttachmentKind::Image
291        );
292        assert_eq!(
293            classify_attachment(Path::new("photo.gif")),
294            AttachmentKind::Image
295        );
296        assert_eq!(
297            classify_attachment(Path::new("photo.webp")),
298            AttachmentKind::Image
299        );
300    }
301
302    #[test]
303    fn classify_attachment_detects_audio() {
304        assert_eq!(
305            classify_attachment(Path::new("note.wav")),
306            AttachmentKind::Audio
307        );
308        assert_eq!(
309            classify_attachment(Path::new("note.mp3")),
310            AttachmentKind::Audio
311        );
312        assert_eq!(
313            classify_attachment(Path::new("note.ogg")),
314            AttachmentKind::Audio
315        );
316    }
317
318    #[test]
319    fn classify_attachment_detects_text() {
320        assert_eq!(
321            classify_attachment(Path::new("readme.txt")),
322            AttachmentKind::Text
323        );
324    }
325
326    #[test]
327    fn classify_attachment_unknown_extension_is_unsupported() {
328        assert_eq!(
329            classify_attachment(Path::new("data.xyz")),
330            AttachmentKind::Unsupported
331        );
332    }
333}