Skip to main content

mecab_ko_core/viterbi/
mod.rs

1//! Viterbi 알고리즘
2//!
3//! 최적 형태소 분석 경로를 찾는 Viterbi 알고리즘을 구현합니다.
4//!
5//! # 개요
6//!
7//! Viterbi 알고리즘은 Lattice에서 최소 비용 경로를 찾는 동적 프로그래밍 알고리즘입니다.
8//!
9//! ```text
10//! 총 비용 = Σ(단어 비용) + Σ(연접 비용) + Σ(띄어쓰기 패널티)
11//! ```
12//!
13//! # 알고리즘
14//!
15//! 1. **Forward Pass**: BOS에서 시작하여 각 노드까지의 최소 비용 계산
16//! 2. **Backward Pass**: EOS에서 BOS까지 역추적하여 최적 경로 추출
17//!
18//! # 한국어 특화
19//!
20//! - `left-space-penalty-factor`: 띄어쓰기 후 특정 품사 시작 시 페널티 부여
21//! - 조사(JK*), 어미(E*) 등이 띄어쓰기 직후 시작하면 높은 페널티
22//!
23//! # Example
24//!
25//! ```rust,no_run
26//! use mecab_ko_core::viterbi::{ViterbiSearcher, SpacePenalty};
27//! use mecab_ko_core::lattice::Lattice;
28//!
29//! let mut lattice = Lattice::new("아버지가방에");
30//! // ... 노드 추가 후 검색 ...
31//!
32//! let searcher = ViterbiSearcher::new()
33//!     .with_space_penalty(SpacePenalty::korean_default());
34//! ```
35
36use crate::lattice::{Lattice, Node, NodeId, NodeType, INVALID_NODE_ID};
37use std::cmp::Ordering;
38use std::collections::BinaryHeap;
39use std::rc::Rc;
40
41// SIMD 최적화 모듈
42#[cfg(feature = "simd")]
43pub mod simd;
44
45#[cfg(feature = "simd")]
46pub use simd::{simd_forward_pass_position, simd_update_node_cost};
47
48/// 여러 값의 포화 덧셈 (체인)
49///
50/// 오버플로우 방지를 위해 포화 연산 사용
51#[inline(always)]
52const fn saturating_add_chain(a: i32, b: i32, c: i32, d: i32) -> i32 {
53    a.saturating_add(b).saturating_add(c).saturating_add(d)
54}
55
56/// 연접 비용 조회 인터페이스
57///
58/// 두 형태소 간의 연결 비용을 반환합니다.
59/// 이 비용은 matrix.def에서 학습된 값입니다.
60pub trait ConnectionCost {
61    /// 두 문맥 ID 간의 연접 비용 반환
62    ///
63    /// # Arguments
64    ///
65    /// * `right_id` - 이전 노드의 우문맥 ID
66    /// * `left_id` - 현재 노드의 좌문맥 ID
67    ///
68    /// # Returns
69    ///
70    /// 연접 비용 (낮을수록 좋음)
71    fn cost(&self, right_id: u16, left_id: u16) -> i32;
72}
73
74/// 더미 연접 비용 (테스트용)
75///
76/// 모든 연접에 대해 0을 반환합니다.
77#[derive(Debug, Clone, Default)]
78pub struct ZeroConnectionCost;
79
80impl ConnectionCost for ZeroConnectionCost {
81    #[inline(always)]
82    fn cost(&self, _right_id: u16, _left_id: u16) -> i32 {
83        0
84    }
85}
86
87/// 고정 연접 비용 (테스트용)
88#[derive(Debug, Clone)]
89pub struct FixedConnectionCost {
90    /// 기본 비용
91    pub default_cost: i32,
92}
93
94impl FixedConnectionCost {
95    /// 새 고정 비용 생성
96    #[must_use]
97    pub const fn new(cost: i32) -> Self {
98        Self { default_cost: cost }
99    }
100}
101
102impl ConnectionCost for FixedConnectionCost {
103    #[inline(always)]
104    fn cost(&self, _right_id: u16, _left_id: u16) -> i32 {
105        self.default_cost
106    }
107}
108
109/// mecab-ko-dict의 `Matrix` trait에 대한 `ConnectionCost` 구현
110///
111/// 사전 모듈의 연접 비용 행렬을 Viterbi 알고리즘에서 직접 사용할 수 있습니다.
112impl<T: mecab_ko_dict::Matrix> ConnectionCost for T {
113    #[inline(always)]
114    fn cost(&self, right_id: u16, left_id: u16) -> i32 {
115        self.get(right_id, left_id)
116    }
117}
118
119/// 띄어쓰기 패널티 설정
120///
121/// mecab-ko의 `left-space-penalty-factor` 기능을 구현합니다.
122/// 띄어쓰기 직후에 특정 품사가 오면 페널티를 부여하여 오분석을 방지합니다.
123///
124/// # Example
125///
126/// ```rust
127/// use mecab_ko_core::viterbi::SpacePenalty;
128///
129/// // mecab-ko 기본 설정
130/// let penalty = SpacePenalty::korean_default();
131///
132/// // dicrc 형식에서 생성
133/// let penalty = SpacePenalty::from_dicrc("1785,6000;1786,6000");
134/// ```
135#[derive(Debug, Clone, Default)]
136pub struct SpacePenalty {
137    /// 페널티를 적용할 품사 ID 목록과 페널티 값
138    /// `(left_id, penalty)`
139    penalties: Vec<(u16, i32)>,
140}
141
142impl SpacePenalty {
143    /// 빈 페널티 설정 생성
144    #[must_use]
145    pub fn new() -> Self {
146        Self::default()
147    }
148
149    /// 한국어 기본 페널티 설정
150    ///
151    /// 조사(JK*)와 어미(E*)가 띄어쓰기 직후 나타나면 높은 페널티를 부여합니다.
152    /// 이는 "아버지가방에" → "아버지가 방에"로 분석하는 데 도움이 됩니다.
153    #[must_use]
154    pub fn korean_default() -> Self {
155        // mecab-ko-dic의 left-id 기준 (실제 값은 사전에 따라 다름)
156        // 여기서는 대표적인 조사/어미 ID 범위를 사용
157        // Build ranges in sorted order so binary_search in get() works correctly.
158
159        // 어미 계열 (EP, EF, EC, ETN, ETM): 1700~1759
160        // 조사 계열 (JKS, JKC, JKG, JKO, JKB, JKV, JKQ, JX, JC): 1780~1809
161        let mut penalties: Vec<(u16, i32)> = (1700u16..1760)
162            .chain(1780..1810)
163            .map(|id| (id, 6000))
164            .collect();
165
166        // Ensure sorted for binary_search
167        penalties.sort_unstable_by_key(|&(id, _)| id);
168        Self { penalties }
169    }
170
171    /// mecab-ko의 dicrc 설정에서 생성
172    ///
173    /// # Format
174    ///
175    /// `left_id,penalty;left_id,penalty;...`
176    ///
177    /// # Example
178    ///
179    /// ```rust
180    /// use mecab_ko_core::viterbi::SpacePenalty;
181    ///
182    /// let penalty = SpacePenalty::from_dicrc("1785,6000;1786,6000;1787,5000");
183    /// assert_eq!(penalty.get(1785), 6000);
184    /// assert_eq!(penalty.get(1786), 6000);
185    /// assert_eq!(penalty.get(9999), 0);  // 미등록
186    /// ```
187    #[must_use]
188    pub fn from_dicrc(config: &str) -> Self {
189        let mut penalties = Vec::new();
190
191        for part in config.split(';') {
192            let parts: Vec<&str> = part.trim().split(',').collect();
193            if parts.len() == 2 {
194                if let (Ok(id), Ok(penalty)) = (
195                    parts[0].trim().parse::<u16>(),
196                    parts[1].trim().parse::<i32>(),
197                ) {
198                    penalties.push((id, penalty));
199                }
200            }
201        }
202
203        // Keep sorted for binary search in get()
204        penalties.sort_unstable_by_key(|&(id, _)| id);
205        Self { penalties }
206    }
207
208    /// 페널티 추가
209    pub fn add(&mut self, left_id: u16, penalty: i32) {
210        // Insert in sorted position for binary search correctness
211        let pos = self.penalties.partition_point(|&(id, _)| id < left_id);
212        self.penalties.insert(pos, (left_id, penalty));
213    }
214
215    /// 특정 품사 ID에 대한 페널티 조회
216    ///
217    /// # Returns
218    ///
219    /// 해당 ID에 설정된 페널티, 없으면 0
220    #[must_use]
221    #[inline]
222    pub fn get(&self, left_id: u16) -> i32 {
223        // Binary search on sorted penalties for O(log n) instead of O(n)
224        self.penalties
225            .binary_search_by_key(&left_id, |&(id, _)| id)
226            .map_or(0, |idx| self.penalties[idx].1)
227    }
228
229    /// 페널티가 설정되어 있는지 확인
230    #[must_use]
231    #[inline]
232    pub fn is_empty(&self) -> bool {
233        self.penalties.is_empty()
234    }
235
236    /// 설정된 페널티 개수
237    #[must_use]
238    #[inline]
239    pub fn len(&self) -> usize {
240        self.penalties.len()
241    }
242}
243
244/// Viterbi 탐색기
245///
246/// Lattice에서 최적 경로를 찾는 Viterbi 알고리즘을 구현합니다.
247#[derive(Debug, Clone)]
248pub struct ViterbiSearcher {
249    /// 띄어쓰기 패널티 설정
250    pub space_penalty: SpacePenalty,
251}
252
253impl Default for ViterbiSearcher {
254    fn default() -> Self {
255        Self::new()
256    }
257}
258
259impl ViterbiSearcher {
260    /// 새 탐색기 생성
261    #[must_use]
262    pub fn new() -> Self {
263        Self {
264            space_penalty: SpacePenalty::default(),
265        }
266    }
267
268    /// 띄어쓰기 패널티 설정
269    #[must_use]
270    pub fn with_space_penalty(mut self, penalty: SpacePenalty) -> Self {
271        self.space_penalty = penalty;
272        self
273    }
274
275    /// 최적 경로 탐색 (Forward-Backward)
276    ///
277    /// # Arguments
278    ///
279    /// * `lattice` - 노드가 추가된 Lattice
280    /// * `conn_cost` - 연접 비용 조회 인터페이스
281    ///
282    /// # Returns
283    ///
284    /// 최적 경로의 노드 ID 목록 (BOS, EOS 제외)
285    ///
286    /// # Example
287    ///
288    /// ```rust,no_run
289    /// # use mecab_ko_core::viterbi::{ViterbiSearcher, SpacePenalty};
290    /// # use mecab_ko_core::lattice::Lattice;
291    /// # let searcher = ViterbiSearcher::new();
292    /// # let conn_cost = mecab_ko_dict::matrix::DenseMatrix::new(1, 1, 0);
293    /// # let mut lattice = Lattice::new("test");
294    /// let path = searcher.search(&mut lattice, &conn_cost);
295    /// for node_id in path {
296    ///     let node = lattice.node(node_id).unwrap();
297    ///     println!("{}: {}", node.surface, node.word_cost);
298    /// }
299    /// ```
300    pub fn search<C: ConnectionCost>(&self, lattice: &mut Lattice, conn_cost: &C) -> Vec<NodeId> {
301        // Forward pass
302        self.forward_pass(lattice, conn_cost);
303
304        // Backward pass
305        Self::backward_pass(lattice)
306    }
307
308    /// Forward Pass: 각 노드의 최소 비용 계산
309    ///
310    /// BOS에서 시작하여 각 위치의 노드들에 대해 최소 비용을 계산합니다.
311    fn forward_pass<C: ConnectionCost>(&self, lattice: &mut Lattice, conn_cost: &C) {
312        let char_len = lattice.char_len();
313
314        // Reusable scratch buffers to avoid per-position Vec allocations.
315        // We collect (node_id) for the starting nodes and (id, total_cost, right_id)
316        // for ending nodes into these, clearing between positions.
317        let mut starting_ids: Vec<NodeId> = Vec::new();
318        let mut ending_nodes: Vec<(NodeId, i32, u16)> = Vec::new();
319
320        // 위치 0부터 끝까지 순회
321        for pos in 0..=char_len {
322            // Collect starting node IDs (need ownership before mutating lattice)
323            starting_ids.clear();
324            starting_ids.extend(lattice.nodes_starting_at(pos).map(|n| n.id));
325
326            // Collect ending node data once per position, reused for every
327            // starting node at this position.
328            ending_nodes.clear();
329            ending_nodes.extend(
330                lattice
331                    .nodes_ending_at(pos)
332                    .map(|n| (n.id, n.total_cost, n.right_id)),
333            );
334
335            for &node_id in &starting_ids {
336                self.update_node_cost_with_endings(lattice, conn_cost, node_id, &ending_nodes);
337            }
338        }
339    }
340
341    /// 단일 노드의 최소 비용 계산 및 업데이트 (사전 수집된 `ending_nodes` 사용)
342    ///
343    /// Hot path: 성능 최적화를 위해 인라인 처리
344    /// SIMD 최적화: 8개 이상의 이전 노드가 있으면 SIMD 배치 처리 사용
345    #[inline]
346    fn update_node_cost_with_endings<C: ConnectionCost>(
347        &self,
348        lattice: &mut Lattice,
349        conn_cost: &C,
350        node_id: NodeId,
351        ending_nodes: &[(NodeId, i32, u16)],
352    ) {
353        // SIMD 최적화: 8개 이상의 이전 노드가 있으면 SIMD 사용
354        #[cfg(feature = "simd")]
355        if ending_nodes.len() >= 8 {
356            let (best_cost, best_prev) = simd::simd_update_node_cost(
357                lattice,
358                conn_cost,
359                node_id,
360                ending_nodes,
361                &self.space_penalty,
362            );
363            if let Some(node) = lattice.node_mut(node_id) {
364                node.total_cost = best_cost;
365                node.prev_node_id = best_prev;
366            }
367            return;
368        }
369
370        // 현재 노드 정보 추출
371        let (left_id, word_cost, has_space) = {
372            let Some(node) = lattice.node(node_id) else {
373                return;
374            };
375            (node.left_id, node.word_cost, node.has_space_before)
376        };
377
378        // 띄어쓰기 패널티는 left_id에 대해 한 번만 조회
379        let space_penalty = if has_space {
380            self.space_penalty.get(left_id)
381        } else {
382            0
383        };
384
385        let mut best_cost = i32::MAX;
386        let mut best_prev = INVALID_NODE_ID;
387
388        for &(prev_id, prev_cost, prev_right_id) in ending_nodes {
389            // 이전 노드까지의 비용이 무한대면 스킵
390            if prev_cost == i32::MAX {
391                continue;
392            }
393
394            // 연접 비용 계산
395            let connection = conn_cost.cost(prev_right_id, left_id);
396
397            // 총 비용 = 이전 비용 + 연접 비용 + 단어 비용 + 띄어쓰기 패널티
398            let total = saturating_add_chain(prev_cost, connection, word_cost, space_penalty);
399
400            if total < best_cost {
401                best_cost = total;
402                best_prev = prev_id;
403            }
404        }
405
406        // 노드 업데이트
407        if let Some(node) = lattice.node_mut(node_id) {
408            node.total_cost = best_cost;
409            node.prev_node_id = best_prev;
410        }
411    }
412
413    /// 단일 노드의 최소 비용 계산 및 업데이트 (레거시, 테스트용으로 유지)
414    #[cfg(test)]
415    #[allow(dead_code)]
416    fn update_node_cost<C: ConnectionCost>(
417        &self,
418        lattice: &mut Lattice,
419        conn_cost: &C,
420        node_id: NodeId,
421        pos: usize,
422    ) {
423        // 현재 노드 정보 추출
424        let (left_id, word_cost, has_space) = {
425            let Some(node) = lattice.node(node_id) else {
426                return;
427            };
428            (node.left_id, node.word_cost, node.has_space_before)
429        };
430
431        // 이 노드로 연결될 수 있는 이전 노드들 (pos에서 끝나는 노드들)
432        let ending_nodes: Vec<(NodeId, i32, u16)> = lattice
433            .nodes_ending_at(pos)
434            .map(|n| (n.id, n.total_cost, n.right_id))
435            .collect();
436
437        let mut best_cost = i32::MAX;
438        let mut best_prev = INVALID_NODE_ID;
439
440        for (prev_id, prev_cost, prev_right_id) in ending_nodes {
441            if prev_cost == i32::MAX {
442                continue;
443            }
444
445            let connection = conn_cost.cost(prev_right_id, left_id);
446
447            let space_penalty = if has_space {
448                self.space_penalty.get(left_id)
449            } else {
450                0
451            };
452
453            let total = prev_cost
454                .saturating_add(connection)
455                .saturating_add(word_cost)
456                .saturating_add(space_penalty);
457
458            if total < best_cost {
459                best_cost = total;
460                best_prev = prev_id;
461            }
462        }
463
464        // 노드 업데이트
465        if let Some(node) = lattice.node_mut(node_id) {
466            node.total_cost = best_cost;
467            node.prev_node_id = best_prev;
468        }
469    }
470
471    /// Backward Pass: EOS에서 BOS까지 역추적
472    ///
473    /// 최적 경로의 노드 ID 목록을 반환합니다 (BOS, EOS 제외).
474    fn backward_pass(lattice: &Lattice) -> Vec<NodeId> {
475        let mut path = Vec::new();
476        let mut current_id = lattice.eos().id;
477
478        while current_id != INVALID_NODE_ID {
479            if let Some(node) = lattice.node(current_id) {
480                // BOS, EOS는 결과에서 제외
481                if node.node_type != NodeType::Bos && node.node_type != NodeType::Eos {
482                    path.push(current_id);
483                }
484                current_id = node.prev_node_id;
485            } else {
486                break;
487            }
488        }
489
490        path.reverse();
491        path
492    }
493
494    /// 최적 경로의 총 비용 조회
495    #[must_use]
496    pub fn get_best_cost(&self, lattice: &Lattice) -> i32 {
497        lattice.eos().total_cost
498    }
499
500    /// 경로가 유효한지 확인
501    ///
502    /// EOS까지의 경로가 존재하는지 확인합니다.
503    #[must_use]
504    pub fn has_valid_path(&self, lattice: &Lattice) -> bool {
505        lattice.eos().total_cost != i32::MAX && lattice.eos().prev_node_id != INVALID_NODE_ID
506    }
507}
508
509// ============================================
510// N-best 지원
511// ============================================
512
513/// N-best 경로 노드 (링크드 리스트)
514///
515/// 경로를 Rc로 연결하여 클론 비용을 줄입니다.
516/// 전체 경로를 복사하는 대신 참조 카운트만 증가시킵니다.
517#[derive(Debug, Clone)]
518struct PathNode {
519    /// 현재 노드 ID
520    node_id: NodeId,
521    /// 이전 경로 노드 (Rc로 공유)
522    prev: Option<Rc<Self>>,
523}
524
525impl PathNode {
526    /// 새 경로 노드 생성
527    ///
528    /// Note: Cannot be const due to `Rc<Self>` parameter
529    #[allow(clippy::missing_const_for_fn)]
530    fn new(node_id: NodeId, prev: Option<Rc<Self>>) -> Self {
531        Self { node_id, prev }
532    }
533
534    /// 경로를 Vec로 변환 (BOS에서 현재 노드까지)
535    fn to_vec(&self) -> Vec<NodeId> {
536        let mut path = Vec::new();
537        let mut current = Some(self);
538
539        while let Some(node) = current {
540            path.push(node.node_id);
541            current = node.prev.as_ref().map(std::convert::AsRef::as_ref);
542        }
543
544        path.reverse();
545        path
546    }
547}
548
549/// N-best 경로 후보
550#[derive(Debug, Clone)]
551struct NbestCandidate {
552    /// 노드 ID
553    node_id: NodeId,
554    /// 총 비용
555    cost: i32,
556    /// 이전 경로 (Rc로 공유되는 링크드 리스트)
557    path: Option<Rc<PathNode>>,
558}
559
560impl Eq for NbestCandidate {}
561
562impl PartialEq for NbestCandidate {
563    fn eq(&self, other: &Self) -> bool {
564        self.cost == other.cost
565    }
566}
567
568impl Ord for NbestCandidate {
569    fn cmp(&self, other: &Self) -> Ordering {
570        // Min-heap: 비용이 낮은 것이 우선
571        other.cost.cmp(&self.cost)
572    }
573}
574
575impl PartialOrd for NbestCandidate {
576    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
577        Some(self.cmp(other))
578    }
579}
580
581/// N-best 탐색기
582///
583/// 상위 N개의 최적 경로를 찾습니다.
584#[derive(Debug, Clone)]
585pub struct NbestSearcher {
586    /// 기본 Viterbi 탐색기
587    viterbi: ViterbiSearcher,
588    /// 최대 결과 수
589    max_results: usize,
590}
591
592impl NbestSearcher {
593    /// 새 N-best 탐색기 생성
594    #[must_use]
595    pub fn new(n: usize) -> Self {
596        Self {
597            viterbi: ViterbiSearcher::new(),
598            max_results: n,
599        }
600    }
601
602    /// 띄어쓰기 패널티 설정
603    #[must_use]
604    pub fn with_space_penalty(mut self, penalty: SpacePenalty) -> Self {
605        self.viterbi.space_penalty = penalty;
606        self
607    }
608
609    /// N-best 경로 탐색
610    ///
611    /// # Arguments
612    ///
613    /// * `lattice` - 노드가 추가된 Lattice
614    /// * `conn_cost` - 연접 비용 조회 인터페이스
615    ///
616    /// # Returns
617    ///
618    /// (경로, 비용) 쌍의 벡터, 비용 오름차순
619    pub fn search<C: ConnectionCost>(
620        &self,
621        lattice: &mut Lattice,
622        conn_cost: &C,
623    ) -> Vec<(Vec<NodeId>, i32)> {
624        // 먼저 Forward pass 실행
625        self.viterbi.forward_pass(lattice, conn_cost);
626
627        // 최적 경로가 없으면 빈 결과 반환
628        if !self.viterbi.has_valid_path(lattice) {
629            return Vec::new();
630        }
631
632        // 1-best인 경우 단순 backward pass
633        if self.max_results == 1 {
634            let path = ViterbiSearcher::backward_pass(lattice);
635            let cost = self.viterbi.get_best_cost(lattice);
636            return vec![(path, cost)];
637        }
638
639        // N-best: A* 유사 알고리즘
640        self.search_nbest(lattice, conn_cost)
641    }
642
643    /// N-best 경로 탐색 (A* 기반)
644    ///
645    /// # 최적화
646    ///
647    /// 경로를 `Rc<PathNode>`로 표현하여 클론 비용을 최소화합니다.
648    /// 전체 Vec를 복사하는 대신 참조 카운트만 증가시켜 O(1) 클론을 달성합니다.
649    fn search_nbest<C: ConnectionCost>(
650        &self,
651        lattice: &Lattice,
652        _conn_cost: &C,
653    ) -> Vec<(Vec<NodeId>, i32)> {
654        let mut results: Vec<(Vec<NodeId>, i32)> = Vec::new();
655        let mut heap: BinaryHeap<NbestCandidate> = BinaryHeap::new();
656
657        // EOS에서 시작
658        let eos = lattice.eos();
659        if eos.total_cost == i32::MAX {
660            return results;
661        }
662
663        heap.push(NbestCandidate {
664            node_id: eos.id,
665            cost: eos.total_cost,
666            path: None,
667        });
668
669        while let Some(candidate) = heap.pop() {
670            if results.len() >= self.max_results {
671                break;
672            }
673
674            let Some(node) = lattice.node(candidate.node_id) else {
675                continue;
676            };
677
678            // 현재까지의 경로 (Rc 클론은 O(1))
679            let current_path = if node.node_type != NodeType::Bos && node.node_type != NodeType::Eos
680            {
681                // BOS, EOS가 아니면 경로에 추가
682                Some(Rc::new(PathNode::new(candidate.node_id, candidate.path)))
683            } else {
684                candidate.path
685            };
686
687            // BOS에 도달하면 결과에 추가
688            if node.node_type == NodeType::Bos {
689                // 경로를 Vec로 변환 (완료된 경로만)
690                let path_vec = current_path.map_or_else(Vec::new, |path_node| path_node.to_vec());
691                results.push((path_vec, candidate.cost));
692                continue;
693            }
694
695            // 이전 노드로 계속 탐색
696            if node.prev_node_id != INVALID_NODE_ID {
697                heap.push(NbestCandidate {
698                    node_id: node.prev_node_id,
699                    cost: candidate.cost,
700                    path: current_path,
701                });
702            }
703        }
704
705        results
706    }
707}
708
709/// Viterbi 결과를 Token으로 변환하는 헬퍼
710pub struct ViterbiResult<'a> {
711    /// Lattice 참조
712    lattice: &'a Lattice,
713    /// 최적 경로 노드 ID
714    path: Vec<NodeId>,
715    /// 총 비용
716    total_cost: i32,
717}
718
719impl<'a> ViterbiResult<'a> {
720    /// 결과 생성
721    #[must_use]
722    pub const fn new(lattice: &'a Lattice, path: Vec<NodeId>, total_cost: i32) -> Self {
723        Self {
724            lattice,
725            path,
726            total_cost,
727        }
728    }
729
730    /// 경로의 노드들 반복
731    pub fn nodes(&self) -> impl Iterator<Item = &'a Node> + '_ {
732        self.path.iter().filter_map(|&id| self.lattice.node(id))
733    }
734
735    /// 총 비용
736    #[must_use]
737    pub const fn cost(&self) -> i32 {
738        self.total_cost
739    }
740
741    /// 노드 개수
742    #[must_use]
743    pub fn len(&self) -> usize {
744        self.path.len()
745    }
746
747    /// 비어있는지 확인
748    #[must_use]
749    pub fn is_empty(&self) -> bool {
750        self.path.is_empty()
751    }
752
753    /// 표면형 목록
754    #[must_use]
755    pub fn surfaces(&self) -> Vec<&str> {
756        self.nodes().map(|n| n.surface.as_ref()).collect()
757    }
758}
759
760#[cfg(test)]
761#[allow(clippy::unwrap_used)]
762mod tests {
763    use super::*;
764    use crate::lattice::NodeBuilder;
765
766    /// 테스트용 연접 비용 행렬
767    struct TestConnectionCost {
768        costs: std::collections::HashMap<(u16, u16), i32>,
769        default: i32,
770    }
771
772    impl TestConnectionCost {
773        fn new(default: i32) -> Self {
774            Self {
775                costs: std::collections::HashMap::new(),
776                default,
777            }
778        }
779
780        fn set(&mut self, right_id: u16, left_id: u16, cost: i32) {
781            self.costs.insert((right_id, left_id), cost);
782        }
783    }
784
785    impl ConnectionCost for TestConnectionCost {
786        fn cost(&self, right_id: u16, left_id: u16) -> i32 {
787            self.costs
788                .get(&(right_id, left_id))
789                .copied()
790                .unwrap_or(self.default)
791        }
792    }
793
794    #[test]
795    fn test_space_penalty_default() {
796        let penalty = SpacePenalty::default();
797        assert!(penalty.is_empty());
798        assert_eq!(penalty.get(100), 0);
799    }
800
801    #[test]
802    fn test_space_penalty_from_dicrc() {
803        let penalty = SpacePenalty::from_dicrc("100,5000;200,3000;300,1000");
804
805        assert_eq!(penalty.len(), 3);
806        assert_eq!(penalty.get(100), 5000);
807        assert_eq!(penalty.get(200), 3000);
808        assert_eq!(penalty.get(300), 1000);
809        assert_eq!(penalty.get(999), 0); // 미등록
810    }
811
812    #[test]
813    fn test_space_penalty_korean_default() {
814        let penalty = SpacePenalty::korean_default();
815        assert!(!penalty.is_empty());
816
817        // 조사 범위에 대해 페널티가 설정되어 있어야 함
818        assert!(penalty.get(1785) > 0);
819    }
820
821    #[test]
822    fn test_viterbi_simple_path() {
823        // 간단한 Lattice: "AB"
824        // BOS -> [A] -> [B] -> EOS
825        let mut lattice = Lattice::new("AB");
826
827        // A 노드 (위치 0-1)
828        lattice.add_node(
829            NodeBuilder::new("A", 0, 1)
830                .left_id(1)
831                .right_id(1)
832                .word_cost(100),
833        );
834
835        // B 노드 (위치 1-2)
836        lattice.add_node(
837            NodeBuilder::new("B", 1, 2)
838                .left_id(2)
839                .right_id(2)
840                .word_cost(200),
841        );
842
843        let conn_cost = ZeroConnectionCost;
844        let searcher = ViterbiSearcher::new();
845
846        let path = searcher.search(&mut lattice, &conn_cost);
847
848        assert_eq!(path.len(), 2);
849
850        // 첫 번째 노드는 "A"
851        let first = lattice.node(path[0]).unwrap();
852        assert_eq!(first.surface.as_ref(), "A");
853
854        // 두 번째 노드는 "B"
855        let second = lattice.node(path[1]).unwrap();
856        assert_eq!(second.surface.as_ref(), "B");
857
858        // 총 비용 확인
859        let total_cost = searcher.get_best_cost(&lattice);
860        assert_eq!(total_cost, 300); // 100 + 200
861    }
862
863    #[test]
864    fn test_viterbi_choose_best_path() {
865        // 두 가지 경로가 있는 Lattice: "AB"
866        // 경로 1: BOS -> [AB] -> EOS (비용: 500)
867        // 경로 2: BOS -> [A] -> [B] -> EOS (비용: 100 + 200 = 300)
868        let mut lattice = Lattice::new("AB");
869
870        // AB 노드 (위치 0-2) - 비용 높음
871        lattice.add_node(
872            NodeBuilder::new("AB", 0, 2)
873                .left_id(1)
874                .right_id(1)
875                .word_cost(500),
876        );
877
878        // A 노드 (위치 0-1)
879        lattice.add_node(
880            NodeBuilder::new("A", 0, 1)
881                .left_id(2)
882                .right_id(2)
883                .word_cost(100),
884        );
885
886        // B 노드 (위치 1-2)
887        lattice.add_node(
888            NodeBuilder::new("B", 1, 2)
889                .left_id(3)
890                .right_id(3)
891                .word_cost(200),
892        );
893
894        let conn_cost = ZeroConnectionCost;
895        let searcher = ViterbiSearcher::new();
896
897        let path = searcher.search(&mut lattice, &conn_cost);
898
899        // 더 낮은 비용의 경로 선택: A + B
900        assert_eq!(path.len(), 2);
901        assert_eq!(lattice.node(path[0]).unwrap().surface.as_ref(), "A");
902        assert_eq!(lattice.node(path[1]).unwrap().surface.as_ref(), "B");
903    }
904
905    #[test]
906    fn test_viterbi_with_connection_cost() {
907        // 연접 비용이 경로 선택에 영향
908        // 경로 1: BOS -> [AB] -> EOS (단어: 300, 연접: 0)
909        // 경로 2: BOS -> [A] -> [B] -> EOS (단어: 100+100=200, 연접: 500)
910        let mut lattice = Lattice::new("AB");
911
912        // AB 노드
913        lattice.add_node(
914            NodeBuilder::new("AB", 0, 2)
915                .left_id(1)
916                .right_id(1)
917                .word_cost(300),
918        );
919
920        // A 노드
921        lattice.add_node(
922            NodeBuilder::new("A", 0, 1)
923                .left_id(2)
924                .right_id(2)
925                .word_cost(100),
926        );
927
928        // B 노드
929        lattice.add_node(
930            NodeBuilder::new("B", 1, 2)
931                .left_id(3)
932                .right_id(3)
933                .word_cost(100),
934        );
935
936        let mut conn_cost = TestConnectionCost::new(0);
937        // A -> B 연접에 높은 비용 설정
938        conn_cost.set(2, 3, 500);
939
940        let searcher = ViterbiSearcher::new();
941        let path = searcher.search(&mut lattice, &conn_cost);
942
943        // 연접 비용 때문에 AB 선택: 300 < 200 + 500
944        assert_eq!(path.len(), 1);
945        assert_eq!(lattice.node(path[0]).unwrap().surface.as_ref(), "AB");
946    }
947
948    #[test]
949    fn test_viterbi_with_space_penalty() {
950        // 띄어쓰기 패널티 테스트
951        // "A B" (공백 있음)
952        // B의 left_id에 패널티가 있으면 다른 경로 선택
953        let mut lattice = Lattice::new("A B");
954        // 공백 제거 후 "AB"
955
956        // AB 노드 (전체)
957        lattice.add_node(
958            NodeBuilder::new("AB", 0, 2)
959                .left_id(1)
960                .right_id(1)
961                .word_cost(500),
962        );
963
964        // A 노드
965        lattice.add_node(
966            NodeBuilder::new("A", 0, 1)
967                .left_id(2)
968                .right_id(2)
969                .word_cost(100),
970        );
971
972        // B 노드 (공백 뒤에서 시작)
973        lattice.add_node(
974            NodeBuilder::new("B", 1, 2)
975                .left_id(100) // 페널티가 적용될 ID
976                .right_id(3)
977                .word_cost(100)
978                .has_space_before(true),
979        );
980
981        // B의 left_id에 높은 페널티 설정
982        let mut penalty = SpacePenalty::new();
983        penalty.add(100, 1000);
984
985        let conn_cost = ZeroConnectionCost;
986        let searcher = ViterbiSearcher::new().with_space_penalty(penalty);
987
988        let path = searcher.search(&mut lattice, &conn_cost);
989
990        // 페널티 때문에 AB 선택: 500 < 100 + 100 + 1000
991        assert_eq!(path.len(), 1);
992        assert_eq!(lattice.node(path[0]).unwrap().surface.as_ref(), "AB");
993    }
994
995    #[test]
996    fn test_viterbi_korean_example() {
997        // 한국어 예시: "아버지가"
998        let mut lattice = Lattice::new("아버지가");
999
1000        // 경로 1: "아버지" + "가" (조사)
1001        lattice.add_node(
1002            NodeBuilder::new("아버지", 0, 3)
1003                .left_id(1)
1004                .right_id(1)
1005                .word_cost(1000),
1006        );
1007        lattice.add_node(
1008            NodeBuilder::new("가", 3, 4)
1009                .left_id(100) // 조사
1010                .right_id(100)
1011                .word_cost(500),
1012        );
1013
1014        // 경로 2: "아버" + "지가"
1015        lattice.add_node(
1016            NodeBuilder::new("아버", 0, 2)
1017                .left_id(2)
1018                .right_id(2)
1019                .word_cost(3000),
1020        );
1021        lattice.add_node(
1022            NodeBuilder::new("지가", 2, 4)
1023                .left_id(3)
1024                .right_id(3)
1025                .word_cost(3000),
1026        );
1027
1028        let conn_cost = ZeroConnectionCost;
1029        let searcher = ViterbiSearcher::new();
1030
1031        let path = searcher.search(&mut lattice, &conn_cost);
1032
1033        // "아버지" + "가" 선택 (비용: 1500 < 6000)
1034        assert_eq!(path.len(), 2);
1035        assert_eq!(lattice.node(path[0]).unwrap().surface.as_ref(), "아버지");
1036        assert_eq!(lattice.node(path[1]).unwrap().surface.as_ref(), "가");
1037    }
1038
1039    #[test]
1040    fn test_viterbi_empty_lattice() {
1041        let mut lattice = Lattice::new("");
1042
1043        let conn_cost = ZeroConnectionCost;
1044        let searcher = ViterbiSearcher::new();
1045
1046        let path = searcher.search(&mut lattice, &conn_cost);
1047
1048        // 빈 텍스트는 빈 경로
1049        assert!(path.is_empty());
1050    }
1051
1052    #[test]
1053    fn test_viterbi_no_path() {
1054        // 노드가 연결되지 않는 경우
1055        let mut lattice = Lattice::new("ABC");
1056
1057        // A만 있고 B, C 없음 -> EOS에 도달 불가
1058        lattice.add_node(
1059            NodeBuilder::new("A", 0, 1)
1060                .left_id(1)
1061                .right_id(1)
1062                .word_cost(100),
1063        );
1064
1065        let conn_cost = ZeroConnectionCost;
1066        let searcher = ViterbiSearcher::new();
1067
1068        let path = searcher.search(&mut lattice, &conn_cost);
1069
1070        // 유효한 경로 없음
1071        assert!(!searcher.has_valid_path(&lattice));
1072        assert!(path.is_empty());
1073    }
1074
1075    #[test]
1076    fn test_nbest_single() {
1077        let mut lattice = Lattice::new("AB");
1078
1079        lattice.add_node(
1080            NodeBuilder::new("A", 0, 1)
1081                .left_id(1)
1082                .right_id(1)
1083                .word_cost(100),
1084        );
1085        lattice.add_node(
1086            NodeBuilder::new("B", 1, 2)
1087                .left_id(2)
1088                .right_id(2)
1089                .word_cost(200),
1090        );
1091
1092        let conn_cost = ZeroConnectionCost;
1093        let searcher = NbestSearcher::new(1);
1094
1095        let results = searcher.search(&mut lattice, &conn_cost);
1096
1097        assert_eq!(results.len(), 1);
1098        assert_eq!(results[0].1, 300); // 비용
1099    }
1100
1101    #[test]
1102    fn test_viterbi_result_helper() {
1103        let mut lattice = Lattice::new("AB");
1104
1105        let _id1 = lattice.add_node(
1106            NodeBuilder::new("A", 0, 1)
1107                .left_id(1)
1108                .right_id(1)
1109                .word_cost(100),
1110        );
1111        let _id2 = lattice.add_node(
1112            NodeBuilder::new("B", 1, 2)
1113                .left_id(2)
1114                .right_id(2)
1115                .word_cost(200),
1116        );
1117
1118        let conn_cost = ZeroConnectionCost;
1119        let searcher = ViterbiSearcher::new();
1120        let path = searcher.search(&mut lattice, &conn_cost);
1121        let cost = searcher.get_best_cost(&lattice);
1122
1123        let result = ViterbiResult::new(&lattice, path, cost);
1124
1125        assert_eq!(result.len(), 2);
1126        assert_eq!(result.cost(), 300);
1127        assert_eq!(result.surfaces(), vec!["A", "B"]);
1128    }
1129
1130    #[test]
1131    fn test_viterbi_with_dense_matrix() {
1132        use mecab_ko_dict::DenseMatrix;
1133
1134        // 3x3 연접 비용 행렬 생성
1135        // left_id: 0=BOS, 1=명사, 2=조사
1136        // right_id: 0=EOS, 1=명사, 2=조사
1137        let mut matrix = DenseMatrix::new(3, 3, 0);
1138
1139        // 연접 비용 설정
1140        // BOS -> 명사: 낮은 비용 (자연스러움)
1141        matrix.set(0, 1, 100);
1142        // 명사 -> 조사: 낮은 비용 (자연스러움)
1143        matrix.set(1, 2, 50);
1144        // 조사 -> EOS: 낮은 비용
1145        matrix.set(2, 0, 30);
1146
1147        // BOS -> 조사: 높은 비용 (부자연스러움)
1148        matrix.set(0, 2, 5000);
1149        // 명사 -> EOS: 중간 비용
1150        matrix.set(1, 0, 200);
1151
1152        let mut lattice = Lattice::new("책을");
1153
1154        // "책" (명사) - 문자 위치 0..1
1155        lattice.add_node(
1156            NodeBuilder::new("책", 0, 1)
1157                .left_id(1) // 명사 left_id
1158                .right_id(1) // 명사 right_id
1159                .word_cost(500),
1160        );
1161
1162        // "을" (조사) - 문자 위치 1..2
1163        lattice.add_node(
1164            NodeBuilder::new("을", 1, 2)
1165                .left_id(2) // 조사 left_id
1166                .right_id(2) // 조사 right_id
1167                .word_cost(100),
1168        );
1169
1170        let searcher = ViterbiSearcher::new();
1171        let path = searcher.search(&mut lattice, &matrix);
1172
1173        // BOS -> 명사 -> 조사 -> EOS 경로 확인
1174        assert!(!path.is_empty());
1175
1176        let result = ViterbiResult::new(&lattice, path, searcher.get_best_cost(&lattice));
1177        assert_eq!(result.surfaces(), vec!["책", "을"]);
1178
1179        // 총 비용: BOS->명사(100) + 명사비용(500) + 명사->조사(50) + 조사비용(100) + 조사->EOS(30)
1180        // = 100 + 500 + 50 + 100 + 30 = 780
1181        assert_eq!(result.cost(), 780);
1182    }
1183}