1use std::collections::HashMap;
4use std::path::PathBuf;
5use std::sync::Arc;
6
7use pyo3::exceptions::{PyRuntimeError, PyValueError};
8use pyo3::prelude::*;
9use tokio::runtime::Runtime;
10
11use hermes_core::{
12 BooleanQuery, FieldValue, FsDirectory, Index, IndexConfig, Schema, TermQuery, search_segment,
13};
14
15fn get_runtime() -> &'static Runtime {
17 use std::sync::OnceLock;
18 static RUNTIME: OnceLock<Runtime> = OnceLock::new();
19 RUNTIME.get_or_init(|| Runtime::new().expect("Failed to create Tokio runtime"))
20}
21
22#[pyclass]
24struct HermesIndex {
25 index: Arc<Index<FsDirectory>>,
26 schema: Arc<Schema>,
27}
28
29#[pymethods]
30impl HermesIndex {
31 #[staticmethod]
33 fn open(path: &str) -> PyResult<Self> {
34 let rt = get_runtime();
35
36 rt.block_on(async {
37 let dir = FsDirectory::new(PathBuf::from(path));
38
39 let config = IndexConfig::default();
40 let index = Index::open(dir, config)
41 .await
42 .map_err(|e| PyRuntimeError::new_err(format!("Failed to open index: {}", e)))?;
43
44 let schema = Arc::new(index.schema().clone());
45
46 Ok(HermesIndex {
47 index: Arc::new(index),
48 schema,
49 })
50 })
51 }
52
53 fn num_docs(&self) -> u32 {
55 self.index.num_docs()
56 }
57
58 fn num_segments(&self) -> usize {
60 self.index.segment_readers().len()
61 }
62
63 fn field_names(&self) -> Vec<String> {
65 self.schema
66 .fields()
67 .map(|(_, entry)| entry.name.clone())
68 .collect()
69 }
70
71 fn get_document(&self, doc_id: u32) -> PyResult<Option<HashMap<String, Py<PyAny>>>> {
73 let rt = get_runtime();
74
75 rt.block_on(async {
76 let doc =
77 self.index.doc(doc_id).await.map_err(|e| {
78 PyRuntimeError::new_err(format!("Failed to get document: {}", e))
79 })?;
80
81 match doc {
82 Some(doc) => Python::attach(|py| {
83 let mut result = HashMap::new();
84 for (field, value) in doc.field_values() {
85 if let Some(entry) = self.schema.get_field_entry(*field) {
86 let py_value = field_value_to_py(py, value);
87 result.insert(entry.name.clone(), py_value);
88 }
89 }
90 Ok(Some(result))
91 }),
92 None => Ok(None),
93 }
94 })
95 }
96
97 fn search_term(
99 &self,
100 field: &str,
101 term: &str,
102 limit: Option<usize>,
103 ) -> PyResult<Vec<(u32, f32)>> {
104 let field_id = self
105 .schema
106 .get_field(field)
107 .ok_or_else(|| PyValueError::new_err(format!("Field '{}' not found", field)))?;
108
109 let query = TermQuery::text(field_id, term);
110 let limit = limit.unwrap_or(10);
111
112 let rt = get_runtime();
113
114 rt.block_on(async {
115 let mut all_results = Vec::new();
116
117 for segment in self.index.segment_readers() {
118 let results = search_segment(&segment, &query, limit)
119 .await
120 .map_err(|e| PyRuntimeError::new_err(format!("Search failed: {}", e)))?;
121
122 for result in results {
123 all_results.push((result.doc_id + segment.doc_id_offset(), result.score));
124 }
125 }
126
127 all_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
129 all_results.truncate(limit);
130
131 Ok(all_results)
132 })
133 }
134
135 fn search_boolean(
137 &self,
138 must: Option<Vec<(String, String)>>,
139 should: Option<Vec<(String, String)>>,
140 must_not: Option<Vec<(String, String)>>,
141 limit: Option<usize>,
142 ) -> PyResult<Vec<(u32, f32)>> {
143 let mut query = BooleanQuery::new();
144
145 if let Some(must_terms) = must {
146 for (field, term) in must_terms {
147 let field_id = self
148 .schema
149 .get_field(&field)
150 .ok_or_else(|| PyValueError::new_err(format!("Field '{}' not found", field)))?;
151 query = query.must(TermQuery::text(field_id, &term));
152 }
153 }
154
155 if let Some(should_terms) = should {
156 for (field, term) in should_terms {
157 let field_id = self
158 .schema
159 .get_field(&field)
160 .ok_or_else(|| PyValueError::new_err(format!("Field '{}' not found", field)))?;
161 query = query.should(TermQuery::text(field_id, &term));
162 }
163 }
164
165 if let Some(must_not_terms) = must_not {
166 for (field, term) in must_not_terms {
167 let field_id = self
168 .schema
169 .get_field(&field)
170 .ok_or_else(|| PyValueError::new_err(format!("Field '{}' not found", field)))?;
171 query = query.must_not(TermQuery::text(field_id, &term));
172 }
173 }
174
175 let limit = limit.unwrap_or(10);
176 let rt = get_runtime();
177
178 rt.block_on(async {
179 let mut all_results = Vec::new();
180
181 for segment in self.index.segment_readers() {
182 let results = search_segment(&segment, &query, limit)
183 .await
184 .map_err(|e| PyRuntimeError::new_err(format!("Search failed: {}", e)))?;
185
186 for result in results {
187 all_results.push((result.doc_id + segment.doc_id_offset(), result.score));
188 }
189 }
190
191 all_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
192 all_results.truncate(limit);
193
194 Ok(all_results)
195 })
196 }
197
198 fn reload(&self) -> PyResult<()> {
200 let rt = get_runtime();
201 rt.block_on(async {
202 self.index
203 .reload()
204 .await
205 .map_err(|e| PyRuntimeError::new_err(format!("Reload failed: {}", e)))
206 })
207 }
208}
209
210fn field_value_to_py(py: Python<'_>, value: &FieldValue) -> Py<PyAny> {
211 match value {
212 FieldValue::Text(s) => s.into_pyobject(py).unwrap().into_any().unbind(),
213 FieldValue::U64(n) => n.into_pyobject(py).unwrap().into_any().unbind(),
214 FieldValue::I64(n) => n.into_pyobject(py).unwrap().into_any().unbind(),
215 FieldValue::F64(n) => n.into_pyobject(py).unwrap().into_any().unbind(),
216 FieldValue::Bytes(b) => b.into_pyobject(py).unwrap().into_any().unbind(),
217 }
218}
219
220#[pymodule]
222fn hermes_python(m: &Bound<'_, PyModule>) -> PyResult<()> {
223 m.add_class::<HermesIndex>()?;
224 Ok(())
225}