1use super::PromptAttachment;
2use agent_client_protocol as acp;
3use base64::Engine;
4use base64::engine::general_purpose::STANDARD as BASE64;
5use std::path::Path;
6use tokio::io::AsyncReadExt;
7use url::Url;
8
9const MAX_EMBED_TEXT_BYTES: usize = 1024 * 1024;
10const MAX_MEDIA_BYTES: usize = 10 * 1024 * 1024;
11
12const IMAGE_MIME_TYPES: &[&str] = &["image/png", "image/jpeg", "image/gif", "image/webp"];
13const AUDIO_MIME_TYPES: &[&str] = &["audio/wav", "audio/mpeg", "audio/mp3", "audio/ogg"];
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum AttachmentKind {
17 Text,
18 Image,
19 Audio,
20 Unsupported,
21}
22
23pub fn classify_attachment(path: &Path) -> AttachmentKind {
24 let mime = mime_guess::from_path(path)
25 .first_or_octet_stream()
26 .to_string();
27
28 if IMAGE_MIME_TYPES.contains(&mime.as_str()) {
29 AttachmentKind::Image
30 } else if AUDIO_MIME_TYPES.contains(&mime.as_str()) {
31 AttachmentKind::Audio
32 } else if mime.starts_with("text/") {
33 AttachmentKind::Text
34 } else {
35 AttachmentKind::Unsupported
36 }
37}
38
39#[derive(Debug, Default)]
40pub struct AttachmentBuildOutcome {
41 pub blocks: Vec<acp::ContentBlock>,
42 pub transcript_placeholders: Vec<String>,
43 pub warnings: Vec<String>,
44}
45
46pub async fn build_attachment_blocks(attachments: &[PromptAttachment]) -> AttachmentBuildOutcome {
47 let mut outcome = AttachmentBuildOutcome::default();
48
49 for attachment in attachments {
50 match try_build_attachment_block(&attachment.path, &attachment.display_name).await {
51 Ok(result) => {
52 outcome.blocks.push(result.block);
53 if let Some(placeholder) = result.transcript_placeholder {
54 outcome.transcript_placeholders.push(placeholder);
55 }
56 if let Some(warning) = result.warning {
57 outcome.warnings.push(warning);
58 }
59 }
60 Err(warning) => outcome.warnings.push(warning),
61 }
62 }
63
64 outcome
65}
66
67struct AttachmentBlockResult {
68 block: acp::ContentBlock,
69 transcript_placeholder: Option<String>,
70 warning: Option<String>,
71}
72
73async fn try_build_attachment_block(
74 path: &Path,
75 display_name: &str,
76) -> Result<AttachmentBlockResult, String> {
77 let kind = classify_attachment(path);
78 let mime_type = mime_guess::from_path(path)
79 .first_or_octet_stream()
80 .to_string();
81
82 match kind {
83 AttachmentKind::Image | AttachmentKind::Audio => {
84 let bytes = read_media_bytes(path, display_name).await?;
85 let data = BASE64.encode(&bytes);
86 let (block, placeholder) = match kind {
87 AttachmentKind::Image => (
88 acp::ContentBlock::Image(acp::ImageContent::new(data, &mime_type)),
89 format!("[image attachment: {display_name}]"),
90 ),
91 _ => (
92 acp::ContentBlock::Audio(acp::AudioContent::new(data, &mime_type)),
93 format!("[audio attachment: {display_name}]"),
94 ),
95 };
96 Ok(AttachmentBlockResult {
97 block,
98 transcript_placeholder: Some(placeholder),
99 warning: None,
100 })
101 }
102 _ => build_text_resource_block(path, display_name, &mime_type).await,
103 }
104}
105
106async fn read_media_bytes(path: &Path, display_name: &str) -> Result<Vec<u8>, String> {
107 let metadata = tokio::fs::metadata(path)
108 .await
109 .map_err(|e| format!("Failed to read {display_name}: {e}"))?;
110
111 if metadata.len() > MAX_MEDIA_BYTES as u64 {
112 return Err(format!(
113 "Skipped {display_name}: file too large ({} bytes, max {})",
114 metadata.len(),
115 MAX_MEDIA_BYTES
116 ));
117 }
118
119 tokio::fs::read(path)
120 .await
121 .map_err(|e| format!("Failed to read {display_name}: {e}"))
122}
123
124async fn build_text_resource_block(
125 path: &Path,
126 display_name: &str,
127 mime_type: &str,
128) -> Result<AttachmentBlockResult, String> {
129 let file = tokio::fs::File::open(path)
130 .await
131 .map_err(|error| format!("Failed to read {display_name}: {error}"))?;
132
133 let mut bytes = Vec::new();
134 file.take((MAX_EMBED_TEXT_BYTES + 1) as u64)
135 .read_to_end(&mut bytes)
136 .await
137 .map_err(|error| format!("Failed to read {display_name}: {error}"))?;
138
139 let truncated = bytes.len() > MAX_EMBED_TEXT_BYTES;
140 if truncated {
141 bytes.truncate(MAX_EMBED_TEXT_BYTES);
142 }
143 let text_bytes = bytes.as_slice();
144
145 let text = match std::str::from_utf8(text_bytes) {
146 Ok(text) => text.to_string(),
147 Err(error) if truncated && error.valid_up_to() > 0 => {
148 let valid_bytes = &text_bytes[..error.valid_up_to()];
149 std::str::from_utf8(valid_bytes)
150 .expect("valid_up_to must point at a utf8 boundary")
151 .to_string()
152 }
153 Err(_) => return Err(format!("Skipped binary or non-UTF8 file: {display_name}")),
154 };
155
156 let file_uri = build_attachment_file_uri(path, display_name).await?;
157 let warning =
158 truncated.then(|| format!("Truncated {display_name} to {MAX_EMBED_TEXT_BYTES} bytes"));
159
160 let block = acp::ContentBlock::Resource(acp::EmbeddedResource::new(
161 acp::EmbeddedResourceResource::TextResourceContents(
162 acp::TextResourceContents::new(text, file_uri).mime_type(mime_type),
163 ),
164 ));
165
166 Ok(AttachmentBlockResult {
167 block,
168 transcript_placeholder: None,
169 warning,
170 })
171}
172
173async fn build_attachment_file_uri(path: &Path, display_name: &str) -> Result<String, String> {
174 let canonical_path = tokio::fs::canonicalize(path).await.ok();
175 let uri_path = canonical_path.as_deref().unwrap_or(path);
176 Url::from_file_path(uri_path)
177 .map_err(|()| format!("Failed to build file URI for {display_name}"))
178 .map(|url| url.to_string())
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184 use tempfile::TempDir;
185
186 #[tokio::test]
187 async fn build_attachment_blocks_truncates_large_file_with_warning() {
188 let tmp = TempDir::new().unwrap();
189 let path = tmp.path().join("large.txt");
190 let display_name = "large.txt".to_string();
191 std::fs::write(&path, "x".repeat(MAX_EMBED_TEXT_BYTES + 64)).unwrap();
192
193 let attachments = vec![PromptAttachment {
194 path,
195 display_name: display_name.clone(),
196 }];
197 let outcome = build_attachment_blocks(&attachments).await;
198
199 assert_eq!(outcome.blocks.len(), 1);
200 assert_eq!(outcome.warnings.len(), 1);
201 assert!(outcome.warnings[0].contains(&format!(
202 "Truncated {display_name} to {MAX_EMBED_TEXT_BYTES} bytes"
203 )));
204 }
205
206 #[tokio::test]
207 async fn build_attachment_blocks_skips_non_utf8_files() {
208 let tmp = TempDir::new().unwrap();
209 let path = tmp.path().join("binary.bin");
210 let display_name = "binary.bin".to_string();
211 std::fs::write(&path, [0xff, 0xfe, 0xfd]).unwrap();
212
213 let attachments = vec![PromptAttachment {
214 path,
215 display_name: display_name.clone(),
216 }];
217 let outcome = build_attachment_blocks(&attachments).await;
218
219 assert!(outcome.blocks.is_empty());
220 assert_eq!(outcome.warnings.len(), 1);
221 assert!(
222 outcome.warnings[0]
223 .contains(&format!("Skipped binary or non-UTF8 file: {display_name}"))
224 );
225 }
226
227 #[tokio::test]
228 async fn build_attachment_file_uri_falls_back_when_canonicalize_fails() {
229 let tmp = TempDir::new().unwrap();
230 let path = tmp.path().join("missing.txt");
231 let expected = Url::from_file_path(&path).unwrap().to_string();
232
233 let uri = build_attachment_file_uri(&path, "missing.txt")
234 .await
235 .expect("URI should be built from original absolute path");
236
237 assert_eq!(uri, expected);
238 }
239
240 #[tokio::test]
241 async fn png_file_produces_image_content_block() {
242 let tmp = TempDir::new().unwrap();
243 let path = tmp.path().join("test.png");
244 std::fs::write(&path, b"fake png data").unwrap();
245
246 let attachments = vec![PromptAttachment {
247 path,
248 display_name: "test.png".to_string(),
249 }];
250 let outcome = build_attachment_blocks(&attachments).await;
251
252 assert_eq!(outcome.blocks.len(), 1);
253 assert!(outcome.warnings.is_empty());
254 assert_eq!(
255 outcome.transcript_placeholders,
256 vec!["[image attachment: test.png]"]
257 );
258 assert!(matches!(outcome.blocks[0], acp::ContentBlock::Image(_)));
259 }
260
261 #[tokio::test]
262 async fn wav_file_produces_audio_content_block() {
263 let tmp = TempDir::new().unwrap();
264 let path = tmp.path().join("test.wav");
265 std::fs::write(&path, b"fake wav data").unwrap();
266
267 let attachments = vec![PromptAttachment {
268 path,
269 display_name: "test.wav".to_string(),
270 }];
271 let outcome = build_attachment_blocks(&attachments).await;
272
273 assert_eq!(outcome.blocks.len(), 1);
274 assert!(outcome.warnings.is_empty());
275 assert_eq!(
276 outcome.transcript_placeholders,
277 vec!["[audio attachment: test.wav]"]
278 );
279 assert!(matches!(outcome.blocks[0], acp::ContentBlock::Audio(_)));
280 }
281
282 #[test]
283 fn classify_attachment_detects_images() {
284 assert_eq!(
285 classify_attachment(Path::new("photo.png")),
286 AttachmentKind::Image
287 );
288 assert_eq!(
289 classify_attachment(Path::new("photo.jpg")),
290 AttachmentKind::Image
291 );
292 assert_eq!(
293 classify_attachment(Path::new("photo.gif")),
294 AttachmentKind::Image
295 );
296 assert_eq!(
297 classify_attachment(Path::new("photo.webp")),
298 AttachmentKind::Image
299 );
300 }
301
302 #[test]
303 fn classify_attachment_detects_audio() {
304 assert_eq!(
305 classify_attachment(Path::new("note.wav")),
306 AttachmentKind::Audio
307 );
308 assert_eq!(
309 classify_attachment(Path::new("note.mp3")),
310 AttachmentKind::Audio
311 );
312 assert_eq!(
313 classify_attachment(Path::new("note.ogg")),
314 AttachmentKind::Audio
315 );
316 }
317
318 #[test]
319 fn classify_attachment_detects_text() {
320 assert_eq!(
321 classify_attachment(Path::new("readme.txt")),
322 AttachmentKind::Text
323 );
324 }
325
326 #[test]
327 fn classify_attachment_unknown_extension_is_unsupported() {
328 assert_eq!(
329 classify_attachment(Path::new("data.xyz")),
330 AttachmentKind::Unsupported
331 );
332 }
333}