Skip to main content

mecab_ko_core/
nbest.rs

1//! 개선된 N-best Viterbi 알고리즘
2//!
3//! 진정한 N-best 경로 탐색을 위한 알고리즘을 제공합니다.
4//!
5//! # 개요
6//!
7//! 기존 N-best 구현은 1-best forward pass 후 단일 경로만 추적했습니다.
8//! 이 모듈은 각 노드에서 K개의 최선 후보를 유지하여 진정한 N-best 결과를 제공합니다.
9//!
10//! # 알고리즘
11//!
12//! 1. **Forward Pass (K-best)**: 각 노드에서 상위 K개의 경로 후보를 유지
13//! 2. **Backward Pass (N-best)**: EOS에서 시작하여 N개의 최적 경로를 추출
14//!
15//! # Example
16//!
17//! ```rust,no_run
18//! use mecab_ko_core::nbest::{ImprovedNbestSearcher, NbestPath};
19//! use mecab_ko_core::lattice::Lattice;
20//! use mecab_ko_core::viterbi::ZeroConnectionCost;
21//!
22//! let mut lattice = Lattice::new("한국어");
23//! // ... 노드 추가 ...
24//!
25//! let searcher = ImprovedNbestSearcher::new(5);
26//! let conn_cost = ZeroConnectionCost;
27//! let results = searcher.search(&mut lattice, &conn_cost);
28//!
29//! for path in results.iter() {
30//!     println!("Cost: {}, Tokens: {:?}", path.cost(), path.surfaces(&lattice));
31//! }
32//! ```
33
34use crate::lattice::{Lattice, Node, NodeId, NodeType, INVALID_NODE_ID};
35use crate::viterbi::{ConnectionCost, SpacePenalty};
36use std::cmp::Ordering;
37use std::collections::BinaryHeap;
38
39/// N-best 경로 하나를 표현
40#[derive(Debug, Clone)]
41pub struct NbestPath {
42    /// 경로의 노드 ID 목록 (BOS, EOS 제외)
43    pub node_ids: Vec<NodeId>,
44    /// 총 비용
45    pub total_cost: i32,
46    /// 경로 순위 (0-based)
47    pub rank: usize,
48}
49
50impl NbestPath {
51    /// 새 N-best 경로 생성
52    #[must_use]
53    pub const fn new(node_ids: Vec<NodeId>, total_cost: i32, rank: usize) -> Self {
54        Self {
55            node_ids,
56            total_cost,
57            rank,
58        }
59    }
60
61    /// 경로가 비어있는지 확인
62    #[must_use]
63    pub fn is_empty(&self) -> bool {
64        self.node_ids.is_empty()
65    }
66
67    /// 경로의 노드 수
68    #[must_use]
69    pub fn len(&self) -> usize {
70        self.node_ids.len()
71    }
72
73    /// 총 비용
74    #[must_use]
75    pub const fn cost(&self) -> i32 {
76        self.total_cost
77    }
78
79    /// 경로의 노드들 반복자
80    pub fn nodes<'a>(&'a self, lattice: &'a Lattice) -> impl Iterator<Item = &'a Node> + 'a {
81        self.node_ids.iter().filter_map(|&id| lattice.node(id))
82    }
83
84    /// 표면형 목록 반환
85    #[must_use]
86    pub fn surfaces<'a>(&'a self, lattice: &'a Lattice) -> Vec<&'a str> {
87        self.nodes(lattice).map(|n| n.surface.as_ref()).collect()
88    }
89
90    /// 품사 태그 목록 반환
91    #[must_use]
92    pub fn pos_tags<'a>(&'a self, lattice: &'a Lattice) -> Vec<&'a str> {
93        self.nodes(lattice)
94            .map(|n| {
95                n.feature
96                    .split(',')
97                    .next()
98                    .unwrap_or_default()
99            })
100            .collect()
101    }
102}
103
104/// N-best 검색 결과
105#[derive(Debug, Clone, Default)]
106pub struct NbestResult {
107    /// 경로 목록 (비용 오름차순)
108    paths: Vec<NbestPath>,
109}
110
111impl NbestResult {
112    /// 새 결과 생성
113    #[must_use]
114    pub const fn new(paths: Vec<NbestPath>) -> Self {
115        Self { paths }
116    }
117
118    /// 결과가 비어있는지 확인
119    #[must_use]
120    pub fn is_empty(&self) -> bool {
121        self.paths.is_empty()
122    }
123
124    /// 결과 수
125    #[must_use]
126    pub fn len(&self) -> usize {
127        self.paths.len()
128    }
129
130    /// 최선 경로 (1-best)
131    #[must_use]
132    pub fn best(&self) -> Option<&NbestPath> {
133        self.paths.first()
134    }
135
136    /// 인덱스로 경로 조회
137    #[must_use]
138    pub fn get(&self, index: usize) -> Option<&NbestPath> {
139        self.paths.get(index)
140    }
141
142    /// 경로 반복자
143    pub fn iter(&self) -> impl Iterator<Item = &NbestPath> {
144        self.paths.iter()
145    }
146
147    /// 경로 벡터로 변환
148    #[must_use]
149    pub fn into_paths(self) -> Vec<NbestPath> {
150        self.paths
151    }
152
153    /// (노드 ID 목록, 비용) 쌍으로 변환 (호환성용)
154    #[must_use]
155    pub fn to_pairs(&self) -> Vec<(Vec<NodeId>, i32)> {
156        self.paths
157            .iter()
158            .map(|p| (p.node_ids.clone(), p.total_cost))
159            .collect()
160    }
161}
162
163impl IntoIterator for NbestResult {
164    type Item = NbestPath;
165    type IntoIter = std::vec::IntoIter<NbestPath>;
166
167    fn into_iter(self) -> Self::IntoIter {
168        self.paths.into_iter()
169    }
170}
171
172/// 노드별 K-best 후보 저장
173#[derive(Debug, Clone)]
174struct NodeCandidate {
175    /// 이 후보까지의 총 비용
176    cost: i32,
177    /// 이전 노드 ID
178    prev_node_id: NodeId,
179    /// 이전 노드에서의 후보 인덱스
180    prev_candidate_idx: usize,
181}
182
183impl Eq for NodeCandidate {}
184
185impl PartialEq for NodeCandidate {
186    fn eq(&self, other: &Self) -> bool {
187        self.cost == other.cost
188    }
189}
190
191impl Ord for NodeCandidate {
192    fn cmp(&self, other: &Self) -> Ordering {
193        // 비용이 낮은 것이 우선 (Min-heap처럼 동작하도록 역순)
194        other.cost.cmp(&self.cost)
195    }
196}
197
198impl PartialOrd for NodeCandidate {
199    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
200        Some(self.cmp(other))
201    }
202}
203
204/// Backward pass용 경로 후보
205#[derive(Debug, Clone)]
206struct BackwardCandidate {
207    /// 현재 노드 ID
208    node_id: NodeId,
209    /// 현재 노드에서의 후보 인덱스
210    candidate_idx: usize,
211    /// 총 비용
212    cost: i32,
213    /// 지금까지의 경로 (역순, BOS 제외)
214    path: Vec<NodeId>,
215}
216
217impl Eq for BackwardCandidate {}
218
219impl PartialEq for BackwardCandidate {
220    fn eq(&self, other: &Self) -> bool {
221        self.cost == other.cost
222    }
223}
224
225impl Ord for BackwardCandidate {
226    fn cmp(&self, other: &Self) -> Ordering {
227        // Min-heap: 비용이 낮은 것이 우선
228        other.cost.cmp(&self.cost)
229    }
230}
231
232impl PartialOrd for BackwardCandidate {
233    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
234        Some(self.cmp(other))
235    }
236}
237
238/// 개선된 N-best Viterbi 탐색기
239///
240/// 각 노드에서 K개의 최선 후보를 유지하여 진정한 N-best 결과를 제공합니다.
241#[derive(Debug, Clone)]
242pub struct ImprovedNbestSearcher {
243    /// 최대 결과 수 (N)
244    max_results: usize,
245    /// 각 노드에서 유지할 최대 후보 수 (K)
246    /// 일반적으로 N보다 크거나 같아야 좋은 결과를 얻음
247    max_candidates_per_node: usize,
248    /// 띄어쓰기 패널티 설정
249    space_penalty: SpacePenalty,
250}
251
252impl ImprovedNbestSearcher {
253    /// 새 N-best 탐색기 생성
254    ///
255    /// # Arguments
256    ///
257    /// * `n` - 반환할 최대 경로 수
258    #[must_use]
259    pub fn new(n: usize) -> Self {
260        Self {
261            max_results: n,
262            // K는 N의 2배로 설정 (더 많은 후보 탐색)
263            max_candidates_per_node: n.max(2) * 2,
264            space_penalty: SpacePenalty::default(),
265        }
266    }
267
268    /// 노드당 최대 후보 수 설정
269    #[must_use]
270    pub const fn with_max_candidates(mut self, k: usize) -> Self {
271        self.max_candidates_per_node = k;
272        self
273    }
274
275    /// 띄어쓰기 패널티 설정
276    #[must_use]
277    pub fn with_space_penalty(mut self, penalty: SpacePenalty) -> Self {
278        self.space_penalty = penalty;
279        self
280    }
281
282    /// N-best 경로 탐색
283    ///
284    /// # Arguments
285    ///
286    /// * `lattice` - 노드가 추가된 Lattice
287    /// * `conn_cost` - 연접 비용 조회 인터페이스
288    ///
289    /// # Returns
290    ///
291    /// N-best 검색 결과
292    pub fn search<C: ConnectionCost>(
293        &self,
294        lattice: &mut Lattice,
295        conn_cost: &C,
296    ) -> NbestResult {
297        if lattice.node_count() <= 2 {
298            // BOS, EOS만 있는 경우
299            return NbestResult::default();
300        }
301
302        // 1. Forward pass: 각 노드에서 K-best 후보 계산
303        let candidates = self.forward_pass_kbest(lattice, conn_cost);
304
305        // 2. Backward pass: EOS에서 시작하여 N-best 경로 추출
306        self.backward_pass_nbest(lattice, &candidates)
307    }
308
309    /// K-best Forward Pass
310    ///
311    /// 각 노드에서 상위 K개의 경로 후보를 유지합니다.
312    fn forward_pass_kbest<C: ConnectionCost>(
313        &self,
314        lattice: &mut Lattice,
315        conn_cost: &C,
316    ) -> Vec<Vec<NodeCandidate>> {
317        let node_count = lattice.node_count();
318        let char_len = lattice.char_len();
319
320        // 각 노드별 K-best 후보 저장
321        let mut candidates: Vec<Vec<NodeCandidate>> = vec![Vec::new(); node_count];
322
323        // BOS 노드의 초기 후보
324        let bos_id = lattice.bos().id;
325        candidates[bos_id as usize].push(NodeCandidate {
326            cost: 0,
327            prev_node_id: INVALID_NODE_ID,
328            prev_candidate_idx: 0,
329        });
330
331        // 재사용 가능한 버퍼
332        let mut starting_ids: Vec<NodeId> = Vec::new();
333        let mut ending_data: Vec<(NodeId, u16)> = Vec::new();
334
335        // 위치 0부터 끝까지 순회
336        for pos in 0..=char_len {
337            // 이 위치에서 시작하는 노드들
338            starting_ids.clear();
339            starting_ids.extend(lattice.nodes_starting_at(pos).map(|n| n.id));
340
341            // 이 위치에서 끝나는 노드들의 정보
342            ending_data.clear();
343            ending_data.extend(
344                lattice
345                    .nodes_ending_at(pos)
346                    .map(|n| (n.id, n.right_id)),
347            );
348
349            for &node_id in &starting_ids {
350                let (left_id, word_cost, has_space) = {
351                    let Some(node) = lattice.node(node_id) else {
352                        continue;
353                    };
354                    (node.left_id, node.word_cost, node.has_space_before)
355                };
356
357                // 띄어쓰기 패널티
358                let space_penalty = if has_space {
359                    self.space_penalty.get(left_id)
360                } else {
361                    0
362                };
363
364                // 이 노드로 올 수 있는 모든 후보 수집
365                let mut new_candidates: BinaryHeap<NodeCandidate> = BinaryHeap::new();
366
367                for &(prev_id, prev_right_id) in &ending_data {
368                    let prev_candidates = &candidates[prev_id as usize];
369                    if prev_candidates.is_empty() {
370                        continue;
371                    }
372
373                    let connection = conn_cost.cost(prev_right_id, left_id);
374
375                    for (idx, prev_cand) in prev_candidates.iter().enumerate() {
376                        if prev_cand.cost == i32::MAX {
377                            continue;
378                        }
379
380                        let total = prev_cand
381                            .cost
382                            .saturating_add(connection)
383                            .saturating_add(word_cost)
384                            .saturating_add(space_penalty);
385
386                        new_candidates.push(NodeCandidate {
387                            cost: total,
388                            prev_node_id: prev_id,
389                            prev_candidate_idx: idx,
390                        });
391                    }
392                }
393
394                // 상위 K개만 유지
395                let k = self.max_candidates_per_node;
396                let mut selected: Vec<NodeCandidate> = Vec::with_capacity(k);
397
398                while selected.len() < k {
399                    if let Some(cand) = new_candidates.pop() {
400                        selected.push(cand);
401                    } else {
402                        break;
403                    }
404                }
405
406                candidates[node_id as usize] = selected;
407
408                // 1-best 정보를 Lattice에도 업데이트 (기존 호환성)
409                if let Some(best) = candidates[node_id as usize].first() {
410                    if let Some(node) = lattice.node_mut(node_id) {
411                        node.total_cost = best.cost;
412                        node.prev_node_id = best.prev_node_id;
413                    }
414                }
415            }
416        }
417
418        candidates
419    }
420
421    /// N-best Backward Pass
422    ///
423    /// EOS에서 시작하여 N개의 최적 경로를 추출합니다.
424    fn backward_pass_nbest(
425        &self,
426        lattice: &Lattice,
427        candidates: &[Vec<NodeCandidate>],
428    ) -> NbestResult {
429        let eos = lattice.eos();
430        let eos_candidates = &candidates[eos.id as usize];
431
432        if eos_candidates.is_empty() {
433            return NbestResult::default();
434        }
435
436        let mut results: Vec<NbestPath> = Vec::with_capacity(self.max_results);
437        let mut heap: BinaryHeap<BackwardCandidate> = BinaryHeap::new();
438
439        // EOS의 모든 K-best 후보를 시작점으로 추가
440        for (idx, cand) in eos_candidates.iter().enumerate() {
441            heap.push(BackwardCandidate {
442                node_id: eos.id,
443                candidate_idx: idx,
444                cost: cand.cost,
445                path: Vec::new(),
446            });
447        }
448
449        while let Some(current) = heap.pop() {
450            if results.len() >= self.max_results {
451                break;
452            }
453
454            let node_cands = &candidates[current.node_id as usize];
455            if current.candidate_idx >= node_cands.len() {
456                continue;
457            }
458
459            let cand = &node_cands[current.candidate_idx];
460
461            // BOS에 도달했으면 경로 완성
462            if cand.prev_node_id == INVALID_NODE_ID {
463                // 경로를 뒤집어서 정상 순서로
464                let mut path = current.path;
465                path.reverse();
466
467                results.push(NbestPath::new(path, current.cost, results.len()));
468                continue;
469            }
470
471            // 이전 노드로 이동
472            let Some(node) = lattice.node(current.node_id) else {
473                continue;
474            };
475
476            let mut new_path = current.path.clone();
477            // BOS, EOS가 아닌 노드만 경로에 추가
478            if node.node_type != NodeType::Bos && node.node_type != NodeType::Eos {
479                new_path.push(current.node_id);
480            }
481
482            heap.push(BackwardCandidate {
483                node_id: cand.prev_node_id,
484                candidate_idx: cand.prev_candidate_idx,
485                cost: current.cost,
486                path: new_path,
487            });
488        }
489
490        NbestResult::new(results)
491    }
492}
493
494/// 기존 `NbestSearcher`와의 호환성을 위한 래퍼
495impl ImprovedNbestSearcher {
496    /// 기존 API 호환: `(Vec<NodeId>, i32)` 쌍의 벡터 반환
497    pub fn search_pairs<C: ConnectionCost>(
498        &self,
499        lattice: &mut Lattice,
500        conn_cost: &C,
501    ) -> Vec<(Vec<NodeId>, i32)> {
502        self.search(lattice, conn_cost).to_pairs()
503    }
504}
505
506#[cfg(test)]
507mod tests {
508    use super::*;
509    use crate::lattice::NodeBuilder;
510    use crate::viterbi::ZeroConnectionCost;
511
512    #[test]
513    fn test_nbest_single_path() {
514        let mut lattice = Lattice::new("AB");
515
516        lattice.add_node(
517            NodeBuilder::new("A", 0, 1)
518                .left_id(1)
519                .right_id(1)
520                .word_cost(100),
521        );
522        lattice.add_node(
523            NodeBuilder::new("B", 1, 2)
524                .left_id(2)
525                .right_id(2)
526                .word_cost(200),
527        );
528
529        let searcher = ImprovedNbestSearcher::new(5);
530        let conn_cost = ZeroConnectionCost;
531        let results = searcher.search(&mut lattice, &conn_cost);
532
533        assert_eq!(results.len(), 1);
534        assert_eq!(results.best().unwrap().cost(), 300);
535    }
536
537    #[test]
538    fn test_nbest_multiple_paths() {
539        // 두 가지 경로가 있는 Lattice
540        // 경로 1: A -> B (비용: 100 + 200 = 300)
541        // 경로 2: AB (비용: 350)
542        let mut lattice = Lattice::new("AB");
543
544        lattice.add_node(
545            NodeBuilder::new("A", 0, 1)
546                .left_id(1)
547                .right_id(1)
548                .word_cost(100),
549        );
550        lattice.add_node(
551            NodeBuilder::new("B", 1, 2)
552                .left_id(2)
553                .right_id(2)
554                .word_cost(200),
555        );
556        lattice.add_node(
557            NodeBuilder::new("AB", 0, 2)
558                .left_id(3)
559                .right_id(3)
560                .word_cost(350),
561        );
562
563        let searcher = ImprovedNbestSearcher::new(5);
564        let conn_cost = ZeroConnectionCost;
565        let results = searcher.search(&mut lattice, &conn_cost);
566
567        // 두 가지 경로가 있어야 함
568        assert_eq!(results.len(), 2);
569
570        // 1-best는 A + B (300)
571        assert_eq!(results.get(0).unwrap().cost(), 300);
572
573        // 2-best는 AB (350)
574        assert_eq!(results.get(1).unwrap().cost(), 350);
575    }
576
577    #[test]
578    fn test_nbest_korean_example() {
579        // "아버지가" 예시
580        // 경로 1: "아버지" + "가" (1000 + 500 = 1500)
581        // 경로 2: "아버" + "지가" (3000 + 3000 = 6000)
582        let mut lattice = Lattice::new("아버지가");
583
584        // 경로 1
585        lattice.add_node(
586            NodeBuilder::new("아버지", 0, 3)
587                .left_id(1)
588                .right_id(1)
589                .word_cost(1000),
590        );
591        lattice.add_node(
592            NodeBuilder::new("가", 3, 4)
593                .left_id(2)
594                .right_id(2)
595                .word_cost(500),
596        );
597
598        // 경로 2
599        lattice.add_node(
600            NodeBuilder::new("아버", 0, 2)
601                .left_id(3)
602                .right_id(3)
603                .word_cost(3000),
604        );
605        lattice.add_node(
606            NodeBuilder::new("지가", 2, 4)
607                .left_id(4)
608                .right_id(4)
609                .word_cost(3000),
610        );
611
612        let searcher = ImprovedNbestSearcher::new(3);
613        let conn_cost = ZeroConnectionCost;
614        let results = searcher.search(&mut lattice, &conn_cost);
615
616        assert!(results.len() >= 2);
617
618        // 1-best는 "아버지" + "가"
619        let best = results.best().unwrap();
620        assert_eq!(best.cost(), 1500);
621        assert_eq!(best.surfaces(&lattice), vec!["아버지", "가"]);
622
623        // 2-best는 "아버" + "지가"
624        let second = results.get(1).unwrap();
625        assert_eq!(second.cost(), 6000);
626        assert_eq!(second.surfaces(&lattice), vec!["아버", "지가"]);
627    }
628
629    #[test]
630    fn test_nbest_result_api() {
631        let mut lattice = Lattice::new("AB");
632
633        lattice.add_node(
634            NodeBuilder::new("A", 0, 1)
635                .left_id(1)
636                .right_id(1)
637                .word_cost(100),
638        );
639        lattice.add_node(
640            NodeBuilder::new("B", 1, 2)
641                .left_id(2)
642                .right_id(2)
643                .word_cost(200),
644        );
645
646        let searcher = ImprovedNbestSearcher::new(5);
647        let conn_cost = ZeroConnectionCost;
648        let results = searcher.search(&mut lattice, &conn_cost);
649
650        // Iterator API
651        for path in results.iter() {
652            assert!(!path.is_empty());
653            assert!(path.cost() > 0);
654        }
655
656        // IntoIterator API
657        let results2 = searcher.search(&mut lattice, &conn_cost);
658        for path in results2 {
659            assert!(!path.is_empty());
660        }
661    }
662
663    #[test]
664    fn test_nbest_empty_lattice() {
665        let mut lattice = Lattice::new("");
666        let searcher = ImprovedNbestSearcher::new(5);
667        let conn_cost = ZeroConnectionCost;
668        let results = searcher.search(&mut lattice, &conn_cost);
669
670        assert!(results.is_empty());
671    }
672
673    #[test]
674    fn test_nbest_compatibility_pairs() {
675        let mut lattice = Lattice::new("AB");
676
677        lattice.add_node(
678            NodeBuilder::new("AB", 0, 2)
679                .left_id(1)
680                .right_id(1)
681                .word_cost(300),
682        );
683
684        let searcher = ImprovedNbestSearcher::new(5);
685        let conn_cost = ZeroConnectionCost;
686        let pairs = searcher.search_pairs(&mut lattice, &conn_cost);
687
688        assert_eq!(pairs.len(), 1);
689        assert_eq!(pairs[0].1, 300);
690    }
691
692    #[test]
693    fn test_nbest_with_max_candidates() {
694        let mut lattice = Lattice::new("ABC");
695
696        // 다양한 경로 추가
697        lattice.add_node(NodeBuilder::new("A", 0, 1).word_cost(100));
698        lattice.add_node(NodeBuilder::new("B", 1, 2).word_cost(100));
699        lattice.add_node(NodeBuilder::new("C", 2, 3).word_cost(100));
700        lattice.add_node(NodeBuilder::new("AB", 0, 2).word_cost(180));
701        lattice.add_node(NodeBuilder::new("BC", 1, 3).word_cost(180));
702        lattice.add_node(NodeBuilder::new("ABC", 0, 3).word_cost(250));
703
704        let searcher = ImprovedNbestSearcher::new(5).with_max_candidates(10);
705        let conn_cost = ZeroConnectionCost;
706        let results = searcher.search(&mut lattice, &conn_cost);
707
708        // 여러 경로가 있어야 함
709        assert!(results.len() >= 2);
710
711        // 비용이 오름차순으로 정렬되어야 함
712        let costs: Vec<i32> = results.iter().map(|p| p.cost()).collect();
713        for i in 1..costs.len() {
714            assert!(costs[i] >= costs[i - 1]);
715        }
716    }
717}