lling-llang 0.1.0

WFST framework for text normalization and grammar correction
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
//! Token graph variants for CTC-like training.
//!
//! This module provides various token graph constructions that encode
//! different prior assumptions about label alignments in sequence-to-sequence
//! training.
//!
//! ## Token Graph Variants
//!
//! 1. **Standard CTC**: Allows any number of blank/non-blank repetitions
//! 2. **Spike CTC**: Single emission per non-blank token
//! 3. **Duration-Limited CTC**: Limits token duration to 1-2 frames
//! 4. **Equally Spaced CTC**: Fixed distance between non-blank tokens
//!
//! ## Prior Encoding
//!
//! The token graph encodes prior beliefs about:
//! - Token duration distribution
//! - Alignment sparsity
//! - Temporal spacing of emissions
//!
//! Different priors suit different data characteristics:
//! - Tight handwriting → shorter token durations → Spike CTC
//! - Loose handwriting → longer durations → Standard CTC
//!
//! ## References
//!
//! - Hannun et al., "Differentiable Weighted Finite-State Transducers" (ICML 2020, arXiv:2010.01003)
//! - Collobert et al., "Wav2Letter" (2016)

use crate::semiring::{LogWeight, Semiring};
use crate::wfst::{MutableWfst, StateId, VectorWfst, Wfst};

/// Token identifier type.
pub type TokenId = u32;

/// Blank token constant (typically 0).
pub const BLANK_TOKEN: TokenId = 0;

/// Token graph type for different CTC variants.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum TokenGraphType {
    /// Standard CTC with unlimited repetitions.
    Standard,
    /// Spike CTC: single emission per token.
    Spike,
    /// Duration-limited CTC with max duration.
    DurationLimited {
        /// Maximum repetition count for any token before forcing a blank.
        max_duration: usize,
    },
    /// Equally spaced CTC with fixed blank count between tokens.
    EquallySpaced {
        /// Number of blank tokens required between each pair of label tokens.
        blank_count: usize,
    },
}

/// Configuration for token graph construction.
#[derive(Clone, Debug)]
pub struct TokenGraphConfig {
    /// Type of token graph to construct.
    pub graph_type: TokenGraphType,
    /// Whether to include blank token.
    pub include_blank: bool,
    /// Blank token ID.
    pub blank_id: TokenId,
    /// Initial weight for transitions.
    pub init_weight: f64,
}

impl Default for TokenGraphConfig {
    fn default() -> Self {
        Self {
            graph_type: TokenGraphType::Standard,
            include_blank: true,
            blank_id: BLANK_TOKEN,
            init_weight: 0.0,
        }
    }
}

/// Build a token graph for a single token.
///
/// # Arguments
///
/// * `token` - The token ID
/// * `config` - Configuration for the token graph
///
/// # Returns
///
/// A WFST representing the token graph.
pub fn build_token_graph(
    token: TokenId,
    config: &TokenGraphConfig,
) -> VectorWfst<TokenId, LogWeight> {
    match config.graph_type {
        TokenGraphType::Standard => build_standard_token_graph(token, config),
        TokenGraphType::Spike => build_spike_token_graph(token, config),
        TokenGraphType::DurationLimited { max_duration } => {
            build_duration_limited_token_graph(token, max_duration, config)
        }
        TokenGraphType::EquallySpaced { blank_count } => {
            build_equally_spaced_token_graph(token, blank_count, config)
        }
    }
}

