1use std::path::Path;
4
5use base64::Engine;
6#[cfg(feature = "pyo3_macros")]
7use pyo3::{pyclass, pymethods};
8use serde::{Deserialize, Serialize};
9
10mod inject;
11mod store;
12pub use inject::{
13 compose_tool_response_with_files, merge_required_outputs_into_args,
14 required_files_tool_addendum, tool_file_to_file,
15};
16pub use store::{FileStore, DEFAULT_FILE_TTL};
17
18pub const MODEL_INLINE_BYTES: usize = 1024;
20
21pub const WIRE_EMBED_LIMIT_BYTES: u64 = 8 * 1024 * 1024;
23
24pub const READ_FILE_MAX_SLICE_CHARS: usize = 64 * 1024;
26
27#[cfg_attr(feature = "pyo3_macros", pyclass)]
29#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct FileSource {
32 pub tool: String,
33 pub round: usize,
35 #[serde(default)]
37 pub turn: usize,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42#[serde(untagged)]
43pub enum FileContent {
44 Text {
46 #[serde(default, skip_serializing_if = "Option::is_none")]
47 text: Option<String>,
48 #[serde(default, skip_serializing_if = "Option::is_none")]
49 preview: Option<String>,
50 },
51 Binary {
53 #[serde(default, skip_serializing_if = "Option::is_none")]
54 data_base64: Option<String>,
55 },
56 Error { code: String, message: String },
58}
59
60impl FileContent {
61 pub fn is_error(&self) -> bool {
62 matches!(self, Self::Error { .. })
63 }
64
65 pub fn is_text(&self) -> bool {
66 matches!(self, Self::Text { .. })
67 }
68
69 pub fn is_binary(&self) -> bool {
70 matches!(self, Self::Binary { .. })
71 }
72}
73
74#[cfg_attr(feature = "pyo3_macros", pyclass)]
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct File {
78 pub id: String,
80 pub name: String,
81 #[serde(default, skip_serializing_if = "Option::is_none")]
83 pub format: Option<String>,
84 #[serde(default, skip_serializing_if = "Option::is_none")]
85 pub mime_type: Option<String>,
86 pub bytes: u64,
87 #[serde(default)]
89 pub created_at: u64,
90 pub source: FileSource,
91 #[serde(flatten)]
92 pub content: FileContent,
93}
94
95impl File {
96 pub(crate) fn make_id(run_id: &str, round: usize, idx: usize) -> String {
97 format!("file_{run_id}_r{round}_{idx}")
98 }
99
100 pub(crate) fn truncate_utf8(s: &str, n: usize) -> &str {
102 if s.len() <= n {
103 return s;
104 }
105 let mut end = n;
106 while end > 0 && !s.is_char_boundary(end) {
107 end -= 1;
108 }
109 &s[..end]
110 }
111
112 pub fn as_text(&self) -> Option<&str> {
114 match &self.content {
115 FileContent::Text { text, .. } => text.as_deref(),
116 _ => None,
117 }
118 }
119
120 pub fn preview_str(&self) -> Option<&str> {
122 match &self.content {
123 FileContent::Text { preview, text } => preview.as_deref().or(text.as_deref()),
124 _ => None,
125 }
126 }
127
128 pub fn binary_data(&self) -> Option<&str> {
130 match &self.content {
131 FileContent::Binary { data_base64 } => data_base64.as_deref(),
132 _ => None,
133 }
134 }
135
136 pub fn is_truncated(&self) -> bool {
138 match &self.content {
139 FileContent::Text { text, .. } => text.is_none() && self.bytes > 0,
140 FileContent::Binary { data_base64 } => data_base64.is_none() && self.bytes > 0,
141 FileContent::Error { .. } => false,
142 }
143 }
144
145 pub fn elide_for_wire(&self) -> File {
147 if self.bytes <= WIRE_EMBED_LIMIT_BYTES {
148 return self.clone();
149 }
150 let content = match &self.content {
151 FileContent::Text { preview, .. } => FileContent::Text {
152 text: None,
153 preview: preview.clone(),
154 },
155 FileContent::Binary { .. } => FileContent::Binary { data_base64: None },
156 FileContent::Error { code, message } => FileContent::Error {
157 code: code.clone(),
158 message: message.clone(),
159 },
160 };
161 File {
162 id: self.id.clone(),
163 name: self.name.clone(),
164 format: self.format.clone(),
165 mime_type: self.mime_type.clone(),
166 bytes: self.bytes,
167 created_at: self.created_at,
168 source: self.source.clone(),
169 content,
170 }
171 }
172
173 pub(crate) fn now_unix_secs() -> u64 {
174 std::time::SystemTime::now()
175 .duration_since(std::time::UNIX_EPOCH)
176 .map(|d| d.as_secs())
177 .unwrap_or(0)
178 }
179
180 pub fn is_text(&self) -> bool {
181 self.content.is_text()
182 }
183
184 pub fn is_binary(&self) -> bool {
185 self.content.is_binary()
186 }
187
188 pub fn is_error(&self) -> bool {
189 self.content.is_error()
190 }
191
192 pub fn is_image(&self) -> bool {
193 self.mime_type
194 .as_deref()
195 .is_some_and(|m| m.to_ascii_lowercase().starts_with("image/"))
196 }
197
198 pub fn is_video(&self) -> bool {
199 self.mime_type
200 .as_deref()
201 .is_some_and(|m| m.to_ascii_lowercase().starts_with("video/"))
202 }
203
204 pub fn save<P: AsRef<Path>>(&self, path: P) -> std::io::Result<()> {
206 match &self.content {
207 FileContent::Text { text: Some(t), .. } => std::fs::write(path, t),
208 FileContent::Binary {
209 data_base64: Some(b64),
210 } => {
211 let bytes = base64::engine::general_purpose::STANDARD
212 .decode(b64)
213 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
214 std::fs::write(path, bytes)
215 }
216 FileContent::Text { text: None, .. } | FileContent::Binary { data_base64: None } => {
217 Err(std::io::Error::new(
218 std::io::ErrorKind::InvalidData,
219 format!(
220 "file '{}' body was elided due to wire embed cap; fetch by id first",
221 self.id
222 ),
223 ))
224 }
225 FileContent::Error { code, message } => Err(std::io::Error::other(format!(
226 "file '{}' is an error placeholder: {code}: {message}",
227 self.id
228 ))),
229 }
230 }
231}
232
233#[cfg(feature = "pyo3_macros")]
234#[pymethods]
235impl File {
236 #[getter]
237 fn id(&self) -> &str {
238 &self.id
239 }
240
241 #[getter]
242 fn name(&self) -> &str {
243 &self.name
244 }
245
246 #[getter]
247 fn format(&self) -> Option<&str> {
248 self.format.as_deref()
249 }
250
251 #[getter]
252 fn mime_type(&self) -> Option<&str> {
253 self.mime_type.as_deref()
254 }
255
256 #[getter]
257 fn bytes(&self) -> u64 {
258 self.bytes
259 }
260
261 #[getter]
262 fn source(&self) -> FileSource {
263 self.source.clone()
264 }
265
266 #[getter]
267 fn text(&self) -> Option<&str> {
268 self.as_text()
269 }
270
271 #[getter]
272 fn data_base64(&self) -> Option<&str> {
273 self.binary_data()
274 }
275
276 #[getter]
277 fn preview(&self) -> Option<&str> {
278 self.preview_str()
279 }
280
281 #[pyo3(name = "is_text")]
282 fn py_is_text(&self) -> bool {
283 self.is_text()
284 }
285
286 #[pyo3(name = "is_binary")]
287 fn py_is_binary(&self) -> bool {
288 self.is_binary()
289 }
290
291 #[pyo3(name = "is_error")]
292 fn py_is_error(&self) -> bool {
293 self.is_error()
294 }
295
296 #[pyo3(name = "is_image")]
297 fn py_is_image(&self) -> bool {
298 self.is_image()
299 }
300
301 #[pyo3(name = "is_video")]
302 fn py_is_video(&self) -> bool {
303 self.is_video()
304 }
305
306 #[pyo3(name = "is_truncated")]
307 fn py_is_truncated(&self) -> bool {
308 self.is_truncated()
309 }
310
311 #[pyo3(name = "save")]
312 fn py_save(&self, path: &str) -> pyo3::PyResult<()> {
313 self.save(path)
314 .map_err(|e| pyo3::exceptions::PyIOError::new_err(format!("failed to save file: {e}")))
315 }
316
317 fn __repr__(&self) -> String {
318 format!("{self:#?}")
319 }
320}
321
322#[cfg_attr(feature = "pyo3_macros", pyclass)]
324#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
325#[derive(Debug, Clone, Serialize, Deserialize)]
326pub struct RequestedFile {
327 pub name: String,
328 #[serde(default, skip_serializing_if = "Option::is_none")]
330 pub format: Option<String>,
331 #[serde(default, skip_serializing_if = "Option::is_none")]
333 pub description: Option<String>,
334}
335
336impl RequestedFile {
337 pub fn new(name: impl Into<String>) -> Self {
338 Self {
339 name: name.into(),
340 format: None,
341 description: None,
342 }
343 }
344
345 pub fn with_format(mut self, format: impl Into<String>) -> Self {
346 self.format = Some(format.into());
347 self
348 }
349
350 pub fn with_description(mut self, description: impl Into<String>) -> Self {
351 self.description = Some(description.into());
352 self
353 }
354}
355
356pub fn format_from_name(name: &str) -> Option<String> {
358 name.rsplit_once('.')
359 .map(|(_, ext)| ext.to_ascii_lowercase())
360}
361
362pub fn mime_for_format(format: &str) -> String {
364 let ext = match format.to_ascii_lowercase().as_str() {
365 "markdown" => "md".to_string(),
366 "yml" => "yaml".to_string(),
367 "latex" | "tex" => "tex".to_string(),
368 "python" => "py".to_string(),
369 "rust" => "rs".to_string(),
370 "vega-lite" | "vega" | "geojson" => "json".to_string(),
371 "text" => "txt".to_string(),
372 other => other.to_string(),
373 };
374 mime_guess::from_ext(&ext)
375 .first_or_octet_stream()
376 .essence_str()
377 .to_string()
378}
379
380pub fn is_text_mime(mime: &str) -> bool {
382 let m = mime.to_ascii_lowercase();
383 m.starts_with("text/")
384 || matches!(
385 m.as_str(),
386 "application/json"
387 | "application/geo+json"
388 | "application/xml"
389 | "application/yaml"
390 | "application/x-yaml"
391 | "application/toml"
392 | "application/sql"
393 | "application/x-tex"
394 | "image/svg+xml"
395 )
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401
402 fn text_file(id: &str, body: &str) -> File {
403 File {
404 id: id.into(),
405 name: "x.txt".into(),
406 format: Some("txt".into()),
407 mime_type: Some("text/plain".into()),
408 bytes: body.len() as u64,
409 created_at: 0,
410 source: FileSource {
411 tool: "execute_python".into(),
412 round: 0,
413 turn: 0,
414 },
415 content: FileContent::Text {
416 text: Some(body.into()),
417 preview: None,
418 },
419 }
420 }
421
422 #[test]
423 fn id_format() {
424 assert_eq!(File::make_id("abc", 0, 0), "file_abc_r0_0");
425 assert_eq!(File::make_id("xyz", 3, 7), "file_xyz_r3_7");
426 }
427
428 #[test]
429 fn truncate_respects_utf8() {
430 let s = "héllo";
431 assert_eq!(File::truncate_utf8(s, 10), "héllo");
433 assert_eq!(File::truncate_utf8(s, 2), "h");
435 assert_eq!(File::truncate_utf8(s, 3), "hé");
436 }
437
438 #[test]
439 fn text_accessors() {
440 let f = text_file("file_x_r0_0", "hello");
441 assert_eq!(f.as_text(), Some("hello"));
442 assert_eq!(f.preview_str(), Some("hello"));
443 assert!(f.binary_data().is_none());
444 assert!(!f.is_truncated());
445 }
446
447 #[test]
448 fn truncated_binary_is_flagged() {
449 let f = File {
450 id: "file_x_r0_0".into(),
451 name: "big.bin".into(),
452 format: Some("bin".into()),
453 mime_type: Some("application/octet-stream".into()),
454 bytes: 64 * 1024 * 1024,
455 created_at: 0,
456 source: FileSource {
457 tool: "execute_python".into(),
458 round: 0,
459 turn: 0,
460 },
461 content: FileContent::Binary { data_base64: None },
462 };
463 assert!(f.is_truncated());
464 }
465
466 #[test]
467 fn format_from_name_extension() {
468 assert_eq!(format_from_name("plot.png"), Some("png".into()));
469 assert_eq!(format_from_name("data.csv"), Some("csv".into()));
470 assert_eq!(format_from_name("report.MD"), Some("md".into()));
471 assert_eq!(format_from_name("noext"), None);
472 }
473
474 #[test]
475 fn mime_lookup() {
476 assert_eq!(mime_for_format("csv"), "text/csv");
477 assert_eq!(mime_for_format("PNG"), "image/png");
478 assert_eq!(mime_for_format("markdown"), "text/markdown");
479 assert_eq!(mime_for_format("geojson"), "application/json");
480 assert_eq!(mime_for_format("unknown_"), "application/octet-stream");
481 }
482
483 #[test]
484 fn text_mime_classifier() {
485 assert!(is_text_mime("text/csv"));
486 assert!(is_text_mime("application/json"));
487 assert!(is_text_mime("image/svg+xml"));
488 assert!(!is_text_mime("image/png"));
489 assert!(!is_text_mime("application/octet-stream"));
490 }
491}