1use std::collections::HashMap;
2use std::sync::Arc;
3
4use cognee_llm::Llm;
5use cognee_llm::Message;
6use tracing::debug;
7
8use crate::error::SessionError;
9use crate::feedback;
10use crate::session_store::{SessionQAUpdate, SessionStore};
11use crate::types::{SessionQAEntry, SessionTraceStep, UsedGraphElementIds};
12
13const DEFAULT_SESSION_ID: &str = "default_session";
14const DEFAULT_HISTORY_LIMIT: usize = 10;
15
16#[derive(Clone)]
22pub struct SessionManager {
23 store: Arc<dyn SessionStore>,
24 default_session_id: String,
25 history_limit: usize,
26 llm: Option<Arc<dyn Llm>>,
27}
28
29impl SessionManager {
30 pub fn new(store: Arc<dyn SessionStore>) -> Self {
31 Self {
32 store,
33 default_session_id: DEFAULT_SESSION_ID.to_string(),
34 history_limit: DEFAULT_HISTORY_LIMIT,
35 llm: None,
36 }
37 }
38
39 pub fn with_llm(mut self, llm: Arc<dyn Llm>) -> Self {
40 self.llm = Some(llm);
41 self
42 }
43
44 pub fn with_default_session_id(mut self, id: impl Into<String>) -> Self {
45 self.default_session_id = id.into();
46 self
47 }
48
49 pub fn with_history_limit(mut self, limit: usize) -> Self {
50 self.history_limit = limit;
51 self
52 }
53
54 fn resolve_session_id<'a>(&'a self, session_id: Option<&'a str>) -> &'a str {
55 session_id.unwrap_or(&self.default_session_id)
56 }
57
58 pub async fn load_history_messages(
63 &self,
64 session_id: Option<&str>,
65 user_id: Option<&str>,
66 ) -> Result<Vec<Message>, SessionError> {
67 let resolved_id = self.resolve_session_id(session_id);
68 let entries = self
69 .store
70 .get_latest_qa_entries(resolved_id, user_id, self.history_limit)
71 .await?;
72
73 debug!(
74 session_id = resolved_id,
75 entries = entries.len(),
76 "Loaded session history"
77 );
78
79 Ok(entries_to_messages(&entries))
80 }
81
82 pub async fn load_history_both(
84 &self,
85 session_id: Option<&str>,
86 user_id: Option<&str>,
87 ) -> Result<(Vec<Message>, String), SessionError> {
88 let resolved_id = self.resolve_session_id(session_id);
89 let entries = self
90 .store
91 .get_latest_qa_entries(resolved_id, user_id, self.history_limit)
92 .await?;
93
94 debug!(
95 session_id = resolved_id,
96 entries = entries.len(),
97 "Loaded session history (both forms)"
98 );
99
100 let messages = entries_to_messages(&entries);
101 let formatted = Self::format_entries(&entries);
102 Ok((messages, formatted))
103 }
104
105 pub async fn save_qa(
112 &self,
113 session_id: Option<&str>,
114 user_id: Option<&str>,
115 question: &str,
116 answer: &str,
117 context: Option<&str>,
118 used_graph_element_ids: Option<UsedGraphElementIds>,
119 ) -> Result<String, SessionError> {
120 let resolved_id = self.resolve_session_id(session_id);
121 let qa_id = self
122 .store
123 .create_qa_entry(resolved_id, user_id, question, answer, context)
124 .await?;
125
126 if let Some(ids) = used_graph_element_ids
128 && let Err(e) = self
129 .store
130 .update_qa_entry(
131 resolved_id,
132 user_id,
133 &qa_id,
134 SessionQAUpdate {
135 used_graph_element_ids: Some(Some(ids)),
136 ..Default::default()
137 },
138 )
139 .await
140 {
141 tracing::warn!(
142 qa_id = %qa_id,
143 "save_qa: failed to persist used_graph_element_ids (non-fatal): {e}"
144 );
145 }
146
147 #[cfg(feature = "telemetry")]
150 {
151 let data_size_bytes =
152 question.len() + answer.len() + context.map(|c| c.len()).unwrap_or(0);
153 cognee_telemetry::send_telemetry(
154 "cognee.session.add_qa",
155 user_id.unwrap_or("sdk"),
156 Some(serde_json::json!({
157 "session_id": resolved_id,
158 "data_size_bytes": data_size_bytes,
159 "has_feedback": false,
160 "has_graph_elements": false,
161 })),
162 );
163 }
164
165 Ok(qa_id)
166 }
167
168 pub async fn delete_session(
170 &self,
171 session_id: Option<&str>,
172 user_id: Option<&str>,
173 ) -> Result<bool, SessionError> {
174 let resolved_id = self.resolve_session_id(session_id);
175 self.store.delete_session(resolved_id, user_id).await
176 }
177
178 pub fn format_entries(entries: &[SessionQAEntry]) -> String {
184 Self::format_entries_with_context(entries, false)
185 }
186
187 pub fn format_entries_with_context(
189 entries: &[SessionQAEntry],
190 include_context: bool,
191 ) -> String {
192 if entries.is_empty() {
193 return String::new();
194 }
195 let mut lines = vec!["Previous conversation:\n\n".to_string()];
196 for entry in entries {
197 lines.push(format!("[{}]\n", entry.created_at.to_rfc3339()));
198 lines.push(format!("QUESTION: {}\n", entry.question));
199 if include_context && let Some(ref ctx) = entry.context {
200 lines.push(format!("CONTEXT: {ctx}\n"));
201 }
202 lines.push(format!("ANSWER: {}\n\n", entry.answer));
203 }
204 lines.concat()
205 }
206
207 pub async fn update_qa(
209 &self,
210 session_id: Option<&str>,
211 user_id: Option<&str>,
212 qa_id: &str,
213 updates: SessionQAUpdate,
214 ) -> Result<bool, SessionError> {
215 let resolved_id = self.resolve_session_id(session_id);
216 self.store
217 .update_qa_entry(resolved_id, user_id, qa_id, updates)
218 .await
219 }
220
221 pub async fn add_feedback(
226 &self,
227 session_id: Option<&str>,
228 user_id: Option<&str>,
229 qa_id: &str,
230 feedback_text: Option<&str>,
231 feedback_score: Option<i32>,
232 ) -> Result<bool, SessionError> {
233 if let Some(score) = feedback_score
234 && !(1..=5).contains(&score)
235 {
236 return Err(SessionError::InvalidParameter(format!(
237 "feedback_score must be between 1 and 5, got {score}"
238 )));
239 }
240
241 let mut memify = HashMap::new();
242 memify.insert("feedback_weights_applied".to_string(), false);
243
244 self.update_qa(
245 session_id,
246 user_id,
247 qa_id,
248 SessionQAUpdate {
249 feedback_text: Some(feedback_text.map(|s| s.to_string())),
250 feedback_score: Some(feedback_score),
251 memify_metadata: Some(Some(memify)),
252 ..Default::default()
253 },
254 )
255 .await
256 }
257
258 pub async fn delete_feedback(
260 &self,
261 session_id: Option<&str>,
262 user_id: Option<&str>,
263 qa_id: &str,
264 ) -> Result<bool, SessionError> {
265 self.update_qa(
266 session_id,
267 user_id,
268 qa_id,
269 SessionQAUpdate {
270 feedback_text: Some(None),
271 feedback_score: Some(None),
272 ..Default::default()
273 },
274 )
275 .await
276 }
277
278 pub async fn latest_qa_id(
284 &self,
285 session_id: Option<&str>,
286 user_id: Option<&str>,
287 ) -> Result<Option<String>, SessionError> {
288 let resolved_id = self.resolve_session_id(session_id);
289 self.store.latest_qa_id(resolved_id, user_id).await
290 }
291
292 pub async fn get_graph_context(
294 &self,
295 session_id: Option<&str>,
296 user_id: Option<&str>,
297 ) -> Result<Option<String>, SessionError> {
298 let resolved_id = self.resolve_session_id(session_id);
299 self.store.get_graph_context(resolved_id, user_id).await
300 }
301
302 pub async fn set_graph_context(
304 &self,
305 session_id: Option<&str>,
306 user_id: Option<&str>,
307 context: &str,
308 ) -> Result<(), SessionError> {
309 let resolved_id = self.resolve_session_id(session_id);
310 self.store
311 .set_graph_context(resolved_id, user_id, context)
312 .await
313 }
314
315 #[allow(clippy::too_many_arguments)]
325 pub async fn add_agent_trace_step(
326 &self,
327 user_id: &str,
328 session_id: Option<&str>,
329 origin_function: &str,
330 status: &str,
331 memory_query: &str,
332 memory_context: &str,
333 method_params: serde_json::Value,
334 method_return_value: Option<serde_json::Value>,
335 error_message: &str,
336 generate_feedback: bool,
337 ) -> Result<String, SessionError> {
338 let resolved_id = self.resolve_session_id(session_id);
339 let trace_id = uuid::Uuid::new_v4().to_string();
340 let session_feedback = if generate_feedback {
341 if let Some(llm) = self.llm.as_ref() {
342 feedback::generate_session_feedback(
343 llm.as_ref(),
344 origin_function,
345 status,
346 method_return_value.as_ref(),
347 error_message,
348 )
349 .await
350 } else {
351 tracing::warn!(
352 origin_function,
353 session_id = resolved_id,
354 "add_agent_trace_step: generate_feedback=true but no LLM wired; using deterministic fallback"
355 );
356 feedback::fallback_feedback(origin_function, status, error_message)
357 }
358 } else {
359 feedback::fallback_feedback(origin_function, status, error_message)
360 };
361
362 let step = SessionTraceStep {
363 trace_id: trace_id.clone(),
364 origin_function: origin_function.to_string(),
365 status: status.to_string(),
366 memory_query: memory_query.to_string(),
367 memory_context: memory_context.to_string(),
368 method_params,
369 method_return_value,
370 error_message: error_message.to_string(),
371 session_feedback,
372 };
373 self.store.save_trace_step(user_id, resolved_id, step).await
374 }
375
376 pub async fn get_agent_trace_session(
381 &self,
382 user_id: &str,
383 session_id: Option<&str>,
384 last_n: Option<usize>,
385 ) -> Result<Vec<SessionTraceStep>, SessionError> {
386 let resolved_id = self.resolve_session_id(session_id);
387 let mut entries = self.store.read_trace_steps(user_id, resolved_id).await?;
388 if let Some(n) = last_n {
389 let drop = entries.len().saturating_sub(n);
390 entries = entries.split_off(drop);
391 }
392 Ok(entries)
393 }
394}
395
396fn entries_to_messages(entries: &[SessionQAEntry]) -> Vec<Message> {
398 let mut messages = Vec::with_capacity(entries.len() * 2);
399 for entry in entries {
400 messages.push(Message::user(&entry.question));
401 messages.push(Message::assistant(&entry.answer));
402 }
403 messages
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409
410 fn make_entry(question: &str, answer: &str) -> SessionQAEntry {
411 SessionQAEntry {
412 id: uuid::Uuid::new_v4(),
413 session_id: "s1".to_string(),
414 user_id: None,
415 question: question.to_string(),
416 answer: answer.to_string(),
417 context: None,
418 created_at: chrono::Utc::now(),
419 feedback_text: None,
420 feedback_score: None,
421 used_graph_element_ids: None,
422 memify_metadata: None,
423 }
424 }
425
426 #[test]
427 fn entries_to_messages_alternates_roles() {
428 let entries = vec![
429 make_entry("What is Rust?", "A systems programming language."),
430 make_entry("Tell me more.", "It focuses on safety and performance."),
431 ];
432
433 let messages = entries_to_messages(&entries);
434 assert_eq!(messages.len(), 4);
435 assert_eq!(messages[0].role, cognee_llm::MessageRole::User);
436 assert_eq!(messages[0].content, "What is Rust?");
437 assert_eq!(messages[1].role, cognee_llm::MessageRole::Assistant);
438 assert_eq!(messages[1].content, "A systems programming language.");
439 assert_eq!(messages[2].role, cognee_llm::MessageRole::User);
440 assert_eq!(messages[3].role, cognee_llm::MessageRole::Assistant);
441 }
442
443 #[test]
444 fn format_entries_produces_expected_output() {
445 let entries = vec![make_entry("Hello?", "Hi there!")];
446
447 let formatted = SessionManager::format_entries(&entries);
448 assert!(formatted.contains("Previous conversation:"));
449 assert!(formatted.contains("QUESTION: Hello?"));
450 assert!(formatted.contains("ANSWER: Hi there!"));
451 }
452
453 #[test]
454 fn format_entries_empty_returns_empty_string() {
455 assert_eq!(SessionManager::format_entries(&[]), "");
456 }
457
458 #[test]
459 fn format_entries_with_context_includes_context() {
460 let mut entry = make_entry("Hello?", "Hi there!");
461 entry.context = Some("Some context here".to_string());
462 let entries = vec![entry];
463
464 let formatted = SessionManager::format_entries_with_context(&entries, true);
465 assert!(formatted.contains("CONTEXT: Some context here"));
466 }
467
468 #[test]
469 fn format_entries_with_context_false_omits_context() {
470 let mut entry = make_entry("Hello?", "Hi there!");
471 entry.context = Some("Some context here".to_string());
472 let entries = vec![entry];
473
474 let formatted = SessionManager::format_entries_with_context(&entries, false);
475 assert!(!formatted.contains("CONTEXT:"));
476 }
477}