1use std::collections::HashMap;
7
8use magnus::prelude::*;
9use magnus::{Error, RArray, Ruby, function, method};
10
11use lindera::dictionary::{FieldDefinition, FieldType, Schema};
12
13#[magnus::wrap(class = "Lindera::FieldType", free_immediately, size)]
17#[derive(Debug, Clone)]
18pub struct RbFieldType {
19 inner: RbFieldTypeKind,
21}
22
23#[derive(Debug, Clone)]
25enum RbFieldTypeKind {
26 Surface,
28 LeftContextId,
30 RightContextId,
32 Cost,
34 Custom,
36}
37
38impl RbFieldType {
39 fn to_s(&self) -> &str {
41 match self.inner {
42 RbFieldTypeKind::Surface => "surface",
43 RbFieldTypeKind::LeftContextId => "left_context_id",
44 RbFieldTypeKind::RightContextId => "right_context_id",
45 RbFieldTypeKind::Cost => "cost",
46 RbFieldTypeKind::Custom => "custom",
47 }
48 }
49
50 fn inspect(&self) -> String {
52 format!("#<Lindera::FieldType: {}>", self.to_s())
53 }
54}
55
56impl From<FieldType> for RbFieldType {
57 fn from(field_type: FieldType) -> Self {
58 let kind = match field_type {
59 FieldType::Surface => RbFieldTypeKind::Surface,
60 FieldType::LeftContextId => RbFieldTypeKind::LeftContextId,
61 FieldType::RightContextId => RbFieldTypeKind::RightContextId,
62 FieldType::Cost => RbFieldTypeKind::Cost,
63 FieldType::Custom => RbFieldTypeKind::Custom,
64 };
65 RbFieldType { inner: kind }
66 }
67}
68
69impl From<RbFieldType> for FieldType {
70 fn from(field_type: RbFieldType) -> Self {
71 match field_type.inner {
72 RbFieldTypeKind::Surface => FieldType::Surface,
73 RbFieldTypeKind::LeftContextId => FieldType::LeftContextId,
74 RbFieldTypeKind::RightContextId => FieldType::RightContextId,
75 RbFieldTypeKind::Cost => FieldType::Cost,
76 RbFieldTypeKind::Custom => FieldType::Custom,
77 }
78 }
79}
80
81#[magnus::wrap(class = "Lindera::FieldDefinition", free_immediately, size)]
85#[derive(Debug, Clone)]
86pub struct RbFieldDefinition {
87 pub index: usize,
89 pub name: String,
91 pub field_type: RbFieldType,
93 pub description: Option<String>,
95}
96
97impl RbFieldDefinition {
98 fn index(&self) -> usize {
100 self.index
101 }
102
103 fn name(&self) -> String {
105 self.name.clone()
106 }
107
108 fn field_type(&self) -> RbFieldType {
110 self.field_type.clone()
111 }
112
113 fn description(&self) -> Option<String> {
115 self.description.clone()
116 }
117
118 fn to_s(&self) -> String {
120 format!("FieldDefinition(index={}, name={})", self.index, self.name)
121 }
122
123 fn inspect(&self) -> String {
125 format!(
126 "#<Lindera::FieldDefinition: index={}, name='{}', field_type={:?}, description={:?}>",
127 self.index, self.name, self.field_type.inner, self.description
128 )
129 }
130}
131
132impl From<FieldDefinition> for RbFieldDefinition {
133 fn from(field_def: FieldDefinition) -> Self {
134 RbFieldDefinition {
135 index: field_def.index,
136 name: field_def.name,
137 field_type: field_def.field_type.into(),
138 description: field_def.description,
139 }
140 }
141}
142
143impl From<RbFieldDefinition> for FieldDefinition {
144 fn from(field_def: RbFieldDefinition) -> Self {
145 FieldDefinition {
146 index: field_def.index,
147 name: field_def.name,
148 field_type: field_def.field_type.into(),
149 description: field_def.description,
150 }
151 }
152}
153
154#[magnus::wrap(class = "Lindera::Schema", free_immediately, size)]
158#[derive(Debug, Clone)]
159pub struct RbSchema {
160 pub fields: Vec<String>,
162 field_index_map: HashMap<String, usize>,
164}
165
166impl RbSchema {
167 fn new(fields: Vec<String>) -> Self {
177 let mut field_index_map = HashMap::new();
178 for (i, field) in fields.iter().enumerate() {
179 field_index_map.insert(field.clone(), i);
180 }
181 Self {
182 fields,
183 field_index_map,
184 }
185 }
186
187 fn create_default() -> Self {
193 Self::new(vec![
194 "surface".to_string(),
195 "left_context_id".to_string(),
196 "right_context_id".to_string(),
197 "cost".to_string(),
198 "major_pos".to_string(),
199 "middle_pos".to_string(),
200 "small_pos".to_string(),
201 "fine_pos".to_string(),
202 "conjugation_type".to_string(),
203 "conjugation_form".to_string(),
204 "base_form".to_string(),
205 "reading".to_string(),
206 "pronunciation".to_string(),
207 ])
208 }
209
210 fn fields(&self) -> Vec<String> {
212 self.fields.clone()
213 }
214
215 fn get_field_index(&self, field_name: String) -> Option<usize> {
225 self.field_index_map.get(&field_name).copied()
226 }
227
228 fn field_count(&self) -> usize {
234 self.fields.len()
235 }
236
237 fn get_field_name(&self, index: usize) -> Option<String> {
247 self.fields.get(index).cloned()
248 }
249
250 fn get_custom_fields(&self) -> Vec<String> {
256 if self.fields.len() > 4 {
257 self.fields[4..].to_vec()
258 } else {
259 Vec::new()
260 }
261 }
262
263 fn get_all_fields(&self) -> Vec<String> {
269 self.fields.clone()
270 }
271
272 fn get_field_by_name(&self, name: String) -> Option<RbFieldDefinition> {
282 self.field_index_map.get(&name).map(|&index| {
283 let field_type = if index < 4 {
284 match index {
285 0 => RbFieldType {
286 inner: RbFieldTypeKind::Surface,
287 },
288 1 => RbFieldType {
289 inner: RbFieldTypeKind::LeftContextId,
290 },
291 2 => RbFieldType {
292 inner: RbFieldTypeKind::RightContextId,
293 },
294 3 => RbFieldType {
295 inner: RbFieldTypeKind::Cost,
296 },
297 _ => unreachable!(),
298 }
299 } else {
300 RbFieldType {
301 inner: RbFieldTypeKind::Custom,
302 }
303 };
304
305 RbFieldDefinition {
306 index,
307 name: name.clone(),
308 field_type,
309 description: None,
310 }
311 })
312 }
313
314 fn validate_record(&self, record: RArray) -> Result<(), Error> {
324 let ruby = Ruby::get().expect("Ruby runtime not initialized");
325 let values: Vec<String> = record.to_vec()?;
326
327 if values.len() < self.fields.len() {
328 return Err(Error::new(
329 ruby.exception_arg_error(),
330 format!(
331 "CSV row has {} fields but schema requires {} fields",
332 values.len(),
333 self.fields.len()
334 ),
335 ));
336 }
337
338 for (index, field_name) in self.fields.iter().enumerate() {
339 if index < values.len() && values[index].trim().is_empty() {
340 return Err(Error::new(
341 ruby.exception_arg_error(),
342 format!("Field {field_name} is missing or empty"),
343 ));
344 }
345 }
346
347 Ok(())
348 }
349
350 fn to_s(&self) -> String {
352 format!("Schema(fields={})", self.fields.len())
353 }
354
355 fn inspect(&self) -> String {
357 format!("#<Lindera::Schema: fields={:?}>", self.fields)
358 }
359}
360
361impl RbSchema {
362 pub fn new_internal(fields: Vec<String>) -> Self {
364 Self::new(fields)
365 }
366
367 pub fn create_default_internal() -> Self {
369 Self::create_default()
370 }
371}
372
373impl From<RbSchema> for Schema {
374 fn from(schema: RbSchema) -> Self {
375 Schema::new(schema.fields)
376 }
377}
378
379impl From<Schema> for RbSchema {
380 fn from(schema: Schema) -> Self {
381 RbSchema::new(schema.get_all_fields().to_vec())
382 }
383}
384
385pub fn define(ruby: &Ruby, module: &magnus::RModule) -> Result<(), Error> {
396 let field_type_class = module.define_class("FieldType", ruby.class_object())?;
397 field_type_class.define_method("to_s", method!(RbFieldType::to_s, 0))?;
398 field_type_class.define_method("inspect", method!(RbFieldType::inspect, 0))?;
399
400 let field_def_class = module.define_class("FieldDefinition", ruby.class_object())?;
401 field_def_class.define_method("index", method!(RbFieldDefinition::index, 0))?;
402 field_def_class.define_method("name", method!(RbFieldDefinition::name, 0))?;
403 field_def_class.define_method("field_type", method!(RbFieldDefinition::field_type, 0))?;
404 field_def_class.define_method("description", method!(RbFieldDefinition::description, 0))?;
405 field_def_class.define_method("to_s", method!(RbFieldDefinition::to_s, 0))?;
406 field_def_class.define_method("inspect", method!(RbFieldDefinition::inspect, 0))?;
407
408 let schema_class = module.define_class("Schema", ruby.class_object())?;
409 schema_class.define_singleton_method("new", function!(RbSchema::new, 1))?;
410 schema_class
411 .define_singleton_method("create_default", function!(RbSchema::create_default, 0))?;
412 schema_class.define_method("fields", method!(RbSchema::fields, 0))?;
413 schema_class.define_method("get_field_index", method!(RbSchema::get_field_index, 1))?;
414 schema_class.define_method("field_count", method!(RbSchema::field_count, 0))?;
415 schema_class.define_method("get_field_name", method!(RbSchema::get_field_name, 1))?;
416 schema_class.define_method("get_custom_fields", method!(RbSchema::get_custom_fields, 0))?;
417 schema_class.define_method("get_all_fields", method!(RbSchema::get_all_fields, 0))?;
418 schema_class.define_method("get_field_by_name", method!(RbSchema::get_field_by_name, 1))?;
419 schema_class.define_method("validate_record", method!(RbSchema::validate_record, 1))?;
420 schema_class.define_method("to_s", method!(RbSchema::to_s, 0))?;
421 schema_class.define_method("inspect", method!(RbSchema::inspect, 0))?;
422
423 Ok(())
424}
425
426#[cfg(test)]
427mod tests {
428 use super::*;
429
430 #[test]
431 fn test_rb_field_type_surface_to_lindera() {
432 let rb = RbFieldType {
433 inner: RbFieldTypeKind::Surface,
434 };
435 let lindera: FieldType = rb.into();
436 assert!(matches!(lindera, FieldType::Surface));
437 }
438
439 #[test]
440 fn test_rb_field_type_left_context_id_to_lindera() {
441 let rb = RbFieldType {
442 inner: RbFieldTypeKind::LeftContextId,
443 };
444 let lindera: FieldType = rb.into();
445 assert!(matches!(lindera, FieldType::LeftContextId));
446 }
447
448 #[test]
449 fn test_rb_field_type_right_context_id_to_lindera() {
450 let rb = RbFieldType {
451 inner: RbFieldTypeKind::RightContextId,
452 };
453 let lindera: FieldType = rb.into();
454 assert!(matches!(lindera, FieldType::RightContextId));
455 }
456
457 #[test]
458 fn test_rb_field_type_cost_to_lindera() {
459 let rb = RbFieldType {
460 inner: RbFieldTypeKind::Cost,
461 };
462 let lindera: FieldType = rb.into();
463 assert!(matches!(lindera, FieldType::Cost));
464 }
465
466 #[test]
467 fn test_rb_field_type_custom_to_lindera() {
468 let rb = RbFieldType {
469 inner: RbFieldTypeKind::Custom,
470 };
471 let lindera: FieldType = rb.into();
472 assert!(matches!(lindera, FieldType::Custom));
473 }
474
475 #[test]
476 fn test_lindera_field_type_surface_to_rb() {
477 let rb: RbFieldType = FieldType::Surface.into();
478 assert!(matches!(rb.inner, RbFieldTypeKind::Surface));
479 }
480
481 #[test]
482 fn test_lindera_field_type_left_context_id_to_rb() {
483 let rb: RbFieldType = FieldType::LeftContextId.into();
484 assert!(matches!(rb.inner, RbFieldTypeKind::LeftContextId));
485 }
486
487 #[test]
488 fn test_lindera_field_type_right_context_id_to_rb() {
489 let rb: RbFieldType = FieldType::RightContextId.into();
490 assert!(matches!(rb.inner, RbFieldTypeKind::RightContextId));
491 }
492
493 #[test]
494 fn test_lindera_field_type_cost_to_rb() {
495 let rb: RbFieldType = FieldType::Cost.into();
496 assert!(matches!(rb.inner, RbFieldTypeKind::Cost));
497 }
498
499 #[test]
500 fn test_lindera_field_type_custom_to_rb() {
501 let rb: RbFieldType = FieldType::Custom.into();
502 assert!(matches!(rb.inner, RbFieldTypeKind::Custom));
503 }
504
505 #[test]
506 fn test_rb_schema_new_builds_index_map() {
507 let fields = vec!["a".to_string(), "b".to_string(), "c".to_string()];
508 let schema = RbSchema::new_internal(fields);
509 assert_eq!(schema.get_field_index("a".to_string()), Some(0));
510 assert_eq!(schema.get_field_index("b".to_string()), Some(1));
511 assert_eq!(schema.get_field_index("c".to_string()), Some(2));
512 assert_eq!(schema.get_field_index("d".to_string()), None);
513 }
514
515 #[test]
516 fn test_rb_schema_field_count() {
517 let fields = vec!["x".to_string(), "y".to_string()];
518 let schema = RbSchema::new_internal(fields);
519 assert_eq!(schema.field_count(), 2);
520 }
521
522 #[test]
523 fn test_rb_schema_get_custom_fields_with_more_than_4() {
524 let fields = vec![
525 "surface".to_string(),
526 "left_context_id".to_string(),
527 "right_context_id".to_string(),
528 "cost".to_string(),
529 "major_pos".to_string(),
530 "reading".to_string(),
531 ];
532 let schema = RbSchema::new_internal(fields);
533 let custom = schema.get_custom_fields();
534 assert_eq!(custom, vec!["major_pos", "reading"]);
535 }
536
537 #[test]
538 fn test_rb_schema_get_custom_fields_with_4_or_fewer() {
539 let fields = vec![
540 "surface".to_string(),
541 "left_context_id".to_string(),
542 "right_context_id".to_string(),
543 "cost".to_string(),
544 ];
545 let schema = RbSchema::new_internal(fields);
546 let custom = schema.get_custom_fields();
547 assert!(custom.is_empty());
548 }
549
550 #[test]
551 fn test_rb_schema_get_custom_fields_empty() {
552 let schema = RbSchema::new_internal(vec![]);
553 let custom = schema.get_custom_fields();
554 assert!(custom.is_empty());
555 }
556
557 #[test]
558 fn test_rb_schema_create_default_has_13_fields() {
559 let schema = RbSchema::create_default_internal();
560 assert_eq!(schema.field_count(), 13);
561 }
562
563 #[test]
564 fn test_rb_schema_create_default_first_fields() {
565 let schema = RbSchema::create_default_internal();
566 assert_eq!(schema.fields[0], "surface");
567 assert_eq!(schema.fields[1], "left_context_id");
568 assert_eq!(schema.fields[2], "right_context_id");
569 assert_eq!(schema.fields[3], "cost");
570 }
571
572 #[test]
573 fn test_rb_schema_to_lindera_schema() {
574 let fields = vec!["a".to_string(), "b".to_string(), "c".to_string()];
575 let rb_schema = RbSchema::new_internal(fields.clone());
576 let lindera_schema: Schema = rb_schema.into();
577 assert_eq!(lindera_schema.get_all_fields(), &fields);
578 }
579
580 #[test]
581 fn test_lindera_schema_to_rb_schema() {
582 let fields = vec!["x".to_string(), "y".to_string(), "z".to_string()];
583 let lindera_schema = Schema::new(fields.clone());
584 let rb_schema: RbSchema = lindera_schema.into();
585 assert_eq!(rb_schema.fields, fields);
586 assert_eq!(rb_schema.get_field_index("x".to_string()), Some(0));
587 assert_eq!(rb_schema.get_field_index("y".to_string()), Some(1));
588 assert_eq!(rb_schema.get_field_index("z".to_string()), Some(2));
589 }
590
591 #[test]
592 fn test_rb_schema_roundtrip() {
593 let fields = vec![
594 "surface".to_string(),
595 "left_context_id".to_string(),
596 "right_context_id".to_string(),
597 "cost".to_string(),
598 "reading".to_string(),
599 ];
600 let rb_schema = RbSchema::new_internal(fields.clone());
601 let lindera_schema: Schema = rb_schema.into();
602 let back: RbSchema = lindera_schema.into();
603 assert_eq!(back.fields, fields);
604 assert_eq!(back.field_count(), 5);
605 }
606
607 #[test]
608 fn test_rb_field_definition_to_lindera() {
609 let rb_def = RbFieldDefinition {
610 index: 2,
611 name: "right_context_id".to_string(),
612 field_type: RbFieldType {
613 inner: RbFieldTypeKind::RightContextId,
614 },
615 description: Some("Right context ID".to_string()),
616 };
617 let lindera_def: FieldDefinition = rb_def.into();
618 assert_eq!(lindera_def.index, 2);
619 assert_eq!(lindera_def.name, "right_context_id");
620 assert!(matches!(lindera_def.field_type, FieldType::RightContextId));
621 assert_eq!(
622 lindera_def.description,
623 Some("Right context ID".to_string())
624 );
625 }
626
627 #[test]
628 fn test_lindera_field_definition_to_rb() {
629 let lindera_def = FieldDefinition {
630 index: 4,
631 name: "major_pos".to_string(),
632 field_type: FieldType::Custom,
633 description: None,
634 };
635 let rb_def: RbFieldDefinition = lindera_def.into();
636 assert_eq!(rb_def.index, 4);
637 assert_eq!(rb_def.name, "major_pos");
638 assert!(matches!(rb_def.field_type.inner, RbFieldTypeKind::Custom));
639 assert!(rb_def.description.is_none());
640 }
641}