Skip to main content

adk_artifact/
inmemory.rs

1use crate::service::*;
2use adk_core::{Part, Result};
3use async_trait::async_trait;
4use std::collections::HashMap;
5use std::sync::{Arc, RwLock};
6
7const USER_SCOPED_KEY: &str = "user";
8
9#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
10struct ArtifactKey {
11    app_name: String,
12    user_id: String,
13    session_id: String,
14    file_name: String,
15    version: i64,
16}
17
18/// In-memory artifact storage for development and testing.
19///
20/// Artifacts are stored in a `HashMap` behind an `RwLock`. Data is lost
21/// when the process exits. For persistent storage, use [`FileArtifactService`](crate::FileArtifactService).
22pub struct InMemoryArtifactService {
23    artifacts: Arc<RwLock<HashMap<ArtifactKey, Part>>>,
24}
25
26impl InMemoryArtifactService {
27    /// Create a new empty in-memory artifact store.
28    pub fn new() -> Self {
29        Self { artifacts: Arc::new(RwLock::new(HashMap::new())) }
30    }
31
32    fn is_user_scoped(file_name: &str) -> bool {
33        file_name.starts_with("user:")
34    }
35
36    fn get_session_id(session_id: &str, file_name: &str) -> String {
37        if Self::is_user_scoped(file_name) {
38            USER_SCOPED_KEY.to_string()
39        } else {
40            session_id.to_string()
41        }
42    }
43
44    fn validate_file_name(file_name: &str) -> Result<()> {
45        if file_name.is_empty() {
46            return Err(adk_core::AdkError::artifact("invalid artifact file name: empty name"));
47        }
48
49        // Prevent path traversal and path-like names; artifacts are logical keys, not paths.
50        if file_name.contains('/')
51            || file_name.contains('\\')
52            || file_name == "."
53            || file_name == ".."
54            || file_name.contains("..")
55        {
56            return Err(adk_core::AdkError::artifact(format!(
57                "invalid artifact file name '{}': path separators and traversal patterns are not allowed",
58                file_name
59            )));
60        }
61
62        Ok(())
63    }
64
65    fn find_latest_version(
66        &self,
67        app_name: &str,
68        user_id: &str,
69        session_id: &str,
70        file_name: &str,
71    ) -> Option<(i64, Part)> {
72        let artifacts = self.artifacts.read().unwrap();
73        let mut versions: Vec<_> = artifacts
74            .iter()
75            .filter(|(k, _)| {
76                k.app_name == app_name
77                    && k.user_id == user_id
78                    && k.session_id == session_id
79                    && k.file_name == file_name
80            })
81            .collect();
82
83        versions.sort_by_key(|b| std::cmp::Reverse(b.0.version));
84        versions.first().map(|(k, v)| (k.version, (*v).clone()))
85    }
86}
87
88impl Default for InMemoryArtifactService {
89    fn default() -> Self {
90        Self::new()
91    }
92}
93
94#[async_trait]
95impl ArtifactService for InMemoryArtifactService {
96    async fn save(&self, req: SaveRequest) -> Result<SaveResponse> {
97        Self::validate_file_name(&req.file_name)?;
98        let session_id = Self::get_session_id(&req.session_id, &req.file_name);
99
100        let version = if let Some(v) = req.version {
101            v
102        } else {
103            let latest =
104                self.find_latest_version(&req.app_name, &req.user_id, &session_id, &req.file_name);
105            latest.map(|(v, _)| v + 1).unwrap_or(1)
106        };
107
108        let key = ArtifactKey {
109            app_name: req.app_name,
110            user_id: req.user_id,
111            session_id,
112            file_name: req.file_name,
113            version,
114        };
115
116        let mut artifacts = self.artifacts.write().unwrap();
117        artifacts.insert(key, req.part);
118
119        Ok(SaveResponse { version })
120    }
121
122    async fn load(&self, req: LoadRequest) -> Result<LoadResponse> {
123        Self::validate_file_name(&req.file_name)?;
124        let session_id = Self::get_session_id(&req.session_id, &req.file_name);
125
126        if let Some(version) = req.version {
127            let key = ArtifactKey {
128                app_name: req.app_name,
129                user_id: req.user_id,
130                session_id,
131                file_name: req.file_name,
132                version,
133            };
134
135            let artifacts = self.artifacts.read().unwrap();
136            let part = artifacts
137                .get(&key)
138                .ok_or_else(|| adk_core::AdkError::artifact("artifact not found"))?;
139
140            Ok(LoadResponse { part: part.clone() })
141        } else {
142            let (_, part) = self
143                .find_latest_version(&req.app_name, &req.user_id, &session_id, &req.file_name)
144                .ok_or_else(|| adk_core::AdkError::artifact("artifact not found"))?;
145
146            Ok(LoadResponse { part })
147        }
148    }
149
150    async fn delete(&self, req: DeleteRequest) -> Result<()> {
151        Self::validate_file_name(&req.file_name)?;
152        let session_id = Self::get_session_id(&req.session_id, &req.file_name);
153
154        let mut artifacts = self.artifacts.write().unwrap();
155
156        if let Some(version) = req.version {
157            let key = ArtifactKey {
158                app_name: req.app_name,
159                user_id: req.user_id,
160                session_id,
161                file_name: req.file_name,
162                version,
163            };
164            artifacts.remove(&key);
165        } else {
166            artifacts.retain(|k, _| {
167                !(k.app_name == req.app_name
168                    && k.user_id == req.user_id
169                    && k.session_id == session_id
170                    && k.file_name == req.file_name)
171            });
172        }
173
174        Ok(())
175    }
176
177    async fn list(&self, req: ListRequest) -> Result<ListResponse> {
178        let artifacts = self.artifacts.read().unwrap();
179        let mut file_names = std::collections::HashSet::new();
180
181        for key in artifacts.keys() {
182            if key.app_name == req.app_name
183                && key.user_id == req.user_id
184                && (key.session_id == req.session_id || key.session_id == USER_SCOPED_KEY)
185            {
186                file_names.insert(key.file_name.clone());
187            }
188        }
189
190        let mut sorted: Vec<_> = file_names.into_iter().collect();
191        sorted.sort();
192
193        Ok(ListResponse { file_names: sorted })
194    }
195
196    async fn versions(&self, req: VersionsRequest) -> Result<VersionsResponse> {
197        Self::validate_file_name(&req.file_name)?;
198        let session_id = Self::get_session_id(&req.session_id, &req.file_name);
199        let artifacts = self.artifacts.read().unwrap();
200
201        let mut versions: Vec<i64> = artifacts
202            .keys()
203            .filter(|k| {
204                k.app_name == req.app_name
205                    && k.user_id == req.user_id
206                    && k.session_id == session_id
207                    && k.file_name == req.file_name
208            })
209            .map(|k| k.version)
210            .collect();
211
212        if versions.is_empty() {
213            return Err(adk_core::AdkError::artifact("artifact not found"));
214        }
215
216        versions.sort_by(|a, b| b.cmp(a));
217
218        Ok(VersionsResponse { versions })
219    }
220}