cognee-session 0.1.3

Session/conversation store (filesystem, Redis, SeaORM) for the cognee pipeline.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
use std::collections::HashMap;
use std::sync::Arc;

use cognee_llm::Llm;
use cognee_llm::Message;
use tracing::debug;

use crate::error::SessionError;
use crate::feedback;
use crate::session_store::{SessionQAUpdate, SessionStore};
use crate::types::{SessionQAEntry, SessionTraceStep, UsedGraphElementIds};

const DEFAULT_SESSION_ID: &str = "default_session";
const DEFAULT_HISTORY_LIMIT: usize = 10;

/// Orchestrates session operations using an `Arc<dyn SessionStore>`.
///
/// Analogous to Python's `SessionManager`. Loads conversation history as
/// `Vec<Message>` for LLM multi-turn conversations, and saves Q&A entries
/// after each search completion.
#[derive(Clone)]
pub struct SessionManager {
    store: Arc<dyn SessionStore>,
    default_session_id: String,
    history_limit: usize,
    llm: Option<Arc<dyn Llm>>,
}

impl SessionManager {
    pub fn new(store: Arc<dyn SessionStore>) -> Self {
        Self {
            store,
            default_session_id: DEFAULT_SESSION_ID.to_string(),
            history_limit: DEFAULT_HISTORY_LIMIT,
            llm: None,
        }
    }

    pub fn with_llm(mut self, llm: Arc<dyn Llm>) -> Self {
        self.llm = Some(llm);
        self
    }

    pub fn with_default_session_id(mut self, id: impl Into<String>) -> Self {
        self.default_session_id = id.into();
        self
    }

    pub fn with_history_limit(mut self, limit: usize) -> Self {
        self.history_limit = limit;
        self
    }

    fn resolve_session_id<'a>(&'a self, session_id: Option<&'a str>) -> &'a str {
        session_id.unwrap_or(&self.default_session_id)
    }

    /// Load conversation history as alternating User/Assistant messages.
    ///
    /// Returns the last `history_limit` Q&A pairs as:
    /// `[User(q1), Assistant(a1), User(q2), Assistant(a2), ...]`
    pub async fn load_history_messages(
        &self,
        session_id: Option<&str>,
        user_id: Option<&str>,
    ) -> Result<Vec<Message>, SessionError> {
        let resolved_id = self.resolve_session_id(session_id);
        let entries = self
            .store
            .get_latest_qa_entries(resolved_id, user_id, self.history_limit)
            .await?;

        debug!(
            session_id = resolved_id,
            entries = entries.len(),
            "Loaded session history"
        );

        Ok(entries_to_messages(&entries))
    }

    /// Load history as structured messages AND a formatted string, with a single store round-trip.
    pub async fn load_history_both(
        &self,
        session_id: Option<&str>,
        user_id: Option<&str>,
    ) -> Result<(Vec<Message>, String), SessionError> {
        let resolved_id = self.resolve_session_id(session_id);
        let entries = self
            .store
            .get_latest_qa_entries(resolved_id, user_id, self.history_limit)
            .await?;

        debug!(
            session_id = resolved_id,
            entries = entries.len(),
            "Loaded session history (both forms)"
        );

        let messages = entries_to_messages(&entries);
        let formatted = Self::format_entries(&entries);
        Ok((messages, formatted))
    }

