1use crate::lattice::{Lattice, Node, NodeId, NodeType, INVALID_NODE_ID};
35use crate::viterbi::{ConnectionCost, SpacePenalty};
36use std::cmp::Ordering;
37use std::collections::BinaryHeap;
38
39#[derive(Debug, Clone)]
41pub struct NbestPath {
42 pub node_ids: Vec<NodeId>,
44 pub total_cost: i32,
46 pub rank: usize,
48}
49
50impl NbestPath {
51 #[must_use]
53 pub const fn new(node_ids: Vec<NodeId>, total_cost: i32, rank: usize) -> Self {
54 Self {
55 node_ids,
56 total_cost,
57 rank,
58 }
59 }
60
61 #[must_use]
63 pub fn is_empty(&self) -> bool {
64 self.node_ids.is_empty()
65 }
66
67 #[must_use]
69 pub fn len(&self) -> usize {
70 self.node_ids.len()
71 }
72
73 #[must_use]
75 pub const fn cost(&self) -> i32 {
76 self.total_cost
77 }
78
79 pub fn nodes<'a>(&'a self, lattice: &'a Lattice) -> impl Iterator<Item = &'a Node> + 'a {
81 self.node_ids.iter().filter_map(|&id| lattice.node(id))
82 }
83
84 #[must_use]
86 pub fn surfaces<'a>(&'a self, lattice: &'a Lattice) -> Vec<&'a str> {
87 self.nodes(lattice).map(|n| n.surface.as_ref()).collect()
88 }
89
90 #[must_use]
92 pub fn pos_tags<'a>(&'a self, lattice: &'a Lattice) -> Vec<&'a str> {
93 self.nodes(lattice)
94 .map(|n| n.feature.split(',').next().unwrap_or_default())
95 .collect()
96 }
97}
98
99#[derive(Debug, Clone, Default)]
101pub struct NbestResult {
102 paths: Vec<NbestPath>,
104}
105
106impl NbestResult {
107 #[must_use]
109 pub const fn new(paths: Vec<NbestPath>) -> Self {
110 Self { paths }
111 }
112
113 #[must_use]
115 pub fn is_empty(&self) -> bool {
116 self.paths.is_empty()
117 }
118
119 #[must_use]
121 pub fn len(&self) -> usize {
122 self.paths.len()
123 }
124
125 #[must_use]
127 pub fn best(&self) -> Option<&NbestPath> {
128 self.paths.first()
129 }
130
131 #[must_use]
133 pub fn get(&self, index: usize) -> Option<&NbestPath> {
134 self.paths.get(index)
135 }
136
137 pub fn iter(&self) -> impl Iterator<Item = &NbestPath> {
139 self.paths.iter()
140 }
141
142 #[must_use]
144 pub fn into_paths(self) -> Vec<NbestPath> {
145 self.paths
146 }
147
148 #[must_use]
150 pub fn to_pairs(&self) -> Vec<(Vec<NodeId>, i32)> {
151 self.paths
152 .iter()
153 .map(|p| (p.node_ids.clone(), p.total_cost))
154 .collect()
155 }
156}
157
158impl IntoIterator for NbestResult {
159 type Item = NbestPath;
160 type IntoIter = std::vec::IntoIter<NbestPath>;
161
162 fn into_iter(self) -> Self::IntoIter {
163 self.paths.into_iter()
164 }
165}
166
167#[derive(Debug, Clone)]
169struct NodeCandidate {
170 cost: i32,
172 prev_node_id: NodeId,
174 prev_candidate_idx: usize,
176}
177
178impl Eq for NodeCandidate {}
179
180impl PartialEq for NodeCandidate {
181 fn eq(&self, other: &Self) -> bool {
182 self.cost == other.cost
183 }
184}
185
186impl Ord for NodeCandidate {
187 fn cmp(&self, other: &Self) -> Ordering {
188 other.cost.cmp(&self.cost)
190 }
191}
192
193impl PartialOrd for NodeCandidate {
194 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
195 Some(self.cmp(other))
196 }
197}
198
199#[derive(Debug, Clone)]
201struct BackwardCandidate {
202 node_id: NodeId,
204 candidate_idx: usize,
206 cost: i32,
208 path: Vec<NodeId>,
210}
211
212impl Eq for BackwardCandidate {}
213
214impl PartialEq for BackwardCandidate {
215 fn eq(&self, other: &Self) -> bool {
216 self.cost == other.cost
217 }
218}
219
220impl Ord for BackwardCandidate {
221 fn cmp(&self, other: &Self) -> Ordering {
222 other.cost.cmp(&self.cost)
224 }
225}
226
227impl PartialOrd for BackwardCandidate {
228 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
229 Some(self.cmp(other))
230 }
231}
232
233#[derive(Debug, Clone)]
237pub struct ImprovedNbestSearcher {
238 max_results: usize,
240 max_candidates_per_node: usize,
243 space_penalty: SpacePenalty,
245}
246
247impl ImprovedNbestSearcher {
248 #[must_use]
254 pub fn new(n: usize) -> Self {
255 Self {
256 max_results: n,
257 max_candidates_per_node: n.max(2) * 2,
259 space_penalty: SpacePenalty::default(),
260 }
261 }
262
263 #[must_use]
265 pub const fn with_max_candidates(mut self, k: usize) -> Self {
266 self.max_candidates_per_node = k;
267 self
268 }
269
270 #[must_use]
272 pub fn with_space_penalty(mut self, penalty: SpacePenalty) -> Self {
273 self.space_penalty = penalty;
274 self
275 }
276
277 pub fn search<C: ConnectionCost>(&self, lattice: &mut Lattice, conn_cost: &C) -> NbestResult {
288 if lattice.node_count() <= 2 {
289 return NbestResult::default();
291 }
292
293 let candidates = self.forward_pass_kbest(lattice, conn_cost);
295
296 self.backward_pass_nbest(lattice, &candidates)
298 }
299
300 fn forward_pass_kbest<C: ConnectionCost>(
304 &self,
305 lattice: &mut Lattice,
306 conn_cost: &C,
307 ) -> Vec<Vec<NodeCandidate>> {
308 let node_count = lattice.node_count();
309 let char_len = lattice.char_len();
310
311 let mut candidates: Vec<Vec<NodeCandidate>> = vec![Vec::new(); node_count];
313
314 let bos_id = lattice.bos().id;
316 candidates[bos_id as usize].push(NodeCandidate {
317 cost: 0,
318 prev_node_id: INVALID_NODE_ID,
319 prev_candidate_idx: 0,
320 });
321
322 let mut starting_ids: Vec<NodeId> = Vec::new();
324 let mut ending_data: Vec<(NodeId, u16)> = Vec::new();
325
326 for pos in 0..=char_len {
328 starting_ids.clear();
330 starting_ids.extend(lattice.nodes_starting_at(pos).map(|n| n.id));
331
332 ending_data.clear();
334 ending_data.extend(lattice.nodes_ending_at(pos).map(|n| (n.id, n.right_id)));
335
336 for &node_id in &starting_ids {
337 let (left_id, word_cost, has_space) = {
338 let Some(node) = lattice.node(node_id) else {
339 continue;
340 };
341 (node.left_id, node.word_cost, node.has_space_before)
342 };
343
344 let space_penalty = if has_space {
346 self.space_penalty.get(left_id)
347 } else {
348 0
349 };
350
351 let mut new_candidates: BinaryHeap<NodeCandidate> = BinaryHeap::new();
353
354 for &(prev_id, prev_right_id) in &ending_data {
355 let prev_candidates = &candidates[prev_id as usize];
356 if prev_candidates.is_empty() {
357 continue;
358 }
359
360 let connection = conn_cost.cost(prev_right_id, left_id);
361
362 for (idx, prev_cand) in prev_candidates.iter().enumerate() {
363 if prev_cand.cost == i32::MAX {
364 continue;
365 }
366
367 let total = prev_cand
368 .cost
369 .saturating_add(connection)
370 .saturating_add(word_cost)
371 .saturating_add(space_penalty);
372
373 new_candidates.push(NodeCandidate {
374 cost: total,
375 prev_node_id: prev_id,
376 prev_candidate_idx: idx,
377 });
378 }
379 }
380
381 let k = self.max_candidates_per_node;
383 let mut selected: Vec<NodeCandidate> = Vec::with_capacity(k);
384
385 while selected.len() < k {
386 if let Some(cand) = new_candidates.pop() {
387 selected.push(cand);
388 } else {
389 break;
390 }
391 }
392
393 candidates[node_id as usize] = selected;
394
395 if let Some(best) = candidates[node_id as usize].first() {
397 if let Some(node) = lattice.node_mut(node_id) {
398 node.total_cost = best.cost;
399 node.prev_node_id = best.prev_node_id;
400 }
401 }
402 }
403 }
404
405 candidates
406 }
407
408 fn backward_pass_nbest(
412 &self,
413 lattice: &Lattice,
414 candidates: &[Vec<NodeCandidate>],
415 ) -> NbestResult {
416 let eos = lattice.eos();
417 let eos_candidates = &candidates[eos.id as usize];
418
419 if eos_candidates.is_empty() {
420 return NbestResult::default();
421 }
422
423 let mut results: Vec<NbestPath> = Vec::with_capacity(self.max_results);
424 let mut heap: BinaryHeap<BackwardCandidate> = BinaryHeap::new();
425
426 for (idx, cand) in eos_candidates.iter().enumerate() {
428 heap.push(BackwardCandidate {
429 node_id: eos.id,
430 candidate_idx: idx,
431 cost: cand.cost,
432 path: Vec::new(),
433 });
434 }
435
436 while let Some(current) = heap.pop() {
437 if results.len() >= self.max_results {
438 break;
439 }
440
441 let node_cands = &candidates[current.node_id as usize];
442 if current.candidate_idx >= node_cands.len() {
443 continue;
444 }
445
446 let cand = &node_cands[current.candidate_idx];
447
448 if cand.prev_node_id == INVALID_NODE_ID {
450 let mut path = current.path;
452 path.reverse();
453
454 results.push(NbestPath::new(path, current.cost, results.len()));
455 continue;
456 }
457
458 let Some(node) = lattice.node(current.node_id) else {
460 continue;
461 };
462
463 let mut new_path = current.path.clone();
464 if node.node_type != NodeType::Bos && node.node_type != NodeType::Eos {
466 new_path.push(current.node_id);
467 }
468
469 heap.push(BackwardCandidate {
470 node_id: cand.prev_node_id,
471 candidate_idx: cand.prev_candidate_idx,
472 cost: current.cost,
473 path: new_path,
474 });
475 }
476
477 NbestResult::new(results)
478 }
479}
480
481impl ImprovedNbestSearcher {
483 pub fn search_pairs<C: ConnectionCost>(
485 &self,
486 lattice: &mut Lattice,
487 conn_cost: &C,
488 ) -> Vec<(Vec<NodeId>, i32)> {
489 self.search(lattice, conn_cost).to_pairs()
490 }
491}
492
493#[cfg(test)]
494mod tests {
495 use super::*;
496 use crate::lattice::NodeBuilder;
497 use crate::viterbi::ZeroConnectionCost;
498
499 #[test]
500 fn test_nbest_single_path() {
501 let mut lattice = Lattice::new("AB");
502
503 lattice.add_node(
504 NodeBuilder::new("A", 0, 1)
505 .left_id(1)
506 .right_id(1)
507 .word_cost(100),
508 );
509 lattice.add_node(
510 NodeBuilder::new("B", 1, 2)
511 .left_id(2)
512 .right_id(2)
513 .word_cost(200),
514 );
515
516 let searcher = ImprovedNbestSearcher::new(5);
517 let conn_cost = ZeroConnectionCost;
518 let results = searcher.search(&mut lattice, &conn_cost);
519
520 assert_eq!(results.len(), 1);
521 assert_eq!(results.best().unwrap().cost(), 300);
522 }
523
524 #[test]
525 fn test_nbest_multiple_paths() {
526 let mut lattice = Lattice::new("AB");
530
531 lattice.add_node(
532 NodeBuilder::new("A", 0, 1)
533 .left_id(1)
534 .right_id(1)
535 .word_cost(100),
536 );
537 lattice.add_node(
538 NodeBuilder::new("B", 1, 2)
539 .left_id(2)
540 .right_id(2)
541 .word_cost(200),
542 );
543 lattice.add_node(
544 NodeBuilder::new("AB", 0, 2)
545 .left_id(3)
546 .right_id(3)
547 .word_cost(350),
548 );
549
550 let searcher = ImprovedNbestSearcher::new(5);
551 let conn_cost = ZeroConnectionCost;
552 let results = searcher.search(&mut lattice, &conn_cost);
553
554 assert_eq!(results.len(), 2);
556
557 assert_eq!(results.get(0).unwrap().cost(), 300);
559
560 assert_eq!(results.get(1).unwrap().cost(), 350);
562 }
563
564 #[test]
565 fn test_nbest_korean_example() {
566 let mut lattice = Lattice::new("아버지가");
570
571 lattice.add_node(
573 NodeBuilder::new("아버지", 0, 3)
574 .left_id(1)
575 .right_id(1)
576 .word_cost(1000),
577 );
578 lattice.add_node(
579 NodeBuilder::new("가", 3, 4)
580 .left_id(2)
581 .right_id(2)
582 .word_cost(500),
583 );
584
585 lattice.add_node(
587 NodeBuilder::new("아버", 0, 2)
588 .left_id(3)
589 .right_id(3)
590 .word_cost(3000),
591 );
592 lattice.add_node(
593 NodeBuilder::new("지가", 2, 4)
594 .left_id(4)
595 .right_id(4)
596 .word_cost(3000),
597 );
598
599 let searcher = ImprovedNbestSearcher::new(3);
600 let conn_cost = ZeroConnectionCost;
601 let results = searcher.search(&mut lattice, &conn_cost);
602
603 assert!(results.len() >= 2);
604
605 let best = results.best().unwrap();
607 assert_eq!(best.cost(), 1500);
608 assert_eq!(best.surfaces(&lattice), vec!["아버지", "가"]);
609
610 let second = results.get(1).unwrap();
612 assert_eq!(second.cost(), 6000);
613 assert_eq!(second.surfaces(&lattice), vec!["아버", "지가"]);
614 }
615
616 #[test]
617 fn test_nbest_result_api() {
618 let mut lattice = Lattice::new("AB");
619
620 lattice.add_node(
621 NodeBuilder::new("A", 0, 1)
622 .left_id(1)
623 .right_id(1)
624 .word_cost(100),
625 );
626 lattice.add_node(
627 NodeBuilder::new("B", 1, 2)
628 .left_id(2)
629 .right_id(2)
630 .word_cost(200),
631 );
632
633 let searcher = ImprovedNbestSearcher::new(5);
634 let conn_cost = ZeroConnectionCost;
635 let results = searcher.search(&mut lattice, &conn_cost);
636
637 for path in results.iter() {
639 assert!(!path.is_empty());
640 assert!(path.cost() > 0);
641 }
642
643 let results2 = searcher.search(&mut lattice, &conn_cost);
645 for path in results2 {
646 assert!(!path.is_empty());
647 }
648 }
649
650 #[test]
651 fn test_nbest_empty_lattice() {
652 let mut lattice = Lattice::new("");
653 let searcher = ImprovedNbestSearcher::new(5);
654 let conn_cost = ZeroConnectionCost;
655 let results = searcher.search(&mut lattice, &conn_cost);
656
657 assert!(results.is_empty());
658 }
659
660 #[test]
661 fn test_nbest_compatibility_pairs() {
662 let mut lattice = Lattice::new("AB");
663
664 lattice.add_node(
665 NodeBuilder::new("AB", 0, 2)
666 .left_id(1)
667 .right_id(1)
668 .word_cost(300),
669 );
670
671 let searcher = ImprovedNbestSearcher::new(5);
672 let conn_cost = ZeroConnectionCost;
673 let pairs = searcher.search_pairs(&mut lattice, &conn_cost);
674
675 assert_eq!(pairs.len(), 1);
676 assert_eq!(pairs[0].1, 300);
677 }
678
679 #[test]
680 fn test_nbest_with_max_candidates() {
681 let mut lattice = Lattice::new("ABC");
682
683 lattice.add_node(NodeBuilder::new("A", 0, 1).word_cost(100));
685 lattice.add_node(NodeBuilder::new("B", 1, 2).word_cost(100));
686 lattice.add_node(NodeBuilder::new("C", 2, 3).word_cost(100));
687 lattice.add_node(NodeBuilder::new("AB", 0, 2).word_cost(180));
688 lattice.add_node(NodeBuilder::new("BC", 1, 3).word_cost(180));
689 lattice.add_node(NodeBuilder::new("ABC", 0, 3).word_cost(250));
690
691 let searcher = ImprovedNbestSearcher::new(5).with_max_candidates(10);
692 let conn_cost = ZeroConnectionCost;
693 let results = searcher.search(&mut lattice, &conn_cost);
694
695 assert!(results.len() >= 2);
697
698 let costs: Vec<i32> = results.iter().map(super::NbestPath::cost).collect();
700 for i in 1..costs.len() {
701 assert!(costs[i] >= costs[i - 1]);
702 }
703 }
704}