kfst_rs/
lib.rs

1/*
2 This file is part of KFST.
3
4 (c) 2023-2025 Iikka Hauhio <iikka.hauhio@helsinki.fi> and Théo Salmenkivi-Friberg <theo.friberg@helsinki.fi>
5
6 KFST is free software: you can redistribute it and/or modify it under the
7 terms of the GNU Lesser General Public License as published by the Free
8 Software Foundation, either version 3 of the License, or (at your option) any
9 later version.
10
11 KFST is distributed in the hope that it will be useful, but WITHOUT ANY
12 WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
13 FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
14 details.
15
16 You should have received a copy of the GNU Lesser General Public License
17 along with KFST. If not, see <https://www.gnu.org/licenses/>.
18*/
19
20//! Fast and portable HFST-compatible finite-state transducers.
21//!
22//! An implementation of finite-state transducers mostly compatible with [HFST](https://hfst.github.io/).
23//! Provides the optional accelerated back-end for [kfst](https://github.com/fergusq/fst-python).
24//! Able to load and execute [Voikko](https://voikko.puimula.org/) and [Omorfi](https://github.com/flammie/omorfi):
25//! see [kfst](https://github.com/fergusq/fst-python) for transducers converted to a compatible format as well as Python bindings.
26//! Supports the ATT format and its own KFST format.
27//!
28//! To convert HFST (optimized lookup or otherwise) to ATT using HFST's tools, do:
29//!
30//! ```bash
31//! hfst-fst2txt transducer.hfst -o transducer.att
32//! ```
33//!
34//! Given the Voikko transducer in KFST or ATT format, one could create a simple analyzer like this:
35//!
36//! ```rust
37//! use kfst_rs::{FSTState, FST};
38//! use std::io::{self, Write};
39//!
40//! // Read in transducer
41//!
42//! # let pathtovoikko = "../pyvoikko/pyvoikko/voikko.kfst".to_string();
43//! let fst = FST::from_kfst_file(pathtovoikko, true).unwrap();
44//! // Alternatively, for ATT use FST::from_att_file
45//!
46//! // Read in word to analyze
47//!
48//! let mut buffer = String::new();
49//! let stdin = io::stdin();
50//! stdin.read_line(&mut buffer).unwrap();
51//! buffer = buffer.trim().to_string();
52//!
53//! // Do analysis proper
54//!
55//! match fst.lookup(&buffer, FSTState::<()>::default(), true) {
56//!     Ok(result) => {
57//!         for (i, analysis) in result.into_iter().enumerate() {
58//!             println!("Analysis {}: {} ({})", i+1, analysis.0, analysis.1)
59//!         }
60//!     },
61//!     Err(err) => println!("No analysis: {:?}", err),
62//! }
63//! ```
64//! Given the input "lentokoneessa", this gives the following analysis:
65//!
66//! ```text
67//! Analysis 1: [Lt][Xp]lentää[X]len[Ln][Xj]to[X]to[Sn][Ny][Bh][Bc][Ln][Xp]kone[X]konee[Sine][Ny]ssa (0)
68//! ```
69
70use std::cmp::Ordering;
71use std::collections::HashSet;
72use std::fmt::Debug;
73#[cfg(feature = "python")]
74use std::fmt::Error;
75use std::fs::{self, File};
76use std::hash::Hash;
77use std::io::Read;
78use std::path::Path;
79
80use indexmap::{indexmap, IndexMap, IndexSet};
81use nom::branch::alt;
82use nom::bytes::complete::{tag, take_until1};
83use nom::multi::many_m_n;
84use nom::Parser;
85use std::sync::{Arc, LazyLock, RwLock};
86use xz2::read::{XzDecoder, XzEncoder};
87
88#[cfg(feature = "python")]
89use pyo3::create_exception;
90#[cfg(feature = "python")]
91use pyo3::exceptions::{PyIOError, PyValueError};
92#[cfg(feature = "python")]
93use pyo3::types::{PyDict, PyNone, PyTuple};
94#[cfg(feature = "python")]
95use pyo3::{prelude::*, py_run, IntoPyObjectExt};
96
97// We have result types that kinda depend on the target
98// If we target pyo3, we want python results and errors
99// Otherwise, we want stdlib errors
100
101#[cfg(feature = "python")]
102type KFSTResult<T> = PyResult<T>;
103#[cfg(not(feature = "python"))]
104type KFSTResult<T> = std::result::Result<T, String>;
105
106#[cfg(feature = "python")]
107fn value_error<T>(msg: String) -> KFSTResult<T> {
108    KFSTResult::Err(PyErr::new::<PyValueError, _>(msg))
109}
110#[cfg(not(feature = "python"))]
111fn value_error<T>(msg: String) -> KFSTResult<T> {
112    KFSTResult::Err(msg)
113}
114
115#[cfg(feature = "python")]
116fn io_error<T>(msg: String) -> KFSTResult<T> {
117    use pyo3::exceptions::PyIOError;
118
119    KFSTResult::Err(PyErr::new::<PyIOError, _>(msg))
120}
121#[cfg(not(feature = "python"))]
122fn io_error<T>(msg: String) -> KFSTResult<T> {
123    KFSTResult::Err(msg)
124}
125
126#[cfg(feature = "python")]
127fn tokenization_exception<T>(msg: String) -> KFSTResult<T> {
128    KFSTResult::Err(PyErr::new::<TokenizationException, _>(msg))
129}
130#[cfg(not(feature = "python"))]
131fn tokenization_exception<T>(msg: String) -> KFSTResult<T> {
132    KFSTResult::Err(msg)
133}
134
135#[cfg(feature = "python")]
136create_exception!(
137    kfst_rs,
138    TokenizationException,
139    pyo3::exceptions::PyException
140);
141
142// Symbol interning
143
144static STRING_INTERNER: LazyLock<RwLock<IndexSet<String>>> =
145    LazyLock::new(|| RwLock::new(IndexSet::new()));
146
147fn intern(s: String) -> u32 {
148    u32::try_from(STRING_INTERNER.write().unwrap().insert_full(s).0).unwrap()
149}
150
151/// Get the string in the string interner at a given position.
152/// Waits on interner read lock
153fn deintern(idx: u32) -> String {
154    with_deinterned(idx, |x| x.to_string())
155}
156
157/// Perform an operation on the string nterned at idx
158/// Notably: the read lock is held for the whole duration of the operation.
159fn with_deinterned<F, X>(idx: u32, f: F) -> X
160where
161    F: FnOnce(&str) -> X,
162{
163    f(STRING_INTERNER
164        .read()
165        .unwrap()
166        .get_index(idx.try_into().unwrap())
167        .unwrap())
168}
169
170#[cfg_attr(
171    feature = "python",
172    pyclass(str = "RawSymbol({value:?})", eq, ord, frozen, hash, get_all)
173)]
174#[derive(Clone, Copy, Hash, PartialEq, PartialOrd, Ord, Eq, Debug)]
175#[readonly::make]
176/// A Symbol type that has a signaling byte (the first one) and 14 other bytes to dispose of as the caller wishes.
177/// This odd size is such that [Symbol] can be 16 bytes long: a 1-byte discriminant + 15 bytes.
178/// (The [Symbol::Flag] variant forces [Symbol] to be at least 16 bytes.)
179pub struct RawSymbol {
180    /// The first bit of the first byte should be 1 if the symbol is to be seen as epsilon (see [is_epsilon](RawSymbol::is_epsilon)).
181    ///
182    /// The second bit of the first byte should be 1 if the symbol is to be seen as unknown (see [is_unknown](RawSymbol::is_unknown)).
183    ///
184    /// The remainder of the first byte is reserved.
185    ///
186    /// The following bytes are caller-defined.
187    pub value: [u8; 15],
188}
189
190#[cfg_attr(feature = "python", pymethods)]
191impl RawSymbol {
192    /// Whether this symbol should be seen as ε. (See [Symbol::is_epsilon] for more general information on this)
193    /// Returns true is the least-significant bit of the first byte of [RawSymbol::value] is set. Returns false otherwise.
194    pub fn is_epsilon(&self) -> bool {
195        (self.value[0] & 1) != 0
196    }
197
198    /// Whether this symbol should be seen as unknown. (See [Symbol::is_unknown] for more general information on this)
199    /// Returns true is the second least-significant bit of the first byte of [RawSymbol::value] is set. Returns false otherwise.
200    pub fn is_unknown(&self) -> bool {
201        (self.value[0] & 2) != 0
202    }
203
204    /// A textual representation of this symbol. (See [Symbol::get_symbol] for more general information on this)
205    pub fn get_symbol(&self) -> String {
206        format!("RawSymbol({:?})", self.value)
207    }
208
209    #[cfg(feature = "python")]
210    #[new]
211    fn new(value: [u8; 15]) -> Self {
212        RawSymbol { value }
213    }
214
215    /// Construct an instance of RawSymbol; simply sets [RawSymbol::value] to the provided value.
216    #[cfg(not(feature = "python"))]
217    pub fn new(value: [u8; 15]) -> Self {
218        RawSymbol { value }
219    }
220
221    #[deprecated]
222    /// Python-style string representation.
223    pub fn __repr__(&self) -> String {
224        format!("RawSymbol({:?})", self.value)
225    }
226}
227
228#[cfg(feature = "python")]
229struct PyObjectSymbol {
230    value: PyObject,
231}
232
233#[cfg(feature = "python")]
234impl Debug for PyObjectSymbol {
235    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236        Python::with_gil(|py| {
237            // Seemingly compiling for CPython3.13t (=free-threaded) doesn't for some mysterious reason allow to extract to a &str
238            // So an owned string it must be
239            let s: String = self
240                .value
241                .getattr(py, "__repr__")
242                .unwrap()
243                .call0(py)
244                .unwrap()
245                .extract(py)
246                .unwrap();
247            f.write_str(&s)
248        })
249    }
250}
251
252#[cfg(feature = "python")]
253impl PartialEq for PyObjectSymbol {
254    fn eq(&self, other: &Self) -> bool {
255        Python::with_gil(|py| {
256            self.value
257                .getattr(py, "__eq__")
258                .unwrap_or_else(|_| {
259                    panic!(
260                        "Symbol {} doesn't have an __eq__ implementation.",
261                        self.value
262                    )
263                })
264                .call1(py, (other.value.clone_ref(py),))
265                .unwrap_or_else(|_| {
266                    panic!("__eq__ on symbol {} failed to return a value.", self.value)
267                })
268                .extract(py)
269                .unwrap_or_else(|_| panic!("__eq__ on symbol {} didn't return a bool.", self.value))
270        })
271    }
272}
273
274#[cfg(feature = "python")]
275impl Hash for PyObjectSymbol {
276    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
277        state.write_i128(Python::with_gil(|py| {
278            self.value
279                .getattr(py, "__hash__")
280                .unwrap_or_else(|_| {
281                    panic!(
282                        "Symbol {} doesn't have a __hash__ implementation.",
283                        self.value
284                    )
285                })
286                .call0(py)
287                .unwrap_or_else(|_| {
288                    panic!(
289                        "__hash__ on symbol {} failed to return a value.",
290                        self.value
291                    )
292                })
293                .extract(py)
294                .unwrap_or_else(|_| {
295                    panic!("__hash__ on symbol {} didn't return an int.", self.value)
296                })
297        }))
298    }
299}
300
301#[cfg(feature = "python")]
302impl Eq for PyObjectSymbol {}
303
304#[cfg(feature = "python")]
305impl PartialOrd for PyObjectSymbol {
306    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
307        Some(self.cmp(other))
308    }
309}
310
311#[cfg(feature = "python")]
312impl Ord for PyObjectSymbol {
313    fn cmp(&self, other: &Self) -> Ordering {
314        Python::with_gil(|py| {
315            match self
316                .value
317                .getattr(py, "__gt__")
318                .unwrap_or_else(|_| {
319                    panic!(
320                        "Symbol {} doesn't have a __gt__ implementation.",
321                        self.value
322                    )
323                })
324                .call1(py, (other.value.clone_ref(py),))
325                .unwrap_or_else(|_| {
326                    panic!("__gt__ on symbol {} failed to return a value.", self.value)
327                })
328                .extract::<bool>(py)
329                .unwrap_or_else(|_| panic!("__gt__ on symbol {} didn't return a bool.", self.value))
330            {
331                true => Ordering::Greater,
332                false => {
333                    match self
334                        .value
335                        .getattr(py, "__eq__")
336                        .unwrap_or_else(|_| {
337                            panic!(
338                                "Symbol {} doesn't have an __eq__ implementation.",
339                                self.value
340                            )
341                        })
342                        .call1(py, (other.value.clone_ref(py),))
343                        .unwrap_or_else(|_| {
344                            panic!("__eq__ on symbol {} failed to return a value.", self.value)
345                        })
346                        .extract::<bool>(py)
347                        .unwrap_or_else(|_| {
348                            panic!("__eq__ on symbol {} didn't return a bool.", self.value)
349                        }) {
350                        true => Ordering::Equal,
351                        false => Ordering::Less,
352                    }
353                }
354            }
355        })
356    }
357}
358
359#[cfg(feature = "python")]
360impl Clone for PyObjectSymbol {
361    fn clone(&self) -> Self {
362        Python::with_gil(|py| Self {
363            value: self.value.clone_ref(py),
364        })
365    }
366}
367
368#[cfg(feature = "python")]
369impl FromPyObject<'_> for PyObjectSymbol {
370    fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
371        Ok(PyObjectSymbol {
372            value: ob.clone().unbind(),
373        }) // The clone here is technical; no actual cloning of a value
374    }
375}
376
377#[cfg(feature = "python")]
378impl<'py> IntoPyObject<'py> for PyObjectSymbol {
379    type Target = PyAny;
380
381    type Output = Bound<'py, Self::Target>;
382
383    type Error = pyo3::PyErr;
384
385    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
386        self.value.into_bound_py_any(py)
387    }
388}
389
390#[cfg(feature = "python")]
391impl PyObjectSymbol {
392    fn is_epsilon(&self) -> bool {
393        Python::with_gil(|py| {
394            self.value
395                .getattr(py, "is_epsilon")
396                .unwrap_or_else(|_| {
397                    panic!(
398                        "Symbol {} doesn't have an is_epsilon implementation.",
399                        self.value
400                    )
401                })
402                .call(py, (), None)
403                .unwrap_or_else(|_| {
404                    panic!(
405                        "is_epsilon on symbol {} failed to return a value.",
406                        self.value
407                    )
408                })
409                .extract(py)
410                .unwrap_or_else(|_| {
411                    panic!("is_epsilon on symbol {} didn't return a bool.", self.value)
412                })
413        })
414    }
415
416    fn is_unknown(&self) -> bool {
417        Python::with_gil(|py| {
418            self.value
419                .getattr(py, "is_unknown")
420                .unwrap_or_else(|_| {
421                    panic!(
422                        "Symbol {} doesn't have an is_unknown implementation.",
423                        self.value
424                    )
425                })
426                .call(py, (), None)
427                .unwrap_or_else(|_| {
428                    panic!(
429                        "is_unknown on symbol {} failed to return a value.",
430                        self.value
431                    )
432                })
433                .extract(py)
434                .unwrap_or_else(|_| {
435                    panic!("is_unknown on symbol {} didn't return a bool.", self.value)
436                })
437        })
438    }
439
440    fn get_symbol(&self) -> String {
441        Python::with_gil(|py| {
442            self.value
443                .getattr(py, "get_symbol")
444                .unwrap_or_else(|_| {
445                    panic!(
446                        "Symbol {} doesn't have a get_symbol implementation.",
447                        self.value
448                    )
449                })
450                .call(py, (), None)
451                .unwrap_or_else(|_| {
452                    panic!(
453                        "get_symbol on symbol {} failed to return a value.",
454                        self.value
455                    )
456                })
457                .extract(py)
458                .unwrap_or_else(|_| {
459                    panic!("get_symbol on symbol {} didn't return a bool.", self.value)
460                })
461        })
462    }
463}
464
465#[cfg_attr(
466    feature = "python",
467    pyclass(
468        str = "StringSymbol({string:?}, {unknown})",
469        eq,
470        ord,
471        frozen,
472        hash,
473        get_all
474    )
475)]
476#[derive(Clone, Copy, Hash, PartialEq, Eq)]
477#[readonly::make]
478/// A symbol that holds an interned string and the information of whether it should be seen as unknown (see [is_unknown](StringSymbol::is_unknown)).
479/// The a copy of the interned string is **held until the end of the program**.
480pub struct StringSymbol {
481    string: u32,
482    /// Whether this symbol is considered unknown.
483    pub unknown: bool,
484}
485
486impl StringSymbol {
487    /// Parse a [&str] into a StringSymbol carrying the same text. Returns a known symbol. Fails if given an empty string.
488    ///
489    /// ```
490    /// use kfst_rs::StringSymbol;
491    ///
492    /// StringSymbol::parse("kissa").unwrap(); // Parses into a symbol
493    /// assert!(StringSymbol::parse("").is_err()); // Fails because of empty string
494    /// ```
495    ///
496    /// This is a [nom]-style parser that returns the unparsed part of the string alongside the parsed [StringSymbol].
497    /// However, it gobbles up the whole input string and is guaranteed to return something of the form (assuming that it returns Ok at all)
498    ///
499    /// ```no_test
500    /// Ok(("", StringSymbol { ... }))
501    /// ```
502    pub fn parse(symbol: &str) -> nom::IResult<&str, StringSymbol> {
503        if symbol.is_empty() {
504            return nom::IResult::Err(nom::Err::Error(nom::error::Error::new(
505                symbol,
506                nom::error::ErrorKind::Fail,
507            )));
508        }
509        Ok((
510            "",
511            StringSymbol {
512                string: intern(symbol.to_string()),
513                unknown: false,
514            },
515        ))
516    }
517
518    /// Perform a computation on a non-owned version of the symbol
519    /// Saves a clone
520    fn with_symbol<F, X>(&self, f: F) -> X
521    where
522        F: FnOnce(&str) -> X,
523    {
524        with_deinterned(self.string, f)
525    }
526}
527
528impl PartialOrd for StringSymbol {
529    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
530        Some(self.cmp(other))
531    }
532}
533
534impl Ord for StringSymbol {
535    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
536        with_deinterned(other.string, |other_str| {
537            with_deinterned(self.string, |self_str| {
538                (other_str.chars().count(), &self_str, self.unknown).cmp(&(
539                    self_str.chars().count(),
540                    &other_str,
541                    other.unknown,
542                ))
543            })
544        })
545    }
546}
547#[cfg_attr(feature = "python", pymethods)]
548impl StringSymbol {
549    /// Is this an ε symbol? (See [Symbol::is_epsilon] for more details on the general case)
550    /// Always returns false.
551    pub fn is_epsilon(&self) -> bool {
552        false
553    }
554
555    /// Is this an unknown symbol? (See [Symbol::is_unknown] for more details on the general case)
556    /// Returns the value of [StringSymbol::unknown].
557    pub fn is_unknown(&self) -> bool {
558        self.unknown
559    }
560
561    /// String representation of this symbol (returns the string from which the symbol was constructed)
562    pub fn get_symbol(&self) -> String {
563        deintern(self.string)
564    }
565
566    #[cfg(feature = "python")]
567    #[new]
568    #[pyo3(signature = (string, unknown = false))]
569    fn new(string: String, unknown: bool) -> Self {
570        StringSymbol {
571            string: intern(string),
572            unknown,
573        }
574    }
575
576    #[cfg(not(feature = "python"))]
577    /// Creates a new string symbol. Notably, this **interns the string for the program's runtime**.
578    pub fn new(string: String, unknown: bool) -> Self {
579        StringSymbol {
580            string: intern(string),
581            unknown,
582        }
583    }
584
585    #[deprecated]
586    /// Python-style string representation.
587    pub fn __repr__(&self) -> String {
588        format!("StringSymbol({:?}, {})", self.string, self.unknown)
589    }
590}
591
592#[cfg_attr(feature = "python", pyclass(eq, ord, frozen))]
593#[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Copy, Hash)]
594/// The different types of flag diacritic supported by kfst_rs.
595pub enum FlagDiacriticType {
596    /// Unification diacritic.
597    U,
598    /// Requirement diacritic
599    R,
600    /// Denial diacritic
601    D,
602    /// Clearing diacritic. The transition can always be taken, and the associated flag is cleared.
603    C,
604    /// Positive setting diacritic.
605    P,
606    /// Negative setting diacritic. The transition can always be taken, and the associated flag is negatively set.
607    /// Eg. `@N.X.Y@` means that `X` is set to a value that is guaranteed to not unify with `Y`.
608    N,
609}
610
611impl FlagDiacriticType {
612    /// Converts a string from the set {U, R, D, C, P, N} to the matching diacritic.
613    /// This potentially confusing (see [std::str::FromStr::from_str]) name is as is for Python compatibility.
614    pub fn from_str(s: &str) -> Option<Self> {
615        match s {
616            "U" => Some(FlagDiacriticType::U),
617            "R" => Some(FlagDiacriticType::R),
618            "D" => Some(FlagDiacriticType::D),
619            "C" => Some(FlagDiacriticType::C),
620            "P" => Some(FlagDiacriticType::P),
621            "N" => Some(FlagDiacriticType::N),
622            _ => None,
623        }
624    }
625}
626
627#[cfg_attr(feature = "python", pymethods)]
628impl FlagDiacriticType {
629    #[deprecated]
630    /// Python-style string representation.
631    pub fn __repr__(&self) -> String {
632        format!("{:?}", &self)
633    }
634}
635
636#[cfg_attr(
637    feature = "python",
638    pyclass(
639        str = "FlagDiacriticSymbol({flag_type:?}, {key:?}, {value:?})",
640        eq,
641        ord,
642        frozen,
643        hash
644    )
645)]
646#[derive(PartialEq, Eq, Clone, Copy, Hash)]
647#[readonly::make]
648/// A [Symbol] representing a flag diacritic.
649/// Flag diacritics allow making state transitions depend on externally kept state, thus often making transducers smaller.
650/// The symbol consist of three parts:
651/// 1. The FlagType; see [FlagDiacriticType] for possible options
652/// 2. The name of the flag itself (accessible via [FlagDiacriticSymbol::key])
653/// 2. The value of the flag (accessible via [FlagDiacriticSymbol::value])
654pub struct FlagDiacriticSymbol {
655    /// The type of the flag; see [FlagDiacriticType] for all possible values.
656    pub flag_type: FlagDiacriticType,
657    key: u32,
658    value: u32,
659}
660
661impl PartialOrd for FlagDiacriticSymbol {
662    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
663        Some(self.cmp(other))
664    }
665}
666
667impl Ord for FlagDiacriticSymbol {
668    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
669        // This should be clean; there is a bijection between all flag diacritics and a subset of strings
670        let other_str = other.get_symbol();
671        let self_str = self.get_symbol();
672        (other_str.chars().count(), &self_str).cmp(&(self_str.chars().count(), &other_str))
673    }
674}
675
676impl FlagDiacriticSymbol {
677    /// Parse a flag diacritic from a string representation of the form @SYMBOL_TYPE.KEY.VALUE@ or @SYMBOL_TYPE.KEY@.
678    pub fn parse(symbol: &str) -> nom::IResult<&str, FlagDiacriticSymbol> {
679        let mut parser = (
680            tag("@"),
681            alt((tag("U"), tag("R"), tag("D"), tag("C"), tag("P"), tag("N"))),
682            tag("."),
683            many_m_n(0, 1, (take_until1("."), tag("."))),
684            take_until1("@"),
685            tag("@"),
686        );
687        let (input, (_, diacritic_type, _, named_piece_1, named_piece_2, _)) =
688            parser.parse(symbol)?;
689        let diacritic_type = match FlagDiacriticType::from_str(diacritic_type) {
690            Some(x) => x,
691            None => {
692                return Err(nom::Err::Error(nom::error::Error::new(
693                    diacritic_type,
694                    nom::error::ErrorKind::Fail,
695                )))
696            }
697        };
698
699        let (name, value) = if !named_piece_1.is_empty() {
700            (named_piece_1[0].0, intern(named_piece_2.to_string()))
701        } else {
702            (named_piece_2, u32::MAX)
703        };
704
705        Ok((
706            input,
707            FlagDiacriticSymbol {
708                flag_type: diacritic_type,
709                key: intern(name.to_string()),
710                value,
711            },
712        ))
713    }
714}
715
716// These functions have some non-trivial pyo3-attributes that cannot be cfg_attr'ed in and non-trivial content
717// Need to be specified in separate impl block
718
719impl FlagDiacriticSymbol {
720    fn _from_symbol_string(symbol: &str) -> KFSTResult<Self> {
721        match FlagDiacriticSymbol::parse(symbol) {
722            Ok(("", symbol)) => KFSTResult::Ok(symbol),
723            Ok((rest, _)) => value_error(format!("String {symbol:?} contains a valid FlagDiacriticSymbol, but it has unparseable text at the end: {rest:?}")),
724            _ => value_error(format!("Not a valid FlagDiacriticSymbol: {symbol:?}"))
725        }
726    }
727
728    #[cfg(not(feature = "python"))]
729    #[deprecated]
730    /// Parse from symbol string; exists for Python compatibility, prefer [FlagDiacriticSymbol::parse].
731    pub fn from_symbol_string(symbol: &str) -> KFSTResult<Self> {
732        FlagDiacriticSymbol::_from_symbol_string(symbol)
733    }
734
735    fn _new(flag_type: String, key: String, value: Option<String>) -> KFSTResult<Self> {
736        let flag_type = match FlagDiacriticType::from_str(&flag_type) {
737            Some(x) => x,
738            None => value_error(format!(
739                "String {flag_type:?} is not a valid FlagDiacriticType specifier"
740            ))?,
741        };
742        Ok(FlagDiacriticSymbol {
743            flag_type,
744            key: intern(key),
745            value: value.map(intern).unwrap_or(u32::MAX),
746        })
747    }
748
749    #[cfg(not(feature = "python"))]
750    /// Construct flag diacritic from a [String] representation of flag type, key and value.
751    pub fn new(flag_type: String, key: String, value: Option<String>) -> KFSTResult<Self> {
752        FlagDiacriticSymbol::_new(flag_type, key, value)
753    }
754
755    #[cfg(not(feature = "python"))]
756    /// Deintern the key
757    pub fn key(self) -> String {
758        deintern(self.key)
759    }
760
761    #[cfg(not(feature = "python"))]
762    /// Deintern the value
763    pub fn value(self) -> String {
764        deintern(self.value)
765    }
766}
767
768#[cfg_attr(feature = "python", pymethods)]
769impl FlagDiacriticSymbol {
770    /// Is this symbol to be treated as an ε symbol? Flag diacritics are always ε; this method is guaranteed to return true.
771    /// See [Symbol::is_epsilon] for a more in-depth explanation of what it means to be ε.
772    pub fn is_epsilon(&self) -> bool {
773        true
774    }
775
776    /// Is this symbol to be treated as an unknown symbol?
777    /// See [Symbol::is_epsilon] for a more in-depth explanation of what it means to be unknown.
778    pub fn is_unknown(&self) -> bool {
779        false
780    }
781
782    pub fn get_symbol(&self) -> String {
783        match self.value {
784            u32::MAX => with_deinterned(self.key, |key| format!("@{:?}.{}@", self.flag_type, key)),
785            value => with_deinterned(self.key, |key| {
786                with_deinterned(value, |value| {
787                    format!("@{:?}.{}.{}@", self.flag_type, key, value)
788                })
789            }),
790        }
791    }
792
793    #[cfg(feature = "python")]
794    #[getter]
795    fn flag_type(&self) -> String {
796        format!("{:?}", self.flag_type)
797    }
798
799    #[cfg(not(feature = "python"))]
800    /// Get the flag_type as a string.
801    pub fn flag_type(&self) -> String {
802        format!("{:?}", self.flag_type)
803    }
804
805    #[cfg(feature = "python")]
806    #[getter]
807    fn key(&self) -> String {
808        deintern(self.key)
809    }
810
811    #[cfg(feature = "python")]
812    #[getter]
813    fn value(&self) -> String {
814        deintern(self.value)
815    }
816
817    #[cfg(feature = "python")]
818    #[new]
819    fn new(flag_type: String, key: String, value: Option<String>) -> KFSTResult<Self> {
820        FlagDiacriticSymbol::_new(flag_type, key, value)
821    }
822
823    #[deprecated]
824    /// Python-style string representation.
825    pub fn __repr__(&self) -> String {
826        match self.value {
827            u32::MAX => with_deinterned(self.key, |key| {
828                format!("FlagDiacriticSymbol({:?}, {:?})", self.flag_type, key)
829            }),
830            value => with_deinterned(self.key, |key| {
831                with_deinterned(value, |value| {
832                    format!(
833                        "FlagDiacriticSymbol({:?}, {:?}, {:?})",
834                        self.flag_type, key, value
835                    )
836                })
837            }),
838        }
839    }
840
841    #[cfg(feature = "python")]
842    #[staticmethod]
843    fn is_flag_diacritic(symbol: &str) -> bool {
844        matches!(FlagDiacriticSymbol::parse(symbol), Ok(("", _)))
845    }
846
847    #[cfg(not(feature = "python"))]
848    /// Check if a string is a [FlagDiacriticSymbol], ie. of the form `@X.Y@` or `@X.Y.Z@` for arbitrary `Y` and `Z` and an `X` that is a [FlagDiacriticType];
849    /// see [FlagDiacriticType::from_str]
850    pub fn is_flag_diacritic(symbol: &str) -> bool {
851        matches!(FlagDiacriticSymbol::parse(symbol), Ok(("", _)))
852    }
853
854    #[cfg(feature = "python")]
855    #[staticmethod]
856    fn from_symbol_string(symbol: &str) -> KFSTResult<Self> {
857        FlagDiacriticSymbol::_from_symbol_string(symbol)
858    }
859}
860
861impl std::fmt::Debug for FlagDiacriticSymbol {
862    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
863        write!(f, "{}", self.get_symbol())
864    }
865}
866
867#[cfg_attr(feature = "python", pyclass(eq, ord, frozen, hash))]
868#[derive(PartialEq, Eq, Clone, Hash, Copy)]
869/// The three possible HFST special symbols.
870pub enum SpecialSymbol {
871    /// The simplest possible ε-symbol.
872    /// In transition position, it can always be followed and it doesn't modify flag state.
873    /// If placed in ouput position, it is removed from the output string.
874    EPSILON,
875    /// The identity special symbol.
876    /// It should only appear in transition position. It accepts any unknown symbol, ie. it accepts a symbol if [Symbol::is_unknown] returns `true` for it.
877    /// It transduces an input symbol into the same symbol on the output side. (Hence the name "identity")
878    IDENTITY,
879    /// The unknown special symbol.
880    /// It should only appear in transition position. It matches any unknown symbol, ie. it accepts a symbol if [Symbol::is_unknown] returns `true` for it.
881    UNKNOWN,
882}
883
884impl SpecialSymbol {
885    /// Parses this symbol from (the beginning of) a string representation.
886    /// Accepts:
887    ///
888    /// * `@_EPSILON_SYMBOL_@` and `@0@` for ε ([SpecialSymbol::EPSILON])
889    /// * `@_IDENTITY_SYMBOL_@` for identity ([SpecialSymbol::IDENTITY])
890    /// * `@_UNKNOWN_SYMBOL_@` for unknown ([SpecialSymbol::UNKNOWN])
891    ///
892    /// Returns a result value (Err if the given &str didn't start with any of the given symbols) containing the remainder of the string and the parsed symbol.
893    pub fn parse(symbol: &str) -> nom::IResult<&str, SpecialSymbol> {
894        let (rest, value) = alt((
895            tag("@_EPSILON_SYMBOL_@"),
896            tag("@0@"),
897            tag("@_IDENTITY_SYMBOL_@"),
898            tag("@_UNKNOWN_SYMBOL_@"),
899        ))
900        .parse(symbol)?;
901
902        let sym = match value {
903            "@_EPSILON_SYMBOL_@" => SpecialSymbol::EPSILON,
904            "@0@" => SpecialSymbol::EPSILON,
905            "@_IDENTITY_SYMBOL_@" => SpecialSymbol::IDENTITY,
906            "@_UNKNOWN_SYMBOL_@" => SpecialSymbol::UNKNOWN,
907            _ => panic!(),
908        };
909        Ok((rest, sym))
910    }
911
912    fn _from_symbol_string(symbol: &str) -> KFSTResult<Self> {
913        match SpecialSymbol::parse(symbol) {
914            Ok(("", result)) => KFSTResult::Ok(result),
915            _ => value_error(format!("Not a valid SpecialSymbol: {symbol:?}")),
916        }
917    }
918
919    /// Perform a computation on a non-owned version of the symbol
920    /// Saves a clone
921    fn with_symbol<F, X>(&self, f: F) -> X
922    where
923        F: FnOnce(&str) -> X,
924    {
925        match self {
926            SpecialSymbol::EPSILON => f("@_EPSILON_SYMBOL_@"),
927            SpecialSymbol::IDENTITY => f("@_IDENTITY_SYMBOL_@"),
928            SpecialSymbol::UNKNOWN => f("@_UNKNOWN_SYMBOL_@"),
929        }
930    }
931
932    #[cfg(not(feature = "python"))]
933    /// Parse a special symbol from a text representation.
934    ///
935    /// ```rust
936    /// use kfst_rs::SpecialSymbol;
937    ///
938    /// assert_eq!(SpecialSymbol::from_symbol_string("@_EPSILON_SYMBOL_@"), Ok(SpecialSymbol::EPSILON));
939    /// // Or alternatively
940    /// assert_eq!(SpecialSymbol::from_symbol_string("@0@"), Ok(SpecialSymbol::EPSILON));
941    /// assert_eq!(SpecialSymbol::from_symbol_string("@_IDENTITY_SYMBOL_@"), Ok(SpecialSymbol::IDENTITY));
942    /// assert_eq!(SpecialSymbol::from_symbol_string("@_UNKNOWN_SYMBOL_@"), Ok(SpecialSymbol::UNKNOWN));
943    /// assert_eq!(SpecialSymbol::from_symbol_string("@_GARBAGE_SYMBOL_@"), Err("Not a valid SpecialSymbol: \"@_GARBAGE_SYMBOL_@\"".to_string()));
944    /// ```
945    pub fn from_symbol_string(symbol: &str) -> KFSTResult<Self> {
946        SpecialSymbol::_from_symbol_string(symbol)
947    }
948}
949
950impl PartialOrd for SpecialSymbol {
951    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
952        Some(self.cmp(other))
953    }
954}
955
956impl Ord for SpecialSymbol {
957    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
958        // This should be clean; there is a bijection between all special symbols and a subset of strings
959        self.with_symbol(|self_str| {
960            other.with_symbol(|other_str| {
961                (other_str.chars().count(), &self_str).cmp(&(self_str.chars().count(), &other_str))
962            })
963        })
964    }
965}
966
967#[cfg_attr(feature = "python", pymethods)]
968impl SpecialSymbol {
969    /// Whether this symbol is ε. (See [Symbol::is_epsilon] for the general case)
970    ///
971    /// Returns true for [SpecialSymbol::EPSILON] and false otherwise.
972    pub fn is_epsilon(&self) -> bool {
973        self == &SpecialSymbol::EPSILON
974    }
975
976    /// Whether this symbol is unknown. (See [Symbol::is_unknown] for the general case)
977    ///
978    /// Always returns false.
979    pub fn is_unknown(&self) -> bool {
980        false
981    }
982
983    /// Textual representation of this symbol. Note that the `@0@` synonym for `@_EPSILON_SYMBOL_@` is always converted to the long form.
984    ///
985    /// ```rust
986    /// use kfst_rs::SpecialSymbol;
987    /// assert_eq!(SpecialSymbol::from_symbol_string("@0@").unwrap().get_symbol(), "@_EPSILON_SYMBOL_@".to_string())
988    /// ```
989    pub fn get_symbol(&self) -> String {
990        self.with_symbol(|x| x.to_string())
991    }
992
993    #[cfg(feature = "python")]
994    #[staticmethod]
995    fn from_symbol_string(symbol: &str) -> KFSTResult<Self> {
996        SpecialSymbol::_from_symbol_string(symbol)
997    }
998
999    #[cfg(feature = "python")]
1000    #[staticmethod]
1001    fn is_special_symbol(symbol: &str) -> bool {
1002        SpecialSymbol::from_symbol_string(symbol).is_ok()
1003    }
1004
1005    #[cfg(not(feature = "python"))]
1006    /// Is `symbol` a valid [SpecialSymbol]?
1007    /// Attempts to parse `symbol` using [SpecialSymbol::from_symbol_string] and returns `true` if this succeeds.
1008    /// ```rust
1009    /// use kfst_rs::SpecialSymbol;
1010    ///
1011    /// assert!(SpecialSymbol::is_special_symbol("@0@"));
1012    /// assert!(SpecialSymbol::is_special_symbol("@_EPSILON_SYMBOL_@"));
1013    /// assert!(SpecialSymbol::is_special_symbol("@_IDENTITY_SYMBOL_@"));
1014    /// assert!(SpecialSymbol::is_special_symbol("@_UNKNOWN_SYMBOL_@"));
1015    /// assert_eq!(SpecialSymbol::is_special_symbol("@_GARBAGE_SYMBOL_@"), false);
1016    /// ```
1017    pub fn is_special_symbol(symbol: &str) -> bool {
1018        SpecialSymbol::from_symbol_string(symbol).is_ok()
1019    }
1020}
1021
1022impl std::fmt::Debug for SpecialSymbol {
1023    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1024        self.with_symbol(|symbol| write!(f, "{symbol}"))
1025    }
1026}
1027
1028impl std::fmt::Debug for StringSymbol {
1029    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1030        self.with_symbol(|symbol| write!(f, "{symbol}"))
1031    }
1032}
1033
1034#[cfg(feature = "python")]
1035#[pyfunction]
1036fn from_symbol_string(symbol: &str, py: Python) -> PyResult<Py<PyAny>> {
1037    Symbol::parse(symbol).unwrap().1.into_py_any(py)
1038}
1039
1040#[cfg(not(feature = "python"))]
1041/// Parse a string into a Symbol; see [Symbol::parse] for implementation details.
1042/// ```rust
1043/// use kfst_rs::{from_symbol_string, Symbol, StringSymbol, SpecialSymbol};
1044///
1045/// assert_eq!(from_symbol_string("example").unwrap(), Symbol::String(StringSymbol::parse("example").unwrap().1));
1046/// assert_eq!(from_symbol_string("@_EPSILON_SYMBOL_@").unwrap(), Symbol::Special(SpecialSymbol::EPSILON));
1047///
1048/// ```
1049pub fn from_symbol_string(symbol: &str) -> Option<Symbol> {
1050    Symbol::parse(symbol).ok().map(|(_, sym)| sym)
1051}
1052
1053/// A wrapper enum for different concrete symbol types. It exists to provide a dense tagged union avoiding dynamic dispatch.
1054/// It also deals with converting symbols between Rust and Python when using kfst_rs as a Python library. (crate feature "python")
1055#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1056pub enum Symbol {
1057    /// Wrapper around [SpecialSymbol].
1058    Special(SpecialSymbol),
1059    /// Wrapper around [FlagDiacriticSymbol].
1060    Flag(FlagDiacriticSymbol),
1061    /// Wrapper around [StringSymbol].
1062    String(StringSymbol),
1063    #[cfg(feature = "python")]
1064    /// Wrapper around [PyObjectSymbol] (only build with crate feature "python")
1065    External(PyObjectSymbol),
1066    /// Wrapper around [RawSymbol].
1067    Raw(RawSymbol),
1068}
1069
1070#[cfg(feature = "python")]
1071impl<'py> IntoPyObject<'py> for Symbol {
1072    type Target = PyAny;
1073
1074    type Output = Bound<'py, Self::Target>;
1075
1076    type Error = pyo3::PyErr;
1077
1078    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
1079        match self {
1080            Symbol::Special(special_symbol) => special_symbol.into_bound_py_any(py),
1081            Symbol::Flag(flag_diacritic_symbol) => flag_diacritic_symbol.into_bound_py_any(py),
1082            Symbol::String(string_symbol) => string_symbol.into_bound_py_any(py),
1083            Symbol::External(pyobject_symbol) => pyobject_symbol.into_bound_py_any(py),
1084            Symbol::Raw(raw_symbol) => raw_symbol.into_bound_py_any(py),
1085        }
1086    }
1087}
1088
1089impl PartialOrd for Symbol {
1090    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
1091        Some(self.cmp(other))
1092    }
1093}
1094
1095impl Ord for Symbol {
1096    /// cmp for Symbol tries to formalize the following:
1097    /// - There is a set of *tokenizable symbol types*: Special, Flag and String (those that split_to_symbols can return).
1098    /// - Between those types, there exists a slightly modified lexicographical ordering as such:
1099    ///   primarily: a < b if the string representing a is longer than the string representing b
1100    ///   secondarily: a < b if the equal-length string representing a is ordered before the string represeting b per string's cmp
1101    ///   tertiarily: flag < special < string
1102    ///   quaternarily: within a single symbol type, its own sort implementation holds
1103    /// - Raw symbols are lesser than any other built-in symbols (external symbols can do what they want); they are internally ordered per their own cmp
1104    /// - external symbols defer to their own sorting logic; if one member of the comparison is an external symbol, it is called upon to do the comparison
1105    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
1106        // 1. Is the input made up of tokenizable symbols only?
1107
1108        let self_is_tokenizable = matches!(
1109            self,
1110            Symbol::Special(_) | Symbol::Flag(_) | Symbol::String(_)
1111        );
1112
1113        let other_is_tokenizable = matches!(
1114            other,
1115            Symbol::Special(_) | Symbol::Flag(_) | Symbol::String(_)
1116        );
1117
1118        match (self_is_tokenizable, other_is_tokenizable) {
1119            // If both are tokenizable...
1120            (true, true) => {
1121                // Do we get an ordering from the strings? ("primarily" and "secondarily")
1122
1123                // Use with_symbol to avoid a clone in the common cases of special symbols and epsilon symbols
1124                let result = self.with_symbol(|self_sym| {
1125                    other.with_symbol(|other_sym| {
1126                        (other_sym.chars().count(), &self_sym)
1127                            .cmp(&(self_sym.chars().count(), &other_sym))
1128                    })
1129                });
1130                if result != Ordering::Equal {
1131                    result
1132                } else {
1133                    match (self, other) {
1134                        // Do the types induce and ordering ("tertiarily")
1135                        (Symbol::Special(_), Symbol::Flag(_)) => Ordering::Greater,
1136                        (Symbol::Special(_), Symbol::String(_)) => Ordering::Less,
1137                        (Symbol::Flag(_), Symbol::Special(_)) => Ordering::Less,
1138                        (Symbol::Flag(_), Symbol::String(_)) => Ordering::Less,
1139                        (Symbol::String(_), Symbol::Special(_)) => Ordering::Greater,
1140                        (Symbol::String(_), Symbol::Flag(_)) => Ordering::Greater,
1141
1142                        // Do we have two values of the same type => type-internal ordering holds
1143                        (Symbol::Special(a), Symbol::Special(b)) => a.cmp(b),
1144                        (Symbol::Flag(a), Symbol::Flag(b)) => a.cmp(b),
1145                        (Symbol::String(a), Symbol::String(b)) => a.cmp(b),
1146
1147                        _ => unreachable!(),
1148                    }
1149                }
1150            }
1151            // At least one is non-tokenizable; do we have external symbols?
1152            _ => {
1153                match (self, other) {
1154                    #[cfg(feature = "python")]
1155                    (Symbol::External(left), right) => {
1156                        Python::with_gil(|py| {
1157                            // Strictly less than
1158
1159                            if left
1160                                .value
1161                                .getattr(py, "__lt__")
1162                                .unwrap_or_else(|_| {
1163                                    panic!(
1164                                        "Symbol {} doesn't have a __lt__ implementation.",
1165                                        left.value
1166                                    )
1167                                })
1168                                .call1(py, (right.clone().into_pyobject(py).unwrap(),))
1169                                .unwrap_or_else(|_| {
1170                                    panic!(
1171                                        "__lt__ on symbol {} failed to return a value.",
1172                                        left.value
1173                                    )
1174                                })
1175                                .extract::<bool>(py)
1176                                .unwrap_or_else(|_| {
1177                                    panic!("__lt__ on symbol {} didn't return a bool.", left.value)
1178                                })
1179                            {
1180                                return Ordering::Less;
1181                            }
1182
1183                            // Strictly equal
1184
1185                            if left
1186                                .value
1187                                .getattr(py, "__eq__")
1188                                .unwrap_or_else(|_| {
1189                                    panic!(
1190                                        "Symbol {} doesn't have a __eq__ implementation.",
1191                                        left.value
1192                                    )
1193                                })
1194                                .call1(py, (right.clone().into_pyobject(py).unwrap(),))
1195                                .unwrap_or_else(|_| {
1196                                    panic!(
1197                                        "__eq__ on symbol {} failed to return a value.",
1198                                        left.value
1199                                    )
1200                                })
1201                                .extract::<bool>(py)
1202                                .unwrap_or_else(|_| {
1203                                    panic!("__eq__ on symbol {} didn't return a bool.", left.value)
1204                                })
1205                            {
1206                                return Ordering::Equal;
1207                            }
1208
1209                            // Otherwise must be greater
1210
1211                            Ordering::Greater
1212                        })
1213                    }
1214                    #[cfg(feature = "python")]
1215                    (left, Symbol::External(right)) => {
1216                        Python::with_gil(|py| {
1217                            // Strictly less than
1218
1219                            if right
1220                                .value
1221                                .getattr(py, "__lt__")
1222                                .unwrap_or_else(|_| {
1223                                    panic!(
1224                                        "Symbol {} doesn't have a __lt__ implementation.",
1225                                        right.value
1226                                    )
1227                                })
1228                                .call1(py, (left.clone().into_pyobject(py).unwrap(),))
1229                                .unwrap_or_else(|_| {
1230                                    panic!(
1231                                        "__lt__ on symbol {} failed to return a value.",
1232                                        right.value
1233                                    )
1234                                })
1235                                .extract::<bool>(py)
1236                                .unwrap_or_else(|_| {
1237                                    panic!("__lt__ on symbol {} didn't return a bool.", right.value)
1238                                })
1239                            {
1240                                return Ordering::Greater;
1241                            }
1242
1243                            // Strictly equal
1244
1245                            if right
1246                                .value
1247                                .getattr(py, "__eq__")
1248                                .unwrap_or_else(|_| {
1249                                    panic!(
1250                                        "Symbol {} doesn't have a __eq__ implementation.",
1251                                        right.value
1252                                    )
1253                                })
1254                                .call1(py, (left.clone().into_pyobject(py).unwrap(),))
1255                                .unwrap_or_else(|_| {
1256                                    panic!(
1257                                        "__eq__ on symbol {} failed to return a value.",
1258                                        right.value
1259                                    )
1260                                })
1261                                .extract::<bool>(py)
1262                                .unwrap_or_else(|_| {
1263                                    panic!("__eq__ on symbol {} didn't return a bool.", right.value)
1264                                })
1265                            {
1266                                return Ordering::Equal;
1267                            }
1268
1269                            // Otherwise must be greater
1270
1271                            Ordering::Less
1272                        })
1273                    }
1274
1275                    // Do we have two raw symbols?
1276                    (Symbol::Raw(a), Symbol::Raw(b)) => a.cmp(b),
1277
1278                    // Raw symbols are lesser
1279                    (Symbol::Raw(_), _) => Ordering::Less,
1280                    (_, Symbol::Raw(_)) => Ordering::Greater,
1281
1282                    _ => unreachable!(),
1283                }
1284            }
1285        }
1286    }
1287}
1288
1289impl Symbol {
1290    /// Is this symbol to be treated as an ε symbol?
1291    /// ε symbols get matched without consuming input.
1292    /// The simplest ε symbol is the one defined in [SpecialSymbol::EPSILON] and represented interchangeably by `@0@` and `@_EPSILON_SYMBOL_@`.
1293    /// All [FlagDiacriticSymbols](FlagDiacriticSymbol) are also ε symbols, as they do not consume input.
1294    /// Their string representations are of the form `@X.A@` or `@X.A.B@` where `X` is a [FlagDiacriticType] and `A` and `B` are arbitrary strings.
1295    /// [FST::run_fst] (and thus [FST::lookup]) drops any symbols on the output side for which this methods returns `true`.
1296    pub fn is_epsilon(&self) -> bool {
1297        match self {
1298            Symbol::Special(special_symbol) => special_symbol.is_epsilon(),
1299            Symbol::Flag(flag_diacritic_symbol) => flag_diacritic_symbol.is_epsilon(),
1300            Symbol::String(string_symbol) => string_symbol.is_epsilon(),
1301            #[cfg(feature = "python")]
1302            Symbol::External(py_object_symbol) => py_object_symbol.is_epsilon(),
1303            Symbol::Raw(raw_symbol) => raw_symbol.is_epsilon(),
1304        }
1305    }
1306
1307    /// Is this symbol to be treated as an unknown symbol?
1308    /// Unknown symbols are accepted by the [`@_IDENTITY_SYMBOL_@`](SpecialSymbol::IDENTITY) and [`@_UNKNOWN_SYMBOL_@`](SpecialSymbol::UNKNOWN) special symbols.
1309    pub fn is_unknown(&self) -> bool {
1310        match self {
1311            Symbol::Special(special_symbol) => special_symbol.is_unknown(),
1312            Symbol::Flag(flag_diacritic_symbol) => flag_diacritic_symbol.is_unknown(),
1313            Symbol::String(string_symbol) => string_symbol.is_unknown(),
1314            #[cfg(feature = "python")]
1315            Symbol::External(py_object_symbol) => py_object_symbol.is_unknown(),
1316            Symbol::Raw(raw_symbol) => raw_symbol.is_unknown(),
1317        }
1318    }
1319
1320    /// Get the string-representation of this symbol.
1321    pub fn get_symbol(&self) -> String {
1322        match self {
1323            Symbol::Special(special_symbol) => special_symbol.get_symbol(),
1324            Symbol::Flag(flag_diacritic_symbol) => flag_diacritic_symbol.get_symbol(),
1325            Symbol::String(string_symbol) => string_symbol.get_symbol(),
1326            #[cfg(feature = "python")]
1327            Symbol::External(py_object_symbol) => py_object_symbol.get_symbol(),
1328            Symbol::Raw(raw_symbol) => raw_symbol.get_symbol(),
1329        }
1330    }
1331
1332    /// Perform an operation on the &str representation of this symbol
1333    /// In some cases (StringSymbol and SpecialSymbol), this avoids a clone
1334    pub fn with_symbol<F, X>(&self, f: F) -> X
1335    where
1336        F: FnOnce(&str) -> X,
1337    {
1338        match self {
1339            Symbol::Special(special_symbol) => special_symbol.with_symbol(f),
1340            Symbol::Flag(flag_diacritic_symbol) => f(&flag_diacritic_symbol.get_symbol()),
1341            Symbol::String(string_symbol) => string_symbol.with_symbol(f),
1342            #[cfg(feature = "python")]
1343            Symbol::External(py_object_symbol) => f(&py_object_symbol.get_symbol()),
1344            Symbol::Raw(raw_symbol) => f(&raw_symbol.get_symbol()),
1345        }
1346    }
1347}
1348
1349impl Symbol {
1350    /// Parses a string into a [Symbol]. This tries the following conversions in order:
1351    ///
1352    /// 1. [FlagDiacriticSymbol] and the [Symbol::Flag] variant.
1353    /// 2. [SpecialSymbol] and the [Symbol::Special] variant.
1354    /// 3. [StringSymbol] and the [Symbol::String] variant.
1355    ///
1356    /// Therefore Symbol::Exernal (only built with feature "python") and [Symbol::Raw] variants cannot be constructed with this method.
1357    ///
1358    /// ```rust
1359    /// use kfst_rs::{Symbol, FlagDiacriticSymbol, SpecialSymbol, StringSymbol};
1360    ///
1361    /// assert_eq!(Symbol::parse("@D.X.Y@").unwrap().1, Symbol::Flag(FlagDiacriticSymbol::parse("@D.X.Y@").unwrap().1));
1362    /// assert_eq!(Symbol::parse("@_EPSILON_SYMBOL_@").unwrap().1, Symbol::Special(SpecialSymbol::parse("@_EPSILON_SYMBOL_@").unwrap().1));
1363    /// assert_eq!(Symbol::parse("ladybird").unwrap().1, Symbol::String(StringSymbol::parse("ladybird").unwrap().1));
1364    /// ```
1365    ///
1366    /// Fails when if and only if [StringSymbol::parse] fails: on an empty string.
1367    pub fn parse(symbol: &str) -> nom::IResult<&str, Symbol> {
1368        let mut parser = alt((
1369            |x| {
1370                (FlagDiacriticSymbol::parse, nom::combinator::eof)
1371                    .parse(x)
1372                    .map(|y| (y.0, Symbol::Flag(y.1 .0)))
1373            },
1374            |x| {
1375                (SpecialSymbol::parse, nom::combinator::eof)
1376                    .parse(x)
1377                    .map(|y| (y.0, Symbol::Special(y.1 .0)))
1378            },
1379            |x| StringSymbol::parse(x).map(|y| (y.0, Symbol::String(y.1))),
1380        ));
1381        parser.parse(symbol)
1382    }
1383}
1384
1385#[cfg(feature = "python")]
1386impl FromPyObject<'_> for Symbol {
1387    fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
1388        ob.extract()
1389            .map(Symbol::Special)
1390            .or_else(|_| ob.extract().map(Symbol::Flag))
1391            .or_else(|_| ob.extract().map(Symbol::String))
1392            .or_else(|_| ob.extract().map(Symbol::Raw))
1393            .or_else(|_| ob.extract().map(Symbol::External))
1394    }
1395}
1396#[derive(Clone, Debug, PartialEq, Hash, PartialOrd, Eq, Ord)]
1397#[readonly::make]
1398/// The flag state of an [FSTState]:
1399/// ```no_test
1400/// (name -> (direction of setting where true is positive, value))
1401/// ```
1402/// name and value are interned string indices.
1403/// This is generally an immutable collection.
1404pub struct FlagMap(Vec<(u32, bool, u32)>);
1405
1406/// Construct a FlagMap form an iterator
1407/// Notably, an iterator over `(String, (bool, String))` is what `HashMap<String, (bool, String)>` offers.
1408impl FromIterator<(String, (bool, String))> for FlagMap {
1409    fn from_iter<T: IntoIterator<Item = (String, (bool, String))>>(iter: T) -> Self {
1410        let mut vals: Vec<_> = iter
1411            .into_iter()
1412            .map(|(a, (b, c))| (intern(a), b, intern(c)))
1413            .collect();
1414        vals.sort();
1415        FlagMap(vals)
1416    }
1417}
1418
1419/// Construct a FlagMap form an iterator
1420/// Notably, an iterator over `(String, (bool, String))` is what can be collected into a `HashMap<String, (bool, String)>`.
1421impl<T> From<T> for FlagMap
1422where
1423    T: IntoIterator<Item = (String, (bool, String))>,
1424{
1425    fn from(value: T) -> Self {
1426        FlagMap::from_iter(value)
1427    }
1428}
1429
1430impl FlagMap {
1431    /// Create a new empty FlagMap
1432    pub fn new() -> FlagMap {
1433        FlagMap(vec![])
1434    }
1435
1436    /// Create a clone of this FlagMap with a specific flag removed
1437    pub fn remove(&self, flag: u32) -> FlagMap {
1438        let pp = self.0.partition_point(|v| v.0 < flag);
1439        if pp < self.0.len() && self.0[pp].0 == flag {
1440            let mut new_vals = self.0.clone();
1441            new_vals.remove(pp);
1442            FlagMap(new_vals)
1443        } else {
1444            self.clone()
1445        }
1446    }
1447
1448    /// Get the value and direction of setting of a flag or `None` if the flag is not set.
1449    pub fn get(&self, flag: u32) -> Option<(bool, u32)> {
1450        let pp = self.0.partition_point(|v| v.0 < flag);
1451        if pp < self.0.len() && self.0[pp].0 == flag {
1452            Some((self.0[pp].1, self.0[pp].2))
1453        } else {
1454            None
1455        }
1456    }
1457
1458    /// Create a clone of this FlagMap with a specific flag inserted.
1459    /// Overwrites a flag if it already exists.
1460    pub fn insert(&self, flag: u32, value: (bool, u32)) -> FlagMap {
1461        let pp = self.0.partition_point(|v| v.0 < flag);
1462        let mut new_vals = self.0.clone();
1463        if pp == self.0.len() || self.0[pp].0 != flag {
1464            new_vals.insert(pp, (flag, value.0, value.1));
1465        } else {
1466            new_vals[pp] = (flag, value.0, value.1);
1467        }
1468        FlagMap(new_vals)
1469    }
1470}
1471
1472impl Default for FlagMap {
1473    /// Construct the empty flagmap
1474    fn default() -> Self {
1475        Self::new()
1476    }
1477}
1478
1479#[cfg(feature = "python")]
1480impl FromPyObject<'_> for FlagMap {
1481    fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
1482        let mut as_map: Vec<(u32, bool, u32)> = ob
1483            .getattr("items")?
1484            .call0()?
1485            .try_iter()?
1486            .map(|x| x.unwrap().extract().unwrap())
1487            .map(|(key, value): (String, (bool, String))| (intern(key), value.0, intern(value.1)))
1488            .collect();
1489        as_map.sort();
1490        Ok(FlagMap { 0: as_map })
1491    }
1492}
1493
1494#[cfg(feature = "python")]
1495impl<'py> IntoPyObject<'py> for FlagMap {
1496    type Target = PyAny;
1497
1498    type Output = Bound<'py, Self::Target>;
1499
1500    type Error = pyo3::PyErr;
1501
1502    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
1503        let collection = self
1504            .0
1505            .into_iter()
1506            .map(|(a, b, c)| (deintern(a), (b, deintern(c))))
1507            .collect::<Vec<_>>()
1508            .into_pyobject(py)?;
1509        let immutables = PyModule::import(py, "immutables")?;
1510        let map_class = immutables.getattr("Map")?;
1511        let new = map_class.call_method1("__new__", (&map_class,))?;
1512        map_class.call_method1("__init__", (&new, collection))?;
1513        Ok(new)
1514    }
1515}
1516
1517// transducer.py
1518
1519#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
1520/// The linked list used to represent the transduction outputs.
1521/// A linked list is used here, as the transduced sequences of states share prefixes
1522/// It is internally in reverse order.
1523/// the IntoIterator clones items into temporary storage.
1524pub enum FSTLinkedList<T> {
1525    Some((Symbol, T), Arc<FSTLinkedList<T>>),
1526    None,
1527}
1528
1529impl<T: Clone + Default> IntoIterator for FSTLinkedList<T> {
1530    type Item = (Symbol, T);
1531
1532    type IntoIter = <Vec<<FSTLinkedList<T> as IntoIterator>::Item> as IntoIterator>::IntoIter;
1533
1534    fn into_iter(self) -> Self::IntoIter {
1535        self.to_vec().into_iter()
1536    }
1537}
1538
1539impl<T> FromIterator<(Symbol, T)> for FSTLinkedList<T> {
1540    fn from_iter<I: IntoIterator<Item = (Symbol, T)>>(iter: I) -> Self {
1541        let mut symbol_mappings = FSTLinkedList::None;
1542        for (symbol, input_index) in iter {
1543            symbol_mappings = FSTLinkedList::Some((symbol, input_index), Arc::new(symbol_mappings));
1544        }
1545        symbol_mappings
1546    }
1547}
1548
1549impl<T: Clone + Default> FSTLinkedList<T> {
1550    fn to_vec(self) -> Vec<(Symbol, T)> {
1551        let mut symbol_mappings = vec![];
1552        let mut symbol_mapping_state = &self;
1553        while let FSTLinkedList::Some(value, list) = symbol_mapping_state {
1554            symbol_mapping_state = list;
1555            symbol_mappings.push(value.clone());
1556        }
1557        symbol_mappings.into_iter().rev().collect()
1558    }
1559
1560    fn from_vec(output_symbols: Vec<Symbol>, input_indices: Vec<T>) -> FSTLinkedList<T> {
1561        if input_indices.is_empty() {
1562            output_symbols
1563                .into_iter()
1564                .zip(std::iter::repeat(T::default()))
1565                .collect()
1566        } else if input_indices.len() == output_symbols.len() {
1567            output_symbols.into_iter().zip(input_indices).collect()
1568        } else {
1569            panic!("Mismatch in input index and output symbol len: {} output symbols and {} input indices.", output_symbols.len(), input_indices.len());
1570        }
1571    }
1572}
1573
1574#[cfg(not(feature = "python"))]
1575pub type FSTState<T> = InternalFSTState<T>;
1576
1577#[cfg(feature = "python")]
1578#[cfg_attr(feature = "python", pyclass(frozen, eq, hash))]
1579#[derive(Clone, Debug, PartialEq, PartialOrd, Hash)]
1580pub struct FSTState {
1581    payload: Result<InternalFSTState<usize>, InternalFSTState<()>>,
1582}
1583
1584#[cfg(feature = "python")]
1585#[pymethods]
1586impl FSTState {
1587    #[new]
1588    #[pyo3(signature = (state_num, path_weight=0.0, input_flags = FlagMap::new(), output_flags = FlagMap::new(), output_symbols=vec![], input_indices=Some(vec![])))]
1589    fn new(
1590        state_num: u64,
1591        path_weight: f64,
1592        input_flags: FlagMap,
1593        output_flags: FlagMap,
1594        output_symbols: Vec<Symbol>,
1595        input_indices: Option<Vec<usize>>,
1596    ) -> Self {
1597        match input_indices {
1598            Some(idxs) => FSTState {
1599                payload: Ok(InternalFSTState::new(
1600                    state_num,
1601                    path_weight,
1602                    input_flags,
1603                    output_flags,
1604                    output_symbols,
1605                    idxs,
1606                )),
1607            },
1608            None => FSTState {
1609                payload: Err(InternalFSTState::new(
1610                    state_num,
1611                    path_weight,
1612                    input_flags,
1613                    output_flags,
1614                    output_symbols,
1615                    vec![],
1616                )),
1617            },
1618        }
1619    }
1620
1621    fn strip_indices(&self) -> FSTState {
1622        match &self.payload {
1623            Ok(p) => FSTState {
1624                payload: Err(p.clone().convert_indices()),
1625            },
1626            Err(_p) => self.clone(),
1627        }
1628    }
1629
1630    fn ensure_indices(&self) -> FSTState {
1631        match &self.payload {
1632            Ok(_p) => self.clone(),
1633            Err(p) => FSTState {
1634                payload: Ok(p.clone().convert_indices()),
1635            },
1636        }
1637    }
1638
1639    #[getter]
1640    fn output_symbols<'a>(&'a self, py: Python<'a>) -> Result<Bound<'a, PyTuple>, PyErr> {
1641        PyTuple::new(
1642            py,
1643            match &self.payload {
1644                Ok(p) => p.output_symbols(),
1645                Err(p) => p.output_symbols(),
1646            },
1647        )
1648    }
1649
1650    #[getter]
1651    fn input_indices<'a>(&'a self, py: Python<'a>) -> PyResult<Bound<'a, PyAny>> {
1652        match &self.payload {
1653            Ok(p) => PyTuple::new(py, p.input_indices().unwrap())?.into_bound_py_any(py),
1654            Err(_p) => PyNone::get(py).into_bound_py_any(py),
1655        }
1656    }
1657
1658    #[deprecated]
1659    /// Python-style string representation.
1660    pub fn __repr__(&self) -> String {
1661        match &self.payload {
1662            Ok(p) => p.__repr__(),
1663            Err(p) => p.__repr__(),
1664        }
1665    }
1666
1667    #[getter]
1668    pub fn state_num(&self) -> u64 {
1669        match &self.payload {
1670            Ok(p) => p.state_num,
1671            Err(p) => p.state_num,
1672        }
1673    }
1674
1675    #[getter]
1676    pub fn path_weight(&self) -> f64 {
1677        match &self.payload {
1678            Ok(p) => p.path_weight,
1679            Err(p) => p.path_weight,
1680        }
1681    }
1682
1683    #[getter]
1684    pub fn input_flags(&self) -> FlagMap {
1685        match &self.payload {
1686            Ok(p) => p.input_flags.clone(),
1687            Err(p) => p.input_flags.clone(),
1688        }
1689    }
1690
1691    #[getter]
1692    pub fn output_flags(&self) -> FlagMap {
1693        match &self.payload {
1694            Ok(p) => p.output_flags.clone(),
1695            Err(p) => p.output_flags.clone(),
1696        }
1697    }
1698}
1699
1700#[cfg(feature = "python")]
1701impl<T: UsizeOrUnit + Default> FromPyObject<'_> for InternalFSTState<T> {
1702    fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
1703        let wrapped: FSTState = ob.extract()?;
1704        match wrapped.payload {
1705            Ok(p) => Ok(p.convert_indices()),
1706            Err(p) => Ok(p.convert_indices()),
1707        }
1708    }
1709}
1710
1711#[cfg(feature = "python")]
1712impl<'py> IntoPyObject<'py> for InternalFSTState<usize> {
1713    type Target = FSTState;
1714
1715    type Output = Bound<'py, Self::Target>;
1716
1717    type Error = pyo3::PyErr;
1718
1719    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
1720        FSTState { payload: Ok(self) }.into_pyobject(py)
1721    }
1722}
1723
1724#[cfg(feature = "python")]
1725impl<'py> IntoPyObject<'py> for InternalFSTState<()> {
1726    type Target = FSTState;
1727
1728    type Output = Bound<'py, Self::Target>;
1729
1730    type Error = pyo3::PyErr;
1731
1732    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
1733        FSTState { payload: Err(self) }.into_pyobject(py)
1734    }
1735}
1736
1737#[derive(Clone, Debug, PartialEq, PartialOrd)]
1738#[readonly::make]
1739/// A state in an [FST].
1740/// Not only does this contain the state number itself,
1741/// but also the path weight so far, the output symbol sequence
1742/// and the input and output flag state.
1743/// InternalFSTState carries a type parameter for the input indices.
1744/// There are cases where the input indices are useful, notably for [FST::lookup_aligned].
1745/// However, if you do not want to use that method, you can get away with passing the unit type.
1746/// This causes the book-keeping relating to indices to be compiled away.
1747pub struct InternalFSTState<T> {
1748    /// Number of the state in the FST.
1749    pub state_num: u64,
1750    /// Sum of transition weights so far.
1751    pub path_weight: f64,
1752    /// Mapping from flags to what they are set to (input side)
1753    pub input_flags: FlagMap,
1754    /// Mapping from flags to what they are set to (output side)
1755    pub output_flags: FlagMap,
1756    /// Output side symbols & alignments for the transduction so far.
1757    pub symbol_mappings: FSTLinkedList<T>,
1758}
1759
1760impl<T: Clone + Default + Debug + UsizeOrUnit> InternalFSTState<T> {
1761    fn convert_indices<T2: UsizeOrUnit + Default>(self) -> InternalFSTState<T2> {
1762        let new_mappings = self
1763            .symbol_mappings
1764            .into_iter()
1765            .map(|(sym, idx)| (sym, T2::convert(idx)))
1766            .collect();
1767        InternalFSTState {
1768            state_num: self.state_num,
1769            path_weight: self.path_weight,
1770            input_flags: self.input_flags,
1771            output_flags: self.output_flags,
1772            symbol_mappings: new_mappings,
1773        }
1774    }
1775}
1776
1777impl<T> Default for InternalFSTState<T> {
1778    /// Produce a neutral start state: number 0, no weight, empty flags, empty input indices and empty output.
1779    fn default() -> Self {
1780        Self {
1781            state_num: 0,
1782            path_weight: 0.0,
1783            input_flags: FlagMap::new(),
1784            output_flags: FlagMap::new(),
1785            symbol_mappings: FSTLinkedList::None,
1786        }
1787    }
1788}
1789
1790impl<T: Hash> Hash for InternalFSTState<T> {
1791    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
1792        self.state_num.hash(state);
1793        self.path_weight.to_be_bytes().hash(state);
1794        self.input_flags.hash(state);
1795        self.output_flags.hash(state);
1796        self.symbol_mappings.hash(state);
1797    }
1798}
1799
1800fn _test_flag(stored_val: &(bool, u32), queried_val: u32) -> bool {
1801    stored_val.0 == (stored_val.1 == queried_val)
1802}
1803
1804impl<T: Clone + Default> InternalFSTState<T> {
1805    fn _new(state: u64) -> Self {
1806        InternalFSTState {
1807            state_num: state,
1808            path_weight: 0.0,
1809            input_flags: FlagMap::new(),
1810            output_flags: FlagMap::new(),
1811            symbol_mappings: FSTLinkedList::None,
1812        }
1813    }
1814
1815    fn __new<F>(
1816        state: u64,
1817        path_weight: f64,
1818        input_flags: F,
1819        output_flags: F,
1820        output_symbols: Vec<Symbol>,
1821        input_indices: Vec<T>,
1822    ) -> Self
1823    where
1824        F: Into<FlagMap>,
1825    {
1826        InternalFSTState {
1827            state_num: state,
1828            path_weight,
1829            input_flags: input_flags.into(),
1830            output_flags: output_flags.into(),
1831            symbol_mappings: FSTLinkedList::from_vec(output_symbols, input_indices),
1832        }
1833    }
1834
1835    #[cfg(not(feature = "python"))]
1836    /// Construct a new FSTState. All arguments are per FSTState fields, except for the flag states.
1837    /// These are not a [FlagMap] but and IndexMap of (name -> (direction of setting where true is positively set, value))
1838    /// where name and value get interned.
1839    pub fn new<F>(
1840        state_num: u64,
1841        path_weight: f64,
1842        input_flags: F,
1843        output_flags: F,
1844        output_symbols: Vec<Symbol>,
1845        input_indices: Vec<T>,
1846    ) -> Self
1847    where
1848        F: Into<FlagMap>,
1849    {
1850        InternalFSTState::__new(
1851            state_num,
1852            path_weight,
1853            input_flags,
1854            output_flags,
1855            output_symbols,
1856            input_indices,
1857        )
1858    }
1859}
1860
1861impl<T: Clone + Debug + Default + UsizeOrUnit> InternalFSTState<T> {
1862    #[cfg(feature = "python")]
1863    fn new(
1864        state_num: u64,
1865        path_weight: f64,
1866        input_flags: FlagMap,
1867        output_flags: FlagMap,
1868        output_symbols: Vec<Symbol>,
1869        input_indices: Vec<T>,
1870    ) -> Self {
1871        InternalFSTState {
1872            state_num,
1873            path_weight,
1874            input_flags,
1875            output_flags,
1876            symbol_mappings: FSTLinkedList::from_vec(output_symbols, input_indices),
1877        }
1878    }
1879
1880    /// Return the output symbols for this state. Internally, they are (as of now) stored in a [FSTLinkedList]. Calling this method reconstructs a vector.
1881    /// ```rust
1882    /// use kfst_rs::{FSTState, FlagMap, Symbol, SpecialSymbol};
1883    ///
1884    /// let output_symbols = vec![
1885    ///     Symbol::Special(SpecialSymbol::EPSILON),
1886    ///     Symbol::Special(SpecialSymbol::UNKNOWN),
1887    ///     Symbol::Special(SpecialSymbol::IDENTITY)
1888    /// ];
1889    /// // The actual alignment (last argument ie. input_indices) here is nonsensical
1890    /// let state = FSTState::new(0, 0.0, FlagMap::new(), FlagMap::new(), output_symbols.clone(), vec![10, 20, 30]);
1891    // assert!(state.output_symbols() == orig);
1892    /// ```
1893    pub fn output_symbols(&self) -> Vec<Symbol> {
1894        self.symbol_mappings
1895            .clone()
1896            .into_iter()
1897            .map(|x| x.0)
1898            .collect()
1899    }
1900
1901    /// Return the output symbols for this state. Internally, they are (as of now) stored in a [FSTLinkedList]. Calling this method reconstructs a vector if T = usize and returns None otherwise.
1902    /// ```rust
1903    /// use kfst_rs::{FSTState, FlagMap, Symbol, SpecialSymbol};
1904    ///
1905    /// let output_symbols = vec![
1906    ///     Symbol::Special(SpecialSymbol::EPSILON),
1907    ///     Symbol::Special(SpecialSymbol::UNKNOWN),
1908    ///     Symbol::Special(SpecialSymbol::IDENTITY)
1909    /// ];
1910    /// // The actual alignment (last argument ie. input_indices) here is nonsensical
1911    /// let state = FSTState::new(0, 0.0, FlagMap::new(), FlagMap::new(), output_symbols.clone(), vec![10, 20, 30]);
1912    /// assert!(state.input_indices() == Some(vec![10, 20, 30]));
1913    /// let state2 = FSTState::new(0, 0.0, FlagMap::new(), FlagMap::new(), output_symbols.clone(), vec![(), (), ()]);
1914    /// assert!(state2.input_indices() == None);
1915    /// ```
1916    pub fn input_indices(&self) -> Option<Vec<usize>> {
1917        T::branch(
1918            || {
1919                Some(
1920                    self.symbol_mappings
1921                        .clone()
1922                        .into_iter()
1923                        .map(|x| x.1.as_usize())
1924                        .collect(),
1925                )
1926            },
1927            || None,
1928        )
1929    }
1930
1931    #[deprecated]
1932    /// Python-style string representation.
1933    pub fn __repr__(&self) -> String {
1934        format!(
1935            "FSTState({}, {}, {:?}, {:?}, {:?}, {:?})",
1936            self.state_num,
1937            self.path_weight,
1938            self.input_flags,
1939            self.output_flags,
1940            self.symbol_mappings
1941                .clone()
1942                .into_iter()
1943                .map(|x| x.0)
1944                .collect::<Vec<_>>(),
1945            self.symbol_mappings
1946                .clone()
1947                .into_iter()
1948                .map(|x| x.1)
1949                .collect::<Vec<_>>()
1950        )
1951    }
1952}
1953
1954/// Cleans up escapes in att; in practice it converts @_TAB_@ and @_SPACE_@ to actual tab and space characters
1955/// Open question: should newlines be handled somehow?
1956fn unescape_att_symbol(att_symbol: &str) -> String {
1957    att_symbol
1958        .replace("@_TAB_@", "\t")
1959        .replace("@_SPACE_@", " ")
1960}
1961
1962/// Escapes symbol for att compatibility; in practice converts tabs and spaces to @_TAB_@ and @_SPACE_@ sequences.
1963/// Open question: should newlines be handled somehow?
1964fn escape_att_symbol(symbol: &str) -> String {
1965    symbol.replace("\t", "@_TAB_@").replace(" ", "@_SPACE_@")
1966}
1967
1968pub trait UsizeOrUnit: Sized {
1969    const ZERO: Self;
1970    const INCREMENT: Self;
1971    fn branch<F1, F2, B>(f1: F1, f2: F2) -> B
1972    where
1973        F1: FnOnce() -> B,
1974        F2: FnOnce() -> B;
1975    fn as_usize(&self) -> usize;
1976    fn convert<T: UsizeOrUnit>(x: T) -> Self;
1977}
1978
1979impl UsizeOrUnit for usize {
1980    const ZERO: Self = 0;
1981
1982    const INCREMENT: Self = 1;
1983
1984    fn as_usize(&self) -> usize {
1985        *self
1986    }
1987
1988    fn convert<T: UsizeOrUnit>(x: T) -> Self {
1989        x.as_usize()
1990    }
1991
1992    fn branch<F1, F2, B>(f1: F1, _f2: F2) -> B
1993    where
1994        F1: FnOnce() -> B,
1995        F2: FnOnce() -> B,
1996    {
1997        f1()
1998    }
1999}
2000
2001impl UsizeOrUnit for () {
2002    fn as_usize(&self) -> usize {
2003        0
2004    }
2005
2006    const ZERO: Self = ();
2007
2008    const INCREMENT: Self = ();
2009
2010    fn convert<T: UsizeOrUnit>(_x: T) -> Self {}
2011
2012    fn branch<F1, F2, B>(_f1: F1, f2: F2) -> B
2013    where
2014        F1: FnOnce() -> B,
2015        F2: FnOnce() -> B,
2016    {
2017        f2()
2018    }
2019}
2020
2021#[cfg_attr(feature = "python", pyclass(frozen, get_all))]
2022#[readonly::make]
2023/// A finite state transducer.
2024/// Constructed using [FST::from_kfst_bytes] or [FST::from_att_rows] from an in-memory representation or [FST::from_att_file] and [FST::from_kfst_file] from the file system.
2025///
2026/// To run an existing transducer (here Voikko):
2027///
2028/// ```rust
2029/// use kfst_rs::{FSTState, FST};
2030/// use std::io::{self, Write};
2031///
2032/// // Read in transducer
2033///
2034/// # let pathtovoikko = "../pyvoikko/pyvoikko/voikko.kfst".to_string();
2035/// let fst = FST::from_kfst_file(pathtovoikko, true).unwrap();
2036/// // Alternatively, for ATT use FST::from_att_file
2037///
2038/// // Read in word to analyze
2039///
2040/// let mut buffer = String::new();
2041/// let stdin = io::stdin();
2042/// stdin.read_line(&mut buffer).unwrap();
2043/// buffer = buffer.trim().to_string();
2044///
2045/// // Do analysis proper
2046///
2047/// match fst.lookup(&buffer, FSTState::<()>::default(), true) {
2048///     Ok(result) => {
2049///         for (i, analysis) in result.into_iter().enumerate() {
2050///             println!("Analysis {}: {} ({})", i+1, analysis.0, analysis.1)
2051///         }
2052///     },
2053///     Err(err) => println!("No analysis: {:?}", err),
2054/// }
2055/// ```
2056/// Given the input "lentokoneessa", this gives the following analysis:
2057///
2058/// ```text
2059/// Analysis 1: [Lt][Xp]lentää[X]len[Ln][Xj]to[X]to[Sn][Ny][Bh][Bc][Ln][Xp]kone[X]konee[Sine][Ny]ssa (0)
2060/// ```
2061pub struct FST {
2062    /// A mapping from the index of a final state to its weight.
2063    pub final_states: IndexMap<u64, f64>,
2064    /// The transition rules of this FST: (state number -> (top symbol -> list of target state indices, bottom symbols and weights))
2065    pub rules: IndexMap<u64, IndexMap<Symbol, Vec<(u64, Symbol, f64)>>>,
2066    /// List of all the symbols in the transducer (useful for tokenization). Sorted in reverse order by length.
2067    pub symbols: Vec<Symbol>,
2068    /// Whether this FST is in debug mode; kept for compatibility with the python implementation of KFST. It's effects on FST behaviour are undefined.
2069    #[deprecated]
2070    pub debug: bool,
2071}
2072
2073impl FST {
2074    fn _run_fst<T: Clone + UsizeOrUnit>(
2075        &self,
2076        input_symbols: &[Symbol],
2077        state: &InternalFSTState<T>,
2078        post_input_advance: bool,
2079        result: &mut Vec<(bool, bool, InternalFSTState<T>)>,
2080        input_symbol_index: usize,
2081        keep_non_final: bool,
2082    ) {
2083        let transitions = self.rules.get(&state.state_num);
2084        let isymbol = if input_symbols.len() - input_symbol_index == 0 {
2085            match self.final_states.get(&state.state_num) {
2086                Some(&weight) => {
2087                    // Update weight of state to account for weight of final state
2088                    result.push((
2089                        true,
2090                        post_input_advance,
2091                        InternalFSTState {
2092                            state_num: state.state_num,
2093                            path_weight: state.path_weight + weight,
2094                            input_flags: state.input_flags.clone(),
2095                            output_flags: state.output_flags.clone(),
2096                            symbol_mappings: state.symbol_mappings.clone(),
2097                        },
2098                    ));
2099                }
2100                None => {
2101                    // Not a final state
2102                    if keep_non_final {
2103                        result.push((false, post_input_advance, state.clone()));
2104                    }
2105                }
2106            }
2107            None
2108        } else {
2109            Some(&input_symbols[input_symbol_index])
2110        };
2111        if let Some(transitions) = transitions {
2112            for transition_isymbol in transitions.keys() {
2113                if transition_isymbol.is_epsilon() || isymbol == Some(transition_isymbol) {
2114                    self._transition(
2115                        input_symbols,
2116                        input_symbol_index,
2117                        state,
2118                        &transitions[transition_isymbol],
2119                        isymbol,
2120                        transition_isymbol,
2121                        result,
2122                        keep_non_final,
2123                    );
2124                }
2125            }
2126            if let Some(isymbol) = isymbol {
2127                if isymbol.is_unknown() {
2128                    if let Some(transition_list) =
2129                        transitions.get(&Symbol::Special(SpecialSymbol::UNKNOWN))
2130                    {
2131                        self._transition(
2132                            input_symbols,
2133                            input_symbol_index,
2134                            state,
2135                            transition_list,
2136                            Some(isymbol),
2137                            &Symbol::Special(SpecialSymbol::UNKNOWN),
2138                            result,
2139                            keep_non_final,
2140                        );
2141                    }
2142
2143                    if let Some(transition_list) =
2144                        transitions.get(&Symbol::Special(SpecialSymbol::IDENTITY))
2145                    {
2146                        self._transition(
2147                            input_symbols,
2148                            input_symbol_index,
2149                            state,
2150                            transition_list,
2151                            Some(isymbol),
2152                            &Symbol::Special(SpecialSymbol::IDENTITY),
2153                            result,
2154                            keep_non_final,
2155                        );
2156                    }
2157                }
2158            }
2159        }
2160    }
2161
2162    fn _transition<T: Clone + UsizeOrUnit>(
2163        &self,
2164        input_symbols: &[Symbol],
2165        input_symbol_index: usize,
2166        state: &InternalFSTState<T>,
2167        transitions: &[(u64, Symbol, f64)],
2168        isymbol: Option<&Symbol>,
2169        transition_isymbol: &Symbol,
2170        result: &mut Vec<(bool, bool, InternalFSTState<T>)>,
2171        keep_non_final: bool,
2172    ) {
2173        for (next_state, osymbol, weight) in transitions.iter() {
2174            let new_output_flags = _update_flags(osymbol, &state.output_flags);
2175            let new_input_flags = _update_flags(transition_isymbol, &state.input_flags);
2176
2177            match (new_output_flags, new_input_flags) {
2178                (Some(new_output_flags), Some(new_input_flags)) => {
2179                    let new_osymbol = match (isymbol, osymbol) {
2180                        (Some(isymbol), Symbol::Special(SpecialSymbol::IDENTITY)) => isymbol,
2181                        _ => osymbol,
2182                    };
2183
2184                    // Long and complicated maneuver to make a decision of what to do at compile time.
2185                    // If we want unit output anyway, we should just compile this condition away
2186
2187                    let new_symbol_mapping: FSTLinkedList<T> = if new_osymbol.is_epsilon() {
2188                        state.symbol_mappings.clone() // Easy case: nothing new to declare anyway
2189                    } else {
2190                        FSTLinkedList::Some(
2191                            (new_osymbol.clone(), T::convert(input_symbol_index)),
2192                            Arc::new(state.symbol_mappings.clone()),
2193                        )
2194                    };
2195                    let new_state = InternalFSTState {
2196                        state_num: *next_state,
2197                        path_weight: state.path_weight + *weight,
2198                        input_flags: new_input_flags,
2199                        output_flags: new_output_flags,
2200                        symbol_mappings: new_symbol_mapping,
2201                    };
2202                    if transition_isymbol.is_epsilon() {
2203                        self._run_fst(
2204                            input_symbols,
2205                            &new_state,
2206                            input_symbols.is_empty(),
2207                            result,
2208                            input_symbol_index,
2209                            keep_non_final,
2210                        );
2211                    } else {
2212                        self._run_fst(
2213                            input_symbols,
2214                            &new_state,
2215                            false,
2216                            result,
2217                            input_symbol_index + 1,
2218                            keep_non_final,
2219                        );
2220                    }
2221                }
2222                _ => continue,
2223            }
2224        }
2225    }
2226
2227    /// Construct an instance of FST from of rows matching those in an att file (see [FST::from_att_code]) that have been parsed into tuples.
2228    /// Thee representation is read:
2229    /// ```no_test
2230    /// Ok((number of a final state, weight of the final state))
2231    /// Err((source state of transition, target state of transition, top symbol, bottom symbol, weight))
2232    /// ```
2233    /// Debug is passed along to [FST::debug].
2234    pub fn from_att_rows(
2235        rows: Vec<Result<(u64, f64), (u64, u64, Symbol, Symbol, f64)>>,
2236        debug: bool,
2237    ) -> FST {
2238        let mut final_states: IndexMap<u64, f64> = IndexMap::new();
2239        let mut rules: IndexMap<u64, IndexMap<Symbol, Vec<(u64, Symbol, f64)>>> = IndexMap::new();
2240        let mut symbols: IndexSet<Symbol> = IndexSet::new();
2241        for line in rows.into_iter() {
2242            match line {
2243                Ok((state_number, state_weight)) => {
2244                    final_states.insert(state_number, state_weight);
2245                }
2246                Err((state_1, state_2, top_symbol, bottom_symbol, weight)) => {
2247                    rules.entry(state_1).or_default();
2248                    let handle = rules.get_mut(&state_1).unwrap();
2249                    if !handle.contains_key(&top_symbol) {
2250                        handle.insert(top_symbol.clone(), vec![]);
2251                    }
2252                    handle.get_mut(&top_symbol).unwrap().push((
2253                        state_2,
2254                        bottom_symbol.clone(),
2255                        weight,
2256                    ));
2257                    symbols.insert(top_symbol);
2258                    symbols.insert(bottom_symbol);
2259                }
2260            }
2261        }
2262        FST::from_rules(
2263            final_states,
2264            rules,
2265            symbols.into_iter().collect(),
2266            Some(debug),
2267        )
2268    }
2269
2270    fn _from_kfst_bytes(kfst_bytes: &[u8]) -> Result<FST, String> {
2271        // Ownership makes error handling such a pain that it makes more sense to just return an option
2272        // We need to parse part of the data from an owned buffer and it just makes this too comples
2273
2274        // Check that this is v0 kfst format
2275
2276        let mut header = nom::sequence::preceded(
2277            nom::bytes::complete::tag("KFST"),
2278            nom::number::complete::be_u16::<&[u8], ()>,
2279        );
2280        let (rest, version) = header
2281            .parse(kfst_bytes)
2282            .map_err(|_| "Failed to parse header")?;
2283        assert!(version == 0);
2284
2285        // Read metadata
2286
2287        let mut metadata = (
2288            nom::number::complete::be_u16::<&[u8], ()>,
2289            nom::number::complete::be_u32,
2290            nom::number::complete::be_u32,
2291            nom::number::complete::u8,
2292        );
2293        let (rest, (num_symbols, num_transitions, num_final_states, is_weighted)) = metadata
2294            .parse(rest)
2295            .map_err(|_| "Failed to parse metadata")?;
2296        let num_transitions: usize = num_transitions
2297            .try_into()
2298            .map_err(|_| "usize too small to represent transitions")?;
2299        let num_final_states: usize = num_final_states
2300            .try_into()
2301            .map_err(|_| "usize too small to represent final states")?;
2302        // Safest conversion I can think of; theoretically it should only be 1 or 0 but Python just defers to C and C doesn't have its act together on this.
2303        let is_weighted: bool = is_weighted != 0u8;
2304
2305        // Parse out symbols
2306
2307        let mut symbol = nom::multi::count(
2308            nom::sequence::terminated(nom::bytes::complete::take_until1("\0"), tag("\0")),
2309            num_symbols.into(),
2310        );
2311        let (rest, symbols) = symbol
2312            .parse(rest)
2313            .map_err(|_: nom::Err<()>| "Failed to parse symbol list")?;
2314        let symbol_strings: Vec<&str> = symbols
2315            .into_iter()
2316            .map(|x| std::str::from_utf8(x))
2317            .collect::<Result<Vec<&str>, _>>()
2318            .map_err(|x| format!("Some symbol was not valid utf-8: {x}"))?;
2319        let symbol_list: Vec<Symbol> = symbol_strings
2320            .iter()
2321            .map(|x| {
2322                Symbol::parse(x)
2323                    .map_err(|x| {
2324                        format!(
2325                            "Some symbol while valid utf8 was not a valid symbol specifier: {x}"
2326                        )
2327                    })
2328                    .and_then(|(extra, sym)| {
2329                        if extra.is_empty() {
2330                            Ok(sym)
2331                        } else {
2332                            Err(format!(
2333                                "Extra data after end of symbol {}: {extra:?}",
2334                                sym.get_symbol(),
2335                            ))
2336                        }
2337                    })
2338            })
2339            .collect::<Result<Vec<Symbol>, _>>()?;
2340        let symbol_objs: IndexSet<Symbol> = symbol_list.iter().cloned().collect();
2341
2342        // From here on, data is lzma-compressed
2343
2344        let mut decomp: Vec<u8> = Vec::new();
2345        let mut decoder = XzDecoder::new(rest);
2346        decoder
2347            .read_to_end(&mut decomp)
2348            .map_err(|_| "Failed to xz-decompress remainder of file")?;
2349
2350        // The decompressed data is - unavoidably - owned by the function
2351        // We promise an error type of &[u8], which we can't provide from here because of lifetimes
2352
2353        let transition_syntax = (
2354            nom::number::complete::be_u32::<&[u8], ()>,
2355            nom::number::complete::be_u32,
2356            nom::number::complete::be_u16,
2357            nom::number::complete::be_u16,
2358        );
2359        let weight_parser = if is_weighted {
2360            nom::number::complete::be_f64
2361        } else {
2362            |input| Ok((input, 0.0)) // Conjure up a default weight out of thin air
2363        };
2364        let (rest, file_rules) = many_m_n(
2365            num_transitions,
2366            num_transitions,
2367            (transition_syntax, weight_parser),
2368        )
2369        .parse(decomp.as_slice())
2370        .map_err(|_| "Broken transition table")?;
2371
2372        let (rest, final_states) = many_m_n(
2373            num_final_states,
2374            num_final_states,
2375            (nom::number::complete::be_u32, weight_parser),
2376        )
2377        .parse(rest)
2378        .map_err(|_| "Broken final states")?;
2379
2380        if !rest.is_empty() {
2381            Err(format!("lzma-compressed payload is {} bytes long when decompressed but given the header, there seems to be {} bytes extra.", decomp.len(), rest.len()))?;
2382        }
2383
2384        // We have a vec, we want a hash map and our numbers to be i64 instead of u32
2385
2386        let final_states = final_states
2387            .into_iter()
2388            .map(|(a, b)| (a.into(), b))
2389            .collect();
2390
2391        // These should be a hash map instead of a vector
2392
2393        let symbols = symbol_objs.into_iter().collect();
2394
2395        // We need to construct the right rule data structure
2396
2397        let mut rules: IndexMap<u64, IndexMap<Symbol, Vec<(u64, Symbol, f64)>>> = IndexMap::new();
2398        for ((from_state, to_state, top_symbol_idx, bottom_symbol_idx), weight) in
2399            file_rules.into_iter()
2400        {
2401            let from_state = from_state.into();
2402            let to_state = to_state.into();
2403            let top_symbol_idx: usize = top_symbol_idx.into();
2404            let bottom_symbol_idx: usize = bottom_symbol_idx.into();
2405            let top_symbol = symbol_list[top_symbol_idx].clone();
2406            let bottom_symbol = symbol_list[bottom_symbol_idx].clone();
2407            let handle = rules.entry(from_state).or_default();
2408            let vec = handle.entry(top_symbol).or_default();
2409            vec.push((to_state, bottom_symbol, weight));
2410        }
2411
2412        Ok(FST::from_rules(final_states, rules, symbols, None))
2413    }
2414
2415    fn _to_kfst_bytes(&self) -> Result<Vec<u8>, String> {
2416        // 1. Figure out if this transducer if weighted & count transitions
2417
2418        let mut weighted = false;
2419
2420        for (_, &weight) in self.final_states.iter() {
2421            if weight != 0.0 {
2422                weighted = true;
2423                break;
2424            }
2425        }
2426
2427        let mut transitions: u32 = 0;
2428
2429        for (_, transition_table) in self.rules.iter() {
2430            for transition in transition_table.values() {
2431                for (_, _, weight) in transition.iter() {
2432                    if (*weight) != 0.0 {
2433                        weighted = true;
2434                    }
2435                    transitions += 1;
2436                }
2437            }
2438        }
2439
2440        // Construct header
2441
2442        let mut result: Vec<u8> = "KFST".into();
2443        result.extend(0u16.to_be_bytes());
2444        let symbol_len: u16 = self
2445            .symbols
2446            .len()
2447            .try_into()
2448            .map_err(|x| format!("Too many symbols to represent as u16: {x}"))?;
2449        result.extend(symbol_len.to_be_bytes());
2450        result.extend(transitions.to_be_bytes());
2451        let num_states: u32 = self
2452            .final_states
2453            .len()
2454            .try_into()
2455            .map_err(|x| format!("Too many final states to represent as u32: {x}"))?;
2456        result.extend(num_states.to_be_bytes());
2457        result.push(weighted.into()); // Promises 0 for false and 1 for true
2458
2459        // Dump symbols
2460
2461        let mut sorted_syms: Vec<_> = self.symbols.iter().collect();
2462        sorted_syms.sort();
2463        for symbol in sorted_syms.iter() {
2464            result.extend(symbol.get_symbol().into_bytes());
2465            result.push(0); // Add null-terminators
2466        }
2467
2468        // lzma-compressed part of payload
2469
2470        let mut to_compress: Vec<u8> = vec![];
2471
2472        // Push transition table to compressible buffer
2473
2474        for (source_state, transition_table) in self.rules.iter() {
2475            for (top_symbol, transition) in transition_table.iter() {
2476                for (target_state, bottom_symbol, weight) in transition.iter() {
2477                    let source_state: u32 = (*source_state).try_into().map_err(|x| {
2478                        format!("Can't represent source state {source_state} as u32: {x}")
2479                    })?;
2480                    let target_state: u32 = (*target_state).try_into().map_err(|x| {
2481                        format!("Can't represent target state {target_state} as u32: {x}")
2482                    })?;
2483                    let top_index: u16 = sorted_syms
2484                        .binary_search(&top_symbol)
2485                        .map_err(|_| {
2486                            format!("Top symbol {top_symbol:?} not found in FST symbol list")
2487                        })
2488                        .and_then(|x| {
2489                            x.try_into().map_err(|x| {
2490                                format!("Can't represent top symbol index as u16: {x}")
2491                            })
2492                        })?;
2493                    let bottom_index: u16 = sorted_syms
2494                        .binary_search(&bottom_symbol)
2495                        .map_err(|_| {
2496                            format!("Bottom symbol {bottom_symbol:?} not found in FST symbol list")
2497                        })
2498                        .and_then(|x| {
2499                            x.try_into().map_err(|x| {
2500                                format!("Can't represent bottom symbol index as u16: {x}")
2501                            })
2502                        })?;
2503                    to_compress.extend(source_state.to_be_bytes());
2504                    to_compress.extend(target_state.to_be_bytes());
2505                    to_compress.extend(top_index.to_be_bytes());
2506                    to_compress.extend(bottom_index.to_be_bytes());
2507                    if weighted {
2508                        to_compress.extend(weight.to_be_bytes());
2509                    } else {
2510                        assert!(*weight == 0.0);
2511                    }
2512                }
2513            }
2514        }
2515
2516        // Push final states to compressible buffer
2517
2518        for (&final_state, weight) in self.final_states.iter() {
2519            let final_state: u32 = final_state
2520                .try_into()
2521                .map_err(|x| format!("Can't represent final state index as u32: {x}"))?;
2522            to_compress.extend(final_state.to_be_bytes());
2523            if weighted {
2524                to_compress.extend(weight.to_be_bytes());
2525            } else {
2526                assert!(*weight == 0.0);
2527            }
2528        }
2529
2530        // Compress compressible buffer
2531
2532        let mut compressed = vec![];
2533
2534        let mut encoder = XzEncoder::new(to_compress.as_slice(), 9);
2535        encoder
2536            .read_to_end(&mut compressed)
2537            .map_err(|x| format!("Failed while compressing with lzma_rs: {x}"))?;
2538        result.extend(compressed);
2539
2540        Ok(result)
2541    }
2542
2543    fn _from_rules(
2544        final_states: IndexMap<u64, f64>,
2545        rules: IndexMap<u64, IndexMap<Symbol, Vec<(u64, Symbol, f64)>>>,
2546        symbols: HashSet<Symbol>,
2547        debug: Option<bool>,
2548    ) -> FST {
2549        let mut new_symbols: Vec<Symbol> = symbols.into_iter().collect();
2550        // Sort by normal comparison but in reverse; this guarantees reverse order by length and also
2551        // That different-by-symbol-string symbols get treated differently
2552        new_symbols.sort();
2553        // Sort rules such that epsilons are at the start
2554        let mut new_rules = IndexMap::new();
2555        for (target_node, rulebook) in rules {
2556            let mut new_rulebook = rulebook;
2557            new_rulebook.sort_by(|a, _, b, _| b.is_epsilon().cmp(&a.is_epsilon()));
2558            new_rules.insert(target_node, new_rulebook);
2559        }
2560        FST {
2561            final_states,
2562            rules: new_rules,
2563            symbols: new_symbols,
2564            debug: debug.unwrap_or(false),
2565        }
2566    }
2567
2568    /// Construct an instance of FST from the fields that make up FST. (See [FST::final_states], [FST::rules], [FST::symbols] and [FST::debug] for more information.)
2569    #[cfg(not(feature = "python"))]
2570    pub fn from_rules(
2571        final_states: IndexMap<u64, f64>,
2572        rules: IndexMap<u64, IndexMap<Symbol, Vec<(u64, Symbol, f64)>>>,
2573        symbols: HashSet<Symbol>,
2574        debug: Option<bool>,
2575    ) -> FST {
2576        FST::_from_rules(final_states, rules, symbols, debug)
2577    }
2578
2579    fn _from_att_file(att_file: String, debug: bool) -> KFSTResult<FST> {
2580        // Debug should default to false, pyo3 doesn't make that particularly easy
2581        match File::open(Path::new(&att_file)) {
2582            Ok(mut file) => {
2583                let mut att_code = String::new();
2584                file.read_to_string(&mut att_code).map_err(|err| {
2585                    io_error::<()>(format!("Failed to read from file {att_file}:\n{err}"))
2586                        .unwrap_err()
2587                })?;
2588                FST::from_att_code(att_code, debug)
2589            }
2590            Err(err) => io_error(format!("Failed to open file {att_file}:\n{err}")),
2591        }
2592    }
2593
2594    #[cfg(not(feature = "python"))]
2595    /// Construct an instance of FST from ATT code that resides on the file system.
2596    /// See [FST::from_att_code] for more details of what ATT code is.
2597    pub fn from_att_file(att_file: String, debug: bool) -> KFSTResult<FST> {
2598        FST::_from_att_file(att_file, debug)
2599    }
2600
2601    fn _from_att_code(att_code: String, debug: bool) -> KFSTResult<FST> {
2602        let mut rows: Vec<Result<(u64, f64), (u64, u64, Symbol, Symbol, f64)>> = vec![];
2603
2604        for (lineno, line) in att_code.lines().enumerate() {
2605            let elements: Vec<&str> = line.split("\t").collect();
2606            if elements.len() == 1 || elements.len() == 2 {
2607                let state = elements[0].parse::<u64>().ok();
2608                let weight = if elements.len() == 1 {
2609                    Some(0.0)
2610                } else {
2611                    elements[1].parse::<f64>().ok()
2612                };
2613                match (state, weight) {
2614                    (Some(state), Some(weight)) => {
2615                        rows.push(Ok((state, weight)));
2616                    }
2617                    _ => {
2618                        return value_error(format!(
2619                            "Failed to parse att code on line {lineno}:\n{line}",
2620                        ))
2621                    }
2622                }
2623            } else if elements.len() == 4 || elements.len() == 5 {
2624                let state_1 = elements[0].parse::<u64>().ok();
2625                let state_2 = elements[1].parse::<u64>().ok();
2626                let unescaped_sym_1 = unescape_att_symbol(elements[2]);
2627                let symbol_1 = Symbol::parse(&unescaped_sym_1).ok();
2628                let unescaped_sym_2 = unescape_att_symbol(elements[3]);
2629                let symbol_2 = Symbol::parse(&unescaped_sym_2).ok();
2630                let weight = if elements.len() == 4 {
2631                    Some(0.0)
2632                } else {
2633                    elements[4].parse::<f64>().ok()
2634                };
2635                match (state_1, state_2, symbol_1, symbol_2, weight) {
2636                    (
2637                        Some(state_1),
2638                        Some(state_2),
2639                        Some(("", symbol_1)),
2640                        Some(("", symbol_2)),
2641                        Some(weight),
2642                    ) => {
2643                        rows.push(Err((state_1, state_2, symbol_1, symbol_2, weight)));
2644                    }
2645                    _ => {
2646                        return value_error(format!(
2647                            "Failed to parse att code on line {lineno}:\n{line}",
2648                        ));
2649                    }
2650                }
2651            }
2652        }
2653        KFSTResult::Ok(FST::from_att_rows(rows, debug))
2654    }
2655
2656    #[cfg(not(feature = "python"))]
2657    /// Construct an FST instance from the AT&T text representation. See eg. [Apertium's wiki](https://wiki.apertium.org/wiki/ATT_format). The `debug` argument is passed to [FST::debug]
2658    /// Both the weighted and unweighted versions are supported:
2659    ///
2660    /// ```rust
2661    /// use kfst_rs::FST;
2662    ///
2663    /// // With weights
2664    ///
2665    /// let weighted = r#"0	1	c	c	1.000000
2666    /// 0	2	d	d	2.000000
2667    /// 1	3	a	a	0.000000
2668    /// 2	4	o	o	0.000000
2669    /// 3	5	t	t	0.000000
2670    /// 4	5	g	g	0.000000
2671    /// 5	6	s	s	10.000000
2672    /// 5	0.000000
2673    /// 6	0.000000"#;
2674    ///
2675    /// // to_att_code doesn't guarantee that the ATT file is laid out in the same order
2676    ///
2677    /// assert_eq!(FST::from_att_code(weighted.to_string(), false).unwrap().to_att_code(), r#"5
2678    /// 6
2679    /// 0	1	c	c	1
2680    /// 0	2	d	d	2
2681    /// 1	3	a	a
2682    /// 2	4	o	o
2683    /// 3	5	t	t
2684    /// 4	5	g	g
2685    /// 5	6	s	s	10"#);
2686    ///
2687    ///
2688    /// // Unweighted
2689    ///
2690    /// FST::from_att_code(r#"0	1	c	c
2691    /// 0	2	d	d
2692    /// 1	3	a	a
2693    /// 2	4	o	o
2694    /// 3	5	t	t
2695    /// 4	5	g	g
2696    /// 5	6	s	s
2697    /// 5
2698    /// 6"#.to_string(), false).unwrap();
2699    /// ```
2700    /// `debug` is passed along to [FST::debug].
2701    ///
2702    /// kfst attempts to maintain compatibility with the hfst interpretation of AT&T. This includes the `@_TAB_@` and `@_SPACE_@` special sequences.
2703    ///
2704    /// ```rust
2705    /// use kfst_rs::{FST, FSTState};
2706    ///
2707    /// // @_TAB_@ and @_SPACE_@ escapes can appear both as top and bottom symbols
2708    ///
2709    /// let f = FST::from_att_code(r#"4
2710    /// 0	1	@_TAB_@	a
2711    /// 1	2	b	@_TAB_@x
2712    /// 2	3	@_SPACE_@	c
2713    /// 3	4	d	@_SPACE_@
2714    /// "#.to_string(), false).unwrap();
2715    ///
2716    /// // The read-in transducer then correctly handles tabs and spaces
2717    ///
2718    /// assert_eq!(f.lookup("\tb d", FSTState::<()>::default(), false).unwrap(), vec![("a\txc ".to_string(), 0.0)]);
2719    /// ```
2720    pub fn from_att_code(att_code: String, debug: bool) -> KFSTResult<FST> {
2721        FST::_from_att_code(att_code, debug)
2722    }
2723
2724    fn _from_kfst_file(kfst_file: String, debug: bool) -> KFSTResult<FST> {
2725        match File::open(Path::new(&kfst_file)) {
2726            Ok(mut file) => {
2727                let mut kfst_bytes: Vec<u8> = vec![];
2728                file.read_to_end(&mut kfst_bytes).map_err(|err| {
2729                    io_error::<()>(format!("Failed to read from file {kfst_file}:\n{err}"))
2730                        .unwrap_err()
2731                })?;
2732                FST::from_kfst_bytes(&kfst_bytes, debug)
2733            }
2734            Err(err) => io_error(format!("Failed to open file {kfst_file}:\n{err}")),
2735        }
2736    }
2737
2738    /// Construct an FST instance from KFST binary representation that resides on the file system.
2739    /// See [FST::from_kfst_bytes] for converting memory-resident KFST binary representation into FST instances.
2740    /// `debug` is passed along to [FST::debug].
2741    #[cfg(not(feature = "python"))]
2742    pub fn from_kfst_file(kfst_file: String, debug: bool) -> KFSTResult<FST> {
2743        FST::_from_kfst_file(kfst_file, debug)
2744    }
2745
2746    #[allow(unused)]
2747    fn __from_kfst_bytes(kfst_bytes: &[u8], debug: bool) -> KFSTResult<FST> {
2748        match FST::_from_kfst_bytes(kfst_bytes) {
2749            Ok(x) => Ok(x),
2750            Err(x) => value_error(x),
2751        }
2752    }
2753
2754    /// Construct an FST instance from KFST binary representation that is resident in memory.
2755    /// The KFST binary representation is a mildly compressed way to represent a transducer.
2756    /// `debug` is passed along to [FST::debug].
2757    #[cfg(not(feature = "python"))]
2758    pub fn from_kfst_bytes(kfst_bytes: &[u8], debug: bool) -> KFSTResult<FST> {
2759        FST::__from_kfst_bytes(kfst_bytes, debug)
2760    }
2761
2762    fn _split_to_symbols(&self, text: &str, allow_unknown: bool) -> Option<Vec<Symbol>> {
2763        let mut result = vec![];
2764        let max_byte_len = self
2765            .symbols
2766            .iter()
2767            .find(|x| matches!(x, Symbol::String(_) | Symbol::Special(_) | Symbol::Flag(_)))
2768            .map(|x| x.with_symbol(|s| s.len()))
2769            .unwrap_or(0);
2770        let mut slice = text;
2771        while !slice.is_empty() {
2772            let mut found = false;
2773            for length in (0..std::cmp::min(max_byte_len, slice.len()) + 1).rev() {
2774                if !slice.is_char_boundary(length) {
2775                    continue;
2776                }
2777                let key = &slice[..length];
2778                let pp = self.symbols.partition_point(|x| {
2779                    x.with_symbol(|y| (key.chars().count(), y) < (y.chars().count(), key))
2780                });
2781                if let Some(sym) = self.symbols[pp..]
2782                    .iter()
2783                    .find(|x| matches!(x, Symbol::String(_) | Symbol::Special(_) | Symbol::Flag(_)))
2784                {
2785                    if sym.with_symbol(|s| s == key) {
2786                        result.push(sym.clone());
2787                        slice = &slice[length..];
2788                        found = true;
2789                        break;
2790                    }
2791                }
2792            }
2793            if (!found) && allow_unknown {
2794                let char = slice.chars().next().unwrap();
2795                slice = &slice[char.len_utf8()..];
2796                result.push(Symbol::String(StringSymbol::new(char.to_string(), false)));
2797                found = true;
2798            }
2799            if !found {
2800                return None;
2801            }
2802        }
2803        Some(result)
2804    }
2805
2806    /// Tokenize a text into Symbol instances matching this transducers alphabet ([FST::symbols]).
2807    /// The argument `allow_unknown` matters only if the text can not be cleanly tokenized:
2808    /// * If it is set to `true`, untokenizable sequences get represented as [Symbol::String] that are marked as unknown (see eg. [Symbol::is_unknown]).
2809    /// * If it is set to `false`, a value of [None] is returned.
2810    #[cfg(not(feature = "python"))]
2811    pub fn split_to_symbols(&self, text: &str, allow_unknown: bool) -> Option<Vec<Symbol>> {
2812        self._split_to_symbols(text, allow_unknown)
2813    }
2814
2815    fn __run_fst<T: UsizeOrUnit + Clone>(
2816        &self,
2817        input_symbols: Vec<Symbol>,
2818        state: InternalFSTState<T>,
2819        post_input_advance: bool,
2820        input_symbol_index: usize,
2821        keep_non_final: bool,
2822    ) -> Vec<(bool, bool, InternalFSTState<T>)> {
2823        let mut result = vec![];
2824        self._run_fst(
2825            input_symbols.as_slice(),
2826            &state,
2827            post_input_advance,
2828            &mut result,
2829            input_symbol_index,
2830            keep_non_final,
2831        );
2832        result
2833    }
2834
2835    #[cfg(not(feature = "python"))]
2836    /// Apply this FST to a sequence of symbols `input_symbols` starting from the state `state`.
2837    /// The members of the elements of the returned tuple are:
2838    ///   * finality of the state
2839    ///   * the value of `post_input_advance`
2840    ///   * the state proper from which an output symbol sequence can be deduced.
2841    ///
2842    /// Unless you use special token types or need to do complex token manipulation, you should probably be using [FST::lookup].
2843    pub fn run_fst<T: Clone + UsizeOrUnit>(
2844        &self,
2845        input_symbols: Vec<Symbol>,
2846        state: InternalFSTState<T>,
2847        post_input_advance: bool,
2848        input_symbol_index: Option<usize>,
2849        keep_non_final: bool,
2850    ) -> Vec<(bool, bool, InternalFSTState<T>)> {
2851        self.__run_fst(
2852            input_symbols,
2853            state,
2854            post_input_advance,
2855            input_symbol_index.unwrap_or(0),
2856            keep_non_final,
2857        )
2858    }
2859
2860    fn _lookup<T: UsizeOrUnit + Clone + Default + Debug>(
2861        &self,
2862        input: &str,
2863        state: InternalFSTState<T>,
2864        allow_unknown: bool,
2865    ) -> KFSTResult<Vec<(String, f64)>> {
2866        let input_symbols = self.split_to_symbols(input, allow_unknown);
2867        match input_symbols {
2868            None => tokenization_exception(format!("Input cannot be split into symbols: {input}")),
2869            Some(input_symbols) => {
2870                let mut dedup: IndexSet<String> = IndexSet::new();
2871                let mut result: Vec<(String, f64)> = vec![];
2872                let mut finished_paths: Vec<(bool, bool, InternalFSTState<()>)> = self.__run_fst(
2873                    input_symbols.clone(),
2874                    state.convert_indices(),
2875                    false,
2876                    0,
2877                    false,
2878                );
2879                finished_paths
2880                    .sort_by(|a, b| a.2.path_weight.partial_cmp(&b.2.path_weight).unwrap());
2881                for finished in finished_paths {
2882                    let output_string: String = finished
2883                        .2
2884                        .symbol_mappings
2885                        .to_vec()
2886                        .iter()
2887                        .map(|x| x.0.get_symbol())
2888                        .collect::<Vec<String>>()
2889                        .join("");
2890                    if dedup.contains(&output_string) {
2891                        continue;
2892                    }
2893                    dedup.insert(output_string.clone());
2894                    result.push((output_string, finished.2.path_weight));
2895                }
2896                Ok(result)
2897            }
2898        }
2899    }
2900
2901    fn _lookup_aligned<T: UsizeOrUnit + Clone + Default + Debug>(
2902        &self,
2903        input: &str,
2904        state: InternalFSTState<T>,
2905        allow_unknown: bool,
2906    ) -> KFSTResult<Vec<(Vec<(usize, Symbol)>, f64)>> {
2907        let input_symbols = self.split_to_symbols(input, allow_unknown);
2908        match input_symbols {
2909            None => tokenization_exception(format!("Input cannot be split into symbols: {input}")),
2910            Some(input_symbols) => {
2911                let mut dedup: IndexSet<Vec<(usize, Symbol)>> = IndexSet::new();
2912                let mut result: Vec<(Vec<(usize, Symbol)>, f64)> = vec![];
2913                let mut finished_paths: Vec<_> = self
2914                    .__run_fst(
2915                        input_symbols.clone(),
2916                        state.convert_indices(),
2917                        false,
2918                        0,
2919                        false,
2920                    )
2921                    .into_iter()
2922                    .filter(|(finished, _, _)| *finished)
2923                    .collect();
2924                finished_paths
2925                    .sort_by(|a, b| a.2.path_weight.partial_cmp(&b.2.path_weight).unwrap());
2926                for finished in finished_paths {
2927                    let output_vec: Vec<(usize, Symbol)> = finished
2928                        .2
2929                        .symbol_mappings
2930                        .to_vec()
2931                        .into_iter()
2932                        .map(|(a, b)| (b, a))
2933                        .collect();
2934                    if dedup.contains(&output_vec) {
2935                        continue;
2936                    }
2937                    dedup.insert(output_vec.clone());
2938                    result.push((output_vec, finished.2.path_weight));
2939                }
2940                Ok(result)
2941            }
2942        }
2943    }
2944
2945    #[cfg(not(feature = "python"))]
2946    /// Tokenize and transduce `input`, starting from the given `state` (note that [FSTState] implements [Default]) and either allowing or disallowing unknown tokens.
2947    /// (See [FST::split_to_symbols] for tokenization of unknown tokens.)
2948    ///
2949    /// If tokenization succeeds, returns a [Vec] of pairs of transduced strings and their weights.
2950    /// If tokenization fails, returns a [KFSTResult::Err] variant
2951    ///
2952    /// If you need to know what symbol was transduced to what, look at [FST::lookup_aligned]. If you need more control over tokenization (or if your symbols just can not be parsed from a string representation), [FST::run_fst] might be what you are looking for.
2953    pub fn lookup<T: UsizeOrUnit + Clone + Default + Debug>(
2954        &self,
2955        input: &str,
2956        state: InternalFSTState<T>,
2957        allow_unknown: bool,
2958    ) -> KFSTResult<Vec<(String, f64)>> {
2959        self._lookup(input, state, allow_unknown)
2960    }
2961
2962    #[cfg(not(feature = "python"))]
2963    /// Tokenize and transduce `input`, starting from the given `state` (note that [FSTState] implements [Default]) and either allowing or disallowing unknown tokens.
2964    /// (See [FST::split_to_symbols] for tokenization of unknown tokens.)
2965    ///
2966    /// If tokenization succeeds, returns a [Vec] of pairs. On the left is a [Vec] of pairs of indices into the input symbol list and matching output symbols. On the right is the weight of this path.
2967    /// If tokenization fails, returns a [KFSTResult::Err] variant
2968    ///
2969    /// If you just want strings in and strings out, look at [FST::lookup]. If you need more control over tokenization (or if your symbols just can not be parsed from a string representation), [FST::run_fst] might be what you are looking for.
2970    ///
2971    /// This method swallows output epsilons. Concretely, if an input character is transduced to an epsilon, that input character is seen as being part of the transduction of whatever the next non-epsilon output character is. (See example for more details)
2972    /// ```rust
2973    /// use kfst_rs::{FST, FSTState, Symbol, StringSymbol, SpecialSymbol};
2974    ///
2975    /// // We load the pykko parser to parse an actual finnish word
2976    ///
2977    /// let fst = FST::from_kfst_file("../pypykko/pypykko/fi-parser.kfst".to_string(), false).unwrap();
2978    ///
2979    /// // We parse "isonvarpaan" which is the genitive form of "isovarvas".
2980    /// // It is the compound of "iso" and "varvas" and notably it inflects in both components:
2981    /// // (iso -> ison and varvas -> varpaan)
2982    /// // We wish to recover the information regarding what ranges in the original word the compound components match.
2983    /// // Thus we need lookup_aligned.
2984    ///
2985    /// assert_eq!(
2986    ///   fst.lookup_aligned("isonvarpaan", FSTState::<usize>::default(), false).unwrap()[0].0, // Discard secondary interpretations and weight
2987    ///   vec![
2988    ///        // The first item of the tuple is the index in the source string;
2989    ///        // If it doesn't increment for a row, the output symbol came from an epsilon transition
2990    ///        (0, Symbol::String(StringSymbol::new("Lexicon".to_string(), false))),
2991    ///        (0, Symbol::String(StringSymbol::new("\t".to_string(), false))),
2992    ///
2993    ///        // Here we have a run of incrementing indices: i:i, s:s, o:o and n:@_EPSILON_SYMBOL_@. The last gets swallowed.
2994    ///
2995    ///        (0, Symbol::String(StringSymbol::new("i".to_string(), false))),
2996    ///        (1, Symbol::String(StringSymbol::new("s".to_string(), false))),
2997    ///        (2, Symbol::String(StringSymbol::new("o".to_string(), false))),
2998    ///
2999    ///        // Here we have a row that isn't incremented after, ie. the separating pipe comes from
3000    ///        // @_EPSILON_SYMBOL_@:|
3001    ///
3002    ///        (4, Symbol::String(StringSymbol::new("|".to_string(), false))),
3003    ///
3004    ///        // Here on we increment v:v, a:a, r:r, p:v, a:a
3005    ///
3006    ///        (4, Symbol::String(StringSymbol::new("v".to_string(), false))),
3007    ///        (5, Symbol::String(StringSymbol::new("a".to_string(), false))),
3008    ///        (6, Symbol::String(StringSymbol::new("r".to_string(), false))),
3009    ///        (7, Symbol::String(StringSymbol::new("v".to_string(), false))),
3010    ///        (8, Symbol::String(StringSymbol::new("a".to_string(), false))),
3011    ///
3012    ///        // These two are somewhat surprising: @_EPSILON_SYMBOL_@:s and  a:@_EPSILON_SYMBOL_@.
3013    ///        // Notably there is consonant gradation going on (varva -> varpa)
3014    ///        // As the output epsilon gets swallowed, this gets interpreted as a:s.
3015    ///
3016    ///        (9, Symbol::String(StringSymbol::new("s".to_string(), false))),
3017    ///
3018    ///        // We are out of the stem and in the genitive ending (-n)
3019    ///        // The final n is consumed by the +gen token
3020    ///
3021    ///        (10, Symbol::String(StringSymbol::new("\tnoun\t".to_string(), false))),
3022    ///        (10, Symbol::String(StringSymbol::new("\t".to_string(), false))),
3023    ///        (10, Symbol::String(StringSymbol::new("\t".to_string(), false))),
3024    ///        (10, Symbol::String(StringSymbol::new("+sg".to_string(), false))),
3025    ///        (10, Symbol::String(StringSymbol::new("+gen".to_string(), false)))]
3026    /// );
3027    /// ```
3028    pub fn lookup_aligned<T: UsizeOrUnit + Clone + Default + Debug>(
3029        &self,
3030        input: &str,
3031        state: InternalFSTState<T>,
3032        allow_unknown: bool,
3033    ) -> KFSTResult<Vec<(Vec<(usize, Symbol)>, f64)>> {
3034        self._lookup_aligned(input, state, allow_unknown)
3035    }
3036}
3037
3038fn _update_flags(symbol: &Symbol, flags: &FlagMap) -> Option<FlagMap> {
3039    if let Symbol::Flag(flag_diacritic_symbol) = symbol {
3040        match flag_diacritic_symbol.flag_type {
3041            FlagDiacriticType::U => {
3042                let value = flag_diacritic_symbol.value;
3043
3044                // Is the current state somehow in conflict?
3045                // It can be, if we are negatively set to what we try to unify to or we are positively set to sth else
3046
3047                if let Some((currently_set, current_value)) = flags.get(flag_diacritic_symbol.key) {
3048                    if (currently_set && current_value != value)
3049                        || (!currently_set && current_value == value)
3050                    {
3051                        return None;
3052                    }
3053                }
3054
3055                // Otherwise, update flag set
3056
3057                Some(flags.insert(flag_diacritic_symbol.key, (true, value)))
3058            }
3059            FlagDiacriticType::R => {
3060                // Param count matters
3061
3062                match flag_diacritic_symbol.value {
3063                    u32::MAX => {
3064                        if flags.get(flag_diacritic_symbol.key).is_some() {
3065                            Some(flags.clone())
3066                        } else {
3067                            None
3068                        }
3069                    }
3070                    value => {
3071                        if flags
3072                            .get(flag_diacritic_symbol.key)
3073                            .map(|stored| _test_flag(&stored, value))
3074                            .unwrap_or(false)
3075                        {
3076                            Some(flags.clone())
3077                        } else {
3078                            None
3079                        }
3080                    }
3081                }
3082            }
3083            FlagDiacriticType::D => {
3084                match (
3085                    flag_diacritic_symbol.value,
3086                    flags.get(flag_diacritic_symbol.key),
3087                ) {
3088                    (u32::MAX, None) => Some(flags.clone()),
3089                    (u32::MAX, _) => None,
3090                    (_, None) => Some(flags.clone()),
3091                    (query, Some(stored)) => {
3092                        if _test_flag(&stored, query) {
3093                            None
3094                        } else {
3095                            Some(flags.clone())
3096                        }
3097                    }
3098                }
3099            }
3100            FlagDiacriticType::C => Some(flags.remove(flag_diacritic_symbol.key)),
3101            FlagDiacriticType::P => {
3102                let value = flag_diacritic_symbol.value;
3103                Some(flags.insert(flag_diacritic_symbol.key, (true, value)))
3104            }
3105            FlagDiacriticType::N => {
3106                let value = flag_diacritic_symbol.value;
3107                Some(flags.insert(flag_diacritic_symbol.key, (false, value)))
3108            }
3109        }
3110    } else {
3111        Some(flags.clone())
3112    }
3113}
3114
3115#[cfg_attr(feature = "python", pymethods)]
3116impl FST {
3117    #[cfg(feature = "python")]
3118    #[staticmethod]
3119    #[pyo3(signature = (final_states, rules, symbols, debug = false))]
3120    fn from_rules(
3121        final_states: IndexMap<u64, f64>,
3122        rules: IndexMap<u64, IndexMap<Symbol, Vec<(u64, Symbol, f64)>>>,
3123        symbols: HashSet<Symbol>,
3124        debug: Option<bool>,
3125    ) -> FST {
3126        FST::_from_rules(final_states, rules, symbols, debug)
3127    }
3128
3129    #[cfg(feature = "python")]
3130    #[staticmethod]
3131    #[pyo3(signature = (att_file, debug = false))]
3132    fn from_att_file(py: Python<'_>, att_file: PyObject, debug: bool) -> KFSTResult<FST> {
3133        FST::_from_att_file(att_file.call_method0(py, "__str__")?.extract(py)?, debug)
3134    }
3135
3136    #[cfg(feature = "python")]
3137    #[staticmethod]
3138    #[pyo3(signature = (att_code, debug = false))]
3139    fn from_att_code(att_code: String, debug: bool) -> KFSTResult<FST> {
3140        FST::_from_att_code(att_code, debug)
3141    }
3142
3143    #[cfg(feature = "python")]
3144    pub fn to_att_file(&self, py: Python<'_>, att_file: PyObject) -> KFSTResult<()> {
3145        let path: String = att_file.call_method0(py, "__str__")?.extract(py)?;
3146        fs::write(Path::new(&path), self.to_att_code()).map_err(|err| {
3147            io_error::<()>(format!("Failed to write to file {path}:\n{err}")).unwrap_err()
3148        })
3149    }
3150
3151    /// Save the current transducer to a file in the ATT format. See [FST::from_att_code] for more details on the ATT format.
3152    #[cfg(not(feature = "python"))]
3153    pub fn to_att_file(&self, att_file: String) -> KFSTResult<()> {
3154        fs::write(Path::new(&att_file), self.to_att_code()).map_err(|err| {
3155            io_error::<()>(format!("Failed to write to file {att_file}:\n{err}")).unwrap_err()
3156        })
3157    }
3158
3159    /// Serialize the current transducer to a [String] in the AT&T format. See [FST::from_att_code] for more details on the AT&T format.
3160    /// kfst tries to adhere to hfst's behaviour as closely as possible. This introduces some corner cases and odd behaviours.
3161    /// AT&T format has significant line feeds and tabs. Tabs in the transducer are represented as `@_TAB_@` in AT&T format.
3162    /// Spaces are also escaped as `@_SPACE_@`. However, there is no escape for the @-character itself (so the litteral string `"@_TAB_@"` cannot appear in a transition.)
3163    /// Line feeds are simply not escaped at all.
3164    ///  
3165    /// ```rust
3166    /// use kfst_rs::FST;
3167    ///
3168    /// // @_TAB_@ and @_SPACE_@ escapes can appear both as top and bottom symbols
3169    ///
3170    /// let code = r#"4
3171    /// 0	1	@_TAB_@	a
3172    /// 1	2	b	@_TAB_@x
3173    /// 2	3	@_SPACE_@	c
3174    /// 3	4	d	@_SPACE_@"#;
3175    ///
3176    /// let f = FST::from_att_code(code.to_string(), false).unwrap();
3177    /// assert_eq!(f.to_att_code(), code.to_string());
3178    /// ```
3179    ///
3180    ///
3181    pub fn to_att_code(&self) -> String {
3182        let mut rows: Vec<String> = vec![];
3183        for (state, weight) in self.final_states.iter() {
3184            match weight {
3185                0.0 => {
3186                    rows.push(format!("{state}"));
3187                }
3188                _ => {
3189                    rows.push(format!("{state}\t{weight}"));
3190                }
3191            }
3192        }
3193        for (from_state, rules) in self.rules.iter() {
3194            for (top_symbol, transitions) in rules.iter() {
3195                for (to_state, bottom_symbol, weight) in transitions.iter() {
3196                    match weight {
3197                        0.0 => {
3198                            rows.push(format!(
3199                                "{}\t{}\t{}\t{}",
3200                                from_state,
3201                                to_state,
3202                                escape_att_symbol(&top_symbol.get_symbol()),
3203                                escape_att_symbol(&bottom_symbol.get_symbol())
3204                            ));
3205                        }
3206                        _ => {
3207                            rows.push(format!(
3208                                "{}\t{}\t{}\t{}\t{}",
3209                                from_state,
3210                                to_state,
3211                                escape_att_symbol(&top_symbol.get_symbol()),
3212                                escape_att_symbol(&bottom_symbol.get_symbol()),
3213                                weight
3214                            ));
3215                        }
3216                    }
3217                }
3218            }
3219        }
3220        rows.join("\n")
3221    }
3222
3223    #[cfg(feature = "python")]
3224    #[staticmethod]
3225    #[pyo3(signature = (kfst_file, debug = false))]
3226    fn from_kfst_file(py: Python<'_>, kfst_file: PyObject, debug: bool) -> KFSTResult<FST> {
3227        FST::_from_kfst_file(kfst_file.call_method0(py, "__str__")?.extract(py)?, debug)
3228    }
3229
3230    #[cfg(feature = "python")]
3231    #[staticmethod]
3232    #[pyo3(signature = (kfst_bytes, debug = false))]
3233    fn from_kfst_bytes(kfst_bytes: &[u8], debug: bool) -> KFSTResult<FST> {
3234        FST::__from_kfst_bytes(kfst_bytes, debug)
3235    }
3236
3237    #[cfg(feature = "python")]
3238    pub fn to_kfst_file(&self, py: Python<'_>, kfst_file: PyObject) -> KFSTResult<()> {
3239        let bytes = self.to_kfst_bytes()?;
3240        let path: String = kfst_file.call_method0(py, "__str__")?.extract(py)?;
3241        fs::write(Path::new(&path), bytes).map_err(|err| {
3242            io_error::<()>(format!("Failed to write to file {path}:\n{err}")).unwrap_err()
3243        })
3244    }
3245
3246    #[cfg(not(feature = "python"))]
3247    /// Save the current transducer to a file in the KFST format. See [FST::from_kfst_bytes] for more details on the KFST format.
3248    pub fn to_kfst_file(&self, kfst_file: String) -> KFSTResult<()> {
3249        let bytes = self.to_kfst_bytes()?;
3250        fs::write(Path::new(&kfst_file), bytes).map_err(|err| {
3251            io_error::<()>(format!("Failed to write to file {kfst_file}:\n{err}")).unwrap_err()
3252        })
3253    }
3254
3255    /// Serialize the current transducer to a bytestring in the KFST format. See [FST::from_kfst_bytes] for more details on the KFST format.
3256    pub fn to_kfst_bytes(&self) -> KFSTResult<Vec<u8>> {
3257        match self._to_kfst_bytes() {
3258            Ok(x) => Ok(x),
3259            Err(x) => value_error(x),
3260        }
3261    }
3262
3263    #[deprecated]
3264    /// Convert this FST into a somewhat human readable string representation. Exists for the Python API's sake.
3265    pub fn __repr__(&self) -> String {
3266        format!(
3267            "FST(final_states: {:?}, rules: {:?}, symbols: {:?}, debug: {:?})",
3268            self.final_states, self.rules, self.symbols, self.debug
3269        )
3270    }
3271
3272    #[cfg(feature = "python")]
3273    #[pyo3(signature = (text, allow_unknown = true))]
3274    fn split_to_symbols(&self, text: &str, allow_unknown: bool) -> Option<Vec<Symbol>> {
3275        self._split_to_symbols(text, allow_unknown)
3276    }
3277
3278    #[cfg(feature = "python")]
3279    #[pyo3(signature = (input_symbols, state = FSTState { payload: Ok(InternalFSTState::_new(0)) }, post_input_advance = false, input_symbol_index = None, keep_non_final = true))]
3280    fn run_fst(
3281        &self,
3282        input_symbols: Vec<Symbol>,
3283        state: FSTState,
3284        post_input_advance: bool,
3285        input_symbol_index: Option<usize>,
3286        keep_non_final: bool,
3287    ) -> Vec<(bool, bool, FSTState)> {
3288        match state.payload {
3289            Ok(p) => self
3290                .__run_fst(
3291                    input_symbols,
3292                    p,
3293                    post_input_advance,
3294                    input_symbol_index.unwrap_or(0),
3295                    keep_non_final,
3296                )
3297                .into_iter()
3298                .map(|(a, b, c)| (a, b, FSTState { payload: Ok(c) }))
3299                .collect(),
3300            Err(p) => self
3301                .__run_fst(
3302                    input_symbols,
3303                    p,
3304                    post_input_advance,
3305                    input_symbol_index.unwrap_or(0),
3306                    keep_non_final,
3307                )
3308                .into_iter()
3309                .map(|(a, b, c)| (a, b, FSTState { payload: Err(c) }))
3310                .collect(),
3311        }
3312    }
3313
3314    #[cfg(feature = "python")]
3315    #[pyo3(signature = (input, state=FSTState { payload: Err(InternalFSTState::_new(0)) }, allow_unknown=true))]
3316    fn lookup(
3317        &self,
3318        input: &str,
3319        state: FSTState,
3320        allow_unknown: bool,
3321    ) -> KFSTResult<Vec<(String, f64)>> {
3322        match state.payload {
3323            Ok(p) => self._lookup(input, p, allow_unknown),
3324            Err(p) => self._lookup(input, p, allow_unknown),
3325        }
3326    }
3327
3328    #[cfg(feature = "python")]
3329    #[pyo3(signature = (input, state=FSTState { payload: Ok(InternalFSTState::_new(0)) }, allow_unknown=true))]
3330    pub fn lookup_aligned(
3331        &self,
3332        input: &str,
3333        state: FSTState,
3334        allow_unknown: bool,
3335    ) -> KFSTResult<Vec<(Vec<(usize, Symbol)>, f64)>> {
3336        use pyo3::exceptions::PyValueError;
3337
3338        match state.payload {
3339            Ok(p) => self._lookup_aligned(input, p, allow_unknown),
3340            Err(p) => PyResult::Err(PyErr::new::<PyValueError, _>(
3341                format!("lookup_aligned refuses to work with states with input_indices=None (passed state {}). Manually convert it to an indexed state by calling ensure_indices()", p.__repr__())
3342            )),
3343        }
3344    }
3345
3346    #[cfg(feature = "python")]
3347    pub fn get_input_symbols(&self, state: FSTState) -> HashSet<Symbol> {
3348        match state.payload {
3349            Ok(p) => self._get_input_symbols(p),
3350            Err(p) => self._get_input_symbols(p),
3351        }
3352    }
3353}
3354
3355impl FST {
3356    fn _get_input_symbols<T>(&self, state: InternalFSTState<T>) -> HashSet<Symbol> {
3357        self.rules
3358            .get(&state.state_num)
3359            .map(|x| x.keys().cloned().collect())
3360            .unwrap_or_default()
3361    }
3362
3363    #[cfg(not(feature = "python"))]
3364    /// Equal to:
3365    /// ```no_test
3366    /// self.rules[&state.state_num].keys().cloned().collect()
3367    /// ```
3368    /// Exists as its own function to make getting the input symbols of a state fast when calling from Python.
3369    /// (Otherwise the whole [FST::rules] mapping needs to be converted into Python's representation, which is significantly slower)
3370    ///
3371    /// ```
3372    /// use kfst_rs::{FST, Symbol, FSTState};
3373    /// use std::collections::{HashSet, HashMap};
3374    ///
3375    /// let fst = FST::from_att_code("0\t1\ta\tb\n".to_string(), false).unwrap();
3376    /// let mut expected = HashSet::new();
3377    /// expected.insert(Symbol::parse("a").unwrap().1);
3378    /// assert_eq!(fst.get_input_symbols(FSTState::<()>::new(0, 0.0, HashMap::new(), HashMap::new(), vec![], vec![])), expected);
3379    /// assert_eq!(fst.get_input_symbols(FSTState::<()>::new(1, 0.0, HashMap::new(), HashMap::new(), vec![], vec![])), HashSet::new());
3380    /// ```
3381    pub fn get_input_symbols<T>(&self, state: FSTState<T>) -> HashSet<Symbol> {
3382        self._get_input_symbols(state)
3383    }
3384}
3385
3386#[cfg(not(feature = "python"))]
3387mod tests {
3388    use crate::*;
3389
3390    #[test]
3391    fn test_att_trivial() {
3392        let fst = FST::from_att_code("1\n0\t1\ta\tb".to_string(), false).unwrap();
3393        assert_eq!(
3394            fst.lookup("a", FSTState::<()>::default(), false).unwrap(),
3395            vec![("b".to_string(), 0.0)]
3396        );
3397    }
3398
3399    #[test]
3400    fn test_att_slightly_less_trivial() {
3401        let fst = FST::from_att_code("2\n0\t1\ta\tb\n1\t2\tc\td".to_string(), false).unwrap();
3402        assert_eq!(
3403            fst.lookup("ac", FSTState::<()>::default(), false).unwrap(),
3404            vec![("bd".to_string(), 0.0)]
3405        );
3406    }
3407
3408    #[test]
3409    fn test_kfst_voikko_kissa() {
3410        let fst =
3411            FST::_from_kfst_file("../pyvoikko/pyvoikko/voikko.kfst".to_string(), false).unwrap();
3412        assert_eq!(
3413            fst.lookup("kissa", FSTState::<()>::_new(0), false).unwrap(),
3414            vec![("[Ln][Xp]kissa[X]kiss[Sn][Ny]a".to_string(), 0.0)]
3415        );
3416        assert_eq!(
3417            fst.lookup("kissojemmekaan", FSTState::<()>::_new(0), false)
3418                .unwrap(),
3419            vec![(
3420                "[Ln][Xp]kissa[X]kiss[Sg][Nm]oje[O1m]mme[Fkaan]kaan".to_string(),
3421                0.0
3422            )]
3423        );
3424    }
3425
3426    #[test]
3427    fn test_that_weight_of_end_state_applies_correctly() {
3428        let code = "0\t1\ta\tb\n1\t1.0";
3429        let fst = FST::from_att_code(code.to_string(), false).unwrap();
3430        assert_eq!(
3431            fst.lookup("a", FSTState::<()>::_new(0), false).unwrap(),
3432            vec![("b".to_string(), 1.0)]
3433        );
3434    }
3435
3436    #[test]
3437    fn test_kfst_voikko_correct_final_states() {
3438        let fst: FST =
3439            FST::_from_kfst_file("../pyvoikko/pyvoikko/voikko.kfst".to_string(), false).unwrap();
3440        let map: IndexMap<_, _> = [(19, 0.0)].into_iter().collect();
3441        assert_eq!(fst.final_states, map);
3442    }
3443
3444    #[test]
3445    fn test_kfst_voikko_split() {
3446        let fst: FST =
3447            FST::_from_kfst_file("../pyvoikko/pyvoikko/voikko.kfst".to_string(), false).unwrap();
3448        assert_eq!(
3449            fst.split_to_symbols("lentokone", false).unwrap(),
3450            vec![
3451                Symbol::String(StringSymbol {
3452                    string: intern("l".to_string()),
3453                    unknown: false
3454                }),
3455                Symbol::String(StringSymbol {
3456                    string: intern("e".to_string()),
3457                    unknown: false
3458                }),
3459                Symbol::String(StringSymbol {
3460                    string: intern("n".to_string()),
3461                    unknown: false
3462                }),
3463                Symbol::String(StringSymbol {
3464                    string: intern("t".to_string()),
3465                    unknown: false
3466                }),
3467                Symbol::String(StringSymbol {
3468                    string: intern("o".to_string()),
3469                    unknown: false
3470                }),
3471                Symbol::String(StringSymbol {
3472                    string: intern("k".to_string()),
3473                    unknown: false
3474                }),
3475                Symbol::String(StringSymbol {
3476                    string: intern("o".to_string()),
3477                    unknown: false
3478                }),
3479                Symbol::String(StringSymbol {
3480                    string: intern("n".to_string()),
3481                    unknown: false
3482                }),
3483                Symbol::String(StringSymbol {
3484                    string: intern("e".to_string()),
3485                    unknown: false
3486                }),
3487            ]
3488        );
3489
3490        assert_eq!(
3491            fst.split_to_symbols("lentää", false).unwrap(),
3492            vec![
3493                Symbol::String(StringSymbol {
3494                    string: intern("l".to_string()),
3495                    unknown: false
3496                }),
3497                Symbol::String(StringSymbol {
3498                    string: intern("e".to_string()),
3499                    unknown: false
3500                }),
3501                Symbol::String(StringSymbol {
3502                    string: intern("n".to_string()),
3503                    unknown: false
3504                }),
3505                Symbol::String(StringSymbol {
3506                    string: intern("t".to_string()),
3507                    unknown: false
3508                }),
3509                Symbol::String(StringSymbol {
3510                    string: intern("ä".to_string()),
3511                    unknown: false
3512                }),
3513                Symbol::String(StringSymbol {
3514                    string: intern("ä".to_string()),
3515                    unknown: false
3516                }),
3517            ]
3518        );
3519    }
3520
3521    #[test]
3522    fn test_kfst_voikko() {
3523        let fst =
3524            FST::_from_kfst_file("../pyvoikko/pyvoikko/voikko.kfst".to_string(), false).unwrap();
3525        assert_eq!(
3526            fst.lookup("lentokone", FSTState::<()>::_new(0), false)
3527                .unwrap(),
3528            vec![(
3529                "[Lt][Xp]lentää[X]len[Ln][Xj]to[X]to[Sn][Ny][Bh][Bc][Ln][Xp]kone[X]kone[Sn][Ny]"
3530                    .to_string(),
3531                0.0
3532            )]
3533        );
3534    }
3535
3536    #[test]
3537    fn test_kfst_voikko_lentää() {
3538        let fst =
3539            FST::_from_kfst_file("../pyvoikko/pyvoikko/voikko.kfst".to_string(), false).unwrap();
3540        let mut sys = fst
3541            .lookup("lentää", FSTState::<()>::_new(0), false)
3542            .unwrap();
3543        sys.sort_by(|a, b| a.partial_cmp(b).unwrap());
3544        assert_eq!(
3545            sys,
3546            vec![
3547                ("[Lt][Xp]lentää[X]len[Tn1][Eb]tää".to_string(), 0.0),
3548                (
3549                    "[Lt][Xp]lentää[X]len[Tt][Ap][P3][Ny][Ef]tää".to_string(),
3550                    0.0
3551                )
3552            ]
3553        );
3554    }
3555
3556    #[test]
3557    fn test_kfst_voikko_lentää_correct_states() {
3558        let fst =
3559            FST::_from_kfst_file("../pyvoikko/pyvoikko/voikko.kfst".to_string(), false).unwrap();
3560        let input_symbols = fst.split_to_symbols("lentää", false).unwrap();
3561
3562        // Correct number of states for different subsequence lengths per KFST
3563
3564        let results = [
3565            vec![
3566                0, 1, 1810, 1946, 1961, 1962, 1963, 1964, 1965, 1966, 2665, 2969, 2970, 3104, 3295,
3567                3484, 3678, 3870, 4064, 4260, 4454, 4648, 4842, 5036, 5230, 5454, 5645, 5839, 6031,
3568                6225, 6419, 6613, 6807, 7001, 7195, 7389, 7579, 12479, 13348, 13444, 13541, 13636,
3569                13733, 13830, 13925, 14028, 14131, 14234, 14331, 14426, 14525, 14622, 14723, 14826,
3570                14929, 15024, 15127, 15230, 15333, 15433, 15526,
3571            ],
3572            vec![
3573                10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1878, 2840, 17295, 25716, 31090, 40909,
3574                85222, 204950, 216255, 217894, 254890, 256725, 256726, 256727, 256728, 256729,
3575                256730, 256731, 256732, 256733, 256734, 256735, 256736, 280866, 281235, 281479,
3576                281836, 281876, 281877, 288536, 355529, 378467,
3577            ],
3578            vec![
3579                17459, 17898, 17899, 26065, 26066, 26067, 26068, 26069, 31245, 42140, 87151,
3580                134039, 134040, 205452, 219693, 219694, 259005, 259666, 259667, 259668, 259669,
3581                259670, 259671, 259672, 280894, 281857, 289402, 356836, 378621, 378750, 378773,
3582                386786, 388199, 388200, 388201, 388202, 388203,
3583            ],
3584            vec![
3585                17458, 17459, 17899, 19455, 26192, 26214, 26215, 26216, 26217, 42361, 87536,
3586                118151, 205474, 216303, 220614, 220615, 220616, 220617, 220618, 220619, 220620,
3587                220621, 220629, 228443, 228444, 228445, 259219, 259220, 259221, 259222, 259223,
3588                259224, 259225, 356941, 387264,
3589            ],
3590            vec![
3591                42362, 102258, 216304, 216309, 216312, 216317, 217230, 356942, 387265,
3592            ],
3593            vec![
3594                211149, 212998, 212999, 213000, 213001, 213002, 216305, 216310, 216313, 216318,
3595            ],
3596            vec![
3597                12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 210815, 210816,
3598                211139, 211140, 214985, 216311, 216314, 216315, 216316,
3599            ],
3600        ];
3601
3602        for i in 0..=input_symbols.len() {
3603            let subsequence = &input_symbols[..i];
3604            let mut states: Vec<_> = fst
3605                .run_fst(
3606                    subsequence.to_vec(),
3607                    FSTState::<()>::_new(0),
3608                    false,
3609                    None,
3610                    true,
3611                )
3612                .into_iter()
3613                .map(|(_, _, x)| x.state_num)
3614                .collect();
3615            states.sort();
3616            assert_eq!(states, results[i]);
3617        }
3618    }
3619
3620    #[test]
3621    fn test_minimal_r_diacritic() {
3622        let code = "0\t1\t@P.V_SALLITTU.T@\tasetus\n1\t2\t@R.V_SALLITTU.T@\ttarkistus\n2";
3623        let fst = FST::from_att_code(code.to_string(), false).unwrap();
3624        let mut result = vec![];
3625        fst._run_fst(&[], &FSTState::<()>::_new(0), false, &mut result, 0, true);
3626        for x in result {
3627            println!("{x:?}");
3628        }
3629        assert_eq!(
3630            fst.lookup("", FSTState::<()>::_new(0), false).unwrap(),
3631            vec![("asetustarkistus".to_string(), 0.0)]
3632        );
3633    }
3634
3635    #[test]
3636    fn test_kfst_voikko_lentää_result_count() {
3637        let fst =
3638            FST::_from_kfst_file("../pyvoikko/pyvoikko/voikko.kfst".to_string(), false).unwrap();
3639        let input_symbols = fst.split_to_symbols("lentää", false).unwrap();
3640
3641        // Correct number of states for different subsequence lengths per KFST
3642
3643        let results = [61, 42, 37, 35, 9, 10, 25];
3644
3645        for i in 0..=input_symbols.len() {
3646            let subsequence = &input_symbols[..i];
3647            assert_eq!(
3648                fst.run_fst(
3649                    subsequence.to_vec(),
3650                    FSTState::<()>::_new(0),
3651                    false,
3652                    None,
3653                    true
3654                )
3655                .len(),
3656                results[i]
3657            );
3658        }
3659    }
3660
3661    #[test]
3662    fn does_not_crash_on_unknown() {
3663        let fst = FST::from_att_code("0\t1\ta\tb\n1".to_string(), false).unwrap();
3664        assert_eq!(
3665            fst.lookup("c", FSTState::<()>::_new(0), true).unwrap(),
3666            vec![]
3667        );
3668        assert!(fst.lookup("c", FSTState::<()>::_new(0), false).is_err());
3669    }
3670
3671    #[test]
3672    fn test_kfst_voikko_paragraph() {
3673        let words = [
3674            "on",
3675            "maanantaiaamu",
3676            "heinäkuussa",
3677            "aurinko",
3678            "paiskaa",
3679            "niin",
3680            "lämpöisesti",
3681            "heikon",
3682            "tuulen",
3683            "avulla",
3684            "ja",
3685            "peipposet",
3686            "kajahuttelevat",
3687            "ensimmäisiä",
3688            "kovia",
3689            "säveleitään",
3690            "tuoksuavissa",
3691            "koivuissa",
3692            "kirkon",
3693            "itäisellä",
3694            "seinuksella",
3695            "on",
3696            "kivipenkki",
3697            "juuri",
3698            "nyt",
3699            "saapuu",
3700            "keski-ikäinen",
3701            "työmies",
3702            "ja",
3703            "istuutuu",
3704            "penkille",
3705            "hän",
3706            "näyttää",
3707            "väsyneeltä",
3708            "alakuloiselta",
3709            "haluttomalla",
3710            "aivan",
3711            "kuin",
3712            "olisi",
3713            "vastikään",
3714            "tullut",
3715            "perheellisestä",
3716            "riidasta",
3717            "tahi",
3718            "jättänyt",
3719            "eilisen",
3720            "sapatinpäivän",
3721            "pyhittämättä",
3722        ];
3723        let gold: [Vec<(&str, i32)>; 48] = [
3724        vec![("[Lt][Xp]olla[X]o[Tt][Ap][P3][Ny][Ef]n", 0)],
3725        vec![("[Ln][Xp]maanantai[X]maanantai[Sn][Ny][Bh][Bc][Ln][Xp]aamu[X]aamu[Sn][Ny]", 0)],
3726        vec![("[Ln][Xp]heinä[X]hein[Sn][Ny]ä[Bh][Bc][Ln][Xp]kuu[X]kuu[Sine][Ny]ssa", 0)],
3727        vec![("[Ln][Xp]aurinko[X]aurinko[Sn][Ny]", 0), ("[Lem][Xp]Aurinko[X]aurinko[Sn][Ny]", 0), ("[Lee][Xp]Auri[X]aur[Sg][Ny]in[Fko][Ef]ko", 0)],
3728        vec![("[Lt][Xp]paiskata[X]paiska[Tt][Ap][P3][Ny][Eb]a", 0)],
3729        vec![("[Ls][Xp]niin[X]niin", 0)],
3730        vec![("[Ln][Xp]lämpö[X]lämpö[Ll][Xj]inen[X]ise[Ssti]sti", 0)],
3731        vec![("[Ll][Xp]heikko[X]heiko[Sg][Ny]n", 0)],
3732        vec![("[Ln][Xp]tuuli[X]tuul[Sg][Ny]en", 0)],
3733        vec![("[Ln][Xp]avu[X]avu[Sade][Ny]lla", 0), ("[Ln][Xp]apu[X]avu[Sade][Ny]lla", 0)],
3734        vec![("[Lc][Xp]ja[X]ja", 0)],
3735        vec![("[Ln][Xp]peipponen[X]peippo[Sn][Nm]set", 0)],
3736        vec![],
3737        vec![("[Lu][Xp]ensimmäinen[X]ensimmäi[Sp][Nm]siä", 0)],
3738        vec![("[Lnl][Xp]kova[X]kov[Sp][Nm]ia", 0)],
3739        vec![],
3740        vec![],
3741        vec![("[Ln][Xp]koivu[X]koivu[Sine][Nm]issa", 0), ("[Les][Xp]Koivu[X]koivu[Sine][Nm]issa", 0)],
3742        vec![("[Ln][Ica][Xp]kirkko[X]kirko[Sg][Ny]n", 0)],
3743        vec![("[Ln][De][Xp]itä[X]itä[Ll][Xj]inen[X]ise[Sade][Ny]llä", 0)],
3744        vec![("[Ln][Xp]seinus[X]seinukse[Sade][Ny]lla", 0)],
3745        vec![("[Lt][Xp]olla[X]o[Tt][Ap][P3][Ny][Ef]n", 0)],
3746        vec![("[Ln][Ica][Xp]kivi[X]kiv[Sn][Ny]i[Bh][Bc][Ln][Xp]penkki[X]penkk[Sn][Ny]i", 0)],
3747        vec![("[Ln][Xp]juuri[X]juur[Sn][Ny]i", 0), ("[Ls][Xp]juuri[X]juuri", 0), ("[Lt][Xp]juuria[X]juuri[Tk][Ap][P2][Ny][Eb]", 0), ("[Lt][Xp]juuria[X]juur[Tt][Ai][P3][Ny][Ef]i", 0)],
3748        vec![("[Ls][Xp]nyt[X]nyt", 0)],
3749        vec![("[Lt][Xp]saapua[X]saapuu[Tt][Ap][P3][Ny][Ef]", 0)],
3750        vec![("[Lp]keski[De]-[Bh][Bc][Ln][Xp]ikä[X]ikä[Ll][Xj]inen[X]i[Sn][Ny]nen", 0)],
3751        vec![("[Ln][Xp]työ[X]työ[Sn][Ny][Bh][Bc][Ln][Xp]mies[X]mies[Sn][Ny]", 0)],
3752        vec![("[Lc][Xp]ja[X]ja", 0)],
3753        vec![("[Lt][Xp]istuutua[X]istuutuu[Tt][Ap][P3][Ny][Ef]", 0)],
3754        vec![("[Ln][Xp]penkki[X]penki[Sall][Ny]lle", 0)],
3755        vec![("[Lr][Xp]hän[X]hä[Sn][Ny]n", 0)],
3756        vec![("[Lt][Xp]näyttää[X]näyttä[Tn1][Eb]ä", 0), ("[Lt][Xp]näyttää[X]näytt[Tt][Ap][P3][Ny][Ef]ää", 0)],
3757        vec![("[Lt][Irm][Xp]väsyä[X]väsy[Ll][Ru]n[Xj]yt[X]ee[Sabl][Ny]ltä", 0)],
3758        vec![("[Ln][De][Xp]ala[X]al[Sn][Ny]a[Bh][Bc][Lnl][Xp]kulo[X]kulo[Ll][Xj]inen[X]ise[Sabl][Ny]lta", 0)],
3759        vec![("[Ln][Xp]halu[X]halu[Ll][Xj]ton[X]ttoma[Sade][Ny]lla", 0)],
3760        vec![("[Ls][Xp]aivan[X]aivan", 0)],
3761        vec![("[Lc][Xp]kuin[X]kuin", 0), ("[Ln][Xp]kuu[X]ku[Sin][Nm]in", 0)],
3762        vec![("[Lt][Xp]olla[X]ol[Te][Ap][P3][Ny][Eb]isi", 0)],
3763        vec![("[Ls][Xp]vast=ikään[X]vast[Bm]ikään", 0)],
3764        vec![("[Lt][Xp]tulla[X]tul[Ll][Ru]l[Xj]ut[X][Sn][Ny]ut", 0), ("[Lt][Xp]tulla[X]tul[Ll][Rt][Xj]tu[X]lu[Sn][Nm]t", 0)],
3765        vec![("[Ln][Xp]perhe[X]perhee[Ll]lli[Xj]nen[X]se[Sela][Ny]stä", 0)],
3766        vec![("[Ln][Xp]riita[X]riida[Sela][Ny]sta", 0)],
3767        vec![("[Lc][Xp]tahi[X]tahi", 0)],
3768        vec![("[Lt][Xp]jättää[X]jättä[Ll][Ru]n[Xj]yt[X][Sn][Ny]yt", 0)],
3769        vec![("[Lnl][Xp]eilinen[X]eili[Sg][Ny]sen", 0)],
3770        vec![("[Ln][Xp]sapatti[X]sapat[Sg][Ny]in[Bh][Bc][Ln][Xp]päivä[X]päiv[Sg][Ny]än", 0)],
3771        vec![("[Lt][Xp]pyhittää[X]pyhittä[Ln]m[Xj]ä[X][Rm]ä[Sab][Ny]ttä", 0), ("[Lt][Xp]pyhittää[X]pyhittä[Tn3][Ny][Sab]mättä", 0)],
3772    ];
3773        let fst =
3774            FST::_from_kfst_file("../pyvoikko/pyvoikko/voikko.kfst".to_string(), false).unwrap();
3775        for (idx, (word, gold)) in words.into_iter().zip(gold.into_iter()).enumerate() {
3776            let mut sys = fst.lookup(word, FSTState::<()>::_new(0), false).unwrap();
3777            sys.sort_by(|a, b| a.partial_cmp(b).unwrap());
3778            let mut gold_sorted = gold;
3779            gold_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
3780            println!("Word at: {idx}");
3781            assert_eq!(
3782                sys,
3783                gold_sorted
3784                    .iter()
3785                    .map(|(s, w)| (s.to_string(), (*w).into()))
3786                    .collect::<Vec<_>>()
3787            );
3788        }
3789    }
3790
3791    #[test]
3792    fn test_simple_unknown() {
3793        let code = "0\t1\t@_UNKNOWN_SYMBOL_@\ty\n1";
3794        let fst = FST::from_att_code(code.to_string(), false).unwrap();
3795
3796        assert_eq!(
3797            fst.run_fst(
3798                vec![Symbol::String(StringSymbol::new("x".to_string(), false,))],
3799                FSTState::<()>::_new(0),
3800                false,
3801                None,
3802                true
3803            ),
3804            vec![]
3805        );
3806
3807        assert_eq!(
3808            fst.run_fst(
3809                vec![Symbol::String(StringSymbol::new("x".to_string(), true,))],
3810                FSTState::_new(0),
3811                false,
3812                None,
3813                true
3814            ),
3815            vec![(
3816                true,
3817                false,
3818                FSTState {
3819                    state_num: 1,
3820                    path_weight: 0.0,
3821                    input_flags: FlagMap::new(),
3822                    output_flags: FlagMap::new(),
3823                    symbol_mappings: FSTLinkedList::Some(
3824                        (Symbol::String(StringSymbol::new("y".to_string(), false)), 0),
3825                        Arc::new(FSTLinkedList::None)
3826                    )
3827                }
3828            )]
3829        );
3830    }
3831
3832    #[test]
3833    fn test_simple_identity() {
3834        let code = "0\t1\t@_IDENTITY_SYMBOL_@\t@_IDENTITY_SYMBOL_@\n1";
3835        let fst = FST::from_att_code(code.to_string(), false).unwrap();
3836
3837        assert_eq!(
3838            fst.run_fst(
3839                vec![Symbol::String(StringSymbol::new("x".to_string(), false,))],
3840                FSTState::<()>::_new(0),
3841                false,
3842                None,
3843                true
3844            ),
3845            vec![]
3846        );
3847
3848        assert_eq!(
3849            fst.run_fst(
3850                vec![Symbol::String(StringSymbol::new("x".to_string(), true,))],
3851                FSTState::_new(0),
3852                false,
3853                None,
3854                true
3855            ),
3856            vec![(
3857                true,
3858                false,
3859                FSTState {
3860                    state_num: 1,
3861                    path_weight: 0.0,
3862                    input_flags: FlagMap::new(),
3863                    output_flags: FlagMap::new(),
3864                    symbol_mappings: FSTLinkedList::Some(
3865                        (Symbol::String(StringSymbol::new("x".to_string(), true)), 0),
3866                        Arc::new(FSTLinkedList::None)
3867                    )
3868                }
3869            )]
3870        );
3871    }
3872
3873    #[test]
3874    fn test_raw_symbols() {
3875        // Construct simple transducer
3876
3877        let mut rules: IndexMap<u64, IndexMap<Symbol, Vec<(u64, Symbol, f64)>>> = IndexMap::new();
3878        let sym_a = Symbol::Raw(RawSymbol {
3879            value: [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
3880        });
3881        let sym_b = Symbol::Raw(RawSymbol {
3882            value: [0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
3883        });
3884        let sym_c = Symbol::Raw(RawSymbol {
3885            value: [0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
3886        });
3887        let special_epsilon = Symbol::Raw(RawSymbol {
3888            value: [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
3889        });
3890        let sym_d = Symbol::Raw(RawSymbol {
3891            value: [0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
3892        });
3893        let sym_d_unk = Symbol::Raw(RawSymbol {
3894            value: [2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
3895        });
3896        rules.insert(0, indexmap!(sym_a.clone() => vec![(1, sym_a.clone(), 0.0)]));
3897        rules.insert(1, indexmap!(sym_b.clone() => vec![(0, sym_b.clone(), 0.0)], Symbol::Special(SpecialSymbol::IDENTITY) => vec![(2, Symbol::Special(SpecialSymbol::IDENTITY), 0.0)]));
3898        rules.insert(
3899            2,
3900            indexmap!(special_epsilon.clone() => vec![(3, sym_c.clone(), 0.0)]),
3901        );
3902        let symbols = vec![sym_a.clone(), sym_b.clone(), sym_c.clone(), special_epsilon];
3903        let fst = FST {
3904            final_states: indexmap! {3 => 0.0},
3905            rules,
3906            symbols,
3907            debug: false,
3908        };
3909
3910        // Accepting example that tests epsilon + unknown bits
3911
3912        let result = fst.run_fst(
3913            vec![
3914                sym_a.clone(),
3915                sym_b.clone(),
3916                sym_a.clone(),
3917                sym_d_unk.clone(),
3918            ],
3919            FSTState::<()>::_new(0),
3920            false,
3921            None,
3922            true,
3923        );
3924        let filtered: Vec<_> = result.into_iter().filter(|x| x.0).collect();
3925        assert_eq!(filtered.len(), 1);
3926        assert_eq!(filtered[0].2.state_num, 3);
3927        assert_eq!(
3928            filtered[0]
3929                .2
3930                .symbol_mappings
3931                .clone()
3932                .to_vec()
3933                .iter()
3934                .map(|x| x.0.clone())
3935                .collect::<Vec<_>>(),
3936            vec![
3937                sym_a.clone(),
3938                sym_b.clone(),
3939                sym_a.clone(),
3940                sym_d_unk.clone(),
3941                sym_c.clone()
3942            ]
3943        );
3944
3945        // Rejecting example that further tests the unknown bit
3946
3947        assert_eq!(
3948            fst.run_fst(
3949                vec![sym_a.clone(), sym_b.clone(), sym_a.clone(), sym_d.clone()],
3950                FSTState::<()>::_new(0),
3951                false,
3952                None,
3953                true
3954            )
3955            .into_iter()
3956            .filter(|x| x.0)
3957            .count(),
3958            0,
3959        );
3960    }
3961
3962    #[test]
3963    fn test_string_comparison_order_for_tokenizable_symbol_types() {
3964        assert!(
3965            StringSymbol::new("aa".to_string(), false) < StringSymbol::new("a".to_string(), false)
3966        );
3967        assert!(
3968            StringSymbol::new("aa".to_string(), true) > StringSymbol::new("aa".to_string(), false)
3969        );
3970        assert!(
3971            StringSymbol::new("ab".to_string(), false) > StringSymbol::new("aa".to_string(), false)
3972        );
3973
3974        assert!(
3975            FlagDiacriticSymbol::parse("@U.aa@").unwrap()
3976                < FlagDiacriticSymbol::parse("@U.a@").unwrap()
3977        );
3978        assert!(
3979            FlagDiacriticSymbol::parse("@U.ab@").unwrap()
3980                > FlagDiacriticSymbol::parse("@U.aa@").unwrap()
3981        );
3982        assert!(SpecialSymbol::IDENTITY < SpecialSymbol::EPSILON); // "@0@" is the shorter string
3983    }
3984
3985    #[test]
3986    fn fst_linked_list_conversion_correctness_internal_methods() {
3987        let orig = vec![
3988            Symbol::Special(SpecialSymbol::EPSILON),
3989            Symbol::Special(SpecialSymbol::IDENTITY),
3990            Symbol::Special(SpecialSymbol::UNKNOWN),
3991        ];
3992        assert!(
3993            FSTLinkedList::from_vec(orig.clone(), vec![10, 20, 30]).to_vec()
3994                == vec![
3995                    (Symbol::Special(SpecialSymbol::EPSILON), 10),
3996                    (Symbol::Special(SpecialSymbol::IDENTITY), 20),
3997                    (Symbol::Special(SpecialSymbol::UNKNOWN), 30),
3998                ]
3999        );
4000
4001        assert!(
4002            FSTLinkedList::from_vec(orig.clone(), vec![]).to_vec()
4003                == vec![
4004                    (Symbol::Special(SpecialSymbol::EPSILON), 0),
4005                    (Symbol::Special(SpecialSymbol::IDENTITY), 0),
4006                    (Symbol::Special(SpecialSymbol::UNKNOWN), 0),
4007                ]
4008        );
4009    }
4010
4011    #[test]
4012    fn fst_linked_list_conversion_correctness_iterators_with_indices() {
4013        let orig = vec![
4014            (Symbol::Special(SpecialSymbol::EPSILON), 10),
4015            (Symbol::Special(SpecialSymbol::IDENTITY), 20),
4016            (Symbol::Special(SpecialSymbol::UNKNOWN), 30),
4017        ];
4018        let ll: FSTLinkedList<usize> = orig.clone().into_iter().collect();
4019        assert!(
4020            ll.into_iter().collect::<Vec<_>>()
4021                == vec![
4022                    (Symbol::Special(SpecialSymbol::EPSILON), 10),
4023                    (Symbol::Special(SpecialSymbol::IDENTITY), 20),
4024                    (Symbol::Special(SpecialSymbol::UNKNOWN), 30),
4025                ]
4026        );
4027    }
4028}
4029
4030/// A Python module implemented in Rust.
4031#[cfg(feature = "python")]
4032#[pymodule]
4033fn kfst_rs(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
4034    let symbols = PyModule::new(m.py(), "symbols")?;
4035    symbols.add_class::<StringSymbol>()?;
4036    symbols.add_class::<FlagDiacriticType>()?;
4037    symbols.add_class::<FlagDiacriticSymbol>()?;
4038    symbols.add_class::<SpecialSymbol>()?;
4039    symbols.add_class::<RawSymbol>()?;
4040    symbols.add_function(wrap_pyfunction!(from_symbol_string, m)?)?;
4041
4042    py_run!(
4043        py,
4044        symbols,
4045        "import sys; sys.modules['kfst_rs.symbols'] = symbols"
4046    );
4047
4048    m.add_submodule(&symbols)?;
4049
4050    let transducer = PyModule::new(m.py(), "transducer")?;
4051    transducer.add_class::<FST>()?;
4052    transducer.add_class::<FSTState>()?;
4053    transducer.add(
4054        "TokenizationException",
4055        py.get_type::<TokenizationException>(),
4056    )?;
4057
4058    py_run!(
4059        py,
4060        transducer,
4061        "import sys; sys.modules['kfst_rs.transducer'] = transducer"
4062    );
4063
4064    m.add_submodule(&transducer)?;
4065
4066    // Mimick reimports
4067
4068    m.add(
4069        "TokenizationException",
4070        py.get_type::<TokenizationException>(),
4071    )?;
4072    m.add_class::<FST>()?;
4073
4074    Ok(())
4075}