1use crate::lattice::{Lattice, Node, NodeId, NodeType, INVALID_NODE_ID};
35use crate::viterbi::{ConnectionCost, SpacePenalty};
36use std::cmp::Ordering;
37use std::collections::BinaryHeap;
38
39#[derive(Debug, Clone)]
41pub struct NbestPath {
42 pub node_ids: Vec<NodeId>,
44 pub total_cost: i32,
46 pub rank: usize,
48}
49
50impl NbestPath {
51 #[must_use]
53 pub const fn new(node_ids: Vec<NodeId>, total_cost: i32, rank: usize) -> Self {
54 Self {
55 node_ids,
56 total_cost,
57 rank,
58 }
59 }
60
61 #[must_use]
63 pub fn is_empty(&self) -> bool {
64 self.node_ids.is_empty()
65 }
66
67 #[must_use]
69 pub fn len(&self) -> usize {
70 self.node_ids.len()
71 }
72
73 #[must_use]
75 pub const fn cost(&self) -> i32 {
76 self.total_cost
77 }
78
79 pub fn nodes<'a>(&'a self, lattice: &'a Lattice) -> impl Iterator<Item = &'a Node> + 'a {
81 self.node_ids.iter().filter_map(|&id| lattice.node(id))
82 }
83
84 #[must_use]
86 pub fn surfaces<'a>(&'a self, lattice: &'a Lattice) -> Vec<&'a str> {
87 self.nodes(lattice).map(|n| n.surface.as_ref()).collect()
88 }
89
90 #[must_use]
92 pub fn pos_tags<'a>(&'a self, lattice: &'a Lattice) -> Vec<&'a str> {
93 self.nodes(lattice)
94 .map(|n| {
95 n.feature
96 .split(',')
97 .next()
98 .unwrap_or_default()
99 })
100 .collect()
101 }
102}
103
104#[derive(Debug, Clone, Default)]
106pub struct NbestResult {
107 paths: Vec<NbestPath>,
109}
110
111impl NbestResult {
112 #[must_use]
114 pub const fn new(paths: Vec<NbestPath>) -> Self {
115 Self { paths }
116 }
117
118 #[must_use]
120 pub fn is_empty(&self) -> bool {
121 self.paths.is_empty()
122 }
123
124 #[must_use]
126 pub fn len(&self) -> usize {
127 self.paths.len()
128 }
129
130 #[must_use]
132 pub fn best(&self) -> Option<&NbestPath> {
133 self.paths.first()
134 }
135
136 #[must_use]
138 pub fn get(&self, index: usize) -> Option<&NbestPath> {
139 self.paths.get(index)
140 }
141
142 pub fn iter(&self) -> impl Iterator<Item = &NbestPath> {
144 self.paths.iter()
145 }
146
147 #[must_use]
149 pub fn into_paths(self) -> Vec<NbestPath> {
150 self.paths
151 }
152
153 #[must_use]
155 pub fn to_pairs(&self) -> Vec<(Vec<NodeId>, i32)> {
156 self.paths
157 .iter()
158 .map(|p| (p.node_ids.clone(), p.total_cost))
159 .collect()
160 }
161}
162
163impl IntoIterator for NbestResult {
164 type Item = NbestPath;
165 type IntoIter = std::vec::IntoIter<NbestPath>;
166
167 fn into_iter(self) -> Self::IntoIter {
168 self.paths.into_iter()
169 }
170}
171
172#[derive(Debug, Clone)]
174struct NodeCandidate {
175 cost: i32,
177 prev_node_id: NodeId,
179 prev_candidate_idx: usize,
181}
182
183impl Eq for NodeCandidate {}
184
185impl PartialEq for NodeCandidate {
186 fn eq(&self, other: &Self) -> bool {
187 self.cost == other.cost
188 }
189}
190
191impl Ord for NodeCandidate {
192 fn cmp(&self, other: &Self) -> Ordering {
193 other.cost.cmp(&self.cost)
195 }
196}
197
198impl PartialOrd for NodeCandidate {
199 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
200 Some(self.cmp(other))
201 }
202}
203
204#[derive(Debug, Clone)]
206struct BackwardCandidate {
207 node_id: NodeId,
209 candidate_idx: usize,
211 cost: i32,
213 path: Vec<NodeId>,
215}
216
217impl Eq for BackwardCandidate {}
218
219impl PartialEq for BackwardCandidate {
220 fn eq(&self, other: &Self) -> bool {
221 self.cost == other.cost
222 }
223}
224
225impl Ord for BackwardCandidate {
226 fn cmp(&self, other: &Self) -> Ordering {
227 other.cost.cmp(&self.cost)
229 }
230}
231
232impl PartialOrd for BackwardCandidate {
233 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
234 Some(self.cmp(other))
235 }
236}
237
238#[derive(Debug, Clone)]
242pub struct ImprovedNbestSearcher {
243 max_results: usize,
245 max_candidates_per_node: usize,
248 space_penalty: SpacePenalty,
250}
251
252impl ImprovedNbestSearcher {
253 #[must_use]
259 pub fn new(n: usize) -> Self {
260 Self {
261 max_results: n,
262 max_candidates_per_node: n.max(2) * 2,
264 space_penalty: SpacePenalty::default(),
265 }
266 }
267
268 #[must_use]
270 pub const fn with_max_candidates(mut self, k: usize) -> Self {
271 self.max_candidates_per_node = k;
272 self
273 }
274
275 #[must_use]
277 pub fn with_space_penalty(mut self, penalty: SpacePenalty) -> Self {
278 self.space_penalty = penalty;
279 self
280 }
281
282 pub fn search<C: ConnectionCost>(
293 &self,
294 lattice: &mut Lattice,
295 conn_cost: &C,
296 ) -> NbestResult {
297 if lattice.node_count() <= 2 {
298 return NbestResult::default();
300 }
301
302 let candidates = self.forward_pass_kbest(lattice, conn_cost);
304
305 self.backward_pass_nbest(lattice, &candidates)
307 }
308
309 fn forward_pass_kbest<C: ConnectionCost>(
313 &self,
314 lattice: &mut Lattice,
315 conn_cost: &C,
316 ) -> Vec<Vec<NodeCandidate>> {
317 let node_count = lattice.node_count();
318 let char_len = lattice.char_len();
319
320 let mut candidates: Vec<Vec<NodeCandidate>> = vec![Vec::new(); node_count];
322
323 let bos_id = lattice.bos().id;
325 candidates[bos_id as usize].push(NodeCandidate {
326 cost: 0,
327 prev_node_id: INVALID_NODE_ID,
328 prev_candidate_idx: 0,
329 });
330
331 let mut starting_ids: Vec<NodeId> = Vec::new();
333 let mut ending_data: Vec<(NodeId, u16)> = Vec::new();
334
335 for pos in 0..=char_len {
337 starting_ids.clear();
339 starting_ids.extend(lattice.nodes_starting_at(pos).map(|n| n.id));
340
341 ending_data.clear();
343 ending_data.extend(
344 lattice
345 .nodes_ending_at(pos)
346 .map(|n| (n.id, n.right_id)),
347 );
348
349 for &node_id in &starting_ids {
350 let (left_id, word_cost, has_space) = {
351 let Some(node) = lattice.node(node_id) else {
352 continue;
353 };
354 (node.left_id, node.word_cost, node.has_space_before)
355 };
356
357 let space_penalty = if has_space {
359 self.space_penalty.get(left_id)
360 } else {
361 0
362 };
363
364 let mut new_candidates: BinaryHeap<NodeCandidate> = BinaryHeap::new();
366
367 for &(prev_id, prev_right_id) in &ending_data {
368 let prev_candidates = &candidates[prev_id as usize];
369 if prev_candidates.is_empty() {
370 continue;
371 }
372
373 let connection = conn_cost.cost(prev_right_id, left_id);
374
375 for (idx, prev_cand) in prev_candidates.iter().enumerate() {
376 if prev_cand.cost == i32::MAX {
377 continue;
378 }
379
380 let total = prev_cand
381 .cost
382 .saturating_add(connection)
383 .saturating_add(word_cost)
384 .saturating_add(space_penalty);
385
386 new_candidates.push(NodeCandidate {
387 cost: total,
388 prev_node_id: prev_id,
389 prev_candidate_idx: idx,
390 });
391 }
392 }
393
394 let k = self.max_candidates_per_node;
396 let mut selected: Vec<NodeCandidate> = Vec::with_capacity(k);
397
398 while selected.len() < k {
399 if let Some(cand) = new_candidates.pop() {
400 selected.push(cand);
401 } else {
402 break;
403 }
404 }
405
406 candidates[node_id as usize] = selected;
407
408 if let Some(best) = candidates[node_id as usize].first() {
410 if let Some(node) = lattice.node_mut(node_id) {
411 node.total_cost = best.cost;
412 node.prev_node_id = best.prev_node_id;
413 }
414 }
415 }
416 }
417
418 candidates
419 }
420
421 fn backward_pass_nbest(
425 &self,
426 lattice: &Lattice,
427 candidates: &[Vec<NodeCandidate>],
428 ) -> NbestResult {
429 let eos = lattice.eos();
430 let eos_candidates = &candidates[eos.id as usize];
431
432 if eos_candidates.is_empty() {
433 return NbestResult::default();
434 }
435
436 let mut results: Vec<NbestPath> = Vec::with_capacity(self.max_results);
437 let mut heap: BinaryHeap<BackwardCandidate> = BinaryHeap::new();
438
439 for (idx, cand) in eos_candidates.iter().enumerate() {
441 heap.push(BackwardCandidate {
442 node_id: eos.id,
443 candidate_idx: idx,
444 cost: cand.cost,
445 path: Vec::new(),
446 });
447 }
448
449 while let Some(current) = heap.pop() {
450 if results.len() >= self.max_results {
451 break;
452 }
453
454 let node_cands = &candidates[current.node_id as usize];
455 if current.candidate_idx >= node_cands.len() {
456 continue;
457 }
458
459 let cand = &node_cands[current.candidate_idx];
460
461 if cand.prev_node_id == INVALID_NODE_ID {
463 let mut path = current.path;
465 path.reverse();
466
467 results.push(NbestPath::new(path, current.cost, results.len()));
468 continue;
469 }
470
471 let Some(node) = lattice.node(current.node_id) else {
473 continue;
474 };
475
476 let mut new_path = current.path.clone();
477 if node.node_type != NodeType::Bos && node.node_type != NodeType::Eos {
479 new_path.push(current.node_id);
480 }
481
482 heap.push(BackwardCandidate {
483 node_id: cand.prev_node_id,
484 candidate_idx: cand.prev_candidate_idx,
485 cost: current.cost,
486 path: new_path,
487 });
488 }
489
490 NbestResult::new(results)
491 }
492}
493
494impl ImprovedNbestSearcher {
496 pub fn search_pairs<C: ConnectionCost>(
498 &self,
499 lattice: &mut Lattice,
500 conn_cost: &C,
501 ) -> Vec<(Vec<NodeId>, i32)> {
502 self.search(lattice, conn_cost).to_pairs()
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509 use crate::lattice::NodeBuilder;
510 use crate::viterbi::ZeroConnectionCost;
511
512 #[test]
513 fn test_nbest_single_path() {
514 let mut lattice = Lattice::new("AB");
515
516 lattice.add_node(
517 NodeBuilder::new("A", 0, 1)
518 .left_id(1)
519 .right_id(1)
520 .word_cost(100),
521 );
522 lattice.add_node(
523 NodeBuilder::new("B", 1, 2)
524 .left_id(2)
525 .right_id(2)
526 .word_cost(200),
527 );
528
529 let searcher = ImprovedNbestSearcher::new(5);
530 let conn_cost = ZeroConnectionCost;
531 let results = searcher.search(&mut lattice, &conn_cost);
532
533 assert_eq!(results.len(), 1);
534 assert_eq!(results.best().unwrap().cost(), 300);
535 }
536
537 #[test]
538 fn test_nbest_multiple_paths() {
539 let mut lattice = Lattice::new("AB");
543
544 lattice.add_node(
545 NodeBuilder::new("A", 0, 1)
546 .left_id(1)
547 .right_id(1)
548 .word_cost(100),
549 );
550 lattice.add_node(
551 NodeBuilder::new("B", 1, 2)
552 .left_id(2)
553 .right_id(2)
554 .word_cost(200),
555 );
556 lattice.add_node(
557 NodeBuilder::new("AB", 0, 2)
558 .left_id(3)
559 .right_id(3)
560 .word_cost(350),
561 );
562
563 let searcher = ImprovedNbestSearcher::new(5);
564 let conn_cost = ZeroConnectionCost;
565 let results = searcher.search(&mut lattice, &conn_cost);
566
567 assert_eq!(results.len(), 2);
569
570 assert_eq!(results.get(0).unwrap().cost(), 300);
572
573 assert_eq!(results.get(1).unwrap().cost(), 350);
575 }
576
577 #[test]
578 fn test_nbest_korean_example() {
579 let mut lattice = Lattice::new("아버지가");
583
584 lattice.add_node(
586 NodeBuilder::new("아버지", 0, 3)
587 .left_id(1)
588 .right_id(1)
589 .word_cost(1000),
590 );
591 lattice.add_node(
592 NodeBuilder::new("가", 3, 4)
593 .left_id(2)
594 .right_id(2)
595 .word_cost(500),
596 );
597
598 lattice.add_node(
600 NodeBuilder::new("아버", 0, 2)
601 .left_id(3)
602 .right_id(3)
603 .word_cost(3000),
604 );
605 lattice.add_node(
606 NodeBuilder::new("지가", 2, 4)
607 .left_id(4)
608 .right_id(4)
609 .word_cost(3000),
610 );
611
612 let searcher = ImprovedNbestSearcher::new(3);
613 let conn_cost = ZeroConnectionCost;
614 let results = searcher.search(&mut lattice, &conn_cost);
615
616 assert!(results.len() >= 2);
617
618 let best = results.best().unwrap();
620 assert_eq!(best.cost(), 1500);
621 assert_eq!(best.surfaces(&lattice), vec!["아버지", "가"]);
622
623 let second = results.get(1).unwrap();
625 assert_eq!(second.cost(), 6000);
626 assert_eq!(second.surfaces(&lattice), vec!["아버", "지가"]);
627 }
628
629 #[test]
630 fn test_nbest_result_api() {
631 let mut lattice = Lattice::new("AB");
632
633 lattice.add_node(
634 NodeBuilder::new("A", 0, 1)
635 .left_id(1)
636 .right_id(1)
637 .word_cost(100),
638 );
639 lattice.add_node(
640 NodeBuilder::new("B", 1, 2)
641 .left_id(2)
642 .right_id(2)
643 .word_cost(200),
644 );
645
646 let searcher = ImprovedNbestSearcher::new(5);
647 let conn_cost = ZeroConnectionCost;
648 let results = searcher.search(&mut lattice, &conn_cost);
649
650 for path in results.iter() {
652 assert!(!path.is_empty());
653 assert!(path.cost() > 0);
654 }
655
656 let results2 = searcher.search(&mut lattice, &conn_cost);
658 for path in results2 {
659 assert!(!path.is_empty());
660 }
661 }
662
663 #[test]
664 fn test_nbest_empty_lattice() {
665 let mut lattice = Lattice::new("");
666 let searcher = ImprovedNbestSearcher::new(5);
667 let conn_cost = ZeroConnectionCost;
668 let results = searcher.search(&mut lattice, &conn_cost);
669
670 assert!(results.is_empty());
671 }
672
673 #[test]
674 fn test_nbest_compatibility_pairs() {
675 let mut lattice = Lattice::new("AB");
676
677 lattice.add_node(
678 NodeBuilder::new("AB", 0, 2)
679 .left_id(1)
680 .right_id(1)
681 .word_cost(300),
682 );
683
684 let searcher = ImprovedNbestSearcher::new(5);
685 let conn_cost = ZeroConnectionCost;
686 let pairs = searcher.search_pairs(&mut lattice, &conn_cost);
687
688 assert_eq!(pairs.len(), 1);
689 assert_eq!(pairs[0].1, 300);
690 }
691
692 #[test]
693 fn test_nbest_with_max_candidates() {
694 let mut lattice = Lattice::new("ABC");
695
696 lattice.add_node(NodeBuilder::new("A", 0, 1).word_cost(100));
698 lattice.add_node(NodeBuilder::new("B", 1, 2).word_cost(100));
699 lattice.add_node(NodeBuilder::new("C", 2, 3).word_cost(100));
700 lattice.add_node(NodeBuilder::new("AB", 0, 2).word_cost(180));
701 lattice.add_node(NodeBuilder::new("BC", 1, 3).word_cost(180));
702 lattice.add_node(NodeBuilder::new("ABC", 0, 3).word_cost(250));
703
704 let searcher = ImprovedNbestSearcher::new(5).with_max_candidates(10);
705 let conn_cost = ZeroConnectionCost;
706 let results = searcher.search(&mut lattice, &conn_cost);
707
708 assert!(results.len() >= 2);
710
711 let costs: Vec<i32> = results.iter().map(|p| p.cost()).collect();
713 for i in 1..costs.len() {
714 assert!(costs[i] >= costs[i - 1]);
715 }
716 }
717}