1use crate::service::*;
2use adk_core::{Part, Result};
3use async_trait::async_trait;
4use std::collections::HashMap;
5use std::sync::{Arc, RwLock};
6
7const USER_SCOPED_KEY: &str = "user";
8
9#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
10struct ArtifactKey {
11 app_name: String,
12 user_id: String,
13 session_id: String,
14 file_name: String,
15 version: i64,
16}
17
18pub struct InMemoryArtifactService {
23 artifacts: Arc<RwLock<HashMap<ArtifactKey, Part>>>,
24}
25
26impl InMemoryArtifactService {
27 pub fn new() -> Self {
29 Self { artifacts: Arc::new(RwLock::new(HashMap::new())) }
30 }
31
32 fn is_user_scoped(file_name: &str) -> bool {
33 file_name.starts_with("user:")
34 }
35
36 fn get_session_id(session_id: &str, file_name: &str) -> String {
37 if Self::is_user_scoped(file_name) {
38 USER_SCOPED_KEY.to_string()
39 } else {
40 session_id.to_string()
41 }
42 }
43
44 fn validate_file_name(file_name: &str) -> Result<()> {
45 if file_name.is_empty() {
46 return Err(adk_core::AdkError::artifact("invalid artifact file name: empty name"));
47 }
48
49 if file_name.contains('/')
51 || file_name.contains('\\')
52 || file_name == "."
53 || file_name == ".."
54 || file_name.contains("..")
55 {
56 return Err(adk_core::AdkError::artifact(format!(
57 "invalid artifact file name '{}': path separators and traversal patterns are not allowed",
58 file_name
59 )));
60 }
61
62 Ok(())
63 }
64
65 fn find_latest_version(
66 &self,
67 app_name: &str,
68 user_id: &str,
69 session_id: &str,
70 file_name: &str,
71 ) -> Option<(i64, Part)> {
72 let artifacts = self.artifacts.read().unwrap();
73 let mut versions: Vec<_> = artifacts
74 .iter()
75 .filter(|(k, _)| {
76 k.app_name == app_name
77 && k.user_id == user_id
78 && k.session_id == session_id
79 && k.file_name == file_name
80 })
81 .collect();
82
83 versions.sort_by_key(|b| std::cmp::Reverse(b.0.version));
84 versions.first().map(|(k, v)| (k.version, (*v).clone()))
85 }
86}
87
88impl Default for InMemoryArtifactService {
89 fn default() -> Self {
90 Self::new()
91 }
92}
93
94#[async_trait]
95impl ArtifactService for InMemoryArtifactService {
96 async fn save(&self, req: SaveRequest) -> Result<SaveResponse> {
97 Self::validate_file_name(&req.file_name)?;
98 let session_id = Self::get_session_id(&req.session_id, &req.file_name);
99
100 let version = if let Some(v) = req.version {
101 v
102 } else {
103 let latest =
104 self.find_latest_version(&req.app_name, &req.user_id, &session_id, &req.file_name);
105 latest.map(|(v, _)| v + 1).unwrap_or(1)
106 };
107
108 let key = ArtifactKey {
109 app_name: req.app_name,
110 user_id: req.user_id,
111 session_id,
112 file_name: req.file_name,
113 version,
114 };
115
116 let mut artifacts = self.artifacts.write().unwrap();
117 artifacts.insert(key, req.part);
118
119 Ok(SaveResponse { version })
120 }
121
122 async fn load(&self, req: LoadRequest) -> Result<LoadResponse> {
123 Self::validate_file_name(&req.file_name)?;
124 let session_id = Self::get_session_id(&req.session_id, &req.file_name);
125
126 if let Some(version) = req.version {
127 let key = ArtifactKey {
128 app_name: req.app_name,
129 user_id: req.user_id,
130 session_id,
131 file_name: req.file_name,
132 version,
133 };
134
135 let artifacts = self.artifacts.read().unwrap();
136 let part = artifacts
137 .get(&key)
138 .ok_or_else(|| adk_core::AdkError::artifact("artifact not found"))?;
139
140 Ok(LoadResponse { part: part.clone() })
141 } else {
142 let (_, part) = self
143 .find_latest_version(&req.app_name, &req.user_id, &session_id, &req.file_name)
144 .ok_or_else(|| adk_core::AdkError::artifact("artifact not found"))?;
145
146 Ok(LoadResponse { part })
147 }
148 }
149
150 async fn delete(&self, req: DeleteRequest) -> Result<()> {
151 Self::validate_file_name(&req.file_name)?;
152 let session_id = Self::get_session_id(&req.session_id, &req.file_name);
153
154 let mut artifacts = self.artifacts.write().unwrap();
155
156 if let Some(version) = req.version {
157 let key = ArtifactKey {
158 app_name: req.app_name,
159 user_id: req.user_id,
160 session_id,
161 file_name: req.file_name,
162 version,
163 };
164 artifacts.remove(&key);
165 } else {
166 artifacts.retain(|k, _| {
167 !(k.app_name == req.app_name
168 && k.user_id == req.user_id
169 && k.session_id == session_id
170 && k.file_name == req.file_name)
171 });
172 }
173
174 Ok(())
175 }
176
177 async fn list(&self, req: ListRequest) -> Result<ListResponse> {
178 let artifacts = self.artifacts.read().unwrap();
179 let mut file_names = std::collections::HashSet::new();
180
181 for key in artifacts.keys() {
182 if key.app_name == req.app_name
183 && key.user_id == req.user_id
184 && (key.session_id == req.session_id || key.session_id == USER_SCOPED_KEY)
185 {
186 file_names.insert(key.file_name.clone());
187 }
188 }
189
190 let mut sorted: Vec<_> = file_names.into_iter().collect();
191 sorted.sort();
192
193 Ok(ListResponse { file_names: sorted })
194 }
195
196 async fn versions(&self, req: VersionsRequest) -> Result<VersionsResponse> {
197 Self::validate_file_name(&req.file_name)?;
198 let session_id = Self::get_session_id(&req.session_id, &req.file_name);
199 let artifacts = self.artifacts.read().unwrap();
200
201 let mut versions: Vec<i64> = artifacts
202 .keys()
203 .filter(|k| {
204 k.app_name == req.app_name
205 && k.user_id == req.user_id
206 && k.session_id == session_id
207 && k.file_name == req.file_name
208 })
209 .map(|k| k.version)
210 .collect();
211
212 if versions.is_empty() {
213 return Err(adk_core::AdkError::artifact("artifact not found"));
214 }
215
216 versions.sort_by(|a, b| b.cmp(a));
217
218 Ok(VersionsResponse { versions })
219 }
220}