1use std::collections::HashMap;
2
3use pyo3::prelude::*;
4
5use lindera::dictionary::{FieldDefinition, FieldType, Schema};
6
7#[pyclass(name = "FieldType")]
8#[derive(Debug, Clone)]
9pub enum PyFieldType {
10 Surface,
11 LeftContextId,
12 RightContextId,
13 Cost,
14 Custom,
15}
16
17#[pymethods]
18impl PyFieldType {
19 fn __str__(&self) -> &str {
20 match self {
21 PyFieldType::Surface => "surface",
22 PyFieldType::LeftContextId => "left_context_id",
23 PyFieldType::RightContextId => "right_context_id",
24 PyFieldType::Cost => "cost",
25 PyFieldType::Custom => "custom",
26 }
27 }
28
29 fn __repr__(&self) -> String {
30 format!("FieldType.{self:?}")
31 }
32}
33
34impl From<FieldType> for PyFieldType {
35 fn from(field_type: FieldType) -> Self {
36 match field_type {
37 FieldType::Surface => PyFieldType::Surface,
38 FieldType::LeftContextId => PyFieldType::LeftContextId,
39 FieldType::RightContextId => PyFieldType::RightContextId,
40 FieldType::Cost => PyFieldType::Cost,
41 FieldType::Custom => PyFieldType::Custom,
42 }
43 }
44}
45
46impl From<PyFieldType> for FieldType {
47 fn from(field_type: PyFieldType) -> Self {
48 match field_type {
49 PyFieldType::Surface => FieldType::Surface,
50 PyFieldType::LeftContextId => FieldType::LeftContextId,
51 PyFieldType::RightContextId => FieldType::RightContextId,
52 PyFieldType::Cost => FieldType::Cost,
53 PyFieldType::Custom => FieldType::Custom,
54 }
55 }
56}
57
58#[pyclass(name = "FieldDefinition")]
59#[derive(Debug, Clone)]
60pub struct PyFieldDefinition {
61 #[pyo3(get)]
62 pub index: usize,
63 #[pyo3(get)]
64 pub name: String,
65 #[pyo3(get)]
66 pub field_type: PyFieldType,
67 #[pyo3(get)]
68 pub description: Option<String>,
69}
70
71#[pymethods]
72impl PyFieldDefinition {
73 #[new]
74 pub fn new(
75 index: usize,
76 name: String,
77 field_type: PyFieldType,
78 description: Option<String>,
79 ) -> Self {
80 Self {
81 index,
82 name,
83 field_type,
84 description,
85 }
86 }
87
88 fn __str__(&self) -> String {
89 format!("FieldDefinition(index={}, name={})", self.index, self.name)
90 }
91
92 fn __repr__(&self) -> String {
93 format!(
94 "FieldDefinition(index={}, name='{}', field_type={:?}, description={:?})",
95 self.index, self.name, self.field_type, self.description
96 )
97 }
98}
99
100impl From<FieldDefinition> for PyFieldDefinition {
101 fn from(field_def: FieldDefinition) -> Self {
102 PyFieldDefinition {
103 index: field_def.index,
104 name: field_def.name,
105 field_type: field_def.field_type.into(),
106 description: field_def.description,
107 }
108 }
109}
110
111impl From<PyFieldDefinition> for FieldDefinition {
112 fn from(field_def: PyFieldDefinition) -> Self {
113 FieldDefinition {
114 index: field_def.index,
115 name: field_def.name,
116 field_type: field_def.field_type.into(),
117 description: field_def.description,
118 }
119 }
120}
121
122#[pyclass(name = "Schema")]
123#[derive(Debug, Clone)]
124pub struct PySchema {
125 #[pyo3(get)]
126 pub fields: Vec<String>,
127 field_index_map: Option<HashMap<String, usize>>,
128}
129
130#[pymethods]
131impl PySchema {
132 #[new]
133 pub fn new(fields: Vec<String>) -> Self {
134 let mut schema = Self {
135 fields,
136 field_index_map: None,
137 };
138 schema.build_index_map();
139 schema
140 }
141
142 #[staticmethod]
143 pub fn create_default() -> Self {
144 Self::new(vec![
145 "surface".to_string(),
146 "left_context_id".to_string(),
147 "right_context_id".to_string(),
148 "cost".to_string(),
149 "major_pos".to_string(),
150 "middle_pos".to_string(),
151 "small_pos".to_string(),
152 "fine_pos".to_string(),
153 "conjugation_type".to_string(),
154 "conjugation_form".to_string(),
155 "base_form".to_string(),
156 "reading".to_string(),
157 "pronunciation".to_string(),
158 ])
159 }
160
161 pub fn get_field_index(&self, field_name: &str) -> Option<usize> {
162 self.field_index_map
163 .as_ref()
164 .and_then(|map| map.get(field_name))
165 .copied()
166 }
167
168 pub fn field_count(&self) -> usize {
169 self.get_all_fields().len()
170 }
171
172 pub fn get_field_name(&self, index: usize) -> Option<&str> {
173 self.fields.get(index).map(|s| s.as_str())
174 }
175
176 pub fn get_custom_fields(&self) -> Vec<String> {
177 if self.fields.len() > 4 {
178 self.fields[4..].to_vec()
179 } else {
180 Vec::new()
181 }
182 }
183
184 pub fn get_all_fields(&self) -> Vec<String> {
185 self.fields.clone()
186 }
187
188 pub fn get_field_by_name(&self, name: &str) -> Option<PyFieldDefinition> {
189 self.get_field_index(name).map(|index| {
190 let field_type = if index < 4 {
191 match index {
192 0 => PyFieldType::Surface,
193 1 => PyFieldType::LeftContextId,
194 2 => PyFieldType::RightContextId,
195 3 => PyFieldType::Cost,
196 _ => unreachable!(),
197 }
198 } else {
199 PyFieldType::Custom
200 };
201
202 PyFieldDefinition {
203 index,
204 name: name.to_string(),
205 field_type,
206 description: None,
207 }
208 })
209 }
210
211 pub fn validate_record(&self, record: Vec<String>) -> PyResult<()> {
212 if record.len() < self.fields.len() {
213 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
214 "CSV row has {} fields but schema requires {} fields",
215 record.len(),
216 self.fields.len()
217 )));
218 }
219
220 for (index, field_name) in self.fields.iter().enumerate() {
222 if index < record.len() && record[index].trim().is_empty() {
223 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
224 "Field {field_name} is missing or empty"
225 )));
226 }
227 }
228
229 Ok(())
230 }
231
232 fn __str__(&self) -> String {
233 format!("Schema(fields={})", self.fields.len())
234 }
235
236 fn __repr__(&self) -> String {
237 format!("Schema(fields={:?})", self.fields)
238 }
239
240 fn __len__(&self) -> usize {
241 self.fields.len()
242 }
243}
244
245impl PySchema {
246 fn build_index_map(&mut self) {
247 let mut map = HashMap::new();
248 for (i, field) in self.fields.iter().enumerate() {
249 map.insert(field.clone(), i);
250 }
251 self.field_index_map = Some(map);
252 }
253}
254
255impl From<PySchema> for Schema {
256 fn from(schema: PySchema) -> Self {
257 Schema::new(schema.fields)
258 }
259}
260
261impl From<Schema> for PySchema {
262 fn from(schema: Schema) -> Self {
263 PySchema::new(schema.get_all_fields().to_vec())
264 }
265}