Skip to main content

robson_core/entities/
conversation.rs

1use anyhow::Result;
2use chrono::Utc;
3use sea_orm::entity::prelude::*;
4use sea_orm::ActiveValue::NotSet;
5use sea_orm::{ActiveValue::Set, QueryOrder};
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, PartialEq, Eq, EnumIter, DeriveActiveEnum, Serialize, Deserialize)]
9#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")]
10pub enum ConversationRole {
11    #[sea_orm(string_value = "user")]
12    User,
13    #[sea_orm(string_value = "assistant")]
14    Assistant,
15    #[sea_orm(string_value = "system")]
16    System,
17}
18
19impl std::fmt::Display for ConversationRole {
20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21        match self {
22            ConversationRole::User => write!(f, "user"),
23            ConversationRole::Assistant => write!(f, "assistant"),
24            ConversationRole::System => write!(f, "system"),
25        }
26    }
27}
28
29#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Serialize, Deserialize)]
30#[sea_orm(table_name = "conversations")]
31pub struct Model {
32    #[sea_orm(primary_key)]
33    pub id: i32,
34    /// The channel or callback URL for the originating gateway.
35    /// Previously named `channel_id`; renamed in migration 006.
36    pub gateway_channel_id: String,
37    pub thread_ts: String,
38    pub user_id: String,
39    pub role: ConversationRole,
40    pub content: String,
41    pub created_at: String,
42    /// False for newly inserted rows; true once handled by the SensoriumLoop.
43    pub processed: bool,
44    /// FK to `gateways.id` — identifies which gateway originated this conversation.
45    /// NULL for rows created before migration 006 via the legacy AgentGateway path.
46    pub gateway_id: Option<i32>,
47}
48
49#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
50pub enum Relation {}
51
52impl ActiveModelBehavior for ActiveModel {}
53
54impl Model {
55    pub async fn insert(
56        db: &DatabaseConnection,
57        gateway_id: Option<i32>,
58        gateway_channel_id: &str,
59        thread_ts: &str,
60        user_id: &str,
61        role: ConversationRole,
62        content: &str,
63    ) -> Result<Self> {
64        let now = Utc::now().to_rfc3339();
65        let active = ActiveModel {
66            id: NotSet,
67            gateway_channel_id: Set(gateway_channel_id.to_string()),
68            thread_ts: Set(thread_ts.to_string()),
69            user_id: Set(user_id.to_string()),
70            role: Set(role),
71            content: Set(content.to_string()),
72            created_at: Set(now),
73            processed: Set(false),
74            gateway_id: Set(gateway_id),
75        };
76        Ok(active.insert(db).await?)
77    }
78
79    /// Returns all conversations that have not yet been processed by the SensoriumLoop.
80    pub async fn find_unprocessed(db: &DatabaseConnection) -> Result<Vec<Model>> {
81        let rows = Entity::find()
82            .filter(Column::Processed.eq(false))
83            .order_by_asc(Column::CreatedAt)
84            .all(db)
85            .await?;
86        Ok(rows)
87    }
88
89    /// Mark a conversation row as processed so it won't be picked up again.
90    pub async fn mark_processed(db: &DatabaseConnection, id: i32) -> Result<()> {
91        let active = ActiveModel {
92            id: Set(id),
93            processed: Set(true),
94            gateway_channel_id: NotSet,
95            thread_ts: NotSet,
96            user_id: NotSet,
97            role: NotSet,
98            content: NotSet,
99            created_at: NotSet,
100            gateway_id: NotSet,
101        };
102        active.update(db).await?;
103        Ok(())
104    }
105
106    pub async fn find_by_thread(
107        db: &DatabaseConnection,
108        gateway_channel_id: &str,
109        thread_ts: &str,
110    ) -> Result<Vec<Model>> {
111        let rows = Entity::find()
112            .filter(Column::GatewayChannelId.eq(gateway_channel_id))
113            .filter(Column::ThreadTs.eq(thread_ts))
114            .order_by_asc(Column::CreatedAt)
115            .all(db)
116            .await?;
117        Ok(rows)
118    }
119
120    /// Look up a single conversation row by primary key.
121    pub async fn find_by_id(db: &DatabaseConnection, id: i32) -> Result<Option<Model>> {
122        Ok(Entity::find_by_id(id).one(db).await?)
123    }
124
125    /// Keep only the last `keep_last_n` rows for a given thread, deleting the rest.
126    pub async fn delete_old_turns(
127        db: &DatabaseConnection,
128        gateway_channel_id: &str,
129        thread_ts: &str,
130        keep_last_n: u64,
131    ) -> Result<()> {
132        let all = Self::find_by_thread(db, gateway_channel_id, thread_ts).await?;
133        let total = all.len() as u64;
134        if total <= keep_last_n {
135            return Ok(());
136        }
137        let to_delete = total - keep_last_n;
138        let ids_to_delete: Vec<i32> = all
139            .into_iter()
140            .take(to_delete as usize)
141            .map(|m| m.id)
142            .collect();
143
144        Entity::delete_many()
145            .filter(Column::Id.is_in(ids_to_delete))
146            .exec(db)
147            .await?;
148        Ok(())
149    }
150}