1use std::collections::HashMap;
27
28use pyo3::prelude::*;
29
30use lindera::dictionary::{FieldDefinition, FieldType, Schema};
31
32#[pyclass(name = "FieldType", from_py_object)]
36#[derive(Debug, Clone)]
37pub enum PyFieldType {
38 Surface,
40 LeftContextId,
42 RightContextId,
44 Cost,
46 Custom,
48}
49
50#[pymethods]
51impl PyFieldType {
52 fn __str__(&self) -> &str {
53 match self {
54 PyFieldType::Surface => "surface",
55 PyFieldType::LeftContextId => "left_context_id",
56 PyFieldType::RightContextId => "right_context_id",
57 PyFieldType::Cost => "cost",
58 PyFieldType::Custom => "custom",
59 }
60 }
61
62 fn __repr__(&self) -> String {
63 format!("FieldType.{self:?}")
64 }
65}
66
67impl From<FieldType> for PyFieldType {
68 fn from(field_type: FieldType) -> Self {
69 match field_type {
70 FieldType::Surface => PyFieldType::Surface,
71 FieldType::LeftContextId => PyFieldType::LeftContextId,
72 FieldType::RightContextId => PyFieldType::RightContextId,
73 FieldType::Cost => PyFieldType::Cost,
74 FieldType::Custom => PyFieldType::Custom,
75 }
76 }
77}
78
79impl From<PyFieldType> for FieldType {
80 fn from(field_type: PyFieldType) -> Self {
81 match field_type {
82 PyFieldType::Surface => FieldType::Surface,
83 PyFieldType::LeftContextId => FieldType::LeftContextId,
84 PyFieldType::RightContextId => FieldType::RightContextId,
85 PyFieldType::Cost => FieldType::Cost,
86 PyFieldType::Custom => FieldType::Custom,
87 }
88 }
89}
90
91#[pyclass(name = "FieldDefinition", from_py_object)]
95#[derive(Debug, Clone)]
96pub struct PyFieldDefinition {
97 #[pyo3(get)]
98 pub index: usize,
99 #[pyo3(get)]
100 pub name: String,
101 #[pyo3(get)]
102 pub field_type: PyFieldType,
103 #[pyo3(get)]
104 pub description: Option<String>,
105}
106
107#[pymethods]
108impl PyFieldDefinition {
109 #[new]
110 pub fn new(
111 index: usize,
112 name: String,
113 field_type: PyFieldType,
114 description: Option<String>,
115 ) -> Self {
116 Self {
117 index,
118 name,
119 field_type,
120 description,
121 }
122 }
123
124 fn __str__(&self) -> String {
125 format!("FieldDefinition(index={}, name={})", self.index, self.name)
126 }
127
128 fn __repr__(&self) -> String {
129 format!(
130 "FieldDefinition(index={}, name='{}', field_type={:?}, description={:?})",
131 self.index, self.name, self.field_type, self.description
132 )
133 }
134}
135
136impl From<FieldDefinition> for PyFieldDefinition {
137 fn from(field_def: FieldDefinition) -> Self {
138 PyFieldDefinition {
139 index: field_def.index,
140 name: field_def.name,
141 field_type: field_def.field_type.into(),
142 description: field_def.description,
143 }
144 }
145}
146
147impl From<PyFieldDefinition> for FieldDefinition {
148 fn from(field_def: PyFieldDefinition) -> Self {
149 FieldDefinition {
150 index: field_def.index,
151 name: field_def.name,
152 field_type: field_def.field_type.into(),
153 description: field_def.description,
154 }
155 }
156}
157
158#[pyclass(name = "Schema", from_py_object)]
173#[derive(Debug, Clone)]
174pub struct PySchema {
175 #[pyo3(get)]
176 pub fields: Vec<String>,
177 field_index_map: Option<HashMap<String, usize>>,
178}
179
180#[pymethods]
181impl PySchema {
182 #[new]
183 pub fn new(fields: Vec<String>) -> Self {
184 let mut schema = Self {
185 fields,
186 field_index_map: None,
187 };
188 schema.build_index_map();
189 schema
190 }
191
192 #[staticmethod]
193 pub fn create_default() -> Self {
194 Self::new(vec![
195 "surface".to_string(),
196 "left_context_id".to_string(),
197 "right_context_id".to_string(),
198 "cost".to_string(),
199 "major_pos".to_string(),
200 "middle_pos".to_string(),
201 "small_pos".to_string(),
202 "fine_pos".to_string(),
203 "conjugation_type".to_string(),
204 "conjugation_form".to_string(),
205 "base_form".to_string(),
206 "reading".to_string(),
207 "pronunciation".to_string(),
208 ])
209 }
210
211 pub fn get_field_index(&self, field_name: &str) -> Option<usize> {
212 self.field_index_map
213 .as_ref()
214 .and_then(|map| map.get(field_name))
215 .copied()
216 }
217
218 pub fn field_count(&self) -> usize {
219 self.get_all_fields().len()
220 }
221
222 pub fn get_field_name(&self, index: usize) -> Option<&str> {
223 self.fields.get(index).map(|s| s.as_str())
224 }
225
226 pub fn get_custom_fields(&self) -> Vec<String> {
227 if self.fields.len() > 4 {
228 self.fields[4..].to_vec()
229 } else {
230 Vec::new()
231 }
232 }
233
234 pub fn get_all_fields(&self) -> Vec<String> {
235 self.fields.clone()
236 }
237
238 pub fn get_field_by_name(&self, name: &str) -> Option<PyFieldDefinition> {
239 self.get_field_index(name).map(|index| {
240 let field_type = if index < 4 {
241 match index {
242 0 => PyFieldType::Surface,
243 1 => PyFieldType::LeftContextId,
244 2 => PyFieldType::RightContextId,
245 3 => PyFieldType::Cost,
246 _ => unreachable!(),
247 }
248 } else {
249 PyFieldType::Custom
250 };
251
252 PyFieldDefinition {
253 index,
254 name: name.to_string(),
255 field_type,
256 description: None,
257 }
258 })
259 }
260
261 pub fn validate_record(&self, record: Vec<String>) -> PyResult<()> {
262 if record.len() < self.fields.len() {
263 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
264 "CSV row has {} fields but schema requires {} fields",
265 record.len(),
266 self.fields.len()
267 )));
268 }
269
270 for (index, field_name) in self.fields.iter().enumerate() {
272 if index < record.len() && record[index].trim().is_empty() {
273 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
274 "Field {field_name} is missing or empty"
275 )));
276 }
277 }
278
279 Ok(())
280 }
281
282 fn __str__(&self) -> String {
283 format!("Schema(fields={})", self.fields.len())
284 }
285
286 fn __repr__(&self) -> String {
287 format!("Schema(fields={:?})", self.fields)
288 }
289
290 fn __len__(&self) -> usize {
291 self.fields.len()
292 }
293}
294
295impl PySchema {
296 fn build_index_map(&mut self) {
297 let mut map = HashMap::new();
298 for (i, field) in self.fields.iter().enumerate() {
299 map.insert(field.clone(), i);
300 }
301 self.field_index_map = Some(map);
302 }
303}
304
305impl From<PySchema> for Schema {
306 fn from(schema: PySchema) -> Self {
307 Schema::new(schema.fields)
308 }
309}
310
311impl From<Schema> for PySchema {
312 fn from(schema: Schema) -> Self {
313 PySchema::new(schema.get_all_fields().to_vec())
314 }
315}
316
317pub fn register(parent_module: &Bound<'_, PyModule>) -> PyResult<()> {
318 let py = parent_module.py();
319 let m = PyModule::new(py, "schema")?;
320 m.add_class::<PySchema>()?;
321 m.add_class::<PyFieldDefinition>()?;
322 m.add_class::<PyFieldType>()?;
323 parent_module.add_submodule(&m)?;
324 Ok(())
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330 use lindera::dictionary::{FieldDefinition, FieldType, Schema};
331
332 #[test]
333 fn test_pyfieldtype_surface_to_fieldtype() {
334 let py_ft = PyFieldType::Surface;
335 let ft: FieldType = py_ft.into();
336 assert!(matches!(ft, FieldType::Surface));
337 }
338
339 #[test]
340 fn test_pyfieldtype_left_context_id_to_fieldtype() {
341 let py_ft = PyFieldType::LeftContextId;
342 let ft: FieldType = py_ft.into();
343 assert!(matches!(ft, FieldType::LeftContextId));
344 }
345
346 #[test]
347 fn test_pyfieldtype_right_context_id_to_fieldtype() {
348 let py_ft = PyFieldType::RightContextId;
349 let ft: FieldType = py_ft.into();
350 assert!(matches!(ft, FieldType::RightContextId));
351 }
352
353 #[test]
354 fn test_pyfieldtype_cost_to_fieldtype() {
355 let py_ft = PyFieldType::Cost;
356 let ft: FieldType = py_ft.into();
357 assert!(matches!(ft, FieldType::Cost));
358 }
359
360 #[test]
361 fn test_pyfieldtype_custom_to_fieldtype() {
362 let py_ft = PyFieldType::Custom;
363 let ft: FieldType = py_ft.into();
364 assert!(matches!(ft, FieldType::Custom));
365 }
366
367 #[test]
368 fn test_fieldtype_to_pyfieldtype_all_variants() {
369 assert!(matches!(
370 PyFieldType::from(FieldType::Surface),
371 PyFieldType::Surface
372 ));
373 assert!(matches!(
374 PyFieldType::from(FieldType::LeftContextId),
375 PyFieldType::LeftContextId
376 ));
377 assert!(matches!(
378 PyFieldType::from(FieldType::RightContextId),
379 PyFieldType::RightContextId
380 ));
381 assert!(matches!(
382 PyFieldType::from(FieldType::Cost),
383 PyFieldType::Cost
384 ));
385 assert!(matches!(
386 PyFieldType::from(FieldType::Custom),
387 PyFieldType::Custom
388 ));
389 }
390
391 #[test]
392 fn test_pyfielddefinition_to_fielddefinition() {
393 let py_fd = PyFieldDefinition {
394 index: 0,
395 name: "surface".to_string(),
396 field_type: PyFieldType::Surface,
397 description: Some("Surface form".to_string()),
398 };
399 let fd: FieldDefinition = py_fd.into();
400 assert_eq!(fd.index, 0);
401 assert_eq!(fd.name, "surface");
402 assert!(matches!(fd.field_type, FieldType::Surface));
403 assert_eq!(fd.description, Some("Surface form".to_string()));
404 }
405
406 #[test]
407 fn test_fielddefinition_to_pyfielddefinition() {
408 let fd = FieldDefinition {
409 index: 4,
410 name: "pos".to_string(),
411 field_type: FieldType::Custom,
412 description: None,
413 };
414 let py_fd: PyFieldDefinition = fd.into();
415 assert_eq!(py_fd.index, 4);
416 assert_eq!(py_fd.name, "pos");
417 assert!(matches!(py_fd.field_type, PyFieldType::Custom));
418 assert!(py_fd.description.is_none());
419 }
420
421 #[test]
422 fn test_pyschema_to_schema() {
423 let py_schema = PySchema::new(vec![
424 "surface".to_string(),
425 "left_context_id".to_string(),
426 "right_context_id".to_string(),
427 "cost".to_string(),
428 "pos".to_string(),
429 ]);
430 let schema: Schema = py_schema.into();
431 let fields = schema.get_all_fields();
432 assert_eq!(fields.len(), 5);
433 assert_eq!(fields[0], "surface");
434 assert_eq!(fields[4], "pos");
435 }
436
437 #[test]
438 fn test_schema_to_pyschema() {
439 let schema = Schema::new(vec![
440 "surface".to_string(),
441 "left_context_id".to_string(),
442 "right_context_id".to_string(),
443 "cost".to_string(),
444 ]);
445 let py_schema: PySchema = schema.into();
446 assert_eq!(py_schema.fields.len(), 4);
447 assert_eq!(py_schema.fields[0], "surface");
448 }
449
450 #[test]
451 fn test_pyschema_new_builds_index_map() {
452 let schema = PySchema::new(vec![
453 "surface".to_string(),
454 "pos".to_string(),
455 "reading".to_string(),
456 ]);
457 assert_eq!(schema.get_field_index("surface"), Some(0));
458 assert_eq!(schema.get_field_index("pos"), Some(1));
459 assert_eq!(schema.get_field_index("reading"), Some(2));
460 }
461
462 #[test]
463 fn test_pyschema_get_field_index_existing() {
464 let schema = PySchema::new(vec!["surface".to_string(), "cost".to_string()]);
465 assert_eq!(schema.get_field_index("surface"), Some(0));
466 assert_eq!(schema.get_field_index("cost"), Some(1));
467 }
468
469 #[test]
470 fn test_pyschema_get_field_index_nonexistent() {
471 let schema = PySchema::new(vec!["surface".to_string()]);
472 assert_eq!(schema.get_field_index("nonexistent"), None);
473 }
474
475 #[test]
476 fn test_pyschema_field_count() {
477 let schema = PySchema::new(vec!["a".to_string(), "b".to_string(), "c".to_string()]);
478 assert_eq!(schema.field_count(), 3);
479 }
480
481 #[test]
482 fn test_pyschema_get_custom_fields() {
483 let schema = PySchema::new(vec![
484 "surface".to_string(),
485 "left_context_id".to_string(),
486 "right_context_id".to_string(),
487 "cost".to_string(),
488 "major_pos".to_string(),
489 "reading".to_string(),
490 ]);
491 let custom = schema.get_custom_fields();
492 assert_eq!(custom.len(), 2);
493 assert_eq!(custom[0], "major_pos");
494 assert_eq!(custom[1], "reading");
495 }
496
497 #[test]
498 fn test_pyschema_get_custom_fields_no_custom() {
499 let schema = PySchema::new(vec![
500 "surface".to_string(),
501 "left_context_id".to_string(),
502 "right_context_id".to_string(),
503 "cost".to_string(),
504 ]);
505 let custom = schema.get_custom_fields();
506 assert!(custom.is_empty());
507 }
508
509 #[test]
510 fn test_pyschema_get_custom_fields_fewer_than_four() {
511 let schema = PySchema::new(vec!["surface".to_string(), "cost".to_string()]);
512 let custom = schema.get_custom_fields();
513 assert!(custom.is_empty());
514 }
515
516 #[test]
517 fn test_pyschema_create_default_has_13_fields() {
518 let schema = PySchema::create_default();
519 assert_eq!(schema.field_count(), 13);
520 assert_eq!(schema.fields[0], "surface");
521 assert_eq!(schema.fields[12], "pronunciation");
522 }
523
524 #[test]
525 fn test_pyschema_create_default_index_map() {
526 let schema = PySchema::create_default();
527 assert_eq!(schema.get_field_index("surface"), Some(0));
528 assert_eq!(schema.get_field_index("cost"), Some(3));
529 assert_eq!(schema.get_field_index("pronunciation"), Some(12));
530 assert_eq!(schema.get_field_index("nonexistent"), None);
531 }
532
533 #[test]
534 fn test_pyschema_get_field_name() {
535 let schema = PySchema::new(vec!["surface".to_string(), "pos".to_string()]);
536 assert_eq!(schema.get_field_name(0), Some("surface"));
537 assert_eq!(schema.get_field_name(1), Some("pos"));
538 assert_eq!(schema.get_field_name(2), None);
539 }
540
541 #[test]
542 fn test_pyschema_get_field_by_name_system_field() {
543 let schema = PySchema::create_default();
544 let field = schema.get_field_by_name("surface").unwrap();
545 assert_eq!(field.index, 0);
546 assert_eq!(field.name, "surface");
547 assert!(matches!(field.field_type, PyFieldType::Surface));
548 }
549
550 #[test]
551 fn test_pyschema_get_field_by_name_custom_field() {
552 let schema = PySchema::create_default();
553 let field = schema.get_field_by_name("major_pos").unwrap();
554 assert_eq!(field.index, 4);
555 assert_eq!(field.name, "major_pos");
556 assert!(matches!(field.field_type, PyFieldType::Custom));
557 }
558
559 #[test]
560 fn test_pyschema_get_field_by_name_nonexistent() {
561 let schema = PySchema::create_default();
562 assert!(schema.get_field_by_name("nonexistent").is_none());
563 }
564
565 #[test]
566 fn test_pyschema_roundtrip() {
567 let fields = vec![
568 "surface".to_string(),
569 "left_context_id".to_string(),
570 "right_context_id".to_string(),
571 "cost".to_string(),
572 "pos".to_string(),
573 ];
574 let py_schema = PySchema::new(fields.clone());
575 let schema: Schema = py_schema.into();
576 let roundtripped: PySchema = schema.into();
577 assert_eq!(roundtripped.fields, fields);
578 }
579}