Skip to main content

mecab_ko_core/viterbi/
mod.rs

1//! Viterbi 알고리즘
2//!
3//! 최적 형태소 분석 경로를 찾는 Viterbi 알고리즘을 구현합니다.
4//!
5//! # 개요
6//!
7//! Viterbi 알고리즘은 Lattice에서 최소 비용 경로를 찾는 동적 프로그래밍 알고리즘입니다.
8//!
9//! ```text
10//! 총 비용 = Σ(단어 비용) + Σ(연접 비용) + Σ(띄어쓰기 패널티)
11//! ```
12//!
13//! # 알고리즘
14//!
15//! 1. **Forward Pass**: BOS에서 시작하여 각 노드까지의 최소 비용 계산
16//! 2. **Backward Pass**: EOS에서 BOS까지 역추적하여 최적 경로 추출
17//!
18//! # 한국어 특화
19//!
20//! - `left-space-penalty-factor`: 띄어쓰기 후 특정 품사 시작 시 페널티 부여
21//! - 조사(JK*), 어미(E*) 등이 띄어쓰기 직후 시작하면 높은 페널티
22//!
23//! # Example
24//!
25//! ```rust,no_run
26//! use mecab_ko_core::viterbi::{ViterbiSearcher, SpacePenalty};
27//! use mecab_ko_core::lattice::Lattice;
28//!
29//! let mut lattice = Lattice::new("아버지가방에");
30//! // ... 노드 추가 후 검색 ...
31//!
32//! let searcher = ViterbiSearcher::new()
33//!     .with_space_penalty(SpacePenalty::korean_default());
34//! ```
35
36use crate::lattice::{Lattice, Node, NodeId, NodeType, INVALID_NODE_ID};
37use std::cmp::Ordering;
38use std::collections::BinaryHeap;
39use std::rc::Rc;
40
41// SIMD 최적화 모듈
42#[cfg(feature = "simd")]
43pub mod simd;
44
45#[cfg(feature = "simd")]
46pub use simd::{simd_forward_pass_position, simd_update_node_cost};
47
48/// 여러 값의 포화 덧셈 (체인)
49///
50/// 오버플로우 방지를 위해 포화 연산 사용
51#[inline(always)]
52fn saturating_add_chain(a: i32, b: i32, c: i32, d: i32) -> i32 {
53    a.saturating_add(b).saturating_add(c).saturating_add(d)
54}
55
56/// 연접 비용 조회 인터페이스
57///
58/// 두 형태소 간의 연결 비용을 반환합니다.
59/// 이 비용은 matrix.def에서 학습된 값입니다.
60pub trait ConnectionCost {
61    /// 두 문맥 ID 간의 연접 비용 반환
62    ///
63    /// # Arguments
64    ///
65    /// * `right_id` - 이전 노드의 우문맥 ID
66    /// * `left_id` - 현재 노드의 좌문맥 ID
67    ///
68    /// # Returns
69    ///
70    /// 연접 비용 (낮을수록 좋음)
71    fn cost(&self, right_id: u16, left_id: u16) -> i32;
72}
73
74/// 더미 연접 비용 (테스트용)
75///
76/// 모든 연접에 대해 0을 반환합니다.
77#[derive(Debug, Clone, Default)]
78pub struct ZeroConnectionCost;
79
80impl ConnectionCost for ZeroConnectionCost {
81    #[inline(always)]
82    fn cost(&self, _right_id: u16, _left_id: u16) -> i32 {
83        0
84    }
85}
86
87/// 고정 연접 비용 (테스트용)
88#[derive(Debug, Clone)]
89pub struct FixedConnectionCost {
90    /// 기본 비용
91    pub default_cost: i32,
92}
93
94impl FixedConnectionCost {
95    /// 새 고정 비용 생성
96    #[must_use]
97    pub const fn new(cost: i32) -> Self {
98        Self { default_cost: cost }
99    }
100}
101
102impl ConnectionCost for FixedConnectionCost {
103    #[inline(always)]
104    fn cost(&self, _right_id: u16, _left_id: u16) -> i32 {
105        self.default_cost
106    }
107}
108
109/// mecab-ko-dict의 `Matrix` trait에 대한 `ConnectionCost` 구현
110///
111/// 사전 모듈의 연접 비용 행렬을 Viterbi 알고리즘에서 직접 사용할 수 있습니다.
112impl<T: mecab_ko_dict::Matrix> ConnectionCost for T {
113    #[inline(always)]
114    fn cost(&self, right_id: u16, left_id: u16) -> i32 {
115        self.get(right_id, left_id)
116    }
117}
118
119/// 띄어쓰기 패널티 설정
120///
121/// mecab-ko의 `left-space-penalty-factor` 기능을 구현합니다.
122/// 띄어쓰기 직후에 특정 품사가 오면 페널티를 부여하여 오분석을 방지합니다.
123///
124/// # Example
125///
126/// ```rust
127/// use mecab_ko_core::viterbi::SpacePenalty;
128///
129/// // mecab-ko 기본 설정
130/// let penalty = SpacePenalty::korean_default();
131///
132/// // dicrc 형식에서 생성
133/// let penalty = SpacePenalty::from_dicrc("1785,6000;1786,6000");
134/// ```
135#[derive(Debug, Clone, Default)]
136pub struct SpacePenalty {
137    /// 페널티를 적용할 품사 ID 목록과 페널티 값
138    /// `(left_id, penalty)`
139    penalties: Vec<(u16, i32)>,
140}
141
142impl SpacePenalty {
143    /// 빈 페널티 설정 생성
144    #[must_use]
145    pub fn new() -> Self {
146        Self::default()
147    }
148
149    /// 한국어 기본 페널티 설정
150    ///
151    /// 조사(JK*)와 어미(E*)가 띄어쓰기 직후 나타나면 높은 페널티를 부여합니다.
152    /// 이는 "아버지가방에" → "아버지가 방에"로 분석하는 데 도움이 됩니다.
153    #[must_use]
154    pub fn korean_default() -> Self {
155        // mecab-ko-dic의 left-id 기준 (실제 값은 사전에 따라 다름)
156        // 여기서는 대표적인 조사/어미 ID 범위를 사용
157        // Build ranges in sorted order so binary_search in get() works correctly.
158
159        // 어미 계열 (EP, EF, EC, ETN, ETM): 1700~1759
160        // 조사 계열 (JKS, JKC, JKG, JKO, JKB, JKV, JKQ, JX, JC): 1780~1809
161        let mut penalties: Vec<(u16, i32)> = (1700u16..1760)
162            .chain(1780..1810)
163            .map(|id| (id, 6000))
164            .collect();
165
166        // Ensure sorted for binary_search
167        penalties.sort_unstable_by_key(|&(id, _)| id);
168        Self { penalties }
169    }
170
171    /// mecab-ko의 dicrc 설정에서 생성
172    ///
173    /// # Format
174    ///
175    /// `left_id,penalty;left_id,penalty;...`
176    ///
177    /// # Example
178    ///
179    /// ```rust
180    /// use mecab_ko_core::viterbi::SpacePenalty;
181    ///
182    /// let penalty = SpacePenalty::from_dicrc("1785,6000;1786,6000;1787,5000");
183    /// assert_eq!(penalty.get(1785), 6000);
184    /// assert_eq!(penalty.get(1786), 6000);
185    /// assert_eq!(penalty.get(9999), 0);  // 미등록
186    /// ```
187    #[must_use]
188    pub fn from_dicrc(config: &str) -> Self {
189        let mut penalties = Vec::new();
190
191        for part in config.split(';') {
192            let parts: Vec<&str> = part.trim().split(',').collect();
193            if parts.len() == 2 {
194                if let (Ok(id), Ok(penalty)) = (
195                    parts[0].trim().parse::<u16>(),
196                    parts[1].trim().parse::<i32>(),
197                ) {
198                    penalties.push((id, penalty));
199                }
200            }
201        }
202
203        // Keep sorted for binary search in get()
204        penalties.sort_unstable_by_key(|&(id, _)| id);
205        Self { penalties }
206    }
207
208    /// 페널티 추가
209    pub fn add(&mut self, left_id: u16, penalty: i32) {
210        // Insert in sorted position for binary search correctness
211        let pos = self.penalties.partition_point(|&(id, _)| id < left_id);
212        self.penalties.insert(pos, (left_id, penalty));
213    }
214
215    /// 특정 품사 ID에 대한 페널티 조회
216    ///
217    /// # Returns
218    ///
219    /// 해당 ID에 설정된 페널티, 없으면 0
220    #[must_use]
221    #[inline]
222    pub fn get(&self, left_id: u16) -> i32 {
223        // Binary search on sorted penalties for O(log n) instead of O(n)
224        self.penalties
225            .binary_search_by_key(&left_id, |&(id, _)| id)
226            .map_or(0, |idx| self.penalties[idx].1)
227    }
228
229    /// 페널티가 설정되어 있는지 확인
230    #[must_use]
231    #[inline]
232    pub fn is_empty(&self) -> bool {
233        self.penalties.is_empty()
234    }
235
236    /// 설정된 페널티 개수
237    #[must_use]
238    #[inline]
239    pub fn len(&self) -> usize {
240        self.penalties.len()
241    }
242}
243
244/// Viterbi 탐색기
245///
246/// Lattice에서 최적 경로를 찾는 Viterbi 알고리즘을 구현합니다.
247#[derive(Debug, Clone)]
248pub struct ViterbiSearcher {
249    /// 띄어쓰기 패널티 설정
250    pub space_penalty: SpacePenalty,
251}
252
253impl Default for ViterbiSearcher {
254    fn default() -> Self {
255        Self::new()
256    }
257}
258
259impl ViterbiSearcher {
260    /// 새 탐색기 생성
261    #[must_use]
262    pub fn new() -> Self {
263        Self {
264            space_penalty: SpacePenalty::default(),
265        }
266    }
267
268    /// 띄어쓰기 패널티 설정
269    #[must_use]
270    pub fn with_space_penalty(mut self, penalty: SpacePenalty) -> Self {
271        self.space_penalty = penalty;
272        self
273    }
274
275    /// 최적 경로 탐색 (Forward-Backward)
276    ///
277    /// # Arguments
278    ///
279    /// * `lattice` - 노드가 추가된 Lattice
280    /// * `conn_cost` - 연접 비용 조회 인터페이스
281    ///
282    /// # Returns
283    ///
284    /// 최적 경로의 노드 ID 목록 (BOS, EOS 제외)
285    ///
286    /// # Example
287    ///
288    /// ```rust,no_run
289    /// # use mecab_ko_core::viterbi::{ViterbiSearcher, SpacePenalty};
290    /// # use mecab_ko_core::lattice::Lattice;
291    /// # let searcher = ViterbiSearcher::new();
292    /// # let conn_cost = mecab_ko_dict::matrix::DenseMatrix::new(1, 1, 0);
293    /// # let mut lattice = Lattice::new("test");
294    /// let path = searcher.search(&mut lattice, &conn_cost);
295    /// for node_id in path {
296    ///     let node = lattice.node(node_id).unwrap();
297    ///     println!("{}: {}", node.surface, node.word_cost);
298    /// }
299    /// ```
300    pub fn search<C: ConnectionCost>(&self, lattice: &mut Lattice, conn_cost: &C) -> Vec<NodeId> {
301        // Forward pass
302        self.forward_pass(lattice, conn_cost);
303
304        // Backward pass
305        Self::backward_pass(lattice)
306    }
307
308    /// Forward Pass: 각 노드의 최소 비용 계산
309    ///
310    /// BOS에서 시작하여 각 위치의 노드들에 대해 최소 비용을 계산합니다.
311    fn forward_pass<C: ConnectionCost>(&self, lattice: &mut Lattice, conn_cost: &C) {
312        let char_len = lattice.char_len();
313
314        // Reusable scratch buffers to avoid per-position Vec allocations.
315        // We collect (node_id) for the starting nodes and (id, total_cost, right_id)
316        // for ending nodes into these, clearing between positions.
317        let mut starting_ids: Vec<NodeId> = Vec::new();
318        let mut ending_nodes: Vec<(NodeId, i32, u16)> = Vec::new();
319
320        // 위치 0부터 끝까지 순회
321        for pos in 0..=char_len {
322            // Collect starting node IDs (need ownership before mutating lattice)
323            starting_ids.clear();
324            starting_ids.extend(lattice.nodes_starting_at(pos).map(|n| n.id));
325
326            // Collect ending node data once per position, reused for every
327            // starting node at this position.
328            ending_nodes.clear();
329            ending_nodes.extend(
330                lattice
331                    .nodes_ending_at(pos)
332                    .map(|n| (n.id, n.total_cost, n.right_id)),
333            );
334
335            for &node_id in &starting_ids {
336                self.update_node_cost_with_endings(lattice, conn_cost, node_id, &ending_nodes);
337            }
338        }
339    }
340
341    /// 단일 노드의 최소 비용 계산 및 업데이트 (사전 수집된 `ending_nodes` 사용)
342    ///
343    /// Hot path: 성능 최적화를 위해 인라인 처리
344    /// SIMD 최적화: 8개 이상의 이전 노드가 있으면 SIMD 배치 처리 사용
345    #[inline]
346    fn update_node_cost_with_endings<C: ConnectionCost>(
347        &self,
348        lattice: &mut Lattice,
349        conn_cost: &C,
350        node_id: NodeId,
351        ending_nodes: &[(NodeId, i32, u16)],
352    ) {
353        // SIMD 최적화: 8개 이상의 이전 노드가 있으면 SIMD 사용
354        #[cfg(feature = "simd")]
355        if ending_nodes.len() >= 8 {
356            let (best_cost, best_prev) =
357                simd::simd_update_node_cost(lattice, conn_cost, node_id, ending_nodes, &self.space_penalty);
358            if let Some(node) = lattice.node_mut(node_id) {
359                node.total_cost = best_cost;
360                node.prev_node_id = best_prev;
361            }
362            return;
363        }
364
365        // 현재 노드 정보 추출
366        let (left_id, word_cost, has_space) = {
367            let Some(node) = lattice.node(node_id) else {
368                return;
369            };
370            (node.left_id, node.word_cost, node.has_space_before)
371        };
372
373        // 띄어쓰기 패널티는 left_id에 대해 한 번만 조회
374        let space_penalty = if has_space {
375            self.space_penalty.get(left_id)
376        } else {
377            0
378        };
379
380        let mut best_cost = i32::MAX;
381        let mut best_prev = INVALID_NODE_ID;
382
383        for &(prev_id, prev_cost, prev_right_id) in ending_nodes {
384            // 이전 노드까지의 비용이 무한대면 스킵
385            if prev_cost == i32::MAX {
386                continue;
387            }
388
389            // 연접 비용 계산
390            let connection = conn_cost.cost(prev_right_id, left_id);
391
392            // 총 비용 = 이전 비용 + 연접 비용 + 단어 비용 + 띄어쓰기 패널티
393            let total = saturating_add_chain(prev_cost, connection, word_cost, space_penalty);
394
395            if total < best_cost {
396                best_cost = total;
397                best_prev = prev_id;
398            }
399        }
400
401        // 노드 업데이트
402        if let Some(node) = lattice.node_mut(node_id) {
403            node.total_cost = best_cost;
404            node.prev_node_id = best_prev;
405        }
406    }
407
408    /// 단일 노드의 최소 비용 계산 및 업데이트 (레거시, 테스트용으로 유지)
409    #[cfg(test)]
410    #[allow(dead_code)]
411    fn update_node_cost<C: ConnectionCost>(
412        &self,
413        lattice: &mut Lattice,
414        conn_cost: &C,
415        node_id: NodeId,
416        pos: usize,
417    ) {
418        // 현재 노드 정보 추출
419        let (left_id, word_cost, has_space) = {
420            let Some(node) = lattice.node(node_id) else {
421                return;
422            };
423            (node.left_id, node.word_cost, node.has_space_before)
424        };
425
426        // 이 노드로 연결될 수 있는 이전 노드들 (pos에서 끝나는 노드들)
427        let ending_nodes: Vec<(NodeId, i32, u16)> = lattice
428            .nodes_ending_at(pos)
429            .map(|n| (n.id, n.total_cost, n.right_id))
430            .collect();
431
432        let mut best_cost = i32::MAX;
433        let mut best_prev = INVALID_NODE_ID;
434
435        for (prev_id, prev_cost, prev_right_id) in ending_nodes {
436            if prev_cost == i32::MAX {
437                continue;
438            }
439
440            let connection = conn_cost.cost(prev_right_id, left_id);
441
442            let space_penalty = if has_space {
443                self.space_penalty.get(left_id)
444            } else {
445                0
446            };
447
448            let total = prev_cost
449                .saturating_add(connection)
450                .saturating_add(word_cost)
451                .saturating_add(space_penalty);
452
453            if total < best_cost {
454                best_cost = total;
455                best_prev = prev_id;
456            }
457        }
458
459        // 노드 업데이트
460        if let Some(node) = lattice.node_mut(node_id) {
461            node.total_cost = best_cost;
462            node.prev_node_id = best_prev;
463        }
464    }
465
466    /// Backward Pass: EOS에서 BOS까지 역추적
467    ///
468    /// 최적 경로의 노드 ID 목록을 반환합니다 (BOS, EOS 제외).
469    fn backward_pass(lattice: &Lattice) -> Vec<NodeId> {
470        let mut path = Vec::new();
471        let mut current_id = lattice.eos().id;
472
473        while current_id != INVALID_NODE_ID {
474            if let Some(node) = lattice.node(current_id) {
475                // BOS, EOS는 결과에서 제외
476                if node.node_type != NodeType::Bos && node.node_type != NodeType::Eos {
477                    path.push(current_id);
478                }
479                current_id = node.prev_node_id;
480            } else {
481                break;
482            }
483        }
484
485        path.reverse();
486        path
487    }
488
489    /// 최적 경로의 총 비용 조회
490    #[must_use]
491    pub fn get_best_cost(&self, lattice: &Lattice) -> i32 {
492        lattice.eos().total_cost
493    }
494
495    /// 경로가 유효한지 확인
496    ///
497    /// EOS까지의 경로가 존재하는지 확인합니다.
498    #[must_use]
499    pub fn has_valid_path(&self, lattice: &Lattice) -> bool {
500        lattice.eos().total_cost != i32::MAX && lattice.eos().prev_node_id != INVALID_NODE_ID
501    }
502}
503
504// ============================================
505// N-best 지원
506// ============================================
507
508/// N-best 경로 노드 (링크드 리스트)
509///
510/// 경로를 Rc로 연결하여 클론 비용을 줄입니다.
511/// 전체 경로를 복사하는 대신 참조 카운트만 증가시킵니다.
512#[derive(Debug, Clone)]
513struct PathNode {
514    /// 현재 노드 ID
515    node_id: NodeId,
516    /// 이전 경로 노드 (Rc로 공유)
517    prev: Option<Rc<Self>>,
518}
519
520impl PathNode {
521    /// 새 경로 노드 생성
522    ///
523    /// Note: Cannot be const due to `Rc<Self>` parameter
524    #[allow(clippy::missing_const_for_fn)]
525    fn new(node_id: NodeId, prev: Option<Rc<Self>>) -> Self {
526        Self { node_id, prev }
527    }
528
529    /// 경로를 Vec로 변환 (BOS에서 현재 노드까지)
530    fn to_vec(&self) -> Vec<NodeId> {
531        let mut path = Vec::new();
532        let mut current = Some(self);
533
534        while let Some(node) = current {
535            path.push(node.node_id);
536            current = node.prev.as_ref().map(std::convert::AsRef::as_ref);
537        }
538
539        path.reverse();
540        path
541    }
542}
543
544/// N-best 경로 후보
545#[derive(Debug, Clone)]
546struct NbestCandidate {
547    /// 노드 ID
548    node_id: NodeId,
549    /// 총 비용
550    cost: i32,
551    /// 이전 경로 (Rc로 공유되는 링크드 리스트)
552    path: Option<Rc<PathNode>>,
553}
554
555impl Eq for NbestCandidate {}
556
557impl PartialEq for NbestCandidate {
558    fn eq(&self, other: &Self) -> bool {
559        self.cost == other.cost
560    }
561}
562
563impl Ord for NbestCandidate {
564    fn cmp(&self, other: &Self) -> Ordering {
565        // Min-heap: 비용이 낮은 것이 우선
566        other.cost.cmp(&self.cost)
567    }
568}
569
570impl PartialOrd for NbestCandidate {
571    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
572        Some(self.cmp(other))
573    }
574}
575
576/// N-best 탐색기
577///
578/// 상위 N개의 최적 경로를 찾습니다.
579#[derive(Debug, Clone)]
580pub struct NbestSearcher {
581    /// 기본 Viterbi 탐색기
582    viterbi: ViterbiSearcher,
583    /// 최대 결과 수
584    max_results: usize,
585}
586
587impl NbestSearcher {
588    /// 새 N-best 탐색기 생성
589    #[must_use]
590    pub fn new(n: usize) -> Self {
591        Self {
592            viterbi: ViterbiSearcher::new(),
593            max_results: n,
594        }
595    }
596
597    /// 띄어쓰기 패널티 설정
598    #[must_use]
599    pub fn with_space_penalty(mut self, penalty: SpacePenalty) -> Self {
600        self.viterbi.space_penalty = penalty;
601        self
602    }
603
604    /// N-best 경로 탐색
605    ///
606    /// # Arguments
607    ///
608    /// * `lattice` - 노드가 추가된 Lattice
609    /// * `conn_cost` - 연접 비용 조회 인터페이스
610    ///
611    /// # Returns
612    ///
613    /// (경로, 비용) 쌍의 벡터, 비용 오름차순
614    pub fn search<C: ConnectionCost>(
615        &self,
616        lattice: &mut Lattice,
617        conn_cost: &C,
618    ) -> Vec<(Vec<NodeId>, i32)> {
619        // 먼저 Forward pass 실행
620        self.viterbi.forward_pass(lattice, conn_cost);
621
622        // 최적 경로가 없으면 빈 결과 반환
623        if !self.viterbi.has_valid_path(lattice) {
624            return Vec::new();
625        }
626
627        // 1-best인 경우 단순 backward pass
628        if self.max_results == 1 {
629            let path = ViterbiSearcher::backward_pass(lattice);
630            let cost = self.viterbi.get_best_cost(lattice);
631            return vec![(path, cost)];
632        }
633
634        // N-best: A* 유사 알고리즘
635        self.search_nbest(lattice, conn_cost)
636    }
637
638    /// N-best 경로 탐색 (A* 기반)
639    ///
640    /// # 최적화
641    ///
642    /// 경로를 `Rc<PathNode>`로 표현하여 클론 비용을 최소화합니다.
643    /// 전체 Vec를 복사하는 대신 참조 카운트만 증가시켜 O(1) 클론을 달성합니다.
644    fn search_nbest<C: ConnectionCost>(
645        &self,
646        lattice: &Lattice,
647        _conn_cost: &C,
648    ) -> Vec<(Vec<NodeId>, i32)> {
649        let mut results: Vec<(Vec<NodeId>, i32)> = Vec::new();
650        let mut heap: BinaryHeap<NbestCandidate> = BinaryHeap::new();
651
652        // EOS에서 시작
653        let eos = lattice.eos();
654        if eos.total_cost == i32::MAX {
655            return results;
656        }
657
658        heap.push(NbestCandidate {
659            node_id: eos.id,
660            cost: eos.total_cost,
661            path: None,
662        });
663
664        while let Some(candidate) = heap.pop() {
665            if results.len() >= self.max_results {
666                break;
667            }
668
669            let Some(node) = lattice.node(candidate.node_id) else {
670                continue;
671            };
672
673            // 현재까지의 경로 (Rc 클론은 O(1))
674            let current_path = if node.node_type != NodeType::Bos && node.node_type != NodeType::Eos
675            {
676                // BOS, EOS가 아니면 경로에 추가
677                Some(Rc::new(PathNode::new(candidate.node_id, candidate.path)))
678            } else {
679                candidate.path
680            };
681
682            // BOS에 도달하면 결과에 추가
683            if node.node_type == NodeType::Bos {
684                // 경로를 Vec로 변환 (완료된 경로만)
685                let path_vec = current_path.map_or_else(Vec::new, |path_node| path_node.to_vec());
686                results.push((path_vec, candidate.cost));
687                continue;
688            }
689
690            // 이전 노드로 계속 탐색
691            if node.prev_node_id != INVALID_NODE_ID {
692                heap.push(NbestCandidate {
693                    node_id: node.prev_node_id,
694                    cost: candidate.cost,
695                    path: current_path,
696                });
697            }
698        }
699
700        results
701    }
702}
703
704/// Viterbi 결과를 Token으로 변환하는 헬퍼
705pub struct ViterbiResult<'a> {
706    /// Lattice 참조
707    lattice: &'a Lattice,
708    /// 최적 경로 노드 ID
709    path: Vec<NodeId>,
710    /// 총 비용
711    total_cost: i32,
712}
713
714impl<'a> ViterbiResult<'a> {
715    /// 결과 생성
716    #[must_use]
717    pub const fn new(lattice: &'a Lattice, path: Vec<NodeId>, total_cost: i32) -> Self {
718        Self {
719            lattice,
720            path,
721            total_cost,
722        }
723    }
724
725    /// 경로의 노드들 반복
726    pub fn nodes(&self) -> impl Iterator<Item = &'a Node> + '_ {
727        self.path.iter().filter_map(|&id| self.lattice.node(id))
728    }
729
730    /// 총 비용
731    #[must_use]
732    pub const fn cost(&self) -> i32 {
733        self.total_cost
734    }
735
736    /// 노드 개수
737    #[must_use]
738    pub fn len(&self) -> usize {
739        self.path.len()
740    }
741
742    /// 비어있는지 확인
743    #[must_use]
744    pub fn is_empty(&self) -> bool {
745        self.path.is_empty()
746    }
747
748    /// 표면형 목록
749    #[must_use]
750    pub fn surfaces(&self) -> Vec<&str> {
751        self.nodes().map(|n| n.surface.as_ref()).collect()
752    }
753}
754
755#[cfg(test)]
756#[allow(clippy::unwrap_used)]
757mod tests {
758    use super::*;
759    use crate::lattice::NodeBuilder;
760
761    /// 테스트용 연접 비용 행렬
762    struct TestConnectionCost {
763        costs: std::collections::HashMap<(u16, u16), i32>,
764        default: i32,
765    }
766
767    impl TestConnectionCost {
768        fn new(default: i32) -> Self {
769            Self {
770                costs: std::collections::HashMap::new(),
771                default,
772            }
773        }
774
775        fn set(&mut self, right_id: u16, left_id: u16, cost: i32) {
776            self.costs.insert((right_id, left_id), cost);
777        }
778    }
779
780    impl ConnectionCost for TestConnectionCost {
781        fn cost(&self, right_id: u16, left_id: u16) -> i32 {
782            self.costs
783                .get(&(right_id, left_id))
784                .copied()
785                .unwrap_or(self.default)
786        }
787    }
788
789    #[test]
790    fn test_space_penalty_default() {
791        let penalty = SpacePenalty::default();
792        assert!(penalty.is_empty());
793        assert_eq!(penalty.get(100), 0);
794    }
795
796    #[test]
797    fn test_space_penalty_from_dicrc() {
798        let penalty = SpacePenalty::from_dicrc("100,5000;200,3000;300,1000");
799
800        assert_eq!(penalty.len(), 3);
801        assert_eq!(penalty.get(100), 5000);
802        assert_eq!(penalty.get(200), 3000);
803        assert_eq!(penalty.get(300), 1000);
804        assert_eq!(penalty.get(999), 0); // 미등록
805    }
806
807    #[test]
808    fn test_space_penalty_korean_default() {
809        let penalty = SpacePenalty::korean_default();
810        assert!(!penalty.is_empty());
811
812        // 조사 범위에 대해 페널티가 설정되어 있어야 함
813        assert!(penalty.get(1785) > 0);
814    }
815
816    #[test]
817    fn test_viterbi_simple_path() {
818        // 간단한 Lattice: "AB"
819        // BOS -> [A] -> [B] -> EOS
820        let mut lattice = Lattice::new("AB");
821
822        // A 노드 (위치 0-1)
823        lattice.add_node(
824            NodeBuilder::new("A", 0, 1)
825                .left_id(1)
826                .right_id(1)
827                .word_cost(100),
828        );
829
830        // B 노드 (위치 1-2)
831        lattice.add_node(
832            NodeBuilder::new("B", 1, 2)
833                .left_id(2)
834                .right_id(2)
835                .word_cost(200),
836        );
837
838        let conn_cost = ZeroConnectionCost;
839        let searcher = ViterbiSearcher::new();
840
841        let path = searcher.search(&mut lattice, &conn_cost);
842
843        assert_eq!(path.len(), 2);
844
845        // 첫 번째 노드는 "A"
846        let first = lattice.node(path[0]).unwrap();
847        assert_eq!(first.surface.as_ref(), "A");
848
849        // 두 번째 노드는 "B"
850        let second = lattice.node(path[1]).unwrap();
851        assert_eq!(second.surface.as_ref(), "B");
852
853        // 총 비용 확인
854        let total_cost = searcher.get_best_cost(&lattice);
855        assert_eq!(total_cost, 300); // 100 + 200
856    }
857
858    #[test]
859    fn test_viterbi_choose_best_path() {
860        // 두 가지 경로가 있는 Lattice: "AB"
861        // 경로 1: BOS -> [AB] -> EOS (비용: 500)
862        // 경로 2: BOS -> [A] -> [B] -> EOS (비용: 100 + 200 = 300)
863        let mut lattice = Lattice::new("AB");
864
865        // AB 노드 (위치 0-2) - 비용 높음
866        lattice.add_node(
867            NodeBuilder::new("AB", 0, 2)
868                .left_id(1)
869                .right_id(1)
870                .word_cost(500),
871        );
872
873        // A 노드 (위치 0-1)
874        lattice.add_node(
875            NodeBuilder::new("A", 0, 1)
876                .left_id(2)
877                .right_id(2)
878                .word_cost(100),
879        );
880
881        // B 노드 (위치 1-2)
882        lattice.add_node(
883            NodeBuilder::new("B", 1, 2)
884                .left_id(3)
885                .right_id(3)
886                .word_cost(200),
887        );
888
889        let conn_cost = ZeroConnectionCost;
890        let searcher = ViterbiSearcher::new();
891
892        let path = searcher.search(&mut lattice, &conn_cost);
893
894        // 더 낮은 비용의 경로 선택: A + B
895        assert_eq!(path.len(), 2);
896        assert_eq!(lattice.node(path[0]).unwrap().surface.as_ref(), "A");
897        assert_eq!(lattice.node(path[1]).unwrap().surface.as_ref(), "B");
898    }
899
900    #[test]
901    fn test_viterbi_with_connection_cost() {
902        // 연접 비용이 경로 선택에 영향
903        // 경로 1: BOS -> [AB] -> EOS (단어: 300, 연접: 0)
904        // 경로 2: BOS -> [A] -> [B] -> EOS (단어: 100+100=200, 연접: 500)
905        let mut lattice = Lattice::new("AB");
906
907        // AB 노드
908        lattice.add_node(
909            NodeBuilder::new("AB", 0, 2)
910                .left_id(1)
911                .right_id(1)
912                .word_cost(300),
913        );
914
915        // A 노드
916        lattice.add_node(
917            NodeBuilder::new("A", 0, 1)
918                .left_id(2)
919                .right_id(2)
920                .word_cost(100),
921        );
922
923        // B 노드
924        lattice.add_node(
925            NodeBuilder::new("B", 1, 2)
926                .left_id(3)
927                .right_id(3)
928                .word_cost(100),
929        );
930
931        let mut conn_cost = TestConnectionCost::new(0);
932        // A -> B 연접에 높은 비용 설정
933        conn_cost.set(2, 3, 500);
934
935        let searcher = ViterbiSearcher::new();
936        let path = searcher.search(&mut lattice, &conn_cost);
937
938        // 연접 비용 때문에 AB 선택: 300 < 200 + 500
939        assert_eq!(path.len(), 1);
940        assert_eq!(lattice.node(path[0]).unwrap().surface.as_ref(), "AB");
941    }
942
943    #[test]
944    fn test_viterbi_with_space_penalty() {
945        // 띄어쓰기 패널티 테스트
946        // "A B" (공백 있음)
947        // B의 left_id에 패널티가 있으면 다른 경로 선택
948        let mut lattice = Lattice::new("A B");
949        // 공백 제거 후 "AB"
950
951        // AB 노드 (전체)
952        lattice.add_node(
953            NodeBuilder::new("AB", 0, 2)
954                .left_id(1)
955                .right_id(1)
956                .word_cost(500),
957        );
958
959        // A 노드
960        lattice.add_node(
961            NodeBuilder::new("A", 0, 1)
962                .left_id(2)
963                .right_id(2)
964                .word_cost(100),
965        );
966
967        // B 노드 (공백 뒤에서 시작)
968        lattice.add_node(
969            NodeBuilder::new("B", 1, 2)
970                .left_id(100) // 페널티가 적용될 ID
971                .right_id(3)
972                .word_cost(100)
973                .has_space_before(true),
974        );
975
976        // B의 left_id에 높은 페널티 설정
977        let mut penalty = SpacePenalty::new();
978        penalty.add(100, 1000);
979
980        let conn_cost = ZeroConnectionCost;
981        let searcher = ViterbiSearcher::new().with_space_penalty(penalty);
982
983        let path = searcher.search(&mut lattice, &conn_cost);
984
985        // 페널티 때문에 AB 선택: 500 < 100 + 100 + 1000
986        assert_eq!(path.len(), 1);
987        assert_eq!(lattice.node(path[0]).unwrap().surface.as_ref(), "AB");
988    }
989
990    #[test]
991    fn test_viterbi_korean_example() {
992        // 한국어 예시: "아버지가"
993        let mut lattice = Lattice::new("아버지가");
994
995        // 경로 1: "아버지" + "가" (조사)
996        lattice.add_node(
997            NodeBuilder::new("아버지", 0, 3)
998                .left_id(1)
999                .right_id(1)
1000                .word_cost(1000),
1001        );
1002        lattice.add_node(
1003            NodeBuilder::new("가", 3, 4)
1004                .left_id(100) // 조사
1005                .right_id(100)
1006                .word_cost(500),
1007        );
1008
1009        // 경로 2: "아버" + "지가"
1010        lattice.add_node(
1011            NodeBuilder::new("아버", 0, 2)
1012                .left_id(2)
1013                .right_id(2)
1014                .word_cost(3000),
1015        );
1016        lattice.add_node(
1017            NodeBuilder::new("지가", 2, 4)
1018                .left_id(3)
1019                .right_id(3)
1020                .word_cost(3000),
1021        );
1022
1023        let conn_cost = ZeroConnectionCost;
1024        let searcher = ViterbiSearcher::new();
1025
1026        let path = searcher.search(&mut lattice, &conn_cost);
1027
1028        // "아버지" + "가" 선택 (비용: 1500 < 6000)
1029        assert_eq!(path.len(), 2);
1030        assert_eq!(lattice.node(path[0]).unwrap().surface.as_ref(), "아버지");
1031        assert_eq!(lattice.node(path[1]).unwrap().surface.as_ref(), "가");
1032    }
1033
1034    #[test]
1035    fn test_viterbi_empty_lattice() {
1036        let mut lattice = Lattice::new("");
1037
1038        let conn_cost = ZeroConnectionCost;
1039        let searcher = ViterbiSearcher::new();
1040
1041        let path = searcher.search(&mut lattice, &conn_cost);
1042
1043        // 빈 텍스트는 빈 경로
1044        assert!(path.is_empty());
1045    }
1046
1047    #[test]
1048    fn test_viterbi_no_path() {
1049        // 노드가 연결되지 않는 경우
1050        let mut lattice = Lattice::new("ABC");
1051
1052        // A만 있고 B, C 없음 -> EOS에 도달 불가
1053        lattice.add_node(
1054            NodeBuilder::new("A", 0, 1)
1055                .left_id(1)
1056                .right_id(1)
1057                .word_cost(100),
1058        );
1059
1060        let conn_cost = ZeroConnectionCost;
1061        let searcher = ViterbiSearcher::new();
1062
1063        let path = searcher.search(&mut lattice, &conn_cost);
1064
1065        // 유효한 경로 없음
1066        assert!(!searcher.has_valid_path(&lattice));
1067        assert!(path.is_empty());
1068    }
1069
1070    #[test]
1071    fn test_nbest_single() {
1072        let mut lattice = Lattice::new("AB");
1073
1074        lattice.add_node(
1075            NodeBuilder::new("A", 0, 1)
1076                .left_id(1)
1077                .right_id(1)
1078                .word_cost(100),
1079        );
1080        lattice.add_node(
1081            NodeBuilder::new("B", 1, 2)
1082                .left_id(2)
1083                .right_id(2)
1084                .word_cost(200),
1085        );
1086
1087        let conn_cost = ZeroConnectionCost;
1088        let searcher = NbestSearcher::new(1);
1089
1090        let results = searcher.search(&mut lattice, &conn_cost);
1091
1092        assert_eq!(results.len(), 1);
1093        assert_eq!(results[0].1, 300); // 비용
1094    }
1095
1096    #[test]
1097    fn test_viterbi_result_helper() {
1098        let mut lattice = Lattice::new("AB");
1099
1100        let _id1 = lattice.add_node(
1101            NodeBuilder::new("A", 0, 1)
1102                .left_id(1)
1103                .right_id(1)
1104                .word_cost(100),
1105        );
1106        let _id2 = lattice.add_node(
1107            NodeBuilder::new("B", 1, 2)
1108                .left_id(2)
1109                .right_id(2)
1110                .word_cost(200),
1111        );
1112
1113        let conn_cost = ZeroConnectionCost;
1114        let searcher = ViterbiSearcher::new();
1115        let path = searcher.search(&mut lattice, &conn_cost);
1116        let cost = searcher.get_best_cost(&lattice);
1117
1118        let result = ViterbiResult::new(&lattice, path, cost);
1119
1120        assert_eq!(result.len(), 2);
1121        assert_eq!(result.cost(), 300);
1122        assert_eq!(result.surfaces(), vec!["A", "B"]);
1123    }
1124
1125    #[test]
1126    fn test_viterbi_with_dense_matrix() {
1127        use mecab_ko_dict::DenseMatrix;
1128
1129        // 3x3 연접 비용 행렬 생성
1130        // left_id: 0=BOS, 1=명사, 2=조사
1131        // right_id: 0=EOS, 1=명사, 2=조사
1132        let mut matrix = DenseMatrix::new(3, 3, 0);
1133
1134        // 연접 비용 설정
1135        // BOS -> 명사: 낮은 비용 (자연스러움)
1136        matrix.set(0, 1, 100);
1137        // 명사 -> 조사: 낮은 비용 (자연스러움)
1138        matrix.set(1, 2, 50);
1139        // 조사 -> EOS: 낮은 비용
1140        matrix.set(2, 0, 30);
1141
1142        // BOS -> 조사: 높은 비용 (부자연스러움)
1143        matrix.set(0, 2, 5000);
1144        // 명사 -> EOS: 중간 비용
1145        matrix.set(1, 0, 200);
1146
1147        let mut lattice = Lattice::new("책을");
1148
1149        // "책" (명사) - 문자 위치 0..1
1150        lattice.add_node(
1151            NodeBuilder::new("책", 0, 1)
1152                .left_id(1) // 명사 left_id
1153                .right_id(1) // 명사 right_id
1154                .word_cost(500),
1155        );
1156
1157        // "을" (조사) - 문자 위치 1..2
1158        lattice.add_node(
1159            NodeBuilder::new("을", 1, 2)
1160                .left_id(2) // 조사 left_id
1161                .right_id(2) // 조사 right_id
1162                .word_cost(100),
1163        );
1164
1165        let searcher = ViterbiSearcher::new();
1166        let path = searcher.search(&mut lattice, &matrix);
1167
1168        // BOS -> 명사 -> 조사 -> EOS 경로 확인
1169        assert!(!path.is_empty());
1170
1171        let result = ViterbiResult::new(&lattice, path, searcher.get_best_cost(&lattice));
1172        assert_eq!(result.surfaces(), vec!["책", "을"]);
1173
1174        // 총 비용: BOS->명사(100) + 명사비용(500) + 명사->조사(50) + 조사비용(100) + 조사->EOS(30)
1175        // = 100 + 500 + 50 + 100 + 30 = 780
1176        assert_eq!(result.cost(), 780);
1177    }
1178}