use pyo3::prelude::*;
use super::random_segmenter::{BaseRandomSegmenter, RandomSegmenter};
#[pyclass(name = "RandomSegmenter", subclass, from_py_object)]
#[derive(Clone)]
pub struct PyRandomSegmenter {
pub inner: RandomSegmenter,
}
impl BaseRandomSegmenter for PyRandomSegmenter {
fn prob(&self) -> f64 {
self.inner.prob()
}
fn from_prob(prob: f64) -> Self {
Self {
inner: RandomSegmenter::from_prob(prob),
}
}
}
#[pymethods]
impl PyRandomSegmenter {
#[new]
#[pyo3(signature = (*, prob))]
fn new(prob: f64) -> PyResult<Self> {
RandomSegmenter::new(prob)
.map(|inner| Self { inner })
.map_err(PyErr::from)
}
#[pyo3(signature = (sent_strs, *, offsets=false))]
fn predict(
&self,
py: Python<'_>,
sent_strs: Vec<String>,
offsets: bool,
) -> PyResult<Py<PyAny>> {
let words = BaseRandomSegmenter::predict(self, sent_strs);
if offsets {
let with_offsets = super::attach_offsets(words);
Ok(with_offsets.into_pyobject(py)?.into_any().unbind())
} else {
Ok(words.into_pyobject(py)?.into_any().unbind())
}
}
}