lindera_py/
segmenter.rs

1use std::borrow::Cow;
2use std::str::FromStr;
3
4use pyo3::exceptions::PyValueError;
5use pyo3::prelude::*;
6use pyo3::types::PyDict;
7
8use lindera::mode::Mode;
9use lindera::segmenter::Segmenter;
10
11use crate::dictionary::{PyDictionary, PyUserDictionary};
12use crate::token::PyToken;
13use crate::util::pydict_to_value;
14
15#[pyclass(name = "Segmenter")]
16#[derive(Clone)]
17pub struct PySegmenter {
18    pub inner: Segmenter,
19}
20
21#[pymethods]
22impl PySegmenter {
23    #[new]
24    #[pyo3(signature = (mode, dictionary, user_dictionary=None))]
25    fn new(
26        mode: &str,
27        dictionary: PyDictionary,
28        user_dictionary: Option<PyUserDictionary>,
29    ) -> PyResult<Self> {
30        let m = Mode::from_str(mode)
31            .map_err(|err| PyValueError::new_err(format!("Failed to create mode: {err}")))?;
32        let d = dictionary.inner;
33        let u = user_dictionary.map(|d| d.inner);
34
35        let segmenter = Segmenter::new(m, d, u);
36
37        Ok(Self { inner: segmenter })
38    }
39
40    #[pyo3(signature = (config))]
41    #[allow(clippy::wrong_self_convention)]
42    fn from_config(&self, config: &Bound<'_, PyDict>) -> PyResult<Self> {
43        let config_value = pydict_to_value(config)?;
44        let segmenter = Segmenter::from_config(&config_value)
45            .map_err(|err| PyValueError::new_err(format!("Failed to create tokenizer: {err}")))?;
46
47        Ok(Self { inner: segmenter })
48    }
49
50    #[pyo3(signature = (text))]
51    fn segment(&self, text: &str) -> PyResult<Vec<PyToken>> {
52        let mut tokens = self
53            .inner
54            .segment(Cow::Borrowed(text))
55            .map_err(|err| PyValueError::new_err(format!("Failed to tokenize text: {err}")))?;
56
57        Ok(tokens
58            .iter_mut()
59            .map(|t| PyToken {
60                #[allow(clippy::suspicious_to_owned)]
61                text: t.text.to_owned().to_string(),
62                byte_start: t.byte_start,
63                byte_end: t.byte_end,
64                position: t.position,
65                position_length: t.position_length,
66                details: t.details().iter().map(|d| d.to_string()).collect(),
67            })
68            .collect())
69    }
70}