use csv::StringRecord;
use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::LinderaResult;
use crate::error::LinderaErrorKind;
#[derive(Debug, Clone, Serialize, Archive, RkyvSerialize, RkyvDeserialize)]
pub struct Schema {
pub fields: Vec<String>,
#[serde(skip)]
field_index_map: Option<HashMap<String, usize>>,
}
impl<'de> serde::Deserialize<'de> for Schema {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
struct DictionarySchemaHelper {
fields: Vec<String>,
}
let helper = <DictionarySchemaHelper as serde::Deserialize>::deserialize(deserializer)?;
let mut schema = Schema {
fields: helper.fields,
field_index_map: None,
};
schema.build_index_map();
Ok(schema)
}
}
impl Default for Schema {
fn default() -> Self {
let fields = vec![
"surface",
"left_context_id",
"right_context_id",
"cost",
"major_pos",
"pos_detail_1",
"pos_detail_2",
"pos_detail_3",
"conjugation_type",
"conjugation_form",
"base_form",
"reading",
"pronunciation",
]
.into_iter()
.map(|s| s.to_string())
.collect();
let mut schema = Self {
fields,
field_index_map: None,
};
schema.build_index_map();
schema
}
}
impl Schema {
pub fn new(fields: Vec<String>) -> Self {
let mut schema = Self {
fields,
field_index_map: None,
};
schema.build_index_map();
schema
}
fn build_index_map(&mut self) {
let mut map = HashMap::new();
for (i, field) in self.fields.iter().enumerate() {
map.insert(field.clone(), i);
}
self.field_index_map = Some(map);
}
pub fn get_field_index(&self, field_name: &str) -> Option<usize> {
self.field_index_map
.as_ref()
.and_then(|map| map.get(field_name))
.copied()
}
pub fn field_count(&self) -> usize {
self.get_all_fields().len()
}
pub fn get_field_name(&self, index: usize) -> Option<&str> {
self.fields.get(index).map(|s| s.as_str())
}
pub fn get_custom_fields(&self) -> &[String] {
if self.fields.len() > 4 {
&self.fields[4..]
} else {
&[]
}
}
pub fn get_all_fields(&self) -> &[String] {
&self.fields
}
pub fn validate_fields(&self, row: &StringRecord) -> LinderaResult<()> {
if row.len() < self.fields.len() {
return Err(LinderaErrorKind::Content.with_error(anyhow::anyhow!(
"CSV row has {} fields but schema requires {} fields",
row.len(),
self.fields.len()
)));
}
for (index, field_name) in self.fields.iter().enumerate() {
if index < row.len() && row[index].trim().is_empty() {
return Err(LinderaErrorKind::Content
.with_error(anyhow::anyhow!("Field {field_name} is missing or empty")));
}
}
Ok(())
}
}
impl Schema {
pub fn get_field_by_name(&self, name: &str) -> Option<FieldDefinition> {
self.get_field_index(name).map(|index| FieldDefinition {
index,
name: name.to_string(),
field_type: if index < 4 {
match index {
0 => FieldType::Surface,
1 => FieldType::LeftContextId,
2 => FieldType::RightContextId,
3 => FieldType::Cost,
_ => unreachable!(),
}
} else {
FieldType::Custom
},
description: None,
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
pub struct FieldDefinition {
pub index: usize,
pub name: String,
pub field_type: FieldType,
pub description: Option<String>,
}
#[derive(
Debug, Clone, Serialize, Deserialize, PartialEq, Archive, RkyvSerialize, RkyvDeserialize,
)]
pub enum FieldType {
Surface,
LeftContextId,
RightContextId,
Cost,
Custom,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_schema() {
let fields = vec!["field1".to_string(), "field2".to_string()];
let schema = Schema::new(fields);
assert_eq!(schema.fields.len(), 2);
assert!(schema.field_index_map.is_some());
}
#[test]
fn test_field_index_lookup() {
let schema = Schema::default();
assert_eq!(schema.get_field_index("surface"), Some(0));
assert_eq!(schema.get_field_index("left_context_id"), Some(1));
assert_eq!(schema.get_field_index("right_context_id"), Some(2));
assert_eq!(schema.get_field_index("cost"), Some(3));
assert_eq!(schema.get_field_index("major_pos"), Some(4));
assert_eq!(schema.get_field_index("base_form"), Some(10));
assert_eq!(schema.get_field_index("pronunciation"), Some(12));
assert_eq!(schema.get_field_index("nonexistent"), None);
}
#[test]
fn test_field_name_lookup() {
let schema = Schema::default();
assert_eq!(schema.get_field_name(0), Some("surface"));
assert_eq!(schema.get_field_name(3), Some("cost"));
assert_eq!(schema.get_field_name(4), Some("major_pos"));
assert_eq!(schema.get_field_name(12), Some("pronunciation"));
assert_eq!(schema.get_field_name(13), None);
}
#[test]
fn test_default_schema() {
let schema = Schema::default();
assert_eq!(schema.field_count(), 13);
assert_eq!(schema.fields.len(), 13);
assert_eq!(schema.get_custom_fields().len(), 9);
}
#[test]
fn test_field_access() {
let schema = Schema::default();
assert_eq!(schema.get_field_index("surface"), Some(0));
assert_eq!(schema.get_field_index("left_context_id"), Some(1));
assert_eq!(schema.get_field_index("right_context_id"), Some(2));
assert_eq!(schema.get_field_index("cost"), Some(3));
}
#[test]
fn test_validate_fields_success() {
let schema = Schema::default();
let record = StringRecord::from(vec![
"surface_form",
"123",
"456",
"789",
"名詞",
"一般",
"*",
"*",
"*",
"*",
"surface_form",
"読み",
"発音",
]);
let result = schema.validate_fields(&record);
assert!(result.is_ok());
}
#[test]
fn test_validate_fields_empty_field() {
let schema = Schema::default();
let record = StringRecord::from(vec![
"", "123",
"456",
"789",
"名詞",
"一般",
"*",
"*",
"*",
"*",
"surface_form",
"読み",
"発音",
]);
let result = schema.validate_fields(&record);
assert!(result.is_err());
}
#[test]
fn test_validate_fields_missing_field() {
let schema = Schema::default();
let record = StringRecord::from(vec![
"surface_form", ]);
let result = schema.validate_fields(&record);
assert!(result.is_err());
}
#[test]
fn test_backward_compatibility() {
let schema = Schema::default();
let field = schema.get_field_by_name("surface").unwrap();
assert_eq!(field.index, 0);
assert_eq!(field.name, "surface");
assert_eq!(field.field_type, FieldType::Surface);
let field = schema.get_field_by_name("major_pos").unwrap();
assert_eq!(field.index, 4);
assert_eq!(field.name, "major_pos");
assert_eq!(field.field_type, FieldType::Custom);
}
#[test]
fn test_custom_fields() {
let schema = Schema::default();
let custom_fields = schema.get_custom_fields();
assert_eq!(custom_fields.len(), 9);
assert_eq!(custom_fields[0], "major_pos");
assert_eq!(custom_fields[8], "pronunciation");
}
}