Skip to main content

mecab_ko_core/
nbest.rs

1//! 개선된 N-best Viterbi 알고리즘
2//!
3//! 진정한 N-best 경로 탐색을 위한 알고리즘을 제공합니다.
4//!
5//! # 개요
6//!
7//! 기존 N-best 구현은 1-best forward pass 후 단일 경로만 추적했습니다.
8//! 이 모듈은 각 노드에서 K개의 최선 후보를 유지하여 진정한 N-best 결과를 제공합니다.
9//!
10//! # 알고리즘
11//!
12//! 1. **Forward Pass (K-best)**: 각 노드에서 상위 K개의 경로 후보를 유지
13//! 2. **Backward Pass (N-best)**: EOS에서 시작하여 N개의 최적 경로를 추출
14//!
15//! # Example
16//!
17//! ```rust,no_run
18//! use mecab_ko_core::nbest::{ImprovedNbestSearcher, NbestPath};
19//! use mecab_ko_core::lattice::Lattice;
20//! use mecab_ko_core::viterbi::ZeroConnectionCost;
21//!
22//! let mut lattice = Lattice::new("한국어");
23//! // ... 노드 추가 ...
24//!
25//! let searcher = ImprovedNbestSearcher::new(5);
26//! let conn_cost = ZeroConnectionCost;
27//! let results = searcher.search(&mut lattice, &conn_cost);
28//!
29//! for path in results.iter() {
30//!     println!("Cost: {}, Tokens: {:?}", path.cost(), path.surfaces(&lattice));
31//! }
32//! ```
33
34use crate::lattice::{Lattice, Node, NodeId, NodeType, INVALID_NODE_ID};
35use crate::viterbi::{ConnectionCost, SpacePenalty};
36use std::cmp::Ordering;
37use std::collections::BinaryHeap;
38
39/// N-best 경로 하나를 표현
40#[derive(Debug, Clone)]
41pub struct NbestPath {
42    /// 경로의 노드 ID 목록 (BOS, EOS 제외)
43    pub node_ids: Vec<NodeId>,
44    /// 총 비용
45    pub total_cost: i32,
46    /// 경로 순위 (0-based)
47    pub rank: usize,
48}
49
50impl NbestPath {
51    /// 새 N-best 경로 생성
52    #[must_use]
53    pub const fn new(node_ids: Vec<NodeId>, total_cost: i32, rank: usize) -> Self {
54        Self {
55            node_ids,
56            total_cost,
57            rank,
58        }
59    }
60
61    /// 경로가 비어있는지 확인
62    #[must_use]
63    pub fn is_empty(&self) -> bool {
64        self.node_ids.is_empty()
65    }
66
67    /// 경로의 노드 수
68    #[must_use]
69    pub fn len(&self) -> usize {
70        self.node_ids.len()
71    }
72
73    /// 총 비용
74    #[must_use]
75    pub const fn cost(&self) -> i32 {
76        self.total_cost
77    }
78
79    /// 경로의 노드들 반복자
80    pub fn nodes<'a>(&'a self, lattice: &'a Lattice) -> impl Iterator<Item = &'a Node> + 'a {
81        self.node_ids.iter().filter_map(|&id| lattice.node(id))
82    }
83
84    /// 표면형 목록 반환
85    #[must_use]
86    pub fn surfaces<'a>(&'a self, lattice: &'a Lattice) -> Vec<&'a str> {
87        self.nodes(lattice).map(|n| n.surface.as_ref()).collect()
88    }
89
90    /// 품사 태그 목록 반환
91    #[must_use]
92    pub fn pos_tags<'a>(&'a self, lattice: &'a Lattice) -> Vec<&'a str> {
93        self.nodes(lattice)
94            .map(|n| n.feature.split(',').next().unwrap_or_default())
95            .collect()
96    }
97}
98
99/// N-best 검색 결과
100#[derive(Debug, Clone, Default)]
101pub struct NbestResult {
102    /// 경로 목록 (비용 오름차순)
103    paths: Vec<NbestPath>,
104}
105
106impl NbestResult {
107    /// 새 결과 생성
108    #[must_use]
109    pub const fn new(paths: Vec<NbestPath>) -> Self {
110        Self { paths }
111    }
112
113    /// 결과가 비어있는지 확인
114    #[must_use]
115    pub fn is_empty(&self) -> bool {
116        self.paths.is_empty()
117    }
118
119    /// 결과 수
120    #[must_use]
121    pub fn len(&self) -> usize {
122        self.paths.len()
123    }
124
125    /// 최선 경로 (1-best)
126    #[must_use]
127    pub fn best(&self) -> Option<&NbestPath> {
128        self.paths.first()
129    }
130
131    /// 인덱스로 경로 조회
132    #[must_use]
133    pub fn get(&self, index: usize) -> Option<&NbestPath> {
134        self.paths.get(index)
135    }
136
137    /// 경로 반복자
138    pub fn iter(&self) -> impl Iterator<Item = &NbestPath> {
139        self.paths.iter()
140    }
141
142    /// 경로 벡터로 변환
143    #[must_use]
144    pub fn into_paths(self) -> Vec<NbestPath> {
145        self.paths
146    }
147
148    /// (노드 ID 목록, 비용) 쌍으로 변환 (호환성용)
149    #[must_use]
150    pub fn to_pairs(&self) -> Vec<(Vec<NodeId>, i32)> {
151        self.paths
152            .iter()
153            .map(|p| (p.node_ids.clone(), p.total_cost))
154            .collect()
155    }
156}
157
158impl IntoIterator for NbestResult {
159    type Item = NbestPath;
160    type IntoIter = std::vec::IntoIter<NbestPath>;
161
162    fn into_iter(self) -> Self::IntoIter {
163        self.paths.into_iter()
164    }
165}
166
167/// 노드별 K-best 후보 저장
168#[derive(Debug, Clone)]
169struct NodeCandidate {
170    /// 이 후보까지의 총 비용
171    cost: i32,
172    /// 이전 노드 ID
173    prev_node_id: NodeId,
174    /// 이전 노드에서의 후보 인덱스
175    prev_candidate_idx: usize,
176}
177
178impl Eq for NodeCandidate {}
179
180impl PartialEq for NodeCandidate {
181    fn eq(&self, other: &Self) -> bool {
182        self.cost == other.cost
183    }
184}
185
186impl Ord for NodeCandidate {
187    fn cmp(&self, other: &Self) -> Ordering {
188        // 비용이 낮은 것이 우선 (Min-heap처럼 동작하도록 역순)
189        other.cost.cmp(&self.cost)
190    }
191}
192
193impl PartialOrd for NodeCandidate {
194    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
195        Some(self.cmp(other))
196    }
197}
198
199/// Backward pass용 경로 후보
200#[derive(Debug, Clone)]
201struct BackwardCandidate {
202    /// 현재 노드 ID
203    node_id: NodeId,
204    /// 현재 노드에서의 후보 인덱스
205    candidate_idx: usize,
206    /// 총 비용
207    cost: i32,
208    /// 지금까지의 경로 (역순, BOS 제외)
209    path: Vec<NodeId>,
210}
211
212impl Eq for BackwardCandidate {}
213
214impl PartialEq for BackwardCandidate {
215    fn eq(&self, other: &Self) -> bool {
216        self.cost == other.cost
217    }
218}
219
220impl Ord for BackwardCandidate {
221    fn cmp(&self, other: &Self) -> Ordering {
222        // Min-heap: 비용이 낮은 것이 우선
223        other.cost.cmp(&self.cost)
224    }
225}
226
227impl PartialOrd for BackwardCandidate {
228    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
229        Some(self.cmp(other))
230    }
231}
232
233/// 개선된 N-best Viterbi 탐색기
234///
235/// 각 노드에서 K개의 최선 후보를 유지하여 진정한 N-best 결과를 제공합니다.
236#[derive(Debug, Clone)]
237pub struct ImprovedNbestSearcher {
238    /// 최대 결과 수 (N)
239    max_results: usize,
240    /// 각 노드에서 유지할 최대 후보 수 (K)
241    /// 일반적으로 N보다 크거나 같아야 좋은 결과를 얻음
242    max_candidates_per_node: usize,
243    /// 띄어쓰기 패널티 설정
244    space_penalty: SpacePenalty,
245}
246
247impl ImprovedNbestSearcher {
248    /// 새 N-best 탐색기 생성
249    ///
250    /// # Arguments
251    ///
252    /// * `n` - 반환할 최대 경로 수
253    #[must_use]
254    pub fn new(n: usize) -> Self {
255        Self {
256            max_results: n,
257            // K는 N의 2배로 설정 (더 많은 후보 탐색)
258            max_candidates_per_node: n.max(2) * 2,
259            space_penalty: SpacePenalty::default(),
260        }
261    }
262
263    /// 노드당 최대 후보 수 설정
264    #[must_use]
265    pub const fn with_max_candidates(mut self, k: usize) -> Self {
266        self.max_candidates_per_node = k;
267        self
268    }
269
270    /// 띄어쓰기 패널티 설정
271    #[must_use]
272    pub fn with_space_penalty(mut self, penalty: SpacePenalty) -> Self {
273        self.space_penalty = penalty;
274        self
275    }
276
277    /// N-best 경로 탐색
278    ///
279    /// # Arguments
280    ///
281    /// * `lattice` - 노드가 추가된 Lattice
282    /// * `conn_cost` - 연접 비용 조회 인터페이스
283    ///
284    /// # Returns
285    ///
286    /// N-best 검색 결과
287    pub fn search<C: ConnectionCost>(&self, lattice: &mut Lattice, conn_cost: &C) -> NbestResult {
288        if lattice.node_count() <= 2 {
289            // BOS, EOS만 있는 경우
290            return NbestResult::default();
291        }
292
293        // 1. Forward pass: 각 노드에서 K-best 후보 계산
294        let candidates = self.forward_pass_kbest(lattice, conn_cost);
295
296        // 2. Backward pass: EOS에서 시작하여 N-best 경로 추출
297        self.backward_pass_nbest(lattice, &candidates)
298    }
299
300    /// K-best Forward Pass
301    ///
302    /// 각 노드에서 상위 K개의 경로 후보를 유지합니다.
303    fn forward_pass_kbest<C: ConnectionCost>(
304        &self,
305        lattice: &mut Lattice,
306        conn_cost: &C,
307    ) -> Vec<Vec<NodeCandidate>> {
308        let node_count = lattice.node_count();
309        let char_len = lattice.char_len();
310
311        // 각 노드별 K-best 후보 저장
312        let mut candidates: Vec<Vec<NodeCandidate>> = vec![Vec::new(); node_count];
313
314        // BOS 노드의 초기 후보
315        let bos_id = lattice.bos().id;
316        candidates[bos_id as usize].push(NodeCandidate {
317            cost: 0,
318            prev_node_id: INVALID_NODE_ID,
319            prev_candidate_idx: 0,
320        });
321
322        // 재사용 가능한 버퍼
323        let mut starting_ids: Vec<NodeId> = Vec::new();
324        let mut ending_data: Vec<(NodeId, u16)> = Vec::new();
325
326        // 위치 0부터 끝까지 순회
327        for pos in 0..=char_len {
328            // 이 위치에서 시작하는 노드들
329            starting_ids.clear();
330            starting_ids.extend(lattice.nodes_starting_at(pos).map(|n| n.id));
331
332            // 이 위치에서 끝나는 노드들의 정보
333            ending_data.clear();
334            ending_data.extend(lattice.nodes_ending_at(pos).map(|n| (n.id, n.right_id)));
335
336            for &node_id in &starting_ids {
337                let (left_id, word_cost, has_space) = {
338                    let Some(node) = lattice.node(node_id) else {
339                        continue;
340                    };
341                    (node.left_id, node.word_cost, node.has_space_before)
342                };
343
344                // 띄어쓰기 패널티
345                let space_penalty = if has_space {
346                    self.space_penalty.get(left_id)
347                } else {
348                    0
349                };
350
351                // 이 노드로 올 수 있는 모든 후보 수집
352                let mut new_candidates: BinaryHeap<NodeCandidate> = BinaryHeap::new();
353
354                for &(prev_id, prev_right_id) in &ending_data {
355                    let prev_candidates = &candidates[prev_id as usize];
356                    if prev_candidates.is_empty() {
357                        continue;
358                    }
359
360                    let connection = conn_cost.cost(prev_right_id, left_id);
361
362                    for (idx, prev_cand) in prev_candidates.iter().enumerate() {
363                        if prev_cand.cost == i32::MAX {
364                            continue;
365                        }
366
367                        let total = prev_cand
368                            .cost
369                            .saturating_add(connection)
370                            .saturating_add(word_cost)
371                            .saturating_add(space_penalty);
372
373                        new_candidates.push(NodeCandidate {
374                            cost: total,
375                            prev_node_id: prev_id,
376                            prev_candidate_idx: idx,
377                        });
378                    }
379                }
380
381                // 상위 K개만 유지
382                let k = self.max_candidates_per_node;
383                let mut selected: Vec<NodeCandidate> = Vec::with_capacity(k);
384
385                while selected.len() < k {
386                    if let Some(cand) = new_candidates.pop() {
387                        selected.push(cand);
388                    } else {
389                        break;
390                    }
391                }
392
393                candidates[node_id as usize] = selected;
394
395                // 1-best 정보를 Lattice에도 업데이트 (기존 호환성)
396                if let Some(best) = candidates[node_id as usize].first() {
397                    if let Some(node) = lattice.node_mut(node_id) {
398                        node.total_cost = best.cost;
399                        node.prev_node_id = best.prev_node_id;
400                    }
401                }
402            }
403        }
404
405        candidates
406    }
407
408    /// N-best Backward Pass
409    ///
410    /// EOS에서 시작하여 N개의 최적 경로를 추출합니다.
411    fn backward_pass_nbest(
412        &self,
413        lattice: &Lattice,
414        candidates: &[Vec<NodeCandidate>],
415    ) -> NbestResult {
416        let eos = lattice.eos();
417        let eos_candidates = &candidates[eos.id as usize];
418
419        if eos_candidates.is_empty() {
420            return NbestResult::default();
421        }
422
423        let mut results: Vec<NbestPath> = Vec::with_capacity(self.max_results);
424        let mut heap: BinaryHeap<BackwardCandidate> = BinaryHeap::new();
425
426        // EOS의 모든 K-best 후보를 시작점으로 추가
427        for (idx, cand) in eos_candidates.iter().enumerate() {
428            heap.push(BackwardCandidate {
429                node_id: eos.id,
430                candidate_idx: idx,
431                cost: cand.cost,
432                path: Vec::new(),
433            });
434        }
435
436        while let Some(current) = heap.pop() {
437            if results.len() >= self.max_results {
438                break;
439            }
440
441            let node_cands = &candidates[current.node_id as usize];
442            if current.candidate_idx >= node_cands.len() {
443                continue;
444            }
445
446            let cand = &node_cands[current.candidate_idx];
447
448            // BOS에 도달했으면 경로 완성
449            if cand.prev_node_id == INVALID_NODE_ID {
450                // 경로를 뒤집어서 정상 순서로
451                let mut path = current.path;
452                path.reverse();
453
454                results.push(NbestPath::new(path, current.cost, results.len()));
455                continue;
456            }
457
458            // 이전 노드로 이동
459            let Some(node) = lattice.node(current.node_id) else {
460                continue;
461            };
462
463            let mut new_path = current.path.clone();
464            // BOS, EOS가 아닌 노드만 경로에 추가
465            if node.node_type != NodeType::Bos && node.node_type != NodeType::Eos {
466                new_path.push(current.node_id);
467            }
468
469            heap.push(BackwardCandidate {
470                node_id: cand.prev_node_id,
471                candidate_idx: cand.prev_candidate_idx,
472                cost: current.cost,
473                path: new_path,
474            });
475        }
476
477        NbestResult::new(results)
478    }
479}
480
481/// 기존 `NbestSearcher`와의 호환성을 위한 래퍼
482impl ImprovedNbestSearcher {
483    /// 기존 API 호환: `(Vec<NodeId>, i32)` 쌍의 벡터 반환
484    pub fn search_pairs<C: ConnectionCost>(
485        &self,
486        lattice: &mut Lattice,
487        conn_cost: &C,
488    ) -> Vec<(Vec<NodeId>, i32)> {
489        self.search(lattice, conn_cost).to_pairs()
490    }
491}
492
493#[cfg(test)]
494mod tests {
495    use super::*;
496    use crate::lattice::NodeBuilder;
497    use crate::viterbi::ZeroConnectionCost;
498
499    #[test]
500    fn test_nbest_single_path() {
501        let mut lattice = Lattice::new("AB");
502
503        lattice.add_node(
504            NodeBuilder::new("A", 0, 1)
505                .left_id(1)
506                .right_id(1)
507                .word_cost(100),
508        );
509        lattice.add_node(
510            NodeBuilder::new("B", 1, 2)
511                .left_id(2)
512                .right_id(2)
513                .word_cost(200),
514        );
515
516        let searcher = ImprovedNbestSearcher::new(5);
517        let conn_cost = ZeroConnectionCost;
518        let results = searcher.search(&mut lattice, &conn_cost);
519
520        assert_eq!(results.len(), 1);
521        assert_eq!(results.best().unwrap().cost(), 300);
522    }
523
524    #[test]
525    fn test_nbest_multiple_paths() {
526        // 두 가지 경로가 있는 Lattice
527        // 경로 1: A -> B (비용: 100 + 200 = 300)
528        // 경로 2: AB (비용: 350)
529        let mut lattice = Lattice::new("AB");
530
531        lattice.add_node(
532            NodeBuilder::new("A", 0, 1)
533                .left_id(1)
534                .right_id(1)
535                .word_cost(100),
536        );
537        lattice.add_node(
538            NodeBuilder::new("B", 1, 2)
539                .left_id(2)
540                .right_id(2)
541                .word_cost(200),
542        );
543        lattice.add_node(
544            NodeBuilder::new("AB", 0, 2)
545                .left_id(3)
546                .right_id(3)
547                .word_cost(350),
548        );
549
550        let searcher = ImprovedNbestSearcher::new(5);
551        let conn_cost = ZeroConnectionCost;
552        let results = searcher.search(&mut lattice, &conn_cost);
553
554        // 두 가지 경로가 있어야 함
555        assert_eq!(results.len(), 2);
556
557        // 1-best는 A + B (300)
558        assert_eq!(results.get(0).unwrap().cost(), 300);
559
560        // 2-best는 AB (350)
561        assert_eq!(results.get(1).unwrap().cost(), 350);
562    }
563
564    #[test]
565    fn test_nbest_korean_example() {
566        // "아버지가" 예시
567        // 경로 1: "아버지" + "가" (1000 + 500 = 1500)
568        // 경로 2: "아버" + "지가" (3000 + 3000 = 6000)
569        let mut lattice = Lattice::new("아버지가");
570
571        // 경로 1
572        lattice.add_node(
573            NodeBuilder::new("아버지", 0, 3)
574                .left_id(1)
575                .right_id(1)
576                .word_cost(1000),
577        );
578        lattice.add_node(
579            NodeBuilder::new("가", 3, 4)
580                .left_id(2)
581                .right_id(2)
582                .word_cost(500),
583        );
584
585        // 경로 2
586        lattice.add_node(
587            NodeBuilder::new("아버", 0, 2)
588                .left_id(3)
589                .right_id(3)
590                .word_cost(3000),
591        );
592        lattice.add_node(
593            NodeBuilder::new("지가", 2, 4)
594                .left_id(4)
595                .right_id(4)
596                .word_cost(3000),
597        );
598
599        let searcher = ImprovedNbestSearcher::new(3);
600        let conn_cost = ZeroConnectionCost;
601        let results = searcher.search(&mut lattice, &conn_cost);
602
603        assert!(results.len() >= 2);
604
605        // 1-best는 "아버지" + "가"
606        let best = results.best().unwrap();
607        assert_eq!(best.cost(), 1500);
608        assert_eq!(best.surfaces(&lattice), vec!["아버지", "가"]);
609
610        // 2-best는 "아버" + "지가"
611        let second = results.get(1).unwrap();
612        assert_eq!(second.cost(), 6000);
613        assert_eq!(second.surfaces(&lattice), vec!["아버", "지가"]);
614    }
615
616    #[test]
617    fn test_nbest_result_api() {
618        let mut lattice = Lattice::new("AB");
619
620        lattice.add_node(
621            NodeBuilder::new("A", 0, 1)
622                .left_id(1)
623                .right_id(1)
624                .word_cost(100),
625        );
626        lattice.add_node(
627            NodeBuilder::new("B", 1, 2)
628                .left_id(2)
629                .right_id(2)
630                .word_cost(200),
631        );
632
633        let searcher = ImprovedNbestSearcher::new(5);
634        let conn_cost = ZeroConnectionCost;
635        let results = searcher.search(&mut lattice, &conn_cost);
636
637        // Iterator API
638        for path in results.iter() {
639            assert!(!path.is_empty());
640            assert!(path.cost() > 0);
641        }
642
643        // IntoIterator API
644        let results2 = searcher.search(&mut lattice, &conn_cost);
645        for path in results2 {
646            assert!(!path.is_empty());
647        }
648    }
649
650    #[test]
651    fn test_nbest_empty_lattice() {
652        let mut lattice = Lattice::new("");
653        let searcher = ImprovedNbestSearcher::new(5);
654        let conn_cost = ZeroConnectionCost;
655        let results = searcher.search(&mut lattice, &conn_cost);
656
657        assert!(results.is_empty());
658    }
659
660    #[test]
661    fn test_nbest_compatibility_pairs() {
662        let mut lattice = Lattice::new("AB");
663
664        lattice.add_node(
665            NodeBuilder::new("AB", 0, 2)
666                .left_id(1)
667                .right_id(1)
668                .word_cost(300),
669        );
670
671        let searcher = ImprovedNbestSearcher::new(5);
672        let conn_cost = ZeroConnectionCost;
673        let pairs = searcher.search_pairs(&mut lattice, &conn_cost);
674
675        assert_eq!(pairs.len(), 1);
676        assert_eq!(pairs[0].1, 300);
677    }
678
679    #[test]
680    fn test_nbest_with_max_candidates() {
681        let mut lattice = Lattice::new("ABC");
682
683        // 다양한 경로 추가
684        lattice.add_node(NodeBuilder::new("A", 0, 1).word_cost(100));
685        lattice.add_node(NodeBuilder::new("B", 1, 2).word_cost(100));
686        lattice.add_node(NodeBuilder::new("C", 2, 3).word_cost(100));
687        lattice.add_node(NodeBuilder::new("AB", 0, 2).word_cost(180));
688        lattice.add_node(NodeBuilder::new("BC", 1, 3).word_cost(180));
689        lattice.add_node(NodeBuilder::new("ABC", 0, 3).word_cost(250));
690
691        let searcher = ImprovedNbestSearcher::new(5).with_max_candidates(10);
692        let conn_cost = ZeroConnectionCost;
693        let results = searcher.search(&mut lattice, &conn_cost);
694
695        // 여러 경로가 있어야 함
696        assert!(results.len() >= 2);
697
698        // 비용이 오름차순으로 정렬되어야 함
699        let costs: Vec<i32> = results.iter().map(super::NbestPath::cost).collect();
700        for i in 1..costs.len() {
701            assert!(costs[i] >= costs[i - 1]);
702        }
703    }
704}