Skip to main content

adk_artifact/
file.rs

1use crate::service::*;
2use adk_core::{Part, Result};
3use async_trait::async_trait;
4use std::collections::BTreeSet;
5use std::path::{Path, PathBuf};
6use std::time::{SystemTime, UNIX_EPOCH};
7use tokio::fs;
8
9const USER_SCOPED_DIR: &str = "_user_scoped_";
10
11/// Persist artifacts on the local filesystem.
12pub struct FileArtifactService {
13    base_dir: PathBuf,
14}
15
16impl FileArtifactService {
17    /// Create a new filesystem-backed artifact service rooted at `base_dir`.
18    pub fn new(base_dir: impl Into<PathBuf>) -> Self {
19        Self { base_dir: base_dir.into() }
20    }
21
22    fn validate_file_name(file_name: &str) -> Result<()> {
23        if file_name.is_empty() {
24            return Err(adk_core::AdkError::Artifact(
25                "invalid artifact file name: empty name".to_string(),
26            ));
27        }
28
29        if file_name.contains('/')
30            || file_name.contains('\\')
31            || file_name == "."
32            || file_name == ".."
33            || file_name.contains("..")
34        {
35            return Err(adk_core::AdkError::Artifact(format!(
36                "invalid artifact file name '{}': path separators and traversal patterns are not allowed",
37                file_name
38            )));
39        }
40
41        Ok(())
42    }
43
44    fn is_user_scoped(file_name: &str) -> bool {
45        file_name.starts_with("user:")
46    }
47
48    fn artifact_dir(
49        &self,
50        app_name: &str,
51        user_id: &str,
52        session_id: &str,
53        file_name: &str,
54    ) -> PathBuf {
55        if Self::is_user_scoped(file_name) {
56            self.base_dir.join(app_name).join(user_id).join(USER_SCOPED_DIR).join(file_name)
57        } else {
58            self.base_dir.join(app_name).join(user_id).join(session_id).join(file_name)
59        }
60    }
61
62    fn version_path(
63        &self,
64        app_name: &str,
65        user_id: &str,
66        session_id: &str,
67        file_name: &str,
68        version: i64,
69    ) -> PathBuf {
70        self.artifact_dir(app_name, user_id, session_id, file_name).join(format!("v{version}.json"))
71    }
72
73    async fn read_versions(
74        &self,
75        app_name: &str,
76        user_id: &str,
77        session_id: &str,
78        file_name: &str,
79    ) -> Result<Vec<i64>> {
80        let dir = self.artifact_dir(app_name, user_id, session_id, file_name);
81        let mut entries = match fs::read_dir(&dir).await {
82            Ok(entries) => entries,
83            Err(error) if error.kind() == std::io::ErrorKind::NotFound => {
84                return Err(adk_core::AdkError::Artifact("artifact not found".into()));
85            }
86            Err(error) => return Err(error.into()),
87        };
88
89        let mut versions = Vec::new();
90        while let Some(entry) = entries.next_entry().await? {
91            let Some(file_name) = entry.file_name().to_str().map(ToString::to_string) else {
92                continue;
93            };
94            let Some(raw) =
95                file_name.strip_prefix('v').and_then(|value| value.strip_suffix(".json"))
96            else {
97                continue;
98            };
99            if let Ok(version) = raw.parse::<i64>() {
100                versions.push(version);
101            }
102        }
103
104        if versions.is_empty() {
105            return Err(adk_core::AdkError::Artifact("artifact not found".into()));
106        }
107
108        versions.sort_by(|left, right| right.cmp(left));
109        Ok(versions)
110    }
111
112    async fn list_scope_dir(path: &Path) -> Result<BTreeSet<String>> {
113        let mut names = BTreeSet::new();
114        let mut entries = match fs::read_dir(path).await {
115            Ok(entries) => entries,
116            Err(error) if error.kind() == std::io::ErrorKind::NotFound => return Ok(names),
117            Err(error) => return Err(error.into()),
118        };
119
120        while let Some(entry) = entries.next_entry().await? {
121            if entry.file_type().await?.is_dir() {
122                if let Some(name) = entry.file_name().to_str() {
123                    names.insert(name.to_string());
124                }
125            }
126        }
127
128        Ok(names)
129    }
130}
131
132#[async_trait]
133impl ArtifactService for FileArtifactService {
134    async fn save(&self, req: SaveRequest) -> Result<SaveResponse> {
135        Self::validate_file_name(&req.file_name)?;
136
137        let version = match req.version {
138            Some(version) => version,
139            None => self
140                .read_versions(&req.app_name, &req.user_id, &req.session_id, &req.file_name)
141                .await
142                .map(|versions| versions[0] + 1)
143                .unwrap_or(1),
144        };
145
146        let dir = self.artifact_dir(&req.app_name, &req.user_id, &req.session_id, &req.file_name);
147        fs::create_dir_all(&dir).await?;
148        let path = self.version_path(
149            &req.app_name,
150            &req.user_id,
151            &req.session_id,
152            &req.file_name,
153            version,
154        );
155        let payload = serde_json::to_vec(&req.part)
156            .map_err(|error| adk_core::AdkError::Artifact(error.to_string()))?;
157        fs::write(path, payload).await?;
158
159        Ok(SaveResponse { version })
160    }
161
162    async fn load(&self, req: LoadRequest) -> Result<LoadResponse> {
163        Self::validate_file_name(&req.file_name)?;
164
165        let version = match req.version {
166            Some(version) => version,
167            None => {
168                self.read_versions(&req.app_name, &req.user_id, &req.session_id, &req.file_name)
169                    .await?[0]
170            }
171        };
172
173        let payload = fs::read(self.version_path(
174            &req.app_name,
175            &req.user_id,
176            &req.session_id,
177            &req.file_name,
178            version,
179        ))
180        .await
181        .map_err(|error| {
182            if error.kind() == std::io::ErrorKind::NotFound {
183                adk_core::AdkError::Artifact("artifact not found".into())
184            } else {
185                error.into()
186            }
187        })?;
188
189        let part = serde_json::from_slice::<Part>(&payload)
190            .map_err(|error| adk_core::AdkError::Artifact(error.to_string()))?;
191
192        Ok(LoadResponse { part })
193    }
194
195    async fn delete(&self, req: DeleteRequest) -> Result<()> {
196        Self::validate_file_name(&req.file_name)?;
197
198        if let Some(version) = req.version {
199            let path = self.version_path(
200                &req.app_name,
201                &req.user_id,
202                &req.session_id,
203                &req.file_name,
204                version,
205            );
206            match fs::remove_file(path).await {
207                Ok(()) => {}
208                Err(error) if error.kind() == std::io::ErrorKind::NotFound => {}
209                Err(error) => return Err(error.into()),
210            }
211        } else {
212            let dir =
213                self.artifact_dir(&req.app_name, &req.user_id, &req.session_id, &req.file_name);
214            match fs::remove_dir_all(dir).await {
215                Ok(()) => {}
216                Err(error) if error.kind() == std::io::ErrorKind::NotFound => {}
217                Err(error) => return Err(error.into()),
218            }
219        }
220
221        Ok(())
222    }
223
224    async fn list(&self, req: ListRequest) -> Result<ListResponse> {
225        let session_dir =
226            self.base_dir.join(&req.app_name).join(&req.user_id).join(&req.session_id);
227        let user_dir = self.base_dir.join(&req.app_name).join(&req.user_id).join(USER_SCOPED_DIR);
228
229        let mut names = Self::list_scope_dir(&session_dir).await?;
230        names.extend(Self::list_scope_dir(&user_dir).await?);
231
232        Ok(ListResponse { file_names: names.into_iter().collect() })
233    }
234
235    async fn versions(&self, req: VersionsRequest) -> Result<VersionsResponse> {
236        Self::validate_file_name(&req.file_name)?;
237        let versions = self
238            .read_versions(&req.app_name, &req.user_id, &req.session_id, &req.file_name)
239            .await?;
240        Ok(VersionsResponse { versions })
241    }
242
243    async fn health_check(&self) -> Result<()> {
244        fs::create_dir_all(&self.base_dir).await?;
245        let nonce = SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_nanos();
246        let path = self.base_dir.join(format!(".healthcheck-{nonce}"));
247        fs::write(&path, b"ok").await?;
248        fs::remove_file(path).await?;
249        Ok(())
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    #[tokio::test]
258    async fn user_scoped_artifacts_are_visible_across_sessions() {
259        let tempdir = tempfile::tempdir().unwrap();
260        let service = FileArtifactService::new(tempdir.path());
261
262        service
263            .save(SaveRequest {
264                app_name: "app".into(),
265                user_id: "user".into(),
266                session_id: "s1".into(),
267                file_name: "user:shared.txt".into(),
268                part: Part::Text { text: "hello".into() },
269                version: None,
270            })
271            .await
272            .unwrap();
273
274        let list = service
275            .list(ListRequest {
276                app_name: "app".into(),
277                user_id: "user".into(),
278                session_id: "s2".into(),
279            })
280            .await
281            .unwrap();
282
283        assert_eq!(list.file_names, vec!["user:shared.txt".to_string()]);
284    }
285}