Skip to main content

mxr_store/
semantic.rs

1use mxr_core::id::*;
2use mxr_core::types::*;
3use sqlx::Row;
4
5impl super::Store {
6    pub async fn list_semantic_profiles(&self) -> Result<Vec<SemanticProfileRecord>, sqlx::Error> {
7        let rows = sqlx::query(
8            r#"SELECT id, profile_name, backend, model_revision, dimensions, status,
9                      installed_at, activated_at, last_indexed_at,
10                      progress_completed, progress_total, last_error
11               FROM semantic_profiles
12               ORDER BY profile_name ASC"#,
13        )
14        .fetch_all(self.reader())
15        .await?;
16
17        Ok(rows.into_iter().map(row_to_semantic_profile).collect())
18    }
19
20    pub async fn get_semantic_profile(
21        &self,
22        profile: SemanticProfile,
23    ) -> Result<Option<SemanticProfileRecord>, sqlx::Error> {
24        let row = sqlx::query(
25            r#"SELECT id, profile_name, backend, model_revision, dimensions, status,
26                      installed_at, activated_at, last_indexed_at,
27                      progress_completed, progress_total, last_error
28               FROM semantic_profiles
29               WHERE profile_name = ?"#,
30        )
31        .bind(profile.as_str())
32        .fetch_optional(self.reader())
33        .await?;
34
35        Ok(row.map(row_to_semantic_profile))
36    }
37
38    pub async fn upsert_semantic_profile(
39        &self,
40        profile: &SemanticProfileRecord,
41    ) -> Result<(), sqlx::Error> {
42        sqlx::query(
43            r#"INSERT INTO semantic_profiles
44               (id, profile_name, backend, model_revision, dimensions, status,
45                installed_at, activated_at, last_indexed_at,
46                progress_completed, progress_total, last_error)
47               VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
48               ON CONFLICT(id) DO UPDATE SET
49                   profile_name = excluded.profile_name,
50                   backend = excluded.backend,
51                   model_revision = excluded.model_revision,
52                   dimensions = excluded.dimensions,
53                   status = excluded.status,
54                   installed_at = excluded.installed_at,
55                   activated_at = excluded.activated_at,
56                   last_indexed_at = excluded.last_indexed_at,
57                   progress_completed = excluded.progress_completed,
58                   progress_total = excluded.progress_total,
59                   last_error = excluded.last_error"#,
60        )
61        .bind(profile.id.as_str())
62        .bind(profile.profile.as_str())
63        .bind(&profile.backend)
64        .bind(&profile.model_revision)
65        .bind(profile.dimensions as i64)
66        .bind(serde_json::to_string(&profile.status).unwrap())
67        .bind(profile.installed_at.map(|v| v.timestamp()))
68        .bind(profile.activated_at.map(|v| v.timestamp()))
69        .bind(profile.last_indexed_at.map(|v| v.timestamp()))
70        .bind(profile.progress_completed as i64)
71        .bind(profile.progress_total as i64)
72        .bind(&profile.last_error)
73        .execute(self.writer())
74        .await?;
75
76        Ok(())
77    }
78
79    pub async fn replace_semantic_message_data(
80        &self,
81        message_id: &MessageId,
82        profile_id: &SemanticProfileId,
83        chunks: &[SemanticChunkRecord],
84        embeddings: &[SemanticEmbeddingRecord],
85    ) -> Result<(), sqlx::Error> {
86        let mut tx = self.writer().begin().await?;
87        let message_id_str = message_id.as_str();
88        let profile_id_str = profile_id.as_str();
89
90        sqlx::query(
91            r#"DELETE FROM semantic_embeddings
92               WHERE profile_id = ?
93                 AND chunk_id IN (
94                    SELECT id FROM semantic_chunks WHERE message_id = ?
95               )"#,
96        )
97        .bind(profile_id_str)
98        .bind(&message_id_str)
99        .execute(&mut *tx)
100        .await?;
101
102        sqlx::query("DELETE FROM semantic_chunks WHERE message_id = ?")
103            .bind(&message_id_str)
104            .execute(&mut *tx)
105            .await?;
106
107        for chunk in chunks {
108            sqlx::query(
109                r#"INSERT INTO semantic_chunks
110                   (id, message_id, source_kind, ordinal, normalized, content_hash, created_at, updated_at)
111                   VALUES (?, ?, ?, ?, ?, ?, ?, ?)"#,
112            )
113            .bind(chunk.id.as_str())
114            .bind(chunk.message_id.as_str())
115            .bind(serde_json::to_string(&chunk.source_kind).unwrap())
116            .bind(chunk.ordinal as i64)
117            .bind(&chunk.normalized)
118            .bind(&chunk.content_hash)
119            .bind(chunk.created_at.timestamp())
120            .bind(chunk.updated_at.timestamp())
121            .execute(&mut *tx)
122            .await?;
123        }
124
125        for embedding in embeddings {
126            sqlx::query(
127                r#"INSERT INTO semantic_embeddings
128                   (chunk_id, profile_id, dimensions, vector_blob, status, created_at, updated_at)
129                   VALUES (?, ?, ?, ?, ?, ?, ?)"#,
130            )
131            .bind(embedding.chunk_id.as_str())
132            .bind(embedding.profile_id.as_str())
133            .bind(embedding.dimensions as i64)
134            .bind(&embedding.vector)
135            .bind(serde_json::to_string(&embedding.status).unwrap())
136            .bind(embedding.created_at.timestamp())
137            .bind(embedding.updated_at.timestamp())
138            .execute(&mut *tx)
139            .await?;
140        }
141
142        tx.commit().await?;
143        Ok(())
144    }
145
146    pub async fn list_semantic_embeddings(
147        &self,
148        profile_id: &SemanticProfileId,
149    ) -> Result<Vec<(SemanticChunkRecord, SemanticEmbeddingRecord)>, sqlx::Error> {
150        let rows = sqlx::query(
151            r#"SELECT
152                   c.id as chunk_id,
153                   c.message_id,
154                   c.source_kind,
155                   c.ordinal,
156                   c.normalized,
157                   c.content_hash,
158                   c.created_at as chunk_created_at,
159                   c.updated_at as chunk_updated_at,
160                   e.profile_id,
161                   e.dimensions,
162                   e.vector_blob,
163                   e.status,
164                   e.created_at as embedding_created_at,
165                   e.updated_at as embedding_updated_at
166               FROM semantic_embeddings e
167               JOIN semantic_chunks c ON c.id = e.chunk_id
168               WHERE e.profile_id = ?
169               ORDER BY c.message_id ASC, c.ordinal ASC"#,
170        )
171        .bind(profile_id.as_str())
172        .fetch_all(self.reader())
173        .await?;
174
175        Ok(rows
176            .into_iter()
177            .map(|row| {
178                let chunk = SemanticChunkRecord {
179                    id: SemanticChunkId::from_uuid(
180                        uuid::Uuid::parse_str(&row.get::<String, _>("chunk_id")).unwrap(),
181                    ),
182                    message_id: MessageId::from_uuid(
183                        uuid::Uuid::parse_str(&row.get::<String, _>("message_id")).unwrap(),
184                    ),
185                    source_kind: serde_json::from_str(&row.get::<String, _>("source_kind"))
186                        .unwrap(),
187                    ordinal: row.get::<i64, _>("ordinal") as u32,
188                    normalized: row.get::<String, _>("normalized"),
189                    content_hash: row.get::<String, _>("content_hash"),
190                    created_at: chrono::DateTime::from_timestamp(
191                        row.get::<i64, _>("chunk_created_at"),
192                        0,
193                    )
194                    .unwrap_or_default(),
195                    updated_at: chrono::DateTime::from_timestamp(
196                        row.get::<i64, _>("chunk_updated_at"),
197                        0,
198                    )
199                    .unwrap_or_default(),
200                };
201                let embedding = SemanticEmbeddingRecord {
202                    chunk_id: chunk.id.clone(),
203                    profile_id: SemanticProfileId::from_uuid(
204                        uuid::Uuid::parse_str(&row.get::<String, _>("profile_id")).unwrap(),
205                    ),
206                    dimensions: row.get::<i64, _>("dimensions") as u32,
207                    vector: row.get::<Vec<u8>, _>("vector_blob"),
208                    status: serde_json::from_str(&row.get::<String, _>("status")).unwrap(),
209                    created_at: chrono::DateTime::from_timestamp(
210                        row.get::<i64, _>("embedding_created_at"),
211                        0,
212                    )
213                    .unwrap_or_default(),
214                    updated_at: chrono::DateTime::from_timestamp(
215                        row.get::<i64, _>("embedding_updated_at"),
216                        0,
217                    )
218                    .unwrap_or_default(),
219                };
220                (chunk, embedding)
221            })
222            .collect())
223    }
224}
225
226fn row_to_semantic_profile(row: sqlx::sqlite::SqliteRow) -> SemanticProfileRecord {
227    SemanticProfileRecord {
228        id: SemanticProfileId::from_uuid(
229            uuid::Uuid::parse_str(&row.get::<String, _>("id")).unwrap(),
230        ),
231        profile: serde_json::from_str(&format!("\"{}\"", row.get::<String, _>("profile_name")))
232            .unwrap(),
233        backend: row.get::<String, _>("backend"),
234        model_revision: row.get::<String, _>("model_revision"),
235        dimensions: row.get::<i64, _>("dimensions") as u32,
236        status: serde_json::from_str(&row.get::<String, _>("status")).unwrap_or_default(),
237        installed_at: row
238            .get::<Option<i64>, _>("installed_at")
239            .and_then(|ts| chrono::DateTime::from_timestamp(ts, 0)),
240        activated_at: row
241            .get::<Option<i64>, _>("activated_at")
242            .and_then(|ts| chrono::DateTime::from_timestamp(ts, 0)),
243        last_indexed_at: row
244            .get::<Option<i64>, _>("last_indexed_at")
245            .and_then(|ts| chrono::DateTime::from_timestamp(ts, 0)),
246        progress_completed: row.get::<i64, _>("progress_completed") as u32,
247        progress_total: row.get::<i64, _>("progress_total") as u32,
248        last_error: row.get::<Option<String>, _>("last_error"),
249    }
250}