/// Build standard CTC token graph.
///
/// Structure:
/// ```text
///     ε:ε (self-loop)
////// 0 --a:a--> 1
//////     a:ε (self-loop)
/// ```
///
/// Allows any number of repetitions.
fn build_standard_token_graph(
    token: TokenId,
    config: &TokenGraphConfig,
) -> VectorWfst<TokenId, LogWeight> {
    let mut fst = VectorWfst::new();

    let s0 = fst.add_state();
    let s1 = fst.add_state();

    fst.set_start(s0);
    fst.set_final(s1, LogWeight::one());

    // Main transition: token:token
    fst.add_arc(
        s0,
        Some(token),
        Some(token),
        s1,
        LogWeight::new(config.init_weight),
    );

    // Self-loop on s1 for repetitions: token:ε
    fst.add_arc(
        s1,
        Some(token),
        None,
        s1,
        LogWeight::new(config.init_weight),
    );

    // Optional blank handling
    if config.include_blank {
        // Self-loop on s0 for leading blanks
        fst.add_arc(
            s0,
            Some(config.blank_id),
            None,
            s0,
            LogWeight::new(config.init_weight),
        );
        // Self-loop on s1 for trailing blanks
        fst.add_arc(
            s1,
            Some(config.blank_id),
            None,
            s1,
            LogWeight::new(config.init_weight),
        );
    }

    fst
}

/// Build spike CTC token graph.
///
/// Structure:
/// ```text
/// 0 --a:a--> 1
/// ```
///
/// Only allows single emission per token.
fn build_spike_token_graph(
    token: TokenId,
    config: &TokenGraphConfig,
) -> VectorWfst<TokenId, LogWeight> {
    let mut fst = VectorWfst::new();

    let s0 = fst.add_state();
    let s1 = fst.add_state();

    fst.set_start(s0);
    fst.set_final(s1, LogWeight::one());

    // Single transition: token:token
    fst.add_arc(
        s0,
        Some(token),
        Some(token),
        s1,
        LogWeight::new(config.init_weight),
    );

    // Optional blank handling - only at boundaries
    if config.include_blank {
        // Blanks before token
        fst.add_arc(
            s0,
            Some(config.blank_id),
            None,
            s0,
            LogWeight::new(config.init_weight),
        );
        // Blanks after token
        fst.add_arc(
            s1,
            Some(config.blank_id),
            None,
            s1,
            LogWeight::new(config.init_weight),
        );
    }

    fst
}

/// Build duration-limited CTC token graph.
///
/// Structure for max_duration=2:
/// ```text
/// 0 --a:a--> 1 --a:ε--> 2
/// ```
///
/// Limits token duration to specified maximum.
fn build_duration_limited_token_graph(
    token: TokenId,
    max_duration: usize,
    config: &TokenGraphConfig,
) -> VectorWfst<TokenId, LogWeight> {
    let mut fst = VectorWfst::new();

    // Create states: 0 (start), 1 to max_duration (emissions)
    let mut states = Vec::with_capacity(max_duration + 1);
    for _ in 0..=max_duration {
        states.push(fst.add_state());
    }

    fst.set_start(states[0]);
    fst.set_final(states[max_duration], LogWeight::one());

    // First transition emits the token
    fst.add_arc(
        states[0],
        Some(token),
        Some(token),
        states[1],
        LogWeight::new(config.init_weight),
    );

    // Subsequent transitions are repetitions (token:ε)
    for i in 1..max_duration {
        fst.add_arc(
            states[i],
            Some(token),
            None,
            states[i + 1],
            LogWeight::new(config.init_weight),
        );

        // Also make intermediate states final
        fst.set_final(states[i], LogWeight::one());
    }

    // Optional blank handling
    if config.include_blank {
        // Blanks at start
        fst.add_arc(
            states[0],
            Some(config.blank_id),
            None,
            states[0],
            LogWeight::new(config.init_weight),
        );
        // Blanks at end
        fst.add_arc(
            states[max_duration],
            Some(config.blank_id),
            None,
            states[max_duration],
            LogWeight::new(config.init_weight),
        );
    }

    fst
}

