use crate::db::Pool;
use crate::db::database::interact_err;
use crate::db::models::Message;
use anyhow::{Context, Result};
use rusqlite::params;
use uuid::Uuid;
#[derive(Clone)]
pub struct MessageRepository {
pool: Pool,
}
impl MessageRepository {
pub fn new(pool: Pool) -> Self {
Self { pool }
}
pub async fn find_by_id(&self, id: Uuid) -> Result<Option<Message>> {
let id_str = id.to_string();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.prepare_cached("SELECT * FROM messages WHERE id = ?1")?
.query_row(params![id_str], Message::from_row)
.optional()
})
.await
.map_err(interact_err)?
.context("Failed to find message")
}
pub async fn find_by_session(&self, session_id: Uuid) -> Result<Vec<Message>> {
let sid = session_id.to_string();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
let mut stmt = conn.prepare_cached(
"SELECT * FROM messages WHERE session_id = ?1 ORDER BY sequence ASC",
)?;
let rows = stmt.query_map(params![sid], Message::from_row)?;
rows.collect::<std::result::Result<Vec<_>, _>>()
})
.await
.map_err(interact_err)?
.context("Failed to find messages by session")
}
pub async fn create(&self, message: &Message) -> Result<()> {
let m = message.clone();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.execute(
"INSERT INTO messages (id, session_id, role, content, sequence,
created_at, token_count, cost)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
params![
m.id.to_string(),
m.session_id.to_string(),
m.role,
m.content,
m.sequence,
m.created_at.timestamp(),
m.token_count,
m.cost,
],
)
})
.await
.map_err(interact_err)?
.context("Failed to create message")?;
tracing::debug!(
"Created message: {} in session: {}",
message.id,
message.session_id
);
Ok(())
}
pub async fn update(&self, message: &Message) -> Result<()> {
let m = message.clone();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.execute(
"UPDATE messages
SET content = ?1, token_count = ?2, cost = ?3, input_tokens = ?4
WHERE id = ?5",
params![
m.content,
m.token_count,
m.cost,
m.input_tokens,
m.id.to_string()
],
)
})
.await
.map_err(interact_err)?
.context("Failed to update message")?;
tracing::debug!("Updated message: {}", message.id);
Ok(())
}
pub async fn last_assistant_input_tokens(&self, session_id: Uuid) -> Result<Option<i32>> {
let sid = session_id.to_string();
let tokens = self
.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.query_row(
"SELECT input_tokens FROM messages \
WHERE session_id = ?1 AND role = 'assistant' \
AND input_tokens IS NOT NULL \
ORDER BY sequence DESC LIMIT 1",
params![sid],
|r| r.get::<_, Option<i32>>(0),
)
.optional()
})
.await
.map_err(interact_err)?
.context("Failed to query last assistant input_tokens")?;
Ok(tokens.flatten())
}
pub async fn append_content(&self, id: Uuid, content_to_append: &str) -> Result<()> {
let id_str = id.to_string();
let content = content_to_append.to_string();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.execute(
"UPDATE messages SET content = content || ?1 WHERE id = ?2",
params![content, id_str],
)
})
.await
.map_err(interact_err)?
.context("Failed to append to message")?;
tracing::debug!("Appended content to message: {}", id);
Ok(())
}
pub async fn delete(&self, id: Uuid) -> Result<()> {
let id_str = id.to_string();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.execute("DELETE FROM messages WHERE id = ?1", params![id_str])
})
.await
.map_err(interact_err)?
.context("Failed to delete message")?;
tracing::debug!("Deleted message: {}", id);
Ok(())
}
pub async fn list_by_session(&self, session_id: Uuid) -> Result<Vec<Message>> {
self.find_by_session(session_id).await
}
pub async fn count_by_session(&self, session_id: Uuid) -> Result<i64> {
let sid = session_id.to_string();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.query_row(
"SELECT COUNT(*) FROM messages WHERE session_id = ?1",
params![sid],
|row| row.get(0),
)
})
.await
.map_err(interact_err)?
.context("Failed to count messages")
}
pub async fn get_last_message(&self, session_id: Uuid) -> Result<Option<Message>> {
let sid = session_id.to_string();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.prepare_cached(
"SELECT * FROM messages WHERE session_id = ?1 ORDER BY sequence DESC LIMIT 1",
)?
.query_row(params![sid], Message::from_row)
.optional()
})
.await
.map_err(interact_err)?
.context("Failed to get last message")
}
pub async fn delete_by_session(&self, session_id: Uuid) -> Result<()> {
let sid = session_id.to_string();
self.pool
.get()
.await
.context("Failed to get connection")?
.interact(move |conn| {
conn.execute("DELETE FROM messages WHERE session_id = ?1", params![sid])
})
.await
.map_err(interact_err)?
.context("Failed to delete session messages")?;
tracing::debug!("Deleted all messages for session: {}", session_id);
Ok(())
}
}
trait OptionalExt<T> {
fn optional(self) -> rusqlite::Result<Option<T>>;
}
impl<T> OptionalExt<T> for rusqlite::Result<T> {
fn optional(self) -> rusqlite::Result<Option<T>> {
match self {
Ok(v) => Ok(Some(v)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(e),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::db::Database;
use crate::db::models::Session;
use crate::db::repository::SessionRepository;
#[tokio::test]
async fn test_message_crud() {
let db = Database::connect_in_memory()
.await
.expect("Failed to create database");
db.run_migrations().await.expect("Failed to run migrations");
let session_repo = SessionRepository::new(db.pool().clone());
let message_repo = MessageRepository::new(db.pool().clone());
let session = Session::new(Some("Test".to_string()), Some("model".to_string()), None);
session_repo
.create(&session)
.await
.expect("Failed to create session");
let message = Message::new(session.id, "user".to_string(), "Hello!".to_string(), 1);
message_repo
.create(&message)
.await
.expect("Failed to create message");
let found = message_repo
.find_by_id(message.id)
.await
.expect("Failed to find");
assert!(found.is_some());
assert_eq!(found.unwrap().content, "Hello!");
let mut updated = message.clone();
updated.content = "Updated content".to_string();
message_repo
.update(&updated)
.await
.expect("Failed to update");
let found = message_repo
.find_by_id(message.id)
.await
.expect("Failed to find");
assert_eq!(found.unwrap().content, "Updated content");
message_repo
.delete(message.id)
.await
.expect("Failed to delete");
let found = message_repo
.find_by_id(message.id)
.await
.expect("Failed to find");
assert!(found.is_none());
}
#[tokio::test]
async fn test_message_list_by_session() {
let db = Database::connect_in_memory()
.await
.expect("Failed to create database");
db.run_migrations().await.expect("Failed to run migrations");
let session_repo = SessionRepository::new(db.pool().clone());
let message_repo = MessageRepository::new(db.pool().clone());
let session = Session::new(Some("Test".to_string()), Some("model".to_string()), None);
session_repo
.create(&session)
.await
.expect("Failed to create session");
for i in 0..3 {
let msg = Message::new(
session.id,
"user".to_string(),
format!("Message {}", i),
i + 1,
);
message_repo
.create(&msg)
.await
.expect("Failed to create message");
}
let messages = message_repo
.list_by_session(session.id)
.await
.expect("Failed to list");
assert_eq!(messages.len(), 3);
let count = message_repo
.count_by_session(session.id)
.await
.expect("Failed to count");
assert_eq!(count, 3);
}
}