Skip to main content

mofa_foundation/prompt/
memory_store.rs

1//! 内存 Prompt 存储实现
2//!
3//! 提供基于内存的 Prompt 模板存储,适用于开发和测试
4
5use super::store::{PromptCompositionEntity, PromptEntity, PromptFilter, PromptStore};
6use super::template::PromptResult;
7use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::RwLock;
10use uuid::Uuid;
11
12/// 内存 Prompt 存储
13///
14/// 线程安全的内存存储实现,适用于:
15/// - 开发和测试环境
16/// - 不需要持久化的场景
17/// - 与预置模板库配合使用
18///
19/// # 示例
20///
21/// ```rust,ignore
22/// use mofa_foundation::prompt::{InMemoryPromptStore, PromptEntity, PromptTemplate};
23///
24/// let store = InMemoryPromptStore::new();
25///
26/// // 保存模板
27/// let template = PromptTemplate::new("greeting")
28///     .with_content("Hello, {name}!");
29/// let entity = PromptEntity::from_template(&template);
30/// store.save_template(&entity).await?;
31///
32/// // 查询模板
33/// let found = store.get_template("greeting").await?;
34/// ```
35pub struct InMemoryPromptStore {
36    /// 模板存储 (UUID -> Entity)
37    templates: RwLock<HashMap<Uuid, PromptEntity>>,
38    /// 模板 ID 索引 (template_id -> UUID)
39    template_index: RwLock<HashMap<String, Uuid>>,
40    /// 组合存储
41    compositions: RwLock<HashMap<String, PromptCompositionEntity>>,
42}
43
44impl Default for InMemoryPromptStore {
45    fn default() -> Self {
46        Self::new()
47    }
48}
49
50impl InMemoryPromptStore {
51    /// 创建新的内存存储
52    pub fn new() -> Self {
53        Self {
54            templates: RwLock::new(HashMap::new()),
55            template_index: RwLock::new(HashMap::new()),
56            compositions: RwLock::new(HashMap::new()),
57        }
58    }
59
60    /// 创建共享实例
61    pub fn shared() -> std::sync::Arc<Self> {
62        std::sync::Arc::new(Self::new())
63    }
64
65    /// 获取模板数量
66    pub fn template_count(&self) -> usize {
67        self.templates.read().unwrap().len()
68    }
69
70    /// 清空所有数据
71    pub fn clear(&self) {
72        self.templates.write().unwrap().clear();
73        self.template_index.write().unwrap().clear();
74        self.compositions.write().unwrap().clear();
75    }
76}
77
78#[async_trait]
79impl PromptStore for InMemoryPromptStore {
80    async fn save_template(&self, entity: &PromptEntity) -> PromptResult<()> {
81        let mut templates = self.templates.write().unwrap();
82        let mut index = self.template_index.write().unwrap();
83
84        // 如果已存在相同 template_id,删除旧的
85        if let Some(&old_id) = index.get(&entity.template_id) {
86            templates.remove(&old_id);
87        }
88
89        templates.insert(entity.id, entity.clone());
90        index.insert(entity.template_id.clone(), entity.id);
91
92        Ok(())
93    }
94
95    async fn get_template_by_id(&self, id: Uuid) -> PromptResult<Option<PromptEntity>> {
96        let templates = self.templates.read().unwrap();
97        Ok(templates.get(&id).cloned())
98    }
99
100    async fn get_template(&self, template_id: &str) -> PromptResult<Option<PromptEntity>> {
101        let index = self.template_index.read().unwrap();
102        let templates = self.templates.read().unwrap();
103
104        if let Some(&uuid) = index.get(template_id) {
105            Ok(templates.get(&uuid).cloned())
106        } else {
107            Ok(None)
108        }
109    }
110
111    async fn query_templates(&self, filter: &PromptFilter) -> PromptResult<Vec<PromptEntity>> {
112        let templates = self.templates.read().unwrap();
113        let mut results: Vec<PromptEntity> = templates
114            .values()
115            .filter(|e| {
116                // 按启用状态过滤
117                if filter.enabled_only && !e.enabled {
118                    return false;
119                }
120
121                // 按模板 ID 过滤
122                if let Some(ref tid) = filter.template_id
123                    && &e.template_id != tid
124                {
125                    return false;
126                }
127
128                // 按租户过滤
129                if let Some(tenant_id) = filter.tenant_id
130                    && e.tenant_id != Some(tenant_id)
131                {
132                    return false;
133                }
134
135                // 按标签过滤
136                if let Some(ref tags) = filter.tags
137                    && !tags.iter().any(|t| e.tags.contains(t))
138                {
139                    return false;
140                }
141
142                // 按关键词搜索
143                if let Some(ref keyword) = filter.search {
144                    let kw = keyword.to_lowercase();
145                    let match_id = e.template_id.to_lowercase().contains(&kw);
146                    let match_name = e
147                        .name
148                        .as_ref()
149                        .is_some_and(|n| n.to_lowercase().contains(&kw));
150                    let match_desc = e
151                        .description
152                        .as_ref()
153                        .is_some_and(|d| d.to_lowercase().contains(&kw));
154
155                    if !match_id && !match_name && !match_desc {
156                        return false;
157                    }
158                }
159
160                true
161            })
162            .cloned()
163            .collect();
164
165        // 按更新时间排序
166        results.sort_by(|a, b| b.updated_at.cmp(&a.updated_at));
167
168        // 分页
169        let offset = filter.offset.unwrap_or(0) as usize;
170        let limit = filter.limit.unwrap_or(100) as usize;
171
172        Ok(results.into_iter().skip(offset).take(limit).collect())
173    }
174
175    async fn find_by_tag(&self, tag: &str) -> PromptResult<Vec<PromptEntity>> {
176        let filter = PromptFilter::new().with_tag(tag);
177        self.query_templates(&filter).await
178    }
179
180    async fn search_templates(&self, keyword: &str) -> PromptResult<Vec<PromptEntity>> {
181        let filter = PromptFilter::new().search(keyword);
182        self.query_templates(&filter).await
183    }
184
185    async fn update_template(&self, entity: &PromptEntity) -> PromptResult<()> {
186        let mut templates = self.templates.write().unwrap();
187        let index = self.template_index.read().unwrap();
188
189        // 确保模板存在
190        if let Some(&uuid) = index.get(&entity.template_id) {
191            let mut updated = entity.clone();
192            updated.id = uuid;
193            updated.updated_at = chrono::Utc::now();
194            templates.insert(uuid, updated);
195        }
196
197        Ok(())
198    }
199
200    async fn delete_template_by_id(&self, id: Uuid) -> PromptResult<bool> {
201        let mut templates = self.templates.write().unwrap();
202        let mut index = self.template_index.write().unwrap();
203
204        if let Some(entity) = templates.remove(&id) {
205            index.remove(&entity.template_id);
206            Ok(true)
207        } else {
208            Ok(false)
209        }
210    }
211
212    async fn delete_template(&self, template_id: &str) -> PromptResult<bool> {
213        let uuid = {
214            let index = self.template_index.read().unwrap();
215            index.get(template_id).copied()
216        };
217        if let Some(uuid) = uuid {
218            self.delete_template_by_id(uuid).await
219        } else {
220            Ok(false)
221        }
222    }
223
224    async fn set_template_enabled(&self, template_id: &str, enabled: bool) -> PromptResult<()> {
225        let index = self.template_index.read().unwrap();
226        let mut templates = self.templates.write().unwrap();
227
228        if let Some(&uuid) = index.get(template_id)
229            && let Some(entity) = templates.get_mut(&uuid)
230        {
231            entity.enabled = enabled;
232            entity.updated_at = chrono::Utc::now();
233        }
234
235        Ok(())
236    }
237
238    async fn exists(&self, template_id: &str) -> PromptResult<bool> {
239        let index = self.template_index.read().unwrap();
240        Ok(index.contains_key(template_id))
241    }
242
243    async fn count(&self, filter: &PromptFilter) -> PromptResult<i64> {
244        let results = self.query_templates(filter).await?;
245        Ok(results.len() as i64)
246    }
247
248    async fn get_all_tags(&self) -> PromptResult<Vec<String>> {
249        let templates = self.templates.read().unwrap();
250        let mut tags: std::collections::HashSet<String> = std::collections::HashSet::new();
251
252        for entity in templates.values() {
253            for tag in &entity.tags {
254                tags.insert(tag.clone());
255            }
256        }
257
258        let mut result: Vec<String> = tags.into_iter().collect();
259        result.sort();
260        Ok(result)
261    }
262
263    async fn save_composition(&self, entity: &PromptCompositionEntity) -> PromptResult<()> {
264        let mut compositions = self.compositions.write().unwrap();
265        compositions.insert(entity.composition_id.clone(), entity.clone());
266        Ok(())
267    }
268
269    async fn get_composition(
270        &self,
271        composition_id: &str,
272    ) -> PromptResult<Option<PromptCompositionEntity>> {
273        let compositions = self.compositions.read().unwrap();
274        Ok(compositions.get(composition_id).cloned())
275    }
276
277    async fn query_compositions(&self) -> PromptResult<Vec<PromptCompositionEntity>> {
278        let compositions = self.compositions.read().unwrap();
279        Ok(compositions.values().cloned().collect())
280    }
281
282    async fn delete_composition(&self, composition_id: &str) -> PromptResult<bool> {
283        let mut compositions = self.compositions.write().unwrap();
284        Ok(compositions.remove(composition_id).is_some())
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291    use crate::prompt::template::PromptTemplate;
292
293    #[tokio::test]
294    async fn test_memory_store_basic() {
295        let store = InMemoryPromptStore::new();
296
297        let template = PromptTemplate::new("test")
298            .with_name("Test Template")
299            .with_content("Hello, {name}!")
300            .with_tag("greeting");
301
302        let entity = PromptEntity::from_template(&template);
303        store.save_template(&entity).await.unwrap();
304
305        assert!(store.exists("test").await.unwrap());
306        assert_eq!(store.template_count(), 1);
307
308        let found = store.get_template("test").await.unwrap();
309        assert!(found.is_some());
310        assert_eq!(found.unwrap().template_id, "test");
311    }
312
313    #[tokio::test]
314    async fn test_memory_store_query() {
315        let store = InMemoryPromptStore::new();
316
317        // 保存多个模板
318        for i in 0..5 {
319            let template = PromptTemplate::new(format!("template-{}", i))
320                .with_name(format!("Template {}", i))
321                .with_tag(if i % 2 == 0 { "even" } else { "odd" });
322
323            store
324                .save_template(&PromptEntity::from_template(&template))
325                .await
326                .unwrap();
327        }
328
329        // 按标签查询
330        let even = store.find_by_tag("even").await.unwrap();
331        assert_eq!(even.len(), 3);
332
333        let odd = store.find_by_tag("odd").await.unwrap();
334        assert_eq!(odd.len(), 2);
335    }
336
337    #[tokio::test]
338    async fn test_memory_store_search() {
339        let store = InMemoryPromptStore::new();
340
341        store
342            .save_template(&PromptEntity::from_template(
343                &PromptTemplate::new("code-review")
344                    .with_name("Code Review")
345                    .with_description("Review code for issues"),
346            ))
347            .await
348            .unwrap();
349
350        store
351            .save_template(&PromptEntity::from_template(
352                &PromptTemplate::new("code-explain")
353                    .with_name("Code Explanation")
354                    .with_description("Explain code in detail"),
355            ))
356            .await
357            .unwrap();
358
359        store
360            .save_template(&PromptEntity::from_template(
361                &PromptTemplate::new("chat").with_name("Chat Assistant"),
362            ))
363            .await
364            .unwrap();
365
366        // 搜索 "code"
367        let results = store.search_templates("code").await.unwrap();
368        assert_eq!(results.len(), 2);
369
370        // 搜索 "review"
371        let results = store.search_templates("review").await.unwrap();
372        assert_eq!(results.len(), 1);
373    }
374
375    #[tokio::test]
376    async fn test_memory_store_delete() {
377        let store = InMemoryPromptStore::new();
378
379        let entity = PromptEntity::from_template(&PromptTemplate::new("test").with_content("test"));
380
381        store.save_template(&entity).await.unwrap();
382        assert!(store.exists("test").await.unwrap());
383
384        store.delete_template("test").await.unwrap();
385        assert!(!store.exists("test").await.unwrap());
386    }
387
388    #[tokio::test]
389    async fn test_memory_store_enable_disable() {
390        let store = InMemoryPromptStore::new();
391
392        let entity = PromptEntity::from_template(&PromptTemplate::new("test").with_content("test"));
393
394        store.save_template(&entity).await.unwrap();
395
396        // 禁用
397        store.set_template_enabled("test", false).await.unwrap();
398
399        // 启用模式查询应该找不到
400        let filter = PromptFilter::new();
401        let results = store.query_templates(&filter).await.unwrap();
402        assert_eq!(results.len(), 0);
403
404        // 包含禁用的查询应该能找到
405        let filter = PromptFilter::new().include_disabled();
406        let results = store.query_templates(&filter).await.unwrap();
407        assert_eq!(results.len(), 1);
408    }
409
410    #[tokio::test]
411    async fn test_memory_store_tags() {
412        let store = InMemoryPromptStore::new();
413
414        store
415            .save_template(&PromptEntity::from_template(
416                &PromptTemplate::new("t1").with_tag("a").with_tag("b"),
417            ))
418            .await
419            .unwrap();
420
421        store
422            .save_template(&PromptEntity::from_template(
423                &PromptTemplate::new("t2").with_tag("b").with_tag("c"),
424            ))
425            .await
426            .unwrap();
427
428        let tags = store.get_all_tags().await.unwrap();
429        assert_eq!(tags, vec!["a", "b", "c"]);
430    }
431}