1use crate::sejong::SejongConverter;
28use crate::tokenizer::{Token, Tokenizer};
29use std::collections::HashMap;
30use std::fs::File;
31use std::io::{BufRead, BufReader};
32use std::path::Path;
33use thiserror::Error;
34
35#[derive(Error, Debug)]
37pub enum EvaluateError {
38 #[error("I/O error: {0}")]
40 Io(#[from] std::io::Error),
41
42 #[error("Parse error: {0}")]
44 Parse(String),
45
46 #[error("Data error: {0}")]
48 Data(String),
49}
50
51pub type Result<T> = std::result::Result<T, EvaluateError>;
53
54#[derive(Debug, Clone, PartialEq, Eq)]
56pub struct GoldToken {
57 pub surface: String,
59 pub pos: String,
61}
62
63impl GoldToken {
64 #[must_use]
71 pub const fn new(surface: String, pos: String) -> Self {
72 Self { surface, pos }
73 }
74
75 pub fn parse(s: &str) -> Result<Self> {
85 let parts: Vec<&str> = s.split('/').collect();
86 if parts.len() != 2 {
87 return Err(EvaluateError::Parse(format!(
88 "Invalid token format: {s} (expected surface/pos)"
89 )));
90 }
91
92 Ok(Self {
93 surface: SejongConverter::normalize_jamo(parts[0]),
94 pos: parts[1].to_string(),
95 })
96 }
97}
98
99#[derive(Debug, Clone)]
101pub struct GoldSentence {
102 pub text: String,
104 pub tokens: Vec<GoldToken>,
106}
107
108impl GoldSentence {
109 #[must_use]
116 pub const fn new(text: String, tokens: Vec<GoldToken>) -> Self {
117 Self { text, tokens }
118 }
119
120 pub fn parse_tsv_line(line: &str) -> Result<Self> {
133 let parts: Vec<&str> = line.split('\t').collect();
134 if parts.len() != 2 {
135 return Err(EvaluateError::Parse(format!(
136 "Invalid TSV line: {line} (expected text\\ttokens)"
137 )));
138 }
139
140 let text = parts[0].trim().to_string();
141 let tokens_str = parts[1].trim();
142
143 let tokens = tokens_str
144 .split_whitespace()
145 .map(GoldToken::parse)
146 .collect::<Result<Vec<_>>>()?;
147
148 if tokens.is_empty() {
149 return Err(EvaluateError::Data(format!(
150 "Empty gold tokens for text: {text}"
151 )));
152 }
153
154 Ok(Self { text, tokens })
155 }
156}
157
158#[derive(Debug, Clone)]
160pub struct TestDataset {
161 pub sentences: Vec<GoldSentence>,
163}
164
165impl TestDataset {
166 #[must_use]
168 pub const fn new() -> Self {
169 Self {
170 sentences: Vec::new(),
171 }
172 }
173
174 pub fn from_tsv<P: AsRef<Path>>(path: P) -> Result<Self> {
190 let file = File::open(path)?;
191 let reader = BufReader::new(file);
192
193 let mut sentences = Vec::new();
194
195 for (line_num, line) in reader.lines().enumerate() {
196 let line = line?;
197 let trimmed = line.trim();
198
199 if trimmed.is_empty() || trimmed.starts_with('#') {
201 continue;
202 }
203
204 let sentence = GoldSentence::parse_tsv_line(trimmed)
205 .map_err(|e| EvaluateError::Parse(format!("Line {}: {}", line_num + 1, e)))?;
206
207 sentences.push(sentence);
208 }
209
210 if sentences.is_empty() {
211 return Err(EvaluateError::Data("Empty dataset".to_string()));
212 }
213
214 Ok(Self { sentences })
215 }
216
217 pub fn add_sentence(&mut self, sentence: GoldSentence) {
223 self.sentences.push(sentence);
224 }
225
226 #[must_use]
228 pub fn len(&self) -> usize {
229 self.sentences.len()
230 }
231
232 #[must_use]
234 pub fn is_empty(&self) -> bool {
235 self.sentences.is_empty()
236 }
237}
238
239impl Default for TestDataset {
240 fn default() -> Self {
241 Self::new()
242 }
243}
244
245#[derive(Debug, Clone)]
247pub struct EvaluationResult {
248 pub total_sentences: usize,
250 pub total_gold_tokens: usize,
252 pub total_pred_tokens: usize,
254
255 pub true_positives: usize,
257 pub false_positives: usize,
259 pub false_negatives: usize,
261
262 pub exact_match_sentences: usize,
264
265 pub token_accuracy: f64,
267 pub sentence_accuracy: f64,
269 pub pos_accuracy: f64,
271 pub precision: f64,
273 pub recall: f64,
275 pub f1_score: f64,
277
278 pub pos_stats: HashMap<String, PosStats>,
280}
281
282#[derive(Debug, Clone, Default)]
284pub struct PosStats {
285 pub gold_count: usize,
287 pub pred_count: usize,
289 pub correct: usize,
291 pub accuracy: f64,
293}
294
295impl EvaluationResult {
296 #[must_use]
298 pub fn new() -> Self {
299 Self {
300 total_sentences: 0,
301 total_gold_tokens: 0,
302 total_pred_tokens: 0,
303 true_positives: 0,
304 false_positives: 0,
305 false_negatives: 0,
306 exact_match_sentences: 0,
307 token_accuracy: 0.0,
308 sentence_accuracy: 0.0,
309 pos_accuracy: 0.0,
310 precision: 0.0,
311 recall: 0.0,
312 f1_score: 0.0,
313 pos_stats: HashMap::new(),
314 }
315 }
316
317 #[must_use]
323 #[allow(clippy::cast_precision_loss, clippy::unwrap_used)]
324 pub fn format_report(&self) -> String {
325 use std::fmt::Write;
326
327 let mut report = String::new();
328
329 report.push_str("=== 정확도 평가 결과 ===\n");
330 writeln!(report, "테스트 문장: {}", self.total_sentences).unwrap();
331 writeln!(
332 report,
333 "Token Accuracy: {:.1}%",
334 self.token_accuracy * 100.0
335 )
336 .unwrap();
337 writeln!(
338 report,
339 "Sentence Accuracy: {:.1}%",
340 self.sentence_accuracy * 100.0
341 )
342 .unwrap();
343 writeln!(report, "POS Accuracy: {:.1}%", self.pos_accuracy * 100.0).unwrap();
344 writeln!(report, "Precision: {:.3}", self.precision).unwrap();
345 writeln!(report, "Recall: {:.3}", self.recall).unwrap();
346 writeln!(report, "F1 Score: {:.3}", self.f1_score).unwrap();
347 report.push('\n');
348
349 report.push_str("토큰 통계:\n");
350 writeln!(report, " 정답 토큰: {}", self.total_gold_tokens).unwrap();
351 writeln!(report, " 예측 토큰: {}", self.total_pred_tokens).unwrap();
352 writeln!(
353 report,
354 " 완전 일치 문장: {} / {} ({:.1}%)",
355 self.exact_match_sentences,
356 self.total_sentences,
357 (self.exact_match_sentences as f64 / self.total_sentences as f64) * 100.0
358 )
359 .unwrap();
360 report.push('\n');
361
362 let mut pos_sorted: Vec<_> = self.pos_stats.iter().collect();
364 pos_sorted.sort_by(|a, b| b.1.gold_count.cmp(&a.1.gold_count));
365
366 if !pos_sorted.is_empty() {
367 report.push_str("품사별 정확도:\n");
368 for (pos, stats) in pos_sorted.iter().take(15) {
369 writeln!(
370 report,
371 " {pos:<6} ({}개): {:.1}%",
372 stats.gold_count,
373 stats.accuracy * 100.0
374 )
375 .unwrap();
376 }
377
378 if pos_sorted.len() > 15 {
379 writeln!(report, " ... 외 {}개 품사", pos_sorted.len() - 15).unwrap();
380 }
381 }
382
383 report
384 }
385}
386
387impl Default for EvaluationResult {
388 fn default() -> Self {
389 Self::new()
390 }
391}
392
393#[must_use]
404pub fn evaluate_tokens(
405 gold_tokens: &[GoldToken],
406 pred_tokens: &[Token],
407) -> (usize, usize, usize, usize) {
408 let min_len = gold_tokens.len().min(pred_tokens.len());
409
410 let mut true_positives = 0;
411 let mut pos_match = 0;
412
413 for i in 0..min_len {
415 let gold = &gold_tokens[i];
416 let pred = &pred_tokens[i];
417
418 if gold.surface == pred.surface && gold.pos == pred.pos {
419 true_positives += 1;
420 pos_match += 1;
421 } else if gold.surface == pred.surface {
422 pos_match += 1;
423 }
424 }
425
426 let false_positives = pred_tokens.len().saturating_sub(true_positives);
427 let false_negatives = gold_tokens.len().saturating_sub(true_positives);
428
429 (true_positives, false_positives, false_negatives, pos_match)
430}
431
432#[must_use]
446pub fn evaluate_tokens_aligned(
447 gold_tokens: &[GoldToken],
448 pred_tokens: &[Token],
449) -> (usize, usize, usize, usize) {
450 let mut true_positives = 0;
451 let mut pos_match = 0;
452
453 let mut gold_idx = 0;
454 let mut pred_idx = 0;
455
456 while gold_idx < gold_tokens.len() && pred_idx < pred_tokens.len() {
457 let gold = &gold_tokens[gold_idx];
458 let pred = &pred_tokens[pred_idx];
459
460 if gold.surface == pred.surface {
461 pos_match += 1;
463 if gold.pos == pred.pos {
464 true_positives += 1;
465 }
466 gold_idx += 1;
467 pred_idx += 1;
468 } else {
469 let mut found = false;
471
472 for look_ahead in 1..=3 {
474 if pred_idx + look_ahead < pred_tokens.len()
475 && pred_tokens[pred_idx + look_ahead].surface == gold.surface
476 {
477 pred_idx += look_ahead;
479 found = true;
480 break;
481 }
482 }
483
484 if !found {
485 for look_ahead in 1..=3 {
487 if gold_idx + look_ahead < gold_tokens.len()
488 && gold_tokens[gold_idx + look_ahead].surface == pred.surface
489 {
490 gold_idx += look_ahead;
492 found = true;
493 break;
494 }
495 }
496 }
497
498 if !found {
499 gold_idx += 1;
501 pred_idx += 1;
502 }
503 }
504 }
505
506 let false_positives = pred_tokens.len().saturating_sub(true_positives);
507 let false_negatives = gold_tokens.len().saturating_sub(true_positives);
508
509 (true_positives, false_positives, false_negatives, pos_match)
510}
511
512#[must_use]
523#[allow(clippy::cast_precision_loss)]
524pub fn evaluate_dataset(tokenizer: &mut Tokenizer, dataset: &TestDataset) -> EvaluationResult {
525 let mut result = EvaluationResult::new();
526 result.total_sentences = dataset.len();
527
528 for gold_sentence in &dataset.sentences {
529 let pred_tokens = tokenizer.tokenize(&gold_sentence.text);
530
531 result.total_gold_tokens += gold_sentence.tokens.len();
532 result.total_pred_tokens += pred_tokens.len();
533
534 let (tp, fp, fn_, _pos_match) = evaluate_tokens(&gold_sentence.tokens, &pred_tokens);
535
536 result.true_positives += tp;
537 result.false_positives += fp;
538 result.false_negatives += fn_;
539
540 if gold_sentence.tokens.len() == pred_tokens.len() && tp == gold_sentence.tokens.len() {
542 result.exact_match_sentences += 1;
543 }
544
545 for (i, gold_token) in gold_sentence.tokens.iter().enumerate() {
547 let pos_stat = result.pos_stats.entry(gold_token.pos.clone()).or_default();
548
549 pos_stat.gold_count += 1;
550
551 if i < pred_tokens.len() {
552 let pred_token = &pred_tokens[i];
553 if gold_token.surface == pred_token.surface {
554 pos_stat.pred_count += 1;
555 if gold_token.pos == pred_token.pos {
556 pos_stat.correct += 1;
557 }
558 }
559 }
560 }
561 }
562
563 let total_tokens = result.total_gold_tokens;
565 if total_tokens > 0 {
566 result.token_accuracy = result.true_positives as f64 / total_tokens as f64;
567 }
568
569 if result.total_sentences > 0 {
570 result.sentence_accuracy =
571 result.exact_match_sentences as f64 / result.total_sentences as f64;
572 }
573
574 let total_pred = result.total_pred_tokens;
575 if total_pred > 0 {
576 result.precision = result.true_positives as f64 / total_pred as f64;
577 }
578
579 if total_tokens > 0 {
580 result.recall = result.true_positives as f64 / total_tokens as f64;
581 }
582
583 if result.precision + result.recall > 0.0 {
584 result.f1_score =
585 2.0 * (result.precision * result.recall) / (result.precision + result.recall);
586 }
587
588 let mut total_pos_correct = 0;
590 let mut total_pos_gold = 0;
591
592 for pos_stat in result.pos_stats.values_mut() {
593 if pos_stat.gold_count > 0 {
594 pos_stat.accuracy = pos_stat.correct as f64 / pos_stat.gold_count as f64;
595 }
596 total_pos_correct += pos_stat.correct;
597 total_pos_gold += pos_stat.gold_count;
598 }
599
600 if total_pos_gold > 0 {
601 result.pos_accuracy = total_pos_correct as f64 / total_pos_gold as f64;
602 }
603
604 result
605}
606
607#[allow(clippy::cast_precision_loss)]
621pub fn evaluate_dataset_sejong(
622 tokenizer: &mut Tokenizer,
623 dataset: &TestDataset,
624) -> EvaluationResult {
625 let converter = SejongConverter::new();
626 let mut result = EvaluationResult::new();
627 result.total_sentences = dataset.len();
628
629 for gold_sentence in &dataset.sentences {
630 let pred_tokens = tokenizer.tokenize(&gold_sentence.text);
631
632 let sejong_tokens = converter.convert_tokens(&pred_tokens);
634
635 let converted_pred: Vec<Token> = sejong_tokens
637 .iter()
638 .map(|st| Token {
639 surface: SejongConverter::normalize_jamo(&st.surface),
640 pos: st.pos.clone(),
641 start_pos: st.start_pos,
642 end_pos: st.end_pos,
643 start_byte: 0,
644 end_byte: 0,
645 reading: None,
646 lemma: None,
647 cost: 0,
648 features: String::new(),
649 normalized: None,
650 })
651 .collect();
652
653 result.total_gold_tokens += gold_sentence.tokens.len();
654 result.total_pred_tokens += converted_pred.len();
655
656 let (tp, fp, fn_, _pos_match) =
657 evaluate_tokens_aligned(&gold_sentence.tokens, &converted_pred);
658
659 result.true_positives += tp;
660 result.false_positives += fp;
661 result.false_negatives += fn_;
662
663 if gold_sentence.tokens.len() == converted_pred.len() && tp == gold_sentence.tokens.len() {
665 result.exact_match_sentences += 1;
666 }
667
668 for (i, gold_token) in gold_sentence.tokens.iter().enumerate() {
670 let pos_stat = result
671 .pos_stats
672 .entry(gold_token.pos.clone())
673 .or_insert_with(|| PosStats {
674 gold_count: 0,
675 pred_count: 0,
676 correct: 0,
677 accuracy: 0.0,
678 });
679 pos_stat.gold_count += 1;
680
681 if i < converted_pred.len() {
682 let pred_token = &converted_pred[i];
683 if gold_token.surface == pred_token.surface {
684 pos_stat.pred_count += 1;
685 if gold_token.pos == pred_token.pos {
686 pos_stat.correct += 1;
687 }
688 }
689 }
690 }
691 }
692
693 let total_tokens = result.total_gold_tokens;
695 if total_tokens > 0 {
696 result.token_accuracy = result.true_positives as f64 / total_tokens as f64;
697 }
698
699 if result.total_sentences > 0 {
700 result.sentence_accuracy =
701 result.exact_match_sentences as f64 / result.total_sentences as f64;
702 }
703
704 let total_pred = result.total_pred_tokens;
705 if total_pred > 0 {
706 result.precision = result.true_positives as f64 / total_pred as f64;
707 }
708
709 if total_tokens > 0 {
710 result.recall = result.true_positives as f64 / total_tokens as f64;
711 }
712
713 if result.precision + result.recall > 0.0 {
714 result.f1_score =
715 2.0 * (result.precision * result.recall) / (result.precision + result.recall);
716 }
717
718 let mut total_pos_correct = 0;
720 let mut total_pos_gold = 0;
721
722 for pos_stat in result.pos_stats.values_mut() {
723 if pos_stat.gold_count > 0 {
724 pos_stat.accuracy = pos_stat.correct as f64 / pos_stat.gold_count as f64;
725 }
726 total_pos_correct += pos_stat.correct;
727 total_pos_gold += pos_stat.gold_count;
728 }
729
730 if total_pos_gold > 0 {
731 result.pos_accuracy = total_pos_correct as f64 / total_pos_gold as f64;
732 }
733
734 result
735}
736
737#[cfg(test)]
738mod tests {
739 use super::*;
740
741 #[test]
742 fn test_gold_token_parse() {
743 let token = GoldToken::parse("나/NP").unwrap();
744 assert_eq!(token.surface, "나");
745 assert_eq!(token.pos, "NP");
746
747 assert!(GoldToken::parse("invalid").is_err());
748 assert!(GoldToken::parse("too/many/parts").is_err());
749 }
750
751 #[test]
752 fn test_gold_sentence_parse() {
753 let sentence =
754 GoldSentence::parse_tsv_line("나는 학생이다\t나/NP 는/JX 학생/NNG 이/VCP 다/EF")
755 .unwrap();
756 assert_eq!(sentence.text, "나는 학생이다");
757 assert_eq!(sentence.tokens.len(), 5);
758 assert_eq!(sentence.tokens[0].surface, "나");
759 assert_eq!(sentence.tokens[0].pos, "NP");
760 }
761
762 #[test]
763 fn test_evaluate_tokens_perfect_match() {
764 let gold = vec![
765 GoldToken::new("나".to_string(), "NP".to_string()),
766 GoldToken::new("는".to_string(), "JX".to_string()),
767 ];
768
769 let pred = vec![
770 Token {
771 surface: "나".to_string(),
772 pos: "NP".to_string(),
773 start_pos: 0,
774 end_pos: 1,
775 start_byte: 0,
776 end_byte: 3,
777 reading: None,
778 lemma: None,
779 cost: 0,
780 features: String::new(),
781 normalized: None,
782 },
783 Token {
784 surface: "는".to_string(),
785 pos: "JX".to_string(),
786 start_pos: 1,
787 end_pos: 2,
788 start_byte: 3,
789 end_byte: 6,
790 reading: None,
791 lemma: None,
792 cost: 0,
793 features: String::new(),
794 normalized: None,
795 },
796 ];
797
798 let (tp, fp, fn_, _) = evaluate_tokens(&gold, &pred);
799 assert_eq!(tp, 2);
800 assert_eq!(fp, 0);
801 assert_eq!(fn_, 0);
802 }
803
804 #[test]
805 fn test_evaluate_tokens_mismatch() {
806 let gold = vec![
807 GoldToken::new("나".to_string(), "NP".to_string()),
808 GoldToken::new("는".to_string(), "JX".to_string()),
809 ];
810
811 let pred = vec![Token {
812 surface: "나".to_string(),
813 pos: "NP".to_string(),
814 start_pos: 0,
815 end_pos: 1,
816 start_byte: 0,
817 end_byte: 3,
818 reading: None,
819 lemma: None,
820 cost: 0,
821 features: String::new(),
822 normalized: None,
823 }];
824
825 let (tp, fp, fn_, _) = evaluate_tokens(&gold, &pred);
826 assert_eq!(tp, 1);
827 assert_eq!(fp, 0);
828 assert_eq!(fn_, 1);
829 }
830
831 #[test]
832 fn test_evaluation_result_format() {
833 let mut result = EvaluationResult::new();
834 result.total_sentences = 10;
835 result.total_gold_tokens = 50;
836 result.total_pred_tokens = 48;
837 result.true_positives = 45;
838 result.false_positives = 3;
839 result.false_negatives = 5;
840 result.exact_match_sentences = 7;
841 result.token_accuracy = 0.9;
842 result.sentence_accuracy = 0.7;
843 result.pos_accuracy = 0.92;
844 result.precision = 0.9375;
845 result.recall = 0.9;
846 result.f1_score = 0.9184;
847
848 let report = result.format_report();
849 assert!(report.contains("테스트 문장: 10"));
850 assert!(report.contains("Token Accuracy: 90.0%"));
851 assert!(report.contains("F1 Score: 0.918"));
852 }
853
854 #[test]
855 #[cfg(feature = "test-utils")]
856 fn test_dataset_from_tsv() {
857 use std::io::Write;
858
859 let mut file = tempfile::NamedTempFile::new().unwrap();
860 writeln!(file, "# 주석").unwrap();
861 writeln!(file, "").unwrap();
862 writeln!(file, "나는 학생\t나/NP 는/JX 학생/NNG").unwrap();
863 writeln!(file, "오늘 날씨\t오늘/NNG 날씨/NNG").unwrap();
864 file.flush().unwrap();
865
866 let dataset = TestDataset::from_tsv(file.path()).unwrap();
867 assert_eq!(dataset.len(), 2);
868 assert_eq!(dataset.sentences[0].text, "나는 학생");
869 assert_eq!(dataset.sentences[0].tokens.len(), 3);
870 assert_eq!(dataset.sentences[1].text, "오늘 날씨");
871 assert_eq!(dataset.sentences[1].tokens.len(), 2);
872 }
873}