lling-llang 0.1.0

WFST framework for text normalization and grammar correction
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
# Differentiable WFSTs

Differentiable WFSTs enable automatic differentiation through WFST operations, allowing gradient-based training with WFST-based loss functions. This bridges the gap between traditional WFST algorithms and modern deep learning frameworks.

## Concepts

### What is Differentiable WFST?

A differentiable WFST extends the standard WFST with the ability to compute gradients with respect to arc weights. This enables:

1. **End-to-end training**: Backpropagation through WFST operations
2. **Sequence-level losses**: CTC, ASG, and other alignment-based losses
3. **Integration with neural networks**: WFSTs as trainable layers

```
Neural Network          WFST Operations          Loss
     │                       │                     │
     ▼                       ▼                     ▼
  logits ───────────► forward_score ─────────► -log p(y|x)
     ▲                       │                     │
     │                       ▼                     │
gradients ◄─────────── backward ◄────────────────┘
```

### The Gradient Graph

Every WFST operation returns a graph where gradients can be computed. The key insight is that **the gradient of a WFST is also a WFST**—it has the same topology but with gradient values instead of weights.

```
Original WFST:                 Gradient WFST:
    w₁=1.0                        g₁=0.73
  0 ────────► 1                 0 ────────► 1
  │           │                 │           │
  │ w₂=2.0    │ (final)         │ g₂=0.27   │ (final)
  └──────────►│                 └──────────►│

Weights are path probabilities    Gradients are path posteriors
```

### Forward and Backward Passes

The differentiation follows the forward-backward algorithm:

**Forward Pass (α)**:
- `α[start] = 1̄` (log semiring one = 0.0)
- `α[t] = α[t] ⊕ (α[s] ⊗ w)` for each arc (s, t, w)
- Total score Z = ⊕_{f ∈ F} (α[f] ⊗ ρ[f])

**Backward Pass (β)**:
- `β[f] = ρ[f]` for final states
- `β[s] = β[s] ⊕ (w ⊗ β[t])` for each arc (s, t, w)

**Arc Gradients**:
```
∂Z/∂w = exp(α[s] + w + β[t] - Z)
```

This is the **posterior probability** that the arc is used in a random path.

## Core API

### Types

```rust
/// Index identifying an arc in a WFST.
pub struct ArcIndex {
    pub from: StateId,
    pub arc_idx: usize,
}

/// Gradient associated with a single arc.
pub struct ArcGradient {
    pub arc: ArcIndex,
    pub gradient: f64,
}

/// Accumulated gradients for all arcs in a WFST.
pub struct GradientAccumulator {
    pub arc_gradients: Vec<ArcGradient>,
    pub num_arcs: usize,
}

/// A WFST with gradient tracking for automatic differentiation.
pub struct GradientWfst<L: Clone> {
    fst: VectorWfst<L, LogWeight>,
    forward_scores: Vec<LogWeight>,   // α values
    backward_scores: Vec<LogWeight>,  // β values
    // ...
}

/// Result of Viterbi path computation with gradients.
pub struct ViterbiGradResult {
    pub score: LogWeight,
    pub path: Vec<ArcIndex>,
    pub gradients: GradientAccumulator,
}
```

### Functions

```rust
/// Compute forward score (log-sum-exp over all paths)
pub fn forward_score<L>(grad_fst: &GradientWfst<L>) -> LogWeight;

/// Alias for forward_score emphasizing the operation
pub fn log_sum_exp_paths<L>(grad_fst: &GradientWfst<L>) -> LogWeight;

/// Compute Viterbi (best path) score
pub fn viterbi_score<L>(grad_fst: &GradientWfst<L>) -> LogWeight;

/// Compute Viterbi path with gradients
pub fn viterbi_path_with_grad<L>(grad_fst: &GradientWfst<L>) -> ViterbiGradResult;

/// Compute backward pass gradients
pub fn backward<L>(grad_fst: &GradientWfst<L>) -> GradientAccumulator;
```

## Examples

### Basic Forward Score and Gradients

