hermes_python/
lib.rs

1//! Python bindings for Hermes search engine
2
3use std::collections::HashMap;
4use std::path::PathBuf;
5use std::sync::Arc;
6
7use pyo3::exceptions::{PyRuntimeError, PyValueError};
8use pyo3::prelude::*;
9use tokio::runtime::Runtime;
10
11use hermes_core::{
12    BooleanQuery, FieldValue, FsDirectory, Index, IndexConfig, Schema, TermQuery, search_segment,
13};
14
15/// Create a tokio runtime for async operations
16fn get_runtime() -> &'static Runtime {
17    use std::sync::OnceLock;
18    static RUNTIME: OnceLock<Runtime> = OnceLock::new();
19    RUNTIME.get_or_init(|| Runtime::new().expect("Failed to create Tokio runtime"))
20}
21
22/// Python wrapper for Hermes Index (read-only)
23#[pyclass]
24struct HermesIndex {
25    index: Arc<Index<FsDirectory>>,
26    schema: Arc<Schema>,
27}
28
29#[pymethods]
30impl HermesIndex {
31    /// Open an existing index
32    #[staticmethod]
33    fn open(path: &str) -> PyResult<Self> {
34        let rt = get_runtime();
35
36        rt.block_on(async {
37            let dir = FsDirectory::new(PathBuf::from(path));
38
39            let config = IndexConfig::default();
40            let index = Index::open(dir, config)
41                .await
42                .map_err(|e| PyRuntimeError::new_err(format!("Failed to open index: {}", e)))?;
43
44            let schema = Arc::new(index.schema().clone());
45
46            Ok(HermesIndex {
47                index: Arc::new(index),
48                schema,
49            })
50        })
51    }
52
53    /// Get the number of documents in the index
54    fn num_docs(&self) -> u32 {
55        self.index.num_docs()
56    }
57
58    /// Get the number of segments
59    fn num_segments(&self) -> usize {
60        self.index.segment_readers().len()
61    }
62
63    /// Get field names
64    fn field_names(&self) -> Vec<String> {
65        self.schema
66            .fields()
67            .map(|(_, entry)| entry.name.clone())
68            .collect()
69    }
70
71    /// Get a document by ID
72    fn get_document(&self, doc_id: u32) -> PyResult<Option<HashMap<String, Py<PyAny>>>> {
73        let rt = get_runtime();
74
75        rt.block_on(async {
76            let doc =
77                self.index.doc(doc_id).await.map_err(|e| {
78                    PyRuntimeError::new_err(format!("Failed to get document: {}", e))
79                })?;
80
81            match doc {
82                Some(doc) => Python::attach(|py| {
83                    let mut result = HashMap::new();
84                    for (field, value) in doc.field_values() {
85                        if let Some(entry) = self.schema.get_field_entry(*field) {
86                            let py_value = field_value_to_py(py, value);
87                            result.insert(entry.name.clone(), py_value);
88                        }
89                    }
90                    Ok(Some(result))
91                }),
92                None => Ok(None),
93            }
94        })
95    }
96
97    /// Search the index with a term query
98    fn search_term(
99        &self,
100        field: &str,
101        term: &str,
102        limit: Option<usize>,
103    ) -> PyResult<Vec<(u32, f32)>> {
104        let field_id = self
105            .schema
106            .get_field(field)
107            .ok_or_else(|| PyValueError::new_err(format!("Field '{}' not found", field)))?;
108
109        let query = TermQuery::text(field_id, term);
110        let limit = limit.unwrap_or(10);
111
112        let rt = get_runtime();
113
114        rt.block_on(async {
115            let mut all_results = Vec::new();
116
117            for segment in self.index.segment_readers() {
118                let results = search_segment(&segment, &query, limit)
119                    .await
120                    .map_err(|e| PyRuntimeError::new_err(format!("Search failed: {}", e)))?;
121
122                for result in results {
123                    all_results.push((result.doc_id + segment.doc_id_offset(), result.score));
124                }
125            }
126
127            // Sort by score descending
128            all_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
129            all_results.truncate(limit);
130
131            Ok(all_results)
132        })
133    }
134
135    /// Search with a boolean query
136    fn search_boolean(
137        &self,
138        must: Option<Vec<(String, String)>>,
139        should: Option<Vec<(String, String)>>,
140        must_not: Option<Vec<(String, String)>>,
141        limit: Option<usize>,
142    ) -> PyResult<Vec<(u32, f32)>> {
143        let mut query = BooleanQuery::new();
144
145        if let Some(must_terms) = must {
146            for (field, term) in must_terms {
147                let field_id = self
148                    .schema
149                    .get_field(&field)
150                    .ok_or_else(|| PyValueError::new_err(format!("Field '{}' not found", field)))?;
151                query = query.must(TermQuery::text(field_id, &term));
152            }
153        }
154
155        if let Some(should_terms) = should {
156            for (field, term) in should_terms {
157                let field_id = self
158                    .schema
159                    .get_field(&field)
160                    .ok_or_else(|| PyValueError::new_err(format!("Field '{}' not found", field)))?;
161                query = query.should(TermQuery::text(field_id, &term));
162            }
163        }
164
165        if let Some(must_not_terms) = must_not {
166            for (field, term) in must_not_terms {
167                let field_id = self
168                    .schema
169                    .get_field(&field)
170                    .ok_or_else(|| PyValueError::new_err(format!("Field '{}' not found", field)))?;
171                query = query.must_not(TermQuery::text(field_id, &term));
172            }
173        }
174
175        let limit = limit.unwrap_or(10);
176        let rt = get_runtime();
177
178        rt.block_on(async {
179            let mut all_results = Vec::new();
180
181            for segment in self.index.segment_readers() {
182                let results = search_segment(&segment, &query, limit)
183                    .await
184                    .map_err(|e| PyRuntimeError::new_err(format!("Search failed: {}", e)))?;
185
186                for result in results {
187                    all_results.push((result.doc_id + segment.doc_id_offset(), result.score));
188                }
189            }
190
191            all_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
192            all_results.truncate(limit);
193
194            Ok(all_results)
195        })
196    }
197
198    /// Reload the index to see new segments
199    fn reload(&self) -> PyResult<()> {
200        let rt = get_runtime();
201        rt.block_on(async {
202            self.index
203                .reload()
204                .await
205                .map_err(|e| PyRuntimeError::new_err(format!("Reload failed: {}", e)))
206        })
207    }
208}
209
210fn field_value_to_py(py: Python<'_>, value: &FieldValue) -> Py<PyAny> {
211    match value {
212        FieldValue::Text(s) => s.into_pyobject(py).unwrap().into_any().unbind(),
213        FieldValue::U64(n) => n.into_pyobject(py).unwrap().into_any().unbind(),
214        FieldValue::I64(n) => n.into_pyobject(py).unwrap().into_any().unbind(),
215        FieldValue::F64(n) => n.into_pyobject(py).unwrap().into_any().unbind(),
216        FieldValue::Bytes(b) => b.into_pyobject(py).unwrap().into_any().unbind(),
217    }
218}
219
220/// Python module
221#[pymodule]
222fn hermes_python(m: &Bound<'_, PyModule>) -> PyResult<()> {
223    m.add_class::<HermesIndex>()?;
224    Ok(())
225}