1use crate::error::{SeqError, SeqResult};
44use crate::handle::LcgRng;
45use crate::hmm::forward_backward::logsumexp;
46
47#[derive(Debug, Clone)]
58pub struct NeuralCrf {
59 pub n_tags: usize,
61 pub input_dim: usize,
63 pub hidden_dim: usize,
65 pub w1: Vec<f64>,
67 pub b1: Vec<f64>,
69 pub w2: Vec<f64>,
71 pub b2: Vec<f64>,
73 pub transitions: Vec<f64>,
75}
76
77#[derive(Debug, Clone)]
81pub struct NeuralCrfGrad {
82 pub w1: Vec<f64>,
84 pub b1: Vec<f64>,
86 pub w2: Vec<f64>,
88 pub b2: Vec<f64>,
90 pub transitions: Vec<f64>,
92}
93
94#[derive(Debug, Clone)]
100pub struct NeuralCrfForward {
101 pub t_max: usize,
103 pub hidden: Vec<f64>,
105 pub emit: Vec<f64>,
107}
108
109impl NeuralCrf {
110 pub fn zeros(n_tags: usize, input_dim: usize, hidden_dim: usize) -> SeqResult<Self> {
114 if n_tags == 0 || input_dim == 0 || hidden_dim == 0 {
115 return Err(SeqError::InvalidConfiguration(
116 "n_tags, input_dim and hidden_dim must all be > 0".to_string(),
117 ));
118 }
119 Ok(Self {
120 n_tags,
121 input_dim,
122 hidden_dim,
123 w1: vec![0.0; hidden_dim * input_dim],
124 b1: vec![0.0; hidden_dim],
125 w2: vec![0.0; n_tags * hidden_dim],
126 b2: vec![0.0; n_tags],
127 transitions: vec![0.0; n_tags * n_tags],
128 })
129 }
130
131 pub fn new(
136 n_tags: usize,
137 input_dim: usize,
138 hidden_dim: usize,
139 scale: f64,
140 rng: &mut LcgRng,
141 ) -> SeqResult<Self> {
142 if !scale.is_finite() || scale <= 0.0 {
143 return Err(SeqError::InvalidParameter {
144 name: "scale".to_string(),
145 value: scale,
146 });
147 }
148 let mut net = Self::zeros(n_tags, input_dim, hidden_dim)?;
149 for v in net.w1.iter_mut() {
150 *v = rng.next_range(-scale, scale);
151 }
152 for v in net.w2.iter_mut() {
153 *v = rng.next_range(-scale, scale);
154 }
155 Ok(net)
156 }
157
158 pub fn param_count(&self) -> usize {
160 self.w1.len() + self.b1.len() + self.w2.len() + self.b2.len() + self.transitions.len()
161 }
162
163 fn check_input(&self, x: &[f64]) -> SeqResult<usize> {
166 if x.is_empty() {
167 return Err(SeqError::EmptyInput);
168 }
169 if x.len() % self.input_dim != 0 {
170 return Err(SeqError::DimensionMismatch {
171 a: x.len(),
172 b: self.input_dim,
173 });
174 }
175 Ok(x.len() / self.input_dim)
176 }
177
178 pub fn forward(&self, x: &[f64]) -> SeqResult<NeuralCrfForward> {
182 let t_max = self.check_input(x)?;
183 let d = self.input_dim;
184 let hh = self.hidden_dim;
185 let k = self.n_tags;
186 let mut hidden = vec![0.0; t_max * hh];
187 let mut emit = vec![0.0; t_max * k];
188 for t in 0..t_max {
189 let xt = &x[t * d..(t + 1) * d];
190 for h in 0..hh {
192 let mut acc = self.b1[h];
193 let row = h * d;
194 for (dd, &xv) in xt.iter().enumerate() {
195 acc += self.w1[row + dd] * xv;
196 }
197 hidden[t * hh + h] = acc.tanh();
198 }
199 for tag in 0..k {
201 let mut acc = self.b2[tag];
202 let row = tag * hh;
203 for h in 0..hh {
204 acc += self.w2[row + h] * hidden[t * hh + h];
205 }
206 emit[t * k + tag] = acc;
207 }
208 }
209 Ok(NeuralCrfForward {
210 t_max,
211 hidden,
212 emit,
213 })
214 }
215
216 fn sequence_score(&self, emit: &[f64], y: &[usize]) -> SeqResult<f64> {
218 let k = self.n_tags;
219 let t_max = y.len();
220 if t_max == 0 {
221 return Err(SeqError::EmptyInput);
222 }
223 if emit.len() != t_max * k {
224 return Err(SeqError::ShapeMismatch {
225 expected: t_max * k,
226 got: emit.len(),
227 });
228 }
229 let mut s = 0.0;
230 for t in 0..t_max {
231 let yt = y[t];
232 if yt >= k {
233 return Err(SeqError::IndexOutOfBounds { index: yt, len: k });
234 }
235 s += emit[t * k + yt];
236 if t > 0 {
237 s += self.transitions[y[t - 1] * k + yt];
238 }
239 }
240 Ok(s)
241 }
242
243 pub fn log_partition(&self, emit: &[f64]) -> SeqResult<f64> {
248 let alpha = self.forward_scores(emit)?;
249 let k = self.n_tags;
250 let t_max = emit.len() / k;
251 Ok(logsumexp(&alpha[(t_max - 1) * k..]))
252 }
253
254 fn forward_scores(&self, emit: &[f64]) -> SeqResult<Vec<f64>> {
256 let k = self.n_tags;
257 if emit.is_empty() || emit.len() % k != 0 {
258 return Err(SeqError::DimensionMismatch {
259 a: emit.len(),
260 b: k,
261 });
262 }
263 let t_max = emit.len() / k;
264 let mut alpha = vec![f64::NEG_INFINITY; t_max * k];
265 alpha[..k].copy_from_slice(&emit[..k]);
266 let mut tmp = vec![0.0; k];
267 for t in 1..t_max {
268 for j in 0..k {
269 for i in 0..k {
270 tmp[i] = alpha[(t - 1) * k + i] + self.transitions[i * k + j];
271 }
272 alpha[t * k + j] = logsumexp(&tmp) + emit[t * k + j];
273 }
274 }
275 Ok(alpha)
276 }
277
278 fn backward_scores(&self, emit: &[f64]) -> Vec<f64> {
280 let k = self.n_tags;
281 let t_max = emit.len() / k;
282 let mut beta = vec![0.0; t_max * k];
283 let mut tmp = vec![0.0; k];
284 for t in (0..t_max.saturating_sub(1)).rev() {
285 for i in 0..k {
286 for j in 0..k {
287 tmp[j] =
288 self.transitions[i * k + j] + emit[(t + 1) * k + j] + beta[(t + 1) * k + j];
289 }
290 beta[t * k + i] = logsumexp(&tmp);
291 }
292 }
293 beta
294 }
295
296 pub fn nll_from_forward(&self, fwd: &NeuralCrfForward, y: &[usize]) -> SeqResult<f64> {
299 if y.len() != fwd.t_max {
300 return Err(SeqError::LengthMismatch {
301 a: y.len(),
302 b: fwd.t_max,
303 });
304 }
305 let score = self.sequence_score(&fwd.emit, y)?;
306 let log_z = self.log_partition(&fwd.emit)?;
307 Ok(log_z - score)
308 }
309
310 pub fn nll(&self, x: &[f64], y: &[usize]) -> SeqResult<f64> {
312 let fwd = self.forward(x)?;
313 self.nll_from_forward(&fwd, y)
314 }
315
316 pub fn decode(&self, x: &[f64]) -> SeqResult<Vec<usize>> {
318 let fwd = self.forward(x)?;
319 self.viterbi(&fwd.emit)
320 }
321
322 fn viterbi(&self, emit: &[f64]) -> SeqResult<Vec<usize>> {
324 let k = self.n_tags;
325 if emit.is_empty() || emit.len() % k != 0 {
326 return Err(SeqError::DimensionMismatch {
327 a: emit.len(),
328 b: k,
329 });
330 }
331 let t_max = emit.len() / k;
332 let mut delta = vec![f64::NEG_INFINITY; t_max * k];
333 let mut psi = vec![0usize; t_max * k];
334 delta[..k].copy_from_slice(&emit[..k]);
335 for t in 1..t_max {
336 for j in 0..k {
337 let mut best = f64::NEG_INFINITY;
338 let mut argmax = 0usize;
339 for i in 0..k {
340 let v = delta[(t - 1) * k + i] + self.transitions[i * k + j];
341 if v > best {
342 best = v;
343 argmax = i;
344 }
345 }
346 delta[t * k + j] = best + emit[t * k + j];
347 psi[t * k + j] = argmax;
348 }
349 }
350 let mut best = f64::NEG_INFINITY;
351 let mut last = 0usize;
352 for j in 0..k {
353 let v = delta[(t_max - 1) * k + j];
354 if v > best {
355 best = v;
356 last = j;
357 }
358 }
359 let mut path = vec![0usize; t_max];
360 path[t_max - 1] = last;
361 for t in (1..t_max).rev() {
362 path[t - 1] = psi[t * k + path[t]];
363 }
364 Ok(path)
365 }
366
367 fn marginals(&self, emit: &[f64]) -> SeqResult<(Vec<f64>, Vec<f64>)> {
373 let k = self.n_tags;
374 let alpha = self.forward_scores(emit)?;
375 let beta = self.backward_scores(emit);
376 let t_max = emit.len() / k;
377 let log_z = logsumexp(&alpha[(t_max - 1) * k..]);
378
379 let mut p_node = vec![0.0; t_max * k];
380 for t in 0..t_max {
381 for j in 0..k {
382 p_node[t * k + j] = (alpha[t * k + j] + beta[t * k + j] - log_z).exp();
383 }
384 let s: f64 = p_node[t * k..t * k + k].iter().sum();
385 if s > 0.0 {
386 for v in p_node[t * k..t * k + k].iter_mut() {
387 *v /= s;
388 }
389 }
390 }
391
392 let edges = t_max.saturating_sub(1);
393 let mut p_edge = vec![0.0; edges * k * k];
394 for t in 0..edges {
395 let mut s = 0.0;
396 for i in 0..k {
397 for j in 0..k {
398 let v = (alpha[t * k + i]
399 + self.transitions[i * k + j]
400 + emit[(t + 1) * k + j]
401 + beta[(t + 1) * k + j]
402 - log_z)
403 .exp();
404 p_edge[t * k * k + i * k + j] = v;
405 s += v;
406 }
407 }
408 if s > 0.0 {
409 for v in p_edge[t * k * k..(t + 1) * k * k].iter_mut() {
410 *v /= s;
411 }
412 }
413 }
414 Ok((p_node, p_edge))
415 }
416
417 pub fn backward(
424 &self,
425 x: &[f64],
426 fwd: &NeuralCrfForward,
427 y: &[usize],
428 ) -> SeqResult<(f64, NeuralCrfGrad)> {
429 let t_max = self.check_input(x)?;
430 if t_max != fwd.t_max {
431 return Err(SeqError::LengthMismatch {
432 a: t_max,
433 b: fwd.t_max,
434 });
435 }
436 if y.len() != t_max {
437 return Err(SeqError::LengthMismatch {
438 a: y.len(),
439 b: t_max,
440 });
441 }
442 let k = self.n_tags;
443 let hh = self.hidden_dim;
444 let d = self.input_dim;
445 for &yt in y {
446 if yt >= k {
447 return Err(SeqError::IndexOutOfBounds { index: yt, len: k });
448 }
449 }
450
451 let (p_node, p_edge) = self.marginals(&fwd.emit)?;
452 let nll = self.nll_from_forward(fwd, y)?;
453
454 let mut g_emit = p_node.clone();
456 for t in 0..t_max {
457 g_emit[t * k + y[t]] -= 1.0;
458 }
459
460 let mut g_trans = vec![0.0; k * k];
462 for t in 0..t_max.saturating_sub(1) {
463 for i in 0..k {
464 for j in 0..k {
465 g_trans[i * k + j] += p_edge[t * k * k + i * k + j];
466 }
467 }
468 g_trans[y[t] * k + y[t + 1]] -= 1.0;
469 }
470
471 let mut g_w1 = vec![0.0; hh * d];
473 let mut g_b1 = vec![0.0; hh];
474 let mut g_w2 = vec![0.0; k * hh];
475 let mut g_b2 = vec![0.0; k];
476
477 for t in 0..t_max {
478 let xt = &x[t * d..(t + 1) * d];
479 let h_t = &fwd.hidden[t * hh..(t + 1) * hh];
480 for tag in 0..k {
482 let ge = g_emit[t * k + tag];
483 g_b2[tag] += ge;
484 let row = tag * hh;
485 for h in 0..hh {
486 g_w2[row + h] += ge * h_t[h];
487 }
488 }
489 for h in 0..hh {
491 let mut g_h = 0.0;
492 for tag in 0..k {
493 g_h += g_emit[t * k + tag] * self.w2[tag * hh + h];
494 }
495 let g_pre = g_h * (1.0 - h_t[h] * h_t[h]);
497 g_b1[h] += g_pre;
498 let row = h * d;
499 for (dd, &xv) in xt.iter().enumerate() {
500 g_w1[row + dd] += g_pre * xv;
501 }
502 }
503 }
504
505 Ok((
506 nll,
507 NeuralCrfGrad {
508 w1: g_w1,
509 b1: g_b1,
510 w2: g_w2,
511 b2: g_b2,
512 transitions: g_trans,
513 },
514 ))
515 }
516
517 pub fn step(&mut self, x: &[f64], y: &[usize], lr: f64) -> SeqResult<f64> {
522 if !lr.is_finite() || lr <= 0.0 {
523 return Err(SeqError::InvalidParameter {
524 name: "lr".to_string(),
525 value: lr,
526 });
527 }
528 let fwd = self.forward(x)?;
529 let (nll, grad) = self.backward(x, &fwd, y)?;
530 for (w, g) in self.w1.iter_mut().zip(grad.w1.iter()) {
531 *w -= lr * g;
532 }
533 for (w, g) in self.b1.iter_mut().zip(grad.b1.iter()) {
534 *w -= lr * g;
535 }
536 for (w, g) in self.w2.iter_mut().zip(grad.w2.iter()) {
537 *w -= lr * g;
538 }
539 for (w, g) in self.b2.iter_mut().zip(grad.b2.iter()) {
540 *w -= lr * g;
541 }
542 for (w, g) in self.transitions.iter_mut().zip(grad.transitions.iter()) {
543 *w -= lr * g;
544 }
545 Ok(nll)
546 }
547}
548
549#[cfg(test)]
550mod tests {
551 use super::*;
552
553 fn brute_log_partition(net: &NeuralCrf, emit: &[f64]) -> f64 {
555 let k = net.n_tags;
556 let t_max = emit.len() / k;
557 let mut scores: Vec<f64> = Vec::new();
558 let mut y = vec![0usize; t_max];
559 loop {
560 let s = net.sequence_score(emit, &y).expect("score");
561 scores.push(s);
562 let mut pos = 0;
564 loop {
565 if pos == t_max {
566 return logsumexp(&scores);
567 }
568 y[pos] += 1;
569 if y[pos] < k {
570 break;
571 }
572 y[pos] = 0;
573 pos += 1;
574 }
575 }
576 }
577
578 fn brute_viterbi(net: &NeuralCrf, emit: &[f64]) -> Vec<usize> {
580 let k = net.n_tags;
581 let t_max = emit.len() / k;
582 let mut best_y = vec![0usize; t_max];
583 let mut best_s = f64::NEG_INFINITY;
584 let mut y = vec![0usize; t_max];
585 loop {
586 let s = net.sequence_score(emit, &y).expect("score");
587 if s > best_s {
588 best_s = s;
589 best_y = y.clone();
590 }
591 let mut pos = 0;
592 loop {
593 if pos == t_max {
594 return best_y;
595 }
596 y[pos] += 1;
597 if y[pos] < k {
598 break;
599 }
600 y[pos] = 0;
601 pos += 1;
602 }
603 }
604 }
605
606 fn toy_net() -> NeuralCrf {
607 let mut rng = LcgRng::new(7);
608 let mut net = NeuralCrf::new(3, 4, 5, 0.4, &mut rng).expect("net");
609 for (i, v) in net.transitions.iter_mut().enumerate() {
610 *v = ((i as f64) * 0.13 - 0.2).sin() * 0.3;
611 }
612 for v in net.b2.iter_mut() {
613 *v = 0.1;
614 }
615 net
616 }
617
618 fn toy_features(net: &NeuralCrf, t_max: usize, seed: u64) -> Vec<f64> {
619 let mut rng = LcgRng::new(seed);
620 (0..t_max * net.input_dim)
621 .map(|_| rng.next_range(-1.0, 1.0))
622 .collect()
623 }
624
625 #[test]
626 fn construct_validates_dims() {
627 assert!(NeuralCrf::zeros(0, 2, 2).is_err());
628 assert!(NeuralCrf::zeros(2, 0, 2).is_err());
629 assert!(NeuralCrf::zeros(2, 2, 0).is_err());
630 let net = NeuralCrf::zeros(3, 4, 5).expect("ok");
631 assert_eq!(net.param_count(), 5 * 4 + 5 + 3 * 5 + 3 + 3 * 3);
632 }
633
634 #[test]
635 fn new_rejects_bad_scale() {
636 let mut rng = LcgRng::new(1);
637 assert!(NeuralCrf::new(2, 2, 2, 0.0, &mut rng).is_err());
638 assert!(NeuralCrf::new(2, 2, 2, -1.0, &mut rng).is_err());
639 assert!(NeuralCrf::new(2, 2, 2, f64::NAN, &mut rng).is_err());
640 }
641
642 #[test]
643 fn forward_shapes_and_emit_match_manual() {
644 let net = toy_net();
645 let x = toy_features(&net, 4, 11);
646 let fwd = net.forward(&x).expect("fwd");
647 assert_eq!(fwd.t_max, 4);
648 assert_eq!(fwd.hidden.len(), 4 * net.hidden_dim);
649 assert_eq!(fwd.emit.len(), 4 * net.n_tags);
650 let d = net.input_dim;
652 let hh = net.hidden_dim;
653 let t = 2usize;
654 let tag = 1usize;
655 let mut acc = net.b2[tag];
656 for h in 0..hh {
657 let mut pre = net.b1[h];
658 for dd in 0..d {
659 pre += net.w1[h * d + dd] * x[t * d + dd];
660 }
661 acc += net.w2[tag * hh + h] * pre.tanh();
662 }
663 assert!((acc - fwd.emit[t * net.n_tags + tag]).abs() < 1e-12);
664 }
665
666 #[test]
667 fn log_partition_matches_brute_force() {
668 let net = toy_net();
669 for (seed, t_max) in [(3u64, 2usize), (5, 3), (9, 4)] {
670 let x = toy_features(&net, t_max, seed);
671 let fwd = net.forward(&x).expect("fwd");
672 let via_forward = net.log_partition(&fwd.emit).expect("logz");
673 let via_brute = brute_log_partition(&net, &fwd.emit);
674 assert!(
675 (via_forward - via_brute).abs() < 1e-9,
676 "T={t_max}: forward={via_forward}, brute={via_brute}"
677 );
678 }
679 }
680
681 #[test]
682 fn viterbi_matches_brute_force_argmax() {
683 let net = toy_net();
684 for (seed, t_max) in [(2u64, 2usize), (4, 3), (6, 4), (8, 5)] {
685 let x = toy_features(&net, t_max, seed);
686 let fwd = net.forward(&x).expect("fwd");
687 let path = net.viterbi(&fwd.emit).expect("viterbi");
688 let brute = brute_viterbi(&net, &fwd.emit);
689 let s_path = net.sequence_score(&fwd.emit, &path).expect("s");
691 let s_brute = net.sequence_score(&fwd.emit, &brute).expect("s");
692 assert!((s_path - s_brute).abs() < 1e-9, "T={t_max}");
693 assert_eq!(path, brute, "T={t_max}");
694 }
695 }
696
697 #[test]
698 fn decode_returns_in_range_path() {
699 let net = toy_net();
700 let x = toy_features(&net, 6, 21);
701 let path = net.decode(&x).expect("decode");
702 assert_eq!(path.len(), 6);
703 assert!(path.iter().all(|&p| p < net.n_tags));
704 }
705
706 #[test]
707 fn nll_is_nonnegative_and_consistent() {
708 let net = toy_net();
709 let x = toy_features(&net, 4, 31);
710 let y = vec![0usize, 2, 1, 0];
711 let direct = net.nll(&x, &y).expect("nll");
712 let fwd = net.forward(&x).expect("fwd");
713 let cached = net.nll_from_forward(&fwd, &y).expect("nll2");
714 assert!((direct - cached).abs() < 1e-12);
715 assert!(direct >= -1e-9, "nll={direct}");
717 }
718
719 #[test]
720 fn emission_and_transition_gradients_match_finite_difference() {
721 let net = toy_net();
722 let x = toy_features(&net, 4, 41);
723 let y = vec![1usize, 0, 2, 1];
724 let fwd = net.forward(&x).expect("fwd");
725 let (_, grad) = net.backward(&x, &fwd, &y).expect("bwd");
726
727 let eps = 1e-6;
728 let central = |perturb: &dyn Fn(&mut NeuralCrf, f64)| -> f64 {
730 let mut up = net.clone();
731 perturb(&mut up, eps);
732 let mut dn = net.clone();
733 perturb(&mut dn, -eps);
734 let lp = up.nll(&x, &y).expect("nll+");
735 let lm = dn.nll(&x, &y).expect("nll-");
736 (lp - lm) / (2.0 * eps)
737 };
738
739 for idx in 0..net.w1.len() {
740 let num = central(&|n, e| n.w1[idx] += e);
741 assert!(
742 (num - grad.w1[idx]).abs() < 1e-4,
743 "w1[{idx}] num={num} ana={}",
744 grad.w1[idx]
745 );
746 }
747 for idx in 0..net.w2.len() {
748 let num = central(&|n, e| n.w2[idx] += e);
749 assert!(
750 (num - grad.w2[idx]).abs() < 1e-4,
751 "w2[{idx}] num={num} ana={}",
752 grad.w2[idx]
753 );
754 }
755 for idx in 0..net.b1.len() {
756 let num = central(&|n, e| n.b1[idx] += e);
757 assert!(
758 (num - grad.b1[idx]).abs() < 1e-4,
759 "b1[{idx}] num={num} ana={}",
760 grad.b1[idx]
761 );
762 }
763 for idx in 0..net.b2.len() {
764 let num = central(&|n, e| n.b2[idx] += e);
765 assert!(
766 (num - grad.b2[idx]).abs() < 1e-4,
767 "b2[{idx}] num={num} ana={}",
768 grad.b2[idx]
769 );
770 }
771 for idx in 0..net.transitions.len() {
772 let num = central(&|n, e| n.transitions[idx] += e);
773 assert!(
774 (num - grad.transitions[idx]).abs() < 1e-4,
775 "trans[{idx}] num={num} ana={}",
776 grad.transitions[idx]
777 );
778 }
779 }
780
781 #[test]
782 fn training_reduces_nll_on_toy_sequence() {
783 let mut net = toy_net();
784 let x = toy_features(&net, 5, 51);
785 let y = vec![0usize, 1, 2, 1, 0];
786 let nll0 = net.nll(&x, &y).expect("nll0");
787 for _ in 0..200 {
788 net.step(&x, &y, 0.05).expect("step");
789 }
790 let nll1 = net.nll(&x, &y).expect("nll1");
791 assert!(nll1 < nll0 - 1e-3, "nll0={nll0}, nll1={nll1}");
792 let path = net.decode(&x).expect("decode");
794 assert_eq!(path, y);
795 }
796
797 #[test]
798 fn step_validates_learning_rate() {
799 let mut net = toy_net();
800 let x = toy_features(&net, 3, 61);
801 let y = vec![0usize, 1, 2];
802 assert!(net.step(&x, &y, 0.0).is_err());
803 assert!(net.step(&x, &y, -0.1).is_err());
804 }
805
806 #[test]
807 fn input_validation_paths() {
808 let net = toy_net();
809 assert!(net.forward(&[]).is_err());
811 let bad = vec![0.0; net.input_dim * 2 + 1];
813 assert!(net.forward(&bad).is_err());
814 let x = toy_features(&net, 2, 71);
816 assert!(net.nll(&x, &[0, net.n_tags]).is_err());
817 assert!(net.nll(&x, &[0]).is_err());
819 }
820
821 #[test]
822 fn marginals_form_valid_distributions() {
823 let net = toy_net();
824 let x = toy_features(&net, 4, 81);
825 let fwd = net.forward(&x).expect("fwd");
826 let (p_node, p_edge) = net.marginals(&fwd.emit).expect("marg");
827 let k = net.n_tags;
828 for t in 0..fwd.t_max {
829 let s: f64 = p_node[t * k..t * k + k].iter().sum();
830 assert!((s - 1.0).abs() < 1e-9, "node t={t} sum={s}");
831 assert!(p_node[t * k..t * k + k].iter().all(|&p| p >= -1e-12));
832 }
833 for t in 0..fwd.t_max - 1 {
834 let s: f64 = p_edge[t * k * k..(t + 1) * k * k].iter().sum();
835 assert!((s - 1.0).abs() < 1e-9, "edge t={t} sum={s}");
836 }
837 }
838}