lindera/
schema.rs

1use std::collections::HashMap;
2
3use pyo3::prelude::*;
4
5use lindera::dictionary::{FieldDefinition, FieldType, Schema};
6
7#[pyclass(name = "FieldType")]
8#[derive(Debug, Clone)]
9pub enum PyFieldType {
10    Surface,
11    LeftContextId,
12    RightContextId,
13    Cost,
14    Custom,
15}
16
17#[pymethods]
18impl PyFieldType {
19    fn __str__(&self) -> &str {
20        match self {
21            PyFieldType::Surface => "surface",
22            PyFieldType::LeftContextId => "left_context_id",
23            PyFieldType::RightContextId => "right_context_id",
24            PyFieldType::Cost => "cost",
25            PyFieldType::Custom => "custom",
26        }
27    }
28
29    fn __repr__(&self) -> String {
30        format!("FieldType.{self:?}")
31    }
32}
33
34impl From<FieldType> for PyFieldType {
35    fn from(field_type: FieldType) -> Self {
36        match field_type {
37            FieldType::Surface => PyFieldType::Surface,
38            FieldType::LeftContextId => PyFieldType::LeftContextId,
39            FieldType::RightContextId => PyFieldType::RightContextId,
40            FieldType::Cost => PyFieldType::Cost,
41            FieldType::Custom => PyFieldType::Custom,
42        }
43    }
44}
45
46impl From<PyFieldType> for FieldType {
47    fn from(field_type: PyFieldType) -> Self {
48        match field_type {
49            PyFieldType::Surface => FieldType::Surface,
50            PyFieldType::LeftContextId => FieldType::LeftContextId,
51            PyFieldType::RightContextId => FieldType::RightContextId,
52            PyFieldType::Cost => FieldType::Cost,
53            PyFieldType::Custom => FieldType::Custom,
54        }
55    }
56}
57
58#[pyclass(name = "FieldDefinition")]
59#[derive(Debug, Clone)]
60pub struct PyFieldDefinition {
61    #[pyo3(get)]
62    pub index: usize,
63    #[pyo3(get)]
64    pub name: String,
65    #[pyo3(get)]
66    pub field_type: PyFieldType,
67    #[pyo3(get)]
68    pub description: Option<String>,
69}
70
71#[pymethods]
72impl PyFieldDefinition {
73    #[new]
74    pub fn new(
75        index: usize,
76        name: String,
77        field_type: PyFieldType,
78        description: Option<String>,
79    ) -> Self {
80        Self {
81            index,
82            name,
83            field_type,
84            description,
85        }
86    }
87
88    fn __str__(&self) -> String {
89        format!("FieldDefinition(index={}, name={})", self.index, self.name)
90    }
91
92    fn __repr__(&self) -> String {
93        format!(
94            "FieldDefinition(index={}, name='{}', field_type={:?}, description={:?})",
95            self.index, self.name, self.field_type, self.description
96        )
97    }
98}
99
100impl From<FieldDefinition> for PyFieldDefinition {
101    fn from(field_def: FieldDefinition) -> Self {
102        PyFieldDefinition {
103            index: field_def.index,
104            name: field_def.name,
105            field_type: field_def.field_type.into(),
106            description: field_def.description,
107        }
108    }
109}
110
111impl From<PyFieldDefinition> for FieldDefinition {
112    fn from(field_def: PyFieldDefinition) -> Self {
113        FieldDefinition {
114            index: field_def.index,
115            name: field_def.name,
116            field_type: field_def.field_type.into(),
117            description: field_def.description,
118        }
119    }
120}
121
122#[pyclass(name = "Schema")]
123#[derive(Debug, Clone)]
124pub struct PySchema {
125    #[pyo3(get)]
126    pub fields: Vec<String>,
127    field_index_map: Option<HashMap<String, usize>>,
128}
129
130#[pymethods]
131impl PySchema {
132    #[new]
133    pub fn new(fields: Vec<String>) -> Self {
134        let mut schema = Self {
135            fields,
136            field_index_map: None,
137        };
138        schema.build_index_map();
139        schema
140    }
141
142    #[staticmethod]
143    pub fn create_default() -> Self {
144        Self::new(vec![
145            "surface".to_string(),
146            "left_context_id".to_string(),
147            "right_context_id".to_string(),
148            "cost".to_string(),
149            "major_pos".to_string(),
150            "middle_pos".to_string(),
151            "small_pos".to_string(),
152            "fine_pos".to_string(),
153            "conjugation_type".to_string(),
154            "conjugation_form".to_string(),
155            "base_form".to_string(),
156            "reading".to_string(),
157            "pronunciation".to_string(),
158        ])
159    }
160
161    pub fn get_field_index(&self, field_name: &str) -> Option<usize> {
162        self.field_index_map
163            .as_ref()
164            .and_then(|map| map.get(field_name))
165            .copied()
166    }
167
168    pub fn field_count(&self) -> usize {
169        self.get_all_fields().len()
170    }
171
172    pub fn get_field_name(&self, index: usize) -> Option<&str> {
173        self.fields.get(index).map(|s| s.as_str())
174    }
175
176    pub fn get_custom_fields(&self) -> Vec<String> {
177        if self.fields.len() > 4 {
178            self.fields[4..].to_vec()
179        } else {
180            Vec::new()
181        }
182    }
183
184    pub fn get_all_fields(&self) -> Vec<String> {
185        self.fields.clone()
186    }
187
188    pub fn get_field_by_name(&self, name: &str) -> Option<PyFieldDefinition> {
189        self.get_field_index(name).map(|index| {
190            let field_type = if index < 4 {
191                match index {
192                    0 => PyFieldType::Surface,
193                    1 => PyFieldType::LeftContextId,
194                    2 => PyFieldType::RightContextId,
195                    3 => PyFieldType::Cost,
196                    _ => unreachable!(),
197                }
198            } else {
199                PyFieldType::Custom
200            };
201
202            PyFieldDefinition {
203                index,
204                name: name.to_string(),
205                field_type,
206                description: None,
207            }
208        })
209    }
210
211    pub fn validate_record(&self, record: Vec<String>) -> PyResult<()> {
212        if record.len() < self.fields.len() {
213            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
214                "CSV row has {} fields but schema requires {} fields",
215                record.len(),
216                self.fields.len()
217            )));
218        }
219
220        // Check that required fields are not empty
221        for (index, field_name) in self.fields.iter().enumerate() {
222            if index < record.len() && record[index].trim().is_empty() {
223                return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
224                    "Field {field_name} is missing or empty"
225                )));
226            }
227        }
228
229        Ok(())
230    }
231
232    fn __str__(&self) -> String {
233        format!("Schema(fields={})", self.fields.len())
234    }
235
236    fn __repr__(&self) -> String {
237        format!("Schema(fields={:?})", self.fields)
238    }
239
240    fn __len__(&self) -> usize {
241        self.fields.len()
242    }
243}
244
245impl PySchema {
246    fn build_index_map(&mut self) {
247        let mut map = HashMap::new();
248        for (i, field) in self.fields.iter().enumerate() {
249            map.insert(field.clone(), i);
250        }
251        self.field_index_map = Some(map);
252    }
253}
254
255impl From<PySchema> for Schema {
256    fn from(schema: PySchema) -> Self {
257        Schema::new(schema.fields)
258    }
259}
260
261impl From<Schema> for PySchema {
262    fn from(schema: Schema) -> Self {
263        PySchema::new(schema.get_all_fields().to_vec())
264    }
265}