Skip to main content

mofa_foundation/prompt/
store.rs

1//! Prompt 持久化存储
2//!
3//! 提供 Prompt 模板的数据库存储支持
4
5use super::template::{
6    PromptComposition, PromptError, PromptResult, PromptTemplate, PromptVariable,
7};
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use uuid::Uuid;
12
13/// Prompt 模板数据库实体
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct PromptEntity {
16    /// 唯一 ID
17    pub id: Uuid,
18    /// 模板标识符(用于查找)
19    pub template_id: String,
20    /// 模板名称
21    pub name: Option<String>,
22    /// 模板描述
23    pub description: Option<String>,
24    /// 模板内容
25    pub content: String,
26    /// 变量定义(JSON)
27    pub variables: serde_json::Value,
28    /// 标签列表
29    pub tags: Vec<String>,
30    /// 版本号
31    pub version: Option<String>,
32    /// 元数据(JSON)
33    pub metadata: serde_json::Value,
34    /// 是否启用
35    pub enabled: bool,
36    /// 创建时间
37    pub created_at: chrono::DateTime<chrono::Utc>,
38    /// 更新时间
39    pub updated_at: chrono::DateTime<chrono::Utc>,
40    /// 创建者 ID
41    pub created_by: Option<Uuid>,
42    /// 租户 ID(用于多租户隔离)
43    pub tenant_id: Option<Uuid>,
44}
45
46impl PromptEntity {
47    /// 从 PromptTemplate 创建实体
48    pub fn from_template(template: &PromptTemplate) -> Self {
49        let now = chrono::Utc::now();
50        let variables = serde_json::to_value(&template.variables).unwrap_or_default();
51        let metadata = serde_json::to_value(&template.metadata).unwrap_or_default();
52
53        Self {
54            id: Uuid::now_v7(),
55            template_id: template.id.clone(),
56            name: template.name.clone(),
57            description: template.description.clone(),
58            content: template.content.clone(),
59            variables,
60            tags: template.tags.clone(),
61            version: template.version.clone(),
62            metadata,
63            enabled: true,
64            created_at: now,
65            updated_at: now,
66            created_by: None,
67            tenant_id: None,
68        }
69    }
70
71    /// 转换为 PromptTemplate
72    pub fn to_template(&self) -> PromptResult<PromptTemplate> {
73        let variables: Vec<PromptVariable> = serde_json::from_value(self.variables.clone())
74            .map_err(|e| PromptError::ParseError(e.to_string()))?;
75        let metadata: HashMap<String, String> =
76            serde_json::from_value(self.metadata.clone()).unwrap_or_default();
77
78        Ok(PromptTemplate {
79            id: self.template_id.clone(),
80            name: self.name.clone(),
81            description: self.description.clone(),
82            content: self.content.clone(),
83            variables,
84            tags: self.tags.clone(),
85            version: self.version.clone(),
86            metadata,
87        })
88    }
89
90    /// 设置创建者
91    pub fn with_creator(mut self, creator_id: Uuid) -> Self {
92        self.created_by = Some(creator_id);
93        self
94    }
95
96    /// 设置租户
97    pub fn with_tenant(mut self, tenant_id: Uuid) -> Self {
98        self.tenant_id = Some(tenant_id);
99        self
100    }
101}
102
103/// Prompt 组合数据库实体
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct PromptCompositionEntity {
106    /// 唯一 ID
107    pub id: Uuid,
108    /// 组合标识符
109    pub composition_id: String,
110    /// 描述
111    pub description: Option<String>,
112    /// 模板 ID 列表
113    pub template_ids: Vec<String>,
114    /// 分隔符
115    pub separator: String,
116    /// 是否启用
117    pub enabled: bool,
118    /// 创建时间
119    pub created_at: chrono::DateTime<chrono::Utc>,
120    /// 更新时间
121    pub updated_at: chrono::DateTime<chrono::Utc>,
122    /// 租户 ID
123    pub tenant_id: Option<Uuid>,
124}
125
126impl PromptCompositionEntity {
127    /// 从 PromptComposition 创建实体
128    pub fn from_composition(composition: &PromptComposition) -> Self {
129        let now = chrono::Utc::now();
130        Self {
131            id: Uuid::now_v7(),
132            composition_id: composition.id.clone(),
133            description: composition.description.clone(),
134            template_ids: composition.template_ids.clone(),
135            separator: composition.separator.clone(),
136            enabled: true,
137            created_at: now,
138            updated_at: now,
139            tenant_id: None,
140        }
141    }
142
143    /// 转换为 PromptComposition
144    pub fn to_composition(&self) -> PromptComposition {
145        PromptComposition {
146            id: self.composition_id.clone(),
147            description: self.description.clone(),
148            template_ids: self.template_ids.clone(),
149            separator: self.separator.clone(),
150        }
151    }
152}
153
154/// Prompt 查询过滤器
155#[derive(Debug, Clone, Default)]
156pub struct PromptFilter {
157    /// 按模板 ID 查找
158    pub template_id: Option<String>,
159    /// 按标签查找
160    pub tags: Option<Vec<String>>,
161    /// 搜索关键词(名称、描述)
162    pub search: Option<String>,
163    /// 只返回启用的
164    pub enabled_only: bool,
165    /// 租户 ID
166    pub tenant_id: Option<Uuid>,
167    /// 分页偏移
168    pub offset: Option<i64>,
169    /// 分页限制
170    pub limit: Option<i64>,
171}
172
173impl PromptFilter {
174    pub fn new() -> Self {
175        Self {
176            enabled_only: true,
177            ..Default::default()
178        }
179    }
180
181    pub fn template_id(mut self, id: impl Into<String>) -> Self {
182        self.template_id = Some(id.into());
183        self
184    }
185
186    pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
187        self.tags.get_or_insert_with(Vec::new).push(tag.into());
188        self
189    }
190
191    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
192        self.tags = Some(tags);
193        self
194    }
195
196    pub fn search(mut self, keyword: impl Into<String>) -> Self {
197        self.search = Some(keyword.into());
198        self
199    }
200
201    pub fn include_disabled(mut self) -> Self {
202        self.enabled_only = false;
203        self
204    }
205
206    pub fn tenant(mut self, tenant_id: Uuid) -> Self {
207        self.tenant_id = Some(tenant_id);
208        self
209    }
210
211    pub fn paginate(mut self, offset: i64, limit: i64) -> Self {
212        self.offset = Some(offset);
213        self.limit = Some(limit);
214        self
215    }
216}
217
218/// Prompt 存储 trait
219///
220/// 定义 Prompt 模板的 CRUD 操作
221#[async_trait]
222pub trait PromptStore: Send + Sync {
223    /// 保存模板
224    async fn save_template(&self, entity: &PromptEntity) -> PromptResult<()>;
225
226    /// 批量保存模板
227    async fn save_templates(&self, entities: &[PromptEntity]) -> PromptResult<()> {
228        for entity in entities {
229            self.save_template(entity).await?;
230        }
231        Ok(())
232    }
233
234    /// 获取模板(按 UUID)
235    async fn get_template_by_id(&self, id: Uuid) -> PromptResult<Option<PromptEntity>>;
236
237    /// 获取模板(按模板 ID)
238    async fn get_template(&self, template_id: &str) -> PromptResult<Option<PromptEntity>>;
239
240    /// 查询模板列表
241    async fn query_templates(&self, filter: &PromptFilter) -> PromptResult<Vec<PromptEntity>>;
242
243    /// 按标签查找模板
244    async fn find_by_tag(&self, tag: &str) -> PromptResult<Vec<PromptEntity>>;
245
246    /// 搜索模板
247    async fn search_templates(&self, keyword: &str) -> PromptResult<Vec<PromptEntity>>;
248
249    /// 更新模板
250    async fn update_template(&self, entity: &PromptEntity) -> PromptResult<()>;
251
252    /// 删除模板(按 UUID)
253    async fn delete_template_by_id(&self, id: Uuid) -> PromptResult<bool>;
254
255    /// 删除模板(按模板 ID)
256    async fn delete_template(&self, template_id: &str) -> PromptResult<bool>;
257
258    /// 启用/禁用模板
259    async fn set_template_enabled(&self, template_id: &str, enabled: bool) -> PromptResult<()>;
260
261    /// 检查模板是否存在
262    async fn exists(&self, template_id: &str) -> PromptResult<bool>;
263
264    /// 统计模板数量
265    async fn count(&self, filter: &PromptFilter) -> PromptResult<i64>;
266
267    /// 获取所有标签
268    async fn get_all_tags(&self) -> PromptResult<Vec<String>>;
269
270    // ========== 组合操作 ==========
271
272    /// 保存组合
273    async fn save_composition(&self, entity: &PromptCompositionEntity) -> PromptResult<()>;
274
275    /// 获取组合
276    async fn get_composition(
277        &self,
278        composition_id: &str,
279    ) -> PromptResult<Option<PromptCompositionEntity>>;
280
281    /// 查询所有组合
282    async fn query_compositions(&self) -> PromptResult<Vec<PromptCompositionEntity>>;
283
284    /// 删除组合
285    async fn delete_composition(&self, composition_id: &str) -> PromptResult<bool>;
286}
287
288/// 动态分发的 PromptStore
289pub type DynPromptStore = std::sync::Arc<dyn PromptStore>;
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294
295    #[test]
296    fn test_prompt_entity_from_template() {
297        let template = PromptTemplate::new("test")
298            .with_name("Test Template")
299            .with_content("Hello, {name}!")
300            .with_tag("greeting");
301
302        let entity = PromptEntity::from_template(&template);
303
304        assert_eq!(entity.template_id, "test");
305        assert_eq!(entity.name, Some("Test Template".to_string()));
306        assert!(entity.enabled);
307    }
308
309    #[test]
310    fn test_prompt_entity_to_template() {
311        let template = PromptTemplate::new("test")
312            .with_name("Test Template")
313            .with_content("Hello, {name}!")
314            .with_tag("greeting");
315
316        let entity = PromptEntity::from_template(&template);
317        let converted = entity.to_template().unwrap();
318
319        assert_eq!(converted.id, template.id);
320        assert_eq!(converted.name, template.name);
321        assert_eq!(converted.content, template.content);
322    }
323
324    #[test]
325    fn test_prompt_filter_builder() {
326        let filter = PromptFilter::new()
327            .with_tag("code")
328            .search("review")
329            .paginate(0, 10);
330
331        assert_eq!(filter.tags, Some(vec!["code".to_string()]));
332        assert_eq!(filter.search, Some("review".to_string()));
333        assert_eq!(filter.limit, Some(10));
334    }
335}