use pyo3::prelude::*;
use std::path::PathBuf;
use super::{BaseLongestStringMatching, LongestStringMatching};
use crate::persistence::pathbuf_to_string;
use crate::trie::Trie;
#[pyclass(name = "LongestStringMatching", subclass, from_py_object)]
#[derive(Clone)]
pub struct PyLongestStringMatching {
pub inner: LongestStringMatching,
}
impl BaseLongestStringMatching for PyLongestStringMatching {
fn max_word_length(&self) -> usize {
self.inner.max_word_length()
}
fn trie(&self) -> &Trie<char, ()> {
self.inner.trie()
}
fn trie_mut(&mut self) -> &mut Trie<char, ()> {
self.inner.trie_mut()
}
fn from_parts(max_word_length: usize, trie: Trie<char, ()>) -> Self {
Self {
inner: LongestStringMatching::from_parts(max_word_length, trie),
}
}
}
#[pymethods]
impl PyLongestStringMatching {
#[new]
#[pyo3(signature = (*, max_word_length))]
fn new(max_word_length: usize) -> PyResult<Self> {
LongestStringMatching::new(max_word_length)
.map(|inner| Self { inner })
.map_err(PyErr::from)
}
fn fit(&mut self, sents: Vec<Vec<String>>) {
BaseLongestStringMatching::fit(self, sents);
}
#[pyo3(signature = (sent_strs, *, offsets=false))]
fn predict(
&self,
py: Python<'_>,
sent_strs: Vec<String>,
offsets: bool,
) -> PyResult<Py<PyAny>> {
let words = BaseLongestStringMatching::predict(self, sent_strs);
if offsets {
let with_offsets = super::super::attach_offsets(words);
Ok(with_offsets.into_pyobject(py)?.into_any().unbind())
} else {
Ok(words.into_pyobject(py)?.into_any().unbind())
}
}
fn save(&self, path: PathBuf) -> PyResult<()> {
let path = pathbuf_to_string(path)?;
self.save_to_path(&path).map_err(PyErr::from)
}
fn load(&mut self, path: PathBuf) -> PyResult<()> {
let path = pathbuf_to_string(path)?;
self.load_from_path(&path).map_err(PyErr::from)
}
}