use crate::db::Pool;
use crate::db::database::interact_err;
use crate::db::models::Message;
use anyhow::{Context, Result};
use rusqlite::params;
use uuid::Uuid;
#[derive(Clone)]
pub struct MessageRepository {
pool: Pool,
}
impl MessageRepository {
pub fn new(pool: Pool) -> Self {
Self { pool }
}
pub async fn find_by_id(&self, id: Uuid) -> Result<Option<Message>> {
let id_str = id.to_string();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.prepare_cached("SELECT * FROM messages WHERE id = ?1")?
.query_row(params![id_str], Message::from_row)
.optional()
})
.await
.map_err(interact_err)?
.context("Failed to find message")
}
pub async fn find_by_session(&self, session_id: Uuid) -> Result<Vec<Message>> {
let sid = session_id.to_string();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
let mut stmt = conn.prepare_cached(
"SELECT * FROM messages WHERE session_id = ?1 ORDER BY sequence ASC",
)?;
let rows = stmt.query_map(params![sid], Message::from_row)?;
rows.collect::<std::result::Result<Vec<_>, _>>()
})
.await
.map_err(interact_err)?
.context("Failed to find messages by session")
}
pub async fn create(&self, message: &Message) -> Result<()> {
let m = message.clone();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.execute(
"INSERT INTO messages (id, session_id, role, content, sequence,
created_at, token_count, cost, input_tokens,
cache_creation_tokens, cache_read_tokens, thinking)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)",
params![
m.id.to_string(),
m.session_id.to_string(),
m.role,
m.content,
m.sequence,
m.created_at.timestamp(),
m.token_count,
m.cost,
m.input_tokens,
m.cache_creation_tokens,
m.cache_read_tokens,
m.thinking,
],
)?;
conn.execute(
"UPDATE sessions SET updated_at = ?1 WHERE id = ?2",
params![m.created_at.timestamp(), m.session_id.to_string()],
)
})
.await
.map_err(interact_err)?
.context("Failed to create message")?;
tracing::debug!(
"Created message: {} in session: {}",
message.id,
message.session_id
);
Ok(())
}
pub async fn update(&self, message: &Message) -> Result<()> {
let m = message.clone();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.execute(
"UPDATE messages
SET content = ?1, token_count = ?2, cost = ?3, input_tokens = ?4,
cache_creation_tokens = ?5, cache_read_tokens = ?6, thinking = ?7
WHERE id = ?8",
params![
m.content,
m.token_count,
m.cost,
m.input_tokens,
m.cache_creation_tokens,
m.cache_read_tokens,
m.thinking,
m.id.to_string()
],
)
})
.await
.map_err(interact_err)?
.context("Failed to update message")?;
tracing::debug!("Updated message: {}", message.id);
Ok(())
}
pub async fn last_assistant_input_tokens(&self, session_id: Uuid) -> Result<Option<i32>> {
let sid = session_id.to_string();
let tokens = self
.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.query_row(
"SELECT input_tokens FROM messages \
WHERE session_id = ?1 AND role = 'assistant' \
AND input_tokens IS NOT NULL \
ORDER BY sequence DESC LIMIT 1",
params![sid],
|r| r.get::<_, Option<i32>>(0),
)
.optional()
})
.await
.map_err(interact_err)?
.context("Failed to query last assistant input_tokens")?;
Ok(tokens.flatten())
}
pub async fn append_content(&self, id: Uuid, content_to_append: &str) -> Result<()> {
let id_str = id.to_string();
let content = content_to_append.to_string();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.execute(
"UPDATE messages SET content = content || ?1 WHERE id = ?2",
params![content, id_str],
)
})
.await
.map_err(interact_err)?
.context("Failed to append to message")?;
tracing::debug!("Appended content to message: {}", id);
Ok(())
}
pub async fn set_thinking(&self, id: Uuid, thinking: &str) -> Result<()> {
let id_str = id.to_string();
let thinking_val = thinking.to_string();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.execute(
"UPDATE messages SET thinking = ?1 WHERE id = ?2",
params![thinking_val, id_str],
)
})
.await
.map_err(interact_err)?
.context("Failed to set thinking on message")?;
tracing::debug!("Set thinking ({} chars) on message: {}", thinking.len(), id);
Ok(())
}
pub async fn append_thinking(&self, id: Uuid, thinking_to_append: &str) -> Result<()> {
let id_str = id.to_string();
let thinking = thinking_to_append.to_string();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.execute(
"UPDATE messages SET thinking = CASE \
WHEN thinking IS NULL THEN ?1 \
ELSE thinking || char(10) || char(10) || ?1 \
END \
WHERE id = ?2",
params![thinking, id_str],
)
})
.await
.map_err(interact_err)?
.context("Failed to append thinking to message")?;
tracing::debug!(
"Appended thinking ({} chars) to message: {}",
thinking_to_append.len(),
id
);
Ok(())
}
pub async fn delete(&self, id: Uuid) -> Result<()> {
let id_str = id.to_string();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.execute("DELETE FROM messages WHERE id = ?1", params![id_str])
})
.await
.map_err(interact_err)?
.context("Failed to delete message")?;
tracing::debug!("Deleted message: {}", id);
Ok(())
}
pub async fn list_by_session(&self, session_id: Uuid) -> Result<Vec<Message>> {
self.find_by_session(session_id).await
}
pub async fn search_by_content(
&self,
session_ids: Option<&[Uuid]>,
query: &str,
limit: usize,
) -> Result<Vec<Message>> {
let pattern = format!("%{}%", query.replace('%', "\\%").replace('_', "\\_"));
let limit_i = limit as i64;
let id_strs: Option<Vec<String>> =
session_ids.map(|ids| ids.iter().map(|i| i.to_string()).collect());
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
if let Some(ids) = id_strs {
if ids.is_empty() {
return Ok::<Vec<Message>, rusqlite::Error>(Vec::new());
}
let placeholders = vec!["?"; ids.len()].join(",");
let sql = format!(
"SELECT * FROM messages \
WHERE session_id IN ({}) AND content LIKE ? ESCAPE '\\' \
COLLATE NOCASE \
ORDER BY created_at DESC LIMIT ?",
placeholders
);
let mut stmt = conn.prepare(&sql)?;
let mut params_vec: Vec<&dyn rusqlite::ToSql> =
Vec::with_capacity(ids.len() + 2);
for id in &ids {
params_vec.push(id);
}
params_vec.push(&pattern);
params_vec.push(&limit_i);
let rows = stmt.query_map(params_vec.as_slice(), Message::from_row)?;
rows.collect::<std::result::Result<Vec<_>, _>>()
} else {
let mut stmt = conn.prepare_cached(
"SELECT * FROM messages \
WHERE content LIKE ?1 ESCAPE '\\' COLLATE NOCASE \
ORDER BY created_at DESC LIMIT ?2",
)?;
let rows = stmt.query_map(params![pattern, limit_i], Message::from_row)?;
rows.collect::<std::result::Result<Vec<_>, _>>()
}
})
.await
.map_err(interact_err)?
.context("Failed to search messages by content")
}
pub async fn count_by_session(&self, session_id: Uuid) -> Result<i64> {
let sid = session_id.to_string();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.query_row(
"SELECT COUNT(*) FROM messages WHERE session_id = ?1",
params![sid],
|row| row.get(0),
)
})
.await
.map_err(interact_err)?
.context("Failed to count messages")
}
pub async fn get_last_message(&self, session_id: Uuid) -> Result<Option<Message>> {
let sid = session_id.to_string();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.prepare_cached(
"SELECT * FROM messages WHERE session_id = ?1 ORDER BY sequence DESC LIMIT 1",
)?
.query_row(params![sid], Message::from_row)
.optional()
})
.await
.map_err(interact_err)?
.context("Failed to get last message")
}
pub async fn get_last_assistant_message(&self, session_id: Uuid) -> Result<Option<Message>> {
let sid = session_id.to_string();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.prepare_cached(
"SELECT * FROM messages WHERE session_id = ?1 AND role = 'assistant' \
ORDER BY sequence DESC LIMIT 1",
)?
.query_row(params![sid], Message::from_row)
.optional()
})
.await
.map_err(interact_err)?
.context("Failed to get last assistant message")
}
pub async fn delete_by_session(&self, session_id: Uuid) -> Result<()> {
let sid = session_id.to_string();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.execute("DELETE FROM messages WHERE session_id = ?1", params![sid])
})
.await
.map_err(interact_err)?
.context("Failed to delete session messages")?;
tracing::debug!("Deleted all messages for session: {}", session_id);
Ok(())
}
}
trait OptionalExt<T> {
fn optional(self) -> rusqlite::Result<Option<T>>;
}
impl<T> OptionalExt<T> for rusqlite::Result<T> {
fn optional(self) -> rusqlite::Result<Option<T>> {
match self {
Ok(v) => Ok(Some(v)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(e),
}
}
}