lindera_dictionary/dictionary/
schema.rs1use csv::StringRecord;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4
5use crate::LinderaResult;
6use crate::error::LinderaErrorKind;
7
8#[derive(Debug, Clone, Serialize)]
10pub struct Schema {
11 pub fields: Vec<String>,
13 #[serde(skip)]
15 field_index_map: Option<HashMap<String, usize>>,
16}
17
18impl<'de> serde::Deserialize<'de> for Schema {
19 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
20 where
21 D: serde::Deserializer<'de>,
22 {
23 #[derive(Deserialize)]
24 struct DictionarySchemaHelper {
25 fields: Vec<String>,
26 }
27
28 let helper = DictionarySchemaHelper::deserialize(deserializer)?;
29 let mut schema = Schema {
30 fields: helper.fields,
31 field_index_map: None,
32 };
33 schema.build_index_map();
34 Ok(schema)
35 }
36}
37
38impl Schema {
39 pub fn new(fields: Vec<String>) -> Self {
41 let mut schema = Self {
42 fields,
43 field_index_map: None,
44 };
45 schema.build_index_map();
46 schema
47 }
48
49 fn build_index_map(&mut self) {
51 let mut map = HashMap::new();
52
53 for (i, field) in self.fields.iter().enumerate() {
55 map.insert(field.clone(), i);
56 }
57
58 self.field_index_map = Some(map);
59 }
60
61 pub fn get_field_index(&self, field_name: &str) -> Option<usize> {
63 self.field_index_map
64 .as_ref()
65 .and_then(|map| map.get(field_name))
66 .copied()
67 }
68
69 pub fn field_count(&self) -> usize {
71 self.get_all_fields().len()
72 }
73
74 pub fn get_field_name(&self, index: usize) -> Option<&str> {
76 self.fields.get(index).map(|s| s.as_str())
77 }
78
79 pub fn get_custom_fields(&self) -> &[String] {
81 if self.fields.len() > 4 {
82 &self.fields[4..]
83 } else {
84 &[]
85 }
86 }
87
88 pub fn get_all_fields(&self) -> &[String] {
90 &self.fields
91 }
92
93 pub fn validate_fields(&self, row: &StringRecord) -> LinderaResult<()> {
95 if row.len() < self.fields.len() {
96 return Err(LinderaErrorKind::Content.with_error(anyhow::anyhow!(
97 "CSV row has {} fields but schema requires {} fields",
98 row.len(),
99 self.fields.len()
100 )));
101 }
102
103 for (index, field_name) in self.fields.iter().enumerate() {
105 if index < row.len() && row[index].trim().is_empty() {
106 return Err(LinderaErrorKind::Content
107 .with_error(anyhow::anyhow!("Field {} is missing or empty", field_name)));
108 }
109 }
110
111 Ok(())
112 }
113}
114
115impl Schema {
117 pub fn get_field_by_name(&self, name: &str) -> Option<FieldDefinition> {
119 self.get_field_index(name).map(|index| FieldDefinition {
120 index,
121 name: name.to_string(),
122 field_type: if index < 4 {
123 match index {
124 0 => FieldType::Surface,
125 1 => FieldType::LeftContextId,
126 2 => FieldType::RightContextId,
127 3 => FieldType::Cost,
128 _ => unreachable!(),
129 }
130 } else {
131 FieldType::Custom
132 },
133 description: None,
134 })
135 }
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct FieldDefinition {
141 pub index: usize,
142 pub name: String,
143 pub field_type: FieldType,
144 pub description: Option<String>,
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
148pub enum FieldType {
149 Surface,
150 LeftContextId,
151 RightContextId,
152 Cost,
153 Custom,
154}
155
156impl Default for Schema {
157 fn default() -> Self {
158 Self::new(vec![
159 "surface".to_string(),
160 "left_context_id".to_string(),
161 "right_context_id".to_string(),
162 "cost".to_string(),
163 "major_pos".to_string(),
164 "middle_pos".to_string(),
165 "small_pos".to_string(),
166 "fine_pos".to_string(),
167 "conjugation_type".to_string(),
168 "conjugation_form".to_string(),
169 "base_form".to_string(),
170 "reading".to_string(),
171 "pronunciation".to_string(),
172 ])
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179
180 #[test]
181 fn test_new_schema() {
182 let fields = vec!["field1".to_string(), "field2".to_string()];
183 let schema = Schema::new(fields);
184
185 assert_eq!(schema.fields.len(), 2);
186 assert!(schema.field_index_map.is_some());
187 }
188
189 #[test]
190 fn test_field_index_lookup() {
191 let schema = Schema::default();
192
193 assert_eq!(schema.get_field_index("surface"), Some(0));
195 assert_eq!(schema.get_field_index("left_context_id"), Some(1));
196 assert_eq!(schema.get_field_index("right_context_id"), Some(2));
197 assert_eq!(schema.get_field_index("cost"), Some(3));
198
199 assert_eq!(schema.get_field_index("major_pos"), Some(4));
201 assert_eq!(schema.get_field_index("base_form"), Some(10));
202 assert_eq!(schema.get_field_index("pronunciation"), Some(12));
203
204 assert_eq!(schema.get_field_index("nonexistent"), None);
206 }
207
208 #[test]
209 fn test_field_name_lookup() {
210 let schema = Schema::default();
211
212 assert_eq!(schema.get_field_name(0), Some("surface"));
213 assert_eq!(schema.get_field_name(3), Some("cost"));
214 assert_eq!(schema.get_field_name(4), Some("major_pos"));
215 assert_eq!(schema.get_field_name(12), Some("pronunciation"));
216 assert_eq!(schema.get_field_name(13), None);
217 }
218
219 #[test]
220 fn test_default_schema() {
221 let schema = Schema::default();
222 assert_eq!(schema.field_count(), 13);
224 assert_eq!(schema.fields.len(), 13);
225 assert_eq!(schema.get_custom_fields().len(), 9);
226 }
227
228 #[test]
229 fn test_field_access() {
230 let schema = Schema::default();
231
232 assert_eq!(schema.get_field_index("surface"), Some(0));
233 assert_eq!(schema.get_field_index("left_context_id"), Some(1));
234 assert_eq!(schema.get_field_index("right_context_id"), Some(2));
235 assert_eq!(schema.get_field_index("cost"), Some(3));
236 }
237
238 #[test]
239 fn test_validate_fields_success() {
240 let schema = Schema::default();
241 let record = StringRecord::from(vec![
242 "surface_form",
243 "123",
244 "456",
245 "789",
246 "名詞",
247 "一般",
248 "*",
249 "*",
250 "*",
251 "*",
252 "surface_form",
253 "読み",
254 "発音",
255 ]);
256
257 let result = schema.validate_fields(&record);
258 assert!(result.is_ok());
259 }
260
261 #[test]
262 fn test_validate_fields_empty_field() {
263 let schema = Schema::default();
264 let record = StringRecord::from(vec![
265 "", "123",
267 "456",
268 "789",
269 "名詞",
270 "一般",
271 "*",
272 "*",
273 "*",
274 "*",
275 "surface_form",
276 "読み",
277 "発音",
278 ]);
279
280 let result = schema.validate_fields(&record);
281 assert!(result.is_err());
282 }
283
284 #[test]
285 fn test_validate_fields_missing_field() {
286 let schema = Schema::default();
287 let record = StringRecord::from(vec![
288 "surface_form", ]);
290
291 let result = schema.validate_fields(&record);
292 assert!(result.is_err());
293 }
294
295 #[test]
296 fn test_backward_compatibility() {
297 let schema = Schema::default();
298
299 let field = schema.get_field_by_name("surface").unwrap();
301 assert_eq!(field.index, 0);
302 assert_eq!(field.name, "surface");
303 assert_eq!(field.field_type, FieldType::Surface);
304
305 let field = schema.get_field_by_name("major_pos").unwrap();
306 assert_eq!(field.index, 4);
307 assert_eq!(field.name, "major_pos");
308 assert_eq!(field.field_type, FieldType::Custom);
309 }
310
311 #[test]
312 fn test_custom_fields() {
313 let schema = Schema::default();
314 let custom_fields = schema.get_custom_fields();
315 assert_eq!(custom_fields.len(), 9);
316 assert_eq!(custom_fields[0], "major_pos");
317 assert_eq!(custom_fields[8], "pronunciation");
318 }
319}