1use std::path::Path;
2use std::str::FromStr;
3
4use pyo3::exceptions::PyValueError;
5use pyo3::prelude::*;
6use pyo3::types::PyDict;
7
8use lindera::mode::Mode;
9use lindera::segmenter::Segmenter;
10use lindera::tokenizer::{Tokenizer, TokenizerBuilder};
11
12use crate::dictionary::{PyDictionary, PyUserDictionary};
13use crate::util::{pydict_to_value, value_to_pydict};
14
15pub type PyDictRef<'a> = &'a Bound<'a, PyDict>;
16
17#[pyclass(name = "TokenizerBuilder")]
18pub struct PyTokenizerBuilder {
19 pub inner: TokenizerBuilder,
20}
21
22#[pymethods]
23impl PyTokenizerBuilder {
24 #[new]
25 #[pyo3(signature = ())]
26 fn new() -> PyResult<Self> {
27 let inner = TokenizerBuilder::new().map_err(|err| {
28 PyValueError::new_err(format!("Failed to create TokenizerBuilder: {err}"))
29 })?;
30
31 Ok(Self { inner })
32 }
33
34 #[pyo3(signature = (file_path))]
35 #[allow(clippy::wrong_self_convention)]
36 fn from_file(&self, file_path: &str) -> PyResult<Self> {
37 let inner = TokenizerBuilder::from_file(Path::new(file_path)).map_err(|err| {
38 PyValueError::new_err(format!("Failed to load config from file: {err}"))
39 })?;
40
41 Ok(Self { inner })
42 }
43
44 #[pyo3(signature = (mode))]
45 fn set_mode<'a>(mut slf: PyRefMut<'a, Self>, mode: &str) -> PyResult<PyRefMut<'a, Self>> {
46 let m = Mode::from_str(mode)
47 .map_err(|err| PyValueError::new_err(format!("Failed to create mode: {err}")))?;
48
49 slf.inner.set_segmenter_mode(&m);
50
51 Ok(slf)
52 }
53
54 #[pyo3(signature = (path))]
55 fn set_dictionary<'a>(mut slf: PyRefMut<'a, Self>, path: &str) -> PyResult<PyRefMut<'a, Self>> {
56 slf.inner.set_segmenter_dictionary(path);
57
58 Ok(slf)
59 }
60
61 #[pyo3(signature = (uri))]
62 fn set_user_dictionary<'a>(
63 mut slf: PyRefMut<'a, Self>,
64 uri: &str,
65 ) -> PyResult<PyRefMut<'a, Self>> {
66 slf.inner.set_segmenter_user_dictionary(uri);
67 Ok(slf)
68 }
69
70 #[pyo3(signature = (kind, args=None))]
72 fn append_character_filter<'a>(
73 mut slf: PyRefMut<'a, Self>,
74 kind: &str,
75 args: Option<&Bound<'_, PyDict>>,
76 ) -> PyResult<PyRefMut<'a, Self>> {
77 let filter_args = if let Some(dict) = args {
78 pydict_to_value(dict)?
79 } else {
80 serde_json::Value::Object(serde_json::Map::new())
81 };
82
83 slf.inner.append_character_filter(kind, &filter_args);
84
85 Ok(slf)
86 }
87
88 #[pyo3(signature = (kind, args=None))]
89 fn append_token_filter<'a>(
90 mut slf: PyRefMut<'a, Self>,
91 kind: &str,
92 args: Option<&Bound<'_, PyDict>>,
93 ) -> PyResult<PyRefMut<'a, Self>> {
94 let filter_args = if let Some(dict) = args {
95 pydict_to_value(dict)?
96 } else {
97 serde_json::Value::Object(serde_json::Map::new())
98 };
99
100 slf.inner.append_token_filter(kind, &filter_args);
101
102 Ok(slf)
103 }
104
105 #[pyo3(signature = ())]
106 fn build(&self) -> PyResult<PyTokenizer> {
107 let tokenizer = self
108 .inner
109 .build()
110 .map_err(|err| PyValueError::new_err(format!("Failed to build tokenizer: {err}")))?;
111
112 Ok(PyTokenizer { inner: tokenizer })
113 }
114}
115
116#[pyclass(name = "Tokenizer")]
117pub struct PyTokenizer {
118 inner: Tokenizer,
119}
120
121#[pymethods]
122impl PyTokenizer {
123 #[new]
124 #[pyo3(signature = (dictionary, mode="normal", user_dictionary=None))]
125 fn new(
126 dictionary: PyDictionary,
127 mode: &str,
128 user_dictionary: Option<PyUserDictionary>,
129 ) -> PyResult<Self> {
130 let m = Mode::from_str(mode)
131 .map_err(|err| PyValueError::new_err(format!("Failed to create mode: {err}")))?;
132
133 let dict = dictionary.inner;
134 let user_dict = user_dictionary.map(|d| d.inner);
135
136 let segmenter = Segmenter::new(m, dict, user_dict);
137 let tokenizer = Tokenizer::new(segmenter);
138
139 Ok(Self { inner: tokenizer })
140 }
141
142 #[pyo3(signature = (text))]
143 fn tokenize(&self, py: Python<'_>, text: &str) -> PyResult<Vec<Py<PyAny>>> {
144 let mut tokens = self
146 .inner
147 .tokenize(text)
148 .map_err(|err| PyValueError::new_err(format!("Failed to tokenize text: {err}")))?;
149
150 let py_tokens: Vec<Py<PyAny>> = tokens
152 .iter_mut()
153 .map(|t| {
154 let v = t.as_value();
155 value_to_pydict(py, &v).map_err(|err| {
156 PyValueError::new_err(format!("Failed to convert token to dict: {err}"))
157 })
158 })
159 .collect::<Result<Vec<_>, _>>()?;
160
161 Ok(py_tokens)
162 }
163}