lindera_dictionary/dictionary/
schema.rs

1use csv::StringRecord;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4
5use crate::LinderaResult;
6use crate::error::LinderaErrorKind;
7
8/// Dictionary schema that defines the structure of dictionary entries
9#[derive(Debug, Clone, Serialize)]
10pub struct Schema {
11    /// All field names including common fields (surface, left_context_id, right_context_id, cost, ...)
12    pub fields: Vec<String>,
13    /// Field name to index mapping for fast lookup
14    #[serde(skip)]
15    field_index_map: Option<HashMap<String, usize>>,
16}
17
18impl<'de> serde::Deserialize<'de> for Schema {
19    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
20    where
21        D: serde::Deserializer<'de>,
22    {
23        #[derive(Deserialize)]
24        struct DictionarySchemaHelper {
25            fields: Vec<String>,
26        }
27
28        let helper = DictionarySchemaHelper::deserialize(deserializer)?;
29        let mut schema = Schema {
30            fields: helper.fields,
31            field_index_map: None,
32        };
33        schema.build_index_map();
34        Ok(schema)
35    }
36}
37
38impl Schema {
39    /// Create a new dictionary schema
40    pub fn new(fields: Vec<String>) -> Self {
41        let mut schema = Self {
42            fields,
43            field_index_map: None,
44        };
45        schema.build_index_map();
46        schema
47    }
48
49    /// Build field name to index mapping
50    fn build_index_map(&mut self) {
51        let mut map = HashMap::new();
52
53        // All fields
54        for (i, field) in self.fields.iter().enumerate() {
55            map.insert(field.clone(), i);
56        }
57
58        self.field_index_map = Some(map);
59    }
60
61    /// Get field index by name
62    pub fn get_field_index(&self, field_name: &str) -> Option<usize> {
63        self.field_index_map
64            .as_ref()
65            .and_then(|map| map.get(field_name))
66            .copied()
67    }
68
69    /// Get total field count
70    pub fn field_count(&self) -> usize {
71        self.get_all_fields().len()
72    }
73
74    /// Get field name by index
75    pub fn get_field_name(&self, index: usize) -> Option<&str> {
76        self.fields.get(index).map(|s| s.as_str())
77    }
78
79    /// Get custom fields (index >= 4)
80    pub fn get_custom_fields(&self) -> &[String] {
81        if self.fields.len() > 4 {
82            &self.fields[4..]
83        } else {
84            &[]
85        }
86    }
87
88    /// Get all fields
89    pub fn get_all_fields(&self) -> &[String] {
90        &self.fields
91    }
92
93    /// Validate that CSV row has all required fields
94    pub fn validate_fields(&self, row: &StringRecord) -> LinderaResult<()> {
95        if row.len() < self.fields.len() {
96            return Err(LinderaErrorKind::Content.with_error(anyhow::anyhow!(
97                "CSV row has {} fields but schema requires {} fields",
98                row.len(),
99                self.fields.len()
100            )));
101        }
102
103        // Check that required fields are not empty
104        for (index, field_name) in self.fields.iter().enumerate() {
105            if index < row.len() && row[index].trim().is_empty() {
106                return Err(LinderaErrorKind::Content
107                    .with_error(anyhow::anyhow!("Field {} is missing or empty", field_name)));
108            }
109        }
110
111        Ok(())
112    }
113}
114
115// Helper methods for backward compatibility
116impl Schema {
117    /// Find field by name (backward compatibility)
118    pub fn get_field_by_name(&self, name: &str) -> Option<FieldDefinition> {
119        self.get_field_index(name).map(|index| FieldDefinition {
120            index,
121            name: name.to_string(),
122            field_type: if index < 4 {
123                match index {
124                    0 => FieldType::Surface,
125                    1 => FieldType::LeftContextId,
126                    2 => FieldType::RightContextId,
127                    3 => FieldType::Cost,
128                    _ => unreachable!(),
129                }
130            } else {
131                FieldType::Custom
132            },
133            description: None,
134        })
135    }
136}
137
138// Backward compatibility types
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct FieldDefinition {
141    pub index: usize,
142    pub name: String,
143    pub field_type: FieldType,
144    pub description: Option<String>,
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
148pub enum FieldType {
149    Surface,
150    LeftContextId,
151    RightContextId,
152    Cost,
153    Custom,
154}
155
156impl Default for Schema {
157    fn default() -> Self {
158        Self::new(vec![
159            "surface".to_string(),
160            "left_context_id".to_string(),
161            "right_context_id".to_string(),
162            "cost".to_string(),
163            "major_pos".to_string(),
164            "middle_pos".to_string(),
165            "small_pos".to_string(),
166            "fine_pos".to_string(),
167            "conjugation_type".to_string(),
168            "conjugation_form".to_string(),
169            "base_form".to_string(),
170            "reading".to_string(),
171            "pronunciation".to_string(),
172        ])
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    #[test]
181    fn test_new_schema() {
182        let fields = vec!["field1".to_string(), "field2".to_string()];
183        let schema = Schema::new(fields);
184
185        assert_eq!(schema.fields.len(), 2);
186        assert!(schema.field_index_map.is_some());
187    }
188
189    #[test]
190    fn test_field_index_lookup() {
191        let schema = Schema::default();
192
193        // Common fields
194        assert_eq!(schema.get_field_index("surface"), Some(0));
195        assert_eq!(schema.get_field_index("left_context_id"), Some(1));
196        assert_eq!(schema.get_field_index("right_context_id"), Some(2));
197        assert_eq!(schema.get_field_index("cost"), Some(3));
198
199        // Custom fields
200        assert_eq!(schema.get_field_index("major_pos"), Some(4));
201        assert_eq!(schema.get_field_index("base_form"), Some(10));
202        assert_eq!(schema.get_field_index("pronunciation"), Some(12));
203
204        // Non-existent field
205        assert_eq!(schema.get_field_index("nonexistent"), None);
206    }
207
208    #[test]
209    fn test_field_name_lookup() {
210        let schema = Schema::default();
211
212        assert_eq!(schema.get_field_name(0), Some("surface"));
213        assert_eq!(schema.get_field_name(3), Some("cost"));
214        assert_eq!(schema.get_field_name(4), Some("major_pos"));
215        assert_eq!(schema.get_field_name(12), Some("pronunciation"));
216        assert_eq!(schema.get_field_name(13), None);
217    }
218
219    #[test]
220    fn test_default_schema() {
221        let schema = Schema::default();
222        // All fields including common fields
223        assert_eq!(schema.field_count(), 13);
224        assert_eq!(schema.fields.len(), 13);
225        assert_eq!(schema.get_custom_fields().len(), 9);
226    }
227
228    #[test]
229    fn test_field_access() {
230        let schema = Schema::default();
231
232        assert_eq!(schema.get_field_index("surface"), Some(0));
233        assert_eq!(schema.get_field_index("left_context_id"), Some(1));
234        assert_eq!(schema.get_field_index("right_context_id"), Some(2));
235        assert_eq!(schema.get_field_index("cost"), Some(3));
236    }
237
238    #[test]
239    fn test_validate_fields_success() {
240        let schema = Schema::default();
241        let record = StringRecord::from(vec![
242            "surface_form",
243            "123",
244            "456",
245            "789",
246            "名詞",
247            "一般",
248            "*",
249            "*",
250            "*",
251            "*",
252            "surface_form",
253            "読み",
254            "発音",
255        ]);
256
257        let result = schema.validate_fields(&record);
258        assert!(result.is_ok());
259    }
260
261    #[test]
262    fn test_validate_fields_empty_field() {
263        let schema = Schema::default();
264        let record = StringRecord::from(vec![
265            "", // Empty surface
266            "123",
267            "456",
268            "789",
269            "名詞",
270            "一般",
271            "*",
272            "*",
273            "*",
274            "*",
275            "surface_form",
276            "読み",
277            "発音",
278        ]);
279
280        let result = schema.validate_fields(&record);
281        assert!(result.is_err());
282    }
283
284    #[test]
285    fn test_validate_fields_missing_field() {
286        let schema = Schema::default();
287        let record = StringRecord::from(vec![
288            "surface_form", // Only first field
289        ]);
290
291        let result = schema.validate_fields(&record);
292        assert!(result.is_err());
293    }
294
295    #[test]
296    fn test_backward_compatibility() {
297        let schema = Schema::default();
298
299        // Test get_field_by_name
300        let field = schema.get_field_by_name("surface").unwrap();
301        assert_eq!(field.index, 0);
302        assert_eq!(field.name, "surface");
303        assert_eq!(field.field_type, FieldType::Surface);
304
305        let field = schema.get_field_by_name("major_pos").unwrap();
306        assert_eq!(field.index, 4);
307        assert_eq!(field.name, "major_pos");
308        assert_eq!(field.field_type, FieldType::Custom);
309    }
310
311    #[test]
312    fn test_custom_fields() {
313        let schema = Schema::default();
314        let custom_fields = schema.get_custom_fields();
315        assert_eq!(custom_fields.len(), 9);
316        assert_eq!(custom_fields[0], "major_pos");
317        assert_eq!(custom_fields[8], "pronunciation");
318    }
319}