Skip to main content

cognee_models/
memory.rs

1//! Discriminated-union memory entries for `remember()` typed dispatch.
2//!
3//! Mirrors Python's `cognee/memory/entries.py`. Typed payloads let callers
4//! pass rich structured data to `cognee.remember()` — Q&A turns, agent
5//! trace steps, feedback attachments — in addition to the legacy
6//! "blob of text/files" shape. Each entry carries a literal `type`
7//! discriminator so the `remember_entry()` dispatch can route to the
8//! right `SessionManager` method.
9//!
10//! Wire shape (Decision 10): the `type` discriminator stays snake_case
11//! (`"qa"` / `"trace"` / `"feedback"`) per Python's
12//! `Literal["qa"|"trace"|"feedback"]`, while every multi-word inner
13//! field name is camelCase on the wire (`feedbackText`, `originFunction`,
14//! `methodParams`, etc.). Snake-case `serde(alias)` attributes accept
15//! the legacy snake_case form on input (Python `populate_by_name=True`
16//! parity).
17
18use serde::{Deserialize, Serialize};
19
20/// Tagged union of typed memory payloads dispatched by `remember_entry()`.
21///
22/// Python parity: `cognee/memory/entries.py:67` (`Union[QAEntry,
23/// TraceEntry, FeedbackEntry]`). The `type` discriminator on the wire
24/// stays snake_case (`"qa"` / `"trace"` / `"feedback"`).
25#[derive(Debug, Clone, Serialize, Deserialize)]
26#[serde(tag = "type", rename_all = "snake_case")]
27pub enum MemoryEntry {
28    /// A Q&A turn stored in the session cache. Dispatched to
29    /// `SessionManager::save_qa` (+ optional `update_qa`).
30    Qa(QAEntry),
31    /// One step of an agent trace. Dispatched to
32    /// `SessionManager::add_agent_trace_step`.
33    Trace(TraceEntry),
34    /// Feedback attached to an existing QA entry. Dispatched to
35    /// `SessionManager::add_feedback`.
36    Feedback(FeedbackEntry),
37}
38
39impl MemoryEntry {
40    /// Python parity helper — the lowercase string discriminator
41    /// (`"qa"` / `"trace"` / `"feedback"`) populated on
42    /// `RememberResult.entry_type`.
43    pub fn type_str(&self) -> &'static str {
44        match self {
45            MemoryEntry::Qa(_) => "qa",
46            MemoryEntry::Trace(_) => "trace",
47            MemoryEntry::Feedback(_) => "feedback",
48        }
49    }
50}
51
52/// A Q&A turn stored in the session cache.
53///
54/// Python parity: `cognee/memory/entries.py:18-31`. `context` defaults
55/// to `""`; the three optional feedback fields default to `None`.
56#[derive(Debug, Clone, Serialize, Deserialize)]
57#[serde(rename_all = "camelCase")]
58pub struct QAEntry {
59    /// The user question.
60    pub question: String,
61    /// The assistant answer.
62    pub answer: String,
63    /// Optional retrieval context. Defaults to `""` (Python parity).
64    #[serde(default)]
65    pub context: String,
66    /// Optional free-form feedback string.
67    #[serde(default, alias = "feedback_text")]
68    pub feedback_text: Option<String>,
69    /// Optional 1..=5 feedback score (validated downstream by
70    /// `SessionManager::add_feedback`).
71    #[serde(default, alias = "feedback_score")]
72    pub feedback_score: Option<i32>,
73    /// Optional graph element ids consulted to produce the answer.
74    /// Wire shape mirrors Python's `dict` — unconstrained `serde_json::Value`.
75    #[serde(default, alias = "used_graph_element_ids")]
76    pub used_graph_element_ids: Option<serde_json::Value>,
77}
78
79/// One step of an agent trace.
80///
81/// Python parity: `cognee/memory/entries.py:34-50`. `status` defaults
82/// to `"success"`; `generate_feedback_with_llm` defaults to `false`;
83/// the three string fields (`memory_query`, `memory_context`,
84/// `error_message`) default to `""`.
85#[derive(Debug, Clone, Serialize, Deserialize)]
86#[serde(rename_all = "camelCase")]
87pub struct TraceEntry {
88    /// Name of the originating function/tool.
89    #[serde(alias = "origin_function")]
90    pub origin_function: String,
91    /// Free-form per Python validator; typically `"success"` / `"error"`.
92    #[serde(default = "default_trace_status")]
93    pub status: String,
94    /// Method parameters (wire: `methodParams`). Optional so callers
95    /// may omit it; converted to `Value::Null` when dispatched.
96    #[serde(default, alias = "method_params")]
97    pub method_params: Option<serde_json::Value>,
98    /// Optional method return value.
99    #[serde(default, alias = "method_return_value")]
100    pub method_return_value: Option<serde_json::Value>,
101    /// Memory query string. Defaults to `""`.
102    #[serde(default, alias = "memory_query")]
103    pub memory_query: String,
104    /// Memory context string. Defaults to `""`.
105    #[serde(default, alias = "memory_context")]
106    pub memory_context: String,
107    /// Error message string. Defaults to `""`.
108    #[serde(default, alias = "error_message")]
109    pub error_message: String,
110    /// If `true`, instructs the dispatcher to generate `session_feedback`
111    /// via an LLM call. **TODO(LIB-01-followup)**: LLM plumbing not in
112    /// scope for LIB-01; the dispatch passes `session_feedback = ""`.
113    #[serde(default, alias = "generate_feedback_with_llm")]
114    pub generate_feedback_with_llm: bool,
115}
116
117/// Feedback attached to an existing QA entry.
118///
119/// Python parity: `cognee/memory/entries.py:53-64`.
120#[derive(Debug, Clone, Serialize, Deserialize)]
121#[serde(rename_all = "camelCase")]
122pub struct FeedbackEntry {
123    /// QA id this feedback is attached to (required).
124    #[serde(alias = "qa_id")]
125    pub qa_id: String,
126    /// Optional free-form feedback string.
127    #[serde(default, alias = "feedback_text")]
128    pub feedback_text: Option<String>,
129    /// Optional 1..=5 feedback score.
130    #[serde(default, alias = "feedback_score")]
131    pub feedback_score: Option<i32>,
132}
133
134fn default_trace_status() -> String {
135    "success".to_string()
136}
137
138#[cfg(test)]
139#[allow(
140    clippy::unwrap_used,
141    clippy::expect_used,
142    reason = "test code — panics are acceptable failures"
143)]
144mod tests {
145    use super::*;
146
147    #[test]
148    fn test_round_trip_memory_entry_qa_json() {
149        // camelCase wire input
150        let camel = r#"{
151            "type": "qa",
152            "question": "Q?",
153            "answer": "A.",
154            "feedbackText": "good",
155            "feedbackScore": 5,
156            "usedGraphElementIds": {"node_ids": ["n1"], "edge_ids": []}
157        }"#;
158        let entry: MemoryEntry = serde_json::from_str(camel).expect("camelCase parse");
159        match entry {
160            MemoryEntry::Qa(ref q) => {
161                assert_eq!(q.question, "Q?");
162                assert_eq!(q.answer, "A.");
163                assert_eq!(q.context, "", "context defaults to empty string");
164                assert_eq!(q.feedback_text.as_deref(), Some("good"));
165                assert_eq!(q.feedback_score, Some(5));
166                assert!(q.used_graph_element_ids.is_some());
167            }
168            other => panic!("expected MemoryEntry::Qa, got {other:?}"),
169        }
170
171        // snake_case alias parity
172        let snake = r#"{
173            "type": "qa",
174            "question": "Q?",
175            "answer": "A.",
176            "feedback_text": "good",
177            "feedback_score": 4
178        }"#;
179        let entry: MemoryEntry = serde_json::from_str(snake).expect("snake_case alias parse");
180        match entry {
181            MemoryEntry::Qa(q) => {
182                assert_eq!(q.feedback_text.as_deref(), Some("good"));
183                assert_eq!(q.feedback_score, Some(4));
184                assert_eq!(q.context, "");
185            }
186            other => panic!("expected MemoryEntry::Qa, got {other:?}"),
187        }
188
189        // Minimal QAEntry — only required fields.
190        let minimal = r#"{"type":"qa","question":"q","answer":"a"}"#;
191        let entry: MemoryEntry = serde_json::from_str(minimal).expect("minimal parse");
192        match entry {
193            MemoryEntry::Qa(q) => {
194                assert_eq!(q.context, "");
195                assert!(q.feedback_text.is_none());
196                assert!(q.feedback_score.is_none());
197                assert!(q.used_graph_element_ids.is_none());
198            }
199            other => panic!("expected MemoryEntry::Qa, got {other:?}"),
200        }
201
202        // Round-trip emits camelCase + the snake_case `type` discriminator.
203        let entry = MemoryEntry::Qa(QAEntry {
204            question: "q".into(),
205            answer: "a".into(),
206            context: "".into(),
207            feedback_text: Some("nice".into()),
208            feedback_score: Some(3),
209            used_graph_element_ids: None,
210        });
211        let s = serde_json::to_string(&entry).expect("serialize");
212        assert!(
213            s.contains("\"type\":\"qa\""),
214            "discriminator stays snake_case: {s}"
215        );
216        assert!(
217            s.contains("\"feedbackText\":\"nice\""),
218            "camelCase wire: {s}"
219        );
220        assert!(s.contains("\"feedbackScore\":3"), "camelCase wire: {s}");
221    }
222
223    #[test]
224    fn test_round_trip_memory_entry_trace_json() {
225        // camelCase wire input with all fields.
226        let camel = r#"{
227            "type": "trace",
228            "originFunction": "search",
229            "status": "error",
230            "methodParams": {"q": "hello"},
231            "methodReturnValue": {"hits": 3},
232            "memoryQuery": "what?",
233            "memoryContext": "context",
234            "errorMessage": "boom",
235            "generateFeedbackWithLlm": true
236        }"#;
237        let entry: MemoryEntry = serde_json::from_str(camel).expect("camelCase trace parse");
238        match entry {
239            MemoryEntry::Trace(t) => {
240                assert_eq!(t.origin_function, "search");
241                assert_eq!(t.status, "error");
242                assert_eq!(t.memory_query, "what?");
243                assert_eq!(t.memory_context, "context");
244                assert_eq!(t.error_message, "boom");
245                assert!(t.generate_feedback_with_llm);
246                assert!(t.method_params.is_some());
247                assert!(t.method_return_value.is_some());
248            }
249            other => panic!("expected MemoryEntry::Trace, got {other:?}"),
250        }
251
252        // snake_case alias parity + defaults.
253        let snake = r#"{
254            "type": "trace",
255            "origin_function": "fn",
256            "method_params": null,
257            "method_return_value": null
258        }"#;
259        let entry: MemoryEntry = serde_json::from_str(snake).expect("snake_case trace parse");
260        match entry {
261            MemoryEntry::Trace(t) => {
262                assert_eq!(t.origin_function, "fn");
263                assert_eq!(t.status, "success", "status defaults to success");
264                assert_eq!(t.memory_query, "");
265                assert_eq!(t.memory_context, "");
266                assert_eq!(t.error_message, "");
267                assert!(!t.generate_feedback_with_llm);
268                // null inputs deserialize to Some(Value::Null) but Option deser of null is None.
269                assert!(t.method_params.is_none());
270                assert!(t.method_return_value.is_none());
271            }
272            other => panic!("expected MemoryEntry::Trace, got {other:?}"),
273        }
274
275        // Round-trip: serialization uses camelCase + snake-case discriminator.
276        let entry = MemoryEntry::Trace(TraceEntry {
277            origin_function: "f".into(),
278            status: "success".into(),
279            method_params: Some(serde_json::json!({"k": "v"})),
280            method_return_value: None,
281            memory_query: "".into(),
282            memory_context: "".into(),
283            error_message: "".into(),
284            generate_feedback_with_llm: false,
285        });
286        let s = serde_json::to_string(&entry).expect("serialize trace");
287        assert!(s.contains("\"type\":\"trace\""));
288        assert!(s.contains("\"originFunction\":\"f\""));
289        assert!(s.contains("\"methodParams\""));
290        assert!(s.contains("\"generateFeedbackWithLlm\":false"));
291    }
292
293    #[test]
294    fn test_round_trip_memory_entry_feedback_json() {
295        // camelCase wire input.
296        let camel = r#"{
297            "type": "feedback",
298            "qaId": "abc-123",
299            "feedbackText": "great",
300            "feedbackScore": 5
301        }"#;
302        let entry: MemoryEntry = serde_json::from_str(camel).expect("camelCase feedback parse");
303        match entry {
304            MemoryEntry::Feedback(ref f) => {
305                assert_eq!(f.qa_id, "abc-123");
306                assert_eq!(f.feedback_text.as_deref(), Some("great"));
307                assert_eq!(f.feedback_score, Some(5));
308            }
309            other => panic!("expected MemoryEntry::Feedback, got {other:?}"),
310        }
311
312        // snake_case alias.
313        let snake = r#"{
314            "type": "feedback",
315            "qa_id": "xyz",
316            "feedback_text": "ok"
317        }"#;
318        let entry: MemoryEntry = serde_json::from_str(snake).expect("snake_case feedback parse");
319        match entry {
320            MemoryEntry::Feedback(f) => {
321                assert_eq!(f.qa_id, "xyz");
322                assert_eq!(f.feedback_text.as_deref(), Some("ok"));
323                assert!(f.feedback_score.is_none());
324            }
325            other => panic!("expected MemoryEntry::Feedback, got {other:?}"),
326        }
327
328        // Round-trip emits camelCase wire fields + snake-case `type`.
329        let entry = MemoryEntry::Feedback(FeedbackEntry {
330            qa_id: "id".into(),
331            feedback_text: Some("ok".into()),
332            feedback_score: None,
333        });
334        let s = serde_json::to_string(&entry).expect("serialize feedback");
335        assert!(s.contains("\"type\":\"feedback\""));
336        assert!(s.contains("\"qaId\":\"id\""));
337        assert!(s.contains("\"feedbackText\":\"ok\""));
338    }
339
340    #[test]
341    fn test_type_str_helper() {
342        let q = MemoryEntry::Qa(QAEntry {
343            question: "".into(),
344            answer: "".into(),
345            context: "".into(),
346            feedback_text: None,
347            feedback_score: None,
348            used_graph_element_ids: None,
349        });
350        assert_eq!(q.type_str(), "qa");
351
352        let t = MemoryEntry::Trace(TraceEntry {
353            origin_function: "x".into(),
354            status: "success".into(),
355            method_params: None,
356            method_return_value: None,
357            memory_query: "".into(),
358            memory_context: "".into(),
359            error_message: "".into(),
360            generate_feedback_with_llm: false,
361        });
362        assert_eq!(t.type_str(), "trace");
363
364        let f = MemoryEntry::Feedback(FeedbackEntry {
365            qa_id: "x".into(),
366            feedback_text: None,
367            feedback_score: None,
368        });
369        assert_eq!(f.type_str(), "feedback");
370    }
371}