use crate::scenes::MemScene;
use crate::types::{MemSceneId, MessageId};
#[allow(unused_imports)]
use zeph_db::sql;
use zeph_db::ActiveDialect;
use crate::error::MemoryError;
use super::SqliteStore;
impl SqliteStore {
pub async fn find_unscened_semantic_messages(
&self,
limit: usize,
) -> Result<Vec<(MessageId, String)>, MemoryError> {
let limit_i64 = i64::try_from(limit).unwrap_or(i64::MAX);
let rows: Vec<(i64, String)> = zeph_db::query_as(
r"
SELECT m.id, m.content
FROM messages m
WHERE m.tier = 'semantic'
AND m.deleted_at IS NULL
AND m.id NOT IN (SELECT message_id FROM mem_scene_members)
ORDER BY m.id ASC
LIMIT ?
",
)
.bind(limit_i64)
.fetch_all(&self.pool)
.await?;
Ok(rows
.into_iter()
.map(|(id, content)| (MessageId(id), content))
.collect())
}
pub async fn insert_mem_scene(
&self,
label: &str,
profile: &str,
member_ids: &[MessageId],
) -> Result<MemSceneId, MemoryError> {
let member_count = i64::try_from(member_ids.len()).unwrap_or(0);
let mut tx = self.pool.begin().await?;
let row: (i64,) = zeph_db::query_as(sql!(
"INSERT INTO mem_scenes (label, profile, member_count) VALUES (?, ?, ?) RETURNING id"
))
.bind(label)
.bind(profile)
.bind(member_count)
.fetch_one(&mut *tx)
.await?;
let scene_id = row.0;
let member_sql = format!(
"{} INTO mem_scene_members (scene_id, message_id) VALUES (?, ?){}",
<ActiveDialect as zeph_db::dialect::Dialect>::INSERT_IGNORE,
<ActiveDialect as zeph_db::dialect::Dialect>::CONFLICT_NOTHING,
);
for &msg_id in member_ids {
zeph_db::query(&member_sql)
.bind(scene_id)
.bind(msg_id.0)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
Ok(MemSceneId(scene_id))
}
pub async fn list_mem_scenes(&self) -> Result<Vec<MemScene>, MemoryError> {
let rows: Vec<(i64, String, String, i64, i64, i64)> = zeph_db::query_as(sql!(
"SELECT id, label, profile, member_count, created_at, updated_at \
FROM mem_scenes ORDER BY created_at DESC"
))
.fetch_all(&self.pool)
.await?;
Ok(rows
.into_iter()
.map(
|(id, label, profile, member_count, created_at, updated_at)| MemScene {
id: MemSceneId(id),
label,
profile,
member_count: u32::try_from(member_count).unwrap_or(0),
created_at,
updated_at,
},
)
.collect())
}
pub async fn scene_member_ids(
&self,
scene_id: MemSceneId,
) -> Result<Vec<MessageId>, MemoryError> {
let rows: Vec<(i64,)> = zeph_db::query_as(sql!(
"SELECT message_id FROM mem_scene_members WHERE scene_id = ?"
))
.bind(scene_id.0)
.fetch_all(&self.pool)
.await?;
Ok(rows.into_iter().map(|(id,)| MessageId(id)).collect())
}
pub async fn reset_mem_scenes(&self) -> Result<u64, MemoryError> {
let result = zeph_db::query(sql!("DELETE FROM mem_scenes"))
.execute(&self.pool)
.await?;
Ok(result.rows_affected())
}
}
#[cfg(test)]
mod tests {
use super::*;
async fn make_store() -> SqliteStore {
SqliteStore::with_pool_size(":memory:", 1).await.unwrap()
}
async fn seed_messages(store: &SqliteStore, n: usize) -> Vec<MessageId> {
let cid = store.create_conversation().await.unwrap();
let mut ids = Vec::with_capacity(n);
for i in 0..n {
let id = store
.save_message(cid, "user", &format!("msg {i}"))
.await
.unwrap();
ids.push(id);
}
ids
}
#[tokio::test]
async fn insert_and_list_scene() {
let store = make_store().await;
let ids = seed_messages(&store, 2).await;
let scene_id = store
.insert_mem_scene("Rust Auth", "JWT tokens used for RS256.", &ids)
.await
.unwrap();
assert!(scene_id.0 > 0, "scene id must be positive");
let scenes = store.list_mem_scenes().await.unwrap();
assert_eq!(scenes.len(), 1);
assert_eq!(scenes[0].label, "Rust Auth");
assert_eq!(scenes[0].member_count, 2);
}
#[tokio::test]
async fn scene_member_ids_expansion() {
let store = make_store().await;
let ids = seed_messages(&store, 3).await;
let scene_id = store
.insert_mem_scene("Topic A", "Profile text.", &ids)
.await
.unwrap();
let members = store.scene_member_ids(scene_id).await.unwrap();
assert_eq!(members.len(), 3);
for id in &ids {
assert!(members.contains(id), "member {id} must be in expansion");
}
}
#[tokio::test]
async fn find_unscened_excludes_assigned_members() {
let store = make_store().await;
let ids = seed_messages(&store, 3).await;
for id in &ids {
zeph_db::query(sql!("UPDATE messages SET tier = 'semantic' WHERE id = ?"))
.bind(id.0)
.execute(store.pool())
.await
.unwrap();
}
let unscened = store.find_unscened_semantic_messages(100).await.unwrap();
assert_eq!(unscened.len(), 3);
store
.insert_mem_scene("Partial Scene", "Some profile", &ids[..2])
.await
.unwrap();
let unscened_after = store.find_unscened_semantic_messages(100).await.unwrap();
assert_eq!(unscened_after.len(), 1);
assert_eq!(unscened_after[0].0, ids[2]);
}
#[tokio::test]
async fn reset_scenes_clears_all() {
let store = make_store().await;
let ids1 = seed_messages(&store, 1).await;
let ids2 = seed_messages(&store, 1).await;
store
.insert_mem_scene("Scene 1", "Profile 1", &ids1)
.await
.unwrap();
store
.insert_mem_scene("Scene 2", "Profile 2", &ids2)
.await
.unwrap();
let deleted = store.reset_mem_scenes().await.unwrap();
assert_eq!(deleted, 2);
let scenes = store.list_mem_scenes().await.unwrap();
assert!(scenes.is_empty());
}
#[tokio::test]
async fn list_scenes_ordered_newest_first() {
let store = make_store().await;
let ids1 = seed_messages(&store, 1).await;
let ids2 = seed_messages(&store, 1).await;
zeph_db::query(
sql!("INSERT INTO mem_scenes (label, profile, member_count, created_at, updated_at) VALUES (?, ?, ?, ?, ?)"),
)
.bind("First")
.bind("Profile first")
.bind(1i64)
.bind(1_000_000i64)
.bind(1_000_000i64)
.execute(store.pool())
.await
.unwrap();
let scene1_id: (i64,) = zeph_db::query_as(sql!("SELECT last_insert_rowid()"))
.fetch_one(store.pool())
.await
.unwrap();
zeph_db::query(
sql!("INSERT INTO mem_scenes (label, profile, member_count, created_at, updated_at) VALUES (?, ?, ?, ?, ?)"),
)
.bind("Second")
.bind("Profile second")
.bind(1i64)
.bind(2_000_000i64)
.bind(2_000_000i64)
.execute(store.pool())
.await
.unwrap();
let scene2_id: (i64,) = zeph_db::query_as(sql!("SELECT last_insert_rowid()"))
.fetch_one(store.pool())
.await
.unwrap();
zeph_db::query(sql!(
"INSERT INTO mem_scene_members (scene_id, message_id) VALUES (?, ?)"
))
.bind(scene1_id.0)
.bind(ids1[0].0)
.execute(store.pool())
.await
.unwrap();
zeph_db::query(sql!(
"INSERT INTO mem_scene_members (scene_id, message_id) VALUES (?, ?)"
))
.bind(scene2_id.0)
.bind(ids2[0].0)
.execute(store.pool())
.await
.unwrap();
let scenes = store.list_mem_scenes().await.unwrap();
assert_eq!(scenes.len(), 2);
assert_eq!(scenes[0].label, "Second");
assert_eq!(scenes[1].label, "First");
}
}