1use std::path::PathBuf;
7
8use async_trait::async_trait;
9use tokio::fs;
10
11use super::{now_secs, Artifact, ArtifactError, ArtifactMetadata, ArtifactService};
12
13pub struct FileArtifactService {
19 root_dir: PathBuf,
20}
21
22impl FileArtifactService {
23 pub fn new(root_dir: impl Into<PathBuf>) -> Result<Self, ArtifactError> {
27 let root = root_dir.into();
28 std::fs::create_dir_all(&root)
30 .map_err(|e| ArtifactError::Storage(format!("Failed to create root dir: {}", e)))?;
31 Ok(Self { root_dir: root })
32 }
33
34 fn artifact_dir(&self, session_id: &str, name: &str) -> PathBuf {
35 let safe_session = sanitize_path_component(session_id);
37 let safe_name = sanitize_path_component(name);
38 self.root_dir.join(&safe_session).join(&safe_name)
39 }
40
41 fn version_dir(&self, session_id: &str, name: &str, version: u32) -> PathBuf {
42 self.artifact_dir(session_id, name)
43 .join(format!("v{}", version))
44 }
45
46 async fn next_version(&self, session_id: &str, name: &str) -> u32 {
48 let dir = self.artifact_dir(session_id, name);
49 if !dir.exists() {
50 return 1;
51 }
52 let mut max_version = 0u32;
53 if let Ok(mut entries) = fs::read_dir(&dir).await {
54 while let Ok(Some(entry)) = entries.next_entry().await {
55 if let Some(name) = entry.file_name().to_str() {
56 if let Some(v) = name.strip_prefix('v') {
57 if let Ok(version) = v.parse::<u32>() {
58 max_version = max_version.max(version);
59 }
60 }
61 }
62 }
63 }
64 max_version + 1
65 }
66}
67
68fn sanitize_path_component(s: &str) -> String {
70 s.replace(['/', '\\', '.'], "_")
71}
72
73#[async_trait]
74impl ArtifactService for FileArtifactService {
75 async fn save(
76 &self,
77 session_id: &str,
78 artifact: Artifact,
79 ) -> Result<ArtifactMetadata, ArtifactError> {
80 let version = self.next_version(session_id, &artifact.metadata.name).await;
81 let ver_dir = self.version_dir(session_id, &artifact.metadata.name, version);
82 fs::create_dir_all(&ver_dir)
83 .await
84 .map_err(|e| ArtifactError::Storage(e.to_string()))?;
85
86 fs::write(ver_dir.join("data"), &artifact.data)
88 .await
89 .map_err(|e| ArtifactError::Storage(e.to_string()))?;
90
91 let mut metadata = artifact.metadata;
93 metadata.version = version;
94 metadata.updated_at = now_secs();
95 if version == 1 {
96 metadata.created_at = metadata.updated_at;
97 }
98
99 let metadata_json = serde_json::to_string_pretty(&metadata)
100 .map_err(|e| ArtifactError::Storage(e.to_string()))?;
101 fs::write(ver_dir.join("metadata.json"), metadata_json)
102 .await
103 .map_err(|e| ArtifactError::Storage(e.to_string()))?;
104
105 Ok(metadata)
106 }
107
108 async fn load(&self, session_id: &str, name: &str) -> Result<Option<Artifact>, ArtifactError> {
109 let latest = self.next_version(session_id, name).await;
110 if latest == 1 {
111 return Ok(None);
112 }
113 self.load_version(session_id, name, latest - 1).await
114 }
115
116 async fn load_version(
117 &self,
118 session_id: &str,
119 name: &str,
120 version: u32,
121 ) -> Result<Option<Artifact>, ArtifactError> {
122 let ver_dir = self.version_dir(session_id, name, version);
123 if !ver_dir.exists() {
124 return Ok(None);
125 }
126
127 let data = fs::read(ver_dir.join("data"))
128 .await
129 .map_err(|e| ArtifactError::Storage(e.to_string()))?;
130 let metadata_str = fs::read_to_string(ver_dir.join("metadata.json"))
131 .await
132 .map_err(|e| ArtifactError::Storage(e.to_string()))?;
133 let metadata: ArtifactMetadata = serde_json::from_str(&metadata_str)
134 .map_err(|e| ArtifactError::Storage(e.to_string()))?;
135
136 Ok(Some(Artifact { metadata, data }))
137 }
138
139 async fn list(&self, session_id: &str) -> Result<Vec<ArtifactMetadata>, ArtifactError> {
140 let session_dir = self.root_dir.join(sanitize_path_component(session_id));
141 if !session_dir.exists() {
142 return Ok(vec![]);
143 }
144
145 let mut result = vec![];
146 let mut entries = fs::read_dir(&session_dir)
147 .await
148 .map_err(|e| ArtifactError::Storage(e.to_string()))?;
149
150 while let Some(entry) = entries
151 .next_entry()
152 .await
153 .map_err(|e| ArtifactError::Storage(e.to_string()))?
154 {
155 if entry.file_type().await.map(|t| t.is_dir()).unwrap_or(false) {
156 let name = entry.file_name().to_string_lossy().to_string();
157 if let Ok(Some(artifact)) = self.load(session_id, &name).await {
159 result.push(artifact.metadata);
160 }
161 }
162 }
163 Ok(result)
164 }
165
166 async fn delete(&self, session_id: &str, name: &str) -> Result<(), ArtifactError> {
167 let dir = self.artifact_dir(session_id, name);
168 if dir.exists() {
169 fs::remove_dir_all(&dir)
170 .await
171 .map_err(|e| ArtifactError::Storage(e.to_string()))?;
172 }
173 Ok(())
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180 use std::sync::atomic::{AtomicU32, Ordering};
181
182 fn test_dir() -> PathBuf {
184 static COUNTER: AtomicU32 = AtomicU32::new(0);
185 let id = COUNTER.fetch_add(1, Ordering::SeqCst);
186 let dir = std::env::temp_dir()
187 .join("rs_adk_file_artifact_tests")
188 .join(format!("test_{}_{}", std::process::id(), id));
189 let _ = std::fs::remove_dir_all(&dir);
191 dir
192 }
193
194 #[tokio::test]
195 async fn save_and_load_round_trip() {
196 let dir = test_dir();
197 let svc = FileArtifactService::new(&dir).unwrap();
198
199 let artifact = Artifact::text("notes", "Hello, world!");
200 let meta = svc.save("session1", artifact).await.unwrap();
201 assert_eq!(meta.name, "notes");
202 assert_eq!(meta.version, 1);
203
204 let loaded = svc.load("session1", "notes").await.unwrap().unwrap();
205 assert_eq!(std::str::from_utf8(&loaded.data).unwrap(), "Hello, world!");
206 assert_eq!(loaded.metadata.version, 1);
207 assert_eq!(loaded.metadata.mime_type, "text/plain");
208
209 std::fs::remove_dir_all(&dir).ok();
210 }
211
212 #[tokio::test]
213 async fn versioning_increments_and_load_gets_latest() {
214 let dir = test_dir();
215 let svc = FileArtifactService::new(&dir).unwrap();
216
217 let m1 = svc
218 .save("s1", Artifact::text("doc", "version 1"))
219 .await
220 .unwrap();
221 assert_eq!(m1.version, 1);
222
223 let m2 = svc
224 .save("s1", Artifact::text("doc", "version 2"))
225 .await
226 .unwrap();
227 assert_eq!(m2.version, 2);
228
229 let m3 = svc
230 .save("s1", Artifact::text("doc", "version 3"))
231 .await
232 .unwrap();
233 assert_eq!(m3.version, 3);
234
235 let latest = svc.load("s1", "doc").await.unwrap().unwrap();
237 assert_eq!(latest.metadata.version, 3);
238 assert_eq!(std::str::from_utf8(&latest.data).unwrap(), "version 3");
239
240 std::fs::remove_dir_all(&dir).ok();
241 }
242
243 #[tokio::test]
244 async fn load_specific_version() {
245 let dir = test_dir();
246 let svc = FileArtifactService::new(&dir).unwrap();
247
248 svc.save("s1", Artifact::text("doc", "v1 data"))
249 .await
250 .unwrap();
251 svc.save("s1", Artifact::text("doc", "v2 data"))
252 .await
253 .unwrap();
254 svc.save("s1", Artifact::text("doc", "v3 data"))
255 .await
256 .unwrap();
257
258 let v1 = svc.load_version("s1", "doc", 1).await.unwrap().unwrap();
259 assert_eq!(std::str::from_utf8(&v1.data).unwrap(), "v1 data");
260 assert_eq!(v1.metadata.version, 1);
261
262 let v2 = svc.load_version("s1", "doc", 2).await.unwrap().unwrap();
263 assert_eq!(std::str::from_utf8(&v2.data).unwrap(), "v2 data");
264
265 let v3 = svc.load_version("s1", "doc", 3).await.unwrap().unwrap();
266 assert_eq!(std::str::from_utf8(&v3.data).unwrap(), "v3 data");
267
268 let v99 = svc.load_version("s1", "doc", 99).await.unwrap();
270 assert!(v99.is_none());
271
272 std::fs::remove_dir_all(&dir).ok();
273 }
274
275 #[tokio::test]
276 async fn list_artifacts() {
277 let dir = test_dir();
278 let svc = FileArtifactService::new(&dir).unwrap();
279
280 svc.save("s1", Artifact::text("alpha", "data"))
281 .await
282 .unwrap();
283 svc.save("s1", Artifact::text("beta", "data"))
284 .await
285 .unwrap();
286 svc.save("s2", Artifact::text("gamma", "data"))
287 .await
288 .unwrap();
289
290 let list = svc.list("s1").await.unwrap();
291 assert_eq!(list.len(), 2);
292 let names: Vec<&str> = list.iter().map(|m| m.name.as_str()).collect();
293 assert!(names.contains(&"alpha"));
294 assert!(names.contains(&"beta"));
295
296 let list2 = svc.list("s2").await.unwrap();
298 assert_eq!(list2.len(), 1);
299 assert_eq!(list2[0].name, "gamma");
300
301 std::fs::remove_dir_all(&dir).ok();
302 }
303
304 #[tokio::test]
305 async fn delete_artifact() {
306 let dir = test_dir();
307 let svc = FileArtifactService::new(&dir).unwrap();
308
309 svc.save("s1", Artifact::text("notes", "data"))
310 .await
311 .unwrap();
312 svc.save("s1", Artifact::text("notes", "v2")).await.unwrap();
313
314 svc.delete("s1", "notes").await.unwrap();
315
316 let result = svc.load("s1", "notes").await.unwrap();
317 assert!(result.is_none());
318
319 svc.delete("s1", "notes").await.unwrap();
321
322 std::fs::remove_dir_all(&dir).ok();
323 }
324
325 #[tokio::test]
326 async fn load_nonexistent_returns_none() {
327 let dir = test_dir();
328 let svc = FileArtifactService::new(&dir).unwrap();
329
330 let result = svc.load("no_session", "no_artifact").await.unwrap();
331 assert!(result.is_none());
332
333 std::fs::remove_dir_all(&dir).ok();
334 }
335
336 #[tokio::test]
337 async fn path_traversal_prevention() {
338 let dir = test_dir();
339 let svc = FileArtifactService::new(&dir).unwrap();
340
341 let artifact = Artifact::text("../../../etc/passwd", "malicious");
343 let meta = svc.save("../../hack", artifact).await.unwrap();
344 assert_eq!(meta.version, 1);
345
346 let sanitized_session = sanitize_path_component("../../hack");
348 let sanitized_name = sanitize_path_component("../../../etc/passwd");
349 assert!(!sanitized_session.contains('/'));
350 assert!(!sanitized_session.contains('\\'));
351 assert!(!sanitized_session.contains('.'));
352 assert!(!sanitized_name.contains('/'));
353 assert!(!sanitized_name.contains('\\'));
354 assert!(!sanitized_name.contains('.'));
355
356 let loaded = svc.load("../../hack", "../../../etc/passwd").await.unwrap();
358 assert!(loaded.is_some());
359 assert_eq!(
360 std::str::from_utf8(&loaded.unwrap().data).unwrap(),
361 "malicious"
362 );
363
364 assert!(dir.exists());
366
367 std::fs::remove_dir_all(&dir).ok();
368 }
369
370 #[test]
371 fn sanitize_removes_dangerous_chars() {
372 assert_eq!(sanitize_path_component("normal"), "normal");
373 assert_eq!(sanitize_path_component(".."), "__");
374 assert_eq!(sanitize_path_component("a/b"), "a_b");
375 assert_eq!(sanitize_path_component("a\\b"), "a_b");
376 assert_eq!(sanitize_path_component("../../etc"), "______etc");
377 assert_eq!(sanitize_path_component("file.txt"), "file_txt");
378 }
379}