adk_artifact/
inmemory.rs

1use crate::service::*;
2use adk_core::{Part, Result};
3use async_trait::async_trait;
4use std::collections::HashMap;
5use std::sync::{Arc, RwLock};
6
7const USER_SCOPED_KEY: &str = "user";
8
9#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
10struct ArtifactKey {
11    app_name: String,
12    user_id: String,
13    session_id: String,
14    file_name: String,
15    version: i64,
16}
17
18pub struct InMemoryArtifactService {
19    artifacts: Arc<RwLock<HashMap<ArtifactKey, Part>>>,
20}
21
22impl InMemoryArtifactService {
23    pub fn new() -> Self {
24        Self { artifacts: Arc::new(RwLock::new(HashMap::new())) }
25    }
26
27    fn is_user_scoped(file_name: &str) -> bool {
28        file_name.starts_with("user:")
29    }
30
31    fn get_session_id(session_id: &str, file_name: &str) -> String {
32        if Self::is_user_scoped(file_name) {
33            USER_SCOPED_KEY.to_string()
34        } else {
35            session_id.to_string()
36        }
37    }
38
39    fn find_latest_version(
40        &self,
41        app_name: &str,
42        user_id: &str,
43        session_id: &str,
44        file_name: &str,
45    ) -> Option<(i64, Part)> {
46        let artifacts = self.artifacts.read().unwrap();
47        let mut versions: Vec<_> = artifacts
48            .iter()
49            .filter(|(k, _)| {
50                k.app_name == app_name
51                    && k.user_id == user_id
52                    && k.session_id == session_id
53                    && k.file_name == file_name
54            })
55            .collect();
56
57        versions.sort_by(|a, b| b.0.version.cmp(&a.0.version));
58        versions.first().map(|(k, v)| (k.version, (*v).clone()))
59    }
60}
61
62impl Default for InMemoryArtifactService {
63    fn default() -> Self {
64        Self::new()
65    }
66}
67
68#[async_trait]
69impl ArtifactService for InMemoryArtifactService {
70    async fn save(&self, req: SaveRequest) -> Result<SaveResponse> {
71        let session_id = Self::get_session_id(&req.session_id, &req.file_name);
72
73        let version = if let Some(v) = req.version {
74            v
75        } else {
76            let latest =
77                self.find_latest_version(&req.app_name, &req.user_id, &session_id, &req.file_name);
78            latest.map(|(v, _)| v + 1).unwrap_or(1)
79        };
80
81        let key = ArtifactKey {
82            app_name: req.app_name,
83            user_id: req.user_id,
84            session_id,
85            file_name: req.file_name,
86            version,
87        };
88
89        let mut artifacts = self.artifacts.write().unwrap();
90        artifacts.insert(key, req.part);
91
92        Ok(SaveResponse { version })
93    }
94
95    async fn load(&self, req: LoadRequest) -> Result<LoadResponse> {
96        let session_id = Self::get_session_id(&req.session_id, &req.file_name);
97
98        if let Some(version) = req.version {
99            let key = ArtifactKey {
100                app_name: req.app_name,
101                user_id: req.user_id,
102                session_id,
103                file_name: req.file_name,
104                version,
105            };
106
107            let artifacts = self.artifacts.read().unwrap();
108            let part = artifacts
109                .get(&key)
110                .ok_or_else(|| adk_core::AdkError::Artifact("artifact not found".into()))?;
111
112            Ok(LoadResponse { part: part.clone() })
113        } else {
114            let (_, part) = self
115                .find_latest_version(&req.app_name, &req.user_id, &session_id, &req.file_name)
116                .ok_or_else(|| adk_core::AdkError::Artifact("artifact not found".into()))?;
117
118            Ok(LoadResponse { part })
119        }
120    }
121
122    async fn delete(&self, req: DeleteRequest) -> Result<()> {
123        let session_id = Self::get_session_id(&req.session_id, &req.file_name);
124
125        let mut artifacts = self.artifacts.write().unwrap();
126
127        if let Some(version) = req.version {
128            let key = ArtifactKey {
129                app_name: req.app_name,
130                user_id: req.user_id,
131                session_id,
132                file_name: req.file_name,
133                version,
134            };
135            artifacts.remove(&key);
136        } else {
137            artifacts.retain(|k, _| {
138                !(k.app_name == req.app_name
139                    && k.user_id == req.user_id
140                    && k.session_id == session_id
141                    && k.file_name == req.file_name)
142            });
143        }
144
145        Ok(())
146    }
147
148    async fn list(&self, req: ListRequest) -> Result<ListResponse> {
149        let artifacts = self.artifacts.read().unwrap();
150        let mut file_names = std::collections::HashSet::new();
151
152        for key in artifacts.keys() {
153            if key.app_name == req.app_name
154                && key.user_id == req.user_id
155                && (key.session_id == req.session_id || key.session_id == USER_SCOPED_KEY)
156            {
157                file_names.insert(key.file_name.clone());
158            }
159        }
160
161        let mut sorted: Vec<_> = file_names.into_iter().collect();
162        sorted.sort();
163
164        Ok(ListResponse { file_names: sorted })
165    }
166
167    async fn versions(&self, req: VersionsRequest) -> Result<VersionsResponse> {
168        let session_id = Self::get_session_id(&req.session_id, &req.file_name);
169        let artifacts = self.artifacts.read().unwrap();
170
171        let mut versions: Vec<i64> = artifacts
172            .keys()
173            .filter(|k| {
174                k.app_name == req.app_name
175                    && k.user_id == req.user_id
176                    && k.session_id == session_id
177                    && k.file_name == req.file_name
178            })
179            .map(|k| k.version)
180            .collect();
181
182        if versions.is_empty() {
183            return Err(adk_core::AdkError::Artifact("artifact not found".into()));
184        }
185
186        versions.sort_by(|a, b| b.cmp(a));
187
188        Ok(VersionsResponse { versions })
189    }
190}