1pub mod index;
12
13use std::collections::HashMap;
14use std::sync::Mutex;
15
16use kyu_extension::{Extension, ProcColumn, ProcParam, ProcRow, ProcedureSignature};
17use kyu_types::{LogicalType, TypedValue};
18use smol_str::SmolStr;
19
20use crate::index::FtsIndex;
21
22pub struct FtsExtension {
24 state: Mutex<FtsIndex>,
25}
26
27impl FtsExtension {
28 pub fn new() -> Self {
29 Self {
30 state: Mutex::new(FtsIndex::new()),
31 }
32 }
33}
34
35impl Default for FtsExtension {
36 fn default() -> Self {
37 Self::new()
38 }
39}
40
41impl Extension for FtsExtension {
42 fn name(&self) -> &str {
43 "fts"
44 }
45
46 fn needs_graph(&self) -> bool {
47 false
48 }
49
50 fn procedures(&self) -> Vec<ProcedureSignature> {
51 vec![
52 ProcedureSignature {
53 name: "add".into(),
54 params: vec![ProcParam {
55 name: "content".into(),
56 type_desc: "STRING".into(),
57 }],
58 columns: vec![ProcColumn {
59 name: "doc_id".into(),
60 data_type: LogicalType::Int64,
61 }],
62 },
63 ProcedureSignature {
64 name: "search".into(),
65 params: vec![
66 ProcParam {
67 name: "query".into(),
68 type_desc: "STRING".into(),
69 },
70 ProcParam {
71 name: "limit".into(),
72 type_desc: "INT64".into(),
73 },
74 ],
75 columns: vec![
76 ProcColumn {
77 name: "doc_id".into(),
78 data_type: LogicalType::Int64,
79 },
80 ProcColumn {
81 name: "score".into(),
82 data_type: LogicalType::Double,
83 },
84 ProcColumn {
85 name: "snippet".into(),
86 data_type: LogicalType::String,
87 },
88 ],
89 },
90 ProcedureSignature {
91 name: "clear".into(),
92 params: vec![],
93 columns: vec![ProcColumn {
94 name: "status".into(),
95 data_type: LogicalType::String,
96 }],
97 },
98 ]
99 }
100
101 fn execute(
102 &self,
103 procedure: &str,
104 args: &[String],
105 _adjacency: &HashMap<i64, Vec<(i64, f64)>>,
106 ) -> Result<Vec<ProcRow>, String> {
107 let mut index = self.state.lock().map_err(|e| format!("lock error: {e}"))?;
108
109 match procedure {
110 "add" => {
111 let content = args.first().ok_or("fts.add requires a content argument")?;
112 let doc_id = index.add_document(content).map_err(|e| e.to_string())?;
113 Ok(vec![vec![TypedValue::Int64(doc_id as i64)]])
114 }
115 "search" => {
116 let query = args.first().ok_or("fts.search requires a query argument")?;
117 let limit = args
118 .get(1)
119 .and_then(|s| s.parse::<usize>().ok())
120 .unwrap_or(10);
121 let results = index.search(query, limit).map_err(|e| e.to_string())?;
122 Ok(results
123 .into_iter()
124 .map(|(doc_id, score, snippet)| {
125 vec![
126 TypedValue::Int64(doc_id as i64),
127 TypedValue::Double(score as f64),
128 TypedValue::String(SmolStr::new(snippet)),
129 ]
130 })
131 .collect())
132 }
133 "clear" => {
134 index.clear().map_err(|e| e.to_string())?;
135 Ok(vec![vec![TypedValue::String(SmolStr::new("ok"))]])
136 }
137 _ => Err(format!("unknown procedure: {procedure}")),
138 }
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145
146 #[test]
147 fn extension_metadata() {
148 let ext = FtsExtension::new();
149 assert_eq!(ext.name(), "fts");
150 assert!(!ext.needs_graph());
151 assert_eq!(ext.procedures().len(), 3);
152 }
153
154 #[test]
155 fn execute_add_and_search() {
156 let ext = FtsExtension::new();
157 let empty = HashMap::new();
158
159 let rows = ext
161 .execute("add", &["the quick brown fox".into()], &empty)
162 .unwrap();
163 assert_eq!(rows.len(), 1);
164 assert_eq!(rows[0][0], TypedValue::Int64(0));
165
166 let results = ext
168 .execute("search", &["fox".into(), "10".into()], &empty)
169 .unwrap();
170 assert_eq!(results.len(), 1);
171 assert_eq!(results[0][0], TypedValue::Int64(0)); assert!(matches!(results[0][1], TypedValue::Double(s) if s > 0.0)); }
174
175 #[test]
176 fn execute_search_no_results() {
177 let ext = FtsExtension::new();
178 let empty = HashMap::new();
179
180 ext.execute("add", &["hello world".into()], &empty).unwrap();
181
182 let results = ext
183 .execute("search", &["quantum".into(), "10".into()], &empty)
184 .unwrap();
185 assert!(results.is_empty());
186 }
187
188 #[test]
189 fn execute_clear() {
190 let ext = FtsExtension::new();
191 let empty = HashMap::new();
192
193 ext.execute("add", &["testing clear".into()], &empty)
194 .unwrap();
195 let clear_result = ext.execute("clear", &[], &empty).unwrap();
196 assert_eq!(clear_result[0][0], TypedValue::String(SmolStr::new("ok")));
197
198 let results = ext
200 .execute("search", &["testing".into(), "10".into()], &empty)
201 .unwrap();
202 assert!(results.is_empty());
203 }
204
205 #[test]
206 fn execute_unknown_procedure() {
207 let ext = FtsExtension::new();
208 let empty = HashMap::new();
209 assert!(ext.execute("nonexistent", &[], &empty).is_err());
210 }
211
212 #[test]
213 fn multiple_documents_ranked() {
214 let ext = FtsExtension::new();
215 let empty = HashMap::new();
216
217 ext.execute("add", &["python for data science".into()], &empty)
218 .unwrap();
219 ext.execute("add", &["rust systems programming language".into()], &empty)
220 .unwrap();
221 ext.execute("add", &["rust rust rust all about rust".into()], &empty)
222 .unwrap();
223
224 let results = ext
225 .execute("search", &["rust".into(), "10".into()], &empty)
226 .unwrap();
227 assert!(results.len() >= 2);
229 let doc_ids: Vec<i64> = results
231 .iter()
232 .map(|r| match r[0] {
233 TypedValue::Int64(id) => id,
234 _ => panic!("expected Int64"),
235 })
236 .collect();
237 assert!(!doc_ids.contains(&0)); }
239}