use anyhow::Result;
use chrono::Utc;
use sea_orm::entity::prelude::*;
use sea_orm::ActiveValue::NotSet;
use sea_orm::{ActiveValue::Set, QueryOrder};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, EnumIter, DeriveActiveEnum, Serialize, Deserialize)]
#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")]
pub enum ConversationRole {
#[sea_orm(string_value = "user")]
User,
#[sea_orm(string_value = "assistant")]
Assistant,
#[sea_orm(string_value = "system")]
System,
}
impl std::fmt::Display for ConversationRole {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConversationRole::User => write!(f, "user"),
ConversationRole::Assistant => write!(f, "assistant"),
ConversationRole::System => write!(f, "system"),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Serialize, Deserialize)]
#[sea_orm(table_name = "conversations")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
pub gateway_channel_id: String,
pub thread_ts: String,
pub user_id: String,
pub role: ConversationRole,
pub content: String,
pub created_at: String,
pub processed: bool,
pub gateway_id: Option<i32>,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl ActiveModelBehavior for ActiveModel {}
impl Model {
pub async fn insert(
db: &DatabaseConnection,
gateway_id: Option<i32>,
gateway_channel_id: &str,
thread_ts: &str,
user_id: &str,
role: ConversationRole,
content: &str,
) -> Result<Self> {
let now = Utc::now().to_rfc3339();
let active = ActiveModel {
id: NotSet,
gateway_channel_id: Set(gateway_channel_id.to_string()),
thread_ts: Set(thread_ts.to_string()),
user_id: Set(user_id.to_string()),
role: Set(role),
content: Set(content.to_string()),
created_at: Set(now),
processed: Set(false),
gateway_id: Set(gateway_id),
};
Ok(active.insert(db).await?)
}
pub async fn find_unprocessed(db: &DatabaseConnection) -> Result<Vec<Model>> {
let rows = Entity::find()
.filter(Column::Processed.eq(false))
.order_by_asc(Column::CreatedAt)
.all(db)
.await?;
Ok(rows)
}
pub async fn mark_processed(db: &DatabaseConnection, id: i32) -> Result<()> {
let active = ActiveModel {
id: Set(id),
processed: Set(true),
gateway_channel_id: NotSet,
thread_ts: NotSet,
user_id: NotSet,
role: NotSet,
content: NotSet,
created_at: NotSet,
gateway_id: NotSet,
};
active.update(db).await?;
Ok(())
}
pub async fn find_by_thread(
db: &DatabaseConnection,
gateway_channel_id: &str,
thread_ts: &str,
) -> Result<Vec<Model>> {
let rows = Entity::find()
.filter(Column::GatewayChannelId.eq(gateway_channel_id))
.filter(Column::ThreadTs.eq(thread_ts))
.order_by_asc(Column::CreatedAt)
.all(db)
.await?;
Ok(rows)
}
pub async fn find_by_id(db: &DatabaseConnection, id: i32) -> Result<Option<Model>> {
Ok(Entity::find_by_id(id).one(db).await?)
}
pub async fn delete_old_turns(
db: &DatabaseConnection,
gateway_channel_id: &str,
thread_ts: &str,
keep_last_n: u64,
) -> Result<()> {
let all = Self::find_by_thread(db, gateway_channel_id, thread_ts).await?;
let total = all.len() as u64;
if total <= keep_last_n {
return Ok(());
}
let to_delete = total - keep_last_n;
let ids_to_delete: Vec<i32> = all
.into_iter()
.take(to_delete as usize)
.map(|m| m.id)
.collect();
Entity::delete_many()
.filter(Column::Id.is_in(ids_to_delete))
.exec(db)
.await?;
Ok(())
}
}