Skip to main content

docx_core/store/
surreal.rs

1use std::{error::Error, fmt, str::FromStr, sync::Arc};
2
3use docx_store::models::{
4    DocBlock,
5    DocChunk,
6    DocSource,
7    Ingest,
8    Project,
9    RelationRecord,
10    Symbol,
11};
12use docx_store::schema::{
13    TABLE_DOC_BLOCK,
14    TABLE_DOC_SOURCE,
15    TABLE_INGEST,
16    TABLE_PROJECT,
17    TABLE_SYMBOL,
18    make_record_id,
19};
20use surrealdb::{Connection, Surreal};
21use surrealdb::sql::{Regex, Thing};
22use uuid::Uuid;
23
24/// Errors returned by the `SurrealDB` store implementation.
25#[derive(Debug)]
26pub enum StoreError {
27    Surreal(Box<surrealdb::Error>),
28    InvalidInput(String),
29}
30
31impl fmt::Display for StoreError {
32    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33        match self {
34            Self::Surreal(err) => write!(f, "SurrealDB error: {err}"),
35            Self::InvalidInput(message) => write!(f, "Invalid input: {message}"),
36        }
37    }
38}
39
40impl Error for StoreError {}
41
42impl From<surrealdb::Error> for StoreError {
43    fn from(err: surrealdb::Error) -> Self {
44        Self::Surreal(Box::new(err))
45    }
46}
47
48pub type StoreResult<T> = Result<T, StoreError>;
49
50/// Store implementation backed by `SurrealDB`.
51pub struct SurrealDocStore<C: Connection> {
52    db: Arc<Surreal<C>>,
53}
54
55impl<C: Connection> Clone for SurrealDocStore<C> {
56    fn clone(&self) -> Self {
57        Self {
58            db: self.db.clone(),
59        }
60    }
61}
62
63impl<C: Connection> SurrealDocStore<C> {
64    #[must_use]
65    pub fn new(db: Surreal<C>) -> Self {
66        Self {
67            db: Arc::new(db),
68        }
69    }
70
71    #[must_use]
72    pub const fn from_arc(db: Arc<Surreal<C>>) -> Self {
73        Self { db }
74    }
75
76    #[must_use]
77    pub fn db(&self) -> &Surreal<C> {
78        &self.db
79    }
80
81    /// Upserts a project record by id.
82    ///
83    /// # Errors
84    /// Returns `StoreError` if validation fails or the database write fails.
85    pub async fn upsert_project(&self, project: Project) -> StoreResult<Project> {
86        ensure_non_empty(&project.project_id, "project_id")?;
87        let fallback = project.clone();
88        let record: Option<Project> = self
89            .db
90            .update((TABLE_PROJECT, project.project_id.clone()))
91            .content(project)
92            .await?;
93        Ok(record.unwrap_or(fallback))
94    }
95
96    /// Fetches a project by id.
97    ///
98    /// # Errors
99    /// Returns `StoreError` if the database query fails.
100    pub async fn get_project(&self, project_id: &str) -> StoreResult<Option<Project>> {
101        let record: Option<Project> = self.db.select((TABLE_PROJECT, project_id)).await?;
102        Ok(record)
103    }
104
105    /// Fetches an ingest record by id.
106    ///
107    /// # Errors
108    /// Returns `StoreError` if the database query fails.
109    pub async fn get_ingest(&self, ingest_id: &str) -> StoreResult<Option<Ingest>> {
110        let record: Option<Ingest> = self.db.select((TABLE_INGEST, ingest_id)).await?;
111        Ok(record)
112    }
113
114    /// Lists projects up to the provided limit.
115    ///
116    /// # Errors
117    /// Returns `StoreError` if the limit is invalid or the database query fails.
118    pub async fn list_projects(&self, limit: usize) -> StoreResult<Vec<Project>> {
119        let limit = limit_to_i64(limit)?;
120        let query = "SELECT * FROM project LIMIT $limit;";
121        let mut response = self.db.query(query).bind(("limit", limit)).await?;
122        let records: Vec<Project> = response.take(0)?;
123        Ok(records)
124    }
125
126    /// Searches projects by name or alias pattern.
127    ///
128    /// # Errors
129    /// Returns `StoreError` if the limit or pattern is invalid or the database query fails.
130    pub async fn search_projects(&self, pattern: &str, limit: usize) -> StoreResult<Vec<Project>> {
131        let Some(pattern) = normalize_pattern(pattern) else {
132            return self.list_projects(limit).await;
133        };
134        let limit = limit_to_i64(limit)?;
135        let regex = build_project_regex(&pattern)?;
136        let query = "SELECT * FROM project WHERE search_text != NONE AND string::matches(search_text, $pattern) LIMIT $limit;";
137        let mut response = self
138            .db
139            .query(query)
140            .bind(("pattern", regex))
141            .bind(("limit", limit))
142            .await?;
143        let records: Vec<Project> = response.take(0)?;
144        Ok(records)
145    }
146
147    /// Lists ingest records for a project.
148    ///
149    /// # Errors
150    /// Returns `StoreError` if the limit is invalid or the database query fails.
151    pub async fn list_ingests(&self, project_id: &str, limit: usize) -> StoreResult<Vec<Ingest>> {
152        let project_id = project_id.to_string();
153        let limit = limit_to_i64(limit)?;
154        let query =
155            "SELECT * FROM ingest WHERE project_id = $project_id ORDER BY ingested_at DESC LIMIT $limit;";
156        let mut response = self
157            .db
158            .query(query)
159            .bind(("project_id", project_id))
160            .bind(("limit", limit))
161            .await?;
162        let records: Vec<Ingest> = response.take(0)?;
163        Ok(records)
164    }
165
166    /// Creates an ingest record.
167    ///
168    /// # Errors
169    /// Returns `StoreError` if the database write fails.
170    pub async fn create_ingest(&self, mut ingest: Ingest) -> StoreResult<Ingest> {
171        let id = ingest.id.clone().unwrap_or_else(|| Uuid::new_v4().to_string());
172        ingest.id = Some(id.clone());
173        let record = Thing::from((TABLE_INGEST, id.as_str()));
174        self.db
175            .query("UPSERT $record CONTENT $data RETURN NONE;")
176            .bind(("record", record))
177            .bind(("data", ingest.clone()))
178            .await?;
179        Ok(ingest)
180    }
181
182    /// Creates a document source record.
183    ///
184    /// # Errors
185    /// Returns `StoreError` if the database write fails.
186    pub async fn create_doc_source(&self, mut source: DocSource) -> StoreResult<DocSource> {
187        let id = source.id.clone().unwrap_or_else(|| Uuid::new_v4().to_string());
188        source.id = Some(id.clone());
189        self.db
190            .query("CREATE doc_source CONTENT $data RETURN NONE;")
191            .bind(("data", source.clone()))
192            .await?;
193        Ok(source)
194    }
195
196    /// Upserts a symbol record by symbol key.
197    ///
198    /// # Errors
199    /// Returns `StoreError` if validation fails or the database write fails.
200    pub async fn upsert_symbol(&self, mut symbol: Symbol) -> StoreResult<Symbol> {
201        ensure_non_empty(&symbol.symbol_key, "symbol_key")?;
202        let id = symbol
203            .id
204            .clone()
205            .unwrap_or_else(|| symbol.symbol_key.clone());
206        symbol.id = Some(id.clone());
207        let record = Thing::from((TABLE_SYMBOL, id.as_str()));
208        self.db
209            .query("UPSERT $record CONTENT $data RETURN NONE;")
210            .bind(("record", record))
211            .bind(("data", symbol.clone()))
212            .await?;
213        Ok(symbol)
214    }
215
216    /// Creates a document block record.
217    ///
218    /// # Errors
219    /// Returns `StoreError` if the database write fails.
220    pub async fn create_doc_block(&self, mut block: DocBlock) -> StoreResult<DocBlock> {
221        let id = block.id.clone().unwrap_or_else(|| Uuid::new_v4().to_string());
222        block.id = Some(id.clone());
223        self.db
224            .query("CREATE doc_block CONTENT $data RETURN NONE;")
225            .bind(("data", block.clone()))
226            .await?;
227        Ok(block)
228    }
229
230    /// Creates document block records.
231    ///
232    /// # Errors
233    /// Returns `StoreError` if the database write fails.
234    pub async fn create_doc_blocks(&self, blocks: Vec<DocBlock>) -> StoreResult<Vec<DocBlock>> {
235        if blocks.is_empty() {
236            return Ok(Vec::new());
237        }
238        let mut stored = Vec::with_capacity(blocks.len());
239        for block in blocks {
240            stored.push(self.create_doc_block(block).await?);
241        }
242        Ok(stored)
243    }
244
245    /// Creates document chunk records.
246    ///
247    /// # Errors
248    /// Returns `StoreError` if the database write fails.
249    pub async fn create_doc_chunks(&self, chunks: Vec<DocChunk>) -> StoreResult<Vec<DocChunk>> {
250        if chunks.is_empty() {
251            return Ok(Vec::new());
252        }
253        let mut stored = Vec::with_capacity(chunks.len());
254        for mut chunk in chunks {
255            let id = chunk.id.clone().unwrap_or_else(|| Uuid::new_v4().to_string());
256            chunk.id = Some(id.clone());
257            self.db
258                .query("CREATE doc_chunk CONTENT $data RETURN NONE;")
259                .bind(("data", chunk.clone()))
260                .await?;
261            stored.push(chunk);
262        }
263        Ok(stored)
264    }
265
266    /// Creates a relation record in the specified table.
267    ///
268    /// # Errors
269    /// Returns `StoreError` if the database write fails.
270    pub async fn create_relation(
271        &self,
272        table: &str,
273        mut relation: RelationRecord,
274    ) -> StoreResult<RelationRecord> {
275        let id = relation.id.clone().unwrap_or_else(|| Uuid::new_v4().to_string());
276        relation.id = Some(id.clone());
277        let statement = format!("CREATE {table} CONTENT $data RETURN NONE;");
278        self.db
279            .query(statement)
280            .bind(("data", relation.clone()))
281            .await?;
282        Ok(relation)
283    }
284
285    /// Creates relation records in the specified table.
286    ///
287    /// # Errors
288    /// Returns `StoreError` if the database write fails.
289    pub async fn create_relations(
290        &self,
291        table: &str,
292        relations: Vec<RelationRecord>,
293    ) -> StoreResult<Vec<RelationRecord>> {
294        if relations.is_empty() {
295            return Ok(Vec::new());
296        }
297        let mut stored = Vec::with_capacity(relations.len());
298        for mut relation in relations {
299            let id = relation.id.clone().unwrap_or_else(|| Uuid::new_v4().to_string());
300            relation.id = Some(id.clone());
301            let statement = format!("CREATE {table} CONTENT $data RETURN NONE;");
302            self.db
303                .query(statement)
304                .bind(("data", relation.clone()))
305                .await?;
306            stored.push(relation);
307        }
308        Ok(stored)
309    }
310
311    /// Fetches a symbol by key.
312    ///
313    /// # Errors
314    /// Returns `StoreError` if the database query fails.
315    pub async fn get_symbol(&self, symbol_key: &str) -> StoreResult<Option<Symbol>> {
316        let record: Option<Symbol> = self.db.select((TABLE_SYMBOL, symbol_key)).await?;
317        Ok(record)
318    }
319
320    /// Fetches a symbol by project id and key.
321    ///
322    /// # Errors
323    /// Returns `StoreError` if the database query fails.
324    pub async fn get_symbol_by_project(
325        &self,
326        project_id: &str,
327        symbol_key: &str,
328    ) -> StoreResult<Option<Symbol>> {
329        let project_id = project_id.to_string();
330        let symbol_key = symbol_key.to_string();
331        let query = "SELECT * FROM symbol WHERE project_id = $project_id AND symbol_key = $symbol_key LIMIT 1;";
332        let mut response = self
333            .db
334            .query(query)
335            .bind(("project_id", project_id))
336            .bind(("symbol_key", symbol_key))
337            .await?;
338        let mut records: Vec<Symbol> = response.take(0)?;
339        Ok(records.pop())
340    }
341
342    /// Lists symbols by name match within a project.
343    ///
344    /// # Errors
345    /// Returns `StoreError` if the limit is invalid or the database query fails.
346    pub async fn list_symbols_by_name(
347        &self,
348        project_id: &str,
349        name: &str,
350        limit: usize,
351    ) -> StoreResult<Vec<Symbol>> {
352        let project_id = project_id.to_string();
353        let name = name.to_string();
354        let limit = limit_to_i64(limit)?;
355        let query = "SELECT * FROM symbol WHERE project_id = $project_id AND name CONTAINS $name LIMIT $limit;";
356        let mut response = self
357            .db
358            .query(query)
359            .bind(("project_id", project_id))
360            .bind(("name", name))
361            .bind(("limit", limit))
362            .await?;
363        let records: Vec<Symbol> = response.take(0)?;
364        Ok(records)
365    }
366
367    /// Lists distinct symbol kinds for a project.
368    ///
369    /// # Errors
370    /// Returns `StoreError` if the database query fails.
371    pub async fn list_symbol_kinds(&self, project_id: &str) -> StoreResult<Vec<String>> {
372        let project_id = project_id.to_string();
373        let query = "SELECT kind FROM symbol WHERE project_id = $project_id GROUP BY kind;";
374        let mut response = self
375            .db
376            .query(query)
377            .bind(("project_id", project_id))
378            .await?;
379        let records: Vec<SymbolKindRow> = response.take(0)?;
380        let mut kinds: Vec<String> = records
381            .into_iter()
382            .filter_map(|row| row.kind)
383            .filter(|value| !value.trim().is_empty())
384            .collect();
385        kinds.sort();
386        kinds.dedup();
387        Ok(kinds)
388    }
389
390    /// Lists members by scope prefix or glob pattern.
391    ///
392    /// # Errors
393    /// Returns `StoreError` if the scope or limit is invalid or the database query fails.
394    pub async fn list_members_by_scope(
395        &self,
396        project_id: &str,
397        scope: &str,
398        limit: usize,
399    ) -> StoreResult<Vec<Symbol>> {
400        let Some(scope) = normalize_pattern(scope) else {
401            return Ok(Vec::new());
402        };
403        let project_id = project_id.to_string();
404        let limit = limit_to_i64(limit)?;
405        let mut response = if scope.contains('*') {
406            let regex = build_scope_regex(&scope)?;
407            let query = "SELECT * FROM symbol WHERE project_id = $project_id AND qualified_name != NONE AND string::matches(string::lowercase(qualified_name), $pattern) LIMIT $limit;";
408            self.db
409                .query(query)
410                .bind(("project_id", project_id))
411                .bind(("pattern", regex))
412                .bind(("limit", limit))
413                .await?
414        } else {
415            let query = "SELECT * FROM symbol WHERE project_id = $project_id AND qualified_name != NONE AND string::starts_with(string::lowercase(qualified_name), $scope) LIMIT $limit;";
416            self.db
417                .query(query)
418                .bind(("project_id", project_id))
419                .bind(("scope", scope))
420                .bind(("limit", limit))
421                .await?
422        };
423        let records: Vec<Symbol> = response.take(0)?;
424        Ok(records)
425    }
426
427    /// Lists document blocks for a symbol, optionally filtering by ingest id.
428    ///
429    /// # Errors
430    /// Returns `StoreError` if the database query fails.
431    pub async fn list_doc_blocks(
432        &self,
433        project_id: &str,
434        symbol_key: &str,
435        ingest_id: Option<&str>,
436    ) -> StoreResult<Vec<DocBlock>> {
437        let project_id = project_id.to_string();
438        let symbol_key = symbol_key.to_string();
439        let (query, binds) = ingest_id.map_or(
440            (
441                "SELECT * FROM doc_block WHERE project_id = $project_id AND symbol_key = $symbol_key;",
442                None,
443            ),
444            |ingest_id| (
445                "SELECT * FROM doc_block WHERE project_id = $project_id AND symbol_key = $symbol_key AND ingest_id = $ingest_id;",
446                Some(ingest_id.to_string()),
447            ),
448        );
449        let response = self
450            .db
451            .query(query)
452            .bind(("project_id", project_id))
453            .bind(("symbol_key", symbol_key));
454        let mut response = if let Some(ingest_id) = binds {
455            response.bind(("ingest_id", ingest_id)).await?
456        } else {
457            response.await?
458        };
459        let records: Vec<DocBlock> = response.take(0)?;
460        Ok(records)
461    }
462
463    /// Searches document blocks by text within a project.
464    ///
465    /// # Errors
466    /// Returns `StoreError` if the limit is invalid or the database query fails.
467    pub async fn search_doc_blocks(
468        &self,
469        project_id: &str,
470        text: &str,
471        limit: usize,
472    ) -> StoreResult<Vec<DocBlock>> {
473        let project_id = project_id.to_string();
474        let text = text.to_string();
475        let limit = limit_to_i64(limit)?;
476        let query = "SELECT * FROM doc_block WHERE project_id = $project_id AND (summary CONTAINS $text OR remarks CONTAINS $text OR returns CONTAINS $text) LIMIT $limit;";
477        let mut response = self
478            .db
479            .query(query)
480            .bind(("project_id", project_id))
481            .bind(("text", text))
482            .bind(("limit", limit))
483            .await?;
484        let records: Vec<DocBlock> = response.take(0)?;
485        Ok(records)
486    }
487
488    /// Lists document sources by project and ingest ids.
489    ///
490    /// # Errors
491    /// Returns `StoreError` if the database query fails.
492    pub async fn list_doc_sources(
493        &self,
494        project_id: &str,
495        ingest_ids: &[String],
496    ) -> StoreResult<Vec<DocSource>> {
497        if ingest_ids.is_empty() {
498            return Ok(Vec::new());
499        }
500        let project_id = project_id.to_string();
501        let ingest_ids = ingest_ids.to_vec();
502        let query = "SELECT * FROM doc_source WHERE project_id = $project_id AND ingest_id IN $ingest_ids;";
503        let mut response = self
504            .db
505            .query(query)
506            .bind(("project_id", project_id))
507            .bind(("ingest_ids", ingest_ids))
508            .await?;
509        let records: Vec<DocSource> = response.take(0)?;
510        Ok(records)
511    }
512
513    /// Fetches a document source by id.
514    ///
515    /// # Errors
516    /// Returns `StoreError` if the database query fails.
517    pub async fn get_doc_source(&self, doc_source_id: &str) -> StoreResult<Option<DocSource>> {
518        let record: Option<DocSource> = self.db.select((TABLE_DOC_SOURCE, doc_source_id)).await?;
519        Ok(record)
520    }
521
522    /// Lists document sources for a project, optionally filtered by ingest id.
523    ///
524    /// # Errors
525    /// Returns `StoreError` if the limit is invalid or the database query fails.
526    pub async fn list_doc_sources_by_project(
527        &self,
528        project_id: &str,
529        ingest_id: Option<&str>,
530        limit: usize,
531    ) -> StoreResult<Vec<DocSource>> {
532        let project_id = project_id.to_string();
533        let limit = limit_to_i64(limit)?;
534        let (query, binds) = ingest_id.map_or(
535            (
536                "SELECT * FROM doc_source WHERE project_id = $project_id ORDER BY source_modified_at DESC LIMIT $limit;",
537                None,
538            ),
539            |ingest_id| (
540                "SELECT * FROM doc_source WHERE project_id = $project_id AND ingest_id = $ingest_id ORDER BY source_modified_at DESC LIMIT $limit;",
541                Some(ingest_id.to_string()),
542            ),
543        );
544        let response = self
545            .db
546            .query(query)
547            .bind(("project_id", project_id))
548            .bind(("limit", limit));
549        let mut response = if let Some(ingest_id) = binds {
550            response.bind(("ingest_id", ingest_id)).await?
551        } else {
552            response.await?
553        };
554        let records: Vec<DocSource> = response.take(0)?;
555        Ok(records)
556    }
557
558    /// Lists relation records in a table where the symbol is the source (outgoing).
559    ///
560    /// # Errors
561    /// Returns `StoreError` if the database query fails.
562    pub async fn list_relations_from_symbol(
563        &self,
564        table: &str,
565        project_id: &str,
566        symbol_id: &str,
567        limit: usize,
568    ) -> StoreResult<Vec<RelationRecord>> {
569        let project_id = project_id.to_string();
570        let limit = limit_to_i64(limit)?;
571        let record_id = make_record_id(TABLE_SYMBOL, symbol_id);
572        let query = format!(
573            "SELECT * FROM {table} WHERE project_id = $project_id AND out = $record_id LIMIT $limit;"
574        );
575        let mut response = self
576            .db
577            .query(query)
578            .bind(("project_id", project_id))
579            .bind(("record_id", record_id))
580            .bind(("limit", limit))
581            .await?;
582        let records: Vec<RelationRecord> = response.take(0)?;
583        Ok(records)
584    }
585
586    /// Lists relation records in a table where the symbol is the target (incoming).
587    ///
588    /// # Errors
589    /// Returns `StoreError` if the database query fails.
590    pub async fn list_relations_to_symbol(
591        &self,
592        table: &str,
593        project_id: &str,
594        symbol_id: &str,
595        limit: usize,
596    ) -> StoreResult<Vec<RelationRecord>> {
597        let project_id = project_id.to_string();
598        let limit = limit_to_i64(limit)?;
599        let record_id = make_record_id(TABLE_SYMBOL, symbol_id);
600        let query = format!(
601            "SELECT * FROM {table} WHERE project_id = $project_id AND in = $record_id LIMIT $limit;"
602        );
603        let mut response = self
604            .db
605            .query(query)
606            .bind(("project_id", project_id))
607            .bind(("record_id", record_id))
608            .bind(("limit", limit))
609            .await?;
610        let records: Vec<RelationRecord> = response.take(0)?;
611        Ok(records)
612    }
613
614    /// Lists relation records for a document block id.
615    ///
616    /// # Errors
617    /// Returns `StoreError` if the database query fails.
618    pub async fn list_relations_from_doc_block(
619        &self,
620        table: &str,
621        project_id: &str,
622        doc_block_id: &str,
623        limit: usize,
624    ) -> StoreResult<Vec<RelationRecord>> {
625        let project_id = project_id.to_string();
626        let limit = limit_to_i64(limit)?;
627        let record_id = make_record_id(TABLE_DOC_BLOCK, doc_block_id);
628        let query = format!(
629            "SELECT * FROM {table} WHERE project_id = $project_id AND in = $record_id LIMIT $limit;"
630        );
631        let mut response = self
632            .db
633            .query(query)
634            .bind(("project_id", project_id))
635            .bind(("record_id", record_id))
636            .bind(("limit", limit))
637            .await?;
638        let records: Vec<RelationRecord> = response.take(0)?;
639        Ok(records)
640    }
641}
642
643fn ensure_non_empty(value: &str, field: &str) -> StoreResult<()> {
644    if value.is_empty() {
645        return Err(StoreError::InvalidInput(format!("{field} is required")));
646    }
647    Ok(())
648}
649
650#[derive(serde::Deserialize)]
651struct SymbolKindRow {
652    kind: Option<String>,
653}
654
655fn normalize_pattern(pattern: &str) -> Option<String> {
656    let trimmed = pattern.trim().to_lowercase();
657    if trimmed.is_empty() {
658        None
659    } else {
660        Some(trimmed)
661    }
662}
663
664fn limit_to_i64(limit: usize) -> StoreResult<i64> {
665    i64::try_from(limit).map_err(|_| {
666        StoreError::InvalidInput("limit exceeds supported range".to_string())
667    })
668}
669
670fn build_project_regex(pattern: &str) -> StoreResult<Regex> {
671    let body = glob_to_regex_body(pattern);
672    let regex = format!(r"(^|\|){body}(\||$)");
673    Regex::from_str(&regex).map_err(|err| {
674        StoreError::InvalidInput(format!("Invalid project search pattern: {err}"))
675    })
676}
677
678fn build_scope_regex(pattern: &str) -> StoreResult<Regex> {
679    let body = glob_to_regex_body(pattern);
680    let regex = format!(r"^{body}$");
681    Regex::from_str(&regex).map_err(|err| {
682        StoreError::InvalidInput(format!("Invalid scope search pattern: {err}"))
683    })
684}
685
686fn glob_to_regex_body(pattern: &str) -> String {
687    let mut escaped = String::new();
688    for ch in pattern.chars() {
689        match ch {
690            '*' => escaped.push_str(".*"),
691            '.' | '+' | '?' | '(' | ')' | '[' | ']' | '{' | '}' | '|' | '^' | '$' | '\\' => {
692                escaped.push('\\');
693                escaped.push(ch);
694            }
695            _ => escaped.push(ch),
696        }
697    }
698    escaped
699}