    /// Save a Q&A exchange to the session. Returns the generated `qa_id`.
    ///
    /// `used_graph_element_ids` carries the node/edge IDs that were consulted
    /// during retrieval so the memify pipeline can trace which graph elements
    /// produced the answer (mirrors Python `session_manager.py:492-525`,
    /// `add_qa(..., used_graph_element_ids=used_graph_element_ids)`).
    pub async fn save_qa(
        &self,
        session_id: Option<&str>,
        user_id: Option<&str>,
        question: &str,
        answer: &str,
        context: Option<&str>,
        used_graph_element_ids: Option<UsedGraphElementIds>,
    ) -> Result<String, SessionError> {
        let resolved_id = self.resolve_session_id(session_id);
        let qa_id = self
            .store
            .create_qa_entry(resolved_id, user_id, question, answer, context)
            .await?;

        // Write used_graph_element_ids if provided.
        if let Some(ids) = used_graph_element_ids
            && let Err(e) = self
                .store
                .update_qa_entry(
                    resolved_id,
                    user_id,
                    &qa_id,
                    SessionQAUpdate {
                        used_graph_element_ids: Some(Some(ids)),
                        ..Default::default()
                    },
                )
                .await
        {
            tracing::warn!(
                qa_id = %qa_id,
                "save_qa: failed to persist used_graph_element_ids (non-fatal): {e}"
            );
        }

        // Mirrors Python `send_telemetry("cognee.session.add_qa", ...)` from
        // cognee/memory/session_manager.py:171.
        #[cfg(feature = "telemetry")]
        {
            let data_size_bytes =
                question.len() + answer.len() + context.map(|c| c.len()).unwrap_or(0);
            cognee_telemetry::send_telemetry(
                "cognee.session.add_qa",
                user_id.unwrap_or("sdk"),
                Some(serde_json::json!({
                    "session_id": resolved_id,
                    "data_size_bytes": data_size_bytes,
                    "has_feedback": false,
                    "has_graph_elements": false,
                })),
            );
        }

        Ok(qa_id)
    }

    /// Delete all Q&A entries for a session.
    pub async fn delete_session(
        &self,
        session_id: Option<&str>,
        user_id: Option<&str>,
    ) -> Result<bool, SessionError> {
        let resolved_id = self.resolve_session_id(session_id);
        self.store.delete_session(resolved_id, user_id).await
    }

    /// Format Q&A entries as a human-readable string (for debugging / compatibility
    /// with Python's `SessionManager.format_entries`).
    ///
    /// When `include_context` is `true`, the context field is included between
    /// QUESTION and ANSWER (matching the Python `include_context` parameter).
    pub fn format_entries(entries: &[SessionQAEntry]) -> String {
        Self::format_entries_with_context(entries, false)
    }

    /// Format Q&A entries, optionally including context.
    pub fn format_entries_with_context(
        entries: &[SessionQAEntry],
        include_context: bool,
    ) -> String {
        if entries.is_empty() {
            return String::new();
        }
        let mut lines = vec!["Previous conversation:\n\n".to_string()];
        for entry in entries {
            lines.push(format!("[{}]\n", entry.created_at.to_rfc3339()));
            lines.push(format!("QUESTION: {}\n", entry.question));
            if include_context && let Some(ref ctx) = entry.context {
                lines.push(format!("CONTEXT: {ctx}\n"));
            }
            lines.push(format!("ANSWER: {}\n\n", entry.answer));
        }
        lines.concat()
    }

    /// Update arbitrary fields on a QA entry.
    pub async fn update_qa(
        &self,
        session_id: Option<&str>,
        user_id: Option<&str>,
        qa_id: &str,
        updates: SessionQAUpdate,
    ) -> Result<bool, SessionError> {
        let resolved_id = self.resolve_session_id(session_id);
        self.store
            .update_qa_entry(resolved_id, user_id, qa_id, updates)
            .await
    }

    /// Add or update feedback on a QA entry (convenience over `update_qa`).
    ///
    /// Resets `memify_metadata.feedback_weights_applied` to `false` so that the
    /// memify pipeline will re-apply weights on the next run.
    pub async fn add_feedback(
        &self,
        session_id: Option<&str>,
        user_id: Option<&str>,
        qa_id: &str,
        feedback_text: Option<&str>,
        feedback_score: Option<i32>,
    ) -> Result<bool, SessionError> {
        if let Some(score) = feedback_score
            && !(1..=5).contains(&score)
        {
            return Err(SessionError::InvalidParameter(format!(
                "feedback_score must be between 1 and 5, got {score}"
            )));
        }

        let mut memify = HashMap::new();
        memify.insert("feedback_weights_applied".to_string(), false);

        self.update_qa(
            session_id,
            user_id,
            qa_id,
            SessionQAUpdate {
                feedback_text: Some(feedback_text.map(|s| s.to_string())),
                feedback_score: Some(feedback_score),
                memify_metadata: Some(Some(memify)),
                ..Default::default()
            },
        )
        .await
    }

