1use std::path::Path;
2use std::str::FromStr;
3
4use pyo3::exceptions::PyValueError;
5use pyo3::prelude::*;
6use pyo3::types::PyDict;
7
8use lindera::mode::Mode;
9use lindera::tokenizer::{Tokenizer, TokenizerBuilder};
10
11use crate::segmenter::PySegmenter;
12use crate::token::PyToken;
13use crate::util::pydict_to_value;
14
15pub type PyDictRef<'a> = &'a Bound<'a, PyDict>;
16
17#[pyclass(name = "TokenizerBuilder")]
18pub struct PyTokenizerBuilder {
19 pub inner: TokenizerBuilder,
20}
21
22#[pymethods]
23impl PyTokenizerBuilder {
24 #[new]
25 #[pyo3(signature = ())]
26 fn new() -> PyResult<Self> {
27 let inner = TokenizerBuilder::new().map_err(|err| {
28 PyValueError::new_err(format!("Failed to create TokenizerBuilder: {err}"))
29 })?;
30
31 Ok(Self { inner })
32 }
33
34 #[pyo3(signature = (file_path))]
35 #[allow(clippy::wrong_self_convention)]
36 fn from_file(&self, file_path: &str) -> PyResult<Self> {
37 let inner = TokenizerBuilder::from_file(Path::new(file_path)).map_err(|err| {
38 PyValueError::new_err(format!("Failed to load config from file: {err}"))
39 })?;
40
41 Ok(Self { inner })
42 }
43
44 #[pyo3(signature = (mode))]
45 fn set_mode<'a>(mut slf: PyRefMut<'a, Self>, mode: &str) -> PyResult<PyRefMut<'a, Self>> {
46 let m = Mode::from_str(mode)
47 .map_err(|err| PyValueError::new_err(format!("Failed to create mode: {err}")))?;
48
49 slf.inner.set_segmenter_mode(&m);
50
51 Ok(slf)
52 }
53
54 #[pyo3(signature = (path))]
55 fn set_dictionary<'a>(mut slf: PyRefMut<'a, Self>, path: &str) -> PyResult<PyRefMut<'a, Self>> {
56 slf.inner.set_segmenter_dictionary(path);
57
58 Ok(slf)
59 }
60
61 #[pyo3(signature = (uri))]
62 fn set_user_dictionary<'a>(
63 mut slf: PyRefMut<'a, Self>,
64 uri: &str,
65 ) -> PyResult<PyRefMut<'a, Self>> {
66 slf.inner.set_segmenter_user_dictionary(uri);
67 Ok(slf)
68 }
69
70 #[pyo3(signature = (kind, args=None))]
72 fn append_character_filter<'a>(
73 mut slf: PyRefMut<'a, Self>,
74 kind: &str,
75 args: Option<&Bound<'_, PyDict>>,
76 ) -> PyResult<PyRefMut<'a, Self>> {
77 let filter_args = if let Some(dict) = args {
78 pydict_to_value(dict)?
79 } else {
80 serde_json::Value::Object(serde_json::Map::new())
81 };
82
83 slf.inner.append_character_filter(kind, &filter_args);
84
85 Ok(slf)
86 }
87
88 #[pyo3(signature = (kind, args=None))]
89 fn append_token_filter<'a>(
90 mut slf: PyRefMut<'a, Self>,
91 kind: &str,
92 args: Option<&Bound<'_, PyDict>>,
93 ) -> PyResult<PyRefMut<'a, Self>> {
94 let filter_args = if let Some(dict) = args {
95 pydict_to_value(dict)?
96 } else {
97 serde_json::Value::Object(serde_json::Map::new())
98 };
99
100 slf.inner.append_token_filter(kind, &filter_args);
101
102 Ok(slf)
103 }
104
105 #[pyo3(signature = ())]
106 fn build(&self) -> PyResult<PyTokenizer> {
107 let tokenizer = self
108 .inner
109 .build()
110 .map_err(|err| PyValueError::new_err(format!("Failed to build tokenizer: {err}")))?;
111
112 Ok(PyTokenizer { inner: tokenizer })
113 }
114}
115
116#[pyclass(name = "Tokenizer")]
117pub struct PyTokenizer {
118 inner: Tokenizer,
119}
120
121#[pymethods]
122impl PyTokenizer {
123 #[new]
124 #[pyo3(signature = (segmenter))]
125 fn new(segmenter: PySegmenter) -> PyResult<Self> {
126 Ok(Self {
127 inner: Tokenizer::new(segmenter.inner),
128 })
129 }
130
131 #[pyo3(signature = (config))]
132 #[allow(clippy::wrong_self_convention)]
133 fn from_config(&self, config: &Bound<'_, PyDict>) -> PyResult<Self> {
134 let config_value = pydict_to_value(config)?;
135 let tokenizer = Tokenizer::from_config(&config_value)
136 .map_err(|err| PyValueError::new_err(format!("Failed to create tokenizer: {err}")))?;
137
138 Ok(Self { inner: tokenizer })
139 }
140
141 #[pyo3(signature = (text))]
142 fn tokenize(&self, text: &str) -> PyResult<Vec<PyToken>> {
143 let mut tokens = self
145 .inner
146 .tokenize(text)
147 .map_err(|err| PyValueError::new_err(format!("Failed to tokenize text: {err}")))?;
148
149 let py_tokens: Vec<PyToken> = tokens
151 .iter_mut()
152 .map(|t| PyToken {
153 text: t.text.to_string(),
154 byte_start: t.byte_start,
155 byte_end: t.byte_end,
156 position: t.position,
157 position_length: t.position_length,
158 details: t.details().iter().map(|d| d.to_string()).collect(),
159 })
160 .collect();
161
162 Ok(py_tokens)
163 }
164}