use crate::storage::{
BackendInfo, Checkpoint, CheckpointState, CheckpointSummary, CreateCheckpointRequest,
CreateSessionRequest, CreateSessionResult, ListCheckpointsQuery, ListCheckpointsResult,
ListSessionsQuery, ListSessionsResult, Session, SessionStatus, SessionStorage, SessionSummary,
SessionVisibility, StorageError, UpdateSessionRequest,
};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use libsql::{Connection, Database};
use std::path::Path;
use std::str::FromStr;
use tempfile::TempDir;
use uuid::Uuid;
pub struct LocalStorage {
db: Database,
backend_info: BackendInfo,
_temp_dir: Option<TempDir>,
}
impl LocalStorage {
pub async fn new(db_path: &str) -> Result<Self, StorageError> {
let (resolved_path, temp_dir) = if db_path == ":memory:" {
let temp_dir = tempfile::tempdir().map_err(|e| {
StorageError::Connection(format!("Failed to create temp directory: {}", e))
})?;
(temp_dir.path().join("local-storage.db"), Some(temp_dir))
} else {
(Path::new(db_path).to_path_buf(), None)
};
if let Some(parent) = resolved_path.parent() {
std::fs::create_dir_all(parent).map_err(|e| {
StorageError::Connection(format!("Failed to create database directory: {}", e))
})?;
}
let db = libsql::Builder::new_local(&resolved_path)
.build()
.await
.map_err(|e| StorageError::Connection(format!("Failed to open database: {}", e)))?;
let storage = Self {
db,
backend_info: BackendInfo::local(db_path.to_string()),
_temp_dir: temp_dir,
};
storage.configure_database_pragmas().await?;
storage.init_schema().await?;
Ok(storage)
}
pub async fn from_db_and_connection(
db: Database,
_conn: Connection,
) -> Result<Self, StorageError> {
let storage = Self {
db,
backend_info: BackendInfo::local(":memory:"),
_temp_dir: None,
};
storage.configure_database_pragmas().await?;
storage.init_schema().await?;
Ok(storage)
}
async fn configure_database_pragmas(&self) -> Result<(), StorageError> {
let conn = self.connect_raw()?;
stakpak_shared::sqlite::apply_database_pragmas(&conn)
.await
.map_err(|e| StorageError::Connection(e.to_string()))?;
Ok(())
}
#[deprecated(note = "use LocalStorage::new(...) or LocalStorage::from_db_and_connection(...)")]
pub async fn from_connection(_conn: Connection) -> Result<Self, StorageError> {
Err(StorageError::Connection(
"LocalStorage::from_connection is unsupported without the owning libsql::Database; use LocalStorage::new(...) or LocalStorage::from_db_and_connection(...)".to_string(),
))
}
async fn init_schema(&self) -> Result<(), StorageError> {
let conn = self.connection().await?;
super::migrations::run_migrations(&conn)
.await
.map_err(StorageError::Internal)
}
pub(crate) fn connect_raw(&self) -> Result<Connection, StorageError> {
self.db
.connect()
.map_err(|e| StorageError::Connection(format!("Failed to connect to database: {}", e)))
}
pub(crate) async fn connection(&self) -> Result<Connection, StorageError> {
let conn = self.connect_raw()?;
stakpak_shared::sqlite::apply_connection_pragmas(&conn)
.await
.map_err(|e| StorageError::Connection(e.to_string()))?;
Ok(conn)
}
async fn get_latest_checkpoint_for_session_inner(
conn: &Connection,
session_id: Uuid,
) -> Result<Checkpoint, StorageError> {
let mut rows = conn
.query(
"SELECT id, session_id, parent_id, state, created_at, updated_at FROM checkpoints
WHERE session_id = ? ORDER BY created_at DESC LIMIT 1",
[session_id.to_string()],
)
.await
.map_err(|e| StorageError::Internal(e.to_string()))?;
if let Ok(Some(row)) = rows.next().await {
let id: String = row
.get(0)
.map_err(|e| StorageError::Internal(e.to_string()))?;
let session_id: String = row
.get(1)
.map_err(|e| StorageError::Internal(e.to_string()))?;
let parent_id: Option<String> = row.get(2).ok();
let state: Option<String> = row.get(3).ok();
let created_at: String = row
.get(4)
.map_err(|e| StorageError::Internal(e.to_string()))?;
let updated_at: String = row
.get(5)
.map_err(|e| StorageError::Internal(e.to_string()))?;
let state: CheckpointState = if let Some(state_str) = state {
serde_json::from_str(&state_str).unwrap_or_default()
} else {
CheckpointState::default()
};
Ok(Checkpoint {
id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
session_id: Uuid::from_str(&session_id)
.map_err(|e| StorageError::Internal(e.to_string()))?,
parent_id: parent_id.and_then(|id| Uuid::from_str(&id).ok()),
state,
created_at: parse_datetime(&created_at)?,
updated_at: parse_datetime(&updated_at)?,
})
} else {
Err(StorageError::NotFound(format!(
"No checkpoints found for session {}",
session_id
)))
}
}
}
#[async_trait]
impl SessionStorage for LocalStorage {
fn backend_info(&self) -> BackendInfo {
self.backend_info.clone()
}
async fn list_sessions(
&self,
query: &ListSessionsQuery,
) -> Result<ListSessionsResult, StorageError> {
let limit = query.limit.unwrap_or(100);
let offset = query.offset.unwrap_or(0);
let mut sql = "SELECT s.id, s.title, s.visibility, COALESCE(s.status, 'ACTIVE') as status, s.cwd, s.created_at, s.updated_at,
COALESCE((
SELECT json_array_length(c.state, '$.messages')
FROM checkpoints c
WHERE c.session_id = s.id
ORDER BY c.created_at DESC
LIMIT 1
), 0) as message_count,
(SELECT id FROM checkpoints c WHERE c.session_id = s.id ORDER BY c.created_at DESC LIMIT 1) as active_checkpoint_id
FROM sessions s WHERE 1=1".to_string();
if let Some(status) = &query.status {
sql.push_str(&format!(" AND s.status = '{}'", status));
}
if let Some(visibility) = &query.visibility {
sql.push_str(&format!(" AND s.visibility = '{}'", visibility));
}
if query.search.is_some() {
sql.push_str(" AND s.title LIKE '%' || ? || '%'");
}
sql.push_str(&format!(
" ORDER BY s.updated_at DESC LIMIT {} OFFSET {}",
limit, offset
));
let conn = self.connection().await?;
let mut rows = if let Some(search) = &query.search {
conn.query(&sql, [search.as_str()])
.await
.map_err(|e| StorageError::Internal(e.to_string()))?
} else {
conn.query(&sql, ())
.await
.map_err(|e| StorageError::Internal(e.to_string()))?
};
let mut sessions = Vec::new();
while let Ok(Some(row)) = rows.next().await {
let id: String = row
.get(0)
.map_err(|e| StorageError::Internal(e.to_string()))?;
let title: String = row
.get(1)
.map_err(|e| StorageError::Internal(e.to_string()))?;
let visibility: String = row
.get(2)
.map_err(|e| StorageError::Internal(e.to_string()))?;
let status: String = row
.get(3)
.map_err(|e| StorageError::Internal(e.to_string()))?;
let cwd: Option<String> = row.get(4).ok();
let created_at: String = row
.get(5)
.map_err(|e| StorageError::Internal(e.to_string()))?;
let updated_at: String = row
.get(6)
.map_err(|e| StorageError::Internal(e.to_string()))?;
let message_count: i64 = row.get(7).unwrap_or(0);
let active_checkpoint_id: Option<String> = row.get(8).ok();
sessions.push(SessionSummary {
id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
title,
visibility: parse_visibility(&visibility),
status: parse_status(&status),
cwd,
created_at: parse_datetime(&created_at)?,
updated_at: parse_datetime(&updated_at)?,
message_count: message_count.max(0) as u32,
active_checkpoint_id: active_checkpoint_id.and_then(|id| Uuid::from_str(&id).ok()),
last_message_at: None,
});
}
Ok(ListSessionsResult {
sessions,
total: None,
})
}
async fn get_session(&self, session_id: Uuid) -> Result<Session, StorageError> {
let conn = self.connection().await?;
let mut rows = conn
.query(
"SELECT id, title, visibility, COALESCE(status, 'ACTIVE') as status, cwd, created_at, updated_at FROM sessions WHERE id = ?",
[session_id.to_string()],
)
.await
.map_err(|e| StorageError::Internal(e.to_string()))?;
if let Ok(Some(row)) = rows.next().await {
let id: String = row
.get(0)
.map_err(|e| StorageError::Internal(e.to_string()))?;
let title: String = row
.get(1)
.map_err(|e| StorageError::Internal(e.to_string()))?;
let visibility: String = row
.get(2)
.map_err(|e| StorageError::Internal(e.to_string()))?;
let status: String = row
.get(3)
.map_err(|e| StorageError::Internal(e.to_string()))?;
let cwd: Option<String> = row.get(4).ok();
let created_at: String = row
.get(5)
.map_err(|e| StorageError::Internal(e.to_string()))?;
let updated_at: String = row
.get(6)
.map_err(|e| StorageError::Internal(e.to_string()))?;
let active_checkpoint =
Self::get_latest_checkpoint_for_session_inner(&conn, session_id)
.await
.ok();
Ok(Session {
id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
title,
visibility: parse_visibility(&visibility),
status: parse_status(&status),
cwd,
created_at: parse_datetime(&created_at)?,
updated_at: parse_datetime(&updated_at)?,
active_checkpoint,
})
} else {
Err(StorageError::NotFound(format!(
"Session {} not found",
session_id
)))
}
}
async fn create_session(
&self,
request: &CreateSessionRequest,
) -> Result<CreateSessionResult, StorageError> {
let now = Utc::now();
let session_id = Uuid::new_v4();
let checkpoint_id = Uuid::new_v4();
let conn = self.connection().await?;
conn.execute(
"INSERT INTO sessions (id, title, visibility, status, cwd, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
(
session_id.to_string(),
request.title.as_str(),
request.visibility.to_string(),
"ACTIVE",
request.cwd.as_deref(),
now.to_rfc3339(),
now.to_rfc3339(),
),
)
.await
.map_err(|e| StorageError::Internal(e.to_string()))?;
let state_json = serde_json::to_string(&request.initial_state)
.map_err(|e| StorageError::Internal(e.to_string()))?;
conn.execute(
"INSERT INTO checkpoints (id, session_id, parent_id, state, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)",
(
checkpoint_id.to_string(),
session_id.to_string(),
None::<String>,
state_json,
now.to_rfc3339(),
now.to_rfc3339(),
),
)
.await
.map_err(|e| StorageError::Internal(e.to_string()))?;
Ok(CreateSessionResult {
session_id,
checkpoint: Checkpoint {
id: checkpoint_id,
session_id,
parent_id: None,
state: request.initial_state.clone(),
created_at: now,
updated_at: now,
},
})
}
async fn update_session(
&self,
session_id: Uuid,
request: &UpdateSessionRequest,
) -> Result<Session, StorageError> {
let now = Utc::now();
{
let conn = self.connection().await?;
if let Some(title) = &request.title {
conn.execute(
"UPDATE sessions SET title = ?, updated_at = ? WHERE id = ?",
(title.as_str(), now.to_rfc3339(), session_id.to_string()),
)
.await
.map_err(|e| StorageError::Internal(e.to_string()))?;
}
if let Some(visibility) = &request.visibility {
conn.execute(
"UPDATE sessions SET visibility = ?, updated_at = ? WHERE id = ?",
(
visibility.to_string(),
now.to_rfc3339(),
session_id.to_string(),
),
)
.await
.map_err(|e| StorageError::Internal(e.to_string()))?;
}
}
self.get_session(session_id).await
}
async fn delete_session(&self, session_id: Uuid) -> Result<(), StorageError> {
let now = Utc::now();
let conn = self.connection().await?;
conn.execute(
"UPDATE sessions SET status = 'DELETED', updated_at = ? WHERE id = ?",
(now.to_rfc3339(), session_id.to_string()),
)
.await
.map_err(|e| StorageError::Internal(e.to_string()))?;
Ok(())
}
async fn list_checkpoints(
&self,
session_id: Uuid,
query: &ListCheckpointsQuery,
) -> Result<ListCheckpointsResult, StorageError> {
let limit = query.limit.unwrap_or(100);
let offset = query.offset.unwrap_or(0);
let sql = format!(
"SELECT id, session_id, parent_id, state, created_at, updated_at FROM checkpoints
WHERE session_id = ? ORDER BY created_at ASC LIMIT {} OFFSET {}",
limit, offset
);
let conn = self.connection().await?;
let mut rows = conn
.query(&sql, [session_id.to_string()])
.await
.map_err(|e| StorageError::Internal(e.to_string()))?;
let mut checkpoints = Vec::new();
while let Ok(Some(row)) = rows.next().await {
let id: String = row
.get(0)
.map_err(|e| StorageError::Internal(e.to_string()))?;
let session_id: String = row
.get(1)
.map_err(|e| StorageError::Internal(e.to_string()))?;
let parent_id: Option<String> = row.get(2).ok();
let state: Option<String> = row.get(3).ok();
let created_at: String = row
.get(4)
.map_err(|e| StorageError::Internal(e.to_string()))?;
let updated_at: String = row
.get(5)
.map_err(|e| StorageError::Internal(e.to_string()))?;
let state: CheckpointState = if let Some(state_str) = state {
serde_json::from_str(&state_str).unwrap_or_default()
} else {
CheckpointState::default()
};
checkpoints.push(CheckpointSummary {
id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
session_id: Uuid::from_str(&session_id)
.map_err(|e| StorageError::Internal(e.to_string()))?,
parent_id: parent_id.and_then(|id| Uuid::from_str(&id).ok()),
message_count: state.messages.len() as u32,
created_at: parse_datetime(&created_at)?,
updated_at: parse_datetime(&updated_at)?,
});
}
Ok(ListCheckpointsResult {
checkpoints,
total: None,
})
}
async fn get_checkpoint(&self, checkpoint_id: Uuid) -> Result<Checkpoint, StorageError> {
let conn = self.connection().await?;
let mut rows = conn
.query(
"SELECT id, session_id, parent_id, state, created_at, updated_at FROM checkpoints WHERE id = ?",
[checkpoint_id.to_string()],
)
.await
.map_err(|e| StorageError::Internal(e.to_string()))?;
if let Ok(Some(row)) = rows.next().await {
let id: String = row
.get(0)
.map_err(|e| StorageError::Internal(e.to_string()))?;
let session_id: String = row
.get(1)
.map_err(|e| StorageError::Internal(e.to_string()))?;
let parent_id: Option<String> = row.get(2).ok();
let state: Option<String> = row.get(3).ok();
let created_at: String = row
.get(4)
.map_err(|e| StorageError::Internal(e.to_string()))?;
let updated_at: String = row
.get(5)
.map_err(|e| StorageError::Internal(e.to_string()))?;
let state: CheckpointState = if let Some(state_str) = state {
serde_json::from_str(&state_str).unwrap_or_default()
} else {
CheckpointState::default()
};
Ok(Checkpoint {
id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
session_id: Uuid::from_str(&session_id)
.map_err(|e| StorageError::Internal(e.to_string()))?,
parent_id: parent_id.and_then(|id| Uuid::from_str(&id).ok()),
state,
created_at: parse_datetime(&created_at)?,
updated_at: parse_datetime(&updated_at)?,
})
} else {
Err(StorageError::NotFound(format!(
"Checkpoint {} not found",
checkpoint_id
)))
}
}
async fn create_checkpoint(
&self,
session_id: Uuid,
request: &CreateCheckpointRequest,
) -> Result<Checkpoint, StorageError> {
let now = Utc::now();
let checkpoint_id = Uuid::new_v4();
let state_json = serde_json::to_string(&request.state)
.map_err(|e| StorageError::Internal(e.to_string()))?;
let conn = self.connection().await?;
conn.execute(
"INSERT INTO checkpoints (id, session_id, parent_id, state, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)",
(
checkpoint_id.to_string(),
session_id.to_string(),
request.parent_id.map(|id| id.to_string()),
state_json,
now.to_rfc3339(),
now.to_rfc3339(),
),
)
.await
.map_err(|e| StorageError::Internal(e.to_string()))?;
conn.execute(
"UPDATE sessions SET updated_at = ? WHERE id = ?",
(now.to_rfc3339(), session_id.to_string()),
)
.await
.map_err(|e| StorageError::Internal(e.to_string()))?;
Ok(Checkpoint {
id: checkpoint_id,
session_id,
parent_id: request.parent_id,
state: request.state.clone(),
created_at: now,
updated_at: now,
})
}
}
fn parse_visibility(s: &str) -> SessionVisibility {
match s.to_uppercase().as_str() {
"PUBLIC" => SessionVisibility::Public,
_ => SessionVisibility::Private,
}
}
fn parse_status(s: &str) -> SessionStatus {
match s.to_uppercase().as_str() {
"DELETED" => SessionStatus::Deleted,
_ => SessionStatus::Active,
}
}
fn parse_datetime(s: &str) -> Result<DateTime<Utc>, StorageError> {
DateTime::parse_from_rfc3339(s)
.map(|dt| dt.with_timezone(&Utc))
.map_err(|e| StorageError::Internal(format!("Failed to parse datetime: {}", e)))
}