use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use pyo3::exceptions::{PyRuntimeError, PyValueError};
use pyo3::prelude::*;
use tokio::runtime::Runtime;
use hermes_core::{
BooleanQuery, FieldValue, FsDirectory, Index, IndexConfig, Schema, TermQuery, search_segment,
};
fn get_runtime() -> &'static Runtime {
use std::sync::OnceLock;
static RUNTIME: OnceLock<Runtime> = OnceLock::new();
RUNTIME.get_or_init(|| Runtime::new().expect("Failed to create Tokio runtime"))
}
#[pyclass]
struct HermesIndex {
index: Arc<Index<FsDirectory>>,
schema: Arc<Schema>,
}
#[pymethods]
impl HermesIndex {
#[staticmethod]
fn open(path: &str) -> PyResult<Self> {
let rt = get_runtime();
rt.block_on(async {
let dir = FsDirectory::new(PathBuf::from(path));
let config = IndexConfig::default();
let index = Index::open(dir, config)
.await
.map_err(|e| PyRuntimeError::new_err(format!("Failed to open index: {}", e)))?;
let schema = Arc::new(index.schema().clone());
Ok(HermesIndex {
index: Arc::new(index),
schema,
})
})
}
fn num_docs(&self) -> u32 {
self.index.num_docs()
}
fn num_segments(&self) -> usize {
self.index.segment_readers().len()
}
fn field_names(&self) -> Vec<String> {
self.schema
.fields()
.map(|(_, entry)| entry.name.clone())
.collect()
}
fn get_document(&self, doc_id: u32) -> PyResult<Option<HashMap<String, Py<PyAny>>>> {
let rt = get_runtime();
rt.block_on(async {
let doc =
self.index.doc(doc_id).await.map_err(|e| {
PyRuntimeError::new_err(format!("Failed to get document: {}", e))
})?;
match doc {
Some(doc) => Python::attach(|py| {
let mut result = HashMap::new();
for (field, value) in doc.field_values() {
if let Some(entry) = self.schema.get_field_entry(*field) {
let py_value = field_value_to_py(py, value);
result.insert(entry.name.clone(), py_value);
}
}
Ok(Some(result))
}),
None => Ok(None),
}
})
}
fn search_term(
&self,
field: &str,
term: &str,
limit: Option<usize>,
) -> PyResult<Vec<(u32, f32)>> {
let field_id = self
.schema
.get_field(field)
.ok_or_else(|| PyValueError::new_err(format!("Field '{}' not found", field)))?;
let query = TermQuery::text(field_id, term);
let limit = limit.unwrap_or(10);
let rt = get_runtime();
rt.block_on(async {
let mut all_results = Vec::new();
for segment in self.index.segment_readers() {
let results = search_segment(&segment, &query, limit)
.await
.map_err(|e| PyRuntimeError::new_err(format!("Search failed: {}", e)))?;
for result in results {
all_results.push((result.doc_id + segment.doc_id_offset(), result.score));
}
}
all_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
all_results.truncate(limit);
Ok(all_results)
})
}
fn search_boolean(
&self,
must: Option<Vec<(String, String)>>,
should: Option<Vec<(String, String)>>,
must_not: Option<Vec<(String, String)>>,
limit: Option<usize>,
) -> PyResult<Vec<(u32, f32)>> {
let mut query = BooleanQuery::new();
if let Some(must_terms) = must {
for (field, term) in must_terms {
let field_id = self
.schema
.get_field(&field)
.ok_or_else(|| PyValueError::new_err(format!("Field '{}' not found", field)))?;
query = query.must(TermQuery::text(field_id, &term));
}
}
if let Some(should_terms) = should {
for (field, term) in should_terms {
let field_id = self
.schema
.get_field(&field)
.ok_or_else(|| PyValueError::new_err(format!("Field '{}' not found", field)))?;
query = query.should(TermQuery::text(field_id, &term));
}
}
if let Some(must_not_terms) = must_not {
for (field, term) in must_not_terms {
let field_id = self
.schema
.get_field(&field)
.ok_or_else(|| PyValueError::new_err(format!("Field '{}' not found", field)))?;
query = query.must_not(TermQuery::text(field_id, &term));
}
}
let limit = limit.unwrap_or(10);
let rt = get_runtime();
rt.block_on(async {
let mut all_results = Vec::new();
for segment in self.index.segment_readers() {
let results = search_segment(&segment, &query, limit)
.await
.map_err(|e| PyRuntimeError::new_err(format!("Search failed: {}", e)))?;
for result in results {
all_results.push((result.doc_id + segment.doc_id_offset(), result.score));
}
}
all_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
all_results.truncate(limit);
Ok(all_results)
})
}
fn reload(&self) -> PyResult<()> {
let rt = get_runtime();
rt.block_on(async {
self.index
.reload()
.await
.map_err(|e| PyRuntimeError::new_err(format!("Reload failed: {}", e)))
})
}
}
fn field_value_to_py(py: Python<'_>, value: &FieldValue) -> Py<PyAny> {
match value {
FieldValue::Text(s) => s.into_pyobject(py).unwrap().into_any().unbind(),
FieldValue::U64(n) => n.into_pyobject(py).unwrap().into_any().unbind(),
FieldValue::I64(n) => n.into_pyobject(py).unwrap().into_any().unbind(),
FieldValue::F64(n) => n.into_pyobject(py).unwrap().into_any().unbind(),
FieldValue::Bytes(b) => b.into_pyobject(py).unwrap().into_any().unbind(),
}
}
#[pymodule]
fn hermes_python(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<HermesIndex>()?;
Ok(())
}