use std::time::Duration;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use uuid::Uuid;
use crate::error::StorageError;
use crate::provider::{ContentPart, ModelName, TokenUsage, ToolCall};
pub mod memory;
pub(crate) mod util;
#[cfg(any(
feature = "sqlx-postgres",
feature = "sqlx-mysql",
feature = "sqlx-sqlite"
))]
pub mod sql;
#[cfg(feature = "mongodb")]
pub mod mongodb;
#[cfg(feature = "surrealdb")]
pub mod surrealdb;
#[cfg(feature = "redis")]
pub mod redis;
#[cfg(feature = "qdrant")]
pub mod qdrant;
#[cfg(feature = "object_store")]
pub mod object;
pub type StoreResult<T> = std::result::Result<T, StorageError>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct Pagination {
pub limit: u32,
pub offset: u32,
}
impl Default for Pagination {
fn default() -> Self {
Self {
limit: 100,
offset: 0,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SessionFilter {
pub metadata_filter: Option<Value>,
pub created_after: Option<DateTime<Utc>>,
pub created_before: Option<DateTime<Utc>>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Session {
pub id: Uuid,
pub title: String,
pub model: ModelName,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
pub metadata: Value,
}
impl Session {
#[must_use]
pub fn new(title: impl Into<String>, model: ModelName) -> Self {
let now = Utc::now();
Self {
id: Uuid::now_v7(),
title: title.into(),
model,
created_at: now,
updated_at: now,
metadata: Value::Null,
}
}
#[must_use]
pub fn with_metadata(mut self, metadata: Value) -> Self {
self.metadata = metadata;
self
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct CompactionMeta {
pub tail_start_id: Option<Uuid>,
pub previous_compaction_id: Option<Uuid>,
pub summary_text: Option<String>,
}
impl CompactionMeta {
#[must_use]
pub fn new(tail_start_id: Uuid) -> Self {
Self {
tail_start_id: Some(tail_start_id),
previous_compaction_id: None,
summary_text: None,
}
}
#[must_use]
pub fn with_previous(mut self, previous_id: Uuid) -> Self {
self.previous_compaction_id = Some(previous_id);
self
}
#[must_use]
pub fn with_summary(mut self, summary: String) -> Self {
self.summary_text = Some(summary);
self
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MessageRecord {
pub id: Uuid,
pub session_id: Uuid,
pub role: MessageRole,
pub content: Vec<ContentPart>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tool_calls: Vec<ToolCall>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_name: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub usage: Option<TokenUsage>,
pub created_at: DateTime<Utc>,
#[serde(default)]
pub is_compaction: bool,
#[serde(default)]
pub is_summary: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub compaction_meta: Option<CompactionMeta>,
}
impl MessageRecord {
#[must_use]
pub fn new(session_id: Uuid, role: MessageRole, content: Vec<ContentPart>) -> Self {
Self {
id: Uuid::now_v7(),
session_id,
role,
content,
tool_calls: Vec::new(),
tool_call_id: None,
tool_name: None,
usage: None,
created_at: Utc::now(),
is_compaction: false,
is_summary: false,
compaction_meta: None,
}
}
#[must_use]
pub fn with_tool_result(mut self, call_id: String, name: String) -> Self {
self.tool_call_id = Some(call_id);
self.tool_name = Some(name);
self
}
#[must_use]
pub fn with_tool_calls(mut self, tool_calls: Vec<ToolCall>) -> Self {
self.tool_calls = tool_calls;
self
}
#[must_use]
pub fn with_usage(mut self, usage: TokenUsage) -> Self {
self.usage = Some(usage);
self
}
#[must_use]
pub fn with_compaction(mut self, meta: CompactionMeta) -> Self {
self.is_compaction = true;
self.compaction_meta = Some(meta);
self
}
#[must_use]
pub fn with_summary(mut self, meta: CompactionMeta) -> Self {
self.is_summary = true;
self.compaction_meta = Some(meta);
self
}
}
impl From<&crate::provider::Message> for MessageRole {
fn from(message: &crate::provider::Message) -> Self {
match message {
crate::provider::Message::System { .. } => Self::System,
crate::provider::Message::User { .. } => Self::User,
crate::provider::Message::Assistant { .. } => Self::Assistant,
crate::provider::Message::Tool { .. } => Self::Tool,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum MessageRole {
System,
User,
Assistant,
Tool,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct EmbeddingRecord {
pub id: Uuid,
pub session_id: Option<Uuid>,
pub model: String,
pub vector: Vec<f32>,
pub metadata: Value,
pub created_at: DateTime<Utc>,
}
impl EmbeddingRecord {
#[must_use]
pub fn new(model: impl Into<String>, vector: Vec<f32>) -> Self {
Self {
id: Uuid::now_v7(),
session_id: None,
model: model.into(),
vector,
metadata: Value::Null,
created_at: Utc::now(),
}
}
#[must_use]
pub fn with_session(mut self, session_id: Uuid) -> Self {
self.session_id = Some(session_id);
self
}
#[must_use]
pub fn with_metadata(mut self, metadata: Value) -> Self {
self.metadata = metadata;
self
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ScoredEmbedding {
pub record: EmbeddingRecord,
pub score: f32,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Artifact {
pub id: Uuid,
pub session_id: Option<Uuid>,
pub name: String,
pub content_type: String,
#[serde(with = "base64_bytes")]
pub data: Vec<u8>,
pub metadata: Value,
pub created_at: DateTime<Utc>,
}
impl Artifact {
#[must_use]
pub fn new(name: impl Into<String>, content_type: impl Into<String>, data: Vec<u8>) -> Self {
Self {
id: Uuid::now_v7(),
session_id: None,
name: name.into(),
content_type: content_type.into(),
data,
metadata: Value::Null,
created_at: Utc::now(),
}
}
#[must_use]
pub fn with_session(mut self, session_id: Uuid) -> Self {
self.session_id = Some(session_id);
self
}
#[must_use]
pub fn with_metadata(mut self, metadata: Value) -> Self {
self.metadata = metadata;
self
}
}
mod base64_bytes {
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub(super) fn serialize<S>(data: &Vec<u8>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
use base64::Engine as _;
base64::engine::general_purpose::STANDARD
.encode(data)
.serialize(serializer)
}
pub(super) fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
where
D: Deserializer<'de>,
{
use base64::Engine as _;
let s = String::deserialize(deserializer)?;
base64::engine::general_purpose::STANDARD
.decode(&s)
.map_err(serde::de::Error::custom)
}
}
#[async_trait]
pub trait SessionStore: Send + Sync {
async fn create_session(&self, session: Session) -> StoreResult<Session>;
async fn list_sessions(&self) -> StoreResult<Vec<Session>>;
async fn get_session(&self, id: &Uuid) -> StoreResult<Option<Session>>;
async fn delete_session(&self, id: &Uuid) -> StoreResult<()>;
async fn update_session(
&self,
id: &Uuid,
title: &str,
model: Option<&ModelName>,
) -> StoreResult<Session> {
let _ = (id, title, model);
Err(StorageError::BackendError {
backend: "session".to_owned(),
message: "update_session not implemented for this backend".to_owned(),
source: None,
})
}
async fn append_message(&self, message: MessageRecord) -> StoreResult<MessageRecord>;
async fn list_messages(&self, session_id: &Uuid) -> StoreResult<Vec<MessageRecord>>;
async fn update_usage(&self, message_id: &Uuid, usage: TokenUsage) -> StoreResult<()>;
async fn list_sessions_paginated(
&self,
pagination: Pagination,
filter: SessionFilter,
) -> StoreResult<Vec<Session>> {
let sessions = self.list_sessions().await?;
let filtered: Vec<Session> = sessions
.into_iter()
.filter(|s| {
if let Some(ref after) = filter.created_after {
if s.created_at < *after {
return false;
}
}
if let Some(ref before) = filter.created_before {
if s.created_at >= *before {
return false;
}
}
if let Some(ref meta_filter) = filter.metadata_filter {
return s.metadata == *meta_filter;
}
true
})
.skip(pagination.offset as usize)
.take(pagination.limit as usize)
.collect();
Ok(filtered)
}
async fn list_messages_paginated(
&self,
session_id: &Uuid,
pagination: Pagination,
) -> StoreResult<Vec<MessageRecord>> {
let messages = self.list_messages(session_id).await?;
Ok(messages
.into_iter()
.skip(pagination.offset as usize)
.take(pagination.limit as usize)
.collect())
}
async fn health_check(&self) -> StoreResult<()> {
Ok(())
}
async fn get_latest_compaction(&self, session_id: &Uuid) -> StoreResult<Option<MessageRecord>> {
let messages = self.list_messages(session_id).await?;
Ok(messages.into_iter().rev().find(|m| m.is_compaction))
}
async fn mark_compacted(&self, _message_id: &Uuid) -> StoreResult<()> {
Ok(())
}
}
#[async_trait]
pub trait EmbeddingStore: Send + Sync {
async fn upsert(&self, record: EmbeddingRecord) -> StoreResult<EmbeddingRecord>;
async fn search(&self, query: &[f32], limit: usize) -> StoreResult<Vec<ScoredEmbedding>>;
async fn delete(&self, id: &Uuid) -> StoreResult<()>;
async fn delete_by_session(&self, session_id: &Uuid) -> StoreResult<u64>;
}
#[async_trait]
pub trait ArtifactStore: Send + Sync {
async fn put(&self, artifact: Artifact) -> StoreResult<Artifact>;
async fn get(&self, id: &Uuid) -> StoreResult<Option<Artifact>>;
async fn delete(&self, id: &Uuid) -> StoreResult<()>;
async fn list_by_session(&self, session_id: &Uuid) -> StoreResult<Vec<Artifact>>;
async fn delete_by_session(&self, session_id: &Uuid) -> StoreResult<u64> {
let _ = session_id;
Ok(0)
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolExecution {
pub id: Uuid,
pub session_id: Uuid,
pub message_id: Uuid,
pub call_id: String,
pub tool_name: String,
pub arguments: Value,
pub result: Option<Value>,
pub status: ToolExecutionStatus,
pub error: Option<String>,
#[serde(with = "duration_millis")]
pub duration: Duration,
pub created_at: DateTime<Utc>,
}
impl ToolExecution {
#[must_use]
pub fn new(
session_id: Uuid,
message_id: Uuid,
call_id: impl Into<String>,
tool_name: impl Into<String>,
arguments: Value,
) -> Self {
Self {
id: Uuid::now_v7(),
session_id,
message_id,
call_id: call_id.into(),
tool_name: tool_name.into(),
arguments,
result: None,
status: ToolExecutionStatus::Pending,
error: None,
duration: Duration::ZERO,
created_at: Utc::now(),
}
}
#[must_use]
pub fn with_success(mut self, result: Value, duration: Duration) -> Self {
self.result = Some(result);
self.status = ToolExecutionStatus::Success;
self.duration = duration;
self
}
#[must_use]
pub fn with_failure(mut self, error: impl Into<String>, duration: Duration) -> Self {
self.error = Some(error.into());
self.status = ToolExecutionStatus::Failed;
self.duration = duration;
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum ToolExecutionStatus {
Pending,
Success,
Failed,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct UsageRecord {
pub id: Uuid,
pub session_id: Uuid,
pub message_id: Uuid,
pub provider: String,
pub model: String,
pub input_tokens: u64,
pub output_tokens: u64,
pub total_tokens: u64,
pub created_at: DateTime<Utc>,
}
impl UsageRecord {
#[must_use]
pub fn new(
session_id: Uuid,
message_id: Uuid,
provider: impl Into<String>,
model: impl Into<String>,
usage: TokenUsage,
) -> Self {
Self {
id: Uuid::now_v7(),
session_id,
message_id,
provider: provider.into(),
model: model.into(),
input_tokens: usage.input_tokens,
output_tokens: usage.output_tokens,
total_tokens: usage.total_tokens,
created_at: Utc::now(),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SessionStats {
pub session_id: Uuid,
pub message_count: u64,
pub tool_call_count: u64,
pub tool_success_count: u64,
pub tool_failure_count: u64,
pub total_input_tokens: u64,
pub total_output_tokens: u64,
pub total_tokens: u64,
pub avg_tool_duration_ms: u64,
}
impl SessionStats {
#[must_use]
pub fn empty(session_id: Uuid) -> Self {
Self {
session_id,
message_count: 0,
tool_call_count: 0,
tool_success_count: 0,
tool_failure_count: 0,
total_input_tokens: 0,
total_output_tokens: 0,
total_tokens: 0,
avg_tool_duration_ms: 0,
}
}
}
#[async_trait]
pub trait ExecutionStore: Send + Sync {
async fn record_execution(&self, execution: ToolExecution) -> StoreResult<ToolExecution>;
async fn list_executions(&self, session_id: &Uuid) -> StoreResult<Vec<ToolExecution>>;
async fn list_executions_by_message(
&self,
message_id: &Uuid,
) -> StoreResult<Vec<ToolExecution>>;
async fn record_usage(&self, record: UsageRecord) -> StoreResult<UsageRecord>;
async fn list_usage(&self, session_id: &Uuid) -> StoreResult<Vec<UsageRecord>>;
async fn session_stats(&self, session_id: &Uuid) -> StoreResult<SessionStats>;
async fn delete_by_session(&self, session_id: &Uuid) -> StoreResult<u64> {
let _ = session_id;
Ok(0)
}
}
mod duration_millis {
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::time::Duration;
pub(super) fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
duration.as_millis().serialize(serializer)
}
pub(super) fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
let millis = u64::deserialize(deserializer)?;
Ok(Duration::from_millis(millis))
}
}