/// Build equally spaced CTC token graph.
///
/// Structure for blank_count=2:
/// ```text
/// 0 --a:a--> 1 --<blank>:ε--> 2 --<blank>:ε--> 3
/// ```
///
/// Requires fixed number of blanks between tokens.
fn build_equally_spaced_token_graph(
    token: TokenId,
    blank_count: usize,
    config: &TokenGraphConfig,
) -> VectorWfst<TokenId, LogWeight> {
    let mut fst = VectorWfst::new();

    // States: 0 (start), 1 (after token), 2..blank_count+1 (blanks)
    let num_states = blank_count + 2;
    let mut states = Vec::with_capacity(num_states);
    for _ in 0..num_states {
        states.push(fst.add_state());
    }

    fst.set_start(states[0]);
    fst.set_final(states[num_states - 1], LogWeight::one());

    // Token emission
    fst.add_arc(
        states[0],
        Some(token),
        Some(token),
        states[1],
        LogWeight::new(config.init_weight),
    );

    // Required blanks
    for i in 1..=blank_count {
        fst.add_arc(
            states[i],
            Some(config.blank_id),
            None,
            states[i + 1],
            LogWeight::new(config.init_weight),
        );
    }

    // Also allow empty (for end of sequence)
    fst.set_final(states[1], LogWeight::one());

    fst
}

/// Build a complete token vocabulary graph.
///
/// Creates the union of token graphs for all tokens in the vocabulary.
///
/// # Arguments
///
/// * `vocab_size` - Number of tokens (excluding blank if separate)
/// * `config` - Configuration for token graphs
///
/// # Returns
///
/// A WFST representing (T₁ + T₂ + ... + T_n)* where T_i is the token graph for token i.
pub fn build_vocabulary_graph(
    vocab_size: usize,
    config: &TokenGraphConfig,
) -> VectorWfst<TokenId, LogWeight> {
    let mut fst = VectorWfst::new();

    // Create start state (also final for empty sequence)
    let start = fst.add_state();
    fst.set_start(start);
    fst.set_final(start, LogWeight::one());

    // Start token ID (skip blank if it's ID 0)
    let start_token = if config.include_blank { 1 } else { 0 };

    // Add each token's graph
    for token_id in start_token..(start_token + vocab_size as TokenId) {
        let token_graph = build_token_graph(token_id, config);

        // Embed token graph into main FST
        // Map states: token_graph.start -> new state, etc.
        let state_offset = fst.num_states() as StateId;

        // Add states for this token graph
        for _ in 0..token_graph.num_states() {
            fst.add_state();
        }

        // Copy arcs with state remapping
        for s in 0..token_graph.num_states() as StateId {
            for arc in token_graph.transitions(s) {
                fst.add_arc(
                    s + state_offset,
                    arc.input.clone(),
                    arc.output.clone(),
                    arc.to + state_offset,
                    arc.weight,
                );
            }
        }

        // Connect main start to token graph start
        let token_start = token_graph.start() + state_offset;
        fst.add_arc(start, None, None, token_start, LogWeight::one());

        // Connect token graph final states back to main start (for closure)
        for s in 0..token_graph.num_states() as StateId {
            if token_graph.is_final(s) {
                let mapped_state = s + state_offset;
                fst.add_arc(mapped_state, None, None, start, token_graph.final_weight(s));
            }
        }
    }

    fst
}

/// Build a blank graph for CTC.
///
/// Creates a graph that accepts only blank tokens.
pub fn build_blank_graph(config: &TokenGraphConfig) -> VectorWfst<TokenId, LogWeight> {
    let mut fst = VectorWfst::new();

    let s0 = fst.add_state();
    fst.set_start(s0);
    fst.set_final(s0, LogWeight::one());

    // Self-loop for blanks
    fst.add_arc(
        s0,
        Some(config.blank_id),
        None,
        s0,
        LogWeight::new(config.init_weight),
    );

    fst
}

/// Statistics about token graphs.
#[derive(Clone, Debug, Default)]
pub struct TokenGraphStats {
    /// Number of states in the graph.
    pub num_states: usize,
    /// Number of arcs in the graph.
    pub num_arcs: usize,
    /// Graph type used.
    pub graph_type: Option<TokenGraphType>,
}

