1use std::collections::HashMap;
27
28use pyo3::prelude::*;
29
30use lindera::dictionary::{FieldDefinition, FieldType, Schema};
31
32#[pyclass(name = "FieldType")]
36#[derive(Debug, Clone)]
37pub enum PyFieldType {
38 Surface,
40 LeftContextId,
42 RightContextId,
44 Cost,
46 Custom,
48}
49
50#[pymethods]
51impl PyFieldType {
52 fn __str__(&self) -> &str {
53 match self {
54 PyFieldType::Surface => "surface",
55 PyFieldType::LeftContextId => "left_context_id",
56 PyFieldType::RightContextId => "right_context_id",
57 PyFieldType::Cost => "cost",
58 PyFieldType::Custom => "custom",
59 }
60 }
61
62 fn __repr__(&self) -> String {
63 format!("FieldType.{self:?}")
64 }
65}
66
67impl From<FieldType> for PyFieldType {
68 fn from(field_type: FieldType) -> Self {
69 match field_type {
70 FieldType::Surface => PyFieldType::Surface,
71 FieldType::LeftContextId => PyFieldType::LeftContextId,
72 FieldType::RightContextId => PyFieldType::RightContextId,
73 FieldType::Cost => PyFieldType::Cost,
74 FieldType::Custom => PyFieldType::Custom,
75 }
76 }
77}
78
79impl From<PyFieldType> for FieldType {
80 fn from(field_type: PyFieldType) -> Self {
81 match field_type {
82 PyFieldType::Surface => FieldType::Surface,
83 PyFieldType::LeftContextId => FieldType::LeftContextId,
84 PyFieldType::RightContextId => FieldType::RightContextId,
85 PyFieldType::Cost => FieldType::Cost,
86 PyFieldType::Custom => FieldType::Custom,
87 }
88 }
89}
90
91#[pyclass(name = "FieldDefinition")]
95#[derive(Debug, Clone)]
96pub struct PyFieldDefinition {
97 #[pyo3(get)]
98 pub index: usize,
99 #[pyo3(get)]
100 pub name: String,
101 #[pyo3(get)]
102 pub field_type: PyFieldType,
103 #[pyo3(get)]
104 pub description: Option<String>,
105}
106
107#[pymethods]
108impl PyFieldDefinition {
109 #[new]
110 pub fn new(
111 index: usize,
112 name: String,
113 field_type: PyFieldType,
114 description: Option<String>,
115 ) -> Self {
116 Self {
117 index,
118 name,
119 field_type,
120 description,
121 }
122 }
123
124 fn __str__(&self) -> String {
125 format!("FieldDefinition(index={}, name={})", self.index, self.name)
126 }
127
128 fn __repr__(&self) -> String {
129 format!(
130 "FieldDefinition(index={}, name='{}', field_type={:?}, description={:?})",
131 self.index, self.name, self.field_type, self.description
132 )
133 }
134}
135
136impl From<FieldDefinition> for PyFieldDefinition {
137 fn from(field_def: FieldDefinition) -> Self {
138 PyFieldDefinition {
139 index: field_def.index,
140 name: field_def.name,
141 field_type: field_def.field_type.into(),
142 description: field_def.description,
143 }
144 }
145}
146
147impl From<PyFieldDefinition> for FieldDefinition {
148 fn from(field_def: PyFieldDefinition) -> Self {
149 FieldDefinition {
150 index: field_def.index,
151 name: field_def.name,
152 field_type: field_def.field_type.into(),
153 description: field_def.description,
154 }
155 }
156}
157
158#[pyclass(name = "Schema")]
173#[derive(Debug, Clone)]
174pub struct PySchema {
175 #[pyo3(get)]
176 pub fields: Vec<String>,
177 field_index_map: Option<HashMap<String, usize>>,
178}
179
180#[pymethods]
181impl PySchema {
182 #[new]
183 pub fn new(fields: Vec<String>) -> Self {
184 let mut schema = Self {
185 fields,
186 field_index_map: None,
187 };
188 schema.build_index_map();
189 schema
190 }
191
192 #[staticmethod]
193 pub fn create_default() -> Self {
194 Self::new(vec![
195 "surface".to_string(),
196 "left_context_id".to_string(),
197 "right_context_id".to_string(),
198 "cost".to_string(),
199 "major_pos".to_string(),
200 "middle_pos".to_string(),
201 "small_pos".to_string(),
202 "fine_pos".to_string(),
203 "conjugation_type".to_string(),
204 "conjugation_form".to_string(),
205 "base_form".to_string(),
206 "reading".to_string(),
207 "pronunciation".to_string(),
208 ])
209 }
210
211 pub fn get_field_index(&self, field_name: &str) -> Option<usize> {
212 self.field_index_map
213 .as_ref()
214 .and_then(|map| map.get(field_name))
215 .copied()
216 }
217
218 pub fn field_count(&self) -> usize {
219 self.get_all_fields().len()
220 }
221
222 pub fn get_field_name(&self, index: usize) -> Option<&str> {
223 self.fields.get(index).map(|s| s.as_str())
224 }
225
226 pub fn get_custom_fields(&self) -> Vec<String> {
227 if self.fields.len() > 4 {
228 self.fields[4..].to_vec()
229 } else {
230 Vec::new()
231 }
232 }
233
234 pub fn get_all_fields(&self) -> Vec<String> {
235 self.fields.clone()
236 }
237
238 pub fn get_field_by_name(&self, name: &str) -> Option<PyFieldDefinition> {
239 self.get_field_index(name).map(|index| {
240 let field_type = if index < 4 {
241 match index {
242 0 => PyFieldType::Surface,
243 1 => PyFieldType::LeftContextId,
244 2 => PyFieldType::RightContextId,
245 3 => PyFieldType::Cost,
246 _ => unreachable!(),
247 }
248 } else {
249 PyFieldType::Custom
250 };
251
252 PyFieldDefinition {
253 index,
254 name: name.to_string(),
255 field_type,
256 description: None,
257 }
258 })
259 }
260
261 pub fn validate_record(&self, record: Vec<String>) -> PyResult<()> {
262 if record.len() < self.fields.len() {
263 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
264 "CSV row has {} fields but schema requires {} fields",
265 record.len(),
266 self.fields.len()
267 )));
268 }
269
270 for (index, field_name) in self.fields.iter().enumerate() {
272 if index < record.len() && record[index].trim().is_empty() {
273 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
274 "Field {field_name} is missing or empty"
275 )));
276 }
277 }
278
279 Ok(())
280 }
281
282 fn __str__(&self) -> String {
283 format!("Schema(fields={})", self.fields.len())
284 }
285
286 fn __repr__(&self) -> String {
287 format!("Schema(fields={:?})", self.fields)
288 }
289
290 fn __len__(&self) -> usize {
291 self.fields.len()
292 }
293}
294
295impl PySchema {
296 fn build_index_map(&mut self) {
297 let mut map = HashMap::new();
298 for (i, field) in self.fields.iter().enumerate() {
299 map.insert(field.clone(), i);
300 }
301 self.field_index_map = Some(map);
302 }
303}
304
305impl From<PySchema> for Schema {
306 fn from(schema: PySchema) -> Self {
307 Schema::new(schema.fields)
308 }
309}
310
311impl From<Schema> for PySchema {
312 fn from(schema: Schema) -> Self {
313 PySchema::new(schema.get_all_fields().to_vec())
314 }
315}