1use std::cmp::Ordering;
71use std::collections::HashSet;
72use std::fmt::Debug;
73#[cfg(feature = "python")]
74use std::fmt::Error;
75use std::fs::{self, File};
76use std::hash::Hash;
77use std::io::Read;
78use std::path::Path;
79
80use indexmap::{indexmap, IndexMap, IndexSet};
81use nom::branch::alt;
82use nom::bytes::complete::{tag, take_until1};
83use nom::multi::many_m_n;
84use nom::Parser;
85use std::sync::{Arc, LazyLock, RwLock};
86use xz2::read::{XzDecoder, XzEncoder};
87
88#[cfg(feature = "python")]
89use pyo3::create_exception;
90#[cfg(feature = "python")]
91use pyo3::exceptions::{PyIOError, PyValueError};
92#[cfg(feature = "python")]
93use pyo3::types::{PyDict, PyNone, PyTuple};
94#[cfg(feature = "python")]
95use pyo3::{prelude::*, py_run, IntoPyObjectExt};
96
97#[cfg(feature = "python")]
102type KFSTResult<T> = PyResult<T>;
103#[cfg(not(feature = "python"))]
104type KFSTResult<T> = std::result::Result<T, String>;
105
106#[cfg(feature = "python")]
107fn value_error<T>(msg: String) -> KFSTResult<T> {
108 KFSTResult::Err(PyErr::new::<PyValueError, _>(msg))
109}
110#[cfg(not(feature = "python"))]
111fn value_error<T>(msg: String) -> KFSTResult<T> {
112 KFSTResult::Err(msg)
113}
114
115#[cfg(feature = "python")]
116fn io_error<T>(msg: String) -> KFSTResult<T> {
117 use pyo3::exceptions::PyIOError;
118
119 KFSTResult::Err(PyErr::new::<PyIOError, _>(msg))
120}
121#[cfg(not(feature = "python"))]
122fn io_error<T>(msg: String) -> KFSTResult<T> {
123 KFSTResult::Err(msg)
124}
125
126#[cfg(feature = "python")]
127fn tokenization_exception<T>(msg: String) -> KFSTResult<T> {
128 KFSTResult::Err(PyErr::new::<TokenizationException, _>(msg))
129}
130#[cfg(not(feature = "python"))]
131fn tokenization_exception<T>(msg: String) -> KFSTResult<T> {
132 KFSTResult::Err(msg)
133}
134
135#[cfg(feature = "python")]
136create_exception!(
137 kfst_rs,
138 TokenizationException,
139 pyo3::exceptions::PyException
140);
141
142static STRING_INTERNER: LazyLock<RwLock<IndexSet<String>>> =
145 LazyLock::new(|| RwLock::new(IndexSet::new()));
146
147fn intern(s: String) -> u32 {
148 u32::try_from(STRING_INTERNER.write().unwrap().insert_full(s).0).unwrap()
149}
150
151fn deintern(idx: u32) -> String {
154 with_deinterned(idx, |x| x.to_string())
155}
156
157fn with_deinterned<F, X>(idx: u32, f: F) -> X
160where
161 F: FnOnce(&str) -> X,
162{
163 f(STRING_INTERNER
164 .read()
165 .unwrap()
166 .get_index(idx.try_into().unwrap())
167 .unwrap())
168}
169
170#[cfg_attr(
171 feature = "python",
172 pyclass(str = "RawSymbol({value:?})", eq, ord, frozen, hash, get_all)
173)]
174#[derive(Clone, Copy, Hash, PartialEq, PartialOrd, Ord, Eq, Debug)]
175#[readonly::make]
176pub struct RawSymbol {
180 pub value: [u8; 15],
188}
189
190#[cfg_attr(feature = "python", pymethods)]
191impl RawSymbol {
192 pub fn is_epsilon(&self) -> bool {
195 (self.value[0] & 1) != 0
196 }
197
198 pub fn is_unknown(&self) -> bool {
201 (self.value[0] & 2) != 0
202 }
203
204 pub fn get_symbol(&self) -> String {
206 format!("RawSymbol({:?})", self.value)
207 }
208
209 #[cfg(feature = "python")]
210 #[new]
211 fn new(value: [u8; 15]) -> Self {
212 RawSymbol { value }
213 }
214
215 #[cfg(not(feature = "python"))]
217 pub fn new(value: [u8; 15]) -> Self {
218 RawSymbol { value }
219 }
220
221 #[deprecated]
222 pub fn __repr__(&self) -> String {
224 format!("RawSymbol({:?})", self.value)
225 }
226}
227
228#[cfg(feature = "python")]
229struct PyObjectSymbol {
230 value: PyObject,
231}
232
233#[cfg(feature = "python")]
234impl Debug for PyObjectSymbol {
235 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236 Python::with_gil(|py| {
237 let s: String = self
240 .value
241 .getattr(py, "__repr__")
242 .unwrap()
243 .call0(py)
244 .unwrap()
245 .extract(py)
246 .unwrap();
247 f.write_str(&s)
248 })
249 }
250}
251
252#[cfg(feature = "python")]
253impl PartialEq for PyObjectSymbol {
254 fn eq(&self, other: &Self) -> bool {
255 Python::with_gil(|py| {
256 self.value
257 .getattr(py, "__eq__")
258 .unwrap_or_else(|_| {
259 panic!(
260 "Symbol {} doesn't have an __eq__ implementation.",
261 self.value
262 )
263 })
264 .call1(py, (other.value.clone_ref(py),))
265 .unwrap_or_else(|_| {
266 panic!("__eq__ on symbol {} failed to return a value.", self.value)
267 })
268 .extract(py)
269 .unwrap_or_else(|_| panic!("__eq__ on symbol {} didn't return a bool.", self.value))
270 })
271 }
272}
273
274#[cfg(feature = "python")]
275impl Hash for PyObjectSymbol {
276 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
277 state.write_i128(Python::with_gil(|py| {
278 self.value
279 .getattr(py, "__hash__")
280 .unwrap_or_else(|_| {
281 panic!(
282 "Symbol {} doesn't have a __hash__ implementation.",
283 self.value
284 )
285 })
286 .call0(py)
287 .unwrap_or_else(|_| {
288 panic!(
289 "__hash__ on symbol {} failed to return a value.",
290 self.value
291 )
292 })
293 .extract(py)
294 .unwrap_or_else(|_| {
295 panic!("__hash__ on symbol {} didn't return an int.", self.value)
296 })
297 }))
298 }
299}
300
301#[cfg(feature = "python")]
302impl Eq for PyObjectSymbol {}
303
304#[cfg(feature = "python")]
305impl PartialOrd for PyObjectSymbol {
306 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
307 Some(self.cmp(other))
308 }
309}
310
311#[cfg(feature = "python")]
312impl Ord for PyObjectSymbol {
313 fn cmp(&self, other: &Self) -> Ordering {
314 Python::with_gil(|py| {
315 match self
316 .value
317 .getattr(py, "__gt__")
318 .unwrap_or_else(|_| {
319 panic!(
320 "Symbol {} doesn't have a __gt__ implementation.",
321 self.value
322 )
323 })
324 .call1(py, (other.value.clone_ref(py),))
325 .unwrap_or_else(|_| {
326 panic!("__gt__ on symbol {} failed to return a value.", self.value)
327 })
328 .extract::<bool>(py)
329 .unwrap_or_else(|_| panic!("__gt__ on symbol {} didn't return a bool.", self.value))
330 {
331 true => Ordering::Greater,
332 false => {
333 match self
334 .value
335 .getattr(py, "__eq__")
336 .unwrap_or_else(|_| {
337 panic!(
338 "Symbol {} doesn't have an __eq__ implementation.",
339 self.value
340 )
341 })
342 .call1(py, (other.value.clone_ref(py),))
343 .unwrap_or_else(|_| {
344 panic!("__eq__ on symbol {} failed to return a value.", self.value)
345 })
346 .extract::<bool>(py)
347 .unwrap_or_else(|_| {
348 panic!("__eq__ on symbol {} didn't return a bool.", self.value)
349 }) {
350 true => Ordering::Equal,
351 false => Ordering::Less,
352 }
353 }
354 }
355 })
356 }
357}
358
359#[cfg(feature = "python")]
360impl Clone for PyObjectSymbol {
361 fn clone(&self) -> Self {
362 Python::with_gil(|py| Self {
363 value: self.value.clone_ref(py),
364 })
365 }
366}
367
368#[cfg(feature = "python")]
369impl<'py> FromPyObject<'_, 'py> for PyObjectSymbol {
370 type Error = PyErr;
371
372 fn extract(ob: Borrowed<'_, 'py, PyAny>) -> PyResult<Self> {
373 Ok(PyObjectSymbol {
374 value: <pyo3::Bound<'_, pyo3::PyAny> as Clone>::clone(&ob).unbind(),
375 }) }
377}
378
379#[cfg(feature = "python")]
380impl<'py> IntoPyObject<'py> for PyObjectSymbol {
381 type Target = PyAny;
382
383 type Output = Bound<'py, Self::Target>;
384
385 type Error = pyo3::PyErr;
386
387 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
388 self.value.into_bound_py_any(py)
389 }
390}
391
392#[cfg(feature = "python")]
393impl PyObjectSymbol {
394 fn is_epsilon(&self) -> bool {
395 Python::with_gil(|py| {
396 self.value
397 .getattr(py, "is_epsilon")
398 .unwrap_or_else(|_| {
399 panic!(
400 "Symbol {} doesn't have an is_epsilon implementation.",
401 self.value
402 )
403 })
404 .call(py, (), None)
405 .unwrap_or_else(|_| {
406 panic!(
407 "is_epsilon on symbol {} failed to return a value.",
408 self.value
409 )
410 })
411 .extract(py)
412 .unwrap_or_else(|_| {
413 panic!("is_epsilon on symbol {} didn't return a bool.", self.value)
414 })
415 })
416 }
417
418 fn is_unknown(&self) -> bool {
419 Python::with_gil(|py| {
420 self.value
421 .getattr(py, "is_unknown")
422 .unwrap_or_else(|_| {
423 panic!(
424 "Symbol {} doesn't have an is_unknown implementation.",
425 self.value
426 )
427 })
428 .call(py, (), None)
429 .unwrap_or_else(|_| {
430 panic!(
431 "is_unknown on symbol {} failed to return a value.",
432 self.value
433 )
434 })
435 .extract(py)
436 .unwrap_or_else(|_| {
437 panic!("is_unknown on symbol {} didn't return a bool.", self.value)
438 })
439 })
440 }
441
442 fn get_symbol(&self) -> String {
443 Python::with_gil(|py| {
444 self.value
445 .getattr(py, "get_symbol")
446 .unwrap_or_else(|_| {
447 panic!(
448 "Symbol {} doesn't have a get_symbol implementation.",
449 self.value
450 )
451 })
452 .call(py, (), None)
453 .unwrap_or_else(|_| {
454 panic!(
455 "get_symbol on symbol {} failed to return a value.",
456 self.value
457 )
458 })
459 .extract(py)
460 .unwrap_or_else(|_| {
461 panic!("get_symbol on symbol {} didn't return a bool.", self.value)
462 })
463 })
464 }
465}
466
467#[cfg_attr(
468 feature = "python",
469 pyclass(
470 str = "StringSymbol({string:?}, {unknown})",
471 eq,
472 ord,
473 frozen,
474 hash,
475 get_all
476 )
477)]
478#[derive(Clone, Copy, Hash, PartialEq, Eq)]
479#[readonly::make]
480pub struct StringSymbol {
483 string: u32,
484 pub unknown: bool,
486}
487
488impl StringSymbol {
489 pub fn parse(symbol: &str) -> nom::IResult<&str, StringSymbol> {
505 if symbol.is_empty() {
506 return nom::IResult::Err(nom::Err::Error(nom::error::Error::new(
507 symbol,
508 nom::error::ErrorKind::Fail,
509 )));
510 }
511 Ok((
512 "",
513 StringSymbol {
514 string: intern(symbol.to_string()),
515 unknown: false,
516 },
517 ))
518 }
519
520 fn with_symbol<F, X>(&self, f: F) -> X
523 where
524 F: FnOnce(&str) -> X,
525 {
526 with_deinterned(self.string, f)
527 }
528}
529
530impl PartialOrd for StringSymbol {
531 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
532 Some(self.cmp(other))
533 }
534}
535
536impl Ord for StringSymbol {
537 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
538 with_deinterned(other.string, |other_str| {
539 with_deinterned(self.string, |self_str| {
540 (other_str.chars().count(), &self_str, self.unknown).cmp(&(
541 self_str.chars().count(),
542 &other_str,
543 other.unknown,
544 ))
545 })
546 })
547 }
548}
549#[cfg_attr(feature = "python", pymethods)]
550impl StringSymbol {
551 pub fn is_epsilon(&self) -> bool {
554 false
555 }
556
557 pub fn is_unknown(&self) -> bool {
560 self.unknown
561 }
562
563 pub fn get_symbol(&self) -> String {
565 deintern(self.string)
566 }
567
568 #[cfg(feature = "python")]
569 #[new]
570 #[pyo3(signature = (string, unknown = false))]
571 fn new(string: String, unknown: bool) -> Self {
572 StringSymbol {
573 string: intern(string),
574 unknown,
575 }
576 }
577
578 #[cfg(not(feature = "python"))]
579 pub fn new(string: String, unknown: bool) -> Self {
581 StringSymbol {
582 string: intern(string),
583 unknown,
584 }
585 }
586
587 #[deprecated]
588 pub fn __repr__(&self) -> String {
590 format!("StringSymbol({:?}, {})", self.string, self.unknown)
591 }
592}
593
594#[cfg_attr(feature = "python", pyclass(eq, ord, frozen))]
595#[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Copy, Hash)]
596pub enum FlagDiacriticType {
598 U,
600 R,
602 D,
604 C,
606 P,
608 N,
611}
612
613impl FlagDiacriticType {
614 pub fn from_str(s: &str) -> Option<Self> {
617 match s {
618 "U" => Some(FlagDiacriticType::U),
619 "R" => Some(FlagDiacriticType::R),
620 "D" => Some(FlagDiacriticType::D),
621 "C" => Some(FlagDiacriticType::C),
622 "P" => Some(FlagDiacriticType::P),
623 "N" => Some(FlagDiacriticType::N),
624 _ => None,
625 }
626 }
627}
628
629#[cfg_attr(feature = "python", pymethods)]
630impl FlagDiacriticType {
631 #[deprecated]
632 pub fn __repr__(&self) -> String {
634 format!("{:?}", &self)
635 }
636}
637
638#[cfg_attr(
639 feature = "python",
640 pyclass(
641 str = "FlagDiacriticSymbol({flag_type:?}, {key:?}, {value:?})",
642 eq,
643 ord,
644 frozen,
645 hash
646 )
647)]
648#[derive(PartialEq, Eq, Clone, Copy, Hash)]
649#[readonly::make]
650pub struct FlagDiacriticSymbol {
657 pub flag_type: FlagDiacriticType,
659 key: u32,
660 value: u32,
661}
662
663impl PartialOrd for FlagDiacriticSymbol {
664 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
665 Some(self.cmp(other))
666 }
667}
668
669impl Ord for FlagDiacriticSymbol {
670 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
671 let other_str = other.get_symbol();
673 let self_str = self.get_symbol();
674 (other_str.chars().count(), &self_str).cmp(&(self_str.chars().count(), &other_str))
675 }
676}
677
678impl FlagDiacriticSymbol {
679 pub fn parse(symbol: &str) -> nom::IResult<&str, FlagDiacriticSymbol> {
681 let mut parser = (
682 tag("@"),
683 alt((tag("U"), tag("R"), tag("D"), tag("C"), tag("P"), tag("N"))),
684 tag("."),
685 many_m_n(0, 1, (take_until1("."), tag("."))),
686 take_until1("@"),
687 tag("@"),
688 );
689 let (input, (_, diacritic_type, _, named_piece_1, named_piece_2, _)) =
690 parser.parse(symbol)?;
691 let diacritic_type = match FlagDiacriticType::from_str(diacritic_type) {
692 Some(x) => x,
693 None => {
694 return Err(nom::Err::Error(nom::error::Error::new(
695 diacritic_type,
696 nom::error::ErrorKind::Fail,
697 )))
698 }
699 };
700
701 let (name, value) = if !named_piece_1.is_empty() {
702 (named_piece_1[0].0, intern(named_piece_2.to_string()))
703 } else {
704 (named_piece_2, u32::MAX)
705 };
706
707 Ok((
708 input,
709 FlagDiacriticSymbol {
710 flag_type: diacritic_type,
711 key: intern(name.to_string()),
712 value,
713 },
714 ))
715 }
716}
717
718impl FlagDiacriticSymbol {
722 fn _from_symbol_string(symbol: &str) -> KFSTResult<Self> {
723 match FlagDiacriticSymbol::parse(symbol) {
724 Ok(("", symbol)) => KFSTResult::Ok(symbol),
725 Ok((rest, _)) => value_error(format!("String {symbol:?} contains a valid FlagDiacriticSymbol, but it has unparseable text at the end: {rest:?}")),
726 _ => value_error(format!("Not a valid FlagDiacriticSymbol: {symbol:?}"))
727 }
728 }
729
730 #[cfg(not(feature = "python"))]
731 #[deprecated]
732 pub fn from_symbol_string(symbol: &str) -> KFSTResult<Self> {
734 FlagDiacriticSymbol::_from_symbol_string(symbol)
735 }
736
737 fn _new(flag_type: String, key: String, value: Option<String>) -> KFSTResult<Self> {
738 let flag_type = match FlagDiacriticType::from_str(&flag_type) {
739 Some(x) => x,
740 None => value_error(format!(
741 "String {flag_type:?} is not a valid FlagDiacriticType specifier"
742 ))?,
743 };
744 Ok(FlagDiacriticSymbol {
745 flag_type,
746 key: intern(key),
747 value: value.map(intern).unwrap_or(u32::MAX),
748 })
749 }
750
751 #[cfg(not(feature = "python"))]
752 pub fn new(flag_type: String, key: String, value: Option<String>) -> KFSTResult<Self> {
754 FlagDiacriticSymbol::_new(flag_type, key, value)
755 }
756
757 #[cfg(not(feature = "python"))]
758 pub fn key(self) -> String {
760 deintern(self.key)
761 }
762
763 #[cfg(not(feature = "python"))]
764 pub fn value(self) -> String {
766 deintern(self.value)
767 }
768}
769
770#[cfg_attr(feature = "python", pymethods)]
771impl FlagDiacriticSymbol {
772 pub fn is_epsilon(&self) -> bool {
775 true
776 }
777
778 pub fn is_unknown(&self) -> bool {
781 false
782 }
783
784 pub fn get_symbol(&self) -> String {
785 match self.value {
786 u32::MAX => with_deinterned(self.key, |key| format!("@{:?}.{}@", self.flag_type, key)),
787 value => with_deinterned(self.key, |key| {
788 with_deinterned(value, |value| {
789 format!("@{:?}.{}.{}@", self.flag_type, key, value)
790 })
791 }),
792 }
793 }
794
795 #[cfg(feature = "python")]
796 #[getter]
797 fn flag_type(&self) -> String {
798 format!("{:?}", self.flag_type)
799 }
800
801 #[cfg(not(feature = "python"))]
802 pub fn flag_type(&self) -> String {
804 format!("{:?}", self.flag_type)
805 }
806
807 #[cfg(feature = "python")]
808 #[getter]
809 fn key(&self) -> String {
810 deintern(self.key)
811 }
812
813 #[cfg(feature = "python")]
814 #[getter]
815 fn value(&self) -> String {
816 deintern(self.value)
817 }
818
819 #[cfg(feature = "python")]
820 #[new]
821 fn new(flag_type: String, key: String, value: Option<String>) -> KFSTResult<Self> {
822 FlagDiacriticSymbol::_new(flag_type, key, value)
823 }
824
825 #[deprecated]
826 pub fn __repr__(&self) -> String {
828 match self.value {
829 u32::MAX => with_deinterned(self.key, |key| {
830 format!("FlagDiacriticSymbol({:?}, {:?})", self.flag_type, key)
831 }),
832 value => with_deinterned(self.key, |key| {
833 with_deinterned(value, |value| {
834 format!(
835 "FlagDiacriticSymbol({:?}, {:?}, {:?})",
836 self.flag_type, key, value
837 )
838 })
839 }),
840 }
841 }
842
843 #[cfg(feature = "python")]
844 #[staticmethod]
845 fn is_flag_diacritic(symbol: &str) -> bool {
846 matches!(FlagDiacriticSymbol::parse(symbol), Ok(("", _)))
847 }
848
849 #[cfg(not(feature = "python"))]
850 pub fn is_flag_diacritic(symbol: &str) -> bool {
853 matches!(FlagDiacriticSymbol::parse(symbol), Ok(("", _)))
854 }
855
856 #[cfg(feature = "python")]
857 #[staticmethod]
858 fn from_symbol_string(symbol: &str) -> KFSTResult<Self> {
859 FlagDiacriticSymbol::_from_symbol_string(symbol)
860 }
861}
862
863impl std::fmt::Debug for FlagDiacriticSymbol {
864 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
865 write!(f, "{}", self.get_symbol())
866 }
867}
868
869#[cfg_attr(feature = "python", pyclass(eq, ord, frozen, hash))]
870#[derive(PartialEq, Eq, Clone, Hash, Copy)]
871pub enum SpecialSymbol {
873 EPSILON,
877 IDENTITY,
881 UNKNOWN,
884}
885
886impl SpecialSymbol {
887 pub fn parse(symbol: &str) -> nom::IResult<&str, SpecialSymbol> {
896 let (rest, value) = alt((
897 tag("@_EPSILON_SYMBOL_@"),
898 tag("@0@"),
899 tag("@_IDENTITY_SYMBOL_@"),
900 tag("@_UNKNOWN_SYMBOL_@"),
901 ))
902 .parse(symbol)?;
903
904 let sym = match value {
905 "@_EPSILON_SYMBOL_@" => SpecialSymbol::EPSILON,
906 "@0@" => SpecialSymbol::EPSILON,
907 "@_IDENTITY_SYMBOL_@" => SpecialSymbol::IDENTITY,
908 "@_UNKNOWN_SYMBOL_@" => SpecialSymbol::UNKNOWN,
909 _ => panic!(),
910 };
911 Ok((rest, sym))
912 }
913
914 fn _from_symbol_string(symbol: &str) -> KFSTResult<Self> {
915 match SpecialSymbol::parse(symbol) {
916 Ok(("", result)) => KFSTResult::Ok(result),
917 _ => value_error(format!("Not a valid SpecialSymbol: {symbol:?}")),
918 }
919 }
920
921 fn with_symbol<F, X>(&self, f: F) -> X
924 where
925 F: FnOnce(&str) -> X,
926 {
927 match self {
928 SpecialSymbol::EPSILON => f("@_EPSILON_SYMBOL_@"),
929 SpecialSymbol::IDENTITY => f("@_IDENTITY_SYMBOL_@"),
930 SpecialSymbol::UNKNOWN => f("@_UNKNOWN_SYMBOL_@"),
931 }
932 }
933
934 #[cfg(not(feature = "python"))]
935 pub fn from_symbol_string(symbol: &str) -> KFSTResult<Self> {
948 SpecialSymbol::_from_symbol_string(symbol)
949 }
950}
951
952impl PartialOrd for SpecialSymbol {
953 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
954 Some(self.cmp(other))
955 }
956}
957
958impl Ord for SpecialSymbol {
959 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
960 self.with_symbol(|self_str| {
962 other.with_symbol(|other_str| {
963 (other_str.chars().count(), &self_str).cmp(&(self_str.chars().count(), &other_str))
964 })
965 })
966 }
967}
968
969#[cfg_attr(feature = "python", pymethods)]
970impl SpecialSymbol {
971 pub fn is_epsilon(&self) -> bool {
975 self == &SpecialSymbol::EPSILON
976 }
977
978 pub fn is_unknown(&self) -> bool {
982 false
983 }
984
985 pub fn get_symbol(&self) -> String {
992 self.with_symbol(|x| x.to_string())
993 }
994
995 #[cfg(feature = "python")]
996 #[staticmethod]
997 fn from_symbol_string(symbol: &str) -> KFSTResult<Self> {
998 SpecialSymbol::_from_symbol_string(symbol)
999 }
1000
1001 #[cfg(feature = "python")]
1002 #[staticmethod]
1003 fn is_special_symbol(symbol: &str) -> bool {
1004 SpecialSymbol::from_symbol_string(symbol).is_ok()
1005 }
1006
1007 #[cfg(not(feature = "python"))]
1008 pub fn is_special_symbol(symbol: &str) -> bool {
1020 SpecialSymbol::from_symbol_string(symbol).is_ok()
1021 }
1022}
1023
1024impl std::fmt::Debug for SpecialSymbol {
1025 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1026 self.with_symbol(|symbol| write!(f, "{symbol}"))
1027 }
1028}
1029
1030impl std::fmt::Debug for StringSymbol {
1031 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1032 self.with_symbol(|symbol| write!(f, "{symbol}"))
1033 }
1034}
1035
1036#[cfg(feature = "python")]
1037#[pyfunction]
1038fn from_symbol_string(symbol: &str, py: Python) -> PyResult<Py<PyAny>> {
1039 Symbol::parse(symbol).unwrap().1.into_py_any(py)
1040}
1041
1042#[cfg(not(feature = "python"))]
1043pub fn from_symbol_string(symbol: &str) -> Option<Symbol> {
1052 Symbol::parse(symbol).ok().map(|(_, sym)| sym)
1053}
1054
1055#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1058pub enum Symbol {
1059 Special(SpecialSymbol),
1061 Flag(FlagDiacriticSymbol),
1063 String(StringSymbol),
1065 #[cfg(feature = "python")]
1066 External(PyObjectSymbol),
1068 Raw(RawSymbol),
1070}
1071
1072#[cfg(feature = "python")]
1073impl<'py> IntoPyObject<'py> for Symbol {
1074 type Target = PyAny;
1075
1076 type Output = Bound<'py, Self::Target>;
1077
1078 type Error = pyo3::PyErr;
1079
1080 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
1081 match self {
1082 Symbol::Special(special_symbol) => special_symbol.into_bound_py_any(py),
1083 Symbol::Flag(flag_diacritic_symbol) => flag_diacritic_symbol.into_bound_py_any(py),
1084 Symbol::String(string_symbol) => string_symbol.into_bound_py_any(py),
1085 Symbol::External(pyobject_symbol) => pyobject_symbol.into_bound_py_any(py),
1086 Symbol::Raw(raw_symbol) => raw_symbol.into_bound_py_any(py),
1087 }
1088 }
1089}
1090
1091impl PartialOrd for Symbol {
1092 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
1093 Some(self.cmp(other))
1094 }
1095}
1096
1097impl Ord for Symbol {
1098 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
1108 let self_is_tokenizable = matches!(
1111 self,
1112 Symbol::Special(_) | Symbol::Flag(_) | Symbol::String(_)
1113 );
1114
1115 let other_is_tokenizable = matches!(
1116 other,
1117 Symbol::Special(_) | Symbol::Flag(_) | Symbol::String(_)
1118 );
1119
1120 match (self_is_tokenizable, other_is_tokenizable) {
1121 (true, true) => {
1123 let result = self.with_symbol(|self_sym| {
1127 other.with_symbol(|other_sym| {
1128 (other_sym.chars().count(), &self_sym)
1129 .cmp(&(self_sym.chars().count(), &other_sym))
1130 })
1131 });
1132 if result != Ordering::Equal {
1133 result
1134 } else {
1135 match (self, other) {
1136 (Symbol::Special(_), Symbol::Flag(_)) => Ordering::Greater,
1138 (Symbol::Special(_), Symbol::String(_)) => Ordering::Less,
1139 (Symbol::Flag(_), Symbol::Special(_)) => Ordering::Less,
1140 (Symbol::Flag(_), Symbol::String(_)) => Ordering::Less,
1141 (Symbol::String(_), Symbol::Special(_)) => Ordering::Greater,
1142 (Symbol::String(_), Symbol::Flag(_)) => Ordering::Greater,
1143
1144 (Symbol::Special(a), Symbol::Special(b)) => a.cmp(b),
1146 (Symbol::Flag(a), Symbol::Flag(b)) => a.cmp(b),
1147 (Symbol::String(a), Symbol::String(b)) => a.cmp(b),
1148
1149 _ => unreachable!(),
1150 }
1151 }
1152 }
1153 _ => {
1155 match (self, other) {
1156 #[cfg(feature = "python")]
1157 (Symbol::External(left), right) => {
1158 Python::with_gil(|py| {
1159 if left
1162 .value
1163 .getattr(py, "__lt__")
1164 .unwrap_or_else(|_| {
1165 panic!(
1166 "Symbol {} doesn't have a __lt__ implementation.",
1167 left.value
1168 )
1169 })
1170 .call1(py, (right.clone().into_pyobject(py).unwrap(),))
1171 .unwrap_or_else(|_| {
1172 panic!(
1173 "__lt__ on symbol {} failed to return a value.",
1174 left.value
1175 )
1176 })
1177 .extract::<bool>(py)
1178 .unwrap_or_else(|_| {
1179 panic!("__lt__ on symbol {} didn't return a bool.", left.value)
1180 })
1181 {
1182 return Ordering::Less;
1183 }
1184
1185 if left
1188 .value
1189 .getattr(py, "__eq__")
1190 .unwrap_or_else(|_| {
1191 panic!(
1192 "Symbol {} doesn't have a __eq__ implementation.",
1193 left.value
1194 )
1195 })
1196 .call1(py, (right.clone().into_pyobject(py).unwrap(),))
1197 .unwrap_or_else(|_| {
1198 panic!(
1199 "__eq__ on symbol {} failed to return a value.",
1200 left.value
1201 )
1202 })
1203 .extract::<bool>(py)
1204 .unwrap_or_else(|_| {
1205 panic!("__eq__ on symbol {} didn't return a bool.", left.value)
1206 })
1207 {
1208 return Ordering::Equal;
1209 }
1210
1211 Ordering::Greater
1214 })
1215 }
1216 #[cfg(feature = "python")]
1217 (left, Symbol::External(right)) => {
1218 Python::with_gil(|py| {
1219 if right
1222 .value
1223 .getattr(py, "__lt__")
1224 .unwrap_or_else(|_| {
1225 panic!(
1226 "Symbol {} doesn't have a __lt__ implementation.",
1227 right.value
1228 )
1229 })
1230 .call1(py, (left.clone().into_pyobject(py).unwrap(),))
1231 .unwrap_or_else(|_| {
1232 panic!(
1233 "__lt__ on symbol {} failed to return a value.",
1234 right.value
1235 )
1236 })
1237 .extract::<bool>(py)
1238 .unwrap_or_else(|_| {
1239 panic!("__lt__ on symbol {} didn't return a bool.", right.value)
1240 })
1241 {
1242 return Ordering::Greater;
1243 }
1244
1245 if right
1248 .value
1249 .getattr(py, "__eq__")
1250 .unwrap_or_else(|_| {
1251 panic!(
1252 "Symbol {} doesn't have a __eq__ implementation.",
1253 right.value
1254 )
1255 })
1256 .call1(py, (left.clone().into_pyobject(py).unwrap(),))
1257 .unwrap_or_else(|_| {
1258 panic!(
1259 "__eq__ on symbol {} failed to return a value.",
1260 right.value
1261 )
1262 })
1263 .extract::<bool>(py)
1264 .unwrap_or_else(|_| {
1265 panic!("__eq__ on symbol {} didn't return a bool.", right.value)
1266 })
1267 {
1268 return Ordering::Equal;
1269 }
1270
1271 Ordering::Less
1274 })
1275 }
1276
1277 (Symbol::Raw(a), Symbol::Raw(b)) => a.cmp(b),
1279
1280 (Symbol::Raw(_), _) => Ordering::Less,
1282 (_, Symbol::Raw(_)) => Ordering::Greater,
1283
1284 _ => unreachable!(),
1285 }
1286 }
1287 }
1288 }
1289}
1290
1291impl Symbol {
1292 pub fn is_epsilon(&self) -> bool {
1299 match self {
1300 Symbol::Special(special_symbol) => special_symbol.is_epsilon(),
1301 Symbol::Flag(flag_diacritic_symbol) => flag_diacritic_symbol.is_epsilon(),
1302 Symbol::String(string_symbol) => string_symbol.is_epsilon(),
1303 #[cfg(feature = "python")]
1304 Symbol::External(py_object_symbol) => py_object_symbol.is_epsilon(),
1305 Symbol::Raw(raw_symbol) => raw_symbol.is_epsilon(),
1306 }
1307 }
1308
1309 pub fn is_unknown(&self) -> bool {
1312 match self {
1313 Symbol::Special(special_symbol) => special_symbol.is_unknown(),
1314 Symbol::Flag(flag_diacritic_symbol) => flag_diacritic_symbol.is_unknown(),
1315 Symbol::String(string_symbol) => string_symbol.is_unknown(),
1316 #[cfg(feature = "python")]
1317 Symbol::External(py_object_symbol) => py_object_symbol.is_unknown(),
1318 Symbol::Raw(raw_symbol) => raw_symbol.is_unknown(),
1319 }
1320 }
1321
1322 pub fn get_symbol(&self) -> String {
1324 match self {
1325 Symbol::Special(special_symbol) => special_symbol.get_symbol(),
1326 Symbol::Flag(flag_diacritic_symbol) => flag_diacritic_symbol.get_symbol(),
1327 Symbol::String(string_symbol) => string_symbol.get_symbol(),
1328 #[cfg(feature = "python")]
1329 Symbol::External(py_object_symbol) => py_object_symbol.get_symbol(),
1330 Symbol::Raw(raw_symbol) => raw_symbol.get_symbol(),
1331 }
1332 }
1333
1334 pub fn with_symbol<F, X>(&self, f: F) -> X
1337 where
1338 F: FnOnce(&str) -> X,
1339 {
1340 match self {
1341 Symbol::Special(special_symbol) => special_symbol.with_symbol(f),
1342 Symbol::Flag(flag_diacritic_symbol) => f(&flag_diacritic_symbol.get_symbol()),
1343 Symbol::String(string_symbol) => string_symbol.with_symbol(f),
1344 #[cfg(feature = "python")]
1345 Symbol::External(py_object_symbol) => f(&py_object_symbol.get_symbol()),
1346 Symbol::Raw(raw_symbol) => f(&raw_symbol.get_symbol()),
1347 }
1348 }
1349}
1350
1351impl Symbol {
1352 pub fn parse(symbol: &str) -> nom::IResult<&str, Symbol> {
1370 let mut parser = alt((
1371 |x| {
1372 (FlagDiacriticSymbol::parse, nom::combinator::eof)
1373 .parse(x)
1374 .map(|y| (y.0, Symbol::Flag(y.1 .0)))
1375 },
1376 |x| {
1377 (SpecialSymbol::parse, nom::combinator::eof)
1378 .parse(x)
1379 .map(|y| (y.0, Symbol::Special(y.1 .0)))
1380 },
1381 |x| StringSymbol::parse(x).map(|y| (y.0, Symbol::String(y.1))),
1382 ));
1383 parser.parse(symbol)
1384 }
1385}
1386
1387#[cfg(feature = "python")]
1388impl<'py> FromPyObject<'_, 'py> for Symbol {
1389 type Error = PyErr;
1390 fn extract(ob: Borrowed<'_, 'py, PyAny>) -> PyResult<Self> {
1391 ob.extract()
1392 .map(Symbol::Special)
1393 .or_else(|_| ob.extract().map(Symbol::Flag))
1394 .or_else(|_| ob.extract().map(Symbol::String))
1395 .or_else(|_| ob.extract().map(Symbol::Raw))
1396 .or_else(|_| ob.extract().map(Symbol::External))
1397 }
1398}
1399#[derive(Clone, Debug, PartialEq, Hash, PartialOrd, Eq, Ord)]
1400#[readonly::make]
1401pub struct FlagMap(Vec<(u32, bool, u32)>);
1408
1409impl FromIterator<(String, (bool, String))> for FlagMap {
1412 fn from_iter<T: IntoIterator<Item = (String, (bool, String))>>(iter: T) -> Self {
1413 let mut vals: Vec<_> = iter
1414 .into_iter()
1415 .map(|(a, (b, c))| (intern(a), b, intern(c)))
1416 .collect();
1417 vals.sort();
1418 FlagMap(vals)
1419 }
1420}
1421
1422impl<T> From<T> for FlagMap
1425where
1426 T: IntoIterator<Item = (String, (bool, String))>,
1427{
1428 fn from(value: T) -> Self {
1429 FlagMap::from_iter(value)
1430 }
1431}
1432
1433impl FlagMap {
1434 pub fn new() -> FlagMap {
1436 FlagMap(vec![])
1437 }
1438
1439 pub fn remove(&self, flag: u32) -> FlagMap {
1441 let pp = self.0.partition_point(|v| v.0 < flag);
1442 if pp < self.0.len() && self.0[pp].0 == flag {
1443 let mut new_vals = self.0.clone();
1444 new_vals.remove(pp);
1445 FlagMap(new_vals)
1446 } else {
1447 self.clone()
1448 }
1449 }
1450
1451 pub fn get(&self, flag: u32) -> Option<(bool, u32)> {
1453 let pp = self.0.partition_point(|v| v.0 < flag);
1454 if pp < self.0.len() && self.0[pp].0 == flag {
1455 Some((self.0[pp].1, self.0[pp].2))
1456 } else {
1457 None
1458 }
1459 }
1460
1461 pub fn insert(&self, flag: u32, value: (bool, u32)) -> FlagMap {
1464 let pp = self.0.partition_point(|v| v.0 < flag);
1465 let mut new_vals = self.0.clone();
1466 if pp == self.0.len() || self.0[pp].0 != flag {
1467 new_vals.insert(pp, (flag, value.0, value.1));
1468 } else {
1469 new_vals[pp] = (flag, value.0, value.1);
1470 }
1471 FlagMap(new_vals)
1472 }
1473}
1474
1475impl Default for FlagMap {
1476 fn default() -> Self {
1478 Self::new()
1479 }
1480}
1481
1482#[cfg(feature = "python")]
1483impl<'py> FromPyObject<'_, 'py> for FlagMap {
1484 type Error = PyErr;
1485 fn extract(ob: Borrowed<'_, 'py, PyAny>) -> PyResult<Self> {
1486 let mut as_map: Vec<(u32, bool, u32)> = ob
1487 .getattr("items")?
1488 .call0()?
1489 .try_iter()?
1490 .map(|x| x.unwrap().extract().unwrap())
1491 .map(|(key, value): (String, (bool, String))| (intern(key), value.0, intern(value.1)))
1492 .collect();
1493 as_map.sort();
1494 Ok(FlagMap { 0: as_map })
1495 }
1496}
1497
1498#[cfg(feature = "python")]
1499impl<'py> IntoPyObject<'py> for FlagMap {
1500 type Target = PyAny;
1501
1502 type Output = Bound<'py, Self::Target>;
1503
1504 type Error = pyo3::PyErr;
1505
1506 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
1507 let collection = self
1508 .0
1509 .into_iter()
1510 .map(|(a, b, c)| (deintern(a), (b, deintern(c))))
1511 .collect::<Vec<_>>()
1512 .into_pyobject(py)?;
1513 let immutables = PyModule::import(py, "immutables")?;
1514 let map_class = immutables.getattr("Map")?;
1515 let new = map_class.call_method1("__new__", (&map_class,))?;
1516 map_class.call_method1("__init__", (&new, collection))?;
1517 Ok(new)
1518 }
1519}
1520
1521#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
1524pub enum FSTLinkedList<T> {
1529 Some((Symbol, T), Arc<FSTLinkedList<T>>),
1530 None,
1531}
1532
1533impl<T: Clone + Default> IntoIterator for FSTLinkedList<T> {
1534 type Item = (Symbol, T);
1535
1536 type IntoIter = <Vec<<FSTLinkedList<T> as IntoIterator>::Item> as IntoIterator>::IntoIter;
1537
1538 fn into_iter(self) -> Self::IntoIter {
1539 self.to_vec().into_iter()
1540 }
1541}
1542
1543impl<T> FromIterator<(Symbol, T)> for FSTLinkedList<T> {
1544 fn from_iter<I: IntoIterator<Item = (Symbol, T)>>(iter: I) -> Self {
1545 let mut symbol_mappings = FSTLinkedList::None;
1546 for (symbol, input_index) in iter {
1547 symbol_mappings = FSTLinkedList::Some((symbol, input_index), Arc::new(symbol_mappings));
1548 }
1549 symbol_mappings
1550 }
1551}
1552
1553impl<T: Clone + Default> FSTLinkedList<T> {
1554 fn to_vec(self) -> Vec<(Symbol, T)> {
1555 let mut symbol_mappings = vec![];
1556 let mut symbol_mapping_state = &self;
1557 while let FSTLinkedList::Some(value, list) = symbol_mapping_state {
1558 symbol_mapping_state = list;
1559 symbol_mappings.push(value.clone());
1560 }
1561 symbol_mappings.into_iter().rev().collect()
1562 }
1563
1564 fn from_vec(output_symbols: Vec<Symbol>, input_indices: Vec<T>) -> FSTLinkedList<T> {
1565 if input_indices.is_empty() {
1566 output_symbols
1567 .into_iter()
1568 .zip(std::iter::repeat(T::default()))
1569 .collect()
1570 } else if input_indices.len() == output_symbols.len() {
1571 output_symbols.into_iter().zip(input_indices).collect()
1572 } else {
1573 panic!("Mismatch in input index and output symbol len: {} output symbols and {} input indices.", output_symbols.len(), input_indices.len());
1574 }
1575 }
1576}
1577
1578#[cfg(not(feature = "python"))]
1579pub type FSTState<T> = InternalFSTState<T>;
1580
1581#[cfg(feature = "python")]
1582#[cfg_attr(feature = "python", pyclass(frozen, eq, hash))]
1583#[derive(Clone, Debug, PartialEq, PartialOrd, Hash)]
1584pub struct FSTState {
1585 payload: Result<InternalFSTState<usize>, InternalFSTState<()>>,
1586}
1587
1588#[cfg(feature = "python")]
1589#[pymethods]
1590impl FSTState {
1591 #[new]
1592 #[pyo3(signature = (state_num, path_weight=0.0, input_flags = FlagMap::new(), output_flags = FlagMap::new(), output_symbols=vec![], input_indices=Some(vec![])))]
1593 fn new(
1594 state_num: u64,
1595 path_weight: f64,
1596 input_flags: FlagMap,
1597 output_flags: FlagMap,
1598 output_symbols: Vec<Symbol>,
1599 input_indices: Option<Vec<usize>>,
1600 ) -> Self {
1601 match input_indices {
1602 Some(idxs) => FSTState {
1603 payload: Ok(InternalFSTState::new(
1604 state_num,
1605 path_weight,
1606 input_flags,
1607 output_flags,
1608 output_symbols,
1609 idxs,
1610 )),
1611 },
1612 None => FSTState {
1613 payload: Err(InternalFSTState::new(
1614 state_num,
1615 path_weight,
1616 input_flags,
1617 output_flags,
1618 output_symbols,
1619 vec![],
1620 )),
1621 },
1622 }
1623 }
1624
1625 fn strip_indices(&self) -> FSTState {
1626 match &self.payload {
1627 Ok(p) => FSTState {
1628 payload: Err(p.clone().convert_indices()),
1629 },
1630 Err(_p) => self.clone(),
1631 }
1632 }
1633
1634 fn ensure_indices(&self) -> FSTState {
1635 match &self.payload {
1636 Ok(_p) => self.clone(),
1637 Err(p) => FSTState {
1638 payload: Ok(p.clone().convert_indices()),
1639 },
1640 }
1641 }
1642
1643 #[getter]
1644 fn output_symbols<'a>(&'a self, py: Python<'a>) -> Result<Bound<'a, PyTuple>, PyErr> {
1645 PyTuple::new(
1646 py,
1647 match &self.payload {
1648 Ok(p) => p.output_symbols(),
1649 Err(p) => p.output_symbols(),
1650 },
1651 )
1652 }
1653
1654 #[getter]
1655 fn input_indices<'a>(&'a self, py: Python<'a>) -> PyResult<Bound<'a, PyAny>> {
1656 match &self.payload {
1657 Ok(p) => PyTuple::new(py, p.input_indices().unwrap())?.into_bound_py_any(py),
1658 Err(_p) => PyNone::get(py).into_bound_py_any(py),
1659 }
1660 }
1661
1662 #[deprecated]
1663 pub fn __repr__(&self) -> String {
1665 match &self.payload {
1666 Ok(p) => p.__repr__(),
1667 Err(p) => p.__repr__(),
1668 }
1669 }
1670
1671 #[getter]
1672 pub fn state_num(&self) -> u64 {
1673 match &self.payload {
1674 Ok(p) => p.state_num,
1675 Err(p) => p.state_num,
1676 }
1677 }
1678
1679 #[getter]
1680 pub fn path_weight(&self) -> f64 {
1681 match &self.payload {
1682 Ok(p) => p.path_weight,
1683 Err(p) => p.path_weight,
1684 }
1685 }
1686
1687 #[getter]
1688 pub fn input_flags(&self) -> FlagMap {
1689 match &self.payload {
1690 Ok(p) => p.input_flags.clone(),
1691 Err(p) => p.input_flags.clone(),
1692 }
1693 }
1694
1695 #[getter]
1696 pub fn output_flags(&self) -> FlagMap {
1697 match &self.payload {
1698 Ok(p) => p.output_flags.clone(),
1699 Err(p) => p.output_flags.clone(),
1700 }
1701 }
1702}
1703
1704#[cfg(feature = "python")]
1705impl<'py, T: UsizeOrUnit + Default> FromPyObject<'_, 'py> for InternalFSTState<T> {
1706 type Error = PyErr;
1707 fn extract(ob: Borrowed<'_, 'py, PyAny>) -> PyResult<Self> {
1708 let wrapped: FSTState = ob.extract()?;
1709 match wrapped.payload {
1710 Ok(p) => Ok(p.convert_indices()),
1711 Err(p) => Ok(p.convert_indices()),
1712 }
1713 }
1714}
1715
1716#[cfg(feature = "python")]
1717impl<'py> IntoPyObject<'py> for InternalFSTState<usize> {
1718 type Target = FSTState;
1719
1720 type Output = Bound<'py, Self::Target>;
1721
1722 type Error = pyo3::PyErr;
1723
1724 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
1725 FSTState { payload: Ok(self) }.into_pyobject(py)
1726 }
1727}
1728
1729#[cfg(feature = "python")]
1730impl<'py> IntoPyObject<'py> for InternalFSTState<()> {
1731 type Target = FSTState;
1732
1733 type Output = Bound<'py, Self::Target>;
1734
1735 type Error = pyo3::PyErr;
1736
1737 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
1738 FSTState { payload: Err(self) }.into_pyobject(py)
1739 }
1740}
1741
1742#[derive(Clone, Debug, PartialEq, PartialOrd)]
1743#[readonly::make]
1744pub struct InternalFSTState<T> {
1753 pub state_num: u64,
1755 pub path_weight: f64,
1757 pub input_flags: FlagMap,
1759 pub output_flags: FlagMap,
1761 pub symbol_mappings: FSTLinkedList<T>,
1763}
1764
1765impl<T: Clone + Default + Debug + UsizeOrUnit> InternalFSTState<T> {
1766 fn convert_indices<T2: UsizeOrUnit + Default>(self) -> InternalFSTState<T2> {
1767 let new_mappings = self
1768 .symbol_mappings
1769 .into_iter()
1770 .map(|(sym, idx)| (sym, T2::convert(idx)))
1771 .collect();
1772 InternalFSTState {
1773 state_num: self.state_num,
1774 path_weight: self.path_weight,
1775 input_flags: self.input_flags,
1776 output_flags: self.output_flags,
1777 symbol_mappings: new_mappings,
1778 }
1779 }
1780}
1781
1782impl<T> Default for InternalFSTState<T> {
1783 fn default() -> Self {
1785 Self {
1786 state_num: 0,
1787 path_weight: 0.0,
1788 input_flags: FlagMap::new(),
1789 output_flags: FlagMap::new(),
1790 symbol_mappings: FSTLinkedList::None,
1791 }
1792 }
1793}
1794
1795impl<T: Hash> Hash for InternalFSTState<T> {
1796 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
1797 self.state_num.hash(state);
1798 self.path_weight.to_be_bytes().hash(state);
1799 self.input_flags.hash(state);
1800 self.output_flags.hash(state);
1801 self.symbol_mappings.hash(state);
1802 }
1803}
1804
1805fn _test_flag(stored_val: &(bool, u32), queried_val: u32) -> bool {
1806 stored_val.0 == (stored_val.1 == queried_val)
1807}
1808
1809impl<T: Clone + Default> InternalFSTState<T> {
1810 fn _new(state: u64) -> Self {
1811 InternalFSTState {
1812 state_num: state,
1813 path_weight: 0.0,
1814 input_flags: FlagMap::new(),
1815 output_flags: FlagMap::new(),
1816 symbol_mappings: FSTLinkedList::None,
1817 }
1818 }
1819
1820 fn __new<F>(
1821 state: u64,
1822 path_weight: f64,
1823 input_flags: F,
1824 output_flags: F,
1825 output_symbols: Vec<Symbol>,
1826 input_indices: Vec<T>,
1827 ) -> Self
1828 where
1829 F: Into<FlagMap>,
1830 {
1831 InternalFSTState {
1832 state_num: state,
1833 path_weight,
1834 input_flags: input_flags.into(),
1835 output_flags: output_flags.into(),
1836 symbol_mappings: FSTLinkedList::from_vec(output_symbols, input_indices),
1837 }
1838 }
1839
1840 #[cfg(not(feature = "python"))]
1841 pub fn new<F>(
1845 state_num: u64,
1846 path_weight: f64,
1847 input_flags: F,
1848 output_flags: F,
1849 output_symbols: Vec<Symbol>,
1850 input_indices: Vec<T>,
1851 ) -> Self
1852 where
1853 F: Into<FlagMap>,
1854 {
1855 InternalFSTState::__new(
1856 state_num,
1857 path_weight,
1858 input_flags,
1859 output_flags,
1860 output_symbols,
1861 input_indices,
1862 )
1863 }
1864}
1865
1866impl<T: Clone + Debug + Default + UsizeOrUnit> InternalFSTState<T> {
1867 #[cfg(feature = "python")]
1868 fn new(
1869 state_num: u64,
1870 path_weight: f64,
1871 input_flags: FlagMap,
1872 output_flags: FlagMap,
1873 output_symbols: Vec<Symbol>,
1874 input_indices: Vec<T>,
1875 ) -> Self {
1876 InternalFSTState {
1877 state_num,
1878 path_weight,
1879 input_flags,
1880 output_flags,
1881 symbol_mappings: FSTLinkedList::from_vec(output_symbols, input_indices),
1882 }
1883 }
1884
1885 pub fn output_symbols(&self) -> Vec<Symbol> {
1899 self.symbol_mappings
1900 .clone()
1901 .into_iter()
1902 .map(|x| x.0)
1903 .collect()
1904 }
1905
1906 pub fn input_indices(&self) -> Option<Vec<usize>> {
1922 T::branch(
1923 || {
1924 Some(
1925 self.symbol_mappings
1926 .clone()
1927 .into_iter()
1928 .map(|x| x.1.as_usize())
1929 .collect(),
1930 )
1931 },
1932 || None,
1933 )
1934 }
1935
1936 #[deprecated]
1937 pub fn __repr__(&self) -> String {
1939 format!(
1940 "FSTState({}, {}, {:?}, {:?}, {:?}, {:?})",
1941 self.state_num,
1942 self.path_weight,
1943 self.input_flags,
1944 self.output_flags,
1945 self.symbol_mappings
1946 .clone()
1947 .into_iter()
1948 .map(|x| x.0)
1949 .collect::<Vec<_>>(),
1950 self.symbol_mappings
1951 .clone()
1952 .into_iter()
1953 .map(|x| x.1)
1954 .collect::<Vec<_>>()
1955 )
1956 }
1957}
1958
1959fn unescape_att_symbol(att_symbol: &str) -> String {
1962 att_symbol
1963 .replace("@_TAB_@", "\t")
1964 .replace("@_SPACE_@", " ")
1965}
1966
1967fn escape_att_symbol(symbol: &str) -> String {
1970 symbol.replace("\t", "@_TAB_@").replace(" ", "@_SPACE_@")
1971}
1972
1973pub trait UsizeOrUnit: Sized {
1974 const ZERO: Self;
1975 const INCREMENT: Self;
1976 fn branch<F1, F2, B>(f1: F1, f2: F2) -> B
1977 where
1978 F1: FnOnce() -> B,
1979 F2: FnOnce() -> B;
1980 fn as_usize(&self) -> usize;
1981 fn convert<T: UsizeOrUnit>(x: T) -> Self;
1982}
1983
1984impl UsizeOrUnit for usize {
1985 const ZERO: Self = 0;
1986
1987 const INCREMENT: Self = 1;
1988
1989 fn as_usize(&self) -> usize {
1990 *self
1991 }
1992
1993 fn convert<T: UsizeOrUnit>(x: T) -> Self {
1994 x.as_usize()
1995 }
1996
1997 fn branch<F1, F2, B>(f1: F1, _f2: F2) -> B
1998 where
1999 F1: FnOnce() -> B,
2000 F2: FnOnce() -> B,
2001 {
2002 f1()
2003 }
2004}
2005
2006impl UsizeOrUnit for () {
2007 fn as_usize(&self) -> usize {
2008 0
2009 }
2010
2011 const ZERO: Self = ();
2012
2013 const INCREMENT: Self = ();
2014
2015 fn convert<T: UsizeOrUnit>(_x: T) -> Self {}
2016
2017 fn branch<F1, F2, B>(_f1: F1, f2: F2) -> B
2018 where
2019 F1: FnOnce() -> B,
2020 F2: FnOnce() -> B,
2021 {
2022 f2()
2023 }
2024}
2025
2026#[cfg_attr(feature = "python", pyclass(frozen, get_all))]
2027#[readonly::make]
2028pub struct FST {
2067 pub final_states: IndexMap<u64, f64>,
2069 pub rules: IndexMap<u64, IndexMap<Symbol, Vec<(u64, Symbol, f64)>>>,
2071 pub symbols: Vec<Symbol>,
2073 #[deprecated]
2075 pub debug: bool,
2076}
2077
2078impl FST {
2079 fn _run_fst<T: Clone + UsizeOrUnit>(
2080 &self,
2081 input_symbols: &[Symbol],
2082 state: &InternalFSTState<T>,
2083 post_input_advance: bool,
2084 result: &mut Vec<(bool, bool, InternalFSTState<T>)>,
2085 input_symbol_index: usize,
2086 keep_non_final: bool,
2087 ) {
2088 let transitions = self.rules.get(&state.state_num);
2089 let isymbol = if input_symbols.len() - input_symbol_index == 0 {
2090 match self.final_states.get(&state.state_num) {
2091 Some(&weight) => {
2092 result.push((
2094 true,
2095 post_input_advance,
2096 InternalFSTState {
2097 state_num: state.state_num,
2098 path_weight: state.path_weight + weight,
2099 input_flags: state.input_flags.clone(),
2100 output_flags: state.output_flags.clone(),
2101 symbol_mappings: state.symbol_mappings.clone(),
2102 },
2103 ));
2104 }
2105 None => {
2106 if keep_non_final {
2108 result.push((false, post_input_advance, state.clone()));
2109 }
2110 }
2111 }
2112 None
2113 } else {
2114 Some(&input_symbols[input_symbol_index])
2115 };
2116 if let Some(transitions) = transitions {
2117 for transition_isymbol in transitions.keys() {
2118 if transition_isymbol.is_epsilon() || isymbol == Some(transition_isymbol) {
2119 self._transition(
2120 input_symbols,
2121 input_symbol_index,
2122 state,
2123 &transitions[transition_isymbol],
2124 isymbol,
2125 transition_isymbol,
2126 result,
2127 keep_non_final,
2128 );
2129 }
2130 }
2131 if let Some(isymbol) = isymbol {
2132 if isymbol.is_unknown() {
2133 if let Some(transition_list) =
2134 transitions.get(&Symbol::Special(SpecialSymbol::UNKNOWN))
2135 {
2136 self._transition(
2137 input_symbols,
2138 input_symbol_index,
2139 state,
2140 transition_list,
2141 Some(isymbol),
2142 &Symbol::Special(SpecialSymbol::UNKNOWN),
2143 result,
2144 keep_non_final,
2145 );
2146 }
2147
2148 if let Some(transition_list) =
2149 transitions.get(&Symbol::Special(SpecialSymbol::IDENTITY))
2150 {
2151 self._transition(
2152 input_symbols,
2153 input_symbol_index,
2154 state,
2155 transition_list,
2156 Some(isymbol),
2157 &Symbol::Special(SpecialSymbol::IDENTITY),
2158 result,
2159 keep_non_final,
2160 );
2161 }
2162 }
2163 }
2164 }
2165 }
2166
2167 fn _transition<T: Clone + UsizeOrUnit>(
2168 &self,
2169 input_symbols: &[Symbol],
2170 input_symbol_index: usize,
2171 state: &InternalFSTState<T>,
2172 transitions: &[(u64, Symbol, f64)],
2173 isymbol: Option<&Symbol>,
2174 transition_isymbol: &Symbol,
2175 result: &mut Vec<(bool, bool, InternalFSTState<T>)>,
2176 keep_non_final: bool,
2177 ) {
2178 for (next_state, osymbol, weight) in transitions.iter() {
2179 let new_output_flags = _update_flags(osymbol, &state.output_flags);
2180 let new_input_flags = _update_flags(transition_isymbol, &state.input_flags);
2181
2182 match (new_output_flags, new_input_flags) {
2183 (Some(new_output_flags), Some(new_input_flags)) => {
2184 let new_osymbol = match (isymbol, osymbol) {
2185 (Some(isymbol), Symbol::Special(SpecialSymbol::IDENTITY)) => isymbol,
2186 _ => osymbol,
2187 };
2188
2189 let new_symbol_mapping: FSTLinkedList<T> = if new_osymbol.is_epsilon() {
2193 state.symbol_mappings.clone() } else {
2195 FSTLinkedList::Some(
2196 (new_osymbol.clone(), T::convert(input_symbol_index)),
2197 Arc::new(state.symbol_mappings.clone()),
2198 )
2199 };
2200 let new_state = InternalFSTState {
2201 state_num: *next_state,
2202 path_weight: state.path_weight + *weight,
2203 input_flags: new_input_flags,
2204 output_flags: new_output_flags,
2205 symbol_mappings: new_symbol_mapping,
2206 };
2207 if transition_isymbol.is_epsilon() {
2208 self._run_fst(
2209 input_symbols,
2210 &new_state,
2211 input_symbols.is_empty(),
2212 result,
2213 input_symbol_index,
2214 keep_non_final,
2215 );
2216 } else {
2217 self._run_fst(
2218 input_symbols,
2219 &new_state,
2220 false,
2221 result,
2222 input_symbol_index + 1,
2223 keep_non_final,
2224 );
2225 }
2226 }
2227 _ => continue,
2228 }
2229 }
2230 }
2231
2232 pub fn from_att_rows(
2240 rows: Vec<Result<(u64, f64), (u64, u64, Symbol, Symbol, f64)>>,
2241 debug: bool,
2242 ) -> FST {
2243 let mut final_states: IndexMap<u64, f64> = IndexMap::new();
2244 let mut rules: IndexMap<u64, IndexMap<Symbol, Vec<(u64, Symbol, f64)>>> = IndexMap::new();
2245 let mut symbols: IndexSet<Symbol> = IndexSet::new();
2246 for line in rows.into_iter() {
2247 match line {
2248 Ok((state_number, state_weight)) => {
2249 final_states.insert(state_number, state_weight);
2250 }
2251 Err((state_1, state_2, top_symbol, bottom_symbol, weight)) => {
2252 rules.entry(state_1).or_default();
2253 let handle = rules.get_mut(&state_1).unwrap();
2254 if !handle.contains_key(&top_symbol) {
2255 handle.insert(top_symbol.clone(), vec![]);
2256 }
2257 handle.get_mut(&top_symbol).unwrap().push((
2258 state_2,
2259 bottom_symbol.clone(),
2260 weight,
2261 ));
2262 symbols.insert(top_symbol);
2263 symbols.insert(bottom_symbol);
2264 }
2265 }
2266 }
2267 FST::from_rules(
2268 final_states,
2269 rules,
2270 symbols.into_iter().collect(),
2271 Some(debug),
2272 )
2273 }
2274
2275 fn _from_kfst_bytes(kfst_bytes: &[u8]) -> Result<FST, String> {
2276 let mut header = nom::sequence::preceded(
2282 nom::bytes::complete::tag("KFST"),
2283 nom::number::complete::be_u16::<&[u8], ()>,
2284 );
2285 let (rest, version) = header
2286 .parse(kfst_bytes)
2287 .map_err(|_| "Failed to parse header")?;
2288 assert!(version == 0);
2289
2290 let mut metadata = (
2293 nom::number::complete::be_u16::<&[u8], ()>,
2294 nom::number::complete::be_u32,
2295 nom::number::complete::be_u32,
2296 nom::number::complete::u8,
2297 );
2298 let (rest, (num_symbols, num_transitions, num_final_states, is_weighted)) = metadata
2299 .parse(rest)
2300 .map_err(|_| "Failed to parse metadata")?;
2301 let num_transitions: usize = num_transitions
2302 .try_into()
2303 .map_err(|_| "usize too small to represent transitions")?;
2304 let num_final_states: usize = num_final_states
2305 .try_into()
2306 .map_err(|_| "usize too small to represent final states")?;
2307 let is_weighted: bool = is_weighted != 0u8;
2309
2310 let mut symbol = nom::multi::count(
2313 nom::sequence::terminated(nom::bytes::complete::take_until1("\0"), tag("\0")),
2314 num_symbols.into(),
2315 );
2316 let (rest, symbols) = symbol
2317 .parse(rest)
2318 .map_err(|_: nom::Err<()>| "Failed to parse symbol list")?;
2319 let symbol_strings: Vec<&str> = symbols
2320 .into_iter()
2321 .map(|x| std::str::from_utf8(x))
2322 .collect::<Result<Vec<&str>, _>>()
2323 .map_err(|x| format!("Some symbol was not valid utf-8: {x}"))?;
2324 let symbol_list: Vec<Symbol> = symbol_strings
2325 .iter()
2326 .map(|x| {
2327 Symbol::parse(x)
2328 .map_err(|x| {
2329 format!(
2330 "Some symbol while valid utf8 was not a valid symbol specifier: {x}"
2331 )
2332 })
2333 .and_then(|(extra, sym)| {
2334 if extra.is_empty() {
2335 Ok(sym)
2336 } else {
2337 Err(format!(
2338 "Extra data after end of symbol {}: {extra:?}",
2339 sym.get_symbol(),
2340 ))
2341 }
2342 })
2343 })
2344 .collect::<Result<Vec<Symbol>, _>>()?;
2345 let symbol_objs: IndexSet<Symbol> = symbol_list.iter().cloned().collect();
2346
2347 let mut decomp: Vec<u8> = Vec::new();
2350 let mut decoder = XzDecoder::new(rest);
2351 decoder
2352 .read_to_end(&mut decomp)
2353 .map_err(|_| "Failed to xz-decompress remainder of file")?;
2354
2355 let transition_syntax = (
2359 nom::number::complete::be_u32::<&[u8], ()>,
2360 nom::number::complete::be_u32,
2361 nom::number::complete::be_u16,
2362 nom::number::complete::be_u16,
2363 );
2364 let weight_parser = if is_weighted {
2365 nom::number::complete::be_f64
2366 } else {
2367 |input| Ok((input, 0.0)) };
2369 let (rest, file_rules) = many_m_n(
2370 num_transitions,
2371 num_transitions,
2372 (transition_syntax, weight_parser),
2373 )
2374 .parse(decomp.as_slice())
2375 .map_err(|_| "Broken transition table")?;
2376
2377 let (rest, final_states) = many_m_n(
2378 num_final_states,
2379 num_final_states,
2380 (nom::number::complete::be_u32, weight_parser),
2381 )
2382 .parse(rest)
2383 .map_err(|_| "Broken final states")?;
2384
2385 if !rest.is_empty() {
2386 Err(format!("lzma-compressed payload is {} bytes long when decompressed but given the header, there seems to be {} bytes extra.", decomp.len(), rest.len()))?;
2387 }
2388
2389 let final_states = final_states
2392 .into_iter()
2393 .map(|(a, b)| (a.into(), b))
2394 .collect();
2395
2396 let symbols = symbol_objs.into_iter().collect();
2399
2400 let mut rules: IndexMap<u64, IndexMap<Symbol, Vec<(u64, Symbol, f64)>>> = IndexMap::new();
2403 for ((from_state, to_state, top_symbol_idx, bottom_symbol_idx), weight) in
2404 file_rules.into_iter()
2405 {
2406 let from_state = from_state.into();
2407 let to_state = to_state.into();
2408 let top_symbol_idx: usize = top_symbol_idx.into();
2409 let bottom_symbol_idx: usize = bottom_symbol_idx.into();
2410 let top_symbol = symbol_list[top_symbol_idx].clone();
2411 let bottom_symbol = symbol_list[bottom_symbol_idx].clone();
2412 let handle = rules.entry(from_state).or_default();
2413 let vec = handle.entry(top_symbol).or_default();
2414 vec.push((to_state, bottom_symbol, weight));
2415 }
2416
2417 Ok(FST::from_rules(final_states, rules, symbols, None))
2418 }
2419
2420 fn _to_kfst_bytes(&self) -> Result<Vec<u8>, String> {
2421 let mut weighted = false;
2424
2425 for (_, &weight) in self.final_states.iter() {
2426 if weight != 0.0 {
2427 weighted = true;
2428 break;
2429 }
2430 }
2431
2432 let mut transitions: u32 = 0;
2433
2434 for (_, transition_table) in self.rules.iter() {
2435 for transition in transition_table.values() {
2436 for (_, _, weight) in transition.iter() {
2437 if (*weight) != 0.0 {
2438 weighted = true;
2439 }
2440 transitions += 1;
2441 }
2442 }
2443 }
2444
2445 let mut result: Vec<u8> = "KFST".into();
2448 result.extend(0u16.to_be_bytes());
2449 let symbol_len: u16 = self
2450 .symbols
2451 .len()
2452 .try_into()
2453 .map_err(|x| format!("Too many symbols to represent as u16: {x}"))?;
2454 result.extend(symbol_len.to_be_bytes());
2455 result.extend(transitions.to_be_bytes());
2456 let num_states: u32 = self
2457 .final_states
2458 .len()
2459 .try_into()
2460 .map_err(|x| format!("Too many final states to represent as u32: {x}"))?;
2461 result.extend(num_states.to_be_bytes());
2462 result.push(weighted.into()); let mut sorted_syms: Vec<_> = self.symbols.iter().collect();
2467 sorted_syms.sort();
2468 for symbol in sorted_syms.iter() {
2469 result.extend(symbol.get_symbol().into_bytes());
2470 result.push(0); }
2472
2473 let mut to_compress: Vec<u8> = vec![];
2476
2477 for (source_state, transition_table) in self.rules.iter() {
2480 for (top_symbol, transition) in transition_table.iter() {
2481 for (target_state, bottom_symbol, weight) in transition.iter() {
2482 let source_state: u32 = (*source_state).try_into().map_err(|x| {
2483 format!("Can't represent source state {source_state} as u32: {x}")
2484 })?;
2485 let target_state: u32 = (*target_state).try_into().map_err(|x| {
2486 format!("Can't represent target state {target_state} as u32: {x}")
2487 })?;
2488 let top_index: u16 = sorted_syms
2489 .binary_search(&top_symbol)
2490 .map_err(|_| {
2491 format!("Top symbol {top_symbol:?} not found in FST symbol list")
2492 })
2493 .and_then(|x| {
2494 x.try_into().map_err(|x| {
2495 format!("Can't represent top symbol index as u16: {x}")
2496 })
2497 })?;
2498 let bottom_index: u16 = sorted_syms
2499 .binary_search(&bottom_symbol)
2500 .map_err(|_| {
2501 format!("Bottom symbol {bottom_symbol:?} not found in FST symbol list")
2502 })
2503 .and_then(|x| {
2504 x.try_into().map_err(|x| {
2505 format!("Can't represent bottom symbol index as u16: {x}")
2506 })
2507 })?;
2508 to_compress.extend(source_state.to_be_bytes());
2509 to_compress.extend(target_state.to_be_bytes());
2510 to_compress.extend(top_index.to_be_bytes());
2511 to_compress.extend(bottom_index.to_be_bytes());
2512 if weighted {
2513 to_compress.extend(weight.to_be_bytes());
2514 } else {
2515 assert!(*weight == 0.0);
2516 }
2517 }
2518 }
2519 }
2520
2521 for (&final_state, weight) in self.final_states.iter() {
2524 let final_state: u32 = final_state
2525 .try_into()
2526 .map_err(|x| format!("Can't represent final state index as u32: {x}"))?;
2527 to_compress.extend(final_state.to_be_bytes());
2528 if weighted {
2529 to_compress.extend(weight.to_be_bytes());
2530 } else {
2531 assert!(*weight == 0.0);
2532 }
2533 }
2534
2535 let mut compressed = vec![];
2538
2539 let mut encoder = XzEncoder::new(to_compress.as_slice(), 9);
2540 encoder
2541 .read_to_end(&mut compressed)
2542 .map_err(|x| format!("Failed while compressing with lzma_rs: {x}"))?;
2543 result.extend(compressed);
2544
2545 Ok(result)
2546 }
2547
2548 fn _from_rules(
2549 final_states: IndexMap<u64, f64>,
2550 rules: IndexMap<u64, IndexMap<Symbol, Vec<(u64, Symbol, f64)>>>,
2551 symbols: HashSet<Symbol>,
2552 debug: Option<bool>,
2553 ) -> FST {
2554 let mut new_symbols: Vec<Symbol> = symbols.into_iter().collect();
2555 new_symbols.sort();
2558 let mut new_rules = IndexMap::new();
2560 for (target_node, rulebook) in rules {
2561 let mut new_rulebook = rulebook;
2562 new_rulebook.sort_by(|a, _, b, _| b.is_epsilon().cmp(&a.is_epsilon()));
2563 new_rules.insert(target_node, new_rulebook);
2564 }
2565 FST {
2566 final_states,
2567 rules: new_rules,
2568 symbols: new_symbols,
2569 debug: debug.unwrap_or(false),
2570 }
2571 }
2572
2573 #[cfg(not(feature = "python"))]
2575 pub fn from_rules(
2576 final_states: IndexMap<u64, f64>,
2577 rules: IndexMap<u64, IndexMap<Symbol, Vec<(u64, Symbol, f64)>>>,
2578 symbols: HashSet<Symbol>,
2579 debug: Option<bool>,
2580 ) -> FST {
2581 FST::_from_rules(final_states, rules, symbols, debug)
2582 }
2583
2584 fn _from_att_file(att_file: String, debug: bool) -> KFSTResult<FST> {
2585 match File::open(Path::new(&att_file)) {
2587 Ok(mut file) => {
2588 let mut att_code = String::new();
2589 file.read_to_string(&mut att_code).map_err(|err| {
2590 io_error::<()>(format!("Failed to read from file {att_file}:\n{err}"))
2591 .unwrap_err()
2592 })?;
2593 FST::from_att_code(att_code, debug)
2594 }
2595 Err(err) => io_error(format!("Failed to open file {att_file}:\n{err}")),
2596 }
2597 }
2598
2599 #[cfg(not(feature = "python"))]
2600 pub fn from_att_file(att_file: String, debug: bool) -> KFSTResult<FST> {
2603 FST::_from_att_file(att_file, debug)
2604 }
2605
2606 fn _from_att_code(att_code: String, debug: bool) -> KFSTResult<FST> {
2607 let mut rows: Vec<Result<(u64, f64), (u64, u64, Symbol, Symbol, f64)>> = vec![];
2608
2609 for (lineno, line) in att_code.lines().enumerate() {
2610 let elements: Vec<&str> = line.split("\t").collect();
2611 if elements.len() == 1 || elements.len() == 2 {
2612 let state = elements[0].parse::<u64>().ok();
2613 let weight = if elements.len() == 1 {
2614 Some(0.0)
2615 } else {
2616 elements[1].parse::<f64>().ok()
2617 };
2618 match (state, weight) {
2619 (Some(state), Some(weight)) => {
2620 rows.push(Ok((state, weight)));
2621 }
2622 _ => {
2623 return value_error(format!(
2624 "Failed to parse att code on line {lineno}:\n{line}",
2625 ))
2626 }
2627 }
2628 } else if elements.len() == 4 || elements.len() == 5 {
2629 let state_1 = elements[0].parse::<u64>().ok();
2630 let state_2 = elements[1].parse::<u64>().ok();
2631 let unescaped_sym_1 = unescape_att_symbol(elements[2]);
2632 let symbol_1 = Symbol::parse(&unescaped_sym_1).ok();
2633 let unescaped_sym_2 = unescape_att_symbol(elements[3]);
2634 let symbol_2 = Symbol::parse(&unescaped_sym_2).ok();
2635 let weight = if elements.len() == 4 {
2636 Some(0.0)
2637 } else {
2638 elements[4].parse::<f64>().ok()
2639 };
2640 match (state_1, state_2, symbol_1, symbol_2, weight) {
2641 (
2642 Some(state_1),
2643 Some(state_2),
2644 Some(("", symbol_1)),
2645 Some(("", symbol_2)),
2646 Some(weight),
2647 ) => {
2648 rows.push(Err((state_1, state_2, symbol_1, symbol_2, weight)));
2649 }
2650 _ => {
2651 return value_error(format!(
2652 "Failed to parse att code on line {lineno}:\n{line}",
2653 ));
2654 }
2655 }
2656 }
2657 }
2658 KFSTResult::Ok(FST::from_att_rows(rows, debug))
2659 }
2660
2661 #[cfg(not(feature = "python"))]
2662 pub fn from_att_code(att_code: String, debug: bool) -> KFSTResult<FST> {
2726 FST::_from_att_code(att_code, debug)
2727 }
2728
2729 fn _from_kfst_file(kfst_file: String, debug: bool) -> KFSTResult<FST> {
2730 match File::open(Path::new(&kfst_file)) {
2731 Ok(mut file) => {
2732 let mut kfst_bytes: Vec<u8> = vec![];
2733 file.read_to_end(&mut kfst_bytes).map_err(|err| {
2734 io_error::<()>(format!("Failed to read from file {kfst_file}:\n{err}"))
2735 .unwrap_err()
2736 })?;
2737 FST::from_kfst_bytes(&kfst_bytes, debug)
2738 }
2739 Err(err) => io_error(format!("Failed to open file {kfst_file}:\n{err}")),
2740 }
2741 }
2742
2743 #[cfg(not(feature = "python"))]
2747 pub fn from_kfst_file(kfst_file: String, debug: bool) -> KFSTResult<FST> {
2748 FST::_from_kfst_file(kfst_file, debug)
2749 }
2750
2751 #[allow(unused)]
2752 fn __from_kfst_bytes(kfst_bytes: &[u8], debug: bool) -> KFSTResult<FST> {
2753 match FST::_from_kfst_bytes(kfst_bytes) {
2754 Ok(x) => Ok(x),
2755 Err(x) => value_error(x),
2756 }
2757 }
2758
2759 #[cfg(not(feature = "python"))]
2763 pub fn from_kfst_bytes(kfst_bytes: &[u8], debug: bool) -> KFSTResult<FST> {
2764 FST::__from_kfst_bytes(kfst_bytes, debug)
2765 }
2766
2767 fn _split_to_symbols(&self, text: &str, allow_unknown: bool) -> Option<Vec<Symbol>> {
2768 let mut result = vec![];
2769 let max_byte_len = self
2770 .symbols
2771 .iter()
2772 .find(|x| matches!(x, Symbol::String(_) | Symbol::Special(_) | Symbol::Flag(_)))
2773 .map(|x| x.with_symbol(|s| s.len()))
2774 .unwrap_or(0);
2775 let mut slice = text;
2776 while !slice.is_empty() {
2777 let mut found = false;
2778 for length in (0..std::cmp::min(max_byte_len, slice.len()) + 1).rev() {
2779 if !slice.is_char_boundary(length) {
2780 continue;
2781 }
2782 let key = &slice[..length];
2783 let pp = self.symbols.partition_point(|x| {
2784 x.with_symbol(|y| (key.chars().count(), y) < (y.chars().count(), key))
2785 });
2786 if let Some(sym) = self.symbols[pp..]
2787 .iter()
2788 .find(|x| matches!(x, Symbol::String(_) | Symbol::Special(_) | Symbol::Flag(_)))
2789 {
2790 if sym.with_symbol(|s| s == key) {
2791 result.push(sym.clone());
2792 slice = &slice[length..];
2793 found = true;
2794 break;
2795 }
2796 }
2797 }
2798 if (!found) && allow_unknown {
2799 let char = slice.chars().next().unwrap();
2800 slice = &slice[char.len_utf8()..];
2801 result.push(Symbol::String(StringSymbol::new(char.to_string(), true)));
2802 found = true;
2803 }
2804 if !found {
2805 return None;
2806 }
2807 }
2808 Some(result)
2809 }
2810
2811 #[cfg(not(feature = "python"))]
2816 pub fn split_to_symbols(&self, text: &str, allow_unknown: bool) -> Option<Vec<Symbol>> {
2817 self._split_to_symbols(text, allow_unknown)
2818 }
2819
2820 fn __run_fst<T: UsizeOrUnit + Clone>(
2821 &self,
2822 input_symbols: Vec<Symbol>,
2823 state: InternalFSTState<T>,
2824 post_input_advance: bool,
2825 input_symbol_index: usize,
2826 keep_non_final: bool,
2827 ) -> Vec<(bool, bool, InternalFSTState<T>)> {
2828 let mut result = vec![];
2829 self._run_fst(
2830 input_symbols.as_slice(),
2831 &state,
2832 post_input_advance,
2833 &mut result,
2834 input_symbol_index,
2835 keep_non_final,
2836 );
2837 result
2838 }
2839
2840 #[cfg(not(feature = "python"))]
2841 pub fn run_fst<T: Clone + UsizeOrUnit>(
2849 &self,
2850 input_symbols: Vec<Symbol>,
2851 state: InternalFSTState<T>,
2852 post_input_advance: bool,
2853 input_symbol_index: Option<usize>,
2854 keep_non_final: bool,
2855 ) -> Vec<(bool, bool, InternalFSTState<T>)> {
2856 self.__run_fst(
2857 input_symbols,
2858 state,
2859 post_input_advance,
2860 input_symbol_index.unwrap_or(0),
2861 keep_non_final,
2862 )
2863 }
2864
2865 fn _lookup<T: UsizeOrUnit + Clone + Default + Debug>(
2866 &self,
2867 input: &str,
2868 state: InternalFSTState<T>,
2869 allow_unknown: bool,
2870 ) -> KFSTResult<Vec<(String, f64)>> {
2871 let input_symbols = self.split_to_symbols(input, allow_unknown);
2872 match input_symbols {
2873 None => tokenization_exception(format!("Input cannot be split into symbols: {input}")),
2874 Some(input_symbols) => {
2875 let mut dedup: IndexSet<String> = IndexSet::new();
2876 let mut result: Vec<(String, f64)> = vec![];
2877 let mut finished_paths: Vec<(bool, bool, InternalFSTState<()>)> = self.__run_fst(
2878 input_symbols.clone(),
2879 state.convert_indices(),
2880 false,
2881 0,
2882 false,
2883 );
2884 finished_paths
2885 .sort_by(|a, b| a.2.path_weight.partial_cmp(&b.2.path_weight).unwrap());
2886 for finished in finished_paths {
2887 let output_string: String = finished
2888 .2
2889 .symbol_mappings
2890 .to_vec()
2891 .iter()
2892 .map(|x| x.0.get_symbol())
2893 .collect::<Vec<String>>()
2894 .join("");
2895 if dedup.contains(&output_string) {
2896 continue;
2897 }
2898 dedup.insert(output_string.clone());
2899 result.push((output_string, finished.2.path_weight));
2900 }
2901 Ok(result)
2902 }
2903 }
2904 }
2905
2906 fn _lookup_aligned<T: UsizeOrUnit + Clone + Default + Debug>(
2907 &self,
2908 input: &str,
2909 state: InternalFSTState<T>,
2910 allow_unknown: bool,
2911 ) -> KFSTResult<Vec<(Vec<(usize, Symbol)>, f64)>> {
2912 let input_symbols = self.split_to_symbols(input, allow_unknown);
2913 match input_symbols {
2914 None => tokenization_exception(format!("Input cannot be split into symbols: {input}")),
2915 Some(input_symbols) => {
2916 let mut dedup: IndexSet<Vec<(usize, Symbol)>> = IndexSet::new();
2917 let mut result: Vec<(Vec<(usize, Symbol)>, f64)> = vec![];
2918 let mut finished_paths: Vec<_> = self
2919 .__run_fst(
2920 input_symbols.clone(),
2921 state.convert_indices(),
2922 false,
2923 0,
2924 false,
2925 )
2926 .into_iter()
2927 .filter(|(finished, _, _)| *finished)
2928 .collect();
2929 finished_paths
2930 .sort_by(|a, b| a.2.path_weight.partial_cmp(&b.2.path_weight).unwrap());
2931 for finished in finished_paths {
2932 let output_vec: Vec<(usize, Symbol)> = finished
2933 .2
2934 .symbol_mappings
2935 .to_vec()
2936 .into_iter()
2937 .map(|(a, b)| (b, a))
2938 .collect();
2939 if dedup.contains(&output_vec) {
2940 continue;
2941 }
2942 dedup.insert(output_vec.clone());
2943 result.push((output_vec, finished.2.path_weight));
2944 }
2945 Ok(result)
2946 }
2947 }
2948 }
2949
2950 #[cfg(not(feature = "python"))]
2951 pub fn lookup<T: UsizeOrUnit + Clone + Default + Debug>(
2959 &self,
2960 input: &str,
2961 state: InternalFSTState<T>,
2962 allow_unknown: bool,
2963 ) -> KFSTResult<Vec<(String, f64)>> {
2964 self._lookup(input, state, allow_unknown)
2965 }
2966
2967 #[cfg(not(feature = "python"))]
2968 pub fn lookup_aligned<T: UsizeOrUnit + Clone + Default + Debug>(
3034 &self,
3035 input: &str,
3036 state: InternalFSTState<T>,
3037 allow_unknown: bool,
3038 ) -> KFSTResult<Vec<(Vec<(usize, Symbol)>, f64)>> {
3039 self._lookup_aligned(input, state, allow_unknown)
3040 }
3041}
3042
3043fn _update_flags(symbol: &Symbol, flags: &FlagMap) -> Option<FlagMap> {
3044 if let Symbol::Flag(flag_diacritic_symbol) = symbol {
3045 match flag_diacritic_symbol.flag_type {
3046 FlagDiacriticType::U => {
3047 let value = flag_diacritic_symbol.value;
3048
3049 if let Some((currently_set, current_value)) = flags.get(flag_diacritic_symbol.key) {
3053 if (currently_set && current_value != value)
3054 || (!currently_set && current_value == value)
3055 {
3056 return None;
3057 }
3058 }
3059
3060 Some(flags.insert(flag_diacritic_symbol.key, (true, value)))
3063 }
3064 FlagDiacriticType::R => {
3065 match flag_diacritic_symbol.value {
3068 u32::MAX => {
3069 if flags.get(flag_diacritic_symbol.key).is_some() {
3070 Some(flags.clone())
3071 } else {
3072 None
3073 }
3074 }
3075 value => {
3076 if flags
3077 .get(flag_diacritic_symbol.key)
3078 .map(|stored| _test_flag(&stored, value))
3079 .unwrap_or(false)
3080 {
3081 Some(flags.clone())
3082 } else {
3083 None
3084 }
3085 }
3086 }
3087 }
3088 FlagDiacriticType::D => {
3089 match (
3090 flag_diacritic_symbol.value,
3091 flags.get(flag_diacritic_symbol.key),
3092 ) {
3093 (u32::MAX, None) => Some(flags.clone()),
3094 (u32::MAX, _) => None,
3095 (_, None) => Some(flags.clone()),
3096 (query, Some(stored)) => {
3097 if _test_flag(&stored, query) {
3098 None
3099 } else {
3100 Some(flags.clone())
3101 }
3102 }
3103 }
3104 }
3105 FlagDiacriticType::C => Some(flags.remove(flag_diacritic_symbol.key)),
3106 FlagDiacriticType::P => {
3107 let value = flag_diacritic_symbol.value;
3108 Some(flags.insert(flag_diacritic_symbol.key, (true, value)))
3109 }
3110 FlagDiacriticType::N => {
3111 let value = flag_diacritic_symbol.value;
3112 Some(flags.insert(flag_diacritic_symbol.key, (false, value)))
3113 }
3114 }
3115 } else {
3116 Some(flags.clone())
3117 }
3118}
3119
3120#[cfg_attr(feature = "python", pymethods)]
3121impl FST {
3122 #[cfg(feature = "python")]
3123 #[staticmethod]
3124 #[pyo3(signature = (final_states, rules, symbols, debug = false))]
3125 fn from_rules(
3126 final_states: IndexMap<u64, f64>,
3127 rules: IndexMap<u64, IndexMap<Symbol, Vec<(u64, Symbol, f64)>>>,
3128 symbols: HashSet<Symbol>,
3129 debug: Option<bool>,
3130 ) -> FST {
3131 FST::_from_rules(final_states, rules, symbols, debug)
3132 }
3133
3134 #[cfg(feature = "python")]
3135 #[staticmethod]
3136 #[pyo3(signature = (att_file, debug = false))]
3137 fn from_att_file(py: Python<'_>, att_file: PyObject, debug: bool) -> KFSTResult<FST> {
3138 FST::_from_att_file(att_file.call_method0(py, "__str__")?.extract(py)?, debug)
3139 }
3140
3141 #[cfg(feature = "python")]
3142 #[staticmethod]
3143 #[pyo3(signature = (att_code, debug = false))]
3144 fn from_att_code(att_code: String, debug: bool) -> KFSTResult<FST> {
3145 FST::_from_att_code(att_code, debug)
3146 }
3147
3148 #[cfg(feature = "python")]
3149 pub fn to_att_file(&self, py: Python<'_>, att_file: PyObject) -> KFSTResult<()> {
3150 let path: String = att_file.call_method0(py, "__str__")?.extract(py)?;
3151 fs::write(Path::new(&path), self.to_att_code()).map_err(|err| {
3152 io_error::<()>(format!("Failed to write to file {path}:\n{err}")).unwrap_err()
3153 })
3154 }
3155
3156 #[cfg(not(feature = "python"))]
3158 pub fn to_att_file(&self, att_file: String) -> KFSTResult<()> {
3159 fs::write(Path::new(&att_file), self.to_att_code()).map_err(|err| {
3160 io_error::<()>(format!("Failed to write to file {att_file}:\n{err}")).unwrap_err()
3161 })
3162 }
3163
3164 pub fn to_att_code(&self) -> String {
3187 let mut rows: Vec<String> = vec![];
3188 for (state, weight) in self.final_states.iter() {
3189 match weight {
3190 0.0 => {
3191 rows.push(format!("{state}"));
3192 }
3193 _ => {
3194 rows.push(format!("{state}\t{weight}"));
3195 }
3196 }
3197 }
3198 for (from_state, rules) in self.rules.iter() {
3199 for (top_symbol, transitions) in rules.iter() {
3200 for (to_state, bottom_symbol, weight) in transitions.iter() {
3201 match weight {
3202 0.0 => {
3203 rows.push(format!(
3204 "{}\t{}\t{}\t{}",
3205 from_state,
3206 to_state,
3207 escape_att_symbol(&top_symbol.get_symbol()),
3208 escape_att_symbol(&bottom_symbol.get_symbol())
3209 ));
3210 }
3211 _ => {
3212 rows.push(format!(
3213 "{}\t{}\t{}\t{}\t{}",
3214 from_state,
3215 to_state,
3216 escape_att_symbol(&top_symbol.get_symbol()),
3217 escape_att_symbol(&bottom_symbol.get_symbol()),
3218 weight
3219 ));
3220 }
3221 }
3222 }
3223 }
3224 }
3225 rows.join("\n")
3226 }
3227
3228 #[cfg(feature = "python")]
3229 #[staticmethod]
3230 #[pyo3(signature = (kfst_file, debug = false))]
3231 fn from_kfst_file(py: Python<'_>, kfst_file: PyObject, debug: bool) -> KFSTResult<FST> {
3232 FST::_from_kfst_file(kfst_file.call_method0(py, "__str__")?.extract(py)?, debug)
3233 }
3234
3235 #[cfg(feature = "python")]
3236 #[staticmethod]
3237 #[pyo3(signature = (kfst_bytes, debug = false))]
3238 fn from_kfst_bytes(kfst_bytes: &[u8], debug: bool) -> KFSTResult<FST> {
3239 FST::__from_kfst_bytes(kfst_bytes, debug)
3240 }
3241
3242 #[cfg(feature = "python")]
3243 pub fn to_kfst_file(&self, py: Python<'_>, kfst_file: PyObject) -> KFSTResult<()> {
3244 let bytes = self.to_kfst_bytes()?;
3245 let path: String = kfst_file.call_method0(py, "__str__")?.extract(py)?;
3246 fs::write(Path::new(&path), bytes).map_err(|err| {
3247 io_error::<()>(format!("Failed to write to file {path}:\n{err}")).unwrap_err()
3248 })
3249 }
3250
3251 #[cfg(not(feature = "python"))]
3252 pub fn to_kfst_file(&self, kfst_file: String) -> KFSTResult<()> {
3254 let bytes = self.to_kfst_bytes()?;
3255 fs::write(Path::new(&kfst_file), bytes).map_err(|err| {
3256 io_error::<()>(format!("Failed to write to file {kfst_file}:\n{err}")).unwrap_err()
3257 })
3258 }
3259
3260 pub fn to_kfst_bytes(&self) -> KFSTResult<Vec<u8>> {
3262 match self._to_kfst_bytes() {
3263 Ok(x) => Ok(x),
3264 Err(x) => value_error(x),
3265 }
3266 }
3267
3268 #[deprecated]
3269 pub fn __repr__(&self) -> String {
3271 format!(
3272 "FST(final_states: {:?}, rules: {:?}, symbols: {:?}, debug: {:?})",
3273 self.final_states, self.rules, self.symbols, self.debug
3274 )
3275 }
3276
3277 #[cfg(feature = "python")]
3278 #[pyo3(signature = (text, allow_unknown = true))]
3279 fn split_to_symbols(&self, text: &str, allow_unknown: bool) -> Option<Vec<Symbol>> {
3280 self._split_to_symbols(text, allow_unknown)
3281 }
3282
3283 #[cfg(feature = "python")]
3284 #[pyo3(signature = (input_symbols, state = FSTState { payload: Ok(InternalFSTState::_new(0)) }, post_input_advance = false, input_symbol_index = None, keep_non_final = true))]
3285 fn run_fst(
3286 &self,
3287 input_symbols: Vec<Symbol>,
3288 state: FSTState,
3289 post_input_advance: bool,
3290 input_symbol_index: Option<usize>,
3291 keep_non_final: bool,
3292 ) -> Vec<(bool, bool, FSTState)> {
3293 match state.payload {
3294 Ok(p) => self
3295 .__run_fst(
3296 input_symbols,
3297 p,
3298 post_input_advance,
3299 input_symbol_index.unwrap_or(0),
3300 keep_non_final,
3301 )
3302 .into_iter()
3303 .map(|(a, b, c)| (a, b, FSTState { payload: Ok(c) }))
3304 .collect(),
3305 Err(p) => self
3306 .__run_fst(
3307 input_symbols,
3308 p,
3309 post_input_advance,
3310 input_symbol_index.unwrap_or(0),
3311 keep_non_final,
3312 )
3313 .into_iter()
3314 .map(|(a, b, c)| (a, b, FSTState { payload: Err(c) }))
3315 .collect(),
3316 }
3317 }
3318
3319 #[cfg(feature = "python")]
3320 #[pyo3(signature = (input, state=FSTState { payload: Err(InternalFSTState::_new(0)) }, allow_unknown=true))]
3321 fn lookup(
3322 &self,
3323 input: &str,
3324 state: FSTState,
3325 allow_unknown: bool,
3326 ) -> KFSTResult<Vec<(String, f64)>> {
3327 match state.payload {
3328 Ok(p) => self._lookup(input, p, allow_unknown),
3329 Err(p) => self._lookup(input, p, allow_unknown),
3330 }
3331 }
3332
3333 #[cfg(feature = "python")]
3334 #[pyo3(signature = (input, state=FSTState { payload: Ok(InternalFSTState::_new(0)) }, allow_unknown=true))]
3335 pub fn lookup_aligned(
3336 &self,
3337 input: &str,
3338 state: FSTState,
3339 allow_unknown: bool,
3340 ) -> KFSTResult<Vec<(Vec<(usize, Symbol)>, f64)>> {
3341 use pyo3::exceptions::PyValueError;
3342
3343 match state.payload {
3344 Ok(p) => self._lookup_aligned(input, p, allow_unknown),
3345 Err(p) => PyResult::Err(PyErr::new::<PyValueError, _>(
3346 format!("lookup_aligned refuses to work with states with input_indices=None (passed state {}). Manually convert it to an indexed state by calling ensure_indices()", p.__repr__())
3347 )),
3348 }
3349 }
3350
3351 #[cfg(feature = "python")]
3352 pub fn get_input_symbols(&self, state: FSTState) -> HashSet<Symbol> {
3353 match state.payload {
3354 Ok(p) => self._get_input_symbols(p),
3355 Err(p) => self._get_input_symbols(p),
3356 }
3357 }
3358}
3359
3360impl FST {
3361 fn _get_input_symbols<T>(&self, state: InternalFSTState<T>) -> HashSet<Symbol> {
3362 self.rules
3363 .get(&state.state_num)
3364 .map(|x| x.keys().cloned().collect())
3365 .unwrap_or_default()
3366 }
3367
3368 #[cfg(not(feature = "python"))]
3369 pub fn get_input_symbols<T>(&self, state: FSTState<T>) -> HashSet<Symbol> {
3387 self._get_input_symbols(state)
3388 }
3389}
3390
3391#[cfg(not(feature = "python"))]
3392mod tests {
3393 use crate::*;
3394
3395 #[test]
3396 fn test_att_trivial() {
3397 let fst = FST::from_att_code("1\n0\t1\ta\tb".to_string(), false).unwrap();
3398 assert_eq!(
3399 fst.lookup("a", FSTState::<()>::default(), false).unwrap(),
3400 vec![("b".to_string(), 0.0)]
3401 );
3402 }
3403
3404 #[test]
3405 fn test_att_slightly_less_trivial() {
3406 let fst = FST::from_att_code("2\n0\t1\ta\tb\n1\t2\tc\td".to_string(), false).unwrap();
3407 assert_eq!(
3408 fst.lookup("ac", FSTState::<()>::default(), false).unwrap(),
3409 vec![("bd".to_string(), 0.0)]
3410 );
3411 }
3412
3413 #[test]
3414 fn test_kfst_voikko_kissa() {
3415 let fst =
3416 FST::_from_kfst_file("../pyvoikko/pyvoikko/voikko.kfst".to_string(), false).unwrap();
3417 assert_eq!(
3418 fst.lookup("kissa", FSTState::<()>::_new(0), false).unwrap(),
3419 vec![("[Ln][Xp]kissa[X]kiss[Sn][Ny]a".to_string(), 0.0)]
3420 );
3421 assert_eq!(
3422 fst.lookup("kissojemmekaan", FSTState::<()>::_new(0), false)
3423 .unwrap(),
3424 vec![(
3425 "[Ln][Xp]kissa[X]kiss[Sg][Nm]oje[O1m]mme[Fkaan]kaan".to_string(),
3426 0.0
3427 )]
3428 );
3429 }
3430
3431 #[test]
3432 fn test_that_weight_of_end_state_applies_correctly() {
3433 let code = "0\t1\ta\tb\n1\t1.0";
3434 let fst = FST::from_att_code(code.to_string(), false).unwrap();
3435 assert_eq!(
3436 fst.lookup("a", FSTState::<()>::_new(0), false).unwrap(),
3437 vec![("b".to_string(), 1.0)]
3438 );
3439 }
3440
3441 #[test]
3442 fn test_kfst_voikko_correct_final_states() {
3443 let fst: FST =
3444 FST::_from_kfst_file("../pyvoikko/pyvoikko/voikko.kfst".to_string(), false).unwrap();
3445 let map: IndexMap<_, _> = [(19, 0.0)].into_iter().collect();
3446 assert_eq!(fst.final_states, map);
3447 }
3448
3449 #[test]
3450 fn test_kfst_voikko_split() {
3451 let fst: FST =
3452 FST::_from_kfst_file("../pyvoikko/pyvoikko/voikko.kfst".to_string(), false).unwrap();
3453 assert_eq!(
3454 fst.split_to_symbols("lentokone", false).unwrap(),
3455 vec![
3456 Symbol::String(StringSymbol {
3457 string: intern("l".to_string()),
3458 unknown: false
3459 }),
3460 Symbol::String(StringSymbol {
3461 string: intern("e".to_string()),
3462 unknown: false
3463 }),
3464 Symbol::String(StringSymbol {
3465 string: intern("n".to_string()),
3466 unknown: false
3467 }),
3468 Symbol::String(StringSymbol {
3469 string: intern("t".to_string()),
3470 unknown: false
3471 }),
3472 Symbol::String(StringSymbol {
3473 string: intern("o".to_string()),
3474 unknown: false
3475 }),
3476 Symbol::String(StringSymbol {
3477 string: intern("k".to_string()),
3478 unknown: false
3479 }),
3480 Symbol::String(StringSymbol {
3481 string: intern("o".to_string()),
3482 unknown: false
3483 }),
3484 Symbol::String(StringSymbol {
3485 string: intern("n".to_string()),
3486 unknown: false
3487 }),
3488 Symbol::String(StringSymbol {
3489 string: intern("e".to_string()),
3490 unknown: false
3491 }),
3492 ]
3493 );
3494
3495 assert_eq!(
3496 fst.split_to_symbols("lentää", false).unwrap(),
3497 vec![
3498 Symbol::String(StringSymbol {
3499 string: intern("l".to_string()),
3500 unknown: false
3501 }),
3502 Symbol::String(StringSymbol {
3503 string: intern("e".to_string()),
3504 unknown: false
3505 }),
3506 Symbol::String(StringSymbol {
3507 string: intern("n".to_string()),
3508 unknown: false
3509 }),
3510 Symbol::String(StringSymbol {
3511 string: intern("t".to_string()),
3512 unknown: false
3513 }),
3514 Symbol::String(StringSymbol {
3515 string: intern("ä".to_string()),
3516 unknown: false
3517 }),
3518 Symbol::String(StringSymbol {
3519 string: intern("ä".to_string()),
3520 unknown: false
3521 }),
3522 ]
3523 );
3524 }
3525
3526 #[test]
3527 fn test_kfst_voikko() {
3528 let fst =
3529 FST::_from_kfst_file("../pyvoikko/pyvoikko/voikko.kfst".to_string(), false).unwrap();
3530 assert_eq!(
3531 fst.lookup("lentokone", FSTState::<()>::_new(0), false)
3532 .unwrap(),
3533 vec![(
3534 "[Lt][Xp]lentää[X]len[Ln][Xj]to[X]to[Sn][Ny][Bh][Bc][Ln][Xp]kone[X]kone[Sn][Ny]"
3535 .to_string(),
3536 0.0
3537 )]
3538 );
3539 }
3540
3541 #[test]
3542 fn test_kfst_voikko_lentää() {
3543 let fst =
3544 FST::_from_kfst_file("../pyvoikko/pyvoikko/voikko.kfst".to_string(), false).unwrap();
3545 let mut sys = fst
3546 .lookup("lentää", FSTState::<()>::_new(0), false)
3547 .unwrap();
3548 sys.sort_by(|a, b| a.partial_cmp(b).unwrap());
3549 assert_eq!(
3550 sys,
3551 vec![
3552 ("[Lt][Xp]lentää[X]len[Tn1][Eb]tää".to_string(), 0.0),
3553 (
3554 "[Lt][Xp]lentää[X]len[Tt][Ap][P3][Ny][Ef]tää".to_string(),
3555 0.0
3556 )
3557 ]
3558 );
3559 }
3560
3561 #[test]
3562 fn test_kfst_voikko_lentää_correct_states() {
3563 let fst =
3564 FST::_from_kfst_file("../pyvoikko/pyvoikko/voikko.kfst".to_string(), false).unwrap();
3565 let input_symbols = fst.split_to_symbols("lentää", false).unwrap();
3566
3567 let results = [
3570 vec![
3571 0, 1, 1810, 1946, 1961, 1962, 1963, 1964, 1965, 1966, 2665, 2969, 2970, 3104, 3295,
3572 3484, 3678, 3870, 4064, 4260, 4454, 4648, 4842, 5036, 5230, 5454, 5645, 5839, 6031,
3573 6225, 6419, 6613, 6807, 7001, 7195, 7389, 7579, 12479, 13348, 13444, 13541, 13636,
3574 13733, 13830, 13925, 14028, 14131, 14234, 14331, 14426, 14525, 14622, 14723, 14826,
3575 14929, 15024, 15127, 15230, 15333, 15433, 15526,
3576 ],
3577 vec![
3578 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1878, 2840, 17295, 25716, 31090, 40909,
3579 85222, 204950, 216255, 217894, 254890, 256725, 256726, 256727, 256728, 256729,
3580 256730, 256731, 256732, 256733, 256734, 256735, 256736, 280866, 281235, 281479,
3581 281836, 281876, 281877, 288536, 355529, 378467,
3582 ],
3583 vec![
3584 17459, 17898, 17899, 26065, 26066, 26067, 26068, 26069, 31245, 42140, 87151,
3585 134039, 134040, 205452, 219693, 219694, 259005, 259666, 259667, 259668, 259669,
3586 259670, 259671, 259672, 280894, 281857, 289402, 356836, 378621, 378750, 378773,
3587 386786, 388199, 388200, 388201, 388202, 388203,
3588 ],
3589 vec![
3590 17458, 17459, 17899, 19455, 26192, 26214, 26215, 26216, 26217, 42361, 87536,
3591 118151, 205474, 216303, 220614, 220615, 220616, 220617, 220618, 220619, 220620,
3592 220621, 220629, 228443, 228444, 228445, 259219, 259220, 259221, 259222, 259223,
3593 259224, 259225, 356941, 387264,
3594 ],
3595 vec![
3596 42362, 102258, 216304, 216309, 216312, 216317, 217230, 356942, 387265,
3597 ],
3598 vec![
3599 211149, 212998, 212999, 213000, 213001, 213002, 216305, 216310, 216313, 216318,
3600 ],
3601 vec![
3602 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 210815, 210816,
3603 211139, 211140, 214985, 216311, 216314, 216315, 216316,
3604 ],
3605 ];
3606
3607 for i in 0..=input_symbols.len() {
3608 let subsequence = &input_symbols[..i];
3609 let mut states: Vec<_> = fst
3610 .run_fst(
3611 subsequence.to_vec(),
3612 FSTState::<()>::_new(0),
3613 false,
3614 None,
3615 true,
3616 )
3617 .into_iter()
3618 .map(|(_, _, x)| x.state_num)
3619 .collect();
3620 states.sort();
3621 assert_eq!(states, results[i]);
3622 }
3623 }
3624
3625 #[test]
3626 fn test_minimal_r_diacritic() {
3627 let code = "0\t1\t@P.V_SALLITTU.T@\tasetus\n1\t2\t@R.V_SALLITTU.T@\ttarkistus\n2";
3628 let fst = FST::from_att_code(code.to_string(), false).unwrap();
3629 let mut result = vec![];
3630 fst._run_fst(&[], &FSTState::<()>::_new(0), false, &mut result, 0, true);
3631 for x in result {
3632 println!("{x:?}");
3633 }
3634 assert_eq!(
3635 fst.lookup("", FSTState::<()>::_new(0), false).unwrap(),
3636 vec![("asetustarkistus".to_string(), 0.0)]
3637 );
3638 }
3639
3640 #[test]
3641 fn test_kfst_voikko_lentää_result_count() {
3642 let fst =
3643 FST::_from_kfst_file("../pyvoikko/pyvoikko/voikko.kfst".to_string(), false).unwrap();
3644 let input_symbols = fst.split_to_symbols("lentää", false).unwrap();
3645
3646 let results = [61, 42, 37, 35, 9, 10, 25];
3649
3650 for i in 0..=input_symbols.len() {
3651 let subsequence = &input_symbols[..i];
3652 assert_eq!(
3653 fst.run_fst(
3654 subsequence.to_vec(),
3655 FSTState::<()>::_new(0),
3656 false,
3657 None,
3658 true
3659 )
3660 .len(),
3661 results[i]
3662 );
3663 }
3664 }
3665
3666 #[test]
3667 fn does_not_crash_on_unknown() {
3668 let fst = FST::from_att_code("0\t1\ta\tb\n1".to_string(), false).unwrap();
3669 assert_eq!(
3670 fst.lookup("c", FSTState::<()>::_new(0), true).unwrap(),
3671 vec![]
3672 );
3673 assert!(fst.lookup("c", FSTState::<()>::_new(0), false).is_err());
3674 }
3675
3676 #[test]
3677 fn test_kfst_voikko_paragraph() {
3678 let words = [
3679 "on",
3680 "maanantaiaamu",
3681 "heinäkuussa",
3682 "aurinko",
3683 "paiskaa",
3684 "niin",
3685 "lämpöisesti",
3686 "heikon",
3687 "tuulen",
3688 "avulla",
3689 "ja",
3690 "peipposet",
3691 "kajahuttelevat",
3692 "ensimmäisiä",
3693 "kovia",
3694 "säveleitään",
3695 "tuoksuavissa",
3696 "koivuissa",
3697 "kirkon",
3698 "itäisellä",
3699 "seinuksella",
3700 "on",
3701 "kivipenkki",
3702 "juuri",
3703 "nyt",
3704 "saapuu",
3705 "keski-ikäinen",
3706 "työmies",
3707 "ja",
3708 "istuutuu",
3709 "penkille",
3710 "hän",
3711 "näyttää",
3712 "väsyneeltä",
3713 "alakuloiselta",
3714 "haluttomalla",
3715 "aivan",
3716 "kuin",
3717 "olisi",
3718 "vastikään",
3719 "tullut",
3720 "perheellisestä",
3721 "riidasta",
3722 "tahi",
3723 "jättänyt",
3724 "eilisen",
3725 "sapatinpäivän",
3726 "pyhittämättä",
3727 ];
3728 let gold: [Vec<(&str, i32)>; 48] = [
3729 vec![("[Lt][Xp]olla[X]o[Tt][Ap][P3][Ny][Ef]n", 0)],
3730 vec![("[Ln][Xp]maanantai[X]maanantai[Sn][Ny][Bh][Bc][Ln][Xp]aamu[X]aamu[Sn][Ny]", 0)],
3731 vec![("[Ln][Xp]heinä[X]hein[Sn][Ny]ä[Bh][Bc][Ln][Xp]kuu[X]kuu[Sine][Ny]ssa", 0)],
3732 vec![("[Ln][Xp]aurinko[X]aurinko[Sn][Ny]", 0), ("[Lem][Xp]Aurinko[X]aurinko[Sn][Ny]", 0), ("[Lee][Xp]Auri[X]aur[Sg][Ny]in[Fko][Ef]ko", 0)],
3733 vec![("[Lt][Xp]paiskata[X]paiska[Tt][Ap][P3][Ny][Eb]a", 0)],
3734 vec![("[Ls][Xp]niin[X]niin", 0)],
3735 vec![("[Ln][Xp]lämpö[X]lämpö[Ll][Xj]inen[X]ise[Ssti]sti", 0)],
3736 vec![("[Ll][Xp]heikko[X]heiko[Sg][Ny]n", 0)],
3737 vec![("[Ln][Xp]tuuli[X]tuul[Sg][Ny]en", 0)],
3738 vec![("[Ln][Xp]avu[X]avu[Sade][Ny]lla", 0), ("[Ln][Xp]apu[X]avu[Sade][Ny]lla", 0)],
3739 vec![("[Lc][Xp]ja[X]ja", 0)],
3740 vec![("[Ln][Xp]peipponen[X]peippo[Sn][Nm]set", 0)],
3741 vec![],
3742 vec![("[Lu][Xp]ensimmäinen[X]ensimmäi[Sp][Nm]siä", 0)],
3743 vec![("[Lnl][Xp]kova[X]kov[Sp][Nm]ia", 0)],
3744 vec![],
3745 vec![],
3746 vec![("[Ln][Xp]koivu[X]koivu[Sine][Nm]issa", 0), ("[Les][Xp]Koivu[X]koivu[Sine][Nm]issa", 0)],
3747 vec![("[Ln][Ica][Xp]kirkko[X]kirko[Sg][Ny]n", 0)],
3748 vec![("[Ln][De][Xp]itä[X]itä[Ll][Xj]inen[X]ise[Sade][Ny]llä", 0)],
3749 vec![("[Ln][Xp]seinus[X]seinukse[Sade][Ny]lla", 0)],
3750 vec![("[Lt][Xp]olla[X]o[Tt][Ap][P3][Ny][Ef]n", 0)],
3751 vec![("[Ln][Ica][Xp]kivi[X]kiv[Sn][Ny]i[Bh][Bc][Ln][Xp]penkki[X]penkk[Sn][Ny]i", 0)],
3752 vec![("[Ln][Xp]juuri[X]juur[Sn][Ny]i", 0), ("[Ls][Xp]juuri[X]juuri", 0), ("[Lt][Xp]juuria[X]juuri[Tk][Ap][P2][Ny][Eb]", 0), ("[Lt][Xp]juuria[X]juur[Tt][Ai][P3][Ny][Ef]i", 0)],
3753 vec![("[Ls][Xp]nyt[X]nyt", 0)],
3754 vec![("[Lt][Xp]saapua[X]saapuu[Tt][Ap][P3][Ny][Ef]", 0)],
3755 vec![("[Lp]keski[De]-[Bh][Bc][Ln][Xp]ikä[X]ikä[Ll][Xj]inen[X]i[Sn][Ny]nen", 0)],
3756 vec![("[Ln][Xp]työ[X]työ[Sn][Ny][Bh][Bc][Ln][Xp]mies[X]mies[Sn][Ny]", 0)],
3757 vec![("[Lc][Xp]ja[X]ja", 0)],
3758 vec![("[Lt][Xp]istuutua[X]istuutuu[Tt][Ap][P3][Ny][Ef]", 0)],
3759 vec![("[Ln][Xp]penkki[X]penki[Sall][Ny]lle", 0)],
3760 vec![("[Lr][Xp]hän[X]hä[Sn][Ny]n", 0)],
3761 vec![("[Lt][Xp]näyttää[X]näyttä[Tn1][Eb]ä", 0), ("[Lt][Xp]näyttää[X]näytt[Tt][Ap][P3][Ny][Ef]ää", 0)],
3762 vec![("[Lt][Irm][Xp]väsyä[X]väsy[Ll][Ru]n[Xj]yt[X]ee[Sabl][Ny]ltä", 0)],
3763 vec![("[Ln][De][Xp]ala[X]al[Sn][Ny]a[Bh][Bc][Lnl][Xp]kulo[X]kulo[Ll][Xj]inen[X]ise[Sabl][Ny]lta", 0)],
3764 vec![("[Ln][Xp]halu[X]halu[Ll][Xj]ton[X]ttoma[Sade][Ny]lla", 0)],
3765 vec![("[Ls][Xp]aivan[X]aivan", 0)],
3766 vec![("[Lc][Xp]kuin[X]kuin", 0), ("[Ln][Xp]kuu[X]ku[Sin][Nm]in", 0)],
3767 vec![("[Lt][Xp]olla[X]ol[Te][Ap][P3][Ny][Eb]isi", 0)],
3768 vec![("[Ls][Xp]vast=ikään[X]vast[Bm]ikään", 0)],
3769 vec![("[Lt][Xp]tulla[X]tul[Ll][Ru]l[Xj]ut[X][Sn][Ny]ut", 0), ("[Lt][Xp]tulla[X]tul[Ll][Rt][Xj]tu[X]lu[Sn][Nm]t", 0)],
3770 vec![("[Ln][Xp]perhe[X]perhee[Ll]lli[Xj]nen[X]se[Sela][Ny]stä", 0)],
3771 vec![("[Ln][Xp]riita[X]riida[Sela][Ny]sta", 0)],
3772 vec![("[Lc][Xp]tahi[X]tahi", 0)],
3773 vec![("[Lt][Xp]jättää[X]jättä[Ll][Ru]n[Xj]yt[X][Sn][Ny]yt", 0)],
3774 vec![("[Lnl][Xp]eilinen[X]eili[Sg][Ny]sen", 0)],
3775 vec![("[Ln][Xp]sapatti[X]sapat[Sg][Ny]in[Bh][Bc][Ln][Xp]päivä[X]päiv[Sg][Ny]än", 0)],
3776 vec![("[Lt][Xp]pyhittää[X]pyhittä[Ln]m[Xj]ä[X][Rm]ä[Sab][Ny]ttä", 0), ("[Lt][Xp]pyhittää[X]pyhittä[Tn3][Ny][Sab]mättä", 0)],
3777 ];
3778 let fst =
3779 FST::_from_kfst_file("../pyvoikko/pyvoikko/voikko.kfst".to_string(), false).unwrap();
3780 for (idx, (word, gold)) in words.into_iter().zip(gold.into_iter()).enumerate() {
3781 let mut sys = fst.lookup(word, FSTState::<()>::_new(0), false).unwrap();
3782 sys.sort_by(|a, b| a.partial_cmp(b).unwrap());
3783 let mut gold_sorted = gold;
3784 gold_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
3785 println!("Word at: {idx}");
3786 assert_eq!(
3787 sys,
3788 gold_sorted
3789 .iter()
3790 .map(|(s, w)| (s.to_string(), (*w).into()))
3791 .collect::<Vec<_>>()
3792 );
3793 }
3794 }
3795
3796 #[test]
3797 fn test_simple_unknown() {
3798 let code = "0\t1\t@_UNKNOWN_SYMBOL_@\ty\n1";
3799 let fst = FST::from_att_code(code.to_string(), false).unwrap();
3800
3801 assert_eq!(
3802 fst.run_fst(
3803 vec![Symbol::String(StringSymbol::new("x".to_string(), false,))],
3804 FSTState::<()>::_new(0),
3805 false,
3806 None,
3807 true
3808 ),
3809 vec![]
3810 );
3811
3812 assert_eq!(
3813 fst.run_fst(
3814 vec![Symbol::String(StringSymbol::new("x".to_string(), true,))],
3815 FSTState::_new(0),
3816 false,
3817 None,
3818 true
3819 ),
3820 vec![(
3821 true,
3822 false,
3823 FSTState {
3824 state_num: 1,
3825 path_weight: 0.0,
3826 input_flags: FlagMap::new(),
3827 output_flags: FlagMap::new(),
3828 symbol_mappings: FSTLinkedList::Some(
3829 (Symbol::String(StringSymbol::new("y".to_string(), false)), 0),
3830 Arc::new(FSTLinkedList::None)
3831 )
3832 }
3833 )]
3834 );
3835 }
3836
3837 #[test]
3838 fn test_simple_identity() {
3839
3840 let code = "0\t1\t@_IDENTITY_SYMBOL_@\t@_IDENTITY_SYMBOL_@\n2\t3\ta\ta\n1";
3843 let fst = FST::from_att_code(code.to_string(), false).unwrap();
3844
3845 assert_eq!(
3846 fst.run_fst(
3847 vec![Symbol::String(StringSymbol::new("x".to_string(), false,))],
3848 FSTState::<()>::_new(0),
3849 false,
3850 None,
3851 true
3852 ),
3853 vec![]
3854 );
3855
3856 assert_eq!(
3857 fst.run_fst(
3858 vec![Symbol::String(StringSymbol::new("x".to_string(), true,))],
3859 FSTState::_new(0),
3860 false,
3861 None,
3862 true
3863 ),
3864 vec![(
3865 true,
3866 false,
3867 FSTState {
3868 state_num: 1,
3869 path_weight: 0.0,
3870 input_flags: FlagMap::new(),
3871 output_flags: FlagMap::new(),
3872 symbol_mappings: FSTLinkedList::Some(
3873 (Symbol::String(StringSymbol::new("x".to_string(), true)), 0),
3874 Arc::new(FSTLinkedList::None)
3875 )
3876 }
3877 )]
3878 );
3879
3880 assert_eq!(fst.lookup::<()>("x", FSTState::default(), true).unwrap(), vec![("x".to_string(), 0.0)]);
3883 assert_eq!(fst.lookup::<()>("a", FSTState::default(), true).unwrap(), vec![]);
3884 }
3885
3886 #[test]
3887 fn test_raw_symbols() {
3888 let mut rules: IndexMap<u64, IndexMap<Symbol, Vec<(u64, Symbol, f64)>>> = IndexMap::new();
3891 let sym_a = Symbol::Raw(RawSymbol {
3892 value: [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
3893 });
3894 let sym_b = Symbol::Raw(RawSymbol {
3895 value: [0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
3896 });
3897 let sym_c = Symbol::Raw(RawSymbol {
3898 value: [0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
3899 });
3900 let special_epsilon = Symbol::Raw(RawSymbol {
3901 value: [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
3902 });
3903 let sym_d = Symbol::Raw(RawSymbol {
3904 value: [0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
3905 });
3906 let sym_d_unk = Symbol::Raw(RawSymbol {
3907 value: [2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
3908 });
3909 rules.insert(0, indexmap!(sym_a.clone() => vec![(1, sym_a.clone(), 0.0)]));
3910 rules.insert(1, indexmap!(sym_b.clone() => vec![(0, sym_b.clone(), 0.0)], Symbol::Special(SpecialSymbol::IDENTITY) => vec![(2, Symbol::Special(SpecialSymbol::IDENTITY), 0.0)]));
3911 rules.insert(
3912 2,
3913 indexmap!(special_epsilon.clone() => vec![(3, sym_c.clone(), 0.0)]),
3914 );
3915 let symbols = vec![sym_a.clone(), sym_b.clone(), sym_c.clone(), special_epsilon];
3916 let fst = FST {
3917 final_states: indexmap! {3 => 0.0},
3918 rules,
3919 symbols,
3920 debug: false,
3921 };
3922
3923 let result = fst.run_fst(
3926 vec![
3927 sym_a.clone(),
3928 sym_b.clone(),
3929 sym_a.clone(),
3930 sym_d_unk.clone(),
3931 ],
3932 FSTState::<()>::_new(0),
3933 false,
3934 None,
3935 true,
3936 );
3937 let filtered: Vec<_> = result.into_iter().filter(|x| x.0).collect();
3938 assert_eq!(filtered.len(), 1);
3939 assert_eq!(filtered[0].2.state_num, 3);
3940 assert_eq!(
3941 filtered[0]
3942 .2
3943 .symbol_mappings
3944 .clone()
3945 .to_vec()
3946 .iter()
3947 .map(|x| x.0.clone())
3948 .collect::<Vec<_>>(),
3949 vec![
3950 sym_a.clone(),
3951 sym_b.clone(),
3952 sym_a.clone(),
3953 sym_d_unk.clone(),
3954 sym_c.clone()
3955 ]
3956 );
3957
3958 assert_eq!(
3961 fst.run_fst(
3962 vec![sym_a.clone(), sym_b.clone(), sym_a.clone(), sym_d.clone()],
3963 FSTState::<()>::_new(0),
3964 false,
3965 None,
3966 true
3967 )
3968 .into_iter()
3969 .filter(|x| x.0)
3970 .count(),
3971 0,
3972 );
3973 }
3974
3975 #[test]
3976 fn test_string_comparison_order_for_tokenizable_symbol_types() {
3977 assert!(
3978 StringSymbol::new("aa".to_string(), false) < StringSymbol::new("a".to_string(), false)
3979 );
3980 assert!(
3981 StringSymbol::new("aa".to_string(), true) > StringSymbol::new("aa".to_string(), false)
3982 );
3983 assert!(
3984 StringSymbol::new("ab".to_string(), false) > StringSymbol::new("aa".to_string(), false)
3985 );
3986
3987 assert!(
3988 FlagDiacriticSymbol::parse("@U.aa@").unwrap()
3989 < FlagDiacriticSymbol::parse("@U.a@").unwrap()
3990 );
3991 assert!(
3992 FlagDiacriticSymbol::parse("@U.ab@").unwrap()
3993 > FlagDiacriticSymbol::parse("@U.aa@").unwrap()
3994 );
3995 assert!(SpecialSymbol::IDENTITY < SpecialSymbol::EPSILON); }
3997
3998 #[test]
3999 fn fst_linked_list_conversion_correctness_internal_methods() {
4000 let orig = vec![
4001 Symbol::Special(SpecialSymbol::EPSILON),
4002 Symbol::Special(SpecialSymbol::IDENTITY),
4003 Symbol::Special(SpecialSymbol::UNKNOWN),
4004 ];
4005 assert!(
4006 FSTLinkedList::from_vec(orig.clone(), vec![10, 20, 30]).to_vec()
4007 == vec![
4008 (Symbol::Special(SpecialSymbol::EPSILON), 10),
4009 (Symbol::Special(SpecialSymbol::IDENTITY), 20),
4010 (Symbol::Special(SpecialSymbol::UNKNOWN), 30),
4011 ]
4012 );
4013
4014 assert!(
4015 FSTLinkedList::from_vec(orig.clone(), vec![]).to_vec()
4016 == vec![
4017 (Symbol::Special(SpecialSymbol::EPSILON), 0),
4018 (Symbol::Special(SpecialSymbol::IDENTITY), 0),
4019 (Symbol::Special(SpecialSymbol::UNKNOWN), 0),
4020 ]
4021 );
4022 }
4023
4024 #[test]
4025 fn fst_linked_list_conversion_correctness_iterators_with_indices() {
4026 let orig = vec![
4027 (Symbol::Special(SpecialSymbol::EPSILON), 10),
4028 (Symbol::Special(SpecialSymbol::IDENTITY), 20),
4029 (Symbol::Special(SpecialSymbol::UNKNOWN), 30),
4030 ];
4031 let ll: FSTLinkedList<usize> = orig.clone().into_iter().collect();
4032 assert!(
4033 ll.into_iter().collect::<Vec<_>>()
4034 == vec![
4035 (Symbol::Special(SpecialSymbol::EPSILON), 10),
4036 (Symbol::Special(SpecialSymbol::IDENTITY), 20),
4037 (Symbol::Special(SpecialSymbol::UNKNOWN), 30),
4038 ]
4039 );
4040 }
4041}
4042
4043#[cfg(feature = "python")]
4045#[pymodule]
4046fn kfst_rs(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
4047 let symbols = PyModule::new(m.py(), "symbols")?;
4048 symbols.add_class::<StringSymbol>()?;
4049 symbols.add_class::<FlagDiacriticType>()?;
4050 symbols.add_class::<FlagDiacriticSymbol>()?;
4051 symbols.add_class::<SpecialSymbol>()?;
4052 symbols.add_class::<RawSymbol>()?;
4053 symbols.add_function(wrap_pyfunction!(from_symbol_string, m)?)?;
4054
4055 py_run!(
4056 py,
4057 symbols,
4058 "import sys; sys.modules['kfst_rs.symbols'] = symbols"
4059 );
4060
4061 m.add_submodule(&symbols)?;
4062
4063 let transducer = PyModule::new(m.py(), "transducer")?;
4064 transducer.add_class::<FST>()?;
4065 transducer.add_class::<FSTState>()?;
4066 transducer.add(
4067 "TokenizationException",
4068 py.get_type::<TokenizationException>(),
4069 )?;
4070
4071 py_run!(
4072 py,
4073 transducer,
4074 "import sys; sys.modules['kfst_rs.transducer'] = transducer"
4075 );
4076
4077 m.add_submodule(&transducer)?;
4078
4079 m.add(
4082 "TokenizationException",
4083 py.get_type::<TokenizationException>(),
4084 )?;
4085 m.add_class::<FST>()?;
4086
4087 Ok(())
4088}