use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use crate::dawg::Dawg;
use crate::trie::Trie;
#[pyclass(name = "Trie")]
pub struct PyTrie {
inner: Trie,
}
#[pymethods]
impl PyTrie {
#[new]
fn new() -> Self {
PyTrie { inner: Trie::new() }
}
fn __contains__(&self, word: &str) -> bool {
self.inner.contains(word)
}
fn __len__(&self) -> usize {
self.inner.node_count()
}
fn __repr__(&self) -> String {
format!(
"Trie(words={}, nodes={})",
self.inner.word_count(),
self.inner.node_count()
)
}
#[pyo3(signature = (word, count=1))]
fn add(&mut self, word: &str, count: usize) -> PyResult<()> {
self.inner
.add(word, count)
.map_err(|e| PyValueError::new_err(e.to_string()))
}
fn add_all(&mut self, source: &Bound<'_, PyAny>) -> PyResult<()> {
let iter = source.try_iter()?;
for item in iter {
let word: String = item?.extract()?;
self.inner
.add(&word, 1)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
}
Ok(())
}
fn add_from_file(&mut self, path: &str) -> PyResult<()> {
self.inner
.add_from_file(path)
.map_err(|e| PyValueError::new_err(e.to_string()))
}
fn contains_prefix(&self, prefix: &str) -> bool {
self.inner.contains_prefix(prefix)
}
fn get_word_count(&self) -> usize {
self.inner.word_count()
}
#[pyo3(signature = (pattern, with_count=false))]
fn search(&self, py: Python<'_>, pattern: &str, with_count: bool) -> PyResult<Py<PyAny>> {
if with_count {
let results = self
.inner
.search_with_count(pattern)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
Ok(results.into_pyobject(py)?.unbind().into_any())
} else {
let results = self
.inner
.search(pattern)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
Ok(results.into_pyobject(py)?.unbind().into_any())
}
}
#[pyo3(signature = (prefix, with_count=false))]
fn search_with_prefix(
&self,
py: Python<'_>,
prefix: &str,
with_count: bool,
) -> PyResult<Py<PyAny>> {
if with_count {
Ok(self
.inner
.search_with_prefix_count(prefix)
.into_pyobject(py)?
.into())
} else {
Ok(self
.inner
.search_with_prefix(prefix)
.into_pyobject(py)?
.into())
}
}
#[pyo3(signature = (word, dist=0, with_count=false))]
fn search_within_distance(
&self,
py: Python<'_>,
word: &str,
dist: usize,
with_count: bool,
) -> PyResult<Py<PyAny>> {
if with_count {
Ok(self
.inner
.search_within_distance_count(word, dist)
.into_pyobject(py)?
.into())
} else {
Ok(self
.inner
.search_within_distance(word, dist)
.into_pyobject(py)?
.into())
}
}
}
#[pyclass(name = "DAWG")]
pub struct PyDAWG {
inner: Dawg,
}
#[pymethods]
impl PyDAWG {
#[new]
fn new() -> Self {
PyDAWG { inner: Dawg::new() }
}
fn __contains__(&self, word: &str) -> bool {
self.inner.contains(word)
}
fn __len__(&self) -> usize {
self.inner.node_count()
}
fn __repr__(&self) -> String {
format!(
"DAWG(words={}, nodes={})",
self.inner.word_count(),
self.inner.node_count()
)
}
#[pyo3(signature = (word, count=1))]
fn add(&mut self, word: &str, count: usize) -> PyResult<()> {
self.inner
.add(word, count)
.map_err(|e| PyValueError::new_err(e.to_string()))
}
fn add_all(&mut self, source: &Bound<'_, PyAny>) -> PyResult<()> {
let iter = source.try_iter()?;
let mut words: Vec<String> = iter
.map(|item| item.and_then(|i| i.extract::<String>()))
.collect::<PyResult<_>>()?;
words.sort();
for word in words {
self.inner
.add(&word, 1)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
}
self.inner.reduce();
Ok(())
}
fn add_from_file(&mut self, path: &str) -> PyResult<()> {
self.inner
.add_from_file(path)
.map_err(|e| PyValueError::new_err(e.to_string()))
}
fn reduce(&mut self) {
self.inner.reduce();
}
fn contains_prefix(&self, prefix: &str) -> bool {
self.inner.contains_prefix(prefix)
}
fn get_word_count(&self) -> usize {
self.inner.word_count()
}
#[pyo3(signature = (pattern, with_count=false))]
fn search(&self, py: Python<'_>, pattern: &str, with_count: bool) -> PyResult<Py<PyAny>> {
if with_count {
let results = self
.inner
.search_with_count(pattern)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
Ok(results.into_pyobject(py)?.unbind().into_any())
} else {
let results = self
.inner
.search(pattern)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
Ok(results.into_pyobject(py)?.unbind().into_any())
}
}
#[pyo3(signature = (prefix, with_count=false))]
fn search_with_prefix(
&self,
py: Python<'_>,
prefix: &str,
with_count: bool,
) -> PyResult<Py<PyAny>> {
if with_count {
Ok(self
.inner
.search_with_prefix_count(prefix)
.into_pyobject(py)?
.into())
} else {
Ok(self
.inner
.search_with_prefix(prefix)
.into_pyobject(py)?
.into())
}
}
#[pyo3(signature = (word, dist=0, with_count=false))]
fn search_within_distance(
&self,
py: Python<'_>,
word: &str,
dist: usize,
with_count: bool,
) -> PyResult<Py<PyAny>> {
if with_count {
Ok(self
.inner
.search_within_distance_count(word, dist)
.into_pyobject(py)?
.into())
} else {
Ok(self
.inner
.search_within_distance(word, dist)
.into_pyobject(py)?
.into())
}
}
}
#[pymodule]
fn lexrs(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyTrie>()?;
m.add_class::<PyDAWG>()?;
Ok(())
}