use std::collections::HashMap;
use std::sync::Arc;
use cognee_llm::Llm;
use cognee_llm::Message;
use tracing::debug;
use crate::error::SessionError;
use crate::feedback;
use crate::session_store::{SessionQAUpdate, SessionStore};
use crate::types::{SessionQAEntry, SessionTraceStep, UsedGraphElementIds};
const DEFAULT_SESSION_ID: &str = "default_session";
const DEFAULT_HISTORY_LIMIT: usize = 10;
#[derive(Clone)]
pub struct SessionManager {
store: Arc<dyn SessionStore>,
default_session_id: String,
history_limit: usize,
llm: Option<Arc<dyn Llm>>,
}
impl SessionManager {
pub fn new(store: Arc<dyn SessionStore>) -> Self {
Self {
store,
default_session_id: DEFAULT_SESSION_ID.to_string(),
history_limit: DEFAULT_HISTORY_LIMIT,
llm: None,
}
}
pub fn with_llm(mut self, llm: Arc<dyn Llm>) -> Self {
self.llm = Some(llm);
self
}
pub fn with_default_session_id(mut self, id: impl Into<String>) -> Self {
self.default_session_id = id.into();
self
}
pub fn with_history_limit(mut self, limit: usize) -> Self {
self.history_limit = limit;
self
}
fn resolve_session_id<'a>(&'a self, session_id: Option<&'a str>) -> &'a str {
session_id.unwrap_or(&self.default_session_id)
}
pub async fn load_history_messages(
&self,
session_id: Option<&str>,
user_id: Option<&str>,
) -> Result<Vec<Message>, SessionError> {
let resolved_id = self.resolve_session_id(session_id);
let entries = self
.store
.get_latest_qa_entries(resolved_id, user_id, self.history_limit)
.await?;
debug!(
session_id = resolved_id,
entries = entries.len(),
"Loaded session history"
);
Ok(entries_to_messages(&entries))
}
pub async fn load_history_both(
&self,
session_id: Option<&str>,
user_id: Option<&str>,
) -> Result<(Vec<Message>, String), SessionError> {
let resolved_id = self.resolve_session_id(session_id);
let entries = self
.store
.get_latest_qa_entries(resolved_id, user_id, self.history_limit)
.await?;
debug!(
session_id = resolved_id,
entries = entries.len(),
"Loaded session history (both forms)"
);
let messages = entries_to_messages(&entries);
let formatted = Self::format_entries(&entries);
Ok((messages, formatted))
}
pub async fn save_qa(
&self,
session_id: Option<&str>,
user_id: Option<&str>,
question: &str,
answer: &str,
context: Option<&str>,
used_graph_element_ids: Option<UsedGraphElementIds>,
) -> Result<String, SessionError> {
let resolved_id = self.resolve_session_id(session_id);
let qa_id = self
.store
.create_qa_entry(resolved_id, user_id, question, answer, context)
.await?;
if let Some(ids) = used_graph_element_ids
&& let Err(e) = self
.store
.update_qa_entry(
resolved_id,
user_id,
&qa_id,
SessionQAUpdate {
used_graph_element_ids: Some(Some(ids)),
..Default::default()
},
)
.await
{
tracing::warn!(
qa_id = %qa_id,
"save_qa: failed to persist used_graph_element_ids (non-fatal): {e}"
);
}
#[cfg(feature = "telemetry")]
{
let data_size_bytes =
question.len() + answer.len() + context.map(|c| c.len()).unwrap_or(0);
cognee_telemetry::send_telemetry(
"cognee.session.add_qa",
user_id.unwrap_or("sdk"),
Some(serde_json::json!({
"session_id": resolved_id,
"data_size_bytes": data_size_bytes,
"has_feedback": false,
"has_graph_elements": false,
})),
);
}
Ok(qa_id)
}
pub async fn delete_session(
&self,
session_id: Option<&str>,
user_id: Option<&str>,
) -> Result<bool, SessionError> {
let resolved_id = self.resolve_session_id(session_id);
self.store.delete_session(resolved_id, user_id).await
}
pub fn format_entries(entries: &[SessionQAEntry]) -> String {
Self::format_entries_with_context(entries, false)
}
pub fn format_entries_with_context(
entries: &[SessionQAEntry],
include_context: bool,
) -> String {
if entries.is_empty() {
return String::new();
}
let mut lines = vec!["Previous conversation:\n\n".to_string()];
for entry in entries {
lines.push(format!("[{}]\n", entry.created_at.to_rfc3339()));
lines.push(format!("QUESTION: {}\n", entry.question));
if include_context && let Some(ref ctx) = entry.context {
lines.push(format!("CONTEXT: {ctx}\n"));
}
lines.push(format!("ANSWER: {}\n\n", entry.answer));
}
lines.concat()
}
pub async fn update_qa(
&self,
session_id: Option<&str>,
user_id: Option<&str>,
qa_id: &str,
updates: SessionQAUpdate,
) -> Result<bool, SessionError> {
let resolved_id = self.resolve_session_id(session_id);
self.store
.update_qa_entry(resolved_id, user_id, qa_id, updates)
.await
}
pub async fn add_feedback(
&self,
session_id: Option<&str>,
user_id: Option<&str>,
qa_id: &str,
feedback_text: Option<&str>,
feedback_score: Option<i32>,
) -> Result<bool, SessionError> {
if let Some(score) = feedback_score
&& !(1..=5).contains(&score)
{
return Err(SessionError::InvalidParameter(format!(
"feedback_score must be between 1 and 5, got {score}"
)));
}
let mut memify = HashMap::new();
memify.insert("feedback_weights_applied".to_string(), false);
self.update_qa(
session_id,
user_id,
qa_id,
SessionQAUpdate {
feedback_text: Some(feedback_text.map(|s| s.to_string())),
feedback_score: Some(feedback_score),
memify_metadata: Some(Some(memify)),
..Default::default()
},
)
.await
}
pub async fn delete_feedback(
&self,
session_id: Option<&str>,
user_id: Option<&str>,
qa_id: &str,
) -> Result<bool, SessionError> {
self.update_qa(
session_id,
user_id,
qa_id,
SessionQAUpdate {
feedback_text: Some(None),
feedback_score: Some(None),
..Default::default()
},
)
.await
}
pub async fn latest_qa_id(
&self,
session_id: Option<&str>,
user_id: Option<&str>,
) -> Result<Option<String>, SessionError> {
let resolved_id = self.resolve_session_id(session_id);
self.store.latest_qa_id(resolved_id, user_id).await
}
pub async fn get_graph_context(
&self,
session_id: Option<&str>,
user_id: Option<&str>,
) -> Result<Option<String>, SessionError> {
let resolved_id = self.resolve_session_id(session_id);
self.store.get_graph_context(resolved_id, user_id).await
}
pub async fn set_graph_context(
&self,
session_id: Option<&str>,
user_id: Option<&str>,
context: &str,
) -> Result<(), SessionError> {
let resolved_id = self.resolve_session_id(session_id);
self.store
.set_graph_context(resolved_id, user_id, context)
.await
}
#[allow(clippy::too_many_arguments)]
pub async fn add_agent_trace_step(
&self,
user_id: &str,
session_id: Option<&str>,
origin_function: &str,
status: &str,
memory_query: &str,
memory_context: &str,
method_params: serde_json::Value,
method_return_value: Option<serde_json::Value>,
error_message: &str,
generate_feedback: bool,
) -> Result<String, SessionError> {
let resolved_id = self.resolve_session_id(session_id);
let trace_id = uuid::Uuid::new_v4().to_string();
let session_feedback = if generate_feedback {
if let Some(llm) = self.llm.as_ref() {
feedback::generate_session_feedback(
llm.as_ref(),
origin_function,
status,
method_return_value.as_ref(),
error_message,
)
.await
} else {
tracing::warn!(
origin_function,
session_id = resolved_id,
"add_agent_trace_step: generate_feedback=true but no LLM wired; using deterministic fallback"
);
feedback::fallback_feedback(origin_function, status, error_message)
}
} else {
feedback::fallback_feedback(origin_function, status, error_message)
};
let step = SessionTraceStep {
trace_id: trace_id.clone(),
origin_function: origin_function.to_string(),
status: status.to_string(),
memory_query: memory_query.to_string(),
memory_context: memory_context.to_string(),
method_params,
method_return_value,
error_message: error_message.to_string(),
session_feedback,
};
self.store.save_trace_step(user_id, resolved_id, step).await
}
pub async fn get_agent_trace_session(
&self,
user_id: &str,
session_id: Option<&str>,
last_n: Option<usize>,
) -> Result<Vec<SessionTraceStep>, SessionError> {
let resolved_id = self.resolve_session_id(session_id);
let mut entries = self.store.read_trace_steps(user_id, resolved_id).await?;
if let Some(n) = last_n {
let drop = entries.len().saturating_sub(n);
entries = entries.split_off(drop);
}
Ok(entries)
}
}
fn entries_to_messages(entries: &[SessionQAEntry]) -> Vec<Message> {
let mut messages = Vec::with_capacity(entries.len() * 2);
for entry in entries {
messages.push(Message::user(&entry.question));
messages.push(Message::assistant(&entry.answer));
}
messages
}
#[cfg(test)]
mod tests {
use super::*;
fn make_entry(question: &str, answer: &str) -> SessionQAEntry {
SessionQAEntry {
id: uuid::Uuid::new_v4(),
session_id: "s1".to_string(),
user_id: None,
question: question.to_string(),
answer: answer.to_string(),
context: None,
created_at: chrono::Utc::now(),
feedback_text: None,
feedback_score: None,
used_graph_element_ids: None,
memify_metadata: None,
}
}
#[test]
fn entries_to_messages_alternates_roles() {
let entries = vec![
make_entry("What is Rust?", "A systems programming language."),
make_entry("Tell me more.", "It focuses on safety and performance."),
];
let messages = entries_to_messages(&entries);
assert_eq!(messages.len(), 4);
assert_eq!(messages[0].role, cognee_llm::MessageRole::User);
assert_eq!(messages[0].content, "What is Rust?");
assert_eq!(messages[1].role, cognee_llm::MessageRole::Assistant);
assert_eq!(messages[1].content, "A systems programming language.");
assert_eq!(messages[2].role, cognee_llm::MessageRole::User);
assert_eq!(messages[3].role, cognee_llm::MessageRole::Assistant);
}
#[test]
fn format_entries_produces_expected_output() {
let entries = vec![make_entry("Hello?", "Hi there!")];
let formatted = SessionManager::format_entries(&entries);
assert!(formatted.contains("Previous conversation:"));
assert!(formatted.contains("QUESTION: Hello?"));
assert!(formatted.contains("ANSWER: Hi there!"));
}
#[test]
fn format_entries_empty_returns_empty_string() {
assert_eq!(SessionManager::format_entries(&[]), "");
}
#[test]
fn format_entries_with_context_includes_context() {
let mut entry = make_entry("Hello?", "Hi there!");
entry.context = Some("Some context here".to_string());
let entries = vec![entry];
let formatted = SessionManager::format_entries_with_context(&entries, true);
assert!(formatted.contains("CONTEXT: Some context here"));
}
#[test]
fn format_entries_with_context_false_omits_context() {
let mut entry = make_entry("Hello?", "Hi there!");
entry.context = Some("Some context here".to_string());
let entries = vec![entry];
let formatted = SessionManager::format_entries_with_context(&entries, false);
assert!(!formatted.contains("CONTEXT:"));
}
}