Skip to main content

agm_core/model/
memory.rs

1//! Memory model types (spec S28).
2
3use serde::{Deserialize, Serialize};
4use std::fmt;
5use std::str::FromStr;
6
7use super::fields::ParseEnumError;
8
9#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
10#[serde(rename_all = "snake_case")]
11pub enum MemoryAction {
12    Get,
13    Upsert,
14    Delete,
15    List,
16    Search,
17}
18
19impl fmt::Display for MemoryAction {
20    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21        match self {
22            Self::Get => write!(f, "get"),
23            Self::Upsert => write!(f, "upsert"),
24            Self::Delete => write!(f, "delete"),
25            Self::List => write!(f, "list"),
26            Self::Search => write!(f, "search"),
27        }
28    }
29}
30
31impl FromStr for MemoryAction {
32    type Err = ParseEnumError;
33
34    fn from_str(s: &str) -> Result<Self, Self::Err> {
35        match s {
36            "get" => Ok(Self::Get),
37            "upsert" => Ok(Self::Upsert),
38            "delete" => Ok(Self::Delete),
39            "list" => Ok(Self::List),
40            "search" => Ok(Self::Search),
41            _ => Err(ParseEnumError {
42                type_name: "MemoryAction",
43                value: s.to_owned(),
44            }),
45        }
46    }
47}
48
49#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
50#[serde(rename_all = "snake_case")]
51pub enum MemoryScope {
52    Node,
53    Session,
54    Project,
55    Global,
56}
57
58impl fmt::Display for MemoryScope {
59    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60        match self {
61            Self::Node => write!(f, "node"),
62            Self::Session => write!(f, "session"),
63            Self::Project => write!(f, "project"),
64            Self::Global => write!(f, "global"),
65        }
66    }
67}
68
69impl FromStr for MemoryScope {
70    type Err = ParseEnumError;
71
72    fn from_str(s: &str) -> Result<Self, Self::Err> {
73        match s {
74            "node" => Ok(Self::Node),
75            "session" => Ok(Self::Session),
76            "project" => Ok(Self::Project),
77            "global" => Ok(Self::Global),
78            _ => Err(ParseEnumError {
79                type_name: "MemoryScope",
80                value: s.to_owned(),
81            }),
82        }
83    }
84}
85
86#[derive(Debug, Clone, PartialEq)]
87pub enum MemoryTtl {
88    Permanent,
89    Session,
90    Duration(String),
91}
92
93impl fmt::Display for MemoryTtl {
94    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95        match self {
96            Self::Permanent => write!(f, "permanent"),
97            Self::Session => write!(f, "session"),
98            Self::Duration(d) => write!(f, "duration:{d}"),
99        }
100    }
101}
102
103impl FromStr for MemoryTtl {
104    type Err = ParseEnumError;
105
106    fn from_str(s: &str) -> Result<Self, Self::Err> {
107        match s {
108            "permanent" => Ok(Self::Permanent),
109            "session" => Ok(Self::Session),
110            _ if s.starts_with("duration:") => {
111                let duration = s.strip_prefix("duration:").unwrap().to_owned();
112                if duration.is_empty() {
113                    return Err(ParseEnumError {
114                        type_name: "MemoryTtl",
115                        value: s.to_owned(),
116                    });
117                }
118                Ok(Self::Duration(duration))
119            }
120            _ => Err(ParseEnumError {
121                type_name: "MemoryTtl",
122                value: s.to_owned(),
123            }),
124        }
125    }
126}
127
128impl Serialize for MemoryTtl {
129    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
130        serializer.serialize_str(&self.to_string())
131    }
132}
133
134impl<'de> Deserialize<'de> for MemoryTtl {
135    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
136        let s = String::deserialize(deserializer)?;
137        s.parse().map_err(serde::de::Error::custom)
138    }
139}
140
141#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
142pub struct MemoryEntry {
143    pub key: String,
144    pub topic: String,
145    pub action: MemoryAction,
146    #[serde(skip_serializing_if = "Option::is_none")]
147    pub value: Option<String>,
148    #[serde(skip_serializing_if = "Option::is_none")]
149    pub scope: Option<MemoryScope>,
150    #[serde(skip_serializing_if = "Option::is_none")]
151    pub ttl: Option<MemoryTtl>,
152    #[serde(skip_serializing_if = "Option::is_none")]
153    pub query: Option<String>,
154    #[serde(skip_serializing_if = "Option::is_none")]
155    pub max_results: Option<u32>,
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161
162    #[test]
163    fn test_memory_action_from_str_valid_returns_ok() {
164        assert_eq!("get".parse::<MemoryAction>().unwrap(), MemoryAction::Get);
165        assert_eq!(
166            "upsert".parse::<MemoryAction>().unwrap(),
167            MemoryAction::Upsert
168        );
169        assert_eq!(
170            "delete".parse::<MemoryAction>().unwrap(),
171            MemoryAction::Delete
172        );
173        assert_eq!("list".parse::<MemoryAction>().unwrap(), MemoryAction::List);
174        assert_eq!(
175            "search".parse::<MemoryAction>().unwrap(),
176            MemoryAction::Search
177        );
178    }
179
180    #[test]
181    fn test_memory_action_from_str_invalid_returns_error() {
182        let err = "update".parse::<MemoryAction>().unwrap_err();
183        assert_eq!(err.type_name, "MemoryAction");
184    }
185
186    #[test]
187    fn test_memory_action_display_roundtrip() {
188        for a in [
189            MemoryAction::Get,
190            MemoryAction::Upsert,
191            MemoryAction::Delete,
192            MemoryAction::List,
193            MemoryAction::Search,
194        ] {
195            let text = a.to_string();
196            assert_eq!(text.parse::<MemoryAction>().unwrap(), a);
197        }
198    }
199
200    #[test]
201    fn test_memory_scope_from_str_valid_returns_ok() {
202        assert_eq!("node".parse::<MemoryScope>().unwrap(), MemoryScope::Node);
203        assert_eq!(
204            "session".parse::<MemoryScope>().unwrap(),
205            MemoryScope::Session
206        );
207        assert_eq!(
208            "project".parse::<MemoryScope>().unwrap(),
209            MemoryScope::Project
210        );
211        assert_eq!(
212            "global".parse::<MemoryScope>().unwrap(),
213            MemoryScope::Global
214        );
215    }
216
217    #[test]
218    fn test_memory_scope_from_str_invalid_returns_error() {
219        let err = "workspace".parse::<MemoryScope>().unwrap_err();
220        assert_eq!(err.type_name, "MemoryScope");
221    }
222
223    #[test]
224    fn test_memory_scope_display_roundtrip() {
225        for s in [
226            MemoryScope::Node,
227            MemoryScope::Session,
228            MemoryScope::Project,
229            MemoryScope::Global,
230        ] {
231            let text = s.to_string();
232            assert_eq!(text.parse::<MemoryScope>().unwrap(), s);
233        }
234    }
235
236    #[test]
237    fn test_memory_ttl_from_str_permanent_returns_ok() {
238        assert_eq!(
239            "permanent".parse::<MemoryTtl>().unwrap(),
240            MemoryTtl::Permanent
241        );
242    }
243
244    #[test]
245    fn test_memory_ttl_from_str_session_returns_ok() {
246        assert_eq!("session".parse::<MemoryTtl>().unwrap(), MemoryTtl::Session);
247    }
248
249    #[test]
250    fn test_memory_ttl_from_str_duration_returns_ok() {
251        assert_eq!(
252            "duration:P7D".parse::<MemoryTtl>().unwrap(),
253            MemoryTtl::Duration("P7D".to_owned())
254        );
255    }
256
257    #[test]
258    fn test_memory_ttl_from_str_duration_pt1h_returns_ok() {
259        assert_eq!(
260            "duration:PT1H".parse::<MemoryTtl>().unwrap(),
261            MemoryTtl::Duration("PT1H".to_owned())
262        );
263    }
264
265    #[test]
266    fn test_memory_ttl_from_str_empty_duration_returns_error() {
267        assert!("duration:".parse::<MemoryTtl>().is_err());
268    }
269
270    #[test]
271    fn test_memory_ttl_from_str_invalid_returns_error() {
272        let err = "forever".parse::<MemoryTtl>().unwrap_err();
273        assert_eq!(err.type_name, "MemoryTtl");
274    }
275
276    #[test]
277    fn test_memory_ttl_display_roundtrip() {
278        for t in [
279            MemoryTtl::Permanent,
280            MemoryTtl::Session,
281            MemoryTtl::Duration("P30D".to_owned()),
282        ] {
283            let text = t.to_string();
284            assert_eq!(text.parse::<MemoryTtl>().unwrap(), t);
285        }
286    }
287
288    #[test]
289    fn test_memory_ttl_serde_roundtrip() {
290        for t in [
291            MemoryTtl::Permanent,
292            MemoryTtl::Session,
293            MemoryTtl::Duration("P7D".to_owned()),
294        ] {
295            let json = serde_json::to_string(&t).unwrap();
296            let back: MemoryTtl = serde_json::from_str(&json).unwrap();
297            assert_eq!(t, back);
298        }
299    }
300
301    #[test]
302    fn test_memory_entry_upsert_serde_roundtrip() {
303        let entry = MemoryEntry {
304            key: "repo.pattern".to_owned(),
305            topic: "rust.repository".to_owned(),
306            action: MemoryAction::Upsert,
307            value: Some("row_to_column uses get()".to_owned()),
308            scope: Some(MemoryScope::Project),
309            ttl: Some(MemoryTtl::Permanent),
310            query: None,
311            max_results: None,
312        };
313        let json = serde_json::to_string(&entry).unwrap();
314        let back: MemoryEntry = serde_json::from_str(&json).unwrap();
315        assert_eq!(entry, back);
316    }
317
318    #[test]
319    fn test_memory_entry_search_serde_roundtrip() {
320        let entry = MemoryEntry {
321            key: "search.patterns".to_owned(),
322            topic: "rust.repository".to_owned(),
323            action: MemoryAction::Search,
324            value: None,
325            scope: None,
326            ttl: None,
327            query: Some("how are optional fields handled".to_owned()),
328            max_results: Some(5),
329        };
330        let json = serde_json::to_string(&entry).unwrap();
331        let back: MemoryEntry = serde_json::from_str(&json).unwrap();
332        assert_eq!(entry, back);
333    }
334
335    #[test]
336    fn test_memory_entry_optional_fields_absent() {
337        let entry = MemoryEntry {
338            key: "test.key".to_owned(),
339            topic: "test".to_owned(),
340            action: MemoryAction::Get,
341            value: None,
342            scope: None,
343            ttl: None,
344            query: None,
345            max_results: None,
346        };
347        let json = serde_json::to_string(&entry).unwrap();
348        assert!(!json.contains("value"));
349        assert!(!json.contains("scope"));
350        assert!(!json.contains("ttl"));
351        assert!(!json.contains("query"));
352        assert!(!json.contains("max_results"));
353    }
354}