lindera_dictionary/dictionary/
schema.rs

1use csv::StringRecord;
2use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6use crate::LinderaResult;
7use crate::error::LinderaErrorKind;
8
9/// Dictionary schema that defines the structure of dictionary entries
10#[derive(Debug, Clone, Serialize, Archive, RkyvSerialize, RkyvDeserialize)]
11
12pub struct Schema {
13    /// All field names including common fields (surface, left_context_id, right_context_id, cost, ...)
14    pub fields: Vec<String>,
15    /// Field name to index mapping for fast lookup
16    #[serde(skip)]
17    field_index_map: Option<HashMap<String, usize>>,
18}
19
20impl<'de> serde::Deserialize<'de> for Schema {
21    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
22    where
23        D: serde::Deserializer<'de>,
24    {
25        #[derive(Deserialize)]
26        struct DictionarySchemaHelper {
27            fields: Vec<String>,
28        }
29
30        let helper = <DictionarySchemaHelper as serde::Deserialize>::deserialize(deserializer)?;
31        let mut schema = Schema {
32            fields: helper.fields,
33            field_index_map: None,
34        };
35        schema.build_index_map();
36        Ok(schema)
37    }
38}
39
40impl Default for Schema {
41    fn default() -> Self {
42        let fields = vec![
43            "surface",
44            "left_context_id",
45            "right_context_id",
46            "cost",
47            "major_pos",
48            "pos_detail_1",
49            "pos_detail_2",
50            "pos_detail_3",
51            "conjugation_type",
52            "conjugation_form",
53            "base_form",
54            "reading",
55            "pronunciation",
56        ]
57        .into_iter()
58        .map(|s| s.to_string())
59        .collect();
60
61        let mut schema = Self {
62            fields,
63            field_index_map: None,
64        };
65        schema.build_index_map();
66        schema
67    }
68}
69
70impl Schema {
71    /// Create a new dictionary schema
72    pub fn new(fields: Vec<String>) -> Self {
73        let mut schema = Self {
74            fields,
75            field_index_map: None,
76        };
77        schema.build_index_map();
78        schema
79    }
80
81    /// Build field name to index mapping
82    fn build_index_map(&mut self) {
83        let mut map = HashMap::new();
84
85        // All fields
86        for (i, field) in self.fields.iter().enumerate() {
87            map.insert(field.clone(), i);
88        }
89
90        self.field_index_map = Some(map);
91    }
92
93    /// Get field index by name
94    pub fn get_field_index(&self, field_name: &str) -> Option<usize> {
95        self.field_index_map
96            .as_ref()
97            .and_then(|map| map.get(field_name))
98            .copied()
99    }
100
101    /// Get total field count
102    pub fn field_count(&self) -> usize {
103        self.get_all_fields().len()
104    }
105
106    /// Get field name by index
107    pub fn get_field_name(&self, index: usize) -> Option<&str> {
108        self.fields.get(index).map(|s| s.as_str())
109    }
110
111    /// Get custom fields (index >= 4)
112    pub fn get_custom_fields(&self) -> &[String] {
113        if self.fields.len() > 4 {
114            &self.fields[4..]
115        } else {
116            &[]
117        }
118    }
119
120    /// Get all fields
121    pub fn get_all_fields(&self) -> &[String] {
122        &self.fields
123    }
124
125    /// Validate that CSV row has all required fields
126    pub fn validate_fields(&self, row: &StringRecord) -> LinderaResult<()> {
127        if row.len() < self.fields.len() {
128            return Err(LinderaErrorKind::Content.with_error(anyhow::anyhow!(
129                "CSV row has {} fields but schema requires {} fields",
130                row.len(),
131                self.fields.len()
132            )));
133        }
134
135        // Check that required fields are not empty
136        for (index, field_name) in self.fields.iter().enumerate() {
137            if index < row.len() && row[index].trim().is_empty() {
138                return Err(LinderaErrorKind::Content
139                    .with_error(anyhow::anyhow!("Field {field_name} is missing or empty")));
140            }
141        }
142
143        Ok(())
144    }
145}
146
147// Helper methods for backward compatibility
148impl Schema {
149    /// Find field by name (backward compatibility)
150    pub fn get_field_by_name(&self, name: &str) -> Option<FieldDefinition> {
151        self.get_field_index(name).map(|index| FieldDefinition {
152            index,
153            name: name.to_string(),
154            field_type: if index < 4 {
155                match index {
156                    0 => FieldType::Surface,
157                    1 => FieldType::LeftContextId,
158                    2 => FieldType::RightContextId,
159                    3 => FieldType::Cost,
160                    _ => unreachable!(),
161                }
162            } else {
163                FieldType::Custom
164            },
165            description: None,
166        })
167    }
168}
169
170// Backward compatibility types
171#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
172
173pub struct FieldDefinition {
174    pub index: usize,
175    pub name: String,
176    pub field_type: FieldType,
177    pub description: Option<String>,
178}
179
180#[derive(
181    Debug, Clone, Serialize, Deserialize, PartialEq, Archive, RkyvSerialize, RkyvDeserialize,
182)]
183
184pub enum FieldType {
185    Surface,
186    LeftContextId,
187    RightContextId,
188    Cost,
189    Custom,
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    #[test]
197    fn test_new_schema() {
198        let fields = vec!["field1".to_string(), "field2".to_string()];
199        let schema = Schema::new(fields);
200
201        assert_eq!(schema.fields.len(), 2);
202        assert!(schema.field_index_map.is_some());
203    }
204
205    #[test]
206    fn test_field_index_lookup() {
207        let schema = Schema::default();
208
209        // Common fields
210        assert_eq!(schema.get_field_index("surface"), Some(0));
211        assert_eq!(schema.get_field_index("left_context_id"), Some(1));
212        assert_eq!(schema.get_field_index("right_context_id"), Some(2));
213        assert_eq!(schema.get_field_index("cost"), Some(3));
214
215        // Custom fields
216        assert_eq!(schema.get_field_index("major_pos"), Some(4));
217        assert_eq!(schema.get_field_index("base_form"), Some(10));
218        assert_eq!(schema.get_field_index("pronunciation"), Some(12));
219
220        // Non-existent field
221        assert_eq!(schema.get_field_index("nonexistent"), None);
222    }
223
224    #[test]
225    fn test_field_name_lookup() {
226        let schema = Schema::default();
227
228        assert_eq!(schema.get_field_name(0), Some("surface"));
229        assert_eq!(schema.get_field_name(3), Some("cost"));
230        assert_eq!(schema.get_field_name(4), Some("major_pos"));
231        assert_eq!(schema.get_field_name(12), Some("pronunciation"));
232        assert_eq!(schema.get_field_name(13), None);
233    }
234
235    #[test]
236    fn test_default_schema() {
237        let schema = Schema::default();
238        // All fields including common fields
239        assert_eq!(schema.field_count(), 13);
240        assert_eq!(schema.fields.len(), 13);
241        assert_eq!(schema.get_custom_fields().len(), 9);
242    }
243
244    #[test]
245    fn test_field_access() {
246        let schema = Schema::default();
247
248        assert_eq!(schema.get_field_index("surface"), Some(0));
249        assert_eq!(schema.get_field_index("left_context_id"), Some(1));
250        assert_eq!(schema.get_field_index("right_context_id"), Some(2));
251        assert_eq!(schema.get_field_index("cost"), Some(3));
252    }
253
254    #[test]
255    fn test_validate_fields_success() {
256        let schema = Schema::default();
257        let record = StringRecord::from(vec![
258            "surface_form",
259            "123",
260            "456",
261            "789",
262            "名詞",
263            "一般",
264            "*",
265            "*",
266            "*",
267            "*",
268            "surface_form",
269            "読み",
270            "発音",
271        ]);
272
273        let result = schema.validate_fields(&record);
274        assert!(result.is_ok());
275    }
276
277    #[test]
278    fn test_validate_fields_empty_field() {
279        let schema = Schema::default();
280        let record = StringRecord::from(vec![
281            "", // Empty surface
282            "123",
283            "456",
284            "789",
285            "名詞",
286            "一般",
287            "*",
288            "*",
289            "*",
290            "*",
291            "surface_form",
292            "読み",
293            "発音",
294        ]);
295
296        let result = schema.validate_fields(&record);
297        assert!(result.is_err());
298    }
299
300    #[test]
301    fn test_validate_fields_missing_field() {
302        let schema = Schema::default();
303        let record = StringRecord::from(vec![
304            "surface_form", // Only first field
305        ]);
306
307        let result = schema.validate_fields(&record);
308        assert!(result.is_err());
309    }
310
311    #[test]
312    fn test_backward_compatibility() {
313        let schema = Schema::default();
314
315        // Test get_field_by_name
316        let field = schema.get_field_by_name("surface").unwrap();
317        assert_eq!(field.index, 0);
318        assert_eq!(field.name, "surface");
319        assert_eq!(field.field_type, FieldType::Surface);
320
321        let field = schema.get_field_by_name("major_pos").unwrap();
322        assert_eq!(field.index, 4);
323        assert_eq!(field.name, "major_pos");
324        assert_eq!(field.field_type, FieldType::Custom);
325    }
326
327    #[test]
328    fn test_custom_fields() {
329        let schema = Schema::default();
330        let custom_fields = schema.get_custom_fields();
331        assert_eq!(custom_fields.len(), 9);
332        assert_eq!(custom_fields[0], "major_pos");
333        assert_eq!(custom_fields[8], "pronunciation");
334    }
335}