use crate::db::Pool;
use crate::db::database::interact_err;
use crate::db::models::Message;
use anyhow::{Context, Result};
use rusqlite::params;
use uuid::Uuid;
#[derive(Clone)]
pub struct MessageRepository {
pool: Pool,
}
impl MessageRepository {
pub fn new(pool: Pool) -> Self {
Self { pool }
}
pub async fn find_by_id(&self, id: Uuid) -> Result<Option<Message>> {
let id_str = id.to_string();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.prepare_cached("SELECT * FROM messages WHERE id = ?1")?
.query_row(params![id_str], Message::from_row)
.optional()
})
.await
.map_err(interact_err)?
.context("Failed to find message")
}
pub async fn find_by_session(&self, session_id: Uuid) -> Result<Vec<Message>> {
let sid = session_id.to_string();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
let mut stmt = conn.prepare_cached(
"SELECT * FROM messages WHERE session_id = ?1 ORDER BY sequence ASC",
)?;
let rows = stmt.query_map(params![sid], Message::from_row)?;
rows.collect::<std::result::Result<Vec<_>, _>>()
})
.await
.map_err(interact_err)?
.context("Failed to find messages by session")
}
pub async fn create(&self, message: &Message) -> Result<()> {
let m = message.clone();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.execute(
"INSERT INTO messages (id, session_id, role, content, sequence,
created_at, token_count, cost, input_tokens,
cache_creation_tokens, cache_read_tokens, thinking)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)",
params![
m.id.to_string(),
m.session_id.to_string(),
m.role,
m.content,
m.sequence,
m.created_at.timestamp(),
m.token_count,
m.cost,
m.input_tokens,
m.cache_creation_tokens,
m.cache_read_tokens,
m.thinking,
],
)?;
conn.execute(
"UPDATE sessions SET updated_at = ?1 WHERE id = ?2",
params![m.created_at.timestamp(), m.session_id.to_string()],
)
})
.await
.map_err(interact_err)?
.context("Failed to create message")?;
tracing::debug!(
"Created message: {} in session: {}",
message.id,
message.session_id
);
Ok(())
}
pub async fn update(&self, message: &Message) -> Result<()> {
let m = message.clone();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.execute(
"UPDATE messages
SET content = ?1, token_count = ?2, cost = ?3, input_tokens = ?4,
cache_creation_tokens = ?5, cache_read_tokens = ?6, thinking = ?7
WHERE id = ?8",
params![
m.content,
m.token_count,
m.cost,
m.input_tokens,
m.cache_creation_tokens,
m.cache_read_tokens,
m.thinking,
m.id.to_string()
],
)
})
.await
.map_err(interact_err)?
.context("Failed to update message")?;
tracing::debug!("Updated message: {}", message.id);
Ok(())
}
pub async fn last_assistant_input_tokens(&self, session_id: Uuid) -> Result<Option<i32>> {
let sid = session_id.to_string();
let tokens = self
.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.query_row(
"SELECT input_tokens FROM messages \
WHERE session_id = ?1 AND role = 'assistant' \
AND input_tokens IS NOT NULL \
ORDER BY sequence DESC LIMIT 1",
params![sid],
|r| r.get::<_, Option<i32>>(0),
)
.optional()
})
.await
.map_err(interact_err)?
.context("Failed to query last assistant input_tokens")?;
Ok(tokens.flatten())
}
pub async fn append_content(&self, id: Uuid, content_to_append: &str) -> Result<()> {
let id_str = id.to_string();
let content = content_to_append.to_string();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.execute(
"UPDATE messages SET content = content || ?1 WHERE id = ?2",
params![content, id_str],
)
})
.await
.map_err(interact_err)?
.context("Failed to append to message")?;
tracing::debug!("Appended content to message: {}", id);
Ok(())
}
pub async fn set_thinking(&self, id: Uuid, thinking: &str) -> Result<()> {
let id_str = id.to_string();
let thinking_val = thinking.to_string();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.execute(
"UPDATE messages SET thinking = ?1 WHERE id = ?2",
params![thinking_val, id_str],
)
})
.await
.map_err(interact_err)?
.context("Failed to set thinking on message")?;
tracing::debug!("Set thinking ({} chars) on message: {}", thinking.len(), id);
Ok(())
}
pub async fn append_thinking(&self, id: Uuid, thinking_to_append: &str) -> Result<()> {
let id_str = id.to_string();
let thinking = thinking_to_append.to_string();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.execute(
"UPDATE messages SET thinking = CASE \
WHEN thinking IS NULL THEN ?1 \
ELSE thinking || char(10) || char(10) || ?1 \
END \
WHERE id = ?2",
params![thinking, id_str],
)
})
.await
.map_err(interact_err)?
.context("Failed to append thinking to message")?;
tracing::debug!(
"Appended thinking ({} chars) to message: {}",
thinking_to_append.len(),
id
);
Ok(())
}
pub async fn delete(&self, id: Uuid) -> Result<()> {
let id_str = id.to_string();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.execute("DELETE FROM messages WHERE id = ?1", params![id_str])
})
.await
.map_err(interact_err)?
.context("Failed to delete message")?;
tracing::debug!("Deleted message: {}", id);
Ok(())
}
pub async fn list_by_session(&self, session_id: Uuid) -> Result<Vec<Message>> {
self.find_by_session(session_id).await
}
pub async fn search_by_content(
&self,
session_ids: Option<&[Uuid]>,
query: &str,
limit: usize,
) -> Result<Vec<Message>> {
let pattern = format!("%{}%", query.replace('%', "\\%").replace('_', "\\_"));
let limit_i = limit as i64;
let id_strs: Option<Vec<String>> =
session_ids.map(|ids| ids.iter().map(|i| i.to_string()).collect());
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
if let Some(ids) = id_strs {
if ids.is_empty() {
return Ok::<Vec<Message>, rusqlite::Error>(Vec::new());
}
let placeholders = vec!["?"; ids.len()].join(",");
let sql = format!(
"SELECT * FROM messages \
WHERE session_id IN ({}) AND content LIKE ? ESCAPE '\\' \
COLLATE NOCASE \
ORDER BY created_at DESC LIMIT ?",
placeholders
);
let mut stmt = conn.prepare(&sql)?;
let mut params_vec: Vec<&dyn rusqlite::ToSql> =
Vec::with_capacity(ids.len() + 2);
for id in &ids {
params_vec.push(id);
}
params_vec.push(&pattern);
params_vec.push(&limit_i);
let rows = stmt.query_map(params_vec.as_slice(), Message::from_row)?;
rows.collect::<std::result::Result<Vec<_>, _>>()
} else {
let mut stmt = conn.prepare_cached(
"SELECT * FROM messages \
WHERE content LIKE ?1 ESCAPE '\\' COLLATE NOCASE \
ORDER BY created_at DESC LIMIT ?2",
)?;
let rows = stmt.query_map(params![pattern, limit_i], Message::from_row)?;
rows.collect::<std::result::Result<Vec<_>, _>>()
}
})
.await
.map_err(interact_err)?
.context("Failed to search messages by content")
}
pub async fn count_by_session(&self, session_id: Uuid) -> Result<i64> {
let sid = session_id.to_string();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.query_row(
"SELECT COUNT(*) FROM messages WHERE session_id = ?1",
params![sid],
|row| row.get(0),
)
})
.await
.map_err(interact_err)?
.context("Failed to count messages")
}
pub async fn get_last_message(&self, session_id: Uuid) -> Result<Option<Message>> {
let sid = session_id.to_string();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.prepare_cached(
"SELECT * FROM messages WHERE session_id = ?1 ORDER BY sequence DESC LIMIT 1",
)?
.query_row(params![sid], Message::from_row)
.optional()
})
.await
.map_err(interact_err)?
.context("Failed to get last message")
}
pub async fn get_last_assistant_message(&self, session_id: Uuid) -> Result<Option<Message>> {
let sid = session_id.to_string();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.prepare_cached(
"SELECT * FROM messages WHERE session_id = ?1 AND role = 'assistant' \
ORDER BY sequence DESC LIMIT 1",
)?
.query_row(params![sid], Message::from_row)
.optional()
})
.await
.map_err(interact_err)?
.context("Failed to get last assistant message")
}
pub async fn delete_by_session(&self, session_id: Uuid) -> Result<()> {
let sid = session_id.to_string();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.execute("DELETE FROM messages WHERE session_id = ?1", params![sid])
})
.await
.map_err(interact_err)?
.context("Failed to delete session messages")?;
tracing::debug!("Deleted all messages for session: {}", session_id);
Ok(())
}
}
trait OptionalExt<T> {
fn optional(self) -> rusqlite::Result<Option<T>>;
}
impl<T> OptionalExt<T> for rusqlite::Result<T> {
fn optional(self) -> rusqlite::Result<Option<T>> {
match self {
Ok(v) => Ok(Some(v)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(e),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::db::Database;
use crate::db::models::Session;
use crate::db::repository::SessionRepository;
#[tokio::test]
async fn test_message_crud() {
let db = Database::connect_in_memory()
.await
.expect("Failed to create database");
db.run_migrations().await.expect("Failed to run migrations");
let session_repo = SessionRepository::new(db.pool().clone());
let message_repo = MessageRepository::new(db.pool().clone());
let session = Session::new(Some("Test".to_string()), Some("model".to_string()), None);
session_repo
.create(&session)
.await
.expect("Failed to create session");
let message = Message::new(session.id, "user".to_string(), "Hello!".to_string(), 1);
message_repo
.create(&message)
.await
.expect("Failed to create message");
let found = message_repo
.find_by_id(message.id)
.await
.expect("Failed to find");
assert!(found.is_some());
assert_eq!(found.unwrap().content, "Hello!");
let mut updated = message.clone();
updated.content = "Updated content".to_string();
message_repo
.update(&updated)
.await
.expect("Failed to update");
let found = message_repo
.find_by_id(message.id)
.await
.expect("Failed to find");
assert_eq!(found.unwrap().content, "Updated content");
message_repo
.delete(message.id)
.await
.expect("Failed to delete");
let found = message_repo
.find_by_id(message.id)
.await
.expect("Failed to find");
assert!(found.is_none());
}
#[tokio::test]
async fn test_message_list_by_session() {
let db = Database::connect_in_memory()
.await
.expect("Failed to create database");
db.run_migrations().await.expect("Failed to run migrations");
let session_repo = SessionRepository::new(db.pool().clone());
let message_repo = MessageRepository::new(db.pool().clone());
let session = Session::new(Some("Test".to_string()), Some("model".to_string()), None);
session_repo
.create(&session)
.await
.expect("Failed to create session");
for i in 0..3 {
let msg = Message::new(
session.id,
"user".to_string(),
format!("Message {}", i),
i + 1,
);
message_repo
.create(&msg)
.await
.expect("Failed to create message");
}
let messages = message_repo
.list_by_session(session.id)
.await
.expect("Failed to list");
assert_eq!(messages.len(), 3);
let count = message_repo
.count_by_session(session.id)
.await
.expect("Failed to count");
assert_eq!(count, 3);
}
}