1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
use std::ops::Deref;

use lindera_fst;
use lindera_fst::raw::Output;

use crate::core::word_entry::WordEntry;

#[derive(Clone)]
pub struct PrefixDict<Data = Vec<u8>> {
    pub fst: lindera_fst::raw::Fst<Data>,
    pub vals_data: Data,
}

impl PrefixDict<&[u8]> {
    pub fn from_static_slice(fst_data: &[u8], vals_data: &[u8]) -> lindera_fst::Result<PrefixDict> {
        let fst = lindera_fst::raw::Fst::new(fst_data.to_vec())?;
        Ok(PrefixDict {
            fst,
            vals_data: vals_data.to_vec(),
        })
    }
}

impl<D: Deref<Target = [u8]>> PrefixDict<D> {
    pub fn prefix<'a>(&'a self, s: &'a str) -> impl Iterator<Item = (usize, WordEntry)> + 'a {
        s.as_bytes()
            .iter()
            .scan(
                (0, self.fst.root(), Output::zero()),
                move |(prefix_len, node, output), &byte| {
                    if let Some(b_index) = node.find_input(byte) {
                        let transition = node.transition(b_index);
                        *prefix_len += 1;
                        *output = output.cat(transition.out);
                        *node = self.fst.node(transition.addr);
                        return Some((node.is_final(), *prefix_len, output.value()));
                    }
                    None
                },
            )
            .filter_map(|(is_final, prefix_len, offset_len)| {
                if is_final {
                    Some((prefix_len, offset_len))
                } else {
                    None
                }
            })
            .flat_map(move |(prefix_len, offset_len)| {
                let len = offset_len & ((1u64 << 5) - 1u64);
                let offset = offset_len >> 5u64;
                let offset_bytes = (offset as usize) * WordEntry::SERIALIZED_LEN;
                let data: &[u8] = &self.vals_data[offset_bytes..];
                (0..len as usize).map(move |i| {
                    (
                        prefix_len,
                        WordEntry::deserialize(&data[WordEntry::SERIALIZED_LEN * i..]),
                    )
                })
            })
    }
}

#[cfg(test)]
mod tests {
    //    use crate::core::prefix_dict::PrefixDict;
    //
    //    #[test]
    //    fn test_fst_prefix_2() {
    //        let prefix_dict = PrefixDict::default();
    //        let count_prefix = prefix_dict.prefix("—でも").count();
    //        assert_eq!(count_prefix, 1);
    //    }
    //
    //    #[test]
    //    fn test_fst_prefix_tilde() {
    //        let prefix_dict = PrefixDict::default();
    //        let count_prefix = prefix_dict.prefix("〜").count();
    //        assert_eq!(count_prefix, 2);
    //    }
    //
    //    #[test]
    //    fn test_fst_ikkagetsu() {
    //        let prefix_dict = PrefixDict::default();
    //        let count_prefix = prefix_dict.prefix("ー").count();
    //        assert_eq!(count_prefix, 0);
    //
    //        let count_prefix = prefix_dict.prefix("ヶ月").count();
    //        assert_eq!(count_prefix, 1);
    //    }
    //
    //    #[test]
    //    fn test_fst_prefix_asterisk_symbol() {
    //        let prefix_dict = PrefixDict::default();
    //        let count_prefix = prefix_dict.prefix("※").count();
    //        assert_eq!(count_prefix, 1);
    //    }
}