impl TokenGraphStats {
    /// Compute statistics for a token graph.
    pub fn from_wfst<L: Clone + Send + Sync>(fst: &VectorWfst<L, LogWeight>) -> Self {
        let num_states = fst.num_states();
        let num_arcs: usize = (0..num_states as StateId)
            .map(|s| fst.transitions(s).len())
            .sum();

        Self {
            num_states,
            num_arcs,
            graph_type: None,
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::wfst::NO_STATE;

    #[test]
    fn test_token_graph_config_default() {
        let config = TokenGraphConfig::default();
        assert_eq!(config.graph_type, TokenGraphType::Standard);
        assert!(config.include_blank);
        assert_eq!(config.blank_id, BLANK_TOKEN);
    }

    #[test]
    fn test_standard_token_graph() {
        let config = TokenGraphConfig::default();
        let graph = build_token_graph(1, &config);

        assert!(graph.start() != NO_STATE);
        assert!(graph.num_states() >= 2);

        // Should have self-loops for repetitions
        let stats = TokenGraphStats::from_wfst(&graph);
        assert!(stats.num_arcs >= 2);
    }

    #[test]
    fn test_spike_token_graph() {
        let config = TokenGraphConfig {
            graph_type: TokenGraphType::Spike,
            ..Default::default()
        };
        let graph = build_token_graph(1, &config);

        // Spike has no repetition self-loop on the token itself
        assert_eq!(graph.num_states(), 2);
    }

    #[test]
    fn test_duration_limited_token_graph() {
        let config = TokenGraphConfig {
            graph_type: TokenGraphType::DurationLimited { max_duration: 3 },
            include_blank: false,
            ..Default::default()
        };
        let graph = build_token_graph(1, &config);

        // Should have max_duration + 1 states
        assert_eq!(graph.num_states(), 4);
    }

    #[test]
    fn test_equally_spaced_token_graph() {
        let config = TokenGraphConfig {
            graph_type: TokenGraphType::EquallySpaced { blank_count: 2 },
            ..Default::default()
        };
        let graph = build_token_graph(1, &config);

        // Should have blank_count + 2 states
        assert_eq!(graph.num_states(), 4);
    }

    #[test]
    fn test_vocabulary_graph() {
        let config = TokenGraphConfig {
            graph_type: TokenGraphType::Spike,
            include_blank: true,
            blank_id: 0,
            init_weight: 0.0,
        };

        let graph = build_vocabulary_graph(3, &config);

        // Should have start state plus states for each token graph
        assert!(graph.num_states() > 1);
        assert!(graph.start() != NO_STATE);
    }

    #[test]
    fn test_blank_graph() {
        let config = TokenGraphConfig::default();
        let graph = build_blank_graph(&config);

        assert_eq!(graph.num_states(), 1);
        assert!(graph.is_final(0));
    }

    #[test]
    fn test_token_graph_stats() {
        let config = TokenGraphConfig::default();
        let graph = build_token_graph(1, &config);
        let stats = TokenGraphStats::from_wfst(&graph);

        assert!(stats.num_states > 0);
        assert!(stats.num_arcs > 0);
    }

    #[test]
    fn test_duration_limited_all_states_reachable() {
        let config = TokenGraphConfig {
            graph_type: TokenGraphType::DurationLimited { max_duration: 2 },
            include_blank: false,
            ..Default::default()
        };
        let graph = build_token_graph(1, &config);

        // All intermediate states should be final
        assert!(graph.is_final(1)); // After 1 emission
        assert!(graph.is_final(2)); // After 2 emissions
    }

    #[test]
    fn test_equally_spaced_requires_blanks() {
        let config = TokenGraphConfig {
            graph_type: TokenGraphType::EquallySpaced { blank_count: 2 },
            include_blank: true,
            blank_id: 0,
            ..Default::default()
        };
        let graph = build_token_graph(5, &config);

        // Should have transitions for blanks
        let stats = TokenGraphStats::from_wfst(&graph);
        assert!(stats.num_arcs >= 3); // 1 token + 2 blanks
    }
}