Skip to main content

oxicuda_seq/crf/
neural_crf.rs

1//! Neural linear-chain Conditional Random Field.
2//!
3//! Reference: Collobert, R., Weston, J., Bottou, L., Karlen, M., Kavukcuoglu, K.
4//! & Kuksa, P. (2011). *Natural Language Processing (Almost) from Scratch*.
5//! JMLR 12, 2493–2537 — the "sentence-level log-likelihood" (SLL) network, which
6//! couples a neural feature extractor with a CRF-style transition matrix and
7//! trains end-to-end with the forward (partition-function) algorithm.
8//!
9//! # Model
10//!
11//! A linear-chain CRF whose **emission / unary scores come from a multilayer
12//! perceptron (MLP)** over per-position input features rather than from a linear
13//! score over sparse features:
14//!
15//! ```text
16//! h_t   = tanh(W1 · x_t + b1)              (hidden activations, per position t)
17//! e_t   = W2 · h_t + b2                    (emission scores, one per tag k)
18//! ```
19//!
20//! together with a learned `K × K` transition matrix `A` (`A[i][j]` = score of a
21//! transition from tag `i` to tag `j`). The score of a full tag sequence is
22//!
23//! ```text
24//! s(y, x) = Σ_t e_t[y_t]  +  Σ_{t>0} A[y_{t-1}][y_t]
25//! ```
26//!
27//! and the conditional probability is `p(y|x) = exp(s(y, x)) / Z(x)` where the
28//! log-partition `log Z(x)` is computed by the **forward algorithm in log-space**
29//! (log-sum-exp). Decoding uses **Viterbi**. Training minimises the negative
30//! log-likelihood (NLL); the emission gradient
31//! `∂NLL/∂e_t[k] = p(y_t = k | x) − 1[gold_t = k]` (from forward–backward
32//! marginals) is back-propagated through the MLP, and
33//! `∂NLL/∂A[i][j] = Σ_t p(y_{t-1}=i, y_t=j | x) − count_gold(i → j)`.
34//!
35//! The CRF inference layer (log-space forward, backward, Viterbi) is implemented
36//! here directly on the dense emission tensor `e[t][k]`, mirroring the score-space
37//! forward–backward used by [`crate::crf::crf_train`] but operating on neural
38//! emissions instead of linear feature scores.
39//!
40//! Production code never panics: every fallible path validates its inputs and
41//! returns [`SeqError`].
42
43use crate::error::{SeqError, SeqResult};
44use crate::handle::LcgRng;
45use crate::hmm::forward_backward::logsumexp;
46
47/// A neural linear-chain CRF: an MLP emission scorer plus a `K × K` transition
48/// matrix, trained end-to-end against the CRF negative log-likelihood.
49///
50/// Parameter layout (all row-major, `f64`):
51///
52/// * `w1[h * input_dim + d]` — input→hidden weight (`hidden_dim × input_dim`)
53/// * `b1[h]` — hidden bias (`hidden_dim`)
54/// * `w2[k * hidden_dim + h]` — hidden→tag weight (`n_tags × hidden_dim`)
55/// * `b2[k]` — tag bias (`n_tags`)
56/// * `transitions[i * n_tags + j]` — transition score `i → j` (`n_tags × n_tags`)
57#[derive(Debug, Clone)]
58pub struct NeuralCrf {
59    /// Number of output tags (`K`).
60    pub n_tags: usize,
61    /// Dimensionality of the per-position input feature vector.
62    pub input_dim: usize,
63    /// Hidden-layer width of the emission MLP.
64    pub hidden_dim: usize,
65    /// Input→hidden weight matrix (`hidden_dim × input_dim`).
66    pub w1: Vec<f64>,
67    /// Hidden bias (`hidden_dim`).
68    pub b1: Vec<f64>,
69    /// Hidden→tag weight matrix (`n_tags × hidden_dim`).
70    pub w2: Vec<f64>,
71    /// Tag bias (`n_tags`).
72    pub b2: Vec<f64>,
73    /// Transition score matrix (`n_tags × n_tags`), row = previous tag.
74    pub transitions: Vec<f64>,
75}
76
77/// Gradients of the NLL with respect to every parameter of a [`NeuralCrf`].
78///
79/// Field shapes mirror the corresponding [`NeuralCrf`] parameter arrays exactly.
80#[derive(Debug, Clone)]
81pub struct NeuralCrfGrad {
82    /// Gradient w.r.t. `w1`.
83    pub w1: Vec<f64>,
84    /// Gradient w.r.t. `b1`.
85    pub b1: Vec<f64>,
86    /// Gradient w.r.t. `w2`.
87    pub w2: Vec<f64>,
88    /// Gradient w.r.t. `b2`.
89    pub b2: Vec<f64>,
90    /// Gradient w.r.t. `transitions`.
91    pub transitions: Vec<f64>,
92}
93
94/// Cached intermediates from a forward pass, reused by the backward pass.
95///
96/// Holds the per-position hidden activations and emission scores so the backward
97/// pass can back-propagate the emission gradient through the MLP without a second
98/// forward evaluation.
99#[derive(Debug, Clone)]
100pub struct NeuralCrfForward {
101    /// Number of positions `T`.
102    pub t_max: usize,
103    /// Hidden activations after `tanh`, flattened `T × hidden_dim`.
104    pub hidden: Vec<f64>,
105    /// Emission scores `e[t][k]`, flattened `T × n_tags`.
106    pub emit: Vec<f64>,
107}
108
109impl NeuralCrf {
110    /// Construct a zero-initialised neural CRF.
111    ///
112    /// All dimensions must be positive; otherwise [`SeqError::InvalidConfiguration`].
113    pub fn zeros(n_tags: usize, input_dim: usize, hidden_dim: usize) -> SeqResult<Self> {
114        if n_tags == 0 || input_dim == 0 || hidden_dim == 0 {
115            return Err(SeqError::InvalidConfiguration(
116                "n_tags, input_dim and hidden_dim must all be > 0".to_string(),
117            ));
118        }
119        Ok(Self {
120            n_tags,
121            input_dim,
122            hidden_dim,
123            w1: vec![0.0; hidden_dim * input_dim],
124            b1: vec![0.0; hidden_dim],
125            w2: vec![0.0; n_tags * hidden_dim],
126            b2: vec![0.0; n_tags],
127            transitions: vec![0.0; n_tags * n_tags],
128        })
129    }
130
131    /// Construct a neural CRF with small random weights drawn from a seeded LCG.
132    ///
133    /// Weights are sampled `~ U(-scale, scale)`; biases and the transition matrix
134    /// start at zero. `scale` must be finite and positive.
135    pub fn new(
136        n_tags: usize,
137        input_dim: usize,
138        hidden_dim: usize,
139        scale: f64,
140        rng: &mut LcgRng,
141    ) -> SeqResult<Self> {
142        if !scale.is_finite() || scale <= 0.0 {
143            return Err(SeqError::InvalidParameter {
144                name: "scale".to_string(),
145                value: scale,
146            });
147        }
148        let mut net = Self::zeros(n_tags, input_dim, hidden_dim)?;
149        for v in net.w1.iter_mut() {
150            *v = rng.next_range(-scale, scale);
151        }
152        for v in net.w2.iter_mut() {
153            *v = rng.next_range(-scale, scale);
154        }
155        Ok(net)
156    }
157
158    /// Total number of free parameters.
159    pub fn param_count(&self) -> usize {
160        self.w1.len() + self.b1.len() + self.w2.len() + self.b2.len() + self.transitions.len()
161    }
162
163    /// Validate that an input feature buffer has the expected length for `t_max`
164    /// positions and return `t_max`.
165    fn check_input(&self, x: &[f64]) -> SeqResult<usize> {
166        if x.is_empty() {
167            return Err(SeqError::EmptyInput);
168        }
169        if x.len() % self.input_dim != 0 {
170            return Err(SeqError::DimensionMismatch {
171                a: x.len(),
172                b: self.input_dim,
173            });
174        }
175        Ok(x.len() / self.input_dim)
176    }
177
178    /// Run the MLP emission scorer over an input feature matrix `x`
179    /// (`T × input_dim`, row-major), returning cached hidden activations and the
180    /// emission tensor `e[t][k]` (`T × n_tags`).
181    pub fn forward(&self, x: &[f64]) -> SeqResult<NeuralCrfForward> {
182        let t_max = self.check_input(x)?;
183        let d = self.input_dim;
184        let hh = self.hidden_dim;
185        let k = self.n_tags;
186        let mut hidden = vec![0.0; t_max * hh];
187        let mut emit = vec![0.0; t_max * k];
188        for t in 0..t_max {
189            let xt = &x[t * d..(t + 1) * d];
190            // Hidden layer: h = tanh(W1 x + b1)
191            for h in 0..hh {
192                let mut acc = self.b1[h];
193                let row = h * d;
194                for (dd, &xv) in xt.iter().enumerate() {
195                    acc += self.w1[row + dd] * xv;
196                }
197                hidden[t * hh + h] = acc.tanh();
198            }
199            // Output layer: e = W2 h + b2
200            for tag in 0..k {
201                let mut acc = self.b2[tag];
202                let row = tag * hh;
203                for h in 0..hh {
204                    acc += self.w2[row + h] * hidden[t * hh + h];
205                }
206                emit[t * k + tag] = acc;
207            }
208        }
209        Ok(NeuralCrfForward {
210            t_max,
211            hidden,
212            emit,
213        })
214    }
215
216    /// Score of a full tag sequence given emission scores: Σ emissions + Σ transitions.
217    fn sequence_score(&self, emit: &[f64], y: &[usize]) -> SeqResult<f64> {
218        let k = self.n_tags;
219        let t_max = y.len();
220        if t_max == 0 {
221            return Err(SeqError::EmptyInput);
222        }
223        if emit.len() != t_max * k {
224            return Err(SeqError::ShapeMismatch {
225                expected: t_max * k,
226                got: emit.len(),
227            });
228        }
229        let mut s = 0.0;
230        for t in 0..t_max {
231            let yt = y[t];
232            if yt >= k {
233                return Err(SeqError::IndexOutOfBounds { index: yt, len: k });
234            }
235            s += emit[t * k + yt];
236            if t > 0 {
237                s += self.transitions[y[t - 1] * k + yt];
238            }
239        }
240        Ok(s)
241    }
242
243    /// Log-partition `log Z(x)` via the forward algorithm in log-space.
244    ///
245    /// `alpha_t(j) = logsumexp_i(alpha_{t-1}(i) + A[i][j]) + e_t[j]`,
246    /// `log Z = logsumexp_j alpha_{T-1}(j)`.
247    pub fn log_partition(&self, emit: &[f64]) -> SeqResult<f64> {
248        let alpha = self.forward_scores(emit)?;
249        let k = self.n_tags;
250        let t_max = emit.len() / k;
251        Ok(logsumexp(&alpha[(t_max - 1) * k..]))
252    }
253
254    /// Forward log-scores `alpha[t][j]` over the dense emission tensor.
255    fn forward_scores(&self, emit: &[f64]) -> SeqResult<Vec<f64>> {
256        let k = self.n_tags;
257        if emit.is_empty() || emit.len() % k != 0 {
258            return Err(SeqError::DimensionMismatch {
259                a: emit.len(),
260                b: k,
261            });
262        }
263        let t_max = emit.len() / k;
264        let mut alpha = vec![f64::NEG_INFINITY; t_max * k];
265        alpha[..k].copy_from_slice(&emit[..k]);
266        let mut tmp = vec![0.0; k];
267        for t in 1..t_max {
268            for j in 0..k {
269                for i in 0..k {
270                    tmp[i] = alpha[(t - 1) * k + i] + self.transitions[i * k + j];
271                }
272                alpha[t * k + j] = logsumexp(&tmp) + emit[t * k + j];
273            }
274        }
275        Ok(alpha)
276    }
277
278    /// Backward log-scores `beta[t][i]` over the dense emission tensor.
279    fn backward_scores(&self, emit: &[f64]) -> Vec<f64> {
280        let k = self.n_tags;
281        let t_max = emit.len() / k;
282        let mut beta = vec![0.0; t_max * k];
283        let mut tmp = vec![0.0; k];
284        for t in (0..t_max.saturating_sub(1)).rev() {
285            for i in 0..k {
286                for j in 0..k {
287                    tmp[j] =
288                        self.transitions[i * k + j] + emit[(t + 1) * k + j] + beta[(t + 1) * k + j];
289                }
290                beta[t * k + i] = logsumexp(&tmp);
291            }
292        }
293        beta
294    }
295
296    /// Negative log-likelihood `−log p(y | x) = log Z(x) − s(y, x)` from a cached
297    /// forward pass.
298    pub fn nll_from_forward(&self, fwd: &NeuralCrfForward, y: &[usize]) -> SeqResult<f64> {
299        if y.len() != fwd.t_max {
300            return Err(SeqError::LengthMismatch {
301                a: y.len(),
302                b: fwd.t_max,
303            });
304        }
305        let score = self.sequence_score(&fwd.emit, y)?;
306        let log_z = self.log_partition(&fwd.emit)?;
307        Ok(log_z - score)
308    }
309
310    /// Negative log-likelihood of a gold tag sequence given input features `x`.
311    pub fn nll(&self, x: &[f64], y: &[usize]) -> SeqResult<f64> {
312        let fwd = self.forward(x)?;
313        self.nll_from_forward(&fwd, y)
314    }
315
316    /// Decode the highest-scoring tag path with Viterbi over the emission tensor.
317    pub fn decode(&self, x: &[f64]) -> SeqResult<Vec<usize>> {
318        let fwd = self.forward(x)?;
319        self.viterbi(&fwd.emit)
320    }
321
322    /// Viterbi decoding directly on a dense emission tensor `e[t][k]`.
323    fn viterbi(&self, emit: &[f64]) -> SeqResult<Vec<usize>> {
324        let k = self.n_tags;
325        if emit.is_empty() || emit.len() % k != 0 {
326            return Err(SeqError::DimensionMismatch {
327                a: emit.len(),
328                b: k,
329            });
330        }
331        let t_max = emit.len() / k;
332        let mut delta = vec![f64::NEG_INFINITY; t_max * k];
333        let mut psi = vec![0usize; t_max * k];
334        delta[..k].copy_from_slice(&emit[..k]);
335        for t in 1..t_max {
336            for j in 0..k {
337                let mut best = f64::NEG_INFINITY;
338                let mut argmax = 0usize;
339                for i in 0..k {
340                    let v = delta[(t - 1) * k + i] + self.transitions[i * k + j];
341                    if v > best {
342                        best = v;
343                        argmax = i;
344                    }
345                }
346                delta[t * k + j] = best + emit[t * k + j];
347                psi[t * k + j] = argmax;
348            }
349        }
350        let mut best = f64::NEG_INFINITY;
351        let mut last = 0usize;
352        for j in 0..k {
353            let v = delta[(t_max - 1) * k + j];
354            if v > best {
355                best = v;
356                last = j;
357            }
358        }
359        let mut path = vec![0usize; t_max];
360        path[t_max - 1] = last;
361        for t in (1..t_max).rev() {
362            path[t - 1] = psi[t * k + path[t]];
363        }
364        Ok(path)
365    }
366
367    /// Node and edge posterior marginals from forward–backward.
368    ///
369    /// Returns `(p_node, p_edge)` where `p_node[t][j] = p(y_t = j | x)`
370    /// (`T × n_tags`) and `p_edge[t][i][j] = p(y_t = i, y_{t+1} = j | x)`
371    /// (`(T−1) × n_tags × n_tags`).
372    fn marginals(&self, emit: &[f64]) -> SeqResult<(Vec<f64>, Vec<f64>)> {
373        let k = self.n_tags;
374        let alpha = self.forward_scores(emit)?;
375        let beta = self.backward_scores(emit);
376        let t_max = emit.len() / k;
377        let log_z = logsumexp(&alpha[(t_max - 1) * k..]);
378
379        let mut p_node = vec![0.0; t_max * k];
380        for t in 0..t_max {
381            for j in 0..k {
382                p_node[t * k + j] = (alpha[t * k + j] + beta[t * k + j] - log_z).exp();
383            }
384            let s: f64 = p_node[t * k..t * k + k].iter().sum();
385            if s > 0.0 {
386                for v in p_node[t * k..t * k + k].iter_mut() {
387                    *v /= s;
388                }
389            }
390        }
391
392        let edges = t_max.saturating_sub(1);
393        let mut p_edge = vec![0.0; edges * k * k];
394        for t in 0..edges {
395            let mut s = 0.0;
396            for i in 0..k {
397                for j in 0..k {
398                    let v = (alpha[t * k + i]
399                        + self.transitions[i * k + j]
400                        + emit[(t + 1) * k + j]
401                        + beta[(t + 1) * k + j]
402                        - log_z)
403                        .exp();
404                    p_edge[t * k * k + i * k + j] = v;
405                    s += v;
406                }
407            }
408            if s > 0.0 {
409                for v in p_edge[t * k * k..(t + 1) * k * k].iter_mut() {
410                    *v /= s;
411                }
412            }
413        }
414        Ok((p_node, p_edge))
415    }
416
417    /// Back-propagate the NLL gradient given a cached forward pass and the gold tags.
418    ///
419    /// Returns the NLL and a [`NeuralCrfGrad`]. The emission gradient
420    /// `g_e[t][k] = p(y_t = k | x) − 1[gold_t = k]` is back-propagated through the
421    /// MLP (`tanh` derivative `1 − h²`); the transition gradient is
422    /// `Σ_t p(y_{t-1}=i, y_t=j | x) − count_gold(i → j)`.
423    pub fn backward(
424        &self,
425        x: &[f64],
426        fwd: &NeuralCrfForward,
427        y: &[usize],
428    ) -> SeqResult<(f64, NeuralCrfGrad)> {
429        let t_max = self.check_input(x)?;
430        if t_max != fwd.t_max {
431            return Err(SeqError::LengthMismatch {
432                a: t_max,
433                b: fwd.t_max,
434            });
435        }
436        if y.len() != t_max {
437            return Err(SeqError::LengthMismatch {
438                a: y.len(),
439                b: t_max,
440            });
441        }
442        let k = self.n_tags;
443        let hh = self.hidden_dim;
444        let d = self.input_dim;
445        for &yt in y {
446            if yt >= k {
447                return Err(SeqError::IndexOutOfBounds { index: yt, len: k });
448            }
449        }
450
451        let (p_node, p_edge) = self.marginals(&fwd.emit)?;
452        let nll = self.nll_from_forward(fwd, y)?;
453
454        // Emission-score gradient: g_e[t][k] = p(y_t = k) − 1[gold_t = k].
455        let mut g_emit = p_node.clone();
456        for t in 0..t_max {
457            g_emit[t * k + y[t]] -= 1.0;
458        }
459
460        // Transition gradient: Σ_t p_edge[t][i][j] − count_gold(i → j).
461        let mut g_trans = vec![0.0; k * k];
462        for t in 0..t_max.saturating_sub(1) {
463            for i in 0..k {
464                for j in 0..k {
465                    g_trans[i * k + j] += p_edge[t * k * k + i * k + j];
466                }
467            }
468            g_trans[y[t] * k + y[t + 1]] -= 1.0;
469        }
470
471        // Back-propagate g_emit through the output and hidden layers.
472        let mut g_w1 = vec![0.0; hh * d];
473        let mut g_b1 = vec![0.0; hh];
474        let mut g_w2 = vec![0.0; k * hh];
475        let mut g_b2 = vec![0.0; k];
476
477        for t in 0..t_max {
478            let xt = &x[t * d..(t + 1) * d];
479            let h_t = &fwd.hidden[t * hh..(t + 1) * hh];
480            // Output layer: e = W2 h + b2.
481            for tag in 0..k {
482                let ge = g_emit[t * k + tag];
483                g_b2[tag] += ge;
484                let row = tag * hh;
485                for h in 0..hh {
486                    g_w2[row + h] += ge * h_t[h];
487                }
488            }
489            // Gradient flowing into hidden activations, then through tanh.
490            for h in 0..hh {
491                let mut g_h = 0.0;
492                for tag in 0..k {
493                    g_h += g_emit[t * k + tag] * self.w2[tag * hh + h];
494                }
495                // d tanh / d pre = 1 − tanh²; here h_t[h] is tanh(pre).
496                let g_pre = g_h * (1.0 - h_t[h] * h_t[h]);
497                g_b1[h] += g_pre;
498                let row = h * d;
499                for (dd, &xv) in xt.iter().enumerate() {
500                    g_w1[row + dd] += g_pre * xv;
501                }
502            }
503        }
504
505        Ok((
506            nll,
507            NeuralCrfGrad {
508                w1: g_w1,
509                b1: g_b1,
510                w2: g_w2,
511                b2: g_b2,
512                transitions: g_trans,
513            },
514        ))
515    }
516
517    /// Apply one gradient-descent step with learning rate `lr` on a single example.
518    ///
519    /// Computes the forward pass, back-propagates the NLL gradient, updates every
520    /// parameter in place, and returns the NLL *before* the update.
521    pub fn step(&mut self, x: &[f64], y: &[usize], lr: f64) -> SeqResult<f64> {
522        if !lr.is_finite() || lr <= 0.0 {
523            return Err(SeqError::InvalidParameter {
524                name: "lr".to_string(),
525                value: lr,
526            });
527        }
528        let fwd = self.forward(x)?;
529        let (nll, grad) = self.backward(x, &fwd, y)?;
530        for (w, g) in self.w1.iter_mut().zip(grad.w1.iter()) {
531            *w -= lr * g;
532        }
533        for (w, g) in self.b1.iter_mut().zip(grad.b1.iter()) {
534            *w -= lr * g;
535        }
536        for (w, g) in self.w2.iter_mut().zip(grad.w2.iter()) {
537            *w -= lr * g;
538        }
539        for (w, g) in self.b2.iter_mut().zip(grad.b2.iter()) {
540            *w -= lr * g;
541        }
542        for (w, g) in self.transitions.iter_mut().zip(grad.transitions.iter()) {
543            *w -= lr * g;
544        }
545        Ok(nll)
546    }
547}
548
549#[cfg(test)]
550mod tests {
551    use super::*;
552
553    /// Brute-force log-partition: log-sum-exp of the score over all `K^T` paths.
554    fn brute_log_partition(net: &NeuralCrf, emit: &[f64]) -> f64 {
555        let k = net.n_tags;
556        let t_max = emit.len() / k;
557        let mut scores: Vec<f64> = Vec::new();
558        let mut y = vec![0usize; t_max];
559        loop {
560            let s = net.sequence_score(emit, &y).expect("score");
561            scores.push(s);
562            // Odometer increment over the K^T tag grid.
563            let mut pos = 0;
564            loop {
565                if pos == t_max {
566                    return logsumexp(&scores);
567                }
568                y[pos] += 1;
569                if y[pos] < k {
570                    break;
571                }
572                y[pos] = 0;
573                pos += 1;
574            }
575        }
576    }
577
578    /// Brute-force argmax path by exhaustive enumeration.
579    fn brute_viterbi(net: &NeuralCrf, emit: &[f64]) -> Vec<usize> {
580        let k = net.n_tags;
581        let t_max = emit.len() / k;
582        let mut best_y = vec![0usize; t_max];
583        let mut best_s = f64::NEG_INFINITY;
584        let mut y = vec![0usize; t_max];
585        loop {
586            let s = net.sequence_score(emit, &y).expect("score");
587            if s > best_s {
588                best_s = s;
589                best_y = y.clone();
590            }
591            let mut pos = 0;
592            loop {
593                if pos == t_max {
594                    return best_y;
595                }
596                y[pos] += 1;
597                if y[pos] < k {
598                    break;
599                }
600                y[pos] = 0;
601                pos += 1;
602            }
603        }
604    }
605
606    fn toy_net() -> NeuralCrf {
607        let mut rng = LcgRng::new(7);
608        let mut net = NeuralCrf::new(3, 4, 5, 0.4, &mut rng).expect("net");
609        for (i, v) in net.transitions.iter_mut().enumerate() {
610            *v = ((i as f64) * 0.13 - 0.2).sin() * 0.3;
611        }
612        for v in net.b2.iter_mut() {
613            *v = 0.1;
614        }
615        net
616    }
617
618    fn toy_features(net: &NeuralCrf, t_max: usize, seed: u64) -> Vec<f64> {
619        let mut rng = LcgRng::new(seed);
620        (0..t_max * net.input_dim)
621            .map(|_| rng.next_range(-1.0, 1.0))
622            .collect()
623    }
624
625    #[test]
626    fn construct_validates_dims() {
627        assert!(NeuralCrf::zeros(0, 2, 2).is_err());
628        assert!(NeuralCrf::zeros(2, 0, 2).is_err());
629        assert!(NeuralCrf::zeros(2, 2, 0).is_err());
630        let net = NeuralCrf::zeros(3, 4, 5).expect("ok");
631        assert_eq!(net.param_count(), 5 * 4 + 5 + 3 * 5 + 3 + 3 * 3);
632    }
633
634    #[test]
635    fn new_rejects_bad_scale() {
636        let mut rng = LcgRng::new(1);
637        assert!(NeuralCrf::new(2, 2, 2, 0.0, &mut rng).is_err());
638        assert!(NeuralCrf::new(2, 2, 2, -1.0, &mut rng).is_err());
639        assert!(NeuralCrf::new(2, 2, 2, f64::NAN, &mut rng).is_err());
640    }
641
642    #[test]
643    fn forward_shapes_and_emit_match_manual() {
644        let net = toy_net();
645        let x = toy_features(&net, 4, 11);
646        let fwd = net.forward(&x).expect("fwd");
647        assert_eq!(fwd.t_max, 4);
648        assert_eq!(fwd.hidden.len(), 4 * net.hidden_dim);
649        assert_eq!(fwd.emit.len(), 4 * net.n_tags);
650        // Recompute emission for (t=2, tag=1) by hand.
651        let d = net.input_dim;
652        let hh = net.hidden_dim;
653        let t = 2usize;
654        let tag = 1usize;
655        let mut acc = net.b2[tag];
656        for h in 0..hh {
657            let mut pre = net.b1[h];
658            for dd in 0..d {
659                pre += net.w1[h * d + dd] * x[t * d + dd];
660            }
661            acc += net.w2[tag * hh + h] * pre.tanh();
662        }
663        assert!((acc - fwd.emit[t * net.n_tags + tag]).abs() < 1e-12);
664    }
665
666    #[test]
667    fn log_partition_matches_brute_force() {
668        let net = toy_net();
669        for (seed, t_max) in [(3u64, 2usize), (5, 3), (9, 4)] {
670            let x = toy_features(&net, t_max, seed);
671            let fwd = net.forward(&x).expect("fwd");
672            let via_forward = net.log_partition(&fwd.emit).expect("logz");
673            let via_brute = brute_log_partition(&net, &fwd.emit);
674            assert!(
675                (via_forward - via_brute).abs() < 1e-9,
676                "T={t_max}: forward={via_forward}, brute={via_brute}"
677            );
678        }
679    }
680
681    #[test]
682    fn viterbi_matches_brute_force_argmax() {
683        let net = toy_net();
684        for (seed, t_max) in [(2u64, 2usize), (4, 3), (6, 4), (8, 5)] {
685            let x = toy_features(&net, t_max, seed);
686            let fwd = net.forward(&x).expect("fwd");
687            let path = net.viterbi(&fwd.emit).expect("viterbi");
688            let brute = brute_viterbi(&net, &fwd.emit);
689            // Scores must coincide (path may differ only on exact ties).
690            let s_path = net.sequence_score(&fwd.emit, &path).expect("s");
691            let s_brute = net.sequence_score(&fwd.emit, &brute).expect("s");
692            assert!((s_path - s_brute).abs() < 1e-9, "T={t_max}");
693            assert_eq!(path, brute, "T={t_max}");
694        }
695    }
696
697    #[test]
698    fn decode_returns_in_range_path() {
699        let net = toy_net();
700        let x = toy_features(&net, 6, 21);
701        let path = net.decode(&x).expect("decode");
702        assert_eq!(path.len(), 6);
703        assert!(path.iter().all(|&p| p < net.n_tags));
704    }
705
706    #[test]
707    fn nll_is_nonnegative_and_consistent() {
708        let net = toy_net();
709        let x = toy_features(&net, 4, 31);
710        let y = vec![0usize, 2, 1, 0];
711        let direct = net.nll(&x, &y).expect("nll");
712        let fwd = net.forward(&x).expect("fwd");
713        let cached = net.nll_from_forward(&fwd, &y).expect("nll2");
714        assert!((direct - cached).abs() < 1e-12);
715        // NLL = log Z − score ≥ 0 since the gold score ≤ log Z.
716        assert!(direct >= -1e-9, "nll={direct}");
717    }
718
719    #[test]
720    fn emission_and_transition_gradients_match_finite_difference() {
721        let net = toy_net();
722        let x = toy_features(&net, 4, 41);
723        let y = vec![1usize, 0, 2, 1];
724        let fwd = net.forward(&x).expect("fwd");
725        let (_, grad) = net.backward(&x, &fwd, &y).expect("bwd");
726
727        let eps = 1e-6;
728        // Helper closures perturbing a chosen parameter array.
729        let central = |perturb: &dyn Fn(&mut NeuralCrf, f64)| -> f64 {
730            let mut up = net.clone();
731            perturb(&mut up, eps);
732            let mut dn = net.clone();
733            perturb(&mut dn, -eps);
734            let lp = up.nll(&x, &y).expect("nll+");
735            let lm = dn.nll(&x, &y).expect("nll-");
736            (lp - lm) / (2.0 * eps)
737        };
738
739        for idx in 0..net.w1.len() {
740            let num = central(&|n, e| n.w1[idx] += e);
741            assert!(
742                (num - grad.w1[idx]).abs() < 1e-4,
743                "w1[{idx}] num={num} ana={}",
744                grad.w1[idx]
745            );
746        }
747        for idx in 0..net.w2.len() {
748            let num = central(&|n, e| n.w2[idx] += e);
749            assert!(
750                (num - grad.w2[idx]).abs() < 1e-4,
751                "w2[{idx}] num={num} ana={}",
752                grad.w2[idx]
753            );
754        }
755        for idx in 0..net.b1.len() {
756            let num = central(&|n, e| n.b1[idx] += e);
757            assert!(
758                (num - grad.b1[idx]).abs() < 1e-4,
759                "b1[{idx}] num={num} ana={}",
760                grad.b1[idx]
761            );
762        }
763        for idx in 0..net.b2.len() {
764            let num = central(&|n, e| n.b2[idx] += e);
765            assert!(
766                (num - grad.b2[idx]).abs() < 1e-4,
767                "b2[{idx}] num={num} ana={}",
768                grad.b2[idx]
769            );
770        }
771        for idx in 0..net.transitions.len() {
772            let num = central(&|n, e| n.transitions[idx] += e);
773            assert!(
774                (num - grad.transitions[idx]).abs() < 1e-4,
775                "trans[{idx}] num={num} ana={}",
776                grad.transitions[idx]
777            );
778        }
779    }
780
781    #[test]
782    fn training_reduces_nll_on_toy_sequence() {
783        let mut net = toy_net();
784        let x = toy_features(&net, 5, 51);
785        let y = vec![0usize, 1, 2, 1, 0];
786        let nll0 = net.nll(&x, &y).expect("nll0");
787        for _ in 0..200 {
788            net.step(&x, &y, 0.05).expect("step");
789        }
790        let nll1 = net.nll(&x, &y).expect("nll1");
791        assert!(nll1 < nll0 - 1e-3, "nll0={nll0}, nll1={nll1}");
792        // After training the decoded path should match the gold sequence.
793        let path = net.decode(&x).expect("decode");
794        assert_eq!(path, y);
795    }
796
797    #[test]
798    fn step_validates_learning_rate() {
799        let mut net = toy_net();
800        let x = toy_features(&net, 3, 61);
801        let y = vec![0usize, 1, 2];
802        assert!(net.step(&x, &y, 0.0).is_err());
803        assert!(net.step(&x, &y, -0.1).is_err());
804    }
805
806    #[test]
807    fn input_validation_paths() {
808        let net = toy_net();
809        // Empty input.
810        assert!(net.forward(&[]).is_err());
811        // Ragged input length (not a multiple of input_dim).
812        let bad = vec![0.0; net.input_dim * 2 + 1];
813        assert!(net.forward(&bad).is_err());
814        // Gold tag out of range.
815        let x = toy_features(&net, 2, 71);
816        assert!(net.nll(&x, &[0, net.n_tags]).is_err());
817        // Length mismatch between y and T.
818        assert!(net.nll(&x, &[0]).is_err());
819    }
820
821    #[test]
822    fn marginals_form_valid_distributions() {
823        let net = toy_net();
824        let x = toy_features(&net, 4, 81);
825        let fwd = net.forward(&x).expect("fwd");
826        let (p_node, p_edge) = net.marginals(&fwd.emit).expect("marg");
827        let k = net.n_tags;
828        for t in 0..fwd.t_max {
829            let s: f64 = p_node[t * k..t * k + k].iter().sum();
830            assert!((s - 1.0).abs() < 1e-9, "node t={t} sum={s}");
831            assert!(p_node[t * k..t * k + k].iter().all(|&p| p >= -1e-12));
832        }
833        for t in 0..fwd.t_max - 1 {
834            let s: f64 = p_edge[t * k * k..(t + 1) * k * k].iter().sum();
835            assert!((s - 1.0).abs() < 1e-9, "edge t={t} sum={s}");
836        }
837    }
838}