Skip to main content

rs_adk/artifacts/
file_service.rs

1//! Filesystem-backed artifact service with versioning.
2//!
3//! Directory layout: `{root}/{session_id}/{artifact_name}/v{version}/data`
4//! Metadata stored as: `{root}/{session_id}/{artifact_name}/v{version}/metadata.json`
5
6use std::path::PathBuf;
7
8use async_trait::async_trait;
9use tokio::fs;
10
11use super::{now_secs, Artifact, ArtifactError, ArtifactMetadata, ArtifactService};
12
13/// Filesystem-backed artifact storage with versioning.
14///
15/// Each artifact version is stored in its own directory with a `data` file
16/// and a `metadata.json` sidecar. Session IDs and artifact names are sanitized
17/// to prevent path traversal attacks.
18pub struct FileArtifactService {
19    root_dir: PathBuf,
20}
21
22impl FileArtifactService {
23    /// Create a new file artifact service rooted at the given directory.
24    ///
25    /// Creates the root directory if it doesn't exist.
26    pub fn new(root_dir: impl Into<PathBuf>) -> Result<Self, ArtifactError> {
27        let root = root_dir.into();
28        // Create root dir if it doesn't exist (use std::fs since this is construction)
29        std::fs::create_dir_all(&root)
30            .map_err(|e| ArtifactError::Storage(format!("Failed to create root dir: {}", e)))?;
31        Ok(Self { root_dir: root })
32    }
33
34    fn artifact_dir(&self, session_id: &str, name: &str) -> PathBuf {
35        // Sanitize inputs to prevent path traversal
36        let safe_session = sanitize_path_component(session_id);
37        let safe_name = sanitize_path_component(name);
38        self.root_dir.join(&safe_session).join(&safe_name)
39    }
40
41    fn version_dir(&self, session_id: &str, name: &str, version: u32) -> PathBuf {
42        self.artifact_dir(session_id, name)
43            .join(format!("v{}", version))
44    }
45
46    /// Get the next version number by counting existing version directories.
47    async fn next_version(&self, session_id: &str, name: &str) -> u32 {
48        let dir = self.artifact_dir(session_id, name);
49        if !dir.exists() {
50            return 1;
51        }
52        let mut max_version = 0u32;
53        if let Ok(mut entries) = fs::read_dir(&dir).await {
54            while let Ok(Some(entry)) = entries.next_entry().await {
55                if let Some(name) = entry.file_name().to_str() {
56                    if let Some(v) = name.strip_prefix('v') {
57                        if let Ok(version) = v.parse::<u32>() {
58                            max_version = max_version.max(version);
59                        }
60                    }
61                }
62            }
63        }
64        max_version + 1
65    }
66}
67
68/// Replace path separators and ".." with underscores to prevent traversal.
69fn sanitize_path_component(s: &str) -> String {
70    s.replace(['/', '\\', '.'], "_")
71}
72
73#[async_trait]
74impl ArtifactService for FileArtifactService {
75    async fn save(
76        &self,
77        session_id: &str,
78        artifact: Artifact,
79    ) -> Result<ArtifactMetadata, ArtifactError> {
80        let version = self.next_version(session_id, &artifact.metadata.name).await;
81        let ver_dir = self.version_dir(session_id, &artifact.metadata.name, version);
82        fs::create_dir_all(&ver_dir)
83            .await
84            .map_err(|e| ArtifactError::Storage(e.to_string()))?;
85
86        // Write data
87        fs::write(ver_dir.join("data"), &artifact.data)
88            .await
89            .map_err(|e| ArtifactError::Storage(e.to_string()))?;
90
91        // Update metadata with correct version and write as JSON sidecar
92        let mut metadata = artifact.metadata;
93        metadata.version = version;
94        metadata.updated_at = now_secs();
95        if version == 1 {
96            metadata.created_at = metadata.updated_at;
97        }
98
99        let metadata_json = serde_json::to_string_pretty(&metadata)
100            .map_err(|e| ArtifactError::Storage(e.to_string()))?;
101        fs::write(ver_dir.join("metadata.json"), metadata_json)
102            .await
103            .map_err(|e| ArtifactError::Storage(e.to_string()))?;
104
105        Ok(metadata)
106    }
107
108    async fn load(&self, session_id: &str, name: &str) -> Result<Option<Artifact>, ArtifactError> {
109        let latest = self.next_version(session_id, name).await;
110        if latest == 1 {
111            return Ok(None);
112        }
113        self.load_version(session_id, name, latest - 1).await
114    }
115
116    async fn load_version(
117        &self,
118        session_id: &str,
119        name: &str,
120        version: u32,
121    ) -> Result<Option<Artifact>, ArtifactError> {
122        let ver_dir = self.version_dir(session_id, name, version);
123        if !ver_dir.exists() {
124            return Ok(None);
125        }
126
127        let data = fs::read(ver_dir.join("data"))
128            .await
129            .map_err(|e| ArtifactError::Storage(e.to_string()))?;
130        let metadata_str = fs::read_to_string(ver_dir.join("metadata.json"))
131            .await
132            .map_err(|e| ArtifactError::Storage(e.to_string()))?;
133        let metadata: ArtifactMetadata = serde_json::from_str(&metadata_str)
134            .map_err(|e| ArtifactError::Storage(e.to_string()))?;
135
136        Ok(Some(Artifact { metadata, data }))
137    }
138
139    async fn list(&self, session_id: &str) -> Result<Vec<ArtifactMetadata>, ArtifactError> {
140        let session_dir = self.root_dir.join(sanitize_path_component(session_id));
141        if !session_dir.exists() {
142            return Ok(vec![]);
143        }
144
145        let mut result = vec![];
146        let mut entries = fs::read_dir(&session_dir)
147            .await
148            .map_err(|e| ArtifactError::Storage(e.to_string()))?;
149
150        while let Some(entry) = entries
151            .next_entry()
152            .await
153            .map_err(|e| ArtifactError::Storage(e.to_string()))?
154        {
155            if entry.file_type().await.map(|t| t.is_dir()).unwrap_or(false) {
156                let name = entry.file_name().to_string_lossy().to_string();
157                // Load latest version metadata
158                if let Ok(Some(artifact)) = self.load(session_id, &name).await {
159                    result.push(artifact.metadata);
160                }
161            }
162        }
163        Ok(result)
164    }
165
166    async fn delete(&self, session_id: &str, name: &str) -> Result<(), ArtifactError> {
167        let dir = self.artifact_dir(session_id, name);
168        if dir.exists() {
169            fs::remove_dir_all(&dir)
170                .await
171                .map_err(|e| ArtifactError::Storage(e.to_string()))?;
172        }
173        Ok(())
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use std::sync::atomic::{AtomicU32, Ordering};
181
182    /// Create a unique temp directory for each test.
183    fn test_dir() -> PathBuf {
184        static COUNTER: AtomicU32 = AtomicU32::new(0);
185        let id = COUNTER.fetch_add(1, Ordering::SeqCst);
186        let dir = std::env::temp_dir()
187            .join("rs_adk_file_artifact_tests")
188            .join(format!("test_{}_{}", std::process::id(), id));
189        // Clean up any leftovers from previous runs
190        let _ = std::fs::remove_dir_all(&dir);
191        dir
192    }
193
194    #[tokio::test]
195    async fn save_and_load_round_trip() {
196        let dir = test_dir();
197        let svc = FileArtifactService::new(&dir).unwrap();
198
199        let artifact = Artifact::text("notes", "Hello, world!");
200        let meta = svc.save("session1", artifact).await.unwrap();
201        assert_eq!(meta.name, "notes");
202        assert_eq!(meta.version, 1);
203
204        let loaded = svc.load("session1", "notes").await.unwrap().unwrap();
205        assert_eq!(std::str::from_utf8(&loaded.data).unwrap(), "Hello, world!");
206        assert_eq!(loaded.metadata.version, 1);
207        assert_eq!(loaded.metadata.mime_type, "text/plain");
208
209        std::fs::remove_dir_all(&dir).ok();
210    }
211
212    #[tokio::test]
213    async fn versioning_increments_and_load_gets_latest() {
214        let dir = test_dir();
215        let svc = FileArtifactService::new(&dir).unwrap();
216
217        let m1 = svc
218            .save("s1", Artifact::text("doc", "version 1"))
219            .await
220            .unwrap();
221        assert_eq!(m1.version, 1);
222
223        let m2 = svc
224            .save("s1", Artifact::text("doc", "version 2"))
225            .await
226            .unwrap();
227        assert_eq!(m2.version, 2);
228
229        let m3 = svc
230            .save("s1", Artifact::text("doc", "version 3"))
231            .await
232            .unwrap();
233        assert_eq!(m3.version, 3);
234
235        // load() should return latest (v3)
236        let latest = svc.load("s1", "doc").await.unwrap().unwrap();
237        assert_eq!(latest.metadata.version, 3);
238        assert_eq!(std::str::from_utf8(&latest.data).unwrap(), "version 3");
239
240        std::fs::remove_dir_all(&dir).ok();
241    }
242
243    #[tokio::test]
244    async fn load_specific_version() {
245        let dir = test_dir();
246        let svc = FileArtifactService::new(&dir).unwrap();
247
248        svc.save("s1", Artifact::text("doc", "v1 data"))
249            .await
250            .unwrap();
251        svc.save("s1", Artifact::text("doc", "v2 data"))
252            .await
253            .unwrap();
254        svc.save("s1", Artifact::text("doc", "v3 data"))
255            .await
256            .unwrap();
257
258        let v1 = svc.load_version("s1", "doc", 1).await.unwrap().unwrap();
259        assert_eq!(std::str::from_utf8(&v1.data).unwrap(), "v1 data");
260        assert_eq!(v1.metadata.version, 1);
261
262        let v2 = svc.load_version("s1", "doc", 2).await.unwrap().unwrap();
263        assert_eq!(std::str::from_utf8(&v2.data).unwrap(), "v2 data");
264
265        let v3 = svc.load_version("s1", "doc", 3).await.unwrap().unwrap();
266        assert_eq!(std::str::from_utf8(&v3.data).unwrap(), "v3 data");
267
268        // Nonexistent version returns None
269        let v99 = svc.load_version("s1", "doc", 99).await.unwrap();
270        assert!(v99.is_none());
271
272        std::fs::remove_dir_all(&dir).ok();
273    }
274
275    #[tokio::test]
276    async fn list_artifacts() {
277        let dir = test_dir();
278        let svc = FileArtifactService::new(&dir).unwrap();
279
280        svc.save("s1", Artifact::text("alpha", "data"))
281            .await
282            .unwrap();
283        svc.save("s1", Artifact::text("beta", "data"))
284            .await
285            .unwrap();
286        svc.save("s2", Artifact::text("gamma", "data"))
287            .await
288            .unwrap();
289
290        let list = svc.list("s1").await.unwrap();
291        assert_eq!(list.len(), 2);
292        let names: Vec<&str> = list.iter().map(|m| m.name.as_str()).collect();
293        assert!(names.contains(&"alpha"));
294        assert!(names.contains(&"beta"));
295
296        // Different session
297        let list2 = svc.list("s2").await.unwrap();
298        assert_eq!(list2.len(), 1);
299        assert_eq!(list2[0].name, "gamma");
300
301        std::fs::remove_dir_all(&dir).ok();
302    }
303
304    #[tokio::test]
305    async fn delete_artifact() {
306        let dir = test_dir();
307        let svc = FileArtifactService::new(&dir).unwrap();
308
309        svc.save("s1", Artifact::text("notes", "data"))
310            .await
311            .unwrap();
312        svc.save("s1", Artifact::text("notes", "v2")).await.unwrap();
313
314        svc.delete("s1", "notes").await.unwrap();
315
316        let result = svc.load("s1", "notes").await.unwrap();
317        assert!(result.is_none());
318
319        // Deleting again should be a no-op
320        svc.delete("s1", "notes").await.unwrap();
321
322        std::fs::remove_dir_all(&dir).ok();
323    }
324
325    #[tokio::test]
326    async fn load_nonexistent_returns_none() {
327        let dir = test_dir();
328        let svc = FileArtifactService::new(&dir).unwrap();
329
330        let result = svc.load("no_session", "no_artifact").await.unwrap();
331        assert!(result.is_none());
332
333        std::fs::remove_dir_all(&dir).ok();
334    }
335
336    #[tokio::test]
337    async fn path_traversal_prevention() {
338        let dir = test_dir();
339        let svc = FileArtifactService::new(&dir).unwrap();
340
341        // Session ID and name with traversal attempts should be sanitized
342        let artifact = Artifact::text("../../../etc/passwd", "malicious");
343        let meta = svc.save("../../hack", artifact).await.unwrap();
344        assert_eq!(meta.version, 1);
345
346        // The sanitized name should not contain path separators or dots
347        let sanitized_session = sanitize_path_component("../../hack");
348        let sanitized_name = sanitize_path_component("../../../etc/passwd");
349        assert!(!sanitized_session.contains('/'));
350        assert!(!sanitized_session.contains('\\'));
351        assert!(!sanitized_session.contains('.'));
352        assert!(!sanitized_name.contains('/'));
353        assert!(!sanitized_name.contains('\\'));
354        assert!(!sanitized_name.contains('.'));
355
356        // Should be able to load with the original (unsanitized) names
357        let loaded = svc.load("../../hack", "../../../etc/passwd").await.unwrap();
358        assert!(loaded.is_some());
359        assert_eq!(
360            std::str::from_utf8(&loaded.unwrap().data).unwrap(),
361            "malicious"
362        );
363
364        // Verify files stayed within root
365        assert!(dir.exists());
366
367        std::fs::remove_dir_all(&dir).ok();
368    }
369
370    #[test]
371    fn sanitize_removes_dangerous_chars() {
372        assert_eq!(sanitize_path_component("normal"), "normal");
373        assert_eq!(sanitize_path_component(".."), "__");
374        assert_eq!(sanitize_path_component("a/b"), "a_b");
375        assert_eq!(sanitize_path_component("a\\b"), "a_b");
376        assert_eq!(sanitize_path_component("../../etc"), "______etc");
377        assert_eq!(sanitize_path_component("file.txt"), "file_txt");
378    }
379}