1use crate::service::*;
2use adk_core::{Part, Result};
3use async_trait::async_trait;
4use std::collections::HashMap;
5use std::sync::{Arc, RwLock};
6
7const USER_SCOPED_KEY: &str = "user";
8
9#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
10struct ArtifactKey {
11 app_name: String,
12 user_id: String,
13 session_id: String,
14 file_name: String,
15 version: i64,
16}
17
18pub struct InMemoryArtifactService {
19 artifacts: Arc<RwLock<HashMap<ArtifactKey, Part>>>,
20}
21
22impl InMemoryArtifactService {
23 pub fn new() -> Self {
24 Self { artifacts: Arc::new(RwLock::new(HashMap::new())) }
25 }
26
27 fn is_user_scoped(file_name: &str) -> bool {
28 file_name.starts_with("user:")
29 }
30
31 fn get_session_id(session_id: &str, file_name: &str) -> String {
32 if Self::is_user_scoped(file_name) {
33 USER_SCOPED_KEY.to_string()
34 } else {
35 session_id.to_string()
36 }
37 }
38
39 fn find_latest_version(
40 &self,
41 app_name: &str,
42 user_id: &str,
43 session_id: &str,
44 file_name: &str,
45 ) -> Option<(i64, Part)> {
46 let artifacts = self.artifacts.read().unwrap();
47 let mut versions: Vec<_> = artifacts
48 .iter()
49 .filter(|(k, _)| {
50 k.app_name == app_name
51 && k.user_id == user_id
52 && k.session_id == session_id
53 && k.file_name == file_name
54 })
55 .collect();
56
57 versions.sort_by(|a, b| b.0.version.cmp(&a.0.version));
58 versions.first().map(|(k, v)| (k.version, (*v).clone()))
59 }
60}
61
62impl Default for InMemoryArtifactService {
63 fn default() -> Self {
64 Self::new()
65 }
66}
67
68#[async_trait]
69impl ArtifactService for InMemoryArtifactService {
70 async fn save(&self, req: SaveRequest) -> Result<SaveResponse> {
71 let session_id = Self::get_session_id(&req.session_id, &req.file_name);
72
73 let version = if let Some(v) = req.version {
74 v
75 } else {
76 let latest =
77 self.find_latest_version(&req.app_name, &req.user_id, &session_id, &req.file_name);
78 latest.map(|(v, _)| v + 1).unwrap_or(1)
79 };
80
81 let key = ArtifactKey {
82 app_name: req.app_name,
83 user_id: req.user_id,
84 session_id,
85 file_name: req.file_name,
86 version,
87 };
88
89 let mut artifacts = self.artifacts.write().unwrap();
90 artifacts.insert(key, req.part);
91
92 Ok(SaveResponse { version })
93 }
94
95 async fn load(&self, req: LoadRequest) -> Result<LoadResponse> {
96 let session_id = Self::get_session_id(&req.session_id, &req.file_name);
97
98 if let Some(version) = req.version {
99 let key = ArtifactKey {
100 app_name: req.app_name,
101 user_id: req.user_id,
102 session_id,
103 file_name: req.file_name,
104 version,
105 };
106
107 let artifacts = self.artifacts.read().unwrap();
108 let part = artifacts
109 .get(&key)
110 .ok_or_else(|| adk_core::AdkError::Artifact("artifact not found".into()))?;
111
112 Ok(LoadResponse { part: part.clone() })
113 } else {
114 let (_, part) = self
115 .find_latest_version(&req.app_name, &req.user_id, &session_id, &req.file_name)
116 .ok_or_else(|| adk_core::AdkError::Artifact("artifact not found".into()))?;
117
118 Ok(LoadResponse { part })
119 }
120 }
121
122 async fn delete(&self, req: DeleteRequest) -> Result<()> {
123 let session_id = Self::get_session_id(&req.session_id, &req.file_name);
124
125 let mut artifacts = self.artifacts.write().unwrap();
126
127 if let Some(version) = req.version {
128 let key = ArtifactKey {
129 app_name: req.app_name,
130 user_id: req.user_id,
131 session_id,
132 file_name: req.file_name,
133 version,
134 };
135 artifacts.remove(&key);
136 } else {
137 artifacts.retain(|k, _| {
138 !(k.app_name == req.app_name
139 && k.user_id == req.user_id
140 && k.session_id == session_id
141 && k.file_name == req.file_name)
142 });
143 }
144
145 Ok(())
146 }
147
148 async fn list(&self, req: ListRequest) -> Result<ListResponse> {
149 let artifacts = self.artifacts.read().unwrap();
150 let mut file_names = std::collections::HashSet::new();
151
152 for key in artifacts.keys() {
153 if key.app_name == req.app_name
154 && key.user_id == req.user_id
155 && (key.session_id == req.session_id || key.session_id == USER_SCOPED_KEY)
156 {
157 file_names.insert(key.file_name.clone());
158 }
159 }
160
161 let mut sorted: Vec<_> = file_names.into_iter().collect();
162 sorted.sort();
163
164 Ok(ListResponse { file_names: sorted })
165 }
166
167 async fn versions(&self, req: VersionsRequest) -> Result<VersionsResponse> {
168 let session_id = Self::get_session_id(&req.session_id, &req.file_name);
169 let artifacts = self.artifacts.read().unwrap();
170
171 let mut versions: Vec<i64> = artifacts
172 .keys()
173 .filter(|k| {
174 k.app_name == req.app_name
175 && k.user_id == req.user_id
176 && k.session_id == session_id
177 && k.file_name == req.file_name
178 })
179 .map(|k| k.version)
180 .collect();
181
182 if versions.is_empty() {
183 return Err(adk_core::AdkError::Artifact("artifact not found".into()));
184 }
185
186 versions.sort_by(|a, b| b.cmp(a));
187
188 Ok(VersionsResponse { versions })
189 }
190}