1use serde::{Deserialize, Serialize};
4use std::fmt;
5use std::str::FromStr;
6
7use super::fields::ParseEnumError;
8
9#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
10#[serde(rename_all = "snake_case")]
11pub enum MemoryAction {
12 Get,
13 Upsert,
14 Delete,
15 List,
16 Search,
17}
18
19impl fmt::Display for MemoryAction {
20 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21 match self {
22 Self::Get => write!(f, "get"),
23 Self::Upsert => write!(f, "upsert"),
24 Self::Delete => write!(f, "delete"),
25 Self::List => write!(f, "list"),
26 Self::Search => write!(f, "search"),
27 }
28 }
29}
30
31impl FromStr for MemoryAction {
32 type Err = ParseEnumError;
33
34 fn from_str(s: &str) -> Result<Self, Self::Err> {
35 match s {
36 "get" => Ok(Self::Get),
37 "upsert" => Ok(Self::Upsert),
38 "delete" => Ok(Self::Delete),
39 "list" => Ok(Self::List),
40 "search" => Ok(Self::Search),
41 _ => Err(ParseEnumError {
42 type_name: "MemoryAction",
43 value: s.to_owned(),
44 }),
45 }
46 }
47}
48
49#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
50#[serde(rename_all = "snake_case")]
51pub enum MemoryScope {
52 Node,
53 Session,
54 Project,
55 Global,
56}
57
58impl fmt::Display for MemoryScope {
59 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 match self {
61 Self::Node => write!(f, "node"),
62 Self::Session => write!(f, "session"),
63 Self::Project => write!(f, "project"),
64 Self::Global => write!(f, "global"),
65 }
66 }
67}
68
69impl FromStr for MemoryScope {
70 type Err = ParseEnumError;
71
72 fn from_str(s: &str) -> Result<Self, Self::Err> {
73 match s {
74 "node" => Ok(Self::Node),
75 "session" => Ok(Self::Session),
76 "project" => Ok(Self::Project),
77 "global" => Ok(Self::Global),
78 _ => Err(ParseEnumError {
79 type_name: "MemoryScope",
80 value: s.to_owned(),
81 }),
82 }
83 }
84}
85
86#[derive(Debug, Clone, PartialEq)]
87pub enum MemoryTtl {
88 Permanent,
89 Session,
90 Duration(String),
91}
92
93impl fmt::Display for MemoryTtl {
94 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95 match self {
96 Self::Permanent => write!(f, "permanent"),
97 Self::Session => write!(f, "session"),
98 Self::Duration(d) => write!(f, "duration:{d}"),
99 }
100 }
101}
102
103impl FromStr for MemoryTtl {
104 type Err = ParseEnumError;
105
106 fn from_str(s: &str) -> Result<Self, Self::Err> {
107 match s {
108 "permanent" => Ok(Self::Permanent),
109 "session" => Ok(Self::Session),
110 _ if s.starts_with("duration:") => {
111 let duration = s.strip_prefix("duration:").unwrap().to_owned();
112 if duration.is_empty() {
113 return Err(ParseEnumError {
114 type_name: "MemoryTtl",
115 value: s.to_owned(),
116 });
117 }
118 Ok(Self::Duration(duration))
119 }
120 _ => Err(ParseEnumError {
121 type_name: "MemoryTtl",
122 value: s.to_owned(),
123 }),
124 }
125 }
126}
127
128impl Serialize for MemoryTtl {
129 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
130 serializer.serialize_str(&self.to_string())
131 }
132}
133
134impl<'de> Deserialize<'de> for MemoryTtl {
135 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
136 let s = String::deserialize(deserializer)?;
137 s.parse().map_err(serde::de::Error::custom)
138 }
139}
140
141#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
142pub struct MemoryEntry {
143 pub key: String,
144 pub topic: String,
145 pub action: MemoryAction,
146 #[serde(skip_serializing_if = "Option::is_none")]
147 pub value: Option<String>,
148 #[serde(skip_serializing_if = "Option::is_none")]
149 pub scope: Option<MemoryScope>,
150 #[serde(skip_serializing_if = "Option::is_none")]
151 pub ttl: Option<MemoryTtl>,
152 #[serde(skip_serializing_if = "Option::is_none")]
153 pub query: Option<String>,
154 #[serde(skip_serializing_if = "Option::is_none")]
155 pub max_results: Option<u32>,
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161
162 #[test]
163 fn test_memory_action_from_str_valid_returns_ok() {
164 assert_eq!("get".parse::<MemoryAction>().unwrap(), MemoryAction::Get);
165 assert_eq!(
166 "upsert".parse::<MemoryAction>().unwrap(),
167 MemoryAction::Upsert
168 );
169 assert_eq!(
170 "delete".parse::<MemoryAction>().unwrap(),
171 MemoryAction::Delete
172 );
173 assert_eq!("list".parse::<MemoryAction>().unwrap(), MemoryAction::List);
174 assert_eq!(
175 "search".parse::<MemoryAction>().unwrap(),
176 MemoryAction::Search
177 );
178 }
179
180 #[test]
181 fn test_memory_action_from_str_invalid_returns_error() {
182 let err = "update".parse::<MemoryAction>().unwrap_err();
183 assert_eq!(err.type_name, "MemoryAction");
184 }
185
186 #[test]
187 fn test_memory_action_display_roundtrip() {
188 for a in [
189 MemoryAction::Get,
190 MemoryAction::Upsert,
191 MemoryAction::Delete,
192 MemoryAction::List,
193 MemoryAction::Search,
194 ] {
195 let text = a.to_string();
196 assert_eq!(text.parse::<MemoryAction>().unwrap(), a);
197 }
198 }
199
200 #[test]
201 fn test_memory_scope_from_str_valid_returns_ok() {
202 assert_eq!("node".parse::<MemoryScope>().unwrap(), MemoryScope::Node);
203 assert_eq!(
204 "session".parse::<MemoryScope>().unwrap(),
205 MemoryScope::Session
206 );
207 assert_eq!(
208 "project".parse::<MemoryScope>().unwrap(),
209 MemoryScope::Project
210 );
211 assert_eq!(
212 "global".parse::<MemoryScope>().unwrap(),
213 MemoryScope::Global
214 );
215 }
216
217 #[test]
218 fn test_memory_scope_from_str_invalid_returns_error() {
219 let err = "workspace".parse::<MemoryScope>().unwrap_err();
220 assert_eq!(err.type_name, "MemoryScope");
221 }
222
223 #[test]
224 fn test_memory_scope_display_roundtrip() {
225 for s in [
226 MemoryScope::Node,
227 MemoryScope::Session,
228 MemoryScope::Project,
229 MemoryScope::Global,
230 ] {
231 let text = s.to_string();
232 assert_eq!(text.parse::<MemoryScope>().unwrap(), s);
233 }
234 }
235
236 #[test]
237 fn test_memory_ttl_from_str_permanent_returns_ok() {
238 assert_eq!(
239 "permanent".parse::<MemoryTtl>().unwrap(),
240 MemoryTtl::Permanent
241 );
242 }
243
244 #[test]
245 fn test_memory_ttl_from_str_session_returns_ok() {
246 assert_eq!("session".parse::<MemoryTtl>().unwrap(), MemoryTtl::Session);
247 }
248
249 #[test]
250 fn test_memory_ttl_from_str_duration_returns_ok() {
251 assert_eq!(
252 "duration:P7D".parse::<MemoryTtl>().unwrap(),
253 MemoryTtl::Duration("P7D".to_owned())
254 );
255 }
256
257 #[test]
258 fn test_memory_ttl_from_str_duration_pt1h_returns_ok() {
259 assert_eq!(
260 "duration:PT1H".parse::<MemoryTtl>().unwrap(),
261 MemoryTtl::Duration("PT1H".to_owned())
262 );
263 }
264
265 #[test]
266 fn test_memory_ttl_from_str_empty_duration_returns_error() {
267 assert!("duration:".parse::<MemoryTtl>().is_err());
268 }
269
270 #[test]
271 fn test_memory_ttl_from_str_invalid_returns_error() {
272 let err = "forever".parse::<MemoryTtl>().unwrap_err();
273 assert_eq!(err.type_name, "MemoryTtl");
274 }
275
276 #[test]
277 fn test_memory_ttl_display_roundtrip() {
278 for t in [
279 MemoryTtl::Permanent,
280 MemoryTtl::Session,
281 MemoryTtl::Duration("P30D".to_owned()),
282 ] {
283 let text = t.to_string();
284 assert_eq!(text.parse::<MemoryTtl>().unwrap(), t);
285 }
286 }
287
288 #[test]
289 fn test_memory_ttl_serde_roundtrip() {
290 for t in [
291 MemoryTtl::Permanent,
292 MemoryTtl::Session,
293 MemoryTtl::Duration("P7D".to_owned()),
294 ] {
295 let json = serde_json::to_string(&t).unwrap();
296 let back: MemoryTtl = serde_json::from_str(&json).unwrap();
297 assert_eq!(t, back);
298 }
299 }
300
301 #[test]
302 fn test_memory_entry_upsert_serde_roundtrip() {
303 let entry = MemoryEntry {
304 key: "repo.pattern".to_owned(),
305 topic: "rust.repository".to_owned(),
306 action: MemoryAction::Upsert,
307 value: Some("row_to_column uses get()".to_owned()),
308 scope: Some(MemoryScope::Project),
309 ttl: Some(MemoryTtl::Permanent),
310 query: None,
311 max_results: None,
312 };
313 let json = serde_json::to_string(&entry).unwrap();
314 let back: MemoryEntry = serde_json::from_str(&json).unwrap();
315 assert_eq!(entry, back);
316 }
317
318 #[test]
319 fn test_memory_entry_search_serde_roundtrip() {
320 let entry = MemoryEntry {
321 key: "search.patterns".to_owned(),
322 topic: "rust.repository".to_owned(),
323 action: MemoryAction::Search,
324 value: None,
325 scope: None,
326 ttl: None,
327 query: Some("how are optional fields handled".to_owned()),
328 max_results: Some(5),
329 };
330 let json = serde_json::to_string(&entry).unwrap();
331 let back: MemoryEntry = serde_json::from_str(&json).unwrap();
332 assert_eq!(entry, back);
333 }
334
335 #[test]
336 fn test_memory_entry_optional_fields_absent() {
337 let entry = MemoryEntry {
338 key: "test.key".to_owned(),
339 topic: "test".to_owned(),
340 action: MemoryAction::Get,
341 value: None,
342 scope: None,
343 ttl: None,
344 query: None,
345 max_results: None,
346 };
347 let json = serde_json::to_string(&entry).unwrap();
348 assert!(!json.contains("value"));
349 assert!(!json.contains("scope"));
350 assert!(!json.contains("ttl"));
351 assert!(!json.contains("query"));
352 assert!(!json.contains("max_results"));
353 }
354}