Skip to main content

cognee_session/
session_manager.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use cognee_llm::Llm;
5use cognee_llm::Message;
6use tracing::debug;
7
8use crate::error::SessionError;
9use crate::feedback;
10use crate::session_store::{SessionQAUpdate, SessionStore};
11use crate::types::{SessionQAEntry, SessionTraceStep, UsedGraphElementIds};
12
13const DEFAULT_SESSION_ID: &str = "default_session";
14const DEFAULT_HISTORY_LIMIT: usize = 10;
15
16/// Orchestrates session operations using an `Arc<dyn SessionStore>`.
17///
18/// Analogous to Python's `SessionManager`. Loads conversation history as
19/// `Vec<Message>` for LLM multi-turn conversations, and saves Q&A entries
20/// after each search completion.
21#[derive(Clone)]
22pub struct SessionManager {
23    store: Arc<dyn SessionStore>,
24    default_session_id: String,
25    history_limit: usize,
26    llm: Option<Arc<dyn Llm>>,
27}
28
29impl SessionManager {
30    pub fn new(store: Arc<dyn SessionStore>) -> Self {
31        Self {
32            store,
33            default_session_id: DEFAULT_SESSION_ID.to_string(),
34            history_limit: DEFAULT_HISTORY_LIMIT,
35            llm: None,
36        }
37    }
38
39    pub fn with_llm(mut self, llm: Arc<dyn Llm>) -> Self {
40        self.llm = Some(llm);
41        self
42    }
43
44    pub fn with_default_session_id(mut self, id: impl Into<String>) -> Self {
45        self.default_session_id = id.into();
46        self
47    }
48
49    pub fn with_history_limit(mut self, limit: usize) -> Self {
50        self.history_limit = limit;
51        self
52    }
53
54    fn resolve_session_id<'a>(&'a self, session_id: Option<&'a str>) -> &'a str {
55        session_id.unwrap_or(&self.default_session_id)
56    }
57
58    /// Load conversation history as alternating User/Assistant messages.
59    ///
60    /// Returns the last `history_limit` Q&A pairs as:
61    /// `[User(q1), Assistant(a1), User(q2), Assistant(a2), ...]`
62    pub async fn load_history_messages(
63        &self,
64        session_id: Option<&str>,
65        user_id: Option<&str>,
66    ) -> Result<Vec<Message>, SessionError> {
67        let resolved_id = self.resolve_session_id(session_id);
68        let entries = self
69            .store
70            .get_latest_qa_entries(resolved_id, user_id, self.history_limit)
71            .await?;
72
73        debug!(
74            session_id = resolved_id,
75            entries = entries.len(),
76            "Loaded session history"
77        );
78
79        Ok(entries_to_messages(&entries))
80    }
81
82    /// Load history as structured messages AND a formatted string, with a single store round-trip.
83    pub async fn load_history_both(
84        &self,
85        session_id: Option<&str>,
86        user_id: Option<&str>,
87    ) -> Result<(Vec<Message>, String), SessionError> {
88        let resolved_id = self.resolve_session_id(session_id);
89        let entries = self
90            .store
91            .get_latest_qa_entries(resolved_id, user_id, self.history_limit)
92            .await?;
93
94        debug!(
95            session_id = resolved_id,
96            entries = entries.len(),
97            "Loaded session history (both forms)"
98        );
99
100        let messages = entries_to_messages(&entries);
101        let formatted = Self::format_entries(&entries);
102        Ok((messages, formatted))
103    }
104
105    /// Save a Q&A exchange to the session. Returns the generated `qa_id`.
106    ///
107    /// `used_graph_element_ids` carries the node/edge IDs that were consulted
108    /// during retrieval so the memify pipeline can trace which graph elements
109    /// produced the answer (mirrors Python `session_manager.py:492-525`,
110    /// `add_qa(..., used_graph_element_ids=used_graph_element_ids)`).
111    pub async fn save_qa(
112        &self,
113        session_id: Option<&str>,
114        user_id: Option<&str>,
115        question: &str,
116        answer: &str,
117        context: Option<&str>,
118        used_graph_element_ids: Option<UsedGraphElementIds>,
119    ) -> Result<String, SessionError> {
120        let resolved_id = self.resolve_session_id(session_id);
121        let qa_id = self
122            .store
123            .create_qa_entry(resolved_id, user_id, question, answer, context)
124            .await?;
125
126        // Write used_graph_element_ids if provided.
127        if let Some(ids) = used_graph_element_ids
128            && let Err(e) = self
129                .store
130                .update_qa_entry(
131                    resolved_id,
132                    user_id,
133                    &qa_id,
134                    SessionQAUpdate {
135                        used_graph_element_ids: Some(Some(ids)),
136                        ..Default::default()
137                    },
138                )
139                .await
140        {
141            tracing::warn!(
142                qa_id = %qa_id,
143                "save_qa: failed to persist used_graph_element_ids (non-fatal): {e}"
144            );
145        }
146
147        // Mirrors Python `send_telemetry("cognee.session.add_qa", ...)` from
148        // cognee/memory/session_manager.py:171.
149        #[cfg(feature = "telemetry")]
150        {
151            let data_size_bytes =
152                question.len() + answer.len() + context.map(|c| c.len()).unwrap_or(0);
153            cognee_telemetry::send_telemetry(
154                "cognee.session.add_qa",
155                user_id.unwrap_or("sdk"),
156                Some(serde_json::json!({
157                    "session_id": resolved_id,
158                    "data_size_bytes": data_size_bytes,
159                    "has_feedback": false,
160                    "has_graph_elements": false,
161                })),
162            );
163        }
164
165        Ok(qa_id)
166    }
167
168    /// Delete all Q&A entries for a session.
169    pub async fn delete_session(
170        &self,
171        session_id: Option<&str>,
172        user_id: Option<&str>,
173    ) -> Result<bool, SessionError> {
174        let resolved_id = self.resolve_session_id(session_id);
175        self.store.delete_session(resolved_id, user_id).await
176    }
177
178    /// Format Q&A entries as a human-readable string (for debugging / compatibility
179    /// with Python's `SessionManager.format_entries`).
180    ///
181    /// When `include_context` is `true`, the context field is included between
182    /// QUESTION and ANSWER (matching the Python `include_context` parameter).
183    pub fn format_entries(entries: &[SessionQAEntry]) -> String {
184        Self::format_entries_with_context(entries, false)
185    }
186
187    /// Format Q&A entries, optionally including context.
188    pub fn format_entries_with_context(
189        entries: &[SessionQAEntry],
190        include_context: bool,
191    ) -> String {
192        if entries.is_empty() {
193            return String::new();
194        }
195        let mut lines = vec!["Previous conversation:\n\n".to_string()];
196        for entry in entries {
197            lines.push(format!("[{}]\n", entry.created_at.to_rfc3339()));
198            lines.push(format!("QUESTION: {}\n", entry.question));
199            if include_context && let Some(ref ctx) = entry.context {
200                lines.push(format!("CONTEXT: {ctx}\n"));
201            }
202            lines.push(format!("ANSWER: {}\n\n", entry.answer));
203        }
204        lines.concat()
205    }
206
207    /// Update arbitrary fields on a QA entry.
208    pub async fn update_qa(
209        &self,
210        session_id: Option<&str>,
211        user_id: Option<&str>,
212        qa_id: &str,
213        updates: SessionQAUpdate,
214    ) -> Result<bool, SessionError> {
215        let resolved_id = self.resolve_session_id(session_id);
216        self.store
217            .update_qa_entry(resolved_id, user_id, qa_id, updates)
218            .await
219    }
220
221    /// Add or update feedback on a QA entry (convenience over `update_qa`).
222    ///
223    /// Resets `memify_metadata.feedback_weights_applied` to `false` so that the
224    /// memify pipeline will re-apply weights on the next run.
225    pub async fn add_feedback(
226        &self,
227        session_id: Option<&str>,
228        user_id: Option<&str>,
229        qa_id: &str,
230        feedback_text: Option<&str>,
231        feedback_score: Option<i32>,
232    ) -> Result<bool, SessionError> {
233        if let Some(score) = feedback_score
234            && !(1..=5).contains(&score)
235        {
236            return Err(SessionError::InvalidParameter(format!(
237                "feedback_score must be between 1 and 5, got {score}"
238            )));
239        }
240
241        let mut memify = HashMap::new();
242        memify.insert("feedback_weights_applied".to_string(), false);
243
244        self.update_qa(
245            session_id,
246            user_id,
247            qa_id,
248            SessionQAUpdate {
249                feedback_text: Some(feedback_text.map(|s| s.to_string())),
250                feedback_score: Some(feedback_score),
251                memify_metadata: Some(Some(memify)),
252                ..Default::default()
253            },
254        )
255        .await
256    }
257
258    /// Clear feedback from a QA entry.
259    pub async fn delete_feedback(
260        &self,
261        session_id: Option<&str>,
262        user_id: Option<&str>,
263        qa_id: &str,
264    ) -> Result<bool, SessionError> {
265        self.update_qa(
266            session_id,
267            user_id,
268            qa_id,
269            SessionQAUpdate {
270                feedback_text: Some(None),
271                feedback_score: Some(None),
272                ..Default::default()
273            },
274        )
275        .await
276    }
277
278    /// Return the `qa_id` of the most-recent Q&A entry in the session.
279    ///
280    /// Returns `None` when the session has no entries yet. Used to route
281    /// conversationally-detected feedback to the prior QA entry before saving the
282    /// new turn (mirrors Python `session_manager.py:462-469`).
283    pub async fn latest_qa_id(
284        &self,
285        session_id: Option<&str>,
286        user_id: Option<&str>,
287    ) -> Result<Option<String>, SessionError> {
288        let resolved_id = self.resolve_session_id(session_id);
289        self.store.latest_qa_id(resolved_id, user_id).await
290    }
291
292    /// Retrieve graph knowledge snapshot for a session.
293    pub async fn get_graph_context(
294        &self,
295        session_id: Option<&str>,
296        user_id: Option<&str>,
297    ) -> Result<Option<String>, SessionError> {
298        let resolved_id = self.resolve_session_id(session_id);
299        self.store.get_graph_context(resolved_id, user_id).await
300    }
301
302    /// Store graph knowledge snapshot for a session.
303    pub async fn set_graph_context(
304        &self,
305        session_id: Option<&str>,
306        user_id: Option<&str>,
307        context: &str,
308    ) -> Result<(), SessionError> {
309        let resolved_id = self.resolve_session_id(session_id);
310        self.store
311            .set_graph_context(resolved_id, user_id, context)
312            .await
313    }
314
315    /// Append one agent-trace step to the session and return the generated
316    /// `trace_id` (UUID4).
317    ///
318    /// Mirrors Python's `SessionManager.add_agent_trace_step`.
319    ///
320    /// When `generate_feedback` is `true`, this method attempts to use the
321    /// configured LLM (`with_llm`) to summarize `method_return_value`; if no
322    /// LLM is wired or generation fails, it falls back to deterministic
323    /// feedback (`<origin> succeeded/failed`).
324    #[allow(clippy::too_many_arguments)]
325    pub async fn add_agent_trace_step(
326        &self,
327        user_id: &str,
328        session_id: Option<&str>,
329        origin_function: &str,
330        status: &str,
331        memory_query: &str,
332        memory_context: &str,
333        method_params: serde_json::Value,
334        method_return_value: Option<serde_json::Value>,
335        error_message: &str,
336        generate_feedback: bool,
337    ) -> Result<String, SessionError> {
338        let resolved_id = self.resolve_session_id(session_id);
339        let trace_id = uuid::Uuid::new_v4().to_string();
340        let session_feedback = if generate_feedback {
341            if let Some(llm) = self.llm.as_ref() {
342                feedback::generate_session_feedback(
343                    llm.as_ref(),
344                    origin_function,
345                    status,
346                    method_return_value.as_ref(),
347                    error_message,
348                )
349                .await
350            } else {
351                tracing::warn!(
352                    origin_function,
353                    session_id = resolved_id,
354                    "add_agent_trace_step: generate_feedback=true but no LLM wired; using deterministic fallback"
355                );
356                feedback::fallback_feedback(origin_function, status, error_message)
357            }
358        } else {
359            feedback::fallback_feedback(origin_function, status, error_message)
360        };
361
362        let step = SessionTraceStep {
363            trace_id: trace_id.clone(),
364            origin_function: origin_function.to_string(),
365            status: status.to_string(),
366            memory_query: memory_query.to_string(),
367            memory_context: memory_context.to_string(),
368            method_params,
369            method_return_value,
370            error_message: error_message.to_string(),
371            session_feedback,
372        };
373        self.store.save_trace_step(user_id, resolved_id, step).await
374    }
375
376    /// Retrieve agent-trace steps for a session, oldest-first.
377    ///
378    /// If `last_n` is `Some(n)`, the trailing `n` entries are returned
379    /// (mirrors Python's `entries[-last_n:]`).
380    pub async fn get_agent_trace_session(
381        &self,
382        user_id: &str,
383        session_id: Option<&str>,
384        last_n: Option<usize>,
385    ) -> Result<Vec<SessionTraceStep>, SessionError> {
386        let resolved_id = self.resolve_session_id(session_id);
387        let mut entries = self.store.read_trace_steps(user_id, resolved_id).await?;
388        if let Some(n) = last_n {
389            let drop = entries.len().saturating_sub(n);
390            entries = entries.split_off(drop);
391        }
392        Ok(entries)
393    }
394}
395
396/// Convert session Q&A entries to alternating User/Assistant LLM messages.
397fn entries_to_messages(entries: &[SessionQAEntry]) -> Vec<Message> {
398    let mut messages = Vec::with_capacity(entries.len() * 2);
399    for entry in entries {
400        messages.push(Message::user(&entry.question));
401        messages.push(Message::assistant(&entry.answer));
402    }
403    messages
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409
410    fn make_entry(question: &str, answer: &str) -> SessionQAEntry {
411        SessionQAEntry {
412            id: uuid::Uuid::new_v4(),
413            session_id: "s1".to_string(),
414            user_id: None,
415            question: question.to_string(),
416            answer: answer.to_string(),
417            context: None,
418            created_at: chrono::Utc::now(),
419            feedback_text: None,
420            feedback_score: None,
421            used_graph_element_ids: None,
422            memify_metadata: None,
423        }
424    }
425
426    #[test]
427    fn entries_to_messages_alternates_roles() {
428        let entries = vec![
429            make_entry("What is Rust?", "A systems programming language."),
430            make_entry("Tell me more.", "It focuses on safety and performance."),
431        ];
432
433        let messages = entries_to_messages(&entries);
434        assert_eq!(messages.len(), 4);
435        assert_eq!(messages[0].role, cognee_llm::MessageRole::User);
436        assert_eq!(messages[0].content, "What is Rust?");
437        assert_eq!(messages[1].role, cognee_llm::MessageRole::Assistant);
438        assert_eq!(messages[1].content, "A systems programming language.");
439        assert_eq!(messages[2].role, cognee_llm::MessageRole::User);
440        assert_eq!(messages[3].role, cognee_llm::MessageRole::Assistant);
441    }
442
443    #[test]
444    fn format_entries_produces_expected_output() {
445        let entries = vec![make_entry("Hello?", "Hi there!")];
446
447        let formatted = SessionManager::format_entries(&entries);
448        assert!(formatted.contains("Previous conversation:"));
449        assert!(formatted.contains("QUESTION: Hello?"));
450        assert!(formatted.contains("ANSWER: Hi there!"));
451    }
452
453    #[test]
454    fn format_entries_empty_returns_empty_string() {
455        assert_eq!(SessionManager::format_entries(&[]), "");
456    }
457
458    #[test]
459    fn format_entries_with_context_includes_context() {
460        let mut entry = make_entry("Hello?", "Hi there!");
461        entry.context = Some("Some context here".to_string());
462        let entries = vec![entry];
463
464        let formatted = SessionManager::format_entries_with_context(&entries, true);
465        assert!(formatted.contains("CONTEXT: Some context here"));
466    }
467
468    #[test]
469    fn format_entries_with_context_false_omits_context() {
470        let mut entry = make_entry("Hello?", "Hi there!");
471        entry.context = Some("Some context here".to_string());
472        let entries = vec![entry];
473
474        let formatted = SessionManager::format_entries_with_context(&entries, false);
475        assert!(!formatted.contains("CONTEXT:"));
476    }
477}