Skip to main content

adk_artifact/
file.rs

1use crate::service::*;
2use adk_core::{Part, Result};
3use async_trait::async_trait;
4use std::collections::BTreeSet;
5use std::path::{Path, PathBuf};
6use std::time::{SystemTime, UNIX_EPOCH};
7use tokio::fs;
8
9const USER_SCOPED_DIR: &str = "_user_scoped_";
10
11/// Persist artifacts on the local filesystem.
12///
13/// The base directory is created and canonicalized at construction time.
14pub struct FileArtifactService {
15    /// Canonical (absolute, resolved) base directory. Set once at construction.
16    base_dir: PathBuf,
17}
18
19impl FileArtifactService {
20    /// Create a new filesystem-backed artifact service rooted at `base_dir`.
21    ///
22    /// Creates the directory if it doesn't exist and stores the canonical path.
23    ///
24    /// # Errors
25    ///
26    /// Returns an error if the directory cannot be created or canonicalized.
27    pub fn new(base_dir: impl Into<PathBuf>) -> Result<Self> {
28        let raw = base_dir.into();
29        std::fs::create_dir_all(&raw)
30            .map_err(|e| adk_core::AdkError::artifact(format!("create base dir failed: {e}")))?;
31        let canonical = raw.canonicalize().map_err(|e| {
32            adk_core::AdkError::artifact(format!("canonicalize base dir failed: {e}"))
33        })?;
34        Ok(Self { base_dir: canonical })
35    }
36
37    fn validate_file_name(file_name: &str) -> Result<()> {
38        if file_name.is_empty() {
39            return Err(adk_core::AdkError::artifact("invalid artifact file name: empty name"));
40        }
41
42        if file_name.contains('/')
43            || file_name.contains('\\')
44            || file_name == "."
45            || file_name == ".."
46            || file_name.contains("..")
47        {
48            return Err(adk_core::AdkError::artifact(format!(
49                "invalid artifact file name '{}': path separators and traversal patterns are not allowed",
50                file_name
51            )));
52        }
53
54        Ok(())
55    }
56
57    /// Validates a path component (app_name, user_id, session_id) used to build artifact paths.
58    ///
59    /// Rejects empty values, directory separators, and traversal patterns.
60    fn validate_path_component(component: &str, field: &str) -> Result<()> {
61        if component.is_empty() {
62            return Err(adk_core::AdkError::artifact(format!(
63                "invalid artifact {field}: empty value"
64            )));
65        }
66
67        if component.contains('/')
68            || component.contains('\\')
69            || component == "."
70            || component == ".."
71            || component.contains("..")
72        {
73            return Err(adk_core::AdkError::artifact(format!(
74                "invalid artifact {field} '{component}': path separators and traversal patterns are not allowed"
75            )));
76        }
77
78        Ok(())
79    }
80
81    /// Ensures the given path stays within the configured base directory.
82    fn ensure_within_base_dir(&self, path: &Path) -> Result<()> {
83        let canonical_base = self.base_dir.canonicalize().map_err(|e| {
84            adk_core::AdkError::artifact(format!("canonicalize base dir failed: {e}"))
85        })?;
86
87        // For paths that may not exist yet, resolve relative to canonical base
88        let canonical_path = match path.canonicalize() {
89            Ok(canonical) => canonical,
90            Err(_) => {
91                let relative = path.strip_prefix(&self.base_dir).unwrap_or(path);
92                canonical_base.join(relative)
93            }
94        };
95
96        if !canonical_path.starts_with(&canonical_base) {
97            return Err(adk_core::AdkError::artifact(
98                "artifact path escapes configured base directory",
99            ));
100        }
101
102        Ok(())
103    }
104
105    fn is_user_scoped(file_name: &str) -> bool {
106        file_name.starts_with("user:")
107    }
108
109    /// Build a safe artifact directory path from validated components.
110    ///
111    /// All components must pass `validate_path_component` before calling this.
112    /// The returned path is guaranteed to be under `self.base_dir`.
113    fn safe_artifact_dir(
114        &self,
115        app_name: &str,
116        user_id: &str,
117        session_id: &str,
118        file_name: &str,
119    ) -> Result<PathBuf> {
120        Self::validate_path_component(app_name, "app name")?;
121        Self::validate_path_component(user_id, "user id")?;
122        Self::validate_path_component(session_id, "session id")?;
123        Self::validate_file_name(file_name)?;
124
125        let dir = if Self::is_user_scoped(file_name) {
126            self.base_dir.join(app_name).join(user_id).join(USER_SCOPED_DIR).join(file_name)
127        } else {
128            self.base_dir.join(app_name).join(user_id).join(session_id).join(file_name)
129        };
130
131        // Verify the constructed path hasn't escaped base_dir
132        self.ensure_within_base_dir(&dir)?;
133        Ok(dir)
134    }
135
136    /// Build a safe version file path from validated components.
137    fn safe_version_path(
138        &self,
139        app_name: &str,
140        user_id: &str,
141        session_id: &str,
142        file_name: &str,
143        version: i64,
144    ) -> Result<PathBuf> {
145        let dir = self.safe_artifact_dir(app_name, user_id, session_id, file_name)?;
146        let path = dir.join(format!("v{version}.json"));
147        Ok(path)
148    }
149
150    async fn read_versions(
151        &self,
152        app_name: &str,
153        user_id: &str,
154        session_id: &str,
155        file_name: &str,
156    ) -> Result<Vec<i64>> {
157        let dir = self.safe_artifact_dir(app_name, user_id, session_id, file_name)?;
158        let mut entries = match fs::read_dir(&dir).await {
159            Ok(entries) => entries,
160            Err(error) if error.kind() == std::io::ErrorKind::NotFound => {
161                return Err(adk_core::AdkError::artifact("artifact not found"));
162            }
163            Err(error) => {
164                return Err(adk_core::AdkError::artifact(format!("read dir failed: {error}")));
165            }
166        };
167
168        let mut versions = Vec::new();
169        while let Some(entry) = entries
170            .next_entry()
171            .await
172            .map_err(|e| adk_core::AdkError::artifact(format!("read dir entry failed: {e}")))?
173        {
174            let Some(file_name) = entry.file_name().to_str().map(ToString::to_string) else {
175                continue;
176            };
177            let Some(raw) =
178                file_name.strip_prefix('v').and_then(|value| value.strip_suffix(".json"))
179            else {
180                continue;
181            };
182            if let Ok(version) = raw.parse::<i64>() {
183                versions.push(version);
184            }
185        }
186
187        if versions.is_empty() {
188            return Err(adk_core::AdkError::artifact("artifact not found"));
189        }
190
191        versions.sort_by(|left, right| right.cmp(left));
192        Ok(versions)
193    }
194
195    async fn list_scope_dir(path: &Path) -> Result<BTreeSet<String>> {
196        let mut names = BTreeSet::new();
197        let mut entries = match fs::read_dir(path).await {
198            Ok(entries) => entries,
199            Err(error) if error.kind() == std::io::ErrorKind::NotFound => return Ok(names),
200            Err(error) => {
201                return Err(adk_core::AdkError::artifact(format!("read dir failed: {error}")));
202            }
203        };
204
205        while let Some(entry) = entries
206            .next_entry()
207            .await
208            .map_err(|e| adk_core::AdkError::artifact(format!("read dir entry failed: {e}")))?
209        {
210            if entry
211                .file_type()
212                .await
213                .map_err(|e| adk_core::AdkError::artifact(format!("file type check failed: {e}")))?
214                .is_dir()
215            {
216                if let Some(name) = entry.file_name().to_str() {
217                    names.insert(name.to_string());
218                }
219            }
220        }
221
222        Ok(names)
223    }
224}
225
226#[async_trait]
227impl ArtifactService for FileArtifactService {
228    async fn save(&self, req: SaveRequest) -> Result<SaveResponse> {
229        let version = match req.version {
230            Some(version) => version,
231            None => self
232                .read_versions(&req.app_name, &req.user_id, &req.session_id, &req.file_name)
233                .await
234                .map(|versions| versions[0] + 1)
235                .unwrap_or(1),
236        };
237
238        // Validate all components reject traversal patterns
239        Self::validate_path_component(&req.app_name, "app name")?;
240        Self::validate_path_component(&req.user_id, "user id")?;
241        Self::validate_path_component(&req.session_id, "session id")?;
242        Self::validate_file_name(&req.file_name)?;
243
244        // base_dir is already canonical from construction
245        let canonical_base = &self.base_dir;
246
247        // Build path from canonical base + validated segments (no user data in base)
248        let canonical_dir = if Self::is_user_scoped(&req.file_name) {
249            canonical_base
250                .join(&req.app_name)
251                .join(&req.user_id)
252                .join(USER_SCOPED_DIR)
253                .join(&req.file_name)
254        } else {
255            canonical_base
256                .join(&req.app_name)
257                .join(&req.user_id)
258                .join(&req.session_id)
259                .join(&req.file_name)
260        };
261
262        fs::create_dir_all(&canonical_dir)
263            .await
264            .map_err(|e| adk_core::AdkError::artifact(format!("create dir failed: {e}")))?;
265
266        // Final canonicalization check after directory exists
267        let verified_dir = canonical_dir.canonicalize().map_err(|e| {
268            adk_core::AdkError::artifact(format!("canonicalize artifact dir failed: {e}"))
269        })?;
270        if !verified_dir.starts_with(canonical_base) {
271            return Err(adk_core::AdkError::artifact(
272                "artifact path escapes configured base directory",
273            ));
274        }
275
276        let write_path = verified_dir.join(format!("v{version}.json"));
277        let payload = serde_json::to_vec(&req.part)
278            .map_err(|error| adk_core::AdkError::artifact(error.to_string()))?;
279        fs::write(write_path, payload)
280            .await
281            .map_err(|e| adk_core::AdkError::artifact(format!("write failed: {e}")))?;
282
283        Ok(SaveResponse { version })
284    }
285
286    async fn load(&self, req: LoadRequest) -> Result<LoadResponse> {
287        let version = match req.version {
288            Some(version) => version,
289            None => {
290                self.read_versions(&req.app_name, &req.user_id, &req.session_id, &req.file_name)
291                    .await?[0]
292            }
293        };
294
295        let path = self.safe_version_path(
296            &req.app_name,
297            &req.user_id,
298            &req.session_id,
299            &req.file_name,
300            version,
301        )?;
302        let payload = fs::read(&path).await.map_err(|error| {
303            if error.kind() == std::io::ErrorKind::NotFound {
304                adk_core::AdkError::artifact("artifact not found")
305            } else {
306                adk_core::AdkError::artifact(format!("read failed: {error}"))
307            }
308        })?;
309
310        let part = serde_json::from_slice::<Part>(&payload)
311            .map_err(|error| adk_core::AdkError::artifact(error.to_string()))?;
312
313        Ok(LoadResponse { part })
314    }
315
316    async fn delete(&self, req: DeleteRequest) -> Result<()> {
317        if let Some(version) = req.version {
318            let path = self.safe_version_path(
319                &req.app_name,
320                &req.user_id,
321                &req.session_id,
322                &req.file_name,
323                version,
324            )?;
325            match fs::remove_file(path).await {
326                Ok(()) => {}
327                Err(error) if error.kind() == std::io::ErrorKind::NotFound => {}
328                Err(error) => {
329                    return Err(adk_core::AdkError::artifact(format!(
330                        "remove file failed: {error}"
331                    )));
332                }
333            }
334        } else {
335            let dir = self.safe_artifact_dir(
336                &req.app_name,
337                &req.user_id,
338                &req.session_id,
339                &req.file_name,
340            )?;
341            match fs::remove_dir_all(dir).await {
342                Ok(()) => {}
343                Err(error) if error.kind() == std::io::ErrorKind::NotFound => {}
344                Err(error) => {
345                    return Err(adk_core::AdkError::artifact(format!(
346                        "remove dir failed: {error}"
347                    )));
348                }
349            }
350        }
351
352        Ok(())
353    }
354
355    async fn list(&self, req: ListRequest) -> Result<ListResponse> {
356        Self::validate_path_component(&req.app_name, "app name")?;
357        Self::validate_path_component(&req.user_id, "user id")?;
358        Self::validate_path_component(&req.session_id, "session id")?;
359
360        // Build paths from validated components only
361        let app = req.app_name.clone();
362        let user = req.user_id.clone();
363        let session = req.session_id.clone();
364        let session_dir = self.base_dir.join(&app).join(&user).join(&session);
365        let user_dir = self.base_dir.join(&app).join(&user).join(USER_SCOPED_DIR);
366
367        self.ensure_within_base_dir(&session_dir)?;
368        self.ensure_within_base_dir(&user_dir)?;
369
370        let mut names = Self::list_scope_dir(&session_dir).await?;
371        names.extend(Self::list_scope_dir(&user_dir).await?);
372
373        Ok(ListResponse { file_names: names.into_iter().collect() })
374    }
375
376    async fn versions(&self, req: VersionsRequest) -> Result<VersionsResponse> {
377        // Validation happens inside read_versions → safe_artifact_dir
378        let versions = self
379            .read_versions(&req.app_name, &req.user_id, &req.session_id, &req.file_name)
380            .await?;
381        Ok(VersionsResponse { versions })
382    }
383
384    async fn health_check(&self) -> Result<()> {
385        fs::create_dir_all(&self.base_dir)
386            .await
387            .map_err(|e| adk_core::AdkError::artifact(format!("health check failed: {e}")))?;
388        let nonce = SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_nanos();
389        let path = self.base_dir.join(format!(".healthcheck-{nonce}"));
390        fs::write(&path, b"ok")
391            .await
392            .map_err(|e| adk_core::AdkError::artifact(format!("health check failed: {e}")))?;
393        fs::remove_file(path)
394            .await
395            .map_err(|e| adk_core::AdkError::artifact(format!("health check failed: {e}")))?;
396        Ok(())
397    }
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403
404    #[tokio::test]
405    async fn user_scoped_artifacts_are_visible_across_sessions() {
406        let tempdir = tempfile::tempdir().unwrap();
407        let service = FileArtifactService::new(tempdir.path()).unwrap();
408
409        service
410            .save(SaveRequest {
411                app_name: "app".into(),
412                user_id: "user".into(),
413                session_id: "s1".into(),
414                file_name: "user:shared.txt".into(),
415                part: Part::Text { text: "hello".into() },
416                version: None,
417            })
418            .await
419            .unwrap();
420
421        let list = service
422            .list(ListRequest {
423                app_name: "app".into(),
424                user_id: "user".into(),
425                session_id: "s2".into(),
426            })
427            .await
428            .unwrap();
429
430        assert_eq!(list.file_names, vec!["user:shared.txt".to_string()]);
431    }
432}