lindera_dictionary/dictionary/
schema.rs1use csv::StringRecord;
2use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6use crate::LinderaResult;
7use crate::error::LinderaErrorKind;
8
9#[derive(Debug, Clone, Serialize, Archive, RkyvSerialize, RkyvDeserialize)]
11
12pub struct Schema {
13 pub fields: Vec<String>,
15 #[serde(skip)]
17 field_index_map: Option<HashMap<String, usize>>,
18}
19
20impl<'de> serde::Deserialize<'de> for Schema {
21 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
22 where
23 D: serde::Deserializer<'de>,
24 {
25 #[derive(Deserialize)]
26 struct DictionarySchemaHelper {
27 fields: Vec<String>,
28 }
29
30 let helper = <DictionarySchemaHelper as serde::Deserialize>::deserialize(deserializer)?;
31 let mut schema = Schema {
32 fields: helper.fields,
33 field_index_map: None,
34 };
35 schema.build_index_map();
36 Ok(schema)
37 }
38}
39
40impl Default for Schema {
41 fn default() -> Self {
42 let fields = vec![
43 "surface",
44 "left_context_id",
45 "right_context_id",
46 "cost",
47 "major_pos",
48 "pos_detail_1",
49 "pos_detail_2",
50 "pos_detail_3",
51 "conjugation_type",
52 "conjugation_form",
53 "base_form",
54 "reading",
55 "pronunciation",
56 ]
57 .into_iter()
58 .map(|s| s.to_string())
59 .collect();
60
61 let mut schema = Self {
62 fields,
63 field_index_map: None,
64 };
65 schema.build_index_map();
66 schema
67 }
68}
69
70impl Schema {
71 pub fn new(fields: Vec<String>) -> Self {
73 let mut schema = Self {
74 fields,
75 field_index_map: None,
76 };
77 schema.build_index_map();
78 schema
79 }
80
81 fn build_index_map(&mut self) {
83 let mut map = HashMap::new();
84
85 for (i, field) in self.fields.iter().enumerate() {
87 map.insert(field.clone(), i);
88 }
89
90 self.field_index_map = Some(map);
91 }
92
93 pub fn get_field_index(&self, field_name: &str) -> Option<usize> {
95 self.field_index_map
96 .as_ref()
97 .and_then(|map| map.get(field_name))
98 .copied()
99 }
100
101 pub fn field_count(&self) -> usize {
103 self.get_all_fields().len()
104 }
105
106 pub fn get_field_name(&self, index: usize) -> Option<&str> {
108 self.fields.get(index).map(|s| s.as_str())
109 }
110
111 pub fn get_custom_fields(&self) -> &[String] {
113 if self.fields.len() > 4 {
114 &self.fields[4..]
115 } else {
116 &[]
117 }
118 }
119
120 pub fn get_all_fields(&self) -> &[String] {
122 &self.fields
123 }
124
125 pub fn validate_fields(&self, row: &StringRecord) -> LinderaResult<()> {
127 if row.len() < self.fields.len() {
128 return Err(LinderaErrorKind::Content.with_error(anyhow::anyhow!(
129 "CSV row has {} fields but schema requires {} fields",
130 row.len(),
131 self.fields.len()
132 )));
133 }
134
135 for (index, field_name) in self.fields.iter().enumerate() {
137 if index < row.len() && row[index].trim().is_empty() {
138 return Err(LinderaErrorKind::Content
139 .with_error(anyhow::anyhow!("Field {field_name} is missing or empty")));
140 }
141 }
142
143 Ok(())
144 }
145}
146
147impl Schema {
149 pub fn get_field_by_name(&self, name: &str) -> Option<FieldDefinition> {
151 self.get_field_index(name).map(|index| FieldDefinition {
152 index,
153 name: name.to_string(),
154 field_type: if index < 4 {
155 match index {
156 0 => FieldType::Surface,
157 1 => FieldType::LeftContextId,
158 2 => FieldType::RightContextId,
159 3 => FieldType::Cost,
160 _ => unreachable!(),
161 }
162 } else {
163 FieldType::Custom
164 },
165 description: None,
166 })
167 }
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
172
173pub struct FieldDefinition {
174 pub index: usize,
175 pub name: String,
176 pub field_type: FieldType,
177 pub description: Option<String>,
178}
179
180#[derive(
181 Debug, Clone, Serialize, Deserialize, PartialEq, Archive, RkyvSerialize, RkyvDeserialize,
182)]
183
184pub enum FieldType {
185 Surface,
186 LeftContextId,
187 RightContextId,
188 Cost,
189 Custom,
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 #[test]
197 fn test_new_schema() {
198 let fields = vec!["field1".to_string(), "field2".to_string()];
199 let schema = Schema::new(fields);
200
201 assert_eq!(schema.fields.len(), 2);
202 assert!(schema.field_index_map.is_some());
203 }
204
205 #[test]
206 fn test_field_index_lookup() {
207 let schema = Schema::default();
208
209 assert_eq!(schema.get_field_index("surface"), Some(0));
211 assert_eq!(schema.get_field_index("left_context_id"), Some(1));
212 assert_eq!(schema.get_field_index("right_context_id"), Some(2));
213 assert_eq!(schema.get_field_index("cost"), Some(3));
214
215 assert_eq!(schema.get_field_index("major_pos"), Some(4));
217 assert_eq!(schema.get_field_index("base_form"), Some(10));
218 assert_eq!(schema.get_field_index("pronunciation"), Some(12));
219
220 assert_eq!(schema.get_field_index("nonexistent"), None);
222 }
223
224 #[test]
225 fn test_field_name_lookup() {
226 let schema = Schema::default();
227
228 assert_eq!(schema.get_field_name(0), Some("surface"));
229 assert_eq!(schema.get_field_name(3), Some("cost"));
230 assert_eq!(schema.get_field_name(4), Some("major_pos"));
231 assert_eq!(schema.get_field_name(12), Some("pronunciation"));
232 assert_eq!(schema.get_field_name(13), None);
233 }
234
235 #[test]
236 fn test_default_schema() {
237 let schema = Schema::default();
238 assert_eq!(schema.field_count(), 13);
240 assert_eq!(schema.fields.len(), 13);
241 assert_eq!(schema.get_custom_fields().len(), 9);
242 }
243
244 #[test]
245 fn test_field_access() {
246 let schema = Schema::default();
247
248 assert_eq!(schema.get_field_index("surface"), Some(0));
249 assert_eq!(schema.get_field_index("left_context_id"), Some(1));
250 assert_eq!(schema.get_field_index("right_context_id"), Some(2));
251 assert_eq!(schema.get_field_index("cost"), Some(3));
252 }
253
254 #[test]
255 fn test_validate_fields_success() {
256 let schema = Schema::default();
257 let record = StringRecord::from(vec![
258 "surface_form",
259 "123",
260 "456",
261 "789",
262 "名詞",
263 "一般",
264 "*",
265 "*",
266 "*",
267 "*",
268 "surface_form",
269 "読み",
270 "発音",
271 ]);
272
273 let result = schema.validate_fields(&record);
274 assert!(result.is_ok());
275 }
276
277 #[test]
278 fn test_validate_fields_empty_field() {
279 let schema = Schema::default();
280 let record = StringRecord::from(vec![
281 "", "123",
283 "456",
284 "789",
285 "名詞",
286 "一般",
287 "*",
288 "*",
289 "*",
290 "*",
291 "surface_form",
292 "読み",
293 "発音",
294 ]);
295
296 let result = schema.validate_fields(&record);
297 assert!(result.is_err());
298 }
299
300 #[test]
301 fn test_validate_fields_missing_field() {
302 let schema = Schema::default();
303 let record = StringRecord::from(vec![
304 "surface_form", ]);
306
307 let result = schema.validate_fields(&record);
308 assert!(result.is_err());
309 }
310
311 #[test]
312 fn test_backward_compatibility() {
313 let schema = Schema::default();
314
315 let field = schema.get_field_by_name("surface").unwrap();
317 assert_eq!(field.index, 0);
318 assert_eq!(field.name, "surface");
319 assert_eq!(field.field_type, FieldType::Surface);
320
321 let field = schema.get_field_by_name("major_pos").unwrap();
322 assert_eq!(field.index, 4);
323 assert_eq!(field.name, "major_pos");
324 assert_eq!(field.field_type, FieldType::Custom);
325 }
326
327 #[test]
328 fn test_custom_fields() {
329 let schema = Schema::default();
330 let custom_fields = schema.get_custom_fields();
331 assert_eq!(custom_fields.len(), 9);
332 assert_eq!(custom_fields[0], "major_pos");
333 assert_eq!(custom_fields[8], "pronunciation");
334 }
335}