mofa_foundation/prompt/
memory_store.rs1use super::store::{PromptCompositionEntity, PromptEntity, PromptFilter, PromptStore};
6use super::template::PromptResult;
7use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::RwLock;
10use uuid::Uuid;
11
12pub struct InMemoryPromptStore {
36 templates: RwLock<HashMap<Uuid, PromptEntity>>,
38 template_index: RwLock<HashMap<String, Uuid>>,
40 compositions: RwLock<HashMap<String, PromptCompositionEntity>>,
42}
43
44impl Default for InMemoryPromptStore {
45 fn default() -> Self {
46 Self::new()
47 }
48}
49
50impl InMemoryPromptStore {
51 pub fn new() -> Self {
53 Self {
54 templates: RwLock::new(HashMap::new()),
55 template_index: RwLock::new(HashMap::new()),
56 compositions: RwLock::new(HashMap::new()),
57 }
58 }
59
60 pub fn shared() -> std::sync::Arc<Self> {
62 std::sync::Arc::new(Self::new())
63 }
64
65 pub fn template_count(&self) -> usize {
67 self.templates.read().unwrap().len()
68 }
69
70 pub fn clear(&self) {
72 self.templates.write().unwrap().clear();
73 self.template_index.write().unwrap().clear();
74 self.compositions.write().unwrap().clear();
75 }
76}
77
78#[async_trait]
79impl PromptStore for InMemoryPromptStore {
80 async fn save_template(&self, entity: &PromptEntity) -> PromptResult<()> {
81 let mut templates = self.templates.write().unwrap();
82 let mut index = self.template_index.write().unwrap();
83
84 if let Some(&old_id) = index.get(&entity.template_id) {
86 templates.remove(&old_id);
87 }
88
89 templates.insert(entity.id, entity.clone());
90 index.insert(entity.template_id.clone(), entity.id);
91
92 Ok(())
93 }
94
95 async fn get_template_by_id(&self, id: Uuid) -> PromptResult<Option<PromptEntity>> {
96 let templates = self.templates.read().unwrap();
97 Ok(templates.get(&id).cloned())
98 }
99
100 async fn get_template(&self, template_id: &str) -> PromptResult<Option<PromptEntity>> {
101 let index = self.template_index.read().unwrap();
102 let templates = self.templates.read().unwrap();
103
104 if let Some(&uuid) = index.get(template_id) {
105 Ok(templates.get(&uuid).cloned())
106 } else {
107 Ok(None)
108 }
109 }
110
111 async fn query_templates(&self, filter: &PromptFilter) -> PromptResult<Vec<PromptEntity>> {
112 let templates = self.templates.read().unwrap();
113 let mut results: Vec<PromptEntity> = templates
114 .values()
115 .filter(|e| {
116 if filter.enabled_only && !e.enabled {
118 return false;
119 }
120
121 if let Some(ref tid) = filter.template_id
123 && &e.template_id != tid
124 {
125 return false;
126 }
127
128 if let Some(tenant_id) = filter.tenant_id
130 && e.tenant_id != Some(tenant_id)
131 {
132 return false;
133 }
134
135 if let Some(ref tags) = filter.tags
137 && !tags.iter().any(|t| e.tags.contains(t))
138 {
139 return false;
140 }
141
142 if let Some(ref keyword) = filter.search {
144 let kw = keyword.to_lowercase();
145 let match_id = e.template_id.to_lowercase().contains(&kw);
146 let match_name = e
147 .name
148 .as_ref()
149 .is_some_and(|n| n.to_lowercase().contains(&kw));
150 let match_desc = e
151 .description
152 .as_ref()
153 .is_some_and(|d| d.to_lowercase().contains(&kw));
154
155 if !match_id && !match_name && !match_desc {
156 return false;
157 }
158 }
159
160 true
161 })
162 .cloned()
163 .collect();
164
165 results.sort_by(|a, b| b.updated_at.cmp(&a.updated_at));
167
168 let offset = filter.offset.unwrap_or(0) as usize;
170 let limit = filter.limit.unwrap_or(100) as usize;
171
172 Ok(results.into_iter().skip(offset).take(limit).collect())
173 }
174
175 async fn find_by_tag(&self, tag: &str) -> PromptResult<Vec<PromptEntity>> {
176 let filter = PromptFilter::new().with_tag(tag);
177 self.query_templates(&filter).await
178 }
179
180 async fn search_templates(&self, keyword: &str) -> PromptResult<Vec<PromptEntity>> {
181 let filter = PromptFilter::new().search(keyword);
182 self.query_templates(&filter).await
183 }
184
185 async fn update_template(&self, entity: &PromptEntity) -> PromptResult<()> {
186 let mut templates = self.templates.write().unwrap();
187 let index = self.template_index.read().unwrap();
188
189 if let Some(&uuid) = index.get(&entity.template_id) {
191 let mut updated = entity.clone();
192 updated.id = uuid;
193 updated.updated_at = chrono::Utc::now();
194 templates.insert(uuid, updated);
195 }
196
197 Ok(())
198 }
199
200 async fn delete_template_by_id(&self, id: Uuid) -> PromptResult<bool> {
201 let mut templates = self.templates.write().unwrap();
202 let mut index = self.template_index.write().unwrap();
203
204 if let Some(entity) = templates.remove(&id) {
205 index.remove(&entity.template_id);
206 Ok(true)
207 } else {
208 Ok(false)
209 }
210 }
211
212 async fn delete_template(&self, template_id: &str) -> PromptResult<bool> {
213 let uuid = {
214 let index = self.template_index.read().unwrap();
215 index.get(template_id).copied()
216 };
217 if let Some(uuid) = uuid {
218 self.delete_template_by_id(uuid).await
219 } else {
220 Ok(false)
221 }
222 }
223
224 async fn set_template_enabled(&self, template_id: &str, enabled: bool) -> PromptResult<()> {
225 let index = self.template_index.read().unwrap();
226 let mut templates = self.templates.write().unwrap();
227
228 if let Some(&uuid) = index.get(template_id)
229 && let Some(entity) = templates.get_mut(&uuid)
230 {
231 entity.enabled = enabled;
232 entity.updated_at = chrono::Utc::now();
233 }
234
235 Ok(())
236 }
237
238 async fn exists(&self, template_id: &str) -> PromptResult<bool> {
239 let index = self.template_index.read().unwrap();
240 Ok(index.contains_key(template_id))
241 }
242
243 async fn count(&self, filter: &PromptFilter) -> PromptResult<i64> {
244 let results = self.query_templates(filter).await?;
245 Ok(results.len() as i64)
246 }
247
248 async fn get_all_tags(&self) -> PromptResult<Vec<String>> {
249 let templates = self.templates.read().unwrap();
250 let mut tags: std::collections::HashSet<String> = std::collections::HashSet::new();
251
252 for entity in templates.values() {
253 for tag in &entity.tags {
254 tags.insert(tag.clone());
255 }
256 }
257
258 let mut result: Vec<String> = tags.into_iter().collect();
259 result.sort();
260 Ok(result)
261 }
262
263 async fn save_composition(&self, entity: &PromptCompositionEntity) -> PromptResult<()> {
264 let mut compositions = self.compositions.write().unwrap();
265 compositions.insert(entity.composition_id.clone(), entity.clone());
266 Ok(())
267 }
268
269 async fn get_composition(
270 &self,
271 composition_id: &str,
272 ) -> PromptResult<Option<PromptCompositionEntity>> {
273 let compositions = self.compositions.read().unwrap();
274 Ok(compositions.get(composition_id).cloned())
275 }
276
277 async fn query_compositions(&self) -> PromptResult<Vec<PromptCompositionEntity>> {
278 let compositions = self.compositions.read().unwrap();
279 Ok(compositions.values().cloned().collect())
280 }
281
282 async fn delete_composition(&self, composition_id: &str) -> PromptResult<bool> {
283 let mut compositions = self.compositions.write().unwrap();
284 Ok(compositions.remove(composition_id).is_some())
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291 use crate::prompt::template::PromptTemplate;
292
293 #[tokio::test]
294 async fn test_memory_store_basic() {
295 let store = InMemoryPromptStore::new();
296
297 let template = PromptTemplate::new("test")
298 .with_name("Test Template")
299 .with_content("Hello, {name}!")
300 .with_tag("greeting");
301
302 let entity = PromptEntity::from_template(&template);
303 store.save_template(&entity).await.unwrap();
304
305 assert!(store.exists("test").await.unwrap());
306 assert_eq!(store.template_count(), 1);
307
308 let found = store.get_template("test").await.unwrap();
309 assert!(found.is_some());
310 assert_eq!(found.unwrap().template_id, "test");
311 }
312
313 #[tokio::test]
314 async fn test_memory_store_query() {
315 let store = InMemoryPromptStore::new();
316
317 for i in 0..5 {
319 let template = PromptTemplate::new(format!("template-{}", i))
320 .with_name(format!("Template {}", i))
321 .with_tag(if i % 2 == 0 { "even" } else { "odd" });
322
323 store
324 .save_template(&PromptEntity::from_template(&template))
325 .await
326 .unwrap();
327 }
328
329 let even = store.find_by_tag("even").await.unwrap();
331 assert_eq!(even.len(), 3);
332
333 let odd = store.find_by_tag("odd").await.unwrap();
334 assert_eq!(odd.len(), 2);
335 }
336
337 #[tokio::test]
338 async fn test_memory_store_search() {
339 let store = InMemoryPromptStore::new();
340
341 store
342 .save_template(&PromptEntity::from_template(
343 &PromptTemplate::new("code-review")
344 .with_name("Code Review")
345 .with_description("Review code for issues"),
346 ))
347 .await
348 .unwrap();
349
350 store
351 .save_template(&PromptEntity::from_template(
352 &PromptTemplate::new("code-explain")
353 .with_name("Code Explanation")
354 .with_description("Explain code in detail"),
355 ))
356 .await
357 .unwrap();
358
359 store
360 .save_template(&PromptEntity::from_template(
361 &PromptTemplate::new("chat").with_name("Chat Assistant"),
362 ))
363 .await
364 .unwrap();
365
366 let results = store.search_templates("code").await.unwrap();
368 assert_eq!(results.len(), 2);
369
370 let results = store.search_templates("review").await.unwrap();
372 assert_eq!(results.len(), 1);
373 }
374
375 #[tokio::test]
376 async fn test_memory_store_delete() {
377 let store = InMemoryPromptStore::new();
378
379 let entity = PromptEntity::from_template(&PromptTemplate::new("test").with_content("test"));
380
381 store.save_template(&entity).await.unwrap();
382 assert!(store.exists("test").await.unwrap());
383
384 store.delete_template("test").await.unwrap();
385 assert!(!store.exists("test").await.unwrap());
386 }
387
388 #[tokio::test]
389 async fn test_memory_store_enable_disable() {
390 let store = InMemoryPromptStore::new();
391
392 let entity = PromptEntity::from_template(&PromptTemplate::new("test").with_content("test"));
393
394 store.save_template(&entity).await.unwrap();
395
396 store.set_template_enabled("test", false).await.unwrap();
398
399 let filter = PromptFilter::new();
401 let results = store.query_templates(&filter).await.unwrap();
402 assert_eq!(results.len(), 0);
403
404 let filter = PromptFilter::new().include_disabled();
406 let results = store.query_templates(&filter).await.unwrap();
407 assert_eq!(results.len(), 1);
408 }
409
410 #[tokio::test]
411 async fn test_memory_store_tags() {
412 let store = InMemoryPromptStore::new();
413
414 store
415 .save_template(&PromptEntity::from_template(
416 &PromptTemplate::new("t1").with_tag("a").with_tag("b"),
417 ))
418 .await
419 .unwrap();
420
421 store
422 .save_template(&PromptEntity::from_template(
423 &PromptTemplate::new("t2").with_tag("b").with_tag("c"),
424 ))
425 .await
426 .unwrap();
427
428 let tags = store.get_all_tags().await.unwrap();
429 assert_eq!(tags, vec!["a", "b", "c"]);
430 }
431}