```rust
use lling_llang::differentiable::{forward_score, backward, GradientWfst};
use lling_llang::wfst::{VectorWfst, MutableWfst};
use lling_llang::semiring::{LogWeight, Semiring};

// Create a WFST with two parallel paths
let mut fst = VectorWfst::<char, LogWeight>::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
fst.set_start(s0);
fst.set_final(s1, LogWeight::one());
fst.add_arc(s0, Some('a'), Some('a'), s1, LogWeight::new(1.0)); // prob e⁻¹
fst.add_arc(s0, Some('b'), Some('b'), s1, LogWeight::new(2.0)); // prob e⁻²

// Wrap in gradient-tracking structure
let grad_fst = GradientWfst::from_wfst(&fst);

// Compute forward score (log of total probability)
let score = forward_score(&grad_fst);
// score ≈ 0.687 = -log(e⁻¹ + e⁻²)

// Compute gradients via backward pass
let gradients = backward(&grad_fst);

// Gradient for arc 0: exp(-1) / (exp(-1) + exp(-2)) ≈ 0.73
// Gradient for arc 1: exp(-2) / (exp(-1) + exp(-2)) ≈ 0.27
```

### Viterbi Score with Path

```rust
use lling_llang::differentiable::{viterbi_score, viterbi_path_with_grad, GradientWfst};

// WFST with two paths of different weights
let mut fst = VectorWfst::<char, LogWeight>::new();
// ... build fst ...

let grad_fst = GradientWfst::from_wfst(&fst);

// Get just the best score
let best_score = viterbi_score(&grad_fst);

// Get score, path, and gradients
let result = viterbi_path_with_grad(&grad_fst);
println!("Best score: {}", result.score.value());
println!("Best path length: {}", result.path.len());

// Gradients are 1.0 for arcs on best path, 0.0 otherwise
for arc in &result.path {
    let grad = result.gradients.get_gradient(*arc);
    assert!((grad - 1.0).abs() < 1e-6);
}
```

### CTC Loss Computation

```rust
use lling_llang::differentiable::{forward_score, backward, GradientWfst};
use lling_llang::ctc::compact_ctc;
use lling_llang::composition::compose;

// Neural network emissions (T frames × V vocabulary)
let emissions = build_emissions_graph(&logits);

// CTC topology (defines valid alignments)
let ctc = compact_ctc::<LogWeight>(vocab_size);

// Target sequence
let target = build_target_graph(&labels);

// Constrained graph: valid alignments for this target
let constrained = compose(&compose(&emissions, &ctc), &target);

// Normalization graph: all possible alignments
let normalization = compose(&emissions, &ctc);

// Wrap for differentiation
let constrained_grad = GradientWfst::from_wfst(&constrained);
let normalization_grad = GradientWfst::from_wfst(&normalization);

// CTC loss = log Z_norm - log Z_constrained
let norm_score = forward_score(&normalization_grad);
let constrained_score = forward_score(&constrained_grad);
let loss = norm_score.value() - constrained_score.value();

// Backward pass for gradients
let constrained_grads = backward(&constrained_grad);
let normalization_grads = backward(&normalization_grad);

// Gradient for each arc: grad_norm - grad_constrained
```

### Sequence-Level Training

```rust
use lling_llang::differentiable::{GradientWfst, forward_score, backward};

// General form of sequence-level loss:
// loss = -log p(y|X) = Z_norm - Z_constrained

fn sequence_loss<L: Clone + Send + Sync>(
    emissions: &VectorWfst<L, LogWeight>,
    transitions: &VectorWfst<L, LogWeight>,
    target: &VectorWfst<L, LogWeight>,
) -> (f64, GradientAccumulator) {
    // Constrained: valid alignments for target
    let constrained = compose(&compose(emissions, transitions), target);
    let constrained_grad = GradientWfst::from_wfst(&constrained);

    // Normalization: all alignments
    let normalization = compose(emissions, transitions);
    let normalization_grad = GradientWfst::from_wfst(&normalization);

    // Scores
    let z_constrained = forward_score(&constrained_grad);
    let z_norm = forward_score(&normalization_grad);

    // Loss
    let loss = z_norm.value() - z_constrained.value();

    // Gradients (difference of posteriors)
    let grad_constrained = backward(&constrained_grad);
    let grad_norm = backward(&normalization_grad);

    // Combine gradients: ∂loss/∂w = p(arc|all) - p(arc|target)
    let mut combined = grad_norm.clone();
    for g in &grad_constrained.arc_gradients {
        combined.add_gradient(g.arc, -g.gradient);
    }

    (loss, combined)
}
```