    /// Clear feedback from a QA entry.
    pub async fn delete_feedback(
        &self,
        session_id: Option<&str>,
        user_id: Option<&str>,
        qa_id: &str,
    ) -> Result<bool, SessionError> {
        self.update_qa(
            session_id,
            user_id,
            qa_id,
            SessionQAUpdate {
                feedback_text: Some(None),
                feedback_score: Some(None),
                ..Default::default()
            },
        )
        .await
    }

    /// Return the `qa_id` of the most-recent Q&A entry in the session.
    ///
    /// Returns `None` when the session has no entries yet. Used to route
    /// conversationally-detected feedback to the prior QA entry before saving the
    /// new turn (mirrors Python `session_manager.py:462-469`).
    pub async fn latest_qa_id(
        &self,
        session_id: Option<&str>,
        user_id: Option<&str>,
    ) -> Result<Option<String>, SessionError> {
        let resolved_id = self.resolve_session_id(session_id);
        self.store.latest_qa_id(resolved_id, user_id).await
    }

    /// Retrieve graph knowledge snapshot for a session.
    pub async fn get_graph_context(
        &self,
        session_id: Option<&str>,
        user_id: Option<&str>,
    ) -> Result<Option<String>, SessionError> {
        let resolved_id = self.resolve_session_id(session_id);
        self.store.get_graph_context(resolved_id, user_id).await
    }

    /// Store graph knowledge snapshot for a session.
    pub async fn set_graph_context(
        &self,
        session_id: Option<&str>,
        user_id: Option<&str>,
        context: &str,
    ) -> Result<(), SessionError> {
        let resolved_id = self.resolve_session_id(session_id);
        self.store
            .set_graph_context(resolved_id, user_id, context)
            .await
    }

    /// Append one agent-trace step to the session and return the generated
    /// `trace_id` (UUID4).
    ///
    /// Mirrors Python's `SessionManager.add_agent_trace_step`.
    ///
    /// When `generate_feedback` is `true`, this method attempts to use the
    /// configured LLM (`with_llm`) to summarize `method_return_value`; if no
    /// LLM is wired or generation fails, it falls back to deterministic
    /// feedback (`<origin> succeeded/failed`).
    #[allow(clippy::too_many_arguments)]
    pub async fn add_agent_trace_step(
        &self,
        user_id: &str,
        session_id: Option<&str>,
        origin_function: &str,
        status: &str,
        memory_query: &str,
        memory_context: &str,
        method_params: serde_json::Value,
        method_return_value: Option<serde_json::Value>,
        error_message: &str,
        generate_feedback: bool,
    ) -> Result<String, SessionError> {
        let resolved_id = self.resolve_session_id(session_id);
        let trace_id = uuid::Uuid::new_v4().to_string();
        let session_feedback = if generate_feedback {
            if let Some(llm) = self.llm.as_ref() {
                feedback::generate_session_feedback(
                    llm.as_ref(),
                    origin_function,
                    status,
                    method_return_value.as_ref(),
                    error_message,
                )
                .await
            } else {
                tracing::warn!(
                    origin_function,
                    session_id = resolved_id,
                    "add_agent_trace_step: generate_feedback=true but no LLM wired; using deterministic fallback"
                );
                feedback::fallback_feedback(origin_function, status, error_message)
            }
        } else {
            feedback::fallback_feedback(origin_function, status, error_message)
        };

        let step = SessionTraceStep {
            trace_id: trace_id.clone(),
            origin_function: origin_function.to_string(),
            status: status.to_string(),
            memory_query: memory_query.to_string(),
            memory_context: memory_context.to_string(),
            method_params,
            method_return_value,
            error_message: error_message.to_string(),
            session_feedback,
        };
        self.store.save_trace_step(user_id, resolved_id, step).await
    }

