1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use uuid::Uuid;
4
5use crate::types::MemoryType;
6
7#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
8#[serde(rename_all = "snake_case")]
9pub enum RequestAction {
10 Add,
11 Get,
12 Update,
13 Delete,
14 Search,
15 List,
16 Tag,
17
18 SearchMemory,
19 GetProjectInfo,
20 GetVersion,
21
22 ProjectCreate,
23 ProjectList,
24 ProjectSwitch,
25 ProjectDelete,
26
27 ConfigGet,
28 ConfigSet,
29
30 Stats,
31}
32
33impl std::fmt::Display for RequestAction {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 match self {
36 RequestAction::Add => write!(f, "add"),
37 RequestAction::Get => write!(f, "get"),
38 RequestAction::Update => write!(f, "update"),
39 RequestAction::Delete => write!(f, "delete"),
40 RequestAction::Search => write!(f, "search"),
41 RequestAction::List => write!(f, "list"),
42 RequestAction::Tag => write!(f, "tag"),
43 RequestAction::SearchMemory => write!(f, "search_memory"),
44 RequestAction::GetProjectInfo => write!(f, "get_project_info"),
45 RequestAction::GetVersion => write!(f, "get_version"),
46 RequestAction::ProjectCreate => write!(f, "project_create"),
47 RequestAction::ProjectList => write!(f, "project_list"),
48 RequestAction::ProjectSwitch => write!(f, "project_switch"),
49 RequestAction::ProjectDelete => write!(f, "project_delete"),
50 RequestAction::ConfigGet => write!(f, "config_get"),
51 RequestAction::ConfigSet => write!(f, "config_set"),
52 RequestAction::Stats => write!(f, "stats"),
53 }
54 }
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct JsonRpcRequest {
59 pub jsonrpc: String,
60 pub method: RequestAction,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 pub params: Option<RequestParams>,
63 pub id: u64,
64}
65
66impl JsonRpcRequest {
67 pub fn new(method: RequestAction, params: Option<RequestParams>, id: u64) -> Self {
68 Self {
69 jsonrpc: "2.0".to_string(),
70 method,
71 params,
72 id,
73 }
74 }
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
78#[serde(tag = "type", rename_all = "snake_case")]
79pub enum RequestParams {
80 Add(AddParams),
81 Get(GetParams),
82 Update(UpdateParams),
83 Delete(DeleteParams),
84 Search(SearchParams),
85 List(ListParams),
86 Tag(TagParams),
87
88 SearchMemory(SearchMemoryParams),
89 GetProjectInfo(GetProjectInfoParams),
90 GetVersion(GetVersionParams),
91
92 ProjectCreate(ProjectCreateParams),
93 ProjectSwitch(ProjectSwitchParams),
94 ProjectDelete(ProjectDeleteParams),
95
96 ConfigSet(ConfigSetParams),
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct AddParams {
101 pub content: String,
102 #[serde(default)]
103 pub memory_type: MemoryType,
104 #[serde(default)]
105 pub tags: Vec<String>,
106 #[serde(skip_serializing_if = "Option::is_none")]
107 pub project_id: Option<Uuid>,
108 #[serde(default)]
109 pub is_global: bool,
110 #[serde(skip_serializing_if = "Option::is_none")]
111 pub working_dir: Option<String>,
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct GetParams {
116 pub id: Uuid,
117 #[serde(default)]
118 pub merge: bool,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct UpdateParams {
123 pub id: Uuid,
124 #[serde(skip_serializing_if = "Option::is_none")]
125 pub content: Option<String>,
126 #[serde(skip_serializing_if = "Option::is_none")]
127 pub tags: Option<Vec<String>>,
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct DeleteParams {
132 pub id: Uuid,
133 #[serde(default)]
134 pub force: bool,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct SearchParams {
139 #[serde(skip_serializing_if = "Option::is_none")]
140 pub text: Option<String>,
141 #[serde(default)]
142 pub mode: SearchMode,
143 #[serde(default)]
144 pub tags: Vec<String>,
145 #[serde(skip_serializing_if = "Option::is_none")]
146 pub time_range: Option<TimeRange>,
147 #[serde(skip_serializing_if = "Option::is_none")]
148 pub project_id: Option<Uuid>,
149 #[serde(default = "default_top_k")]
150 pub top_k: usize,
151 #[serde(default)]
152 pub min_importance: f32,
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct SearchMemoryParams {
157 pub query: String,
158 #[serde(skip_serializing_if = "Option::is_none")]
159 pub project_id: Option<Uuid>,
160 #[serde(default = "default_include_global")]
161 pub include_global: bool,
162 #[serde(default)]
163 pub project_only: bool,
164 #[serde(default)]
165 pub global_only: bool,
166 #[serde(default)]
167 pub cross_project: bool,
168 #[serde(skip_serializing_if = "Option::is_none")]
169 pub memory_type: Option<MemoryType>,
170 #[serde(default = "default_top_k")]
171 pub top_k: usize,
172 #[serde(default = "default_min_score")]
173 pub min_score: f32,
174 #[serde(skip_serializing_if = "Option::is_none")]
175 pub working_dir: Option<String>,
176}
177
178pub fn default_include_global() -> bool { true }
179pub fn default_min_score() -> f32 {
180 std::env::var("MEMREC_MIN_SCORE")
181 .ok()
182 .and_then(|v| v.parse().ok())
183 .unwrap_or(0.75)
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize, Default)]
187pub struct GetProjectInfoParams {
188 #[serde(default, skip_serializing_if = "Option::is_none")]
189 pub working_dir: Option<String>,
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct GetVersionParams;
194
195#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
196#[serde(rename_all = "lowercase")]
197#[derive(Default)]
198pub enum SearchMode {
199 Exact,
200 Semantic,
201 #[default]
202 Hybrid,
203}
204
205
206#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct TimeRange {
208 pub start: DateTime<Utc>,
209 pub end: DateTime<Utc>,
210}
211
212fn default_top_k() -> usize { 10 }
213
214#[derive(Debug, Clone, Serialize, Deserialize)]
215pub struct ListParams {
216 #[serde(default)]
217 pub tags: Vec<String>,
218 #[serde(skip_serializing_if = "Option::is_none")]
219 pub memory_type: Option<MemoryType>,
220 #[serde(default = "default_limit")]
221 pub limit: usize,
222 #[serde(default)]
223 pub project_only: bool,
224 #[serde(default)]
225 pub global_only: bool,
226 #[serde(skip_serializing_if = "Option::is_none")]
227 pub project_id: Option<Uuid>,
228}
229
230fn default_limit() -> usize { 20 }
231
232#[derive(Debug, Clone, Serialize, Deserialize)]
233pub struct TagParams {
234 pub id: Uuid,
235 pub tag: String,
236}
237
238#[derive(Debug, Clone, Serialize, Deserialize)]
239pub struct ProjectCreateParams {
240 pub name: String,
241 #[serde(skip_serializing_if = "Option::is_none")]
242 pub description: Option<String>,
243}
244
245#[derive(Debug, Clone, Serialize, Deserialize)]
246pub struct ProjectSwitchParams {
247 pub name: String,
248}
249
250#[derive(Debug, Clone, Serialize, Deserialize)]
251pub struct ProjectDeleteParams {
252 pub name: String,
253 #[serde(default)]
254 pub force: bool,
255}
256
257#[derive(Debug, Clone, Serialize, Deserialize)]
258pub struct ConfigSetParams {
259 pub key: String,
260 pub value: String,
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266
267 #[test]
268 fn test_request_creation() {
269 let req = JsonRpcRequest::new(
270 RequestAction::Add,
271 Some(RequestParams::Add(AddParams {
272 content: "test".to_string(),
273 memory_type: MemoryType::Knowledge,
274 tags: vec!["tag1".to_string()],
275 project_id: None,
276 is_global: false,
277 working_dir: None,
278 })),
279 1,
280 );
281
282 assert_eq!(req.jsonrpc, "2.0");
283 assert_eq!(req.id, 1);
284 }
285
286 #[test]
287 fn test_request_serde() {
288 let req = JsonRpcRequest::new(
289 RequestAction::Get,
290 Some(RequestParams::Get(GetParams {
291 id: Uuid::nil(),
292 merge: false,
293 })),
294 1,
295 );
296
297 let json = serde_json::to_string(&req).unwrap();
298 let parsed: JsonRpcRequest = serde_json::from_str(&json).unwrap();
299
300 assert_eq!(req.jsonrpc, parsed.jsonrpc);
301 }
302
303 #[test]
304 fn test_search_params_defaults() {
305 let params = SearchParams {
306 text: None,
307 mode: SearchMode::default(),
308 tags: Vec::new(),
309 time_range: None,
310 project_id: None,
311 top_k: default_top_k(),
312 min_importance: 0.0,
313 };
314
315 assert_eq!(params.mode, SearchMode::Hybrid);
316 assert_eq!(params.top_k, 10);
317 }
318
319 #[test]
320 fn test_search_memory_params_defaults() {
321 let params = SearchMemoryParams {
322 query: "test".to_string(),
323 project_id: None,
324 include_global: default_include_global(),
325 project_only: false,
326 global_only: false,
327 cross_project: false,
328 memory_type: None,
329 top_k: default_top_k(),
330 min_score: default_min_score(),
331 working_dir: None,
332 };
333
334 assert_eq!(params.include_global, true);
335 assert_eq!(params.top_k, 10);
336 assert_eq!(params.min_score, 0.75);
337 }
338
339 #[test]
340 fn test_search_memory_params_serde() {
341 let params = SearchMemoryParams {
342 query: "authentication".to_string(),
343 project_id: Some(Uuid::new_v4()),
344 include_global: true,
345 project_only: false,
346 global_only: false,
347 cross_project: false,
348 memory_type: Some(MemoryType::Decision),
349 top_k: 20,
350 min_score: 0.8,
351 working_dir: None,
352 };
353
354 let json = serde_json::to_string(¶ms).unwrap();
355 let parsed: SearchMemoryParams = serde_json::from_str(&json).unwrap();
356
357 assert_eq!(params.query, parsed.query);
358 assert_eq!(params.top_k, parsed.top_k);
359 }
360}