Skip to main content

memrec_common/protocol/
request.rs

1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use uuid::Uuid;
4
5use crate::types::MemoryType;
6
7#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
8#[serde(rename_all = "snake_case")]
9pub enum RequestAction {
10    Add,
11    Get,
12    Update,
13    Delete,
14    Search,
15    List,
16    Tag,
17    
18    SearchMemory,
19    GetProjectInfo,
20    GetVersion,
21    
22    ProjectCreate,
23    ProjectList,
24    ProjectSwitch,
25    ProjectDelete,
26    
27    ConfigGet,
28    ConfigSet,
29    
30    Stats,
31}
32
33impl std::fmt::Display for RequestAction {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        match self {
36            RequestAction::Add => write!(f, "add"),
37            RequestAction::Get => write!(f, "get"),
38            RequestAction::Update => write!(f, "update"),
39            RequestAction::Delete => write!(f, "delete"),
40            RequestAction::Search => write!(f, "search"),
41            RequestAction::List => write!(f, "list"),
42            RequestAction::Tag => write!(f, "tag"),
43            RequestAction::SearchMemory => write!(f, "search_memory"),
44            RequestAction::GetProjectInfo => write!(f, "get_project_info"),
45            RequestAction::GetVersion => write!(f, "get_version"),
46            RequestAction::ProjectCreate => write!(f, "project_create"),
47            RequestAction::ProjectList => write!(f, "project_list"),
48            RequestAction::ProjectSwitch => write!(f, "project_switch"),
49            RequestAction::ProjectDelete => write!(f, "project_delete"),
50            RequestAction::ConfigGet => write!(f, "config_get"),
51            RequestAction::ConfigSet => write!(f, "config_set"),
52            RequestAction::Stats => write!(f, "stats"),
53        }
54    }
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct JsonRpcRequest {
59    pub jsonrpc: String,
60    pub method: RequestAction,
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub params: Option<RequestParams>,
63    pub id: u64,
64}
65
66impl JsonRpcRequest {
67    pub fn new(method: RequestAction, params: Option<RequestParams>, id: u64) -> Self {
68        Self {
69            jsonrpc: "2.0".to_string(),
70            method,
71            params,
72            id,
73        }
74    }
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
78#[serde(tag = "type", rename_all = "snake_case")]
79pub enum RequestParams {
80    Add(AddParams),
81    Get(GetParams),
82    Update(UpdateParams),
83    Delete(DeleteParams),
84    Search(SearchParams),
85    List(ListParams),
86    Tag(TagParams),
87    
88    SearchMemory(SearchMemoryParams),
89    GetProjectInfo(GetProjectInfoParams),
90    GetVersion(GetVersionParams),
91    
92    ProjectCreate(ProjectCreateParams),
93    ProjectSwitch(ProjectSwitchParams),
94    ProjectDelete(ProjectDeleteParams),
95    
96    ConfigSet(ConfigSetParams),
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct AddParams {
101    pub content: String,
102    #[serde(default)]
103    pub memory_type: MemoryType,
104    #[serde(default)]
105    pub tags: Vec<String>,
106    #[serde(skip_serializing_if = "Option::is_none")]
107    pub project_id: Option<Uuid>,
108    #[serde(default)]
109    pub is_global: bool,
110    #[serde(skip_serializing_if = "Option::is_none")]
111    pub working_dir: Option<String>,
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct GetParams {
116    pub id: Uuid,
117    #[serde(default)]
118    pub merge: bool,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct UpdateParams {
123    pub id: Uuid,
124    #[serde(skip_serializing_if = "Option::is_none")]
125    pub content: Option<String>,
126    #[serde(skip_serializing_if = "Option::is_none")]
127    pub tags: Option<Vec<String>>,
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct DeleteParams {
132    pub id: Uuid,
133    #[serde(default)]
134    pub force: bool,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct SearchParams {
139    #[serde(skip_serializing_if = "Option::is_none")]
140    pub text: Option<String>,
141    #[serde(default)]
142    pub mode: SearchMode,
143    #[serde(default)]
144    pub tags: Vec<String>,
145    #[serde(skip_serializing_if = "Option::is_none")]
146    pub time_range: Option<TimeRange>,
147    #[serde(skip_serializing_if = "Option::is_none")]
148    pub project_id: Option<Uuid>,
149    #[serde(default = "default_top_k")]
150    pub top_k: usize,
151    #[serde(default)]
152    pub min_importance: f32,
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct SearchMemoryParams {
157    pub query: String,
158    #[serde(skip_serializing_if = "Option::is_none")]
159    pub project_id: Option<Uuid>,
160    #[serde(default = "default_include_global")]
161    pub include_global: bool,
162    #[serde(default)]
163    pub project_only: bool,
164    #[serde(default)]
165    pub global_only: bool,
166    #[serde(default)]
167    pub cross_project: bool,
168    #[serde(skip_serializing_if = "Option::is_none")]
169    pub memory_type: Option<MemoryType>,
170    #[serde(default = "default_top_k")]
171    pub top_k: usize,
172    #[serde(default = "default_min_score")]
173    pub min_score: f32,
174    #[serde(skip_serializing_if = "Option::is_none")]
175    pub working_dir: Option<String>,
176}
177
178pub fn default_include_global() -> bool { true }
179pub fn default_min_score() -> f32 {
180    std::env::var("MEMREC_MIN_SCORE")
181        .ok()
182        .and_then(|v| v.parse().ok())
183        .unwrap_or(0.75)
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize, Default)]
187pub struct GetProjectInfoParams {
188    #[serde(default, skip_serializing_if = "Option::is_none")]
189    pub working_dir: Option<String>,
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct GetVersionParams;
194
195#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
196#[serde(rename_all = "lowercase")]
197#[derive(Default)]
198pub enum SearchMode {
199    Exact,
200    Semantic,
201    #[default]
202    Hybrid,
203}
204
205
206#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct TimeRange {
208    pub start: DateTime<Utc>,
209    pub end: DateTime<Utc>,
210}
211
212fn default_top_k() -> usize { 10 }
213
214#[derive(Debug, Clone, Serialize, Deserialize)]
215pub struct ListParams {
216    #[serde(default)]
217    pub tags: Vec<String>,
218    #[serde(skip_serializing_if = "Option::is_none")]
219    pub memory_type: Option<MemoryType>,
220    #[serde(default = "default_limit")]
221    pub limit: usize,
222    #[serde(default)]
223    pub project_only: bool,
224    #[serde(default)]
225    pub global_only: bool,
226    #[serde(skip_serializing_if = "Option::is_none")]
227    pub project_id: Option<Uuid>,
228}
229
230fn default_limit() -> usize { 20 }
231
232#[derive(Debug, Clone, Serialize, Deserialize)]
233pub struct TagParams {
234    pub id: Uuid,
235    pub tag: String,
236}
237
238#[derive(Debug, Clone, Serialize, Deserialize)]
239pub struct ProjectCreateParams {
240    pub name: String,
241    #[serde(skip_serializing_if = "Option::is_none")]
242    pub description: Option<String>,
243}
244
245#[derive(Debug, Clone, Serialize, Deserialize)]
246pub struct ProjectSwitchParams {
247    pub name: String,
248}
249
250#[derive(Debug, Clone, Serialize, Deserialize)]
251pub struct ProjectDeleteParams {
252    pub name: String,
253    #[serde(default)]
254    pub force: bool,
255}
256
257#[derive(Debug, Clone, Serialize, Deserialize)]
258pub struct ConfigSetParams {
259    pub key: String,
260    pub value: String,
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266    
267    #[test]
268    fn test_request_creation() {
269        let req = JsonRpcRequest::new(
270            RequestAction::Add,
271            Some(RequestParams::Add(AddParams {
272                content: "test".to_string(),
273                memory_type: MemoryType::Knowledge,
274                tags: vec!["tag1".to_string()],
275                project_id: None,
276                is_global: false,
277                working_dir: None,
278            })),
279            1,
280        );
281        
282        assert_eq!(req.jsonrpc, "2.0");
283        assert_eq!(req.id, 1);
284    }
285    
286    #[test]
287    fn test_request_serde() {
288        let req = JsonRpcRequest::new(
289            RequestAction::Get,
290            Some(RequestParams::Get(GetParams {
291                id: Uuid::nil(),
292                merge: false,
293            })),
294            1,
295        );
296        
297        let json = serde_json::to_string(&req).unwrap();
298        let parsed: JsonRpcRequest = serde_json::from_str(&json).unwrap();
299        
300        assert_eq!(req.jsonrpc, parsed.jsonrpc);
301    }
302    
303    #[test]
304    fn test_search_params_defaults() {
305        let params = SearchParams {
306            text: None,
307            mode: SearchMode::default(),
308            tags: Vec::new(),
309            time_range: None,
310            project_id: None,
311            top_k: default_top_k(),
312            min_importance: 0.0,
313        };
314        
315        assert_eq!(params.mode, SearchMode::Hybrid);
316        assert_eq!(params.top_k, 10);
317    }
318    
319    #[test]
320    fn test_search_memory_params_defaults() {
321        let params = SearchMemoryParams {
322            query: "test".to_string(),
323            project_id: None,
324            include_global: default_include_global(),
325            project_only: false,
326            global_only: false,
327            cross_project: false,
328            memory_type: None,
329            top_k: default_top_k(),
330            min_score: default_min_score(),
331            working_dir: None,
332        };
333        
334        assert_eq!(params.include_global, true);
335        assert_eq!(params.top_k, 10);
336        assert_eq!(params.min_score, 0.75);
337    }
338    
339    #[test]
340    fn test_search_memory_params_serde() {
341        let params = SearchMemoryParams {
342            query: "authentication".to_string(),
343            project_id: Some(Uuid::new_v4()),
344            include_global: true,
345            project_only: false,
346            global_only: false,
347            cross_project: false,
348            memory_type: Some(MemoryType::Decision),
349            top_k: 20,
350            min_score: 0.8,
351            working_dir: None,
352        };
353        
354        let json = serde_json::to_string(&params).unwrap();
355        let parsed: SearchMemoryParams = serde_json::from_str(&json).unwrap();
356        
357        assert_eq!(params.query, parsed.query);
358        assert_eq!(params.top_k, parsed.top_k);
359    }
360}