Skip to main content

adk_artifact/
file.rs

1use crate::service::*;
2use adk_core::{Part, Result};
3use async_trait::async_trait;
4use std::collections::BTreeSet;
5use std::path::{Path, PathBuf};
6use std::time::{SystemTime, UNIX_EPOCH};
7use tokio::fs;
8
9const USER_SCOPED_DIR: &str = "_user_scoped_";
10
11/// Sanitize a file name for use as a filesystem path component.
12///
13/// Colons are valid in ADK artifact names (e.g. `user:shared.txt`) but illegal
14/// in Windows file/directory names. Replace them with a double-underscore.
15fn fs_safe_name(name: &str) -> String {
16    name.replace(':', "__")
17}
18
19/// Reverse the filesystem sanitization to recover the original artifact name.
20fn fs_unsafe_name(name: &str) -> String {
21    name.replace("__", ":")
22}
23
24/// Persist artifacts on the local filesystem.
25///
26/// The base directory is created and canonicalized at construction time.
27pub struct FileArtifactService {
28    /// Canonical (absolute, resolved) base directory. Set once at construction.
29    base_dir: PathBuf,
30}
31
32impl FileArtifactService {
33    /// Create a new filesystem-backed artifact service rooted at `base_dir`.
34    ///
35    /// Creates the directory if it doesn't exist and stores the canonical path.
36    ///
37    /// # Errors
38    ///
39    /// Returns an error if the directory cannot be created or canonicalized.
40    pub fn new(base_dir: impl Into<PathBuf>) -> Result<Self> {
41        let raw = base_dir.into();
42        std::fs::create_dir_all(&raw)
43            .map_err(|e| adk_core::AdkError::artifact(format!("create base dir failed: {e}")))?;
44        let canonical = raw.canonicalize().map_err(|e| {
45            adk_core::AdkError::artifact(format!("canonicalize base dir failed: {e}"))
46        })?;
47        Ok(Self { base_dir: canonical })
48    }
49
50    fn validate_file_name(file_name: &str) -> Result<()> {
51        if file_name.is_empty() {
52            return Err(adk_core::AdkError::artifact("invalid artifact file name: empty name"));
53        }
54
55        if file_name.contains('/')
56            || file_name.contains('\\')
57            || file_name == "."
58            || file_name == ".."
59            || file_name.contains("..")
60        {
61            return Err(adk_core::AdkError::artifact(format!(
62                "invalid artifact file name '{}': path separators and traversal patterns are not allowed",
63                file_name
64            )));
65        }
66
67        Ok(())
68    }
69
70    /// Validates a path component (app_name, user_id, session_id) used to build artifact paths.
71    ///
72    /// Rejects empty values, directory separators, and traversal patterns.
73    fn validate_path_component(component: &str, field: &str) -> Result<()> {
74        if component.is_empty() {
75            return Err(adk_core::AdkError::artifact(format!(
76                "invalid artifact {field}: empty value"
77            )));
78        }
79
80        if component.contains('/')
81            || component.contains('\\')
82            || component == "."
83            || component == ".."
84            || component.contains("..")
85        {
86            return Err(adk_core::AdkError::artifact(format!(
87                "invalid artifact {field} '{component}': path separators and traversal patterns are not allowed"
88            )));
89        }
90
91        Ok(())
92    }
93
94    /// Ensures the given path stays within the configured base directory.
95    fn ensure_within_base_dir(&self, path: &Path) -> Result<()> {
96        let canonical_base = self.base_dir.canonicalize().map_err(|e| {
97            adk_core::AdkError::artifact(format!("canonicalize base dir failed: {e}"))
98        })?;
99
100        // For paths that may not exist yet, resolve relative to canonical base
101        let canonical_path = match path.canonicalize() {
102            Ok(canonical) => canonical,
103            Err(_) => {
104                let relative = path.strip_prefix(&self.base_dir).unwrap_or(path);
105                canonical_base.join(relative)
106            }
107        };
108
109        if !canonical_path.starts_with(&canonical_base) {
110            return Err(adk_core::AdkError::artifact(
111                "artifact path escapes configured base directory",
112            ));
113        }
114
115        Ok(())
116    }
117
118    fn is_user_scoped(file_name: &str) -> bool {
119        file_name.starts_with("user:")
120    }
121
122    /// Build a safe artifact directory path from validated components.
123    ///
124    /// All components must pass `validate_path_component` before calling this.
125    /// The returned path is guaranteed to be under `self.base_dir`.
126    fn safe_artifact_dir(
127        &self,
128        app_name: &str,
129        user_id: &str,
130        session_id: &str,
131        file_name: &str,
132    ) -> Result<PathBuf> {
133        Self::validate_path_component(app_name, "app name")?;
134        Self::validate_path_component(user_id, "user id")?;
135        Self::validate_path_component(session_id, "session id")?;
136        Self::validate_file_name(file_name)?;
137
138        let safe_name = fs_safe_name(file_name);
139        let dir = if Self::is_user_scoped(file_name) {
140            self.base_dir.join(app_name).join(user_id).join(USER_SCOPED_DIR).join(&safe_name)
141        } else {
142            self.base_dir.join(app_name).join(user_id).join(session_id).join(&safe_name)
143        };
144
145        // Verify the constructed path hasn't escaped base_dir
146        self.ensure_within_base_dir(&dir)?;
147        Ok(dir)
148    }
149
150    /// Build a safe version file path from validated components.
151    fn safe_version_path(
152        &self,
153        app_name: &str,
154        user_id: &str,
155        session_id: &str,
156        file_name: &str,
157        version: i64,
158    ) -> Result<PathBuf> {
159        let dir = self.safe_artifact_dir(app_name, user_id, session_id, file_name)?;
160        let path = dir.join(format!("v{version}.json"));
161        Ok(path)
162    }
163
164    async fn read_versions(
165        &self,
166        app_name: &str,
167        user_id: &str,
168        session_id: &str,
169        file_name: &str,
170    ) -> Result<Vec<i64>> {
171        let dir = self.safe_artifact_dir(app_name, user_id, session_id, file_name)?;
172        let mut entries = match fs::read_dir(&dir).await {
173            Ok(entries) => entries,
174            Err(error) if error.kind() == std::io::ErrorKind::NotFound => {
175                return Err(adk_core::AdkError::artifact("artifact not found"));
176            }
177            Err(error) => {
178                return Err(adk_core::AdkError::artifact(format!("read dir failed: {error}")));
179            }
180        };
181
182        let mut versions = Vec::new();
183        while let Some(entry) = entries
184            .next_entry()
185            .await
186            .map_err(|e| adk_core::AdkError::artifact(format!("read dir entry failed: {e}")))?
187        {
188            let Some(file_name) = entry.file_name().to_str().map(ToString::to_string) else {
189                continue;
190            };
191            let Some(raw) =
192                file_name.strip_prefix('v').and_then(|value| value.strip_suffix(".json"))
193            else {
194                continue;
195            };
196            if let Ok(version) = raw.parse::<i64>() {
197                versions.push(version);
198            }
199        }
200
201        if versions.is_empty() {
202            return Err(adk_core::AdkError::artifact("artifact not found"));
203        }
204
205        versions.sort_by(|left, right| right.cmp(left));
206        Ok(versions)
207    }
208
209    async fn list_scope_dir(path: &Path) -> Result<BTreeSet<String>> {
210        let mut names = BTreeSet::new();
211        let mut entries = match fs::read_dir(path).await {
212            Ok(entries) => entries,
213            Err(error) if error.kind() == std::io::ErrorKind::NotFound => return Ok(names),
214            Err(error) => {
215                return Err(adk_core::AdkError::artifact(format!("read dir failed: {error}")));
216            }
217        };
218
219        while let Some(entry) = entries
220            .next_entry()
221            .await
222            .map_err(|e| adk_core::AdkError::artifact(format!("read dir entry failed: {e}")))?
223        {
224            if entry
225                .file_type()
226                .await
227                .map_err(|e| adk_core::AdkError::artifact(format!("file type check failed: {e}")))?
228                .is_dir()
229                && let Some(name) = entry.file_name().to_str()
230            {
231                names.insert(fs_unsafe_name(name));
232            }
233        }
234
235        Ok(names)
236    }
237}
238
239#[async_trait]
240impl ArtifactService for FileArtifactService {
241    async fn save(&self, req: SaveRequest) -> Result<SaveResponse> {
242        let version = match req.version {
243            Some(version) => version,
244            None => self
245                .read_versions(&req.app_name, &req.user_id, &req.session_id, &req.file_name)
246                .await
247                .map(|versions| versions[0] + 1)
248                .unwrap_or(1),
249        };
250
251        // Validate all components reject traversal patterns
252        Self::validate_path_component(&req.app_name, "app name")?;
253        Self::validate_path_component(&req.user_id, "user id")?;
254        Self::validate_path_component(&req.session_id, "session id")?;
255        Self::validate_file_name(&req.file_name)?;
256
257        // base_dir is already canonical from construction
258        let canonical_base = &self.base_dir;
259
260        // Build path from canonical base + validated segments (no user data in base)
261        let safe_name = fs_safe_name(&req.file_name);
262        let canonical_dir = if Self::is_user_scoped(&req.file_name) {
263            canonical_base
264                .join(&req.app_name)
265                .join(&req.user_id)
266                .join(USER_SCOPED_DIR)
267                .join(&safe_name)
268        } else {
269            canonical_base
270                .join(&req.app_name)
271                .join(&req.user_id)
272                .join(&req.session_id)
273                .join(&safe_name)
274        };
275
276        fs::create_dir_all(&canonical_dir)
277            .await
278            .map_err(|e| adk_core::AdkError::artifact(format!("create dir failed: {e}")))?;
279
280        // Final canonicalization check after directory exists
281        let verified_dir = canonical_dir.canonicalize().map_err(|e| {
282            adk_core::AdkError::artifact(format!("canonicalize artifact dir failed: {e}"))
283        })?;
284        if !verified_dir.starts_with(canonical_base) {
285            return Err(adk_core::AdkError::artifact(
286                "artifact path escapes configured base directory",
287            ));
288        }
289
290        let write_path = verified_dir.join(format!("v{version}.json"));
291        let payload = serde_json::to_vec(&req.part)
292            .map_err(|error| adk_core::AdkError::artifact(error.to_string()))?;
293        fs::write(write_path, payload)
294            .await
295            .map_err(|e| adk_core::AdkError::artifact(format!("write failed: {e}")))?;
296
297        Ok(SaveResponse { version })
298    }
299
300    async fn load(&self, req: LoadRequest) -> Result<LoadResponse> {
301        let version = match req.version {
302            Some(version) => version,
303            None => {
304                self.read_versions(&req.app_name, &req.user_id, &req.session_id, &req.file_name)
305                    .await?[0]
306            }
307        };
308
309        let path = self.safe_version_path(
310            &req.app_name,
311            &req.user_id,
312            &req.session_id,
313            &req.file_name,
314            version,
315        )?;
316        let payload = fs::read(&path).await.map_err(|error| {
317            if error.kind() == std::io::ErrorKind::NotFound {
318                adk_core::AdkError::artifact("artifact not found")
319            } else {
320                adk_core::AdkError::artifact(format!("read failed: {error}"))
321            }
322        })?;
323
324        let part = serde_json::from_slice::<Part>(&payload)
325            .map_err(|error| adk_core::AdkError::artifact(error.to_string()))?;
326
327        Ok(LoadResponse { part })
328    }
329
330    async fn delete(&self, req: DeleteRequest) -> Result<()> {
331        if let Some(version) = req.version {
332            let path = self.safe_version_path(
333                &req.app_name,
334                &req.user_id,
335                &req.session_id,
336                &req.file_name,
337                version,
338            )?;
339            match fs::remove_file(path).await {
340                Ok(()) => {}
341                Err(error) if error.kind() == std::io::ErrorKind::NotFound => {}
342                Err(error) => {
343                    return Err(adk_core::AdkError::artifact(format!(
344                        "remove file failed: {error}"
345                    )));
346                }
347            }
348        } else {
349            let dir = self.safe_artifact_dir(
350                &req.app_name,
351                &req.user_id,
352                &req.session_id,
353                &req.file_name,
354            )?;
355            match fs::remove_dir_all(dir).await {
356                Ok(()) => {}
357                Err(error) if error.kind() == std::io::ErrorKind::NotFound => {}
358                Err(error) => {
359                    return Err(adk_core::AdkError::artifact(format!(
360                        "remove dir failed: {error}"
361                    )));
362                }
363            }
364        }
365
366        Ok(())
367    }
368
369    async fn list(&self, req: ListRequest) -> Result<ListResponse> {
370        Self::validate_path_component(&req.app_name, "app name")?;
371        Self::validate_path_component(&req.user_id, "user id")?;
372        Self::validate_path_component(&req.session_id, "session id")?;
373
374        // Build paths from validated components only
375        let app = req.app_name.clone();
376        let user = req.user_id.clone();
377        let session = req.session_id.clone();
378        let session_dir = self.base_dir.join(&app).join(&user).join(&session);
379        let user_dir = self.base_dir.join(&app).join(&user).join(USER_SCOPED_DIR);
380
381        self.ensure_within_base_dir(&session_dir)?;
382        self.ensure_within_base_dir(&user_dir)?;
383
384        let mut names = Self::list_scope_dir(&session_dir).await?;
385        names.extend(Self::list_scope_dir(&user_dir).await?);
386
387        Ok(ListResponse { file_names: names.into_iter().collect() })
388    }
389
390    async fn versions(&self, req: VersionsRequest) -> Result<VersionsResponse> {
391        // Validation happens inside read_versions → safe_artifact_dir
392        let versions = self
393            .read_versions(&req.app_name, &req.user_id, &req.session_id, &req.file_name)
394            .await?;
395        Ok(VersionsResponse { versions })
396    }
397
398    async fn health_check(&self) -> Result<()> {
399        fs::create_dir_all(&self.base_dir)
400            .await
401            .map_err(|e| adk_core::AdkError::artifact(format!("health check failed: {e}")))?;
402        let nonce = SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_nanos();
403        let path = self.base_dir.join(format!(".healthcheck-{nonce}"));
404        fs::write(&path, b"ok")
405            .await
406            .map_err(|e| adk_core::AdkError::artifact(format!("health check failed: {e}")))?;
407        fs::remove_file(path)
408            .await
409            .map_err(|e| adk_core::AdkError::artifact(format!("health check failed: {e}")))?;
410        Ok(())
411    }
412}
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417
418    #[tokio::test]
419    async fn user_scoped_artifacts_are_visible_across_sessions() {
420        let tempdir = tempfile::tempdir().unwrap();
421        let service = FileArtifactService::new(tempdir.path()).unwrap();
422
423        service
424            .save(SaveRequest {
425                app_name: "app".into(),
426                user_id: "user".into(),
427                session_id: "s1".into(),
428                file_name: "user:shared.txt".into(),
429                part: Part::Text { text: "hello".into() },
430                version: None,
431            })
432            .await
433            .unwrap();
434
435        let list = service
436            .list(ListRequest {
437                app_name: "app".into(),
438                user_id: "user".into(),
439                session_id: "s2".into(),
440            })
441            .await
442            .unwrap();
443
444        assert_eq!(list.file_names, vec!["user:shared.txt".to_string()]);
445    }
446}