Skip to main content

ext_fts/
lib.rs

1//! ext-fts: Full-text search extension via tantivy.
2//!
3//! Provides BM25-ranked full-text search through three procedures:
4//! - `fts.add(content)` → indexes a document, returns `{doc_id: INT64}`
5//! - `fts.search(query, limit?)` → BM25 ranked results `{doc_id, score, snippet}`
6//! - `fts.clear()` → resets the index
7//!
8//! Uses an in-memory tantivy index with batched commits (every 1000 documents).
9//! The index is thread-safe via `Mutex<FtsIndex>`.
10
11pub mod index;
12
13use std::collections::HashMap;
14use std::sync::Mutex;
15
16use kyu_extension::{Extension, ProcColumn, ProcParam, ProcRow, ProcedureSignature};
17use kyu_types::{LogicalType, TypedValue};
18use smol_str::SmolStr;
19
20use crate::index::FtsIndex;
21
22/// Full-text search extension.
23pub struct FtsExtension {
24    state: Mutex<FtsIndex>,
25}
26
27impl FtsExtension {
28    pub fn new() -> Self {
29        Self {
30            state: Mutex::new(FtsIndex::new()),
31        }
32    }
33}
34
35impl Default for FtsExtension {
36    fn default() -> Self {
37        Self::new()
38    }
39}
40
41impl Extension for FtsExtension {
42    fn name(&self) -> &str {
43        "fts"
44    }
45
46    fn needs_graph(&self) -> bool {
47        false
48    }
49
50    fn procedures(&self) -> Vec<ProcedureSignature> {
51        vec![
52            ProcedureSignature {
53                name: "add".into(),
54                params: vec![ProcParam {
55                    name: "content".into(),
56                    type_desc: "STRING".into(),
57                }],
58                columns: vec![ProcColumn {
59                    name: "doc_id".into(),
60                    data_type: LogicalType::Int64,
61                }],
62            },
63            ProcedureSignature {
64                name: "search".into(),
65                params: vec![
66                    ProcParam {
67                        name: "query".into(),
68                        type_desc: "STRING".into(),
69                    },
70                    ProcParam {
71                        name: "limit".into(),
72                        type_desc: "INT64".into(),
73                    },
74                ],
75                columns: vec![
76                    ProcColumn {
77                        name: "doc_id".into(),
78                        data_type: LogicalType::Int64,
79                    },
80                    ProcColumn {
81                        name: "score".into(),
82                        data_type: LogicalType::Double,
83                    },
84                    ProcColumn {
85                        name: "snippet".into(),
86                        data_type: LogicalType::String,
87                    },
88                ],
89            },
90            ProcedureSignature {
91                name: "clear".into(),
92                params: vec![],
93                columns: vec![ProcColumn {
94                    name: "status".into(),
95                    data_type: LogicalType::String,
96                }],
97            },
98        ]
99    }
100
101    fn execute(
102        &self,
103        procedure: &str,
104        args: &[String],
105        _adjacency: &HashMap<i64, Vec<(i64, f64)>>,
106    ) -> Result<Vec<ProcRow>, String> {
107        let mut index = self.state.lock().map_err(|e| format!("lock error: {e}"))?;
108
109        match procedure {
110            "add" => {
111                let content = args.first().ok_or("fts.add requires a content argument")?;
112                let doc_id = index.add_document(content).map_err(|e| e.to_string())?;
113                Ok(vec![vec![TypedValue::Int64(doc_id as i64)]])
114            }
115            "search" => {
116                let query = args.first().ok_or("fts.search requires a query argument")?;
117                let limit = args
118                    .get(1)
119                    .and_then(|s| s.parse::<usize>().ok())
120                    .unwrap_or(10);
121                let results = index.search(query, limit).map_err(|e| e.to_string())?;
122                Ok(results
123                    .into_iter()
124                    .map(|(doc_id, score, snippet)| {
125                        vec![
126                            TypedValue::Int64(doc_id as i64),
127                            TypedValue::Double(score as f64),
128                            TypedValue::String(SmolStr::new(snippet)),
129                        ]
130                    })
131                    .collect())
132            }
133            "clear" => {
134                index.clear().map_err(|e| e.to_string())?;
135                Ok(vec![vec![TypedValue::String(SmolStr::new("ok"))]])
136            }
137            _ => Err(format!("unknown procedure: {procedure}")),
138        }
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    #[test]
147    fn extension_metadata() {
148        let ext = FtsExtension::new();
149        assert_eq!(ext.name(), "fts");
150        assert!(!ext.needs_graph());
151        assert_eq!(ext.procedures().len(), 3);
152    }
153
154    #[test]
155    fn execute_add_and_search() {
156        let ext = FtsExtension::new();
157        let empty = HashMap::new();
158
159        // Add a document.
160        let rows = ext
161            .execute("add", &["the quick brown fox".into()], &empty)
162            .unwrap();
163        assert_eq!(rows.len(), 1);
164        assert_eq!(rows[0][0], TypedValue::Int64(0));
165
166        // Search for it.
167        let results = ext
168            .execute("search", &["fox".into(), "10".into()], &empty)
169            .unwrap();
170        assert_eq!(results.len(), 1);
171        assert_eq!(results[0][0], TypedValue::Int64(0)); // doc_id
172        assert!(matches!(results[0][1], TypedValue::Double(s) if s > 0.0)); // score > 0
173    }
174
175    #[test]
176    fn execute_search_no_results() {
177        let ext = FtsExtension::new();
178        let empty = HashMap::new();
179
180        ext.execute("add", &["hello world".into()], &empty).unwrap();
181
182        let results = ext
183            .execute("search", &["quantum".into(), "10".into()], &empty)
184            .unwrap();
185        assert!(results.is_empty());
186    }
187
188    #[test]
189    fn execute_clear() {
190        let ext = FtsExtension::new();
191        let empty = HashMap::new();
192
193        ext.execute("add", &["testing clear".into()], &empty)
194            .unwrap();
195        let clear_result = ext.execute("clear", &[], &empty).unwrap();
196        assert_eq!(clear_result[0][0], TypedValue::String(SmolStr::new("ok")));
197
198        // After clear, search returns nothing.
199        let results = ext
200            .execute("search", &["testing".into(), "10".into()], &empty)
201            .unwrap();
202        assert!(results.is_empty());
203    }
204
205    #[test]
206    fn execute_unknown_procedure() {
207        let ext = FtsExtension::new();
208        let empty = HashMap::new();
209        assert!(ext.execute("nonexistent", &[], &empty).is_err());
210    }
211
212    #[test]
213    fn multiple_documents_ranked() {
214        let ext = FtsExtension::new();
215        let empty = HashMap::new();
216
217        ext.execute("add", &["python for data science".into()], &empty)
218            .unwrap();
219        ext.execute("add", &["rust systems programming language".into()], &empty)
220            .unwrap();
221        ext.execute("add", &["rust rust rust all about rust".into()], &empty)
222            .unwrap();
223
224        let results = ext
225            .execute("search", &["rust".into(), "10".into()], &empty)
226            .unwrap();
227        // At least 2 rust documents should match.
228        assert!(results.len() >= 2);
229        // Python doc should not be in results.
230        let doc_ids: Vec<i64> = results
231            .iter()
232            .map(|r| match r[0] {
233                TypedValue::Int64(id) => id,
234                _ => panic!("expected Int64"),
235            })
236            .collect();
237        assert!(!doc_ids.contains(&0)); // python doc
238    }
239}