    /// Retrieve agent-trace steps for a session, oldest-first.
    ///
    /// If `last_n` is `Some(n)`, the trailing `n` entries are returned
    /// (mirrors Python's `entries[-last_n:]`).
    pub async fn get_agent_trace_session(
        &self,
        user_id: &str,
        session_id: Option<&str>,
        last_n: Option<usize>,
    ) -> Result<Vec<SessionTraceStep>, SessionError> {
        let resolved_id = self.resolve_session_id(session_id);
        let mut entries = self.store.read_trace_steps(user_id, resolved_id).await?;
        if let Some(n) = last_n {
            let drop = entries.len().saturating_sub(n);
            entries = entries.split_off(drop);
        }
        Ok(entries)
    }
}

/// Convert session Q&A entries to alternating User/Assistant LLM messages.
fn entries_to_messages(entries: &[SessionQAEntry]) -> Vec<Message> {
    let mut messages = Vec::with_capacity(entries.len() * 2);
    for entry in entries {
        messages.push(Message::user(&entry.question));
        messages.push(Message::assistant(&entry.answer));
    }
    messages
}

#[cfg(test)]
mod tests {
    use super::*;

    fn make_entry(question: &str, answer: &str) -> SessionQAEntry {
        SessionQAEntry {
            id: uuid::Uuid::new_v4(),
            session_id: "s1".to_string(),
            user_id: None,
            question: question.to_string(),
            answer: answer.to_string(),
            context: None,
            created_at: chrono::Utc::now(),
            feedback_text: None,
            feedback_score: None,
            used_graph_element_ids: None,
            memify_metadata: None,
        }
    }

    #[test]
    fn entries_to_messages_alternates_roles() {
        let entries = vec![
            make_entry("What is Rust?", "A systems programming language."),
            make_entry("Tell me more.", "It focuses on safety and performance."),
        ];

        let messages = entries_to_messages(&entries);
        assert_eq!(messages.len(), 4);
        assert_eq!(messages[0].role, cognee_llm::MessageRole::User);
        assert_eq!(messages[0].content, "What is Rust?");
        assert_eq!(messages[1].role, cognee_llm::MessageRole::Assistant);
        assert_eq!(messages[1].content, "A systems programming language.");
        assert_eq!(messages[2].role, cognee_llm::MessageRole::User);
        assert_eq!(messages[3].role, cognee_llm::MessageRole::Assistant);
    }

    #[test]
    fn format_entries_produces_expected_output() {
        let entries = vec![make_entry("Hello?", "Hi there!")];

        let formatted = SessionManager::format_entries(&entries);
        assert!(formatted.contains("Previous conversation:"));
        assert!(formatted.contains("QUESTION: Hello?"));
        assert!(formatted.contains("ANSWER: Hi there!"));
    }

    #[test]
    fn format_entries_empty_returns_empty_string() {
        assert_eq!(SessionManager::format_entries(&[]), "");
    }

    #[test]
    fn format_entries_with_context_includes_context() {
        let mut entry = make_entry("Hello?", "Hi there!");
        entry.context = Some("Some context here".to_string());
        let entries = vec![entry];

        let formatted = SessionManager::format_entries_with_context(&entries, true);
        assert!(formatted.contains("CONTEXT: Some context here"));
    }

    #[test]
    fn format_entries_with_context_false_omits_context() {
        let mut entry = make_entry("Hello?", "Hi there!");
        entry.context = Some("Some context here".to_string());
        let entries = vec![entry];

        let formatted = SessionManager::format_entries_with_context(&entries, false);
        assert!(!formatted.contains("CONTEXT:"));
    }
}