Skip to main content

adk_artifact/
file.rs

1use crate::service::*;
2use adk_core::{Part, Result};
3use async_trait::async_trait;
4use std::collections::BTreeSet;
5use std::path::{Path, PathBuf};
6use std::time::{SystemTime, UNIX_EPOCH};
7use tokio::fs;
8
9const USER_SCOPED_DIR: &str = "_user_scoped_";
10
11/// Sanitize a file name for use as a filesystem path component.
12///
13/// Colons are valid in ADK artifact names (e.g. `user:shared.txt`) but illegal
14/// in Windows file/directory names. Replace them with a double-underscore.
15fn fs_safe_name(name: &str) -> String {
16    name.replace(':', "__")
17}
18
19/// Reverse the filesystem sanitization to recover the original artifact name.
20fn fs_unsafe_name(name: &str) -> String {
21    name.replace("__", ":")
22}
23
24/// Persist artifacts on the local filesystem.
25///
26/// The base directory is created and canonicalized at construction time.
27pub struct FileArtifactService {
28    /// Canonical (absolute, resolved) base directory. Set once at construction.
29    base_dir: PathBuf,
30}
31
32impl FileArtifactService {
33    /// Create a new filesystem-backed artifact service rooted at `base_dir`.
34    ///
35    /// Creates the directory if it doesn't exist and stores the canonical path.
36    ///
37    /// # Errors
38    ///
39    /// Returns an error if the directory cannot be created or canonicalized.
40    pub fn new(base_dir: impl Into<PathBuf>) -> Result<Self> {
41        let raw = base_dir.into();
42        std::fs::create_dir_all(&raw)
43            .map_err(|e| adk_core::AdkError::artifact(format!("create base dir failed: {e}")))?;
44        let canonical = raw.canonicalize().map_err(|e| {
45            adk_core::AdkError::artifact(format!("canonicalize base dir failed: {e}"))
46        })?;
47        Ok(Self { base_dir: canonical })
48    }
49
50    fn validate_file_name(file_name: &str) -> Result<()> {
51        if file_name.is_empty() {
52            return Err(adk_core::AdkError::artifact("invalid artifact file name: empty name"));
53        }
54
55        if file_name.contains('/')
56            || file_name.contains('\\')
57            || file_name == "."
58            || file_name == ".."
59            || file_name.contains("..")
60        {
61            return Err(adk_core::AdkError::artifact(format!(
62                "invalid artifact file name '{}': path separators and traversal patterns are not allowed",
63                file_name
64            )));
65        }
66
67        Ok(())
68    }
69
70    /// Validates a path component (app_name, user_id, session_id) used to build artifact paths.
71    ///
72    /// Rejects empty values, directory separators, and traversal patterns.
73    fn validate_path_component(component: &str, field: &str) -> Result<()> {
74        if component.is_empty() {
75            return Err(adk_core::AdkError::artifact(format!(
76                "invalid artifact {field}: empty value"
77            )));
78        }
79
80        if component.contains('/')
81            || component.contains('\\')
82            || component == "."
83            || component == ".."
84            || component.contains("..")
85        {
86            return Err(adk_core::AdkError::artifact(format!(
87                "invalid artifact {field} '{component}': path separators and traversal patterns are not allowed"
88            )));
89        }
90
91        Ok(())
92    }
93
94    /// Ensures the given path stays within the configured base directory.
95    fn ensure_within_base_dir(&self, path: &Path) -> Result<()> {
96        let canonical_base = self.base_dir.canonicalize().map_err(|e| {
97            adk_core::AdkError::artifact(format!("canonicalize base dir failed: {e}"))
98        })?;
99
100        // For paths that may not exist yet, resolve relative to canonical base
101        let canonical_path = match path.canonicalize() {
102            Ok(canonical) => canonical,
103            Err(_) => {
104                let relative = path.strip_prefix(&self.base_dir).unwrap_or(path);
105                canonical_base.join(relative)
106            }
107        };
108
109        if !canonical_path.starts_with(&canonical_base) {
110            return Err(adk_core::AdkError::artifact(
111                "artifact path escapes configured base directory",
112            ));
113        }
114
115        Ok(())
116    }
117
118    fn is_user_scoped(file_name: &str) -> bool {
119        file_name.starts_with("user:")
120    }
121
122    /// Build a safe artifact directory path from validated components.
123    ///
124    /// All components must pass `validate_path_component` before calling this.
125    /// The returned path is guaranteed to be under `self.base_dir`.
126    fn safe_artifact_dir(
127        &self,
128        app_name: &str,
129        user_id: &str,
130        session_id: &str,
131        file_name: &str,
132    ) -> Result<PathBuf> {
133        Self::validate_path_component(app_name, "app name")?;
134        Self::validate_path_component(user_id, "user id")?;
135        Self::validate_path_component(session_id, "session id")?;
136        Self::validate_file_name(file_name)?;
137
138        let safe_name = fs_safe_name(file_name);
139        let dir = if Self::is_user_scoped(file_name) {
140            self.base_dir.join(app_name).join(user_id).join(USER_SCOPED_DIR).join(&safe_name)
141        } else {
142            self.base_dir.join(app_name).join(user_id).join(session_id).join(&safe_name)
143        };
144
145        // Verify the constructed path hasn't escaped base_dir
146        self.ensure_within_base_dir(&dir)?;
147        Ok(dir)
148    }
149
150    /// Build a safe version file path from validated components.
151    fn safe_version_path(
152        &self,
153        app_name: &str,
154        user_id: &str,
155        session_id: &str,
156        file_name: &str,
157        version: i64,
158    ) -> Result<PathBuf> {
159        let dir = self.safe_artifact_dir(app_name, user_id, session_id, file_name)?;
160        let path = dir.join(format!("v{version}.json"));
161        Ok(path)
162    }
163
164    async fn read_versions(
165        &self,
166        app_name: &str,
167        user_id: &str,
168        session_id: &str,
169        file_name: &str,
170    ) -> Result<Vec<i64>> {
171        let dir = self.safe_artifact_dir(app_name, user_id, session_id, file_name)?;
172        let mut entries = match fs::read_dir(&dir).await {
173            Ok(entries) => entries,
174            Err(error) if error.kind() == std::io::ErrorKind::NotFound => {
175                return Err(adk_core::AdkError::artifact("artifact not found"));
176            }
177            Err(error) => {
178                return Err(adk_core::AdkError::artifact(format!("read dir failed: {error}")));
179            }
180        };
181
182        let mut versions = Vec::new();
183        while let Some(entry) = entries
184            .next_entry()
185            .await
186            .map_err(|e| adk_core::AdkError::artifact(format!("read dir entry failed: {e}")))?
187        {
188            let Some(file_name) = entry.file_name().to_str().map(ToString::to_string) else {
189                continue;
190            };
191            let Some(raw) =
192                file_name.strip_prefix('v').and_then(|value| value.strip_suffix(".json"))
193            else {
194                continue;
195            };
196            if let Ok(version) = raw.parse::<i64>() {
197                versions.push(version);
198            }
199        }
200
201        if versions.is_empty() {
202            return Err(adk_core::AdkError::artifact("artifact not found"));
203        }
204
205        versions.sort_by(|left, right| right.cmp(left));
206        Ok(versions)
207    }
208
209    async fn list_scope_dir(path: &Path) -> Result<BTreeSet<String>> {
210        let mut names = BTreeSet::new();
211        let mut entries = match fs::read_dir(path).await {
212            Ok(entries) => entries,
213            Err(error) if error.kind() == std::io::ErrorKind::NotFound => return Ok(names),
214            Err(error) => {
215                return Err(adk_core::AdkError::artifact(format!("read dir failed: {error}")));
216            }
217        };
218
219        while let Some(entry) = entries
220            .next_entry()
221            .await
222            .map_err(|e| adk_core::AdkError::artifact(format!("read dir entry failed: {e}")))?
223        {
224            if entry
225                .file_type()
226                .await
227                .map_err(|e| adk_core::AdkError::artifact(format!("file type check failed: {e}")))?
228                .is_dir()
229            {
230                if let Some(name) = entry.file_name().to_str() {
231                    names.insert(fs_unsafe_name(name));
232                }
233            }
234        }
235
236        Ok(names)
237    }
238}
239
240#[async_trait]
241impl ArtifactService for FileArtifactService {
242    async fn save(&self, req: SaveRequest) -> Result<SaveResponse> {
243        let version = match req.version {
244            Some(version) => version,
245            None => self
246                .read_versions(&req.app_name, &req.user_id, &req.session_id, &req.file_name)
247                .await
248                .map(|versions| versions[0] + 1)
249                .unwrap_or(1),
250        };
251
252        // Validate all components reject traversal patterns
253        Self::validate_path_component(&req.app_name, "app name")?;
254        Self::validate_path_component(&req.user_id, "user id")?;
255        Self::validate_path_component(&req.session_id, "session id")?;
256        Self::validate_file_name(&req.file_name)?;
257
258        // base_dir is already canonical from construction
259        let canonical_base = &self.base_dir;
260
261        // Build path from canonical base + validated segments (no user data in base)
262        let safe_name = fs_safe_name(&req.file_name);
263        let canonical_dir = if Self::is_user_scoped(&req.file_name) {
264            canonical_base
265                .join(&req.app_name)
266                .join(&req.user_id)
267                .join(USER_SCOPED_DIR)
268                .join(&safe_name)
269        } else {
270            canonical_base
271                .join(&req.app_name)
272                .join(&req.user_id)
273                .join(&req.session_id)
274                .join(&safe_name)
275        };
276
277        fs::create_dir_all(&canonical_dir)
278            .await
279            .map_err(|e| adk_core::AdkError::artifact(format!("create dir failed: {e}")))?;
280
281        // Final canonicalization check after directory exists
282        let verified_dir = canonical_dir.canonicalize().map_err(|e| {
283            adk_core::AdkError::artifact(format!("canonicalize artifact dir failed: {e}"))
284        })?;
285        if !verified_dir.starts_with(canonical_base) {
286            return Err(adk_core::AdkError::artifact(
287                "artifact path escapes configured base directory",
288            ));
289        }
290
291        let write_path = verified_dir.join(format!("v{version}.json"));
292        let payload = serde_json::to_vec(&req.part)
293            .map_err(|error| adk_core::AdkError::artifact(error.to_string()))?;
294        fs::write(write_path, payload)
295            .await
296            .map_err(|e| adk_core::AdkError::artifact(format!("write failed: {e}")))?;
297
298        Ok(SaveResponse { version })
299    }
300
301    async fn load(&self, req: LoadRequest) -> Result<LoadResponse> {
302        let version = match req.version {
303            Some(version) => version,
304            None => {
305                self.read_versions(&req.app_name, &req.user_id, &req.session_id, &req.file_name)
306                    .await?[0]
307            }
308        };
309
310        let path = self.safe_version_path(
311            &req.app_name,
312            &req.user_id,
313            &req.session_id,
314            &req.file_name,
315            version,
316        )?;
317        let payload = fs::read(&path).await.map_err(|error| {
318            if error.kind() == std::io::ErrorKind::NotFound {
319                adk_core::AdkError::artifact("artifact not found")
320            } else {
321                adk_core::AdkError::artifact(format!("read failed: {error}"))
322            }
323        })?;
324
325        let part = serde_json::from_slice::<Part>(&payload)
326            .map_err(|error| adk_core::AdkError::artifact(error.to_string()))?;
327
328        Ok(LoadResponse { part })
329    }
330
331    async fn delete(&self, req: DeleteRequest) -> Result<()> {
332        if let Some(version) = req.version {
333            let path = self.safe_version_path(
334                &req.app_name,
335                &req.user_id,
336                &req.session_id,
337                &req.file_name,
338                version,
339            )?;
340            match fs::remove_file(path).await {
341                Ok(()) => {}
342                Err(error) if error.kind() == std::io::ErrorKind::NotFound => {}
343                Err(error) => {
344                    return Err(adk_core::AdkError::artifact(format!(
345                        "remove file failed: {error}"
346                    )));
347                }
348            }
349        } else {
350            let dir = self.safe_artifact_dir(
351                &req.app_name,
352                &req.user_id,
353                &req.session_id,
354                &req.file_name,
355            )?;
356            match fs::remove_dir_all(dir).await {
357                Ok(()) => {}
358                Err(error) if error.kind() == std::io::ErrorKind::NotFound => {}
359                Err(error) => {
360                    return Err(adk_core::AdkError::artifact(format!(
361                        "remove dir failed: {error}"
362                    )));
363                }
364            }
365        }
366
367        Ok(())
368    }
369
370    async fn list(&self, req: ListRequest) -> Result<ListResponse> {
371        Self::validate_path_component(&req.app_name, "app name")?;
372        Self::validate_path_component(&req.user_id, "user id")?;
373        Self::validate_path_component(&req.session_id, "session id")?;
374
375        // Build paths from validated components only
376        let app = req.app_name.clone();
377        let user = req.user_id.clone();
378        let session = req.session_id.clone();
379        let session_dir = self.base_dir.join(&app).join(&user).join(&session);
380        let user_dir = self.base_dir.join(&app).join(&user).join(USER_SCOPED_DIR);
381
382        self.ensure_within_base_dir(&session_dir)?;
383        self.ensure_within_base_dir(&user_dir)?;
384
385        let mut names = Self::list_scope_dir(&session_dir).await?;
386        names.extend(Self::list_scope_dir(&user_dir).await?);
387
388        Ok(ListResponse { file_names: names.into_iter().collect() })
389    }
390
391    async fn versions(&self, req: VersionsRequest) -> Result<VersionsResponse> {
392        // Validation happens inside read_versions → safe_artifact_dir
393        let versions = self
394            .read_versions(&req.app_name, &req.user_id, &req.session_id, &req.file_name)
395            .await?;
396        Ok(VersionsResponse { versions })
397    }
398
399    async fn health_check(&self) -> Result<()> {
400        fs::create_dir_all(&self.base_dir)
401            .await
402            .map_err(|e| adk_core::AdkError::artifact(format!("health check failed: {e}")))?;
403        let nonce = SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_nanos();
404        let path = self.base_dir.join(format!(".healthcheck-{nonce}"));
405        fs::write(&path, b"ok")
406            .await
407            .map_err(|e| adk_core::AdkError::artifact(format!("health check failed: {e}")))?;
408        fs::remove_file(path)
409            .await
410            .map_err(|e| adk_core::AdkError::artifact(format!("health check failed: {e}")))?;
411        Ok(())
412    }
413}
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418
419    #[tokio::test]
420    async fn user_scoped_artifacts_are_visible_across_sessions() {
421        let tempdir = tempfile::tempdir().unwrap();
422        let service = FileArtifactService::new(tempdir.path()).unwrap();
423
424        service
425            .save(SaveRequest {
426                app_name: "app".into(),
427                user_id: "user".into(),
428                session_id: "s1".into(),
429                file_name: "user:shared.txt".into(),
430                part: Part::Text { text: "hello".into() },
431                version: None,
432            })
433            .await
434            .unwrap();
435
436        let list = service
437            .list(ListRequest {
438                app_name: "app".into(),
439                user_id: "user".into(),
440                session_id: "s2".into(),
441            })
442            .await
443            .unwrap();
444
445        assert_eq!(list.file_names, vec!["user:shared.txt".to_string()]);
446    }
447}