## Algorithm Details

### Forward Score Computation

```
Algorithm: FORWARD_SCORE(fst)
  1. Initialize α[start] = 0.0 (log one), α[other] = -∞ (log zero)
  2. topo_order = topological_sort(fst)
  3. For each state s in topo_order:
       For each arc (s, t, w):
         α[t] = logadd(α[t], α[s] + w)
  4. Z = logadd_{f ∈ finals}(α[f] + ρ[f])
  5. Return Z
```

Where `logadd(a, b) = log(exp(a) + exp(b))`.

### Backward Pass

```
Algorithm: BACKWARD(fst, Z)
  1. Initialize β[f] = ρ[f] for finals, β[other] = -∞
  2. topo_order = topological_sort(fst)
  3. For each state s in reverse(topo_order):
       For each arc (s, t, w):
         β[s] = logadd(β[s], w + β[t])
  4. For each arc (s, t, w):
       gradient[arc] = exp(α[s] + w + β[t] - Z)
  5. Return gradients
```

### Gradient Interpretation

The gradient `∂Z/∂w = exp(α[s] + w + β[t] - Z)` equals the **posterior probability** that arc (s,t) is used when a path is sampled proportionally to its weight.

```
                α[s]                β[t]
Paths to s ─────────► s ──w──► t ─────────► Final

Gradient = P(path uses arc (s,t) | all paths)
         = (paths through arc) / (all paths)
         = exp(α[s] + w + β[t]) / exp(Z)
```

## Complexity

| Operation | Time | Space |
|-----------|------|-------|
| Forward score (acyclic) | O(\|Q\| + \|E\|) | O(\|Q\|) |
| Forward score (cyclic) | O(\|Q\|²) | O(\|Q\|) |
| Backward pass | O(\|Q\| + \|E\|) | O(\|Q\| + \|E\|) |
| Viterbi score | O(\|Q\| + \|E\|) | O(\|Q\|) |
| Viterbi path | O(\|Q\| + \|E\|) | O(\|Q\|) |

## Semiring Considerations

### Log Semiring for Forward Score

The log semiring is used for computing total path weight:

```
⊕ = logadd (log of sum)
⊗ = +      (log of product)
0̄ = -∞     (log of 0)
1̄ = 0      (log of 1)
```

This gives the **total probability** when weights are log-probabilities.

### Tropical Semiring for Viterbi

The tropical semiring gives the best single path:

```
⊕ = min
⊗ = +
0̄ = +∞
1̄ = 0
```

**Critical difference**: Forward score sums over paths; Viterbi takes the best.

## Common Patterns

### Loss Function Template

```rust
fn differentiable_loss<L>(
    constrained: &VectorWfst<L, LogWeight>,
    normalization: &VectorWfst<L, LogWeight>,
) -> (f64, Vec<ArcGradient>) {
    let c = GradientWfst::from_wfst(constrained);
    let n = GradientWfst::from_wfst(normalization);

    let loss = forward_score(&n).value() - forward_score(&c).value();

    let grad_n = backward(&n);
    let grad_c = backward(&c);

    // Subtract constrained gradients from normalization gradients
    let combined = combine_gradients(&grad_n, &grad_c, -1.0);

    (loss, combined.arc_gradients)
}
```

### Gradient Accumulation

```rust
use lling_llang::differentiable::GradientAccumulator;

// Accumulate gradients across batches
let mut total_grads = GradientAccumulator::new();

for batch in &batches {
    let (loss, grads) = compute_batch_loss(batch);
    total_grads.merge(&grads);
    total_loss += loss;
}

// Average gradients
for g in &mut total_grads.arc_gradients {
    g.gradient /= batches.len() as f64;
}
```

