1use super::PromptAttachment;
2use agent_client_protocol as acp;
3use base64::Engine;
4use base64::engine::general_purpose::STANDARD as BASE64;
5use std::path::Path;
6use tokio::io::AsyncReadExt;
7use url::Url;
8
9const MAX_EMBED_TEXT_BYTES: usize = 1024 * 1024;
10const MAX_MEDIA_BYTES: usize = 10 * 1024 * 1024;
11
12const IMAGE_MIME_TYPES: &[&str] = &["image/png", "image/jpeg", "image/gif", "image/webp"];
13const AUDIO_MIME_TYPES: &[&str] = &["audio/wav", "audio/mpeg", "audio/mp3", "audio/ogg"];
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum AttachmentKind {
17 Text,
18 Image,
19 Audio,
20 Unsupported,
21}
22
23pub fn classify_attachment(path: &Path) -> AttachmentKind {
24 let mime = mime_guess::from_path(path).first_or_octet_stream().to_string();
25
26 if IMAGE_MIME_TYPES.contains(&mime.as_str()) {
27 AttachmentKind::Image
28 } else if AUDIO_MIME_TYPES.contains(&mime.as_str()) {
29 AttachmentKind::Audio
30 } else if mime.starts_with("text/") {
31 AttachmentKind::Text
32 } else {
33 AttachmentKind::Unsupported
34 }
35}
36
37#[derive(Debug, Default)]
38pub struct AttachmentBuildOutcome {
39 pub blocks: Vec<acp::ContentBlock>,
40 pub transcript_placeholders: Vec<String>,
41 pub warnings: Vec<String>,
42}
43
44pub async fn build_attachment_blocks(attachments: &[PromptAttachment]) -> AttachmentBuildOutcome {
45 let mut outcome = AttachmentBuildOutcome::default();
46
47 for attachment in attachments {
48 match try_build_attachment_block(&attachment.path, &attachment.display_name).await {
49 Ok(result) => {
50 outcome.blocks.push(result.block);
51 if let Some(placeholder) = result.transcript_placeholder {
52 outcome.transcript_placeholders.push(placeholder);
53 }
54 if let Some(warning) = result.warning {
55 outcome.warnings.push(warning);
56 }
57 }
58 Err(warning) => outcome.warnings.push(warning),
59 }
60 }
61
62 outcome
63}
64
65struct AttachmentBlockResult {
66 block: acp::ContentBlock,
67 transcript_placeholder: Option<String>,
68 warning: Option<String>,
69}
70
71async fn try_build_attachment_block(path: &Path, display_name: &str) -> Result<AttachmentBlockResult, String> {
72 let kind = classify_attachment(path);
73 let mime_type = mime_guess::from_path(path).first_or_octet_stream().to_string();
74
75 match kind {
76 AttachmentKind::Image | AttachmentKind::Audio => {
77 let bytes = read_media_bytes(path, display_name).await?;
78 let data = BASE64.encode(&bytes);
79 let (block, placeholder) = match kind {
80 AttachmentKind::Image => (
81 acp::ContentBlock::Image(acp::ImageContent::new(data, &mime_type)),
82 format!("[image attachment: {display_name}]"),
83 ),
84 _ => (
85 acp::ContentBlock::Audio(acp::AudioContent::new(data, &mime_type)),
86 format!("[audio attachment: {display_name}]"),
87 ),
88 };
89 Ok(AttachmentBlockResult { block, transcript_placeholder: Some(placeholder), warning: None })
90 }
91 _ => build_text_resource_block(path, display_name, &mime_type).await,
92 }
93}
94
95async fn read_media_bytes(path: &Path, display_name: &str) -> Result<Vec<u8>, String> {
96 let metadata = tokio::fs::metadata(path).await.map_err(|e| format!("Failed to read {display_name}: {e}"))?;
97
98 if metadata.len() > MAX_MEDIA_BYTES as u64 {
99 return Err(format!(
100 "Skipped {display_name}: file too large ({} bytes, max {})",
101 metadata.len(),
102 MAX_MEDIA_BYTES
103 ));
104 }
105
106 tokio::fs::read(path).await.map_err(|e| format!("Failed to read {display_name}: {e}"))
107}
108
109async fn build_text_resource_block(
110 path: &Path,
111 display_name: &str,
112 mime_type: &str,
113) -> Result<AttachmentBlockResult, String> {
114 let file = tokio::fs::File::open(path).await.map_err(|error| format!("Failed to read {display_name}: {error}"))?;
115
116 let mut bytes = Vec::new();
117 file.take((MAX_EMBED_TEXT_BYTES + 1) as u64)
118 .read_to_end(&mut bytes)
119 .await
120 .map_err(|error| format!("Failed to read {display_name}: {error}"))?;
121
122 let truncated = bytes.len() > MAX_EMBED_TEXT_BYTES;
123 if truncated {
124 bytes.truncate(MAX_EMBED_TEXT_BYTES);
125 }
126 let text_bytes = bytes.as_slice();
127
128 let text = match std::str::from_utf8(text_bytes) {
129 Ok(text) => text.to_string(),
130 Err(error) if truncated && error.valid_up_to() > 0 => {
131 let valid_bytes = &text_bytes[..error.valid_up_to()];
132 std::str::from_utf8(valid_bytes).expect("valid_up_to must point at a utf8 boundary").to_string()
133 }
134 Err(_) => return Err(format!("Skipped binary or non-UTF8 file: {display_name}")),
135 };
136
137 let file_uri = build_attachment_file_uri(path, display_name).await?;
138 let warning = truncated.then(|| format!("Truncated {display_name} to {MAX_EMBED_TEXT_BYTES} bytes"));
139
140 let block =
141 acp::ContentBlock::Resource(acp::EmbeddedResource::new(acp::EmbeddedResourceResource::TextResourceContents(
142 acp::TextResourceContents::new(text, file_uri).mime_type(mime_type),
143 )));
144
145 Ok(AttachmentBlockResult { block, transcript_placeholder: None, warning })
146}
147
148async fn build_attachment_file_uri(path: &Path, display_name: &str) -> Result<String, String> {
149 let canonical_path = tokio::fs::canonicalize(path).await.ok();
150 let uri_path = canonical_path.as_deref().unwrap_or(path);
151 Url::from_file_path(uri_path)
152 .map_err(|()| format!("Failed to build file URI for {display_name}"))
153 .map(|url| url.to_string())
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use tempfile::TempDir;
160
161 #[tokio::test]
162 async fn build_attachment_blocks_truncates_large_file_with_warning() {
163 let tmp = TempDir::new().unwrap();
164 let path = tmp.path().join("large.txt");
165 let display_name = "large.txt".to_string();
166 std::fs::write(&path, "x".repeat(MAX_EMBED_TEXT_BYTES + 64)).unwrap();
167
168 let attachments = vec![PromptAttachment { path, display_name: display_name.clone() }];
169 let outcome = build_attachment_blocks(&attachments).await;
170
171 assert_eq!(outcome.blocks.len(), 1);
172 assert_eq!(outcome.warnings.len(), 1);
173 assert!(outcome.warnings[0].contains(&format!("Truncated {display_name} to {MAX_EMBED_TEXT_BYTES} bytes")));
174 }
175
176 #[tokio::test]
177 async fn build_attachment_blocks_skips_non_utf8_files() {
178 let tmp = TempDir::new().unwrap();
179 let path = tmp.path().join("binary.bin");
180 let display_name = "binary.bin".to_string();
181 std::fs::write(&path, [0xff, 0xfe, 0xfd]).unwrap();
182
183 let attachments = vec![PromptAttachment { path, display_name: display_name.clone() }];
184 let outcome = build_attachment_blocks(&attachments).await;
185
186 assert!(outcome.blocks.is_empty());
187 assert_eq!(outcome.warnings.len(), 1);
188 assert!(outcome.warnings[0].contains(&format!("Skipped binary or non-UTF8 file: {display_name}")));
189 }
190
191 #[tokio::test]
192 async fn build_attachment_file_uri_falls_back_when_canonicalize_fails() {
193 let tmp = TempDir::new().unwrap();
194 let path = tmp.path().join("missing.txt");
195 let expected = Url::from_file_path(&path).unwrap().to_string();
196
197 let uri = build_attachment_file_uri(&path, "missing.txt")
198 .await
199 .expect("URI should be built from original absolute path");
200
201 assert_eq!(uri, expected);
202 }
203
204 #[tokio::test]
205 async fn png_file_produces_image_content_block() {
206 let tmp = TempDir::new().unwrap();
207 let path = tmp.path().join("test.png");
208 std::fs::write(&path, b"fake png data").unwrap();
209
210 let attachments = vec![PromptAttachment { path, display_name: "test.png".to_string() }];
211 let outcome = build_attachment_blocks(&attachments).await;
212
213 assert_eq!(outcome.blocks.len(), 1);
214 assert!(outcome.warnings.is_empty());
215 assert_eq!(outcome.transcript_placeholders, vec!["[image attachment: test.png]"]);
216 assert!(matches!(outcome.blocks[0], acp::ContentBlock::Image(_)));
217 }
218
219 #[tokio::test]
220 async fn wav_file_produces_audio_content_block() {
221 let tmp = TempDir::new().unwrap();
222 let path = tmp.path().join("test.wav");
223 std::fs::write(&path, b"fake wav data").unwrap();
224
225 let attachments = vec![PromptAttachment { path, display_name: "test.wav".to_string() }];
226 let outcome = build_attachment_blocks(&attachments).await;
227
228 assert_eq!(outcome.blocks.len(), 1);
229 assert!(outcome.warnings.is_empty());
230 assert_eq!(outcome.transcript_placeholders, vec!["[audio attachment: test.wav]"]);
231 assert!(matches!(outcome.blocks[0], acp::ContentBlock::Audio(_)));
232 }
233
234 #[test]
235 fn classify_attachment_detects_images() {
236 assert_eq!(classify_attachment(Path::new("photo.png")), AttachmentKind::Image);
237 assert_eq!(classify_attachment(Path::new("photo.jpg")), AttachmentKind::Image);
238 assert_eq!(classify_attachment(Path::new("photo.gif")), AttachmentKind::Image);
239 assert_eq!(classify_attachment(Path::new("photo.webp")), AttachmentKind::Image);
240 }
241
242 #[test]
243 fn classify_attachment_detects_audio() {
244 assert_eq!(classify_attachment(Path::new("note.wav")), AttachmentKind::Audio);
245 assert_eq!(classify_attachment(Path::new("note.mp3")), AttachmentKind::Audio);
246 assert_eq!(classify_attachment(Path::new("note.ogg")), AttachmentKind::Audio);
247 }
248
249 #[test]
250 fn classify_attachment_detects_text() {
251 assert_eq!(classify_attachment(Path::new("readme.txt")), AttachmentKind::Text);
252 }
253
254 #[test]
255 fn classify_attachment_unknown_extension_is_unsupported() {
256 assert_eq!(classify_attachment(Path::new("data.xyz")), AttachmentKind::Unsupported);
257 }
258}