Skip to main content

lindera/
schema.rs

1//! Dictionary schema definitions.
2//!
3//! This module provides schema structures that define the format and fields
4//! of dictionary entries.
5//!
6//! # Examples
7//!
8//! ```python
9//! # Create a custom schema
10//! schema = lindera.Schema([
11//!     "surface",
12//!     "left_context_id",
13//!     "right_context_id",
14//!     "cost",
15//!     "part_of_speech"
16//! ])
17//!
18//! # Use default schema
19//! schema = lindera.Schema.create_default()
20//!
21//! # Access field information
22//! index = schema.get_field_index("surface")
23//! field = schema.get_field_by_name("part_of_speech")
24//! ```
25
26use std::collections::HashMap;
27
28use pyo3::prelude::*;
29
30use lindera::dictionary::{FieldDefinition, FieldType, Schema};
31
32/// Field type in dictionary schema.
33///
34/// Defines the type of a field in the dictionary entry.
35#[pyclass(name = "FieldType", from_py_object)]
36#[derive(Debug, Clone)]
37pub enum PyFieldType {
38    /// Surface form (word text)
39    Surface,
40    /// Left context ID for morphological analysis
41    LeftContextId,
42    /// Right context ID for morphological analysis
43    RightContextId,
44    /// Word cost (used in path selection)
45    Cost,
46    /// Custom field (morphological features)
47    Custom,
48}
49
50#[pymethods]
51impl PyFieldType {
52    fn __str__(&self) -> &str {
53        match self {
54            PyFieldType::Surface => "surface",
55            PyFieldType::LeftContextId => "left_context_id",
56            PyFieldType::RightContextId => "right_context_id",
57            PyFieldType::Cost => "cost",
58            PyFieldType::Custom => "custom",
59        }
60    }
61
62    fn __repr__(&self) -> String {
63        format!("FieldType.{self:?}")
64    }
65}
66
67impl From<FieldType> for PyFieldType {
68    fn from(field_type: FieldType) -> Self {
69        match field_type {
70            FieldType::Surface => PyFieldType::Surface,
71            FieldType::LeftContextId => PyFieldType::LeftContextId,
72            FieldType::RightContextId => PyFieldType::RightContextId,
73            FieldType::Cost => PyFieldType::Cost,
74            FieldType::Custom => PyFieldType::Custom,
75        }
76    }
77}
78
79impl From<PyFieldType> for FieldType {
80    fn from(field_type: PyFieldType) -> Self {
81        match field_type {
82            PyFieldType::Surface => FieldType::Surface,
83            PyFieldType::LeftContextId => FieldType::LeftContextId,
84            PyFieldType::RightContextId => FieldType::RightContextId,
85            PyFieldType::Cost => FieldType::Cost,
86            PyFieldType::Custom => FieldType::Custom,
87        }
88    }
89}
90
91/// Field definition in dictionary schema.
92///
93/// Describes a single field in the dictionary entry format.
94#[pyclass(name = "FieldDefinition", from_py_object)]
95#[derive(Debug, Clone)]
96pub struct PyFieldDefinition {
97    #[pyo3(get)]
98    pub index: usize,
99    #[pyo3(get)]
100    pub name: String,
101    #[pyo3(get)]
102    pub field_type: PyFieldType,
103    #[pyo3(get)]
104    pub description: Option<String>,
105}
106
107#[pymethods]
108impl PyFieldDefinition {
109    #[new]
110    pub fn new(
111        index: usize,
112        name: String,
113        field_type: PyFieldType,
114        description: Option<String>,
115    ) -> Self {
116        Self {
117            index,
118            name,
119            field_type,
120            description,
121        }
122    }
123
124    fn __str__(&self) -> String {
125        format!("FieldDefinition(index={}, name={})", self.index, self.name)
126    }
127
128    fn __repr__(&self) -> String {
129        format!(
130            "FieldDefinition(index={}, name='{}', field_type={:?}, description={:?})",
131            self.index, self.name, self.field_type, self.description
132        )
133    }
134}
135
136impl From<FieldDefinition> for PyFieldDefinition {
137    fn from(field_def: FieldDefinition) -> Self {
138        PyFieldDefinition {
139            index: field_def.index,
140            name: field_def.name,
141            field_type: field_def.field_type.into(),
142            description: field_def.description,
143        }
144    }
145}
146
147impl From<PyFieldDefinition> for FieldDefinition {
148    fn from(field_def: PyFieldDefinition) -> Self {
149        FieldDefinition {
150            index: field_def.index,
151            name: field_def.name,
152            field_type: field_def.field_type.into(),
153            description: field_def.description,
154        }
155    }
156}
157
158/// Dictionary schema definition.
159///
160/// Defines the structure and fields of dictionary entries.
161///
162/// # Examples
163///
164/// ```python
165/// # Create schema
166/// schema = lindera.Schema(["surface", "pos", "reading"])
167///
168/// # Query field information
169/// index = schema.get_field_index("pos")
170/// field = schema.get_field_by_name("reading")
171/// ```
172#[pyclass(name = "Schema", from_py_object)]
173#[derive(Debug, Clone)]
174pub struct PySchema {
175    #[pyo3(get)]
176    pub fields: Vec<String>,
177    field_index_map: Option<HashMap<String, usize>>,
178}
179
180#[pymethods]
181impl PySchema {
182    #[new]
183    pub fn new(fields: Vec<String>) -> Self {
184        let mut schema = Self {
185            fields,
186            field_index_map: None,
187        };
188        schema.build_index_map();
189        schema
190    }
191
192    #[staticmethod]
193    pub fn create_default() -> Self {
194        Self::new(vec![
195            "surface".to_string(),
196            "left_context_id".to_string(),
197            "right_context_id".to_string(),
198            "cost".to_string(),
199            "major_pos".to_string(),
200            "middle_pos".to_string(),
201            "small_pos".to_string(),
202            "fine_pos".to_string(),
203            "conjugation_type".to_string(),
204            "conjugation_form".to_string(),
205            "base_form".to_string(),
206            "reading".to_string(),
207            "pronunciation".to_string(),
208        ])
209    }
210
211    pub fn get_field_index(&self, field_name: &str) -> Option<usize> {
212        self.field_index_map
213            .as_ref()
214            .and_then(|map| map.get(field_name))
215            .copied()
216    }
217
218    pub fn field_count(&self) -> usize {
219        self.get_all_fields().len()
220    }
221
222    pub fn get_field_name(&self, index: usize) -> Option<&str> {
223        self.fields.get(index).map(|s| s.as_str())
224    }
225
226    pub fn get_custom_fields(&self) -> Vec<String> {
227        if self.fields.len() > 4 {
228            self.fields[4..].to_vec()
229        } else {
230            Vec::new()
231        }
232    }
233
234    pub fn get_all_fields(&self) -> Vec<String> {
235        self.fields.clone()
236    }
237
238    pub fn get_field_by_name(&self, name: &str) -> Option<PyFieldDefinition> {
239        self.get_field_index(name).map(|index| {
240            let field_type = if index < 4 {
241                match index {
242                    0 => PyFieldType::Surface,
243                    1 => PyFieldType::LeftContextId,
244                    2 => PyFieldType::RightContextId,
245                    3 => PyFieldType::Cost,
246                    _ => unreachable!(),
247                }
248            } else {
249                PyFieldType::Custom
250            };
251
252            PyFieldDefinition {
253                index,
254                name: name.to_string(),
255                field_type,
256                description: None,
257            }
258        })
259    }
260
261    pub fn validate_record(&self, record: Vec<String>) -> PyResult<()> {
262        if record.len() < self.fields.len() {
263            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
264                "CSV row has {} fields but schema requires {} fields",
265                record.len(),
266                self.fields.len()
267            )));
268        }
269
270        // Check that required fields are not empty
271        for (index, field_name) in self.fields.iter().enumerate() {
272            if index < record.len() && record[index].trim().is_empty() {
273                return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
274                    "Field {field_name} is missing or empty"
275                )));
276            }
277        }
278
279        Ok(())
280    }
281
282    fn __str__(&self) -> String {
283        format!("Schema(fields={})", self.fields.len())
284    }
285
286    fn __repr__(&self) -> String {
287        format!("Schema(fields={:?})", self.fields)
288    }
289
290    fn __len__(&self) -> usize {
291        self.fields.len()
292    }
293}
294
295impl PySchema {
296    fn build_index_map(&mut self) {
297        let mut map = HashMap::new();
298        for (i, field) in self.fields.iter().enumerate() {
299            map.insert(field.clone(), i);
300        }
301        self.field_index_map = Some(map);
302    }
303}
304
305impl From<PySchema> for Schema {
306    fn from(schema: PySchema) -> Self {
307        Schema::new(schema.fields)
308    }
309}
310
311impl From<Schema> for PySchema {
312    fn from(schema: Schema) -> Self {
313        PySchema::new(schema.get_all_fields().to_vec())
314    }
315}
316
317pub fn register(parent_module: &Bound<'_, PyModule>) -> PyResult<()> {
318    let py = parent_module.py();
319    let m = PyModule::new(py, "schema")?;
320    m.add_class::<PySchema>()?;
321    m.add_class::<PyFieldDefinition>()?;
322    m.add_class::<PyFieldType>()?;
323    parent_module.add_submodule(&m)?;
324    Ok(())
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330    use lindera::dictionary::{FieldDefinition, FieldType, Schema};
331
332    #[test]
333    fn test_pyfieldtype_surface_to_fieldtype() {
334        let py_ft = PyFieldType::Surface;
335        let ft: FieldType = py_ft.into();
336        assert!(matches!(ft, FieldType::Surface));
337    }
338
339    #[test]
340    fn test_pyfieldtype_left_context_id_to_fieldtype() {
341        let py_ft = PyFieldType::LeftContextId;
342        let ft: FieldType = py_ft.into();
343        assert!(matches!(ft, FieldType::LeftContextId));
344    }
345
346    #[test]
347    fn test_pyfieldtype_right_context_id_to_fieldtype() {
348        let py_ft = PyFieldType::RightContextId;
349        let ft: FieldType = py_ft.into();
350        assert!(matches!(ft, FieldType::RightContextId));
351    }
352
353    #[test]
354    fn test_pyfieldtype_cost_to_fieldtype() {
355        let py_ft = PyFieldType::Cost;
356        let ft: FieldType = py_ft.into();
357        assert!(matches!(ft, FieldType::Cost));
358    }
359
360    #[test]
361    fn test_pyfieldtype_custom_to_fieldtype() {
362        let py_ft = PyFieldType::Custom;
363        let ft: FieldType = py_ft.into();
364        assert!(matches!(ft, FieldType::Custom));
365    }
366
367    #[test]
368    fn test_fieldtype_to_pyfieldtype_all_variants() {
369        assert!(matches!(
370            PyFieldType::from(FieldType::Surface),
371            PyFieldType::Surface
372        ));
373        assert!(matches!(
374            PyFieldType::from(FieldType::LeftContextId),
375            PyFieldType::LeftContextId
376        ));
377        assert!(matches!(
378            PyFieldType::from(FieldType::RightContextId),
379            PyFieldType::RightContextId
380        ));
381        assert!(matches!(
382            PyFieldType::from(FieldType::Cost),
383            PyFieldType::Cost
384        ));
385        assert!(matches!(
386            PyFieldType::from(FieldType::Custom),
387            PyFieldType::Custom
388        ));
389    }
390
391    #[test]
392    fn test_pyfielddefinition_to_fielddefinition() {
393        let py_fd = PyFieldDefinition {
394            index: 0,
395            name: "surface".to_string(),
396            field_type: PyFieldType::Surface,
397            description: Some("Surface form".to_string()),
398        };
399        let fd: FieldDefinition = py_fd.into();
400        assert_eq!(fd.index, 0);
401        assert_eq!(fd.name, "surface");
402        assert!(matches!(fd.field_type, FieldType::Surface));
403        assert_eq!(fd.description, Some("Surface form".to_string()));
404    }
405
406    #[test]
407    fn test_fielddefinition_to_pyfielddefinition() {
408        let fd = FieldDefinition {
409            index: 4,
410            name: "pos".to_string(),
411            field_type: FieldType::Custom,
412            description: None,
413        };
414        let py_fd: PyFieldDefinition = fd.into();
415        assert_eq!(py_fd.index, 4);
416        assert_eq!(py_fd.name, "pos");
417        assert!(matches!(py_fd.field_type, PyFieldType::Custom));
418        assert!(py_fd.description.is_none());
419    }
420
421    #[test]
422    fn test_pyschema_to_schema() {
423        let py_schema = PySchema::new(vec![
424            "surface".to_string(),
425            "left_context_id".to_string(),
426            "right_context_id".to_string(),
427            "cost".to_string(),
428            "pos".to_string(),
429        ]);
430        let schema: Schema = py_schema.into();
431        let fields = schema.get_all_fields();
432        assert_eq!(fields.len(), 5);
433        assert_eq!(fields[0], "surface");
434        assert_eq!(fields[4], "pos");
435    }
436
437    #[test]
438    fn test_schema_to_pyschema() {
439        let schema = Schema::new(vec![
440            "surface".to_string(),
441            "left_context_id".to_string(),
442            "right_context_id".to_string(),
443            "cost".to_string(),
444        ]);
445        let py_schema: PySchema = schema.into();
446        assert_eq!(py_schema.fields.len(), 4);
447        assert_eq!(py_schema.fields[0], "surface");
448    }
449
450    #[test]
451    fn test_pyschema_new_builds_index_map() {
452        let schema = PySchema::new(vec![
453            "surface".to_string(),
454            "pos".to_string(),
455            "reading".to_string(),
456        ]);
457        assert_eq!(schema.get_field_index("surface"), Some(0));
458        assert_eq!(schema.get_field_index("pos"), Some(1));
459        assert_eq!(schema.get_field_index("reading"), Some(2));
460    }
461
462    #[test]
463    fn test_pyschema_get_field_index_existing() {
464        let schema = PySchema::new(vec!["surface".to_string(), "cost".to_string()]);
465        assert_eq!(schema.get_field_index("surface"), Some(0));
466        assert_eq!(schema.get_field_index("cost"), Some(1));
467    }
468
469    #[test]
470    fn test_pyschema_get_field_index_nonexistent() {
471        let schema = PySchema::new(vec!["surface".to_string()]);
472        assert_eq!(schema.get_field_index("nonexistent"), None);
473    }
474
475    #[test]
476    fn test_pyschema_field_count() {
477        let schema = PySchema::new(vec!["a".to_string(), "b".to_string(), "c".to_string()]);
478        assert_eq!(schema.field_count(), 3);
479    }
480
481    #[test]
482    fn test_pyschema_get_custom_fields() {
483        let schema = PySchema::new(vec![
484            "surface".to_string(),
485            "left_context_id".to_string(),
486            "right_context_id".to_string(),
487            "cost".to_string(),
488            "major_pos".to_string(),
489            "reading".to_string(),
490        ]);
491        let custom = schema.get_custom_fields();
492        assert_eq!(custom.len(), 2);
493        assert_eq!(custom[0], "major_pos");
494        assert_eq!(custom[1], "reading");
495    }
496
497    #[test]
498    fn test_pyschema_get_custom_fields_no_custom() {
499        let schema = PySchema::new(vec![
500            "surface".to_string(),
501            "left_context_id".to_string(),
502            "right_context_id".to_string(),
503            "cost".to_string(),
504        ]);
505        let custom = schema.get_custom_fields();
506        assert!(custom.is_empty());
507    }
508
509    #[test]
510    fn test_pyschema_get_custom_fields_fewer_than_four() {
511        let schema = PySchema::new(vec!["surface".to_string(), "cost".to_string()]);
512        let custom = schema.get_custom_fields();
513        assert!(custom.is_empty());
514    }
515
516    #[test]
517    fn test_pyschema_create_default_has_13_fields() {
518        let schema = PySchema::create_default();
519        assert_eq!(schema.field_count(), 13);
520        assert_eq!(schema.fields[0], "surface");
521        assert_eq!(schema.fields[12], "pronunciation");
522    }
523
524    #[test]
525    fn test_pyschema_create_default_index_map() {
526        let schema = PySchema::create_default();
527        assert_eq!(schema.get_field_index("surface"), Some(0));
528        assert_eq!(schema.get_field_index("cost"), Some(3));
529        assert_eq!(schema.get_field_index("pronunciation"), Some(12));
530        assert_eq!(schema.get_field_index("nonexistent"), None);
531    }
532
533    #[test]
534    fn test_pyschema_get_field_name() {
535        let schema = PySchema::new(vec!["surface".to_string(), "pos".to_string()]);
536        assert_eq!(schema.get_field_name(0), Some("surface"));
537        assert_eq!(schema.get_field_name(1), Some("pos"));
538        assert_eq!(schema.get_field_name(2), None);
539    }
540
541    #[test]
542    fn test_pyschema_get_field_by_name_system_field() {
543        let schema = PySchema::create_default();
544        let field = schema.get_field_by_name("surface").unwrap();
545        assert_eq!(field.index, 0);
546        assert_eq!(field.name, "surface");
547        assert!(matches!(field.field_type, PyFieldType::Surface));
548    }
549
550    #[test]
551    fn test_pyschema_get_field_by_name_custom_field() {
552        let schema = PySchema::create_default();
553        let field = schema.get_field_by_name("major_pos").unwrap();
554        assert_eq!(field.index, 4);
555        assert_eq!(field.name, "major_pos");
556        assert!(matches!(field.field_type, PyFieldType::Custom));
557    }
558
559    #[test]
560    fn test_pyschema_get_field_by_name_nonexistent() {
561        let schema = PySchema::create_default();
562        assert!(schema.get_field_by_name("nonexistent").is_none());
563    }
564
565    #[test]
566    fn test_pyschema_roundtrip() {
567        let fields = vec![
568            "surface".to_string(),
569            "left_context_id".to_string(),
570            "right_context_id".to_string(),
571            "cost".to_string(),
572            "pos".to_string(),
573        ];
574        let py_schema = PySchema::new(fields.clone());
575        let schema: Schema = py_schema.into();
576        let roundtripped: PySchema = schema.into();
577        assert_eq!(roundtripped.fields, fields);
578    }
579}