use anyhow::Result;
use std::path::Path;
#[derive(Debug, Clone, PartialEq)]
#[allow(dead_code)]
pub enum Role {
System,
User,
Assistant,
Tool,
}
impl Role {
pub fn as_str(&self) -> &'static str {
match self {
Self::System => "system",
Self::User => "user",
Self::Assistant => "assistant",
Self::Tool => "tool",
}
}
}
impl std::fmt::Display for Role {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
impl std::str::FromStr for Role {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"system" => Ok(Self::System),
"user" => Ok(Self::User),
"assistant" => Ok(Self::Assistant),
"tool" => Ok(Self::Tool),
other => Err(format!("unknown role: {other}")),
}
}
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct Message {
pub id: i64,
pub session_id: String,
pub role: Role,
pub content: Option<String>,
pub full_content: Option<String>,
pub tool_calls: Option<String>,
pub tool_call_id: Option<String>,
pub prompt_tokens: Option<i64>,
pub completion_tokens: Option<i64>,
pub cache_read_tokens: Option<i64>,
pub cache_creation_tokens: Option<i64>,
pub thinking_tokens: Option<i64>,
pub thinking_content: Option<String>,
pub created_at: Option<String>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum InterruptionKind {
Prompt(String),
Tool,
}
#[derive(Debug, Clone, Default)]
pub struct SessionUsage {
pub prompt_tokens: i64,
pub completion_tokens: i64,
pub cache_read_tokens: i64,
pub cache_creation_tokens: i64,
pub thinking_tokens: i64,
pub api_calls: i64,
}
#[derive(Debug, Clone)]
pub struct SessionInfo {
pub id: String,
pub agent_name: String,
pub created_at: String,
pub message_count: i64,
pub total_tokens: i64,
pub title: Option<String>,
pub mode: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct CompactedStats {
pub message_count: i64,
pub session_count: i64,
pub size_bytes: i64,
pub oldest: Option<String>,
}
#[async_trait::async_trait]
pub trait Persistence: Send + Sync {
async fn create_session(&self, agent_name: &str, project_root: &Path) -> Result<String>;
async fn list_sessions(&self, limit: i64, project_root: &Path) -> Result<Vec<SessionInfo>>;
async fn delete_session(&self, session_id: &str) -> Result<bool>;
async fn set_session_title(&self, session_id: &str, title: &str) -> Result<()>;
async fn set_session_mode(&self, session_id: &str, mode: &str) -> Result<()>;
async fn get_session_mode(&self, session_id: &str) -> Result<Option<String>>;
async fn get_session_idle_secs(&self, session_id: &str) -> Result<Option<i64>>;
async fn insert_message(
&self,
session_id: &str,
role: &Role,
content: Option<&str>,
tool_calls: Option<&str>,
tool_call_id: Option<&str>,
usage: Option<&crate::providers::TokenUsage>,
) -> Result<i64>;
#[allow(clippy::too_many_arguments)]
async fn insert_message_with_agent(
&self,
session_id: &str,
role: &Role,
content: Option<&str>,
tool_calls: Option<&str>,
tool_call_id: Option<&str>,
usage: Option<&crate::providers::TokenUsage>,
agent_name: Option<&str>,
) -> Result<i64>;
#[allow(clippy::too_many_arguments)]
async fn insert_tool_message_with_full(
&self,
session_id: &str,
content: &str,
tool_call_id: &str,
full_content: &str,
) -> Result<i64>;
async fn load_context(&self, session_id: &str) -> Result<Vec<Message>>;
async fn load_all_messages(&self, session_id: &str) -> Result<Vec<Message>>;
async fn recent_user_messages(&self, limit: i64) -> Result<Vec<String>>;
async fn last_assistant_message(&self, session_id: &str) -> Result<String>;
async fn last_user_message(&self, session_id: &str) -> Result<String>;
async fn has_pending_tool_calls(&self, session_id: &str) -> Result<bool>;
async fn mark_message_complete(&self, message_id: i64) -> Result<()>;
async fn update_message_thinking_content(&self, message_id: i64, content: &str) -> Result<()>;
async fn session_token_usage(&self, session_id: &str) -> Result<SessionUsage>;
async fn session_usage_by_agent(&self, session_id: &str)
-> Result<Vec<(String, SessionUsage)>>;
async fn compact_session(
&self,
session_id: &str,
summary: &str,
preserve_count: usize,
) -> Result<usize>;
async fn clear_message_content(&self, message_ids: &[i64], stub: &str) -> Result<()>;
async fn compacted_stats(&self) -> Result<CompactedStats>;
async fn purge_compacted(&self, min_age_days: u32) -> Result<usize>;
async fn get_metadata(&self, session_id: &str, key: &str) -> Result<Option<String>>;
async fn set_metadata(&self, session_id: &str, key: &str, value: &str) -> Result<()>;
async fn get_todo(&self, session_id: &str) -> Result<Option<String>>;
async fn set_todo(&self, session_id: &str, content: &str) -> Result<()>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_role_as_str_all_variants() {
assert_eq!(Role::System.as_str(), "system");
assert_eq!(Role::User.as_str(), "user");
assert_eq!(Role::Assistant.as_str(), "assistant");
assert_eq!(Role::Tool.as_str(), "tool");
}
#[test]
fn test_role_from_str_round_trips() {
for (s, expected) in [
("system", Role::System),
("user", Role::User),
("assistant", Role::Assistant),
("tool", Role::Tool),
] {
let parsed: Role = s.parse().expect(s);
assert_eq!(parsed.as_str(), expected.as_str());
}
}
#[test]
fn test_role_from_str_unknown_returns_error() {
let result: Result<Role, _> = "unknown".parse();
assert!(result.is_err());
assert!(result.unwrap_err().contains("unknown role"));
}
#[test]
fn test_role_display_matches_as_str() {
for role in [Role::System, Role::User, Role::Assistant, Role::Tool] {
assert_eq!(role.to_string(), role.as_str());
}
}
}