use serde_json::{json, Value};
use sqlx::types::Json as SqlxJson;
use sqlx::PgPool;
use uuid::Uuid;
use super::state::AgentState;
const KEEP_RECENT_TURNS: i32 = 4;
#[derive(Debug, Clone)]
pub struct Turn {
pub role: String,
pub text: String,
}
#[derive(Debug, Clone)]
pub struct SessionContext {
pub session_id: Uuid,
pub history: Vec<Turn>,
pub summary: Option<String>,
pub next_turn_index: i32,
}
pub async fn load_or_create(
pool: Option<&PgPool>,
requested: Option<&str>,
tenant: Uuid,
auth_subject: &str,
source: &str,
) -> SessionContext {
let Some(pool) = pool else {
let sid = requested
.and_then(|s| Uuid::parse_str(s).ok())
.unwrap_or_else(Uuid::new_v4);
return SessionContext { session_id: sid, history: vec![], summary: None, next_turn_index: 0 };
};
if let Some(sid) = requested.and_then(|s| Uuid::parse_str(s).ok()) {
let touched: Option<(Option<String>, i32)> = sqlx::query_as(
"UPDATE agent_sessions SET last_active = NOW() \
WHERE session_id = $1 RETURNING rolling_summary, summary_turn_index",
)
.bind(sid)
.fetch_optional(pool)
.await
.ok()
.flatten();
if let Some((summary, summary_idx)) = touched {
let next = next_turn_index(pool, sid).await;
let rows: Vec<(String, SqlxJson<Value>)> = sqlx::query_as(
"SELECT role, content_json FROM agent_session_turns \
WHERE session_id = $1 AND turn_index >= $2 ORDER BY turn_index ASC",
)
.bind(sid)
.bind(summary_idx)
.fetch_all(pool)
.await
.unwrap_or_default();
let history = rows.into_iter().map(|(role, content)| Turn { role, text: text_of(&content.0) }).collect();
return SessionContext { session_id: sid, history, summary, next_turn_index: next };
}
let _ = sqlx::query(
"INSERT INTO agent_sessions (session_id, tenant_id, auth_subject, source) \
VALUES ($1, $2, $3, $4) ON CONFLICT (session_id) DO NOTHING",
)
.bind(sid)
.bind(tenant)
.bind(auth_subject)
.bind(source)
.execute(pool)
.await;
return SessionContext { session_id: sid, history: vec![], summary: None, next_turn_index: 0 };
}
let sid = Uuid::new_v4();
let _ = sqlx::query(
"INSERT INTO agent_sessions (session_id, tenant_id, auth_subject, source) VALUES ($1, $2, $3, $4)",
)
.bind(sid)
.bind(tenant)
.bind(auth_subject)
.bind(source)
.execute(pool)
.await;
SessionContext { session_id: sid, history: vec![], summary: None, next_turn_index: 0 }
}
pub async fn persist_turn(
pool: Option<&PgPool>,
session_id: Uuid,
tenant: Uuid,
turn_index: i32,
role: &str,
text: &str,
run_id: Option<Uuid>,
) {
let Some(pool) = pool else { return }; let _ = sqlx::query(
"INSERT INTO agent_session_turns (session_id, tenant_id, turn_index, role, content_json, run_id) \
VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (session_id, turn_index) DO NOTHING",
)
.bind(session_id)
.bind(tenant)
.bind(turn_index)
.bind(role)
.bind(SqlxJson(json!({ "text": text })))
.bind(run_id)
.execute(pool)
.await;
}
pub fn maybe_summarize_detached(state: AgentState, session_id: Uuid, every: i32) {
tokio::spawn(async move {
if let Err(e) = summarize_if_needed(&state, session_id, every).await {
tracing::warn!(session_id = %session_id, error = %e, "rolling summary failed");
}
});
}
async fn summarize_if_needed(state: &AgentState, session_id: Uuid, every: i32) -> anyhow::Result<()> {
let Some(pool) = state.pool.as_ref() else { return Ok(()) }; let (prev_summary, summary_idx): (Option<String>, i32) =
sqlx::query_as("SELECT rolling_summary, summary_turn_index FROM agent_sessions WHERE session_id = $1")
.bind(session_id)
.fetch_one(pool)
.await?;
let max_idx: Option<i32> =
sqlx::query_scalar::<_, Option<i32>>("SELECT MAX(turn_index) FROM agent_session_turns WHERE session_id = $1")
.bind(session_id)
.fetch_one(pool)
.await
.unwrap_or(None);
let Some(max_idx) = max_idx else { return Ok(()) };
if max_idx - summary_idx < every {
return Ok(());
}
let boundary = max_idx + 1 - KEEP_RECENT_TURNS;
if boundary <= summary_idx {
return Ok(());
}
let rows: Vec<(String, SqlxJson<Value>)> = sqlx::query_as(
"SELECT role, content_json FROM agent_session_turns \
WHERE session_id = $1 AND turn_index >= $2 AND turn_index < $3 ORDER BY turn_index ASC",
)
.bind(session_id)
.bind(summary_idx)
.bind(boundary)
.fetch_all(pool)
.await?;
let mut transcript = String::new();
if let Some(ps) = &prev_summary {
if !ps.trim().is_empty() {
transcript.push_str(&format!("Summary so far:\n{ps}\n\n"));
}
}
for (role, content) in &rows {
transcript.push_str(&format!("{role}: {}\n", text_of(&content.0)));
}
let summary = super::runner::summarize_conversation(state, &transcript).await?;
if summary.trim().is_empty() {
return Ok(());
}
sqlx::query("UPDATE agent_sessions SET rolling_summary = $1, summary_turn_index = $2 WHERE session_id = $3")
.bind(summary.trim())
.bind(boundary)
.bind(session_id)
.execute(pool)
.await?;
Ok(())
}
async fn next_turn_index(pool: &PgPool, session_id: Uuid) -> i32 {
let max_idx: Option<i32> =
sqlx::query_scalar::<_, Option<i32>>("SELECT MAX(turn_index) FROM agent_session_turns WHERE session_id = $1")
.bind(session_id)
.fetch_one(pool)
.await
.unwrap_or(None);
max_idx.map(|m| m + 1).unwrap_or(0)
}
fn text_of(content: &Value) -> String {
content.get("text").and_then(|v| v.as_str()).unwrap_or_default().to_string()
}