### Gradient Clipping

```rust
fn clip_gradients(grads: &mut GradientAccumulator, max_norm: f64) {
    // Compute gradient norm
    let norm: f64 = grads.arc_gradients
        .iter()
        .map(|g| g.gradient * g.gradient)
        .sum::<f64>()
        .sqrt();

    // Clip if exceeds max
    if norm > max_norm {
        let scale = max_norm / norm;
        for g in &mut grads.arc_gradients {
            g.gradient *= scale;
        }
    }
}
```

## Numerical Stability

### Log-Space Computation

All operations are performed in log space to avoid underflow:

```rust
// Instead of: prob = prob1 * prob2
// We compute: log_prob = log_prob1 + log_prob2

// Instead of: prob = prob1 + prob2
// We compute: log_prob = logadd(log_prob1, log_prob2)
```

### LogAdd Implementation

```rust
fn logadd(a: f64, b: f64) -> f64 {
    if a == f64::NEG_INFINITY {
        b
    } else if b == f64::NEG_INFINITY {
        a
    } else if a > b {
        a + (b - a).exp().ln_1p()
    } else {
        b + (a - b).exp().ln_1p()
    }
}
```

The `ln_1p` function computes `ln(1 + x)` more accurately for small `x`.

## Visualization

### Forward-Backward on a Diamond

```
           α=0.0
            [0]
           /   \
     w=1.0       w=2.0
         ↓         ↓
        (1)       (2)
         │         │
    w=0.5     w=0.3
         ↓         ↓
            [3]
          β=0.0

Forward (α):                    Backward (β):
  α[0] = 0.0                      β[3] = 0.0
  α[1] = 0.0 + 1.0 = 1.0          β[1] = 0.5 + 0.0 = 0.5
  α[2] = 0.0 + 2.0 = 2.0          β[2] = 0.3 + 0.0 = 0.3
  α[3] = logadd(1.5, 2.3)         β[0] = logadd(1.0+0.5, 2.0+0.3)
       = 1.35                          = 1.35

Z = 1.35

Gradients:
  g(0→1) = exp(0 + 1.0 + 0.5 - 1.35) = 0.86
  g(0→2) = exp(0 + 2.0 + 0.3 - 1.35) = 0.39
  g(1→3) = exp(1.0 + 0.5 + 0 - 1.35) = 0.86
  g(2→3) = exp(2.0 + 0.3 + 0 - 1.35) = 0.39

Note: g(0→1) + g(0→2) > 1 because paths share arcs
```

## Error Handling

```rust
use lling_llang::differentiable::{forward_score, GradientWfst};

let grad_fst = GradientWfst::from_wfst(&fst);
let score = forward_score(&grad_fst);

if score.is_zero() {
    // No paths from start to final states
    // This can happen with:
    // - Empty WFST
    // - Disconnected start/final
    // - Empty intersection
    println!("Warning: No valid paths in WFST");
}

// Check for numerical issues
if score.value().is_nan() || score.value().is_infinite() {
    println!("Warning: Numerical instability detected");
}
```

## Performance Tips

1. **Use topological order**: For acyclic graphs, topological sort gives O(|E|) complexity
2. **Batch operations**: Compute multiple forward scores before backward passes
3. **Cache forward scores**: The backward pass reuses α values
4. **Consider Viterbi**: For max-margin training, Viterbi gradients are sparse (1.0 or 0.0)
5. **Reset between uses**: Call `grad_fst.reset()` when reusing with different inputs

## Next Steps

- [Deep Learning Integration]deep-learning.md: Using differentiable WFSTs with neural networks
- [CTC Topologies]ctc-topologies.md: Building CTC loss functions
- [Weight Pushing]../algorithms/weight-pushing.md: Optimizing WFSTs for differentiable ops
- [ASR Pipeline]asr-pipeline.md: End-to-end speech recognition training