1use std::cmp::Ordering;
71use std::collections::HashSet;
72use std::fmt::Debug;
73#[cfg(feature = "python")]
74use std::fmt::Error;
75use std::fs::{self, File};
76use std::hash::Hash;
77use std::io::Read;
78use std::path::Path;
79
80use indexmap::{indexmap, IndexMap, IndexSet};
81use nom::branch::alt;
82use nom::bytes::complete::{tag, take_until1};
83use nom::multi::many_m_n;
84use nom::Parser;
85use std::sync::{Arc, LazyLock, RwLock};
86use xz2::read::{XzDecoder, XzEncoder};
87
88#[cfg(feature = "python")]
89use pyo3::create_exception;
90#[cfg(feature = "python")]
91use pyo3::exceptions::{PyIOError, PyValueError};
92#[cfg(feature = "python")]
93use pyo3::types::{PyDict, PyNone, PyTuple};
94#[cfg(feature = "python")]
95use pyo3::{prelude::*, py_run, IntoPyObjectExt};
96
97#[cfg(feature = "python")]
102type KFSTResult<T> = PyResult<T>;
103#[cfg(not(feature = "python"))]
104type KFSTResult<T> = std::result::Result<T, String>;
105
106#[cfg(feature = "python")]
107fn value_error<T>(msg: String) -> KFSTResult<T> {
108 KFSTResult::Err(PyErr::new::<PyValueError, _>(msg))
109}
110#[cfg(not(feature = "python"))]
111fn value_error<T>(msg: String) -> KFSTResult<T> {
112 KFSTResult::Err(msg)
113}
114
115#[cfg(feature = "python")]
116fn io_error<T>(msg: String) -> KFSTResult<T> {
117 use pyo3::exceptions::PyIOError;
118
119 KFSTResult::Err(PyErr::new::<PyIOError, _>(msg))
120}
121#[cfg(not(feature = "python"))]
122fn io_error<T>(msg: String) -> KFSTResult<T> {
123 KFSTResult::Err(msg)
124}
125
126#[cfg(feature = "python")]
127fn tokenization_exception<T>(msg: String) -> KFSTResult<T> {
128 KFSTResult::Err(PyErr::new::<TokenizationException, _>(msg))
129}
130#[cfg(not(feature = "python"))]
131fn tokenization_exception<T>(msg: String) -> KFSTResult<T> {
132 KFSTResult::Err(msg)
133}
134
135#[cfg(feature = "python")]
136create_exception!(
137 kfst_rs,
138 TokenizationException,
139 pyo3::exceptions::PyException
140);
141
142static STRING_INTERNER: LazyLock<RwLock<IndexSet<String>>> =
145 LazyLock::new(|| RwLock::new(IndexSet::new()));
146
147fn intern(s: String) -> u32 {
148 u32::try_from(STRING_INTERNER.write().unwrap().insert_full(s).0).unwrap()
149}
150
151fn deintern(idx: u32) -> String {
154 with_deinterned(idx, |x| x.to_string())
155}
156
157fn with_deinterned<F, X>(idx: u32, f: F) -> X
160where
161 F: FnOnce(&str) -> X,
162{
163 f(STRING_INTERNER
164 .read()
165 .unwrap()
166 .get_index(idx.try_into().unwrap())
167 .unwrap())
168}
169
170#[cfg_attr(
171 feature = "python",
172 pyclass(str = "RawSymbol({value:?})", eq, ord, frozen, hash, get_all)
173)]
174#[derive(Clone, Copy, Hash, PartialEq, PartialOrd, Ord, Eq, Debug)]
175#[readonly::make]
176pub struct RawSymbol {
180 pub value: [u8; 15],
188}
189
190#[cfg_attr(feature = "python", pymethods)]
191impl RawSymbol {
192 pub fn is_epsilon(&self) -> bool {
195 (self.value[0] & 1) != 0
196 }
197
198 pub fn is_unknown(&self) -> bool {
201 (self.value[0] & 2) != 0
202 }
203
204 pub fn get_symbol(&self) -> String {
206 format!("RawSymbol({:?})", self.value)
207 }
208
209 #[cfg(feature = "python")]
210 #[new]
211 fn new(value: [u8; 15]) -> Self {
212 RawSymbol { value }
213 }
214
215 #[cfg(not(feature = "python"))]
217 pub fn new(value: [u8; 15]) -> Self {
218 RawSymbol { value }
219 }
220
221 #[deprecated]
222 pub fn __repr__(&self) -> String {
224 format!("RawSymbol({:?})", self.value)
225 }
226}
227
228#[cfg(feature = "python")]
229struct PyObjectSymbol {
230 value: PyObject,
231}
232
233#[cfg(feature = "python")]
234impl Debug for PyObjectSymbol {
235 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236 Python::with_gil(|py| {
237 let s: String = self
240 .value
241 .getattr(py, "__repr__")
242 .unwrap()
243 .call0(py)
244 .unwrap()
245 .extract(py)
246 .unwrap();
247 f.write_str(&s)
248 })
249 }
250}
251
252#[cfg(feature = "python")]
253impl PartialEq for PyObjectSymbol {
254 fn eq(&self, other: &Self) -> bool {
255 Python::with_gil(|py| {
256 self.value
257 .getattr(py, "__eq__")
258 .unwrap_or_else(|_| {
259 panic!(
260 "Symbol {} doesn't have an __eq__ implementation.",
261 self.value
262 )
263 })
264 .call1(py, (other.value.clone_ref(py),))
265 .unwrap_or_else(|_| {
266 panic!("__eq__ on symbol {} failed to return a value.", self.value)
267 })
268 .extract(py)
269 .unwrap_or_else(|_| panic!("__eq__ on symbol {} didn't return a bool.", self.value))
270 })
271 }
272}
273
274#[cfg(feature = "python")]
275impl Hash for PyObjectSymbol {
276 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
277 state.write_i128(Python::with_gil(|py| {
278 self.value
279 .getattr(py, "__hash__")
280 .unwrap_or_else(|_| {
281 panic!(
282 "Symbol {} doesn't have a __hash__ implementation.",
283 self.value
284 )
285 })
286 .call0(py)
287 .unwrap_or_else(|_| {
288 panic!(
289 "__hash__ on symbol {} failed to return a value.",
290 self.value
291 )
292 })
293 .extract(py)
294 .unwrap_or_else(|_| {
295 panic!("__hash__ on symbol {} didn't return an int.", self.value)
296 })
297 }))
298 }
299}
300
301#[cfg(feature = "python")]
302impl Eq for PyObjectSymbol {}
303
304#[cfg(feature = "python")]
305impl PartialOrd for PyObjectSymbol {
306 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
307 Some(self.cmp(other))
308 }
309}
310
311#[cfg(feature = "python")]
312impl Ord for PyObjectSymbol {
313 fn cmp(&self, other: &Self) -> Ordering {
314 Python::with_gil(|py| {
315 match self
316 .value
317 .getattr(py, "__gt__")
318 .unwrap_or_else(|_| {
319 panic!(
320 "Symbol {} doesn't have a __gt__ implementation.",
321 self.value
322 )
323 })
324 .call1(py, (other.value.clone_ref(py),))
325 .unwrap_or_else(|_| {
326 panic!("__gt__ on symbol {} failed to return a value.", self.value)
327 })
328 .extract::<bool>(py)
329 .unwrap_or_else(|_| panic!("__gt__ on symbol {} didn't return a bool.", self.value))
330 {
331 true => Ordering::Greater,
332 false => {
333 match self
334 .value
335 .getattr(py, "__eq__")
336 .unwrap_or_else(|_| {
337 panic!(
338 "Symbol {} doesn't have an __eq__ implementation.",
339 self.value
340 )
341 })
342 .call1(py, (other.value.clone_ref(py),))
343 .unwrap_or_else(|_| {
344 panic!("__eq__ on symbol {} failed to return a value.", self.value)
345 })
346 .extract::<bool>(py)
347 .unwrap_or_else(|_| {
348 panic!("__eq__ on symbol {} didn't return a bool.", self.value)
349 }) {
350 true => Ordering::Equal,
351 false => Ordering::Less,
352 }
353 }
354 }
355 })
356 }
357}
358
359#[cfg(feature = "python")]
360impl Clone for PyObjectSymbol {
361 fn clone(&self) -> Self {
362 Python::with_gil(|py| Self {
363 value: self.value.clone_ref(py),
364 })
365 }
366}
367
368#[cfg(feature = "python")]
369impl FromPyObject<'_> for PyObjectSymbol {
370 fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
371 Ok(PyObjectSymbol {
372 value: ob.clone().unbind(),
373 }) }
375}
376
377#[cfg(feature = "python")]
378impl<'py> IntoPyObject<'py> for PyObjectSymbol {
379 type Target = PyAny;
380
381 type Output = Bound<'py, Self::Target>;
382
383 type Error = pyo3::PyErr;
384
385 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
386 self.value.into_bound_py_any(py)
387 }
388}
389
390#[cfg(feature = "python")]
391impl PyObjectSymbol {
392 fn is_epsilon(&self) -> bool {
393 Python::with_gil(|py| {
394 self.value
395 .getattr(py, "is_epsilon")
396 .unwrap_or_else(|_| {
397 panic!(
398 "Symbol {} doesn't have an is_epsilon implementation.",
399 self.value
400 )
401 })
402 .call(py, (), None)
403 .unwrap_or_else(|_| {
404 panic!(
405 "is_epsilon on symbol {} failed to return a value.",
406 self.value
407 )
408 })
409 .extract(py)
410 .unwrap_or_else(|_| {
411 panic!("is_epsilon on symbol {} didn't return a bool.", self.value)
412 })
413 })
414 }
415
416 fn is_unknown(&self) -> bool {
417 Python::with_gil(|py| {
418 self.value
419 .getattr(py, "is_unknown")
420 .unwrap_or_else(|_| {
421 panic!(
422 "Symbol {} doesn't have an is_unknown implementation.",
423 self.value
424 )
425 })
426 .call(py, (), None)
427 .unwrap_or_else(|_| {
428 panic!(
429 "is_unknown on symbol {} failed to return a value.",
430 self.value
431 )
432 })
433 .extract(py)
434 .unwrap_or_else(|_| {
435 panic!("is_unknown on symbol {} didn't return a bool.", self.value)
436 })
437 })
438 }
439
440 fn get_symbol(&self) -> String {
441 Python::with_gil(|py| {
442 self.value
443 .getattr(py, "get_symbol")
444 .unwrap_or_else(|_| {
445 panic!(
446 "Symbol {} doesn't have a get_symbol implementation.",
447 self.value
448 )
449 })
450 .call(py, (), None)
451 .unwrap_or_else(|_| {
452 panic!(
453 "get_symbol on symbol {} failed to return a value.",
454 self.value
455 )
456 })
457 .extract(py)
458 .unwrap_or_else(|_| {
459 panic!("get_symbol on symbol {} didn't return a bool.", self.value)
460 })
461 })
462 }
463}
464
465#[cfg_attr(
466 feature = "python",
467 pyclass(
468 str = "StringSymbol({string:?}, {unknown})",
469 eq,
470 ord,
471 frozen,
472 hash,
473 get_all
474 )
475)]
476#[derive(Clone, Copy, Hash, PartialEq, Eq)]
477#[readonly::make]
478pub struct StringSymbol {
481 string: u32,
482 pub unknown: bool,
484}
485
486impl StringSymbol {
487 pub fn parse(symbol: &str) -> nom::IResult<&str, StringSymbol> {
503 if symbol.is_empty() {
504 return nom::IResult::Err(nom::Err::Error(nom::error::Error::new(
505 symbol,
506 nom::error::ErrorKind::Fail,
507 )));
508 }
509 Ok((
510 "",
511 StringSymbol {
512 string: intern(symbol.to_string()),
513 unknown: false,
514 },
515 ))
516 }
517
518 fn with_symbol<F, X>(&self, f: F) -> X
521 where
522 F: FnOnce(&str) -> X,
523 {
524 with_deinterned(self.string, f)
525 }
526}
527
528impl PartialOrd for StringSymbol {
529 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
530 Some(self.cmp(other))
531 }
532}
533
534impl Ord for StringSymbol {
535 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
536 with_deinterned(other.string, |other_str| {
537 with_deinterned(self.string, |self_str| {
538 (other_str.chars().count(), &self_str, self.unknown).cmp(&(
539 self_str.chars().count(),
540 &other_str,
541 other.unknown,
542 ))
543 })
544 })
545 }
546}
547#[cfg_attr(feature = "python", pymethods)]
548impl StringSymbol {
549 pub fn is_epsilon(&self) -> bool {
552 false
553 }
554
555 pub fn is_unknown(&self) -> bool {
558 self.unknown
559 }
560
561 pub fn get_symbol(&self) -> String {
563 deintern(self.string)
564 }
565
566 #[cfg(feature = "python")]
567 #[new]
568 #[pyo3(signature = (string, unknown = false))]
569 fn new(string: String, unknown: bool) -> Self {
570 StringSymbol {
571 string: intern(string),
572 unknown,
573 }
574 }
575
576 #[cfg(not(feature = "python"))]
577 pub fn new(string: String, unknown: bool) -> Self {
579 StringSymbol {
580 string: intern(string),
581 unknown,
582 }
583 }
584
585 #[deprecated]
586 pub fn __repr__(&self) -> String {
588 format!("StringSymbol({:?}, {})", self.string, self.unknown)
589 }
590}
591
592#[cfg_attr(feature = "python", pyclass(eq, ord, frozen))]
593#[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Copy, Hash)]
594pub enum FlagDiacriticType {
596 U,
598 R,
600 D,
602 C,
604 P,
606 N,
609}
610
611impl FlagDiacriticType {
612 pub fn from_str(s: &str) -> Option<Self> {
615 match s {
616 "U" => Some(FlagDiacriticType::U),
617 "R" => Some(FlagDiacriticType::R),
618 "D" => Some(FlagDiacriticType::D),
619 "C" => Some(FlagDiacriticType::C),
620 "P" => Some(FlagDiacriticType::P),
621 "N" => Some(FlagDiacriticType::N),
622 _ => None,
623 }
624 }
625}
626
627#[cfg_attr(feature = "python", pymethods)]
628impl FlagDiacriticType {
629 #[deprecated]
630 pub fn __repr__(&self) -> String {
632 format!("{:?}", &self)
633 }
634}
635
636#[cfg_attr(
637 feature = "python",
638 pyclass(
639 str = "FlagDiacriticSymbol({flag_type:?}, {key:?}, {value:?})",
640 eq,
641 ord,
642 frozen,
643 hash
644 )
645)]
646#[derive(PartialEq, Eq, Clone, Copy, Hash)]
647#[readonly::make]
648pub struct FlagDiacriticSymbol {
655 pub flag_type: FlagDiacriticType,
657 key: u32,
658 value: u32,
659}
660
661impl PartialOrd for FlagDiacriticSymbol {
662 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
663 Some(self.cmp(other))
664 }
665}
666
667impl Ord for FlagDiacriticSymbol {
668 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
669 let other_str = other.get_symbol();
671 let self_str = self.get_symbol();
672 (other_str.chars().count(), &self_str).cmp(&(self_str.chars().count(), &other_str))
673 }
674}
675
676impl FlagDiacriticSymbol {
677 pub fn parse(symbol: &str) -> nom::IResult<&str, FlagDiacriticSymbol> {
679 let mut parser = (
680 tag("@"),
681 alt((tag("U"), tag("R"), tag("D"), tag("C"), tag("P"), tag("N"))),
682 tag("."),
683 many_m_n(0, 1, (take_until1("."), tag("."))),
684 take_until1("@"),
685 tag("@"),
686 );
687 let (input, (_, diacritic_type, _, named_piece_1, named_piece_2, _)) =
688 parser.parse(symbol)?;
689 let diacritic_type = match FlagDiacriticType::from_str(diacritic_type) {
690 Some(x) => x,
691 None => {
692 return Err(nom::Err::Error(nom::error::Error::new(
693 diacritic_type,
694 nom::error::ErrorKind::Fail,
695 )))
696 }
697 };
698
699 let (name, value) = if !named_piece_1.is_empty() {
700 (named_piece_1[0].0, intern(named_piece_2.to_string()))
701 } else {
702 (named_piece_2, u32::MAX)
703 };
704
705 Ok((
706 input,
707 FlagDiacriticSymbol {
708 flag_type: diacritic_type,
709 key: intern(name.to_string()),
710 value,
711 },
712 ))
713 }
714}
715
716impl FlagDiacriticSymbol {
720 fn _from_symbol_string(symbol: &str) -> KFSTResult<Self> {
721 match FlagDiacriticSymbol::parse(symbol) {
722 Ok(("", symbol)) => KFSTResult::Ok(symbol),
723 Ok((rest, _)) => value_error(format!("String {symbol:?} contains a valid FlagDiacriticSymbol, but it has unparseable text at the end: {rest:?}")),
724 _ => value_error(format!("Not a valid FlagDiacriticSymbol: {symbol:?}"))
725 }
726 }
727
728 #[cfg(not(feature = "python"))]
729 #[deprecated]
730 pub fn from_symbol_string(symbol: &str) -> KFSTResult<Self> {
732 FlagDiacriticSymbol::_from_symbol_string(symbol)
733 }
734
735 fn _new(flag_type: String, key: String, value: Option<String>) -> KFSTResult<Self> {
736 let flag_type = match FlagDiacriticType::from_str(&flag_type) {
737 Some(x) => x,
738 None => value_error(format!(
739 "String {flag_type:?} is not a valid FlagDiacriticType specifier"
740 ))?,
741 };
742 Ok(FlagDiacriticSymbol {
743 flag_type,
744 key: intern(key),
745 value: value.map(intern).unwrap_or(u32::MAX),
746 })
747 }
748
749 #[cfg(not(feature = "python"))]
750 pub fn new(flag_type: String, key: String, value: Option<String>) -> KFSTResult<Self> {
752 FlagDiacriticSymbol::_new(flag_type, key, value)
753 }
754
755 #[cfg(not(feature = "python"))]
756 pub fn key(self) -> String {
758 deintern(self.key)
759 }
760
761 #[cfg(not(feature = "python"))]
762 pub fn value(self) -> String {
764 deintern(self.value)
765 }
766}
767
768#[cfg_attr(feature = "python", pymethods)]
769impl FlagDiacriticSymbol {
770 pub fn is_epsilon(&self) -> bool {
773 true
774 }
775
776 pub fn is_unknown(&self) -> bool {
779 false
780 }
781
782 pub fn get_symbol(&self) -> String {
783 match self.value {
784 u32::MAX => with_deinterned(self.key, |key| format!("@{:?}.{}@", self.flag_type, key)),
785 value => with_deinterned(self.key, |key| {
786 with_deinterned(value, |value| {
787 format!("@{:?}.{}.{}@", self.flag_type, key, value)
788 })
789 }),
790 }
791 }
792
793 #[cfg(feature = "python")]
794 #[getter]
795 fn flag_type(&self) -> String {
796 format!("{:?}", self.flag_type)
797 }
798
799 #[cfg(not(feature = "python"))]
800 pub fn flag_type(&self) -> String {
802 format!("{:?}", self.flag_type)
803 }
804
805 #[cfg(feature = "python")]
806 #[getter]
807 fn key(&self) -> String {
808 deintern(self.key)
809 }
810
811 #[cfg(feature = "python")]
812 #[getter]
813 fn value(&self) -> String {
814 deintern(self.value)
815 }
816
817 #[cfg(feature = "python")]
818 #[new]
819 fn new(flag_type: String, key: String, value: Option<String>) -> KFSTResult<Self> {
820 FlagDiacriticSymbol::_new(flag_type, key, value)
821 }
822
823 #[deprecated]
824 pub fn __repr__(&self) -> String {
826 match self.value {
827 u32::MAX => with_deinterned(self.key, |key| {
828 format!("FlagDiacriticSymbol({:?}, {:?})", self.flag_type, key)
829 }),
830 value => with_deinterned(self.key, |key| {
831 with_deinterned(value, |value| {
832 format!(
833 "FlagDiacriticSymbol({:?}, {:?}, {:?})",
834 self.flag_type, key, value
835 )
836 })
837 }),
838 }
839 }
840
841 #[cfg(feature = "python")]
842 #[staticmethod]
843 fn is_flag_diacritic(symbol: &str) -> bool {
844 matches!(FlagDiacriticSymbol::parse(symbol), Ok(("", _)))
845 }
846
847 #[cfg(not(feature = "python"))]
848 pub fn is_flag_diacritic(symbol: &str) -> bool {
851 matches!(FlagDiacriticSymbol::parse(symbol), Ok(("", _)))
852 }
853
854 #[cfg(feature = "python")]
855 #[staticmethod]
856 fn from_symbol_string(symbol: &str) -> KFSTResult<Self> {
857 FlagDiacriticSymbol::_from_symbol_string(symbol)
858 }
859}
860
861impl std::fmt::Debug for FlagDiacriticSymbol {
862 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
863 write!(f, "{}", self.get_symbol())
864 }
865}
866
867#[cfg_attr(feature = "python", pyclass(eq, ord, frozen, hash))]
868#[derive(PartialEq, Eq, Clone, Hash, Copy)]
869pub enum SpecialSymbol {
871 EPSILON,
875 IDENTITY,
879 UNKNOWN,
882}
883
884impl SpecialSymbol {
885 pub fn parse(symbol: &str) -> nom::IResult<&str, SpecialSymbol> {
894 let (rest, value) = alt((
895 tag("@_EPSILON_SYMBOL_@"),
896 tag("@0@"),
897 tag("@_IDENTITY_SYMBOL_@"),
898 tag("@_UNKNOWN_SYMBOL_@"),
899 ))
900 .parse(symbol)?;
901
902 let sym = match value {
903 "@_EPSILON_SYMBOL_@" => SpecialSymbol::EPSILON,
904 "@0@" => SpecialSymbol::EPSILON,
905 "@_IDENTITY_SYMBOL_@" => SpecialSymbol::IDENTITY,
906 "@_UNKNOWN_SYMBOL_@" => SpecialSymbol::UNKNOWN,
907 _ => panic!(),
908 };
909 Ok((rest, sym))
910 }
911
912 fn _from_symbol_string(symbol: &str) -> KFSTResult<Self> {
913 match SpecialSymbol::parse(symbol) {
914 Ok(("", result)) => KFSTResult::Ok(result),
915 _ => value_error(format!("Not a valid SpecialSymbol: {symbol:?}")),
916 }
917 }
918
919 fn with_symbol<F, X>(&self, f: F) -> X
922 where
923 F: FnOnce(&str) -> X,
924 {
925 match self {
926 SpecialSymbol::EPSILON => f("@_EPSILON_SYMBOL_@"),
927 SpecialSymbol::IDENTITY => f("@_IDENTITY_SYMBOL_@"),
928 SpecialSymbol::UNKNOWN => f("@_UNKNOWN_SYMBOL_@"),
929 }
930 }
931
932 #[cfg(not(feature = "python"))]
933 pub fn from_symbol_string(symbol: &str) -> KFSTResult<Self> {
946 SpecialSymbol::_from_symbol_string(symbol)
947 }
948}
949
950impl PartialOrd for SpecialSymbol {
951 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
952 Some(self.cmp(other))
953 }
954}
955
956impl Ord for SpecialSymbol {
957 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
958 self.with_symbol(|self_str| {
960 other.with_symbol(|other_str| {
961 (other_str.chars().count(), &self_str).cmp(&(self_str.chars().count(), &other_str))
962 })
963 })
964 }
965}
966
967#[cfg_attr(feature = "python", pymethods)]
968impl SpecialSymbol {
969 pub fn is_epsilon(&self) -> bool {
973 self == &SpecialSymbol::EPSILON
974 }
975
976 pub fn is_unknown(&self) -> bool {
980 false
981 }
982
983 pub fn get_symbol(&self) -> String {
990 self.with_symbol(|x| x.to_string())
991 }
992
993 #[cfg(feature = "python")]
994 #[staticmethod]
995 fn from_symbol_string(symbol: &str) -> KFSTResult<Self> {
996 SpecialSymbol::_from_symbol_string(symbol)
997 }
998
999 #[cfg(feature = "python")]
1000 #[staticmethod]
1001 fn is_special_symbol(symbol: &str) -> bool {
1002 SpecialSymbol::from_symbol_string(symbol).is_ok()
1003 }
1004
1005 #[cfg(not(feature = "python"))]
1006 pub fn is_special_symbol(symbol: &str) -> bool {
1018 SpecialSymbol::from_symbol_string(symbol).is_ok()
1019 }
1020}
1021
1022impl std::fmt::Debug for SpecialSymbol {
1023 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1024 self.with_symbol(|symbol| write!(f, "{symbol}"))
1025 }
1026}
1027
1028impl std::fmt::Debug for StringSymbol {
1029 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1030 self.with_symbol(|symbol| write!(f, "{symbol}"))
1031 }
1032}
1033
1034#[cfg(feature = "python")]
1035#[pyfunction]
1036fn from_symbol_string(symbol: &str, py: Python) -> PyResult<Py<PyAny>> {
1037 Symbol::parse(symbol).unwrap().1.into_py_any(py)
1038}
1039
1040#[cfg(not(feature = "python"))]
1041pub fn from_symbol_string(symbol: &str) -> Option<Symbol> {
1050 Symbol::parse(symbol).ok().map(|(_, sym)| sym)
1051}
1052
1053#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1056pub enum Symbol {
1057 Special(SpecialSymbol),
1059 Flag(FlagDiacriticSymbol),
1061 String(StringSymbol),
1063 #[cfg(feature = "python")]
1064 External(PyObjectSymbol),
1066 Raw(RawSymbol),
1068}
1069
1070#[cfg(feature = "python")]
1071impl<'py> IntoPyObject<'py> for Symbol {
1072 type Target = PyAny;
1073
1074 type Output = Bound<'py, Self::Target>;
1075
1076 type Error = pyo3::PyErr;
1077
1078 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
1079 match self {
1080 Symbol::Special(special_symbol) => special_symbol.into_bound_py_any(py),
1081 Symbol::Flag(flag_diacritic_symbol) => flag_diacritic_symbol.into_bound_py_any(py),
1082 Symbol::String(string_symbol) => string_symbol.into_bound_py_any(py),
1083 Symbol::External(pyobject_symbol) => pyobject_symbol.into_bound_py_any(py),
1084 Symbol::Raw(raw_symbol) => raw_symbol.into_bound_py_any(py),
1085 }
1086 }
1087}
1088
1089impl PartialOrd for Symbol {
1090 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
1091 Some(self.cmp(other))
1092 }
1093}
1094
1095impl Ord for Symbol {
1096 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
1106 let self_is_tokenizable = matches!(
1109 self,
1110 Symbol::Special(_) | Symbol::Flag(_) | Symbol::String(_)
1111 );
1112
1113 let other_is_tokenizable = matches!(
1114 other,
1115 Symbol::Special(_) | Symbol::Flag(_) | Symbol::String(_)
1116 );
1117
1118 match (self_is_tokenizable, other_is_tokenizable) {
1119 (true, true) => {
1121 let result = self.with_symbol(|self_sym| {
1125 other.with_symbol(|other_sym| {
1126 (other_sym.chars().count(), &self_sym)
1127 .cmp(&(self_sym.chars().count(), &other_sym))
1128 })
1129 });
1130 if result != Ordering::Equal {
1131 result
1132 } else {
1133 match (self, other) {
1134 (Symbol::Special(_), Symbol::Flag(_)) => Ordering::Greater,
1136 (Symbol::Special(_), Symbol::String(_)) => Ordering::Less,
1137 (Symbol::Flag(_), Symbol::Special(_)) => Ordering::Less,
1138 (Symbol::Flag(_), Symbol::String(_)) => Ordering::Less,
1139 (Symbol::String(_), Symbol::Special(_)) => Ordering::Greater,
1140 (Symbol::String(_), Symbol::Flag(_)) => Ordering::Greater,
1141
1142 (Symbol::Special(a), Symbol::Special(b)) => a.cmp(b),
1144 (Symbol::Flag(a), Symbol::Flag(b)) => a.cmp(b),
1145 (Symbol::String(a), Symbol::String(b)) => a.cmp(b),
1146
1147 _ => unreachable!(),
1148 }
1149 }
1150 }
1151 _ => {
1153 match (self, other) {
1154 #[cfg(feature = "python")]
1155 (Symbol::External(left), right) => {
1156 Python::with_gil(|py| {
1157 if left
1160 .value
1161 .getattr(py, "__lt__")
1162 .unwrap_or_else(|_| {
1163 panic!(
1164 "Symbol {} doesn't have a __lt__ implementation.",
1165 left.value
1166 )
1167 })
1168 .call1(py, (right.clone().into_pyobject(py).unwrap(),))
1169 .unwrap_or_else(|_| {
1170 panic!(
1171 "__lt__ on symbol {} failed to return a value.",
1172 left.value
1173 )
1174 })
1175 .extract::<bool>(py)
1176 .unwrap_or_else(|_| {
1177 panic!("__lt__ on symbol {} didn't return a bool.", left.value)
1178 })
1179 {
1180 return Ordering::Less;
1181 }
1182
1183 if left
1186 .value
1187 .getattr(py, "__eq__")
1188 .unwrap_or_else(|_| {
1189 panic!(
1190 "Symbol {} doesn't have a __eq__ implementation.",
1191 left.value
1192 )
1193 })
1194 .call1(py, (right.clone().into_pyobject(py).unwrap(),))
1195 .unwrap_or_else(|_| {
1196 panic!(
1197 "__eq__ on symbol {} failed to return a value.",
1198 left.value
1199 )
1200 })
1201 .extract::<bool>(py)
1202 .unwrap_or_else(|_| {
1203 panic!("__eq__ on symbol {} didn't return a bool.", left.value)
1204 })
1205 {
1206 return Ordering::Equal;
1207 }
1208
1209 Ordering::Greater
1212 })
1213 }
1214 #[cfg(feature = "python")]
1215 (left, Symbol::External(right)) => {
1216 Python::with_gil(|py| {
1217 if right
1220 .value
1221 .getattr(py, "__lt__")
1222 .unwrap_or_else(|_| {
1223 panic!(
1224 "Symbol {} doesn't have a __lt__ implementation.",
1225 right.value
1226 )
1227 })
1228 .call1(py, (left.clone().into_pyobject(py).unwrap(),))
1229 .unwrap_or_else(|_| {
1230 panic!(
1231 "__lt__ on symbol {} failed to return a value.",
1232 right.value
1233 )
1234 })
1235 .extract::<bool>(py)
1236 .unwrap_or_else(|_| {
1237 panic!("__lt__ on symbol {} didn't return a bool.", right.value)
1238 })
1239 {
1240 return Ordering::Greater;
1241 }
1242
1243 if right
1246 .value
1247 .getattr(py, "__eq__")
1248 .unwrap_or_else(|_| {
1249 panic!(
1250 "Symbol {} doesn't have a __eq__ implementation.",
1251 right.value
1252 )
1253 })
1254 .call1(py, (left.clone().into_pyobject(py).unwrap(),))
1255 .unwrap_or_else(|_| {
1256 panic!(
1257 "__eq__ on symbol {} failed to return a value.",
1258 right.value
1259 )
1260 })
1261 .extract::<bool>(py)
1262 .unwrap_or_else(|_| {
1263 panic!("__eq__ on symbol {} didn't return a bool.", right.value)
1264 })
1265 {
1266 return Ordering::Equal;
1267 }
1268
1269 Ordering::Less
1272 })
1273 }
1274
1275 (Symbol::Raw(a), Symbol::Raw(b)) => a.cmp(b),
1277
1278 (Symbol::Raw(_), _) => Ordering::Less,
1280 (_, Symbol::Raw(_)) => Ordering::Greater,
1281
1282 _ => unreachable!(),
1283 }
1284 }
1285 }
1286 }
1287}
1288
1289impl Symbol {
1290 pub fn is_epsilon(&self) -> bool {
1297 match self {
1298 Symbol::Special(special_symbol) => special_symbol.is_epsilon(),
1299 Symbol::Flag(flag_diacritic_symbol) => flag_diacritic_symbol.is_epsilon(),
1300 Symbol::String(string_symbol) => string_symbol.is_epsilon(),
1301 #[cfg(feature = "python")]
1302 Symbol::External(py_object_symbol) => py_object_symbol.is_epsilon(),
1303 Symbol::Raw(raw_symbol) => raw_symbol.is_epsilon(),
1304 }
1305 }
1306
1307 pub fn is_unknown(&self) -> bool {
1310 match self {
1311 Symbol::Special(special_symbol) => special_symbol.is_unknown(),
1312 Symbol::Flag(flag_diacritic_symbol) => flag_diacritic_symbol.is_unknown(),
1313 Symbol::String(string_symbol) => string_symbol.is_unknown(),
1314 #[cfg(feature = "python")]
1315 Symbol::External(py_object_symbol) => py_object_symbol.is_unknown(),
1316 Symbol::Raw(raw_symbol) => raw_symbol.is_unknown(),
1317 }
1318 }
1319
1320 pub fn get_symbol(&self) -> String {
1322 match self {
1323 Symbol::Special(special_symbol) => special_symbol.get_symbol(),
1324 Symbol::Flag(flag_diacritic_symbol) => flag_diacritic_symbol.get_symbol(),
1325 Symbol::String(string_symbol) => string_symbol.get_symbol(),
1326 #[cfg(feature = "python")]
1327 Symbol::External(py_object_symbol) => py_object_symbol.get_symbol(),
1328 Symbol::Raw(raw_symbol) => raw_symbol.get_symbol(),
1329 }
1330 }
1331
1332 pub fn with_symbol<F, X>(&self, f: F) -> X
1335 where
1336 F: FnOnce(&str) -> X,
1337 {
1338 match self {
1339 Symbol::Special(special_symbol) => special_symbol.with_symbol(f),
1340 Symbol::Flag(flag_diacritic_symbol) => f(&flag_diacritic_symbol.get_symbol()),
1341 Symbol::String(string_symbol) => string_symbol.with_symbol(f),
1342 #[cfg(feature = "python")]
1343 Symbol::External(py_object_symbol) => f(&py_object_symbol.get_symbol()),
1344 Symbol::Raw(raw_symbol) => f(&raw_symbol.get_symbol()),
1345 }
1346 }
1347}
1348
1349impl Symbol {
1350 pub fn parse(symbol: &str) -> nom::IResult<&str, Symbol> {
1368 let mut parser = alt((
1369 |x| {
1370 (FlagDiacriticSymbol::parse, nom::combinator::eof)
1371 .parse(x)
1372 .map(|y| (y.0, Symbol::Flag(y.1 .0)))
1373 },
1374 |x| {
1375 (SpecialSymbol::parse, nom::combinator::eof)
1376 .parse(x)
1377 .map(|y| (y.0, Symbol::Special(y.1 .0)))
1378 },
1379 |x| StringSymbol::parse(x).map(|y| (y.0, Symbol::String(y.1))),
1380 ));
1381 parser.parse(symbol)
1382 }
1383}
1384
1385#[cfg(feature = "python")]
1386impl FromPyObject<'_> for Symbol {
1387 fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
1388 ob.extract()
1389 .map(Symbol::Special)
1390 .or_else(|_| ob.extract().map(Symbol::Flag))
1391 .or_else(|_| ob.extract().map(Symbol::String))
1392 .or_else(|_| ob.extract().map(Symbol::Raw))
1393 .or_else(|_| ob.extract().map(Symbol::External))
1394 }
1395}
1396#[derive(Clone, Debug, PartialEq, Hash, PartialOrd, Eq, Ord)]
1397#[readonly::make]
1398pub struct FlagMap(Vec<(u32, bool, u32)>);
1405
1406impl FromIterator<(String, (bool, String))> for FlagMap {
1409 fn from_iter<T: IntoIterator<Item = (String, (bool, String))>>(iter: T) -> Self {
1410 let mut vals: Vec<_> = iter
1411 .into_iter()
1412 .map(|(a, (b, c))| (intern(a), b, intern(c)))
1413 .collect();
1414 vals.sort();
1415 FlagMap(vals)
1416 }
1417}
1418
1419impl<T> From<T> for FlagMap
1422where
1423 T: IntoIterator<Item = (String, (bool, String))>,
1424{
1425 fn from(value: T) -> Self {
1426 FlagMap::from_iter(value)
1427 }
1428}
1429
1430impl FlagMap {
1431 pub fn new() -> FlagMap {
1433 FlagMap(vec![])
1434 }
1435
1436 pub fn remove(&self, flag: u32) -> FlagMap {
1438 let pp = self.0.partition_point(|v| v.0 < flag);
1439 if pp < self.0.len() && self.0[pp].0 == flag {
1440 let mut new_vals = self.0.clone();
1441 new_vals.remove(pp);
1442 FlagMap(new_vals)
1443 } else {
1444 self.clone()
1445 }
1446 }
1447
1448 pub fn get(&self, flag: u32) -> Option<(bool, u32)> {
1450 let pp = self.0.partition_point(|v| v.0 < flag);
1451 if pp < self.0.len() && self.0[pp].0 == flag {
1452 Some((self.0[pp].1, self.0[pp].2))
1453 } else {
1454 None
1455 }
1456 }
1457
1458 pub fn insert(&self, flag: u32, value: (bool, u32)) -> FlagMap {
1461 let pp = self.0.partition_point(|v| v.0 < flag);
1462 let mut new_vals = self.0.clone();
1463 if pp == self.0.len() || self.0[pp].0 != flag {
1464 new_vals.insert(pp, (flag, value.0, value.1));
1465 } else {
1466 new_vals[pp] = (flag, value.0, value.1);
1467 }
1468 FlagMap(new_vals)
1469 }
1470}
1471
1472impl Default for FlagMap {
1473 fn default() -> Self {
1475 Self::new()
1476 }
1477}
1478
1479#[cfg(feature = "python")]
1480impl FromPyObject<'_> for FlagMap {
1481 fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
1482 let mut as_map: Vec<(u32, bool, u32)> = ob
1483 .getattr("items")?
1484 .call0()?
1485 .try_iter()?
1486 .map(|x| x.unwrap().extract().unwrap())
1487 .map(|(key, value): (String, (bool, String))| (intern(key), value.0, intern(value.1)))
1488 .collect();
1489 as_map.sort();
1490 Ok(FlagMap { 0: as_map })
1491 }
1492}
1493
1494#[cfg(feature = "python")]
1495impl<'py> IntoPyObject<'py> for FlagMap {
1496 type Target = PyAny;
1497
1498 type Output = Bound<'py, Self::Target>;
1499
1500 type Error = pyo3::PyErr;
1501
1502 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
1503 let collection = self
1504 .0
1505 .into_iter()
1506 .map(|(a, b, c)| (deintern(a), (b, deintern(c))))
1507 .collect::<Vec<_>>()
1508 .into_pyobject(py)?;
1509 let immutables = PyModule::import(py, "immutables")?;
1510 let map_class = immutables.getattr("Map")?;
1511 let new = map_class.call_method1("__new__", (&map_class,))?;
1512 map_class.call_method1("__init__", (&new, collection))?;
1513 Ok(new)
1514 }
1515}
1516
1517#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
1520pub enum FSTLinkedList<T> {
1525 Some((Symbol, T), Arc<FSTLinkedList<T>>),
1526 None,
1527}
1528
1529impl<T: Clone + Default> IntoIterator for FSTLinkedList<T> {
1530 type Item = (Symbol, T);
1531
1532 type IntoIter = <Vec<<FSTLinkedList<T> as IntoIterator>::Item> as IntoIterator>::IntoIter;
1533
1534 fn into_iter(self) -> Self::IntoIter {
1535 self.to_vec().into_iter()
1536 }
1537}
1538
1539impl<T> FromIterator<(Symbol, T)> for FSTLinkedList<T> {
1540 fn from_iter<I: IntoIterator<Item = (Symbol, T)>>(iter: I) -> Self {
1541 let mut symbol_mappings = FSTLinkedList::None;
1542 for (symbol, input_index) in iter {
1543 symbol_mappings = FSTLinkedList::Some((symbol, input_index), Arc::new(symbol_mappings));
1544 }
1545 symbol_mappings
1546 }
1547}
1548
1549impl<T: Clone + Default> FSTLinkedList<T> {
1550 fn to_vec(self) -> Vec<(Symbol, T)> {
1551 let mut symbol_mappings = vec![];
1552 let mut symbol_mapping_state = &self;
1553 while let FSTLinkedList::Some(value, list) = symbol_mapping_state {
1554 symbol_mapping_state = list;
1555 symbol_mappings.push(value.clone());
1556 }
1557 symbol_mappings.into_iter().rev().collect()
1558 }
1559
1560 fn from_vec(output_symbols: Vec<Symbol>, input_indices: Vec<T>) -> FSTLinkedList<T> {
1561 if input_indices.is_empty() {
1562 output_symbols
1563 .into_iter()
1564 .zip(std::iter::repeat(T::default()))
1565 .collect()
1566 } else if input_indices.len() == output_symbols.len() {
1567 output_symbols.into_iter().zip(input_indices).collect()
1568 } else {
1569 panic!("Mismatch in input index and output symbol len: {} output symbols and {} input indices.", output_symbols.len(), input_indices.len());
1570 }
1571 }
1572}
1573
1574#[cfg(not(feature = "python"))]
1575pub type FSTState<T> = InternalFSTState<T>;
1576
1577#[cfg(feature = "python")]
1578#[cfg_attr(feature = "python", pyclass(frozen, eq, hash))]
1579#[derive(Clone, Debug, PartialEq, PartialOrd, Hash)]
1580pub struct FSTState {
1581 payload: Result<InternalFSTState<usize>, InternalFSTState<()>>,
1582}
1583
1584#[cfg(feature = "python")]
1585#[pymethods]
1586impl FSTState {
1587 #[new]
1588 #[pyo3(signature = (state_num, path_weight=0.0, input_flags = FlagMap::new(), output_flags = FlagMap::new(), output_symbols=vec![], input_indices=Some(vec![])))]
1589 fn new(
1590 state_num: u64,
1591 path_weight: f64,
1592 input_flags: FlagMap,
1593 output_flags: FlagMap,
1594 output_symbols: Vec<Symbol>,
1595 input_indices: Option<Vec<usize>>,
1596 ) -> Self {
1597 match input_indices {
1598 Some(idxs) => FSTState {
1599 payload: Ok(InternalFSTState::new(
1600 state_num,
1601 path_weight,
1602 input_flags,
1603 output_flags,
1604 output_symbols,
1605 idxs,
1606 )),
1607 },
1608 None => FSTState {
1609 payload: Err(InternalFSTState::new(
1610 state_num,
1611 path_weight,
1612 input_flags,
1613 output_flags,
1614 output_symbols,
1615 vec![],
1616 )),
1617 },
1618 }
1619 }
1620
1621 fn strip_indices(&self) -> FSTState {
1622 match &self.payload {
1623 Ok(p) => FSTState {
1624 payload: Err(p.clone().convert_indices()),
1625 },
1626 Err(_p) => self.clone(),
1627 }
1628 }
1629
1630 fn ensure_indices(&self) -> FSTState {
1631 match &self.payload {
1632 Ok(_p) => self.clone(),
1633 Err(p) => FSTState {
1634 payload: Ok(p.clone().convert_indices()),
1635 },
1636 }
1637 }
1638
1639 #[getter]
1640 fn output_symbols<'a>(&'a self, py: Python<'a>) -> Result<Bound<'a, PyTuple>, PyErr> {
1641 PyTuple::new(
1642 py,
1643 match &self.payload {
1644 Ok(p) => p.output_symbols(),
1645 Err(p) => p.output_symbols(),
1646 },
1647 )
1648 }
1649
1650 #[getter]
1651 fn input_indices<'a>(&'a self, py: Python<'a>) -> PyResult<Bound<'a, PyAny>> {
1652 match &self.payload {
1653 Ok(p) => PyTuple::new(py, p.input_indices().unwrap())?.into_bound_py_any(py),
1654 Err(_p) => PyNone::get(py).into_bound_py_any(py),
1655 }
1656 }
1657
1658 #[deprecated]
1659 pub fn __repr__(&self) -> String {
1661 match &self.payload {
1662 Ok(p) => p.__repr__(),
1663 Err(p) => p.__repr__(),
1664 }
1665 }
1666
1667 #[getter]
1668 pub fn state_num(&self) -> u64 {
1669 match &self.payload {
1670 Ok(p) => p.state_num,
1671 Err(p) => p.state_num,
1672 }
1673 }
1674
1675 #[getter]
1676 pub fn path_weight(&self) -> f64 {
1677 match &self.payload {
1678 Ok(p) => p.path_weight,
1679 Err(p) => p.path_weight,
1680 }
1681 }
1682
1683 #[getter]
1684 pub fn input_flags(&self) -> FlagMap {
1685 match &self.payload {
1686 Ok(p) => p.input_flags.clone(),
1687 Err(p) => p.input_flags.clone(),
1688 }
1689 }
1690
1691 #[getter]
1692 pub fn output_flags(&self) -> FlagMap {
1693 match &self.payload {
1694 Ok(p) => p.output_flags.clone(),
1695 Err(p) => p.output_flags.clone(),
1696 }
1697 }
1698}
1699
1700#[cfg(feature = "python")]
1701impl<T: UsizeOrUnit + Default> FromPyObject<'_> for InternalFSTState<T> {
1702 fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
1703 let wrapped: FSTState = ob.extract()?;
1704 match wrapped.payload {
1705 Ok(p) => Ok(p.convert_indices()),
1706 Err(p) => Ok(p.convert_indices()),
1707 }
1708 }
1709}
1710
1711#[cfg(feature = "python")]
1712impl<'py> IntoPyObject<'py> for InternalFSTState<usize> {
1713 type Target = FSTState;
1714
1715 type Output = Bound<'py, Self::Target>;
1716
1717 type Error = pyo3::PyErr;
1718
1719 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
1720 FSTState { payload: Ok(self) }.into_pyobject(py)
1721 }
1722}
1723
1724#[cfg(feature = "python")]
1725impl<'py> IntoPyObject<'py> for InternalFSTState<()> {
1726 type Target = FSTState;
1727
1728 type Output = Bound<'py, Self::Target>;
1729
1730 type Error = pyo3::PyErr;
1731
1732 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
1733 FSTState { payload: Err(self) }.into_pyobject(py)
1734 }
1735}
1736
1737#[derive(Clone, Debug, PartialEq, PartialOrd)]
1738#[readonly::make]
1739pub struct InternalFSTState<T> {
1748 pub state_num: u64,
1750 pub path_weight: f64,
1752 pub input_flags: FlagMap,
1754 pub output_flags: FlagMap,
1756 pub symbol_mappings: FSTLinkedList<T>,
1758}
1759
1760impl<T: Clone + Default + Debug + UsizeOrUnit> InternalFSTState<T> {
1761 fn convert_indices<T2: UsizeOrUnit + Default>(self) -> InternalFSTState<T2> {
1762 let new_mappings = self
1763 .symbol_mappings
1764 .into_iter()
1765 .map(|(sym, idx)| (sym, T2::convert(idx)))
1766 .collect();
1767 InternalFSTState {
1768 state_num: self.state_num,
1769 path_weight: self.path_weight,
1770 input_flags: self.input_flags,
1771 output_flags: self.output_flags,
1772 symbol_mappings: new_mappings,
1773 }
1774 }
1775}
1776
1777impl<T> Default for InternalFSTState<T> {
1778 fn default() -> Self {
1780 Self {
1781 state_num: 0,
1782 path_weight: 0.0,
1783 input_flags: FlagMap::new(),
1784 output_flags: FlagMap::new(),
1785 symbol_mappings: FSTLinkedList::None,
1786 }
1787 }
1788}
1789
1790impl<T: Hash> Hash for InternalFSTState<T> {
1791 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
1792 self.state_num.hash(state);
1793 self.path_weight.to_be_bytes().hash(state);
1794 self.input_flags.hash(state);
1795 self.output_flags.hash(state);
1796 self.symbol_mappings.hash(state);
1797 }
1798}
1799
1800fn _test_flag(stored_val: &(bool, u32), queried_val: u32) -> bool {
1801 stored_val.0 == (stored_val.1 == queried_val)
1802}
1803
1804impl<T: Clone + Default> InternalFSTState<T> {
1805 fn _new(state: u64) -> Self {
1806 InternalFSTState {
1807 state_num: state,
1808 path_weight: 0.0,
1809 input_flags: FlagMap::new(),
1810 output_flags: FlagMap::new(),
1811 symbol_mappings: FSTLinkedList::None,
1812 }
1813 }
1814
1815 fn __new<F>(
1816 state: u64,
1817 path_weight: f64,
1818 input_flags: F,
1819 output_flags: F,
1820 output_symbols: Vec<Symbol>,
1821 input_indices: Vec<T>,
1822 ) -> Self
1823 where
1824 F: Into<FlagMap>,
1825 {
1826 InternalFSTState {
1827 state_num: state,
1828 path_weight,
1829 input_flags: input_flags.into(),
1830 output_flags: output_flags.into(),
1831 symbol_mappings: FSTLinkedList::from_vec(output_symbols, input_indices),
1832 }
1833 }
1834
1835 #[cfg(not(feature = "python"))]
1836 pub fn new<F>(
1840 state_num: u64,
1841 path_weight: f64,
1842 input_flags: F,
1843 output_flags: F,
1844 output_symbols: Vec<Symbol>,
1845 input_indices: Vec<T>,
1846 ) -> Self
1847 where
1848 F: Into<FlagMap>,
1849 {
1850 InternalFSTState::__new(
1851 state_num,
1852 path_weight,
1853 input_flags,
1854 output_flags,
1855 output_symbols,
1856 input_indices,
1857 )
1858 }
1859}
1860
1861impl<T: Clone + Debug + Default + UsizeOrUnit> InternalFSTState<T> {
1862 #[cfg(feature = "python")]
1863 fn new(
1864 state_num: u64,
1865 path_weight: f64,
1866 input_flags: FlagMap,
1867 output_flags: FlagMap,
1868 output_symbols: Vec<Symbol>,
1869 input_indices: Vec<T>,
1870 ) -> Self {
1871 InternalFSTState {
1872 state_num,
1873 path_weight,
1874 input_flags,
1875 output_flags,
1876 symbol_mappings: FSTLinkedList::from_vec(output_symbols, input_indices),
1877 }
1878 }
1879
1880 pub fn output_symbols(&self) -> Vec<Symbol> {
1894 self.symbol_mappings
1895 .clone()
1896 .into_iter()
1897 .map(|x| x.0)
1898 .collect()
1899 }
1900
1901 pub fn input_indices(&self) -> Option<Vec<usize>> {
1917 T::branch(
1918 || {
1919 Some(
1920 self.symbol_mappings
1921 .clone()
1922 .into_iter()
1923 .map(|x| x.1.as_usize())
1924 .collect(),
1925 )
1926 },
1927 || None,
1928 )
1929 }
1930
1931 #[deprecated]
1932 pub fn __repr__(&self) -> String {
1934 format!(
1935 "FSTState({}, {}, {:?}, {:?}, {:?}, {:?})",
1936 self.state_num,
1937 self.path_weight,
1938 self.input_flags,
1939 self.output_flags,
1940 self.symbol_mappings
1941 .clone()
1942 .into_iter()
1943 .map(|x| x.0)
1944 .collect::<Vec<_>>(),
1945 self.symbol_mappings
1946 .clone()
1947 .into_iter()
1948 .map(|x| x.1)
1949 .collect::<Vec<_>>()
1950 )
1951 }
1952}
1953
1954fn unescape_att_symbol(att_symbol: &str) -> String {
1957 att_symbol
1958 .replace("@_TAB_@", "\t")
1959 .replace("@_SPACE_@", " ")
1960}
1961
1962fn escape_att_symbol(symbol: &str) -> String {
1965 symbol.replace("\t", "@_TAB_@").replace(" ", "@_SPACE_@")
1966}
1967
1968pub trait UsizeOrUnit: Sized {
1969 const ZERO: Self;
1970 const INCREMENT: Self;
1971 fn branch<F1, F2, B>(f1: F1, f2: F2) -> B
1972 where
1973 F1: FnOnce() -> B,
1974 F2: FnOnce() -> B;
1975 fn as_usize(&self) -> usize;
1976 fn convert<T: UsizeOrUnit>(x: T) -> Self;
1977}
1978
1979impl UsizeOrUnit for usize {
1980 const ZERO: Self = 0;
1981
1982 const INCREMENT: Self = 1;
1983
1984 fn as_usize(&self) -> usize {
1985 *self
1986 }
1987
1988 fn convert<T: UsizeOrUnit>(x: T) -> Self {
1989 x.as_usize()
1990 }
1991
1992 fn branch<F1, F2, B>(f1: F1, _f2: F2) -> B
1993 where
1994 F1: FnOnce() -> B,
1995 F2: FnOnce() -> B,
1996 {
1997 f1()
1998 }
1999}
2000
2001impl UsizeOrUnit for () {
2002 fn as_usize(&self) -> usize {
2003 0
2004 }
2005
2006 const ZERO: Self = ();
2007
2008 const INCREMENT: Self = ();
2009
2010 fn convert<T: UsizeOrUnit>(_x: T) -> Self {}
2011
2012 fn branch<F1, F2, B>(_f1: F1, f2: F2) -> B
2013 where
2014 F1: FnOnce() -> B,
2015 F2: FnOnce() -> B,
2016 {
2017 f2()
2018 }
2019}
2020
2021#[cfg_attr(feature = "python", pyclass(frozen, get_all))]
2022#[readonly::make]
2023pub struct FST {
2062 pub final_states: IndexMap<u64, f64>,
2064 pub rules: IndexMap<u64, IndexMap<Symbol, Vec<(u64, Symbol, f64)>>>,
2066 pub symbols: Vec<Symbol>,
2068 #[deprecated]
2070 pub debug: bool,
2071}
2072
2073impl FST {
2074 fn _run_fst<T: Clone + UsizeOrUnit>(
2075 &self,
2076 input_symbols: &[Symbol],
2077 state: &InternalFSTState<T>,
2078 post_input_advance: bool,
2079 result: &mut Vec<(bool, bool, InternalFSTState<T>)>,
2080 input_symbol_index: usize,
2081 keep_non_final: bool,
2082 ) {
2083 let transitions = self.rules.get(&state.state_num);
2084 let isymbol = if input_symbols.len() - input_symbol_index == 0 {
2085 match self.final_states.get(&state.state_num) {
2086 Some(&weight) => {
2087 result.push((
2089 true,
2090 post_input_advance,
2091 InternalFSTState {
2092 state_num: state.state_num,
2093 path_weight: state.path_weight + weight,
2094 input_flags: state.input_flags.clone(),
2095 output_flags: state.output_flags.clone(),
2096 symbol_mappings: state.symbol_mappings.clone(),
2097 },
2098 ));
2099 }
2100 None => {
2101 if keep_non_final {
2103 result.push((false, post_input_advance, state.clone()));
2104 }
2105 }
2106 }
2107 None
2108 } else {
2109 Some(&input_symbols[input_symbol_index])
2110 };
2111 if let Some(transitions) = transitions {
2112 for transition_isymbol in transitions.keys() {
2113 if transition_isymbol.is_epsilon() || isymbol == Some(transition_isymbol) {
2114 self._transition(
2115 input_symbols,
2116 input_symbol_index,
2117 state,
2118 &transitions[transition_isymbol],
2119 isymbol,
2120 transition_isymbol,
2121 result,
2122 keep_non_final,
2123 );
2124 }
2125 }
2126 if let Some(isymbol) = isymbol {
2127 if isymbol.is_unknown() {
2128 if let Some(transition_list) =
2129 transitions.get(&Symbol::Special(SpecialSymbol::UNKNOWN))
2130 {
2131 self._transition(
2132 input_symbols,
2133 input_symbol_index,
2134 state,
2135 transition_list,
2136 Some(isymbol),
2137 &Symbol::Special(SpecialSymbol::UNKNOWN),
2138 result,
2139 keep_non_final,
2140 );
2141 }
2142
2143 if let Some(transition_list) =
2144 transitions.get(&Symbol::Special(SpecialSymbol::IDENTITY))
2145 {
2146 self._transition(
2147 input_symbols,
2148 input_symbol_index,
2149 state,
2150 transition_list,
2151 Some(isymbol),
2152 &Symbol::Special(SpecialSymbol::IDENTITY),
2153 result,
2154 keep_non_final,
2155 );
2156 }
2157 }
2158 }
2159 }
2160 }
2161
2162 fn _transition<T: Clone + UsizeOrUnit>(
2163 &self,
2164 input_symbols: &[Symbol],
2165 input_symbol_index: usize,
2166 state: &InternalFSTState<T>,
2167 transitions: &[(u64, Symbol, f64)],
2168 isymbol: Option<&Symbol>,
2169 transition_isymbol: &Symbol,
2170 result: &mut Vec<(bool, bool, InternalFSTState<T>)>,
2171 keep_non_final: bool,
2172 ) {
2173 for (next_state, osymbol, weight) in transitions.iter() {
2174 let new_output_flags = _update_flags(osymbol, &state.output_flags);
2175 let new_input_flags = _update_flags(transition_isymbol, &state.input_flags);
2176
2177 match (new_output_flags, new_input_flags) {
2178 (Some(new_output_flags), Some(new_input_flags)) => {
2179 let new_osymbol = match (isymbol, osymbol) {
2180 (Some(isymbol), Symbol::Special(SpecialSymbol::IDENTITY)) => isymbol,
2181 _ => osymbol,
2182 };
2183
2184 let new_symbol_mapping: FSTLinkedList<T> = if new_osymbol.is_epsilon() {
2188 state.symbol_mappings.clone() } else {
2190 FSTLinkedList::Some(
2191 (new_osymbol.clone(), T::convert(input_symbol_index)),
2192 Arc::new(state.symbol_mappings.clone()),
2193 )
2194 };
2195 let new_state = InternalFSTState {
2196 state_num: *next_state,
2197 path_weight: state.path_weight + *weight,
2198 input_flags: new_input_flags,
2199 output_flags: new_output_flags,
2200 symbol_mappings: new_symbol_mapping,
2201 };
2202 if transition_isymbol.is_epsilon() {
2203 self._run_fst(
2204 input_symbols,
2205 &new_state,
2206 input_symbols.is_empty(),
2207 result,
2208 input_symbol_index,
2209 keep_non_final,
2210 );
2211 } else {
2212 self._run_fst(
2213 input_symbols,
2214 &new_state,
2215 false,
2216 result,
2217 input_symbol_index + 1,
2218 keep_non_final,
2219 );
2220 }
2221 }
2222 _ => continue,
2223 }
2224 }
2225 }
2226
2227 pub fn from_att_rows(
2235 rows: Vec<Result<(u64, f64), (u64, u64, Symbol, Symbol, f64)>>,
2236 debug: bool,
2237 ) -> FST {
2238 let mut final_states: IndexMap<u64, f64> = IndexMap::new();
2239 let mut rules: IndexMap<u64, IndexMap<Symbol, Vec<(u64, Symbol, f64)>>> = IndexMap::new();
2240 let mut symbols: IndexSet<Symbol> = IndexSet::new();
2241 for line in rows.into_iter() {
2242 match line {
2243 Ok((state_number, state_weight)) => {
2244 final_states.insert(state_number, state_weight);
2245 }
2246 Err((state_1, state_2, top_symbol, bottom_symbol, weight)) => {
2247 rules.entry(state_1).or_default();
2248 let handle = rules.get_mut(&state_1).unwrap();
2249 if !handle.contains_key(&top_symbol) {
2250 handle.insert(top_symbol.clone(), vec![]);
2251 }
2252 handle.get_mut(&top_symbol).unwrap().push((
2253 state_2,
2254 bottom_symbol.clone(),
2255 weight,
2256 ));
2257 symbols.insert(top_symbol);
2258 symbols.insert(bottom_symbol);
2259 }
2260 }
2261 }
2262 FST::from_rules(
2263 final_states,
2264 rules,
2265 symbols.into_iter().collect(),
2266 Some(debug),
2267 )
2268 }
2269
2270 fn _from_kfst_bytes(kfst_bytes: &[u8]) -> Result<FST, String> {
2271 let mut header = nom::sequence::preceded(
2277 nom::bytes::complete::tag("KFST"),
2278 nom::number::complete::be_u16::<&[u8], ()>,
2279 );
2280 let (rest, version) = header
2281 .parse(kfst_bytes)
2282 .map_err(|_| "Failed to parse header")?;
2283 assert!(version == 0);
2284
2285 let mut metadata = (
2288 nom::number::complete::be_u16::<&[u8], ()>,
2289 nom::number::complete::be_u32,
2290 nom::number::complete::be_u32,
2291 nom::number::complete::u8,
2292 );
2293 let (rest, (num_symbols, num_transitions, num_final_states, is_weighted)) = metadata
2294 .parse(rest)
2295 .map_err(|_| "Failed to parse metadata")?;
2296 let num_transitions: usize = num_transitions
2297 .try_into()
2298 .map_err(|_| "usize too small to represent transitions")?;
2299 let num_final_states: usize = num_final_states
2300 .try_into()
2301 .map_err(|_| "usize too small to represent final states")?;
2302 let is_weighted: bool = is_weighted != 0u8;
2304
2305 let mut symbol = nom::multi::count(
2308 nom::sequence::terminated(nom::bytes::complete::take_until1("\0"), tag("\0")),
2309 num_symbols.into(),
2310 );
2311 let (rest, symbols) = symbol
2312 .parse(rest)
2313 .map_err(|_: nom::Err<()>| "Failed to parse symbol list")?;
2314 let symbol_strings: Vec<&str> = symbols
2315 .into_iter()
2316 .map(|x| std::str::from_utf8(x))
2317 .collect::<Result<Vec<&str>, _>>()
2318 .map_err(|x| format!("Some symbol was not valid utf-8: {x}"))?;
2319 let symbol_list: Vec<Symbol> = symbol_strings
2320 .iter()
2321 .map(|x| {
2322 Symbol::parse(x)
2323 .map_err(|x| {
2324 format!(
2325 "Some symbol while valid utf8 was not a valid symbol specifier: {x}"
2326 )
2327 })
2328 .and_then(|(extra, sym)| {
2329 if extra.is_empty() {
2330 Ok(sym)
2331 } else {
2332 Err(format!(
2333 "Extra data after end of symbol {}: {extra:?}",
2334 sym.get_symbol(),
2335 ))
2336 }
2337 })
2338 })
2339 .collect::<Result<Vec<Symbol>, _>>()?;
2340 let symbol_objs: IndexSet<Symbol> = symbol_list.iter().cloned().collect();
2341
2342 let mut decomp: Vec<u8> = Vec::new();
2345 let mut decoder = XzDecoder::new(rest);
2346 decoder
2347 .read_to_end(&mut decomp)
2348 .map_err(|_| "Failed to xz-decompress remainder of file")?;
2349
2350 let transition_syntax = (
2354 nom::number::complete::be_u32::<&[u8], ()>,
2355 nom::number::complete::be_u32,
2356 nom::number::complete::be_u16,
2357 nom::number::complete::be_u16,
2358 );
2359 let weight_parser = if is_weighted {
2360 nom::number::complete::be_f64
2361 } else {
2362 |input| Ok((input, 0.0)) };
2364 let (rest, file_rules) = many_m_n(
2365 num_transitions,
2366 num_transitions,
2367 (transition_syntax, weight_parser),
2368 )
2369 .parse(decomp.as_slice())
2370 .map_err(|_| "Broken transition table")?;
2371
2372 let (rest, final_states) = many_m_n(
2373 num_final_states,
2374 num_final_states,
2375 (nom::number::complete::be_u32, weight_parser),
2376 )
2377 .parse(rest)
2378 .map_err(|_| "Broken final states")?;
2379
2380 if !rest.is_empty() {
2381 Err(format!("lzma-compressed payload is {} bytes long when decompressed but given the header, there seems to be {} bytes extra.", decomp.len(), rest.len()))?;
2382 }
2383
2384 let final_states = final_states
2387 .into_iter()
2388 .map(|(a, b)| (a.into(), b))
2389 .collect();
2390
2391 let symbols = symbol_objs.into_iter().collect();
2394
2395 let mut rules: IndexMap<u64, IndexMap<Symbol, Vec<(u64, Symbol, f64)>>> = IndexMap::new();
2398 for ((from_state, to_state, top_symbol_idx, bottom_symbol_idx), weight) in
2399 file_rules.into_iter()
2400 {
2401 let from_state = from_state.into();
2402 let to_state = to_state.into();
2403 let top_symbol_idx: usize = top_symbol_idx.into();
2404 let bottom_symbol_idx: usize = bottom_symbol_idx.into();
2405 let top_symbol = symbol_list[top_symbol_idx].clone();
2406 let bottom_symbol = symbol_list[bottom_symbol_idx].clone();
2407 let handle = rules.entry(from_state).or_default();
2408 let vec = handle.entry(top_symbol).or_default();
2409 vec.push((to_state, bottom_symbol, weight));
2410 }
2411
2412 Ok(FST::from_rules(final_states, rules, symbols, None))
2413 }
2414
2415 fn _to_kfst_bytes(&self) -> Result<Vec<u8>, String> {
2416 let mut weighted = false;
2419
2420 for (_, &weight) in self.final_states.iter() {
2421 if weight != 0.0 {
2422 weighted = true;
2423 break;
2424 }
2425 }
2426
2427 let mut transitions: u32 = 0;
2428
2429 for (_, transition_table) in self.rules.iter() {
2430 for transition in transition_table.values() {
2431 for (_, _, weight) in transition.iter() {
2432 if (*weight) != 0.0 {
2433 weighted = true;
2434 }
2435 transitions += 1;
2436 }
2437 }
2438 }
2439
2440 let mut result: Vec<u8> = "KFST".into();
2443 result.extend(0u16.to_be_bytes());
2444 let symbol_len: u16 = self
2445 .symbols
2446 .len()
2447 .try_into()
2448 .map_err(|x| format!("Too many symbols to represent as u16: {x}"))?;
2449 result.extend(symbol_len.to_be_bytes());
2450 result.extend(transitions.to_be_bytes());
2451 let num_states: u32 = self
2452 .final_states
2453 .len()
2454 .try_into()
2455 .map_err(|x| format!("Too many final states to represent as u32: {x}"))?;
2456 result.extend(num_states.to_be_bytes());
2457 result.push(weighted.into()); let mut sorted_syms: Vec<_> = self.symbols.iter().collect();
2462 sorted_syms.sort();
2463 for symbol in sorted_syms.iter() {
2464 result.extend(symbol.get_symbol().into_bytes());
2465 result.push(0); }
2467
2468 let mut to_compress: Vec<u8> = vec![];
2471
2472 for (source_state, transition_table) in self.rules.iter() {
2475 for (top_symbol, transition) in transition_table.iter() {
2476 for (target_state, bottom_symbol, weight) in transition.iter() {
2477 let source_state: u32 = (*source_state).try_into().map_err(|x| {
2478 format!("Can't represent source state {source_state} as u32: {x}")
2479 })?;
2480 let target_state: u32 = (*target_state).try_into().map_err(|x| {
2481 format!("Can't represent target state {target_state} as u32: {x}")
2482 })?;
2483 let top_index: u16 = sorted_syms
2484 .binary_search(&top_symbol)
2485 .map_err(|_| {
2486 format!("Top symbol {top_symbol:?} not found in FST symbol list")
2487 })
2488 .and_then(|x| {
2489 x.try_into().map_err(|x| {
2490 format!("Can't represent top symbol index as u16: {x}")
2491 })
2492 })?;
2493 let bottom_index: u16 = sorted_syms
2494 .binary_search(&bottom_symbol)
2495 .map_err(|_| {
2496 format!("Bottom symbol {bottom_symbol:?} not found in FST symbol list")
2497 })
2498 .and_then(|x| {
2499 x.try_into().map_err(|x| {
2500 format!("Can't represent bottom symbol index as u16: {x}")
2501 })
2502 })?;
2503 to_compress.extend(source_state.to_be_bytes());
2504 to_compress.extend(target_state.to_be_bytes());
2505 to_compress.extend(top_index.to_be_bytes());
2506 to_compress.extend(bottom_index.to_be_bytes());
2507 if weighted {
2508 to_compress.extend(weight.to_be_bytes());
2509 } else {
2510 assert!(*weight == 0.0);
2511 }
2512 }
2513 }
2514 }
2515
2516 for (&final_state, weight) in self.final_states.iter() {
2519 let final_state: u32 = final_state
2520 .try_into()
2521 .map_err(|x| format!("Can't represent final state index as u32: {x}"))?;
2522 to_compress.extend(final_state.to_be_bytes());
2523 if weighted {
2524 to_compress.extend(weight.to_be_bytes());
2525 } else {
2526 assert!(*weight == 0.0);
2527 }
2528 }
2529
2530 let mut compressed = vec![];
2533
2534 let mut encoder = XzEncoder::new(to_compress.as_slice(), 9);
2535 encoder
2536 .read_to_end(&mut compressed)
2537 .map_err(|x| format!("Failed while compressing with lzma_rs: {x}"))?;
2538 result.extend(compressed);
2539
2540 Ok(result)
2541 }
2542
2543 fn _from_rules(
2544 final_states: IndexMap<u64, f64>,
2545 rules: IndexMap<u64, IndexMap<Symbol, Vec<(u64, Symbol, f64)>>>,
2546 symbols: HashSet<Symbol>,
2547 debug: Option<bool>,
2548 ) -> FST {
2549 let mut new_symbols: Vec<Symbol> = symbols.into_iter().collect();
2550 new_symbols.sort();
2553 let mut new_rules = IndexMap::new();
2555 for (target_node, rulebook) in rules {
2556 let mut new_rulebook = rulebook;
2557 new_rulebook.sort_by(|a, _, b, _| b.is_epsilon().cmp(&a.is_epsilon()));
2558 new_rules.insert(target_node, new_rulebook);
2559 }
2560 FST {
2561 final_states,
2562 rules: new_rules,
2563 symbols: new_symbols,
2564 debug: debug.unwrap_or(false),
2565 }
2566 }
2567
2568 #[cfg(not(feature = "python"))]
2570 pub fn from_rules(
2571 final_states: IndexMap<u64, f64>,
2572 rules: IndexMap<u64, IndexMap<Symbol, Vec<(u64, Symbol, f64)>>>,
2573 symbols: HashSet<Symbol>,
2574 debug: Option<bool>,
2575 ) -> FST {
2576 FST::_from_rules(final_states, rules, symbols, debug)
2577 }
2578
2579 fn _from_att_file(att_file: String, debug: bool) -> KFSTResult<FST> {
2580 match File::open(Path::new(&att_file)) {
2582 Ok(mut file) => {
2583 let mut att_code = String::new();
2584 file.read_to_string(&mut att_code).map_err(|err| {
2585 io_error::<()>(format!("Failed to read from file {att_file}:\n{err}"))
2586 .unwrap_err()
2587 })?;
2588 FST::from_att_code(att_code, debug)
2589 }
2590 Err(err) => io_error(format!("Failed to open file {att_file}:\n{err}")),
2591 }
2592 }
2593
2594 #[cfg(not(feature = "python"))]
2595 pub fn from_att_file(att_file: String, debug: bool) -> KFSTResult<FST> {
2598 FST::_from_att_file(att_file, debug)
2599 }
2600
2601 fn _from_att_code(att_code: String, debug: bool) -> KFSTResult<FST> {
2602 let mut rows: Vec<Result<(u64, f64), (u64, u64, Symbol, Symbol, f64)>> = vec![];
2603
2604 for (lineno, line) in att_code.lines().enumerate() {
2605 let elements: Vec<&str> = line.split("\t").collect();
2606 if elements.len() == 1 || elements.len() == 2 {
2607 let state = elements[0].parse::<u64>().ok();
2608 let weight = if elements.len() == 1 {
2609 Some(0.0)
2610 } else {
2611 elements[1].parse::<f64>().ok()
2612 };
2613 match (state, weight) {
2614 (Some(state), Some(weight)) => {
2615 rows.push(Ok((state, weight)));
2616 }
2617 _ => {
2618 return value_error(format!(
2619 "Failed to parse att code on line {lineno}:\n{line}",
2620 ))
2621 }
2622 }
2623 } else if elements.len() == 4 || elements.len() == 5 {
2624 let state_1 = elements[0].parse::<u64>().ok();
2625 let state_2 = elements[1].parse::<u64>().ok();
2626 let unescaped_sym_1 = unescape_att_symbol(elements[2]);
2627 let symbol_1 = Symbol::parse(&unescaped_sym_1).ok();
2628 let unescaped_sym_2 = unescape_att_symbol(elements[3]);
2629 let symbol_2 = Symbol::parse(&unescaped_sym_2).ok();
2630 let weight = if elements.len() == 4 {
2631 Some(0.0)
2632 } else {
2633 elements[4].parse::<f64>().ok()
2634 };
2635 match (state_1, state_2, symbol_1, symbol_2, weight) {
2636 (
2637 Some(state_1),
2638 Some(state_2),
2639 Some(("", symbol_1)),
2640 Some(("", symbol_2)),
2641 Some(weight),
2642 ) => {
2643 rows.push(Err((state_1, state_2, symbol_1, symbol_2, weight)));
2644 }
2645 _ => {
2646 return value_error(format!(
2647 "Failed to parse att code on line {lineno}:\n{line}",
2648 ));
2649 }
2650 }
2651 }
2652 }
2653 KFSTResult::Ok(FST::from_att_rows(rows, debug))
2654 }
2655
2656 #[cfg(not(feature = "python"))]
2657 pub fn from_att_code(att_code: String, debug: bool) -> KFSTResult<FST> {
2721 FST::_from_att_code(att_code, debug)
2722 }
2723
2724 fn _from_kfst_file(kfst_file: String, debug: bool) -> KFSTResult<FST> {
2725 match File::open(Path::new(&kfst_file)) {
2726 Ok(mut file) => {
2727 let mut kfst_bytes: Vec<u8> = vec![];
2728 file.read_to_end(&mut kfst_bytes).map_err(|err| {
2729 io_error::<()>(format!("Failed to read from file {kfst_file}:\n{err}"))
2730 .unwrap_err()
2731 })?;
2732 FST::from_kfst_bytes(&kfst_bytes, debug)
2733 }
2734 Err(err) => io_error(format!("Failed to open file {kfst_file}:\n{err}")),
2735 }
2736 }
2737
2738 #[cfg(not(feature = "python"))]
2742 pub fn from_kfst_file(kfst_file: String, debug: bool) -> KFSTResult<FST> {
2743 FST::_from_kfst_file(kfst_file, debug)
2744 }
2745
2746 #[allow(unused)]
2747 fn __from_kfst_bytes(kfst_bytes: &[u8], debug: bool) -> KFSTResult<FST> {
2748 match FST::_from_kfst_bytes(kfst_bytes) {
2749 Ok(x) => Ok(x),
2750 Err(x) => value_error(x),
2751 }
2752 }
2753
2754 #[cfg(not(feature = "python"))]
2758 pub fn from_kfst_bytes(kfst_bytes: &[u8], debug: bool) -> KFSTResult<FST> {
2759 FST::__from_kfst_bytes(kfst_bytes, debug)
2760 }
2761
2762 fn _split_to_symbols(&self, text: &str, allow_unknown: bool) -> Option<Vec<Symbol>> {
2763 let mut result = vec![];
2764 let max_byte_len = self
2765 .symbols
2766 .iter()
2767 .find(|x| matches!(x, Symbol::String(_) | Symbol::Special(_) | Symbol::Flag(_)))
2768 .map(|x| x.with_symbol(|s| s.len()))
2769 .unwrap_or(0);
2770 let mut slice = text;
2771 while !slice.is_empty() {
2772 let mut found = false;
2773 for length in (0..std::cmp::min(max_byte_len, slice.len()) + 1).rev() {
2774 if !slice.is_char_boundary(length) {
2775 continue;
2776 }
2777 let key = &slice[..length];
2778 let pp = self.symbols.partition_point(|x| {
2779 x.with_symbol(|y| (key.chars().count(), y) < (y.chars().count(), key))
2780 });
2781 if let Some(sym) = self.symbols[pp..]
2782 .iter()
2783 .find(|x| matches!(x, Symbol::String(_) | Symbol::Special(_) | Symbol::Flag(_)))
2784 {
2785 if sym.with_symbol(|s| s == key) {
2786 result.push(sym.clone());
2787 slice = &slice[length..];
2788 found = true;
2789 break;
2790 }
2791 }
2792 }
2793 if (!found) && allow_unknown {
2794 let char = slice.chars().next().unwrap();
2795 slice = &slice[char.len_utf8()..];
2796 result.push(Symbol::String(StringSymbol::new(char.to_string(), false)));
2797 found = true;
2798 }
2799 if !found {
2800 return None;
2801 }
2802 }
2803 Some(result)
2804 }
2805
2806 #[cfg(not(feature = "python"))]
2811 pub fn split_to_symbols(&self, text: &str, allow_unknown: bool) -> Option<Vec<Symbol>> {
2812 self._split_to_symbols(text, allow_unknown)
2813 }
2814
2815 fn __run_fst<T: UsizeOrUnit + Clone>(
2816 &self,
2817 input_symbols: Vec<Symbol>,
2818 state: InternalFSTState<T>,
2819 post_input_advance: bool,
2820 input_symbol_index: usize,
2821 keep_non_final: bool,
2822 ) -> Vec<(bool, bool, InternalFSTState<T>)> {
2823 let mut result = vec![];
2824 self._run_fst(
2825 input_symbols.as_slice(),
2826 &state,
2827 post_input_advance,
2828 &mut result,
2829 input_symbol_index,
2830 keep_non_final,
2831 );
2832 result
2833 }
2834
2835 #[cfg(not(feature = "python"))]
2836 pub fn run_fst<T: Clone + UsizeOrUnit>(
2844 &self,
2845 input_symbols: Vec<Symbol>,
2846 state: InternalFSTState<T>,
2847 post_input_advance: bool,
2848 input_symbol_index: Option<usize>,
2849 keep_non_final: bool,
2850 ) -> Vec<(bool, bool, InternalFSTState<T>)> {
2851 self.__run_fst(
2852 input_symbols,
2853 state,
2854 post_input_advance,
2855 input_symbol_index.unwrap_or(0),
2856 keep_non_final,
2857 )
2858 }
2859
2860 fn _lookup<T: UsizeOrUnit + Clone + Default + Debug>(
2861 &self,
2862 input: &str,
2863 state: InternalFSTState<T>,
2864 allow_unknown: bool,
2865 ) -> KFSTResult<Vec<(String, f64)>> {
2866 let input_symbols = self.split_to_symbols(input, allow_unknown);
2867 match input_symbols {
2868 None => tokenization_exception(format!("Input cannot be split into symbols: {input}")),
2869 Some(input_symbols) => {
2870 let mut dedup: IndexSet<String> = IndexSet::new();
2871 let mut result: Vec<(String, f64)> = vec![];
2872 let mut finished_paths: Vec<(bool, bool, InternalFSTState<()>)> = self.__run_fst(
2873 input_symbols.clone(),
2874 state.convert_indices(),
2875 false,
2876 0,
2877 false,
2878 );
2879 finished_paths
2880 .sort_by(|a, b| a.2.path_weight.partial_cmp(&b.2.path_weight).unwrap());
2881 for finished in finished_paths {
2882 let output_string: String = finished
2883 .2
2884 .symbol_mappings
2885 .to_vec()
2886 .iter()
2887 .map(|x| x.0.get_symbol())
2888 .collect::<Vec<String>>()
2889 .join("");
2890 if dedup.contains(&output_string) {
2891 continue;
2892 }
2893 dedup.insert(output_string.clone());
2894 result.push((output_string, finished.2.path_weight));
2895 }
2896 Ok(result)
2897 }
2898 }
2899 }
2900
2901 fn _lookup_aligned<T: UsizeOrUnit + Clone + Default + Debug>(
2902 &self,
2903 input: &str,
2904 state: InternalFSTState<T>,
2905 allow_unknown: bool,
2906 ) -> KFSTResult<Vec<(Vec<(usize, Symbol)>, f64)>> {
2907 let input_symbols = self.split_to_symbols(input, allow_unknown);
2908 match input_symbols {
2909 None => tokenization_exception(format!("Input cannot be split into symbols: {input}")),
2910 Some(input_symbols) => {
2911 let mut dedup: IndexSet<Vec<(usize, Symbol)>> = IndexSet::new();
2912 let mut result: Vec<(Vec<(usize, Symbol)>, f64)> = vec![];
2913 let mut finished_paths: Vec<_> = self
2914 .__run_fst(
2915 input_symbols.clone(),
2916 state.convert_indices(),
2917 false,
2918 0,
2919 false,
2920 )
2921 .into_iter()
2922 .filter(|(finished, _, _)| *finished)
2923 .collect();
2924 finished_paths
2925 .sort_by(|a, b| a.2.path_weight.partial_cmp(&b.2.path_weight).unwrap());
2926 for finished in finished_paths {
2927 let output_vec: Vec<(usize, Symbol)> = finished
2928 .2
2929 .symbol_mappings
2930 .to_vec()
2931 .into_iter()
2932 .map(|(a, b)| (b, a))
2933 .collect();
2934 if dedup.contains(&output_vec) {
2935 continue;
2936 }
2937 dedup.insert(output_vec.clone());
2938 result.push((output_vec, finished.2.path_weight));
2939 }
2940 Ok(result)
2941 }
2942 }
2943 }
2944
2945 #[cfg(not(feature = "python"))]
2946 pub fn lookup<T: UsizeOrUnit + Clone + Default + Debug>(
2954 &self,
2955 input: &str,
2956 state: InternalFSTState<T>,
2957 allow_unknown: bool,
2958 ) -> KFSTResult<Vec<(String, f64)>> {
2959 self._lookup(input, state, allow_unknown)
2960 }
2961
2962 #[cfg(not(feature = "python"))]
2963 pub fn lookup_aligned<T: UsizeOrUnit + Clone + Default + Debug>(
3029 &self,
3030 input: &str,
3031 state: InternalFSTState<T>,
3032 allow_unknown: bool,
3033 ) -> KFSTResult<Vec<(Vec<(usize, Symbol)>, f64)>> {
3034 self._lookup_aligned(input, state, allow_unknown)
3035 }
3036}
3037
3038fn _update_flags(symbol: &Symbol, flags: &FlagMap) -> Option<FlagMap> {
3039 if let Symbol::Flag(flag_diacritic_symbol) = symbol {
3040 match flag_diacritic_symbol.flag_type {
3041 FlagDiacriticType::U => {
3042 let value = flag_diacritic_symbol.value;
3043
3044 if let Some((currently_set, current_value)) = flags.get(flag_diacritic_symbol.key) {
3048 if (currently_set && current_value != value)
3049 || (!currently_set && current_value == value)
3050 {
3051 return None;
3052 }
3053 }
3054
3055 Some(flags.insert(flag_diacritic_symbol.key, (true, value)))
3058 }
3059 FlagDiacriticType::R => {
3060 match flag_diacritic_symbol.value {
3063 u32::MAX => {
3064 if flags.get(flag_diacritic_symbol.key).is_some() {
3065 Some(flags.clone())
3066 } else {
3067 None
3068 }
3069 }
3070 value => {
3071 if flags
3072 .get(flag_diacritic_symbol.key)
3073 .map(|stored| _test_flag(&stored, value))
3074 .unwrap_or(false)
3075 {
3076 Some(flags.clone())
3077 } else {
3078 None
3079 }
3080 }
3081 }
3082 }
3083 FlagDiacriticType::D => {
3084 match (
3085 flag_diacritic_symbol.value,
3086 flags.get(flag_diacritic_symbol.key),
3087 ) {
3088 (u32::MAX, None) => Some(flags.clone()),
3089 (u32::MAX, _) => None,
3090 (_, None) => Some(flags.clone()),
3091 (query, Some(stored)) => {
3092 if _test_flag(&stored, query) {
3093 None
3094 } else {
3095 Some(flags.clone())
3096 }
3097 }
3098 }
3099 }
3100 FlagDiacriticType::C => Some(flags.remove(flag_diacritic_symbol.key)),
3101 FlagDiacriticType::P => {
3102 let value = flag_diacritic_symbol.value;
3103 Some(flags.insert(flag_diacritic_symbol.key, (true, value)))
3104 }
3105 FlagDiacriticType::N => {
3106 let value = flag_diacritic_symbol.value;
3107 Some(flags.insert(flag_diacritic_symbol.key, (false, value)))
3108 }
3109 }
3110 } else {
3111 Some(flags.clone())
3112 }
3113}
3114
3115#[cfg_attr(feature = "python", pymethods)]
3116impl FST {
3117 #[cfg(feature = "python")]
3118 #[staticmethod]
3119 #[pyo3(signature = (final_states, rules, symbols, debug = false))]
3120 fn from_rules(
3121 final_states: IndexMap<u64, f64>,
3122 rules: IndexMap<u64, IndexMap<Symbol, Vec<(u64, Symbol, f64)>>>,
3123 symbols: HashSet<Symbol>,
3124 debug: Option<bool>,
3125 ) -> FST {
3126 FST::_from_rules(final_states, rules, symbols, debug)
3127 }
3128
3129 #[cfg(feature = "python")]
3130 #[staticmethod]
3131 #[pyo3(signature = (att_file, debug = false))]
3132 fn from_att_file(py: Python<'_>, att_file: PyObject, debug: bool) -> KFSTResult<FST> {
3133 FST::_from_att_file(att_file.call_method0(py, "__str__")?.extract(py)?, debug)
3134 }
3135
3136 #[cfg(feature = "python")]
3137 #[staticmethod]
3138 #[pyo3(signature = (att_code, debug = false))]
3139 fn from_att_code(att_code: String, debug: bool) -> KFSTResult<FST> {
3140 FST::_from_att_code(att_code, debug)
3141 }
3142
3143 #[cfg(feature = "python")]
3144 pub fn to_att_file(&self, py: Python<'_>, att_file: PyObject) -> KFSTResult<()> {
3145 let path: String = att_file.call_method0(py, "__str__")?.extract(py)?;
3146 fs::write(Path::new(&path), self.to_att_code()).map_err(|err| {
3147 io_error::<()>(format!("Failed to write to file {path}:\n{err}")).unwrap_err()
3148 })
3149 }
3150
3151 #[cfg(not(feature = "python"))]
3153 pub fn to_att_file(&self, att_file: String) -> KFSTResult<()> {
3154 fs::write(Path::new(&att_file), self.to_att_code()).map_err(|err| {
3155 io_error::<()>(format!("Failed to write to file {att_file}:\n{err}")).unwrap_err()
3156 })
3157 }
3158
3159 pub fn to_att_code(&self) -> String {
3182 let mut rows: Vec<String> = vec![];
3183 for (state, weight) in self.final_states.iter() {
3184 match weight {
3185 0.0 => {
3186 rows.push(format!("{state}"));
3187 }
3188 _ => {
3189 rows.push(format!("{state}\t{weight}"));
3190 }
3191 }
3192 }
3193 for (from_state, rules) in self.rules.iter() {
3194 for (top_symbol, transitions) in rules.iter() {
3195 for (to_state, bottom_symbol, weight) in transitions.iter() {
3196 match weight {
3197 0.0 => {
3198 rows.push(format!(
3199 "{}\t{}\t{}\t{}",
3200 from_state,
3201 to_state,
3202 escape_att_symbol(&top_symbol.get_symbol()),
3203 escape_att_symbol(&bottom_symbol.get_symbol())
3204 ));
3205 }
3206 _ => {
3207 rows.push(format!(
3208 "{}\t{}\t{}\t{}\t{}",
3209 from_state,
3210 to_state,
3211 escape_att_symbol(&top_symbol.get_symbol()),
3212 escape_att_symbol(&bottom_symbol.get_symbol()),
3213 weight
3214 ));
3215 }
3216 }
3217 }
3218 }
3219 }
3220 rows.join("\n")
3221 }
3222
3223 #[cfg(feature = "python")]
3224 #[staticmethod]
3225 #[pyo3(signature = (kfst_file, debug = false))]
3226 fn from_kfst_file(py: Python<'_>, kfst_file: PyObject, debug: bool) -> KFSTResult<FST> {
3227 FST::_from_kfst_file(kfst_file.call_method0(py, "__str__")?.extract(py)?, debug)
3228 }
3229
3230 #[cfg(feature = "python")]
3231 #[staticmethod]
3232 #[pyo3(signature = (kfst_bytes, debug = false))]
3233 fn from_kfst_bytes(kfst_bytes: &[u8], debug: bool) -> KFSTResult<FST> {
3234 FST::__from_kfst_bytes(kfst_bytes, debug)
3235 }
3236
3237 #[cfg(feature = "python")]
3238 pub fn to_kfst_file(&self, py: Python<'_>, kfst_file: PyObject) -> KFSTResult<()> {
3239 let bytes = self.to_kfst_bytes()?;
3240 let path: String = kfst_file.call_method0(py, "__str__")?.extract(py)?;
3241 fs::write(Path::new(&path), bytes).map_err(|err| {
3242 io_error::<()>(format!("Failed to write to file {path}:\n{err}")).unwrap_err()
3243 })
3244 }
3245
3246 #[cfg(not(feature = "python"))]
3247 pub fn to_kfst_file(&self, kfst_file: String) -> KFSTResult<()> {
3249 let bytes = self.to_kfst_bytes()?;
3250 fs::write(Path::new(&kfst_file), bytes).map_err(|err| {
3251 io_error::<()>(format!("Failed to write to file {kfst_file}:\n{err}")).unwrap_err()
3252 })
3253 }
3254
3255 pub fn to_kfst_bytes(&self) -> KFSTResult<Vec<u8>> {
3257 match self._to_kfst_bytes() {
3258 Ok(x) => Ok(x),
3259 Err(x) => value_error(x),
3260 }
3261 }
3262
3263 #[deprecated]
3264 pub fn __repr__(&self) -> String {
3266 format!(
3267 "FST(final_states: {:?}, rules: {:?}, symbols: {:?}, debug: {:?})",
3268 self.final_states, self.rules, self.symbols, self.debug
3269 )
3270 }
3271
3272 #[cfg(feature = "python")]
3273 #[pyo3(signature = (text, allow_unknown = true))]
3274 fn split_to_symbols(&self, text: &str, allow_unknown: bool) -> Option<Vec<Symbol>> {
3275 self._split_to_symbols(text, allow_unknown)
3276 }
3277
3278 #[cfg(feature = "python")]
3279 #[pyo3(signature = (input_symbols, state = FSTState { payload: Ok(InternalFSTState::_new(0)) }, post_input_advance = false, input_symbol_index = None, keep_non_final = true))]
3280 fn run_fst(
3281 &self,
3282 input_symbols: Vec<Symbol>,
3283 state: FSTState,
3284 post_input_advance: bool,
3285 input_symbol_index: Option<usize>,
3286 keep_non_final: bool,
3287 ) -> Vec<(bool, bool, FSTState)> {
3288 match state.payload {
3289 Ok(p) => self
3290 .__run_fst(
3291 input_symbols,
3292 p,
3293 post_input_advance,
3294 input_symbol_index.unwrap_or(0),
3295 keep_non_final,
3296 )
3297 .into_iter()
3298 .map(|(a, b, c)| (a, b, FSTState { payload: Ok(c) }))
3299 .collect(),
3300 Err(p) => self
3301 .__run_fst(
3302 input_symbols,
3303 p,
3304 post_input_advance,
3305 input_symbol_index.unwrap_or(0),
3306 keep_non_final,
3307 )
3308 .into_iter()
3309 .map(|(a, b, c)| (a, b, FSTState { payload: Err(c) }))
3310 .collect(),
3311 }
3312 }
3313
3314 #[cfg(feature = "python")]
3315 #[pyo3(signature = (input, state=FSTState { payload: Err(InternalFSTState::_new(0)) }, allow_unknown=true))]
3316 fn lookup(
3317 &self,
3318 input: &str,
3319 state: FSTState,
3320 allow_unknown: bool,
3321 ) -> KFSTResult<Vec<(String, f64)>> {
3322 match state.payload {
3323 Ok(p) => self._lookup(input, p, allow_unknown),
3324 Err(p) => self._lookup(input, p, allow_unknown),
3325 }
3326 }
3327
3328 #[cfg(feature = "python")]
3329 #[pyo3(signature = (input, state=FSTState { payload: Ok(InternalFSTState::_new(0)) }, allow_unknown=true))]
3330 pub fn lookup_aligned(
3331 &self,
3332 input: &str,
3333 state: FSTState,
3334 allow_unknown: bool,
3335 ) -> KFSTResult<Vec<(Vec<(usize, Symbol)>, f64)>> {
3336 use pyo3::exceptions::PyValueError;
3337
3338 match state.payload {
3339 Ok(p) => self._lookup_aligned(input, p, allow_unknown),
3340 Err(p) => PyResult::Err(PyErr::new::<PyValueError, _>(
3341 format!("lookup_aligned refuses to work with states with input_indices=None (passed state {}). Manually convert it to an indexed state by calling ensure_indices()", p.__repr__())
3342 )),
3343 }
3344 }
3345
3346 #[cfg(feature = "python")]
3347 pub fn get_input_symbols(&self, state: FSTState) -> HashSet<Symbol> {
3348 match state.payload {
3349 Ok(p) => self._get_input_symbols(p),
3350 Err(p) => self._get_input_symbols(p),
3351 }
3352 }
3353}
3354
3355impl FST {
3356 fn _get_input_symbols<T>(&self, state: InternalFSTState<T>) -> HashSet<Symbol> {
3357 self.rules
3358 .get(&state.state_num)
3359 .map(|x| x.keys().cloned().collect())
3360 .unwrap_or_default()
3361 }
3362
3363 #[cfg(not(feature = "python"))]
3364 pub fn get_input_symbols<T>(&self, state: FSTState<T>) -> HashSet<Symbol> {
3382 self._get_input_symbols(state)
3383 }
3384}
3385
3386#[cfg(not(feature = "python"))]
3387mod tests {
3388 use crate::*;
3389
3390 #[test]
3391 fn test_att_trivial() {
3392 let fst = FST::from_att_code("1\n0\t1\ta\tb".to_string(), false).unwrap();
3393 assert_eq!(
3394 fst.lookup("a", FSTState::<()>::default(), false).unwrap(),
3395 vec![("b".to_string(), 0.0)]
3396 );
3397 }
3398
3399 #[test]
3400 fn test_att_slightly_less_trivial() {
3401 let fst = FST::from_att_code("2\n0\t1\ta\tb\n1\t2\tc\td".to_string(), false).unwrap();
3402 assert_eq!(
3403 fst.lookup("ac", FSTState::<()>::default(), false).unwrap(),
3404 vec![("bd".to_string(), 0.0)]
3405 );
3406 }
3407
3408 #[test]
3409 fn test_kfst_voikko_kissa() {
3410 let fst =
3411 FST::_from_kfst_file("../pyvoikko/pyvoikko/voikko.kfst".to_string(), false).unwrap();
3412 assert_eq!(
3413 fst.lookup("kissa", FSTState::<()>::_new(0), false).unwrap(),
3414 vec![("[Ln][Xp]kissa[X]kiss[Sn][Ny]a".to_string(), 0.0)]
3415 );
3416 assert_eq!(
3417 fst.lookup("kissojemmekaan", FSTState::<()>::_new(0), false)
3418 .unwrap(),
3419 vec![(
3420 "[Ln][Xp]kissa[X]kiss[Sg][Nm]oje[O1m]mme[Fkaan]kaan".to_string(),
3421 0.0
3422 )]
3423 );
3424 }
3425
3426 #[test]
3427 fn test_that_weight_of_end_state_applies_correctly() {
3428 let code = "0\t1\ta\tb\n1\t1.0";
3429 let fst = FST::from_att_code(code.to_string(), false).unwrap();
3430 assert_eq!(
3431 fst.lookup("a", FSTState::<()>::_new(0), false).unwrap(),
3432 vec![("b".to_string(), 1.0)]
3433 );
3434 }
3435
3436 #[test]
3437 fn test_kfst_voikko_correct_final_states() {
3438 let fst: FST =
3439 FST::_from_kfst_file("../pyvoikko/pyvoikko/voikko.kfst".to_string(), false).unwrap();
3440 let map: IndexMap<_, _> = [(19, 0.0)].into_iter().collect();
3441 assert_eq!(fst.final_states, map);
3442 }
3443
3444 #[test]
3445 fn test_kfst_voikko_split() {
3446 let fst: FST =
3447 FST::_from_kfst_file("../pyvoikko/pyvoikko/voikko.kfst".to_string(), false).unwrap();
3448 assert_eq!(
3449 fst.split_to_symbols("lentokone", false).unwrap(),
3450 vec![
3451 Symbol::String(StringSymbol {
3452 string: intern("l".to_string()),
3453 unknown: false
3454 }),
3455 Symbol::String(StringSymbol {
3456 string: intern("e".to_string()),
3457 unknown: false
3458 }),
3459 Symbol::String(StringSymbol {
3460 string: intern("n".to_string()),
3461 unknown: false
3462 }),
3463 Symbol::String(StringSymbol {
3464 string: intern("t".to_string()),
3465 unknown: false
3466 }),
3467 Symbol::String(StringSymbol {
3468 string: intern("o".to_string()),
3469 unknown: false
3470 }),
3471 Symbol::String(StringSymbol {
3472 string: intern("k".to_string()),
3473 unknown: false
3474 }),
3475 Symbol::String(StringSymbol {
3476 string: intern("o".to_string()),
3477 unknown: false
3478 }),
3479 Symbol::String(StringSymbol {
3480 string: intern("n".to_string()),
3481 unknown: false
3482 }),
3483 Symbol::String(StringSymbol {
3484 string: intern("e".to_string()),
3485 unknown: false
3486 }),
3487 ]
3488 );
3489
3490 assert_eq!(
3491 fst.split_to_symbols("lentää", false).unwrap(),
3492 vec![
3493 Symbol::String(StringSymbol {
3494 string: intern("l".to_string()),
3495 unknown: false
3496 }),
3497 Symbol::String(StringSymbol {
3498 string: intern("e".to_string()),
3499 unknown: false
3500 }),
3501 Symbol::String(StringSymbol {
3502 string: intern("n".to_string()),
3503 unknown: false
3504 }),
3505 Symbol::String(StringSymbol {
3506 string: intern("t".to_string()),
3507 unknown: false
3508 }),
3509 Symbol::String(StringSymbol {
3510 string: intern("ä".to_string()),
3511 unknown: false
3512 }),
3513 Symbol::String(StringSymbol {
3514 string: intern("ä".to_string()),
3515 unknown: false
3516 }),
3517 ]
3518 );
3519 }
3520
3521 #[test]
3522 fn test_kfst_voikko() {
3523 let fst =
3524 FST::_from_kfst_file("../pyvoikko/pyvoikko/voikko.kfst".to_string(), false).unwrap();
3525 assert_eq!(
3526 fst.lookup("lentokone", FSTState::<()>::_new(0), false)
3527 .unwrap(),
3528 vec![(
3529 "[Lt][Xp]lentää[X]len[Ln][Xj]to[X]to[Sn][Ny][Bh][Bc][Ln][Xp]kone[X]kone[Sn][Ny]"
3530 .to_string(),
3531 0.0
3532 )]
3533 );
3534 }
3535
3536 #[test]
3537 fn test_kfst_voikko_lentää() {
3538 let fst =
3539 FST::_from_kfst_file("../pyvoikko/pyvoikko/voikko.kfst".to_string(), false).unwrap();
3540 let mut sys = fst
3541 .lookup("lentää", FSTState::<()>::_new(0), false)
3542 .unwrap();
3543 sys.sort_by(|a, b| a.partial_cmp(b).unwrap());
3544 assert_eq!(
3545 sys,
3546 vec![
3547 ("[Lt][Xp]lentää[X]len[Tn1][Eb]tää".to_string(), 0.0),
3548 (
3549 "[Lt][Xp]lentää[X]len[Tt][Ap][P3][Ny][Ef]tää".to_string(),
3550 0.0
3551 )
3552 ]
3553 );
3554 }
3555
3556 #[test]
3557 fn test_kfst_voikko_lentää_correct_states() {
3558 let fst =
3559 FST::_from_kfst_file("../pyvoikko/pyvoikko/voikko.kfst".to_string(), false).unwrap();
3560 let input_symbols = fst.split_to_symbols("lentää", false).unwrap();
3561
3562 let results = [
3565 vec![
3566 0, 1, 1810, 1946, 1961, 1962, 1963, 1964, 1965, 1966, 2665, 2969, 2970, 3104, 3295,
3567 3484, 3678, 3870, 4064, 4260, 4454, 4648, 4842, 5036, 5230, 5454, 5645, 5839, 6031,
3568 6225, 6419, 6613, 6807, 7001, 7195, 7389, 7579, 12479, 13348, 13444, 13541, 13636,
3569 13733, 13830, 13925, 14028, 14131, 14234, 14331, 14426, 14525, 14622, 14723, 14826,
3570 14929, 15024, 15127, 15230, 15333, 15433, 15526,
3571 ],
3572 vec![
3573 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1878, 2840, 17295, 25716, 31090, 40909,
3574 85222, 204950, 216255, 217894, 254890, 256725, 256726, 256727, 256728, 256729,
3575 256730, 256731, 256732, 256733, 256734, 256735, 256736, 280866, 281235, 281479,
3576 281836, 281876, 281877, 288536, 355529, 378467,
3577 ],
3578 vec![
3579 17459, 17898, 17899, 26065, 26066, 26067, 26068, 26069, 31245, 42140, 87151,
3580 134039, 134040, 205452, 219693, 219694, 259005, 259666, 259667, 259668, 259669,
3581 259670, 259671, 259672, 280894, 281857, 289402, 356836, 378621, 378750, 378773,
3582 386786, 388199, 388200, 388201, 388202, 388203,
3583 ],
3584 vec![
3585 17458, 17459, 17899, 19455, 26192, 26214, 26215, 26216, 26217, 42361, 87536,
3586 118151, 205474, 216303, 220614, 220615, 220616, 220617, 220618, 220619, 220620,
3587 220621, 220629, 228443, 228444, 228445, 259219, 259220, 259221, 259222, 259223,
3588 259224, 259225, 356941, 387264,
3589 ],
3590 vec![
3591 42362, 102258, 216304, 216309, 216312, 216317, 217230, 356942, 387265,
3592 ],
3593 vec![
3594 211149, 212998, 212999, 213000, 213001, 213002, 216305, 216310, 216313, 216318,
3595 ],
3596 vec![
3597 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 210815, 210816,
3598 211139, 211140, 214985, 216311, 216314, 216315, 216316,
3599 ],
3600 ];
3601
3602 for i in 0..=input_symbols.len() {
3603 let subsequence = &input_symbols[..i];
3604 let mut states: Vec<_> = fst
3605 .run_fst(
3606 subsequence.to_vec(),
3607 FSTState::<()>::_new(0),
3608 false,
3609 None,
3610 true,
3611 )
3612 .into_iter()
3613 .map(|(_, _, x)| x.state_num)
3614 .collect();
3615 states.sort();
3616 assert_eq!(states, results[i]);
3617 }
3618 }
3619
3620 #[test]
3621 fn test_minimal_r_diacritic() {
3622 let code = "0\t1\t@P.V_SALLITTU.T@\tasetus\n1\t2\t@R.V_SALLITTU.T@\ttarkistus\n2";
3623 let fst = FST::from_att_code(code.to_string(), false).unwrap();
3624 let mut result = vec![];
3625 fst._run_fst(&[], &FSTState::<()>::_new(0), false, &mut result, 0, true);
3626 for x in result {
3627 println!("{x:?}");
3628 }
3629 assert_eq!(
3630 fst.lookup("", FSTState::<()>::_new(0), false).unwrap(),
3631 vec![("asetustarkistus".to_string(), 0.0)]
3632 );
3633 }
3634
3635 #[test]
3636 fn test_kfst_voikko_lentää_result_count() {
3637 let fst =
3638 FST::_from_kfst_file("../pyvoikko/pyvoikko/voikko.kfst".to_string(), false).unwrap();
3639 let input_symbols = fst.split_to_symbols("lentää", false).unwrap();
3640
3641 let results = [61, 42, 37, 35, 9, 10, 25];
3644
3645 for i in 0..=input_symbols.len() {
3646 let subsequence = &input_symbols[..i];
3647 assert_eq!(
3648 fst.run_fst(
3649 subsequence.to_vec(),
3650 FSTState::<()>::_new(0),
3651 false,
3652 None,
3653 true
3654 )
3655 .len(),
3656 results[i]
3657 );
3658 }
3659 }
3660
3661 #[test]
3662 fn does_not_crash_on_unknown() {
3663 let fst = FST::from_att_code("0\t1\ta\tb\n1".to_string(), false).unwrap();
3664 assert_eq!(
3665 fst.lookup("c", FSTState::<()>::_new(0), true).unwrap(),
3666 vec![]
3667 );
3668 assert!(fst.lookup("c", FSTState::<()>::_new(0), false).is_err());
3669 }
3670
3671 #[test]
3672 fn test_kfst_voikko_paragraph() {
3673 let words = [
3674 "on",
3675 "maanantaiaamu",
3676 "heinäkuussa",
3677 "aurinko",
3678 "paiskaa",
3679 "niin",
3680 "lämpöisesti",
3681 "heikon",
3682 "tuulen",
3683 "avulla",
3684 "ja",
3685 "peipposet",
3686 "kajahuttelevat",
3687 "ensimmäisiä",
3688 "kovia",
3689 "säveleitään",
3690 "tuoksuavissa",
3691 "koivuissa",
3692 "kirkon",
3693 "itäisellä",
3694 "seinuksella",
3695 "on",
3696 "kivipenkki",
3697 "juuri",
3698 "nyt",
3699 "saapuu",
3700 "keski-ikäinen",
3701 "työmies",
3702 "ja",
3703 "istuutuu",
3704 "penkille",
3705 "hän",
3706 "näyttää",
3707 "väsyneeltä",
3708 "alakuloiselta",
3709 "haluttomalla",
3710 "aivan",
3711 "kuin",
3712 "olisi",
3713 "vastikään",
3714 "tullut",
3715 "perheellisestä",
3716 "riidasta",
3717 "tahi",
3718 "jättänyt",
3719 "eilisen",
3720 "sapatinpäivän",
3721 "pyhittämättä",
3722 ];
3723 let gold: [Vec<(&str, i32)>; 48] = [
3724 vec![("[Lt][Xp]olla[X]o[Tt][Ap][P3][Ny][Ef]n", 0)],
3725 vec![("[Ln][Xp]maanantai[X]maanantai[Sn][Ny][Bh][Bc][Ln][Xp]aamu[X]aamu[Sn][Ny]", 0)],
3726 vec![("[Ln][Xp]heinä[X]hein[Sn][Ny]ä[Bh][Bc][Ln][Xp]kuu[X]kuu[Sine][Ny]ssa", 0)],
3727 vec![("[Ln][Xp]aurinko[X]aurinko[Sn][Ny]", 0), ("[Lem][Xp]Aurinko[X]aurinko[Sn][Ny]", 0), ("[Lee][Xp]Auri[X]aur[Sg][Ny]in[Fko][Ef]ko", 0)],
3728 vec![("[Lt][Xp]paiskata[X]paiska[Tt][Ap][P3][Ny][Eb]a", 0)],
3729 vec![("[Ls][Xp]niin[X]niin", 0)],
3730 vec![("[Ln][Xp]lämpö[X]lämpö[Ll][Xj]inen[X]ise[Ssti]sti", 0)],
3731 vec![("[Ll][Xp]heikko[X]heiko[Sg][Ny]n", 0)],
3732 vec![("[Ln][Xp]tuuli[X]tuul[Sg][Ny]en", 0)],
3733 vec![("[Ln][Xp]avu[X]avu[Sade][Ny]lla", 0), ("[Ln][Xp]apu[X]avu[Sade][Ny]lla", 0)],
3734 vec![("[Lc][Xp]ja[X]ja", 0)],
3735 vec![("[Ln][Xp]peipponen[X]peippo[Sn][Nm]set", 0)],
3736 vec![],
3737 vec![("[Lu][Xp]ensimmäinen[X]ensimmäi[Sp][Nm]siä", 0)],
3738 vec![("[Lnl][Xp]kova[X]kov[Sp][Nm]ia", 0)],
3739 vec![],
3740 vec![],
3741 vec![("[Ln][Xp]koivu[X]koivu[Sine][Nm]issa", 0), ("[Les][Xp]Koivu[X]koivu[Sine][Nm]issa", 0)],
3742 vec![("[Ln][Ica][Xp]kirkko[X]kirko[Sg][Ny]n", 0)],
3743 vec![("[Ln][De][Xp]itä[X]itä[Ll][Xj]inen[X]ise[Sade][Ny]llä", 0)],
3744 vec![("[Ln][Xp]seinus[X]seinukse[Sade][Ny]lla", 0)],
3745 vec![("[Lt][Xp]olla[X]o[Tt][Ap][P3][Ny][Ef]n", 0)],
3746 vec![("[Ln][Ica][Xp]kivi[X]kiv[Sn][Ny]i[Bh][Bc][Ln][Xp]penkki[X]penkk[Sn][Ny]i", 0)],
3747 vec![("[Ln][Xp]juuri[X]juur[Sn][Ny]i", 0), ("[Ls][Xp]juuri[X]juuri", 0), ("[Lt][Xp]juuria[X]juuri[Tk][Ap][P2][Ny][Eb]", 0), ("[Lt][Xp]juuria[X]juur[Tt][Ai][P3][Ny][Ef]i", 0)],
3748 vec![("[Ls][Xp]nyt[X]nyt", 0)],
3749 vec![("[Lt][Xp]saapua[X]saapuu[Tt][Ap][P3][Ny][Ef]", 0)],
3750 vec![("[Lp]keski[De]-[Bh][Bc][Ln][Xp]ikä[X]ikä[Ll][Xj]inen[X]i[Sn][Ny]nen", 0)],
3751 vec![("[Ln][Xp]työ[X]työ[Sn][Ny][Bh][Bc][Ln][Xp]mies[X]mies[Sn][Ny]", 0)],
3752 vec![("[Lc][Xp]ja[X]ja", 0)],
3753 vec![("[Lt][Xp]istuutua[X]istuutuu[Tt][Ap][P3][Ny][Ef]", 0)],
3754 vec![("[Ln][Xp]penkki[X]penki[Sall][Ny]lle", 0)],
3755 vec![("[Lr][Xp]hän[X]hä[Sn][Ny]n", 0)],
3756 vec![("[Lt][Xp]näyttää[X]näyttä[Tn1][Eb]ä", 0), ("[Lt][Xp]näyttää[X]näytt[Tt][Ap][P3][Ny][Ef]ää", 0)],
3757 vec![("[Lt][Irm][Xp]väsyä[X]väsy[Ll][Ru]n[Xj]yt[X]ee[Sabl][Ny]ltä", 0)],
3758 vec![("[Ln][De][Xp]ala[X]al[Sn][Ny]a[Bh][Bc][Lnl][Xp]kulo[X]kulo[Ll][Xj]inen[X]ise[Sabl][Ny]lta", 0)],
3759 vec![("[Ln][Xp]halu[X]halu[Ll][Xj]ton[X]ttoma[Sade][Ny]lla", 0)],
3760 vec![("[Ls][Xp]aivan[X]aivan", 0)],
3761 vec![("[Lc][Xp]kuin[X]kuin", 0), ("[Ln][Xp]kuu[X]ku[Sin][Nm]in", 0)],
3762 vec![("[Lt][Xp]olla[X]ol[Te][Ap][P3][Ny][Eb]isi", 0)],
3763 vec![("[Ls][Xp]vast=ikään[X]vast[Bm]ikään", 0)],
3764 vec![("[Lt][Xp]tulla[X]tul[Ll][Ru]l[Xj]ut[X][Sn][Ny]ut", 0), ("[Lt][Xp]tulla[X]tul[Ll][Rt][Xj]tu[X]lu[Sn][Nm]t", 0)],
3765 vec![("[Ln][Xp]perhe[X]perhee[Ll]lli[Xj]nen[X]se[Sela][Ny]stä", 0)],
3766 vec![("[Ln][Xp]riita[X]riida[Sela][Ny]sta", 0)],
3767 vec![("[Lc][Xp]tahi[X]tahi", 0)],
3768 vec![("[Lt][Xp]jättää[X]jättä[Ll][Ru]n[Xj]yt[X][Sn][Ny]yt", 0)],
3769 vec![("[Lnl][Xp]eilinen[X]eili[Sg][Ny]sen", 0)],
3770 vec![("[Ln][Xp]sapatti[X]sapat[Sg][Ny]in[Bh][Bc][Ln][Xp]päivä[X]päiv[Sg][Ny]än", 0)],
3771 vec![("[Lt][Xp]pyhittää[X]pyhittä[Ln]m[Xj]ä[X][Rm]ä[Sab][Ny]ttä", 0), ("[Lt][Xp]pyhittää[X]pyhittä[Tn3][Ny][Sab]mättä", 0)],
3772 ];
3773 let fst =
3774 FST::_from_kfst_file("../pyvoikko/pyvoikko/voikko.kfst".to_string(), false).unwrap();
3775 for (idx, (word, gold)) in words.into_iter().zip(gold.into_iter()).enumerate() {
3776 let mut sys = fst.lookup(word, FSTState::<()>::_new(0), false).unwrap();
3777 sys.sort_by(|a, b| a.partial_cmp(b).unwrap());
3778 let mut gold_sorted = gold;
3779 gold_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
3780 println!("Word at: {idx}");
3781 assert_eq!(
3782 sys,
3783 gold_sorted
3784 .iter()
3785 .map(|(s, w)| (s.to_string(), (*w).into()))
3786 .collect::<Vec<_>>()
3787 );
3788 }
3789 }
3790
3791 #[test]
3792 fn test_simple_unknown() {
3793 let code = "0\t1\t@_UNKNOWN_SYMBOL_@\ty\n1";
3794 let fst = FST::from_att_code(code.to_string(), false).unwrap();
3795
3796 assert_eq!(
3797 fst.run_fst(
3798 vec![Symbol::String(StringSymbol::new("x".to_string(), false,))],
3799 FSTState::<()>::_new(0),
3800 false,
3801 None,
3802 true
3803 ),
3804 vec![]
3805 );
3806
3807 assert_eq!(
3808 fst.run_fst(
3809 vec![Symbol::String(StringSymbol::new("x".to_string(), true,))],
3810 FSTState::_new(0),
3811 false,
3812 None,
3813 true
3814 ),
3815 vec![(
3816 true,
3817 false,
3818 FSTState {
3819 state_num: 1,
3820 path_weight: 0.0,
3821 input_flags: FlagMap::new(),
3822 output_flags: FlagMap::new(),
3823 symbol_mappings: FSTLinkedList::Some(
3824 (Symbol::String(StringSymbol::new("y".to_string(), false)), 0),
3825 Arc::new(FSTLinkedList::None)
3826 )
3827 }
3828 )]
3829 );
3830 }
3831
3832 #[test]
3833 fn test_simple_identity() {
3834 let code = "0\t1\t@_IDENTITY_SYMBOL_@\t@_IDENTITY_SYMBOL_@\n1";
3835 let fst = FST::from_att_code(code.to_string(), false).unwrap();
3836
3837 assert_eq!(
3838 fst.run_fst(
3839 vec![Symbol::String(StringSymbol::new("x".to_string(), false,))],
3840 FSTState::<()>::_new(0),
3841 false,
3842 None,
3843 true
3844 ),
3845 vec![]
3846 );
3847
3848 assert_eq!(
3849 fst.run_fst(
3850 vec![Symbol::String(StringSymbol::new("x".to_string(), true,))],
3851 FSTState::_new(0),
3852 false,
3853 None,
3854 true
3855 ),
3856 vec![(
3857 true,
3858 false,
3859 FSTState {
3860 state_num: 1,
3861 path_weight: 0.0,
3862 input_flags: FlagMap::new(),
3863 output_flags: FlagMap::new(),
3864 symbol_mappings: FSTLinkedList::Some(
3865 (Symbol::String(StringSymbol::new("x".to_string(), true)), 0),
3866 Arc::new(FSTLinkedList::None)
3867 )
3868 }
3869 )]
3870 );
3871 }
3872
3873 #[test]
3874 fn test_raw_symbols() {
3875 let mut rules: IndexMap<u64, IndexMap<Symbol, Vec<(u64, Symbol, f64)>>> = IndexMap::new();
3878 let sym_a = Symbol::Raw(RawSymbol {
3879 value: [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
3880 });
3881 let sym_b = Symbol::Raw(RawSymbol {
3882 value: [0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
3883 });
3884 let sym_c = Symbol::Raw(RawSymbol {
3885 value: [0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
3886 });
3887 let special_epsilon = Symbol::Raw(RawSymbol {
3888 value: [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
3889 });
3890 let sym_d = Symbol::Raw(RawSymbol {
3891 value: [0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
3892 });
3893 let sym_d_unk = Symbol::Raw(RawSymbol {
3894 value: [2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
3895 });
3896 rules.insert(0, indexmap!(sym_a.clone() => vec![(1, sym_a.clone(), 0.0)]));
3897 rules.insert(1, indexmap!(sym_b.clone() => vec![(0, sym_b.clone(), 0.0)], Symbol::Special(SpecialSymbol::IDENTITY) => vec![(2, Symbol::Special(SpecialSymbol::IDENTITY), 0.0)]));
3898 rules.insert(
3899 2,
3900 indexmap!(special_epsilon.clone() => vec![(3, sym_c.clone(), 0.0)]),
3901 );
3902 let symbols = vec![sym_a.clone(), sym_b.clone(), sym_c.clone(), special_epsilon];
3903 let fst = FST {
3904 final_states: indexmap! {3 => 0.0},
3905 rules,
3906 symbols,
3907 debug: false,
3908 };
3909
3910 let result = fst.run_fst(
3913 vec![
3914 sym_a.clone(),
3915 sym_b.clone(),
3916 sym_a.clone(),
3917 sym_d_unk.clone(),
3918 ],
3919 FSTState::<()>::_new(0),
3920 false,
3921 None,
3922 true,
3923 );
3924 let filtered: Vec<_> = result.into_iter().filter(|x| x.0).collect();
3925 assert_eq!(filtered.len(), 1);
3926 assert_eq!(filtered[0].2.state_num, 3);
3927 assert_eq!(
3928 filtered[0]
3929 .2
3930 .symbol_mappings
3931 .clone()
3932 .to_vec()
3933 .iter()
3934 .map(|x| x.0.clone())
3935 .collect::<Vec<_>>(),
3936 vec![
3937 sym_a.clone(),
3938 sym_b.clone(),
3939 sym_a.clone(),
3940 sym_d_unk.clone(),
3941 sym_c.clone()
3942 ]
3943 );
3944
3945 assert_eq!(
3948 fst.run_fst(
3949 vec![sym_a.clone(), sym_b.clone(), sym_a.clone(), sym_d.clone()],
3950 FSTState::<()>::_new(0),
3951 false,
3952 None,
3953 true
3954 )
3955 .into_iter()
3956 .filter(|x| x.0)
3957 .count(),
3958 0,
3959 );
3960 }
3961
3962 #[test]
3963 fn test_string_comparison_order_for_tokenizable_symbol_types() {
3964 assert!(
3965 StringSymbol::new("aa".to_string(), false) < StringSymbol::new("a".to_string(), false)
3966 );
3967 assert!(
3968 StringSymbol::new("aa".to_string(), true) > StringSymbol::new("aa".to_string(), false)
3969 );
3970 assert!(
3971 StringSymbol::new("ab".to_string(), false) > StringSymbol::new("aa".to_string(), false)
3972 );
3973
3974 assert!(
3975 FlagDiacriticSymbol::parse("@U.aa@").unwrap()
3976 < FlagDiacriticSymbol::parse("@U.a@").unwrap()
3977 );
3978 assert!(
3979 FlagDiacriticSymbol::parse("@U.ab@").unwrap()
3980 > FlagDiacriticSymbol::parse("@U.aa@").unwrap()
3981 );
3982 assert!(SpecialSymbol::IDENTITY < SpecialSymbol::EPSILON); }
3984
3985 #[test]
3986 fn fst_linked_list_conversion_correctness_internal_methods() {
3987 let orig = vec![
3988 Symbol::Special(SpecialSymbol::EPSILON),
3989 Symbol::Special(SpecialSymbol::IDENTITY),
3990 Symbol::Special(SpecialSymbol::UNKNOWN),
3991 ];
3992 assert!(
3993 FSTLinkedList::from_vec(orig.clone(), vec![10, 20, 30]).to_vec()
3994 == vec![
3995 (Symbol::Special(SpecialSymbol::EPSILON), 10),
3996 (Symbol::Special(SpecialSymbol::IDENTITY), 20),
3997 (Symbol::Special(SpecialSymbol::UNKNOWN), 30),
3998 ]
3999 );
4000
4001 assert!(
4002 FSTLinkedList::from_vec(orig.clone(), vec![]).to_vec()
4003 == vec![
4004 (Symbol::Special(SpecialSymbol::EPSILON), 0),
4005 (Symbol::Special(SpecialSymbol::IDENTITY), 0),
4006 (Symbol::Special(SpecialSymbol::UNKNOWN), 0),
4007 ]
4008 );
4009 }
4010
4011 #[test]
4012 fn fst_linked_list_conversion_correctness_iterators_with_indices() {
4013 let orig = vec![
4014 (Symbol::Special(SpecialSymbol::EPSILON), 10),
4015 (Symbol::Special(SpecialSymbol::IDENTITY), 20),
4016 (Symbol::Special(SpecialSymbol::UNKNOWN), 30),
4017 ];
4018 let ll: FSTLinkedList<usize> = orig.clone().into_iter().collect();
4019 assert!(
4020 ll.into_iter().collect::<Vec<_>>()
4021 == vec![
4022 (Symbol::Special(SpecialSymbol::EPSILON), 10),
4023 (Symbol::Special(SpecialSymbol::IDENTITY), 20),
4024 (Symbol::Special(SpecialSymbol::UNKNOWN), 30),
4025 ]
4026 );
4027 }
4028}
4029
4030#[cfg(feature = "python")]
4032#[pymodule]
4033fn kfst_rs(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
4034 let symbols = PyModule::new(m.py(), "symbols")?;
4035 symbols.add_class::<StringSymbol>()?;
4036 symbols.add_class::<FlagDiacriticType>()?;
4037 symbols.add_class::<FlagDiacriticSymbol>()?;
4038 symbols.add_class::<SpecialSymbol>()?;
4039 symbols.add_class::<RawSymbol>()?;
4040 symbols.add_function(wrap_pyfunction!(from_symbol_string, m)?)?;
4041
4042 py_run!(
4043 py,
4044 symbols,
4045 "import sys; sys.modules['kfst_rs.symbols'] = symbols"
4046 );
4047
4048 m.add_submodule(&symbols)?;
4049
4050 let transducer = PyModule::new(m.py(), "transducer")?;
4051 transducer.add_class::<FST>()?;
4052 transducer.add_class::<FSTState>()?;
4053 transducer.add(
4054 "TokenizationException",
4055 py.get_type::<TokenizationException>(),
4056 )?;
4057
4058 py_run!(
4059 py,
4060 transducer,
4061 "import sys; sys.modules['kfst_rs.transducer'] = transducer"
4062 );
4063
4064 m.add_submodule(&transducer)?;
4065
4066 m.add(
4069 "TokenizationException",
4070 py.get_type::<TokenizationException>(),
4071 )?;
4072 m.add_class::<FST>()?;
4073
4074 Ok(())
4075}