1use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
46
47#[derive(Debug, Clone, Copy)]
49pub struct InterchangeDecodeForward<'a> {
50 pub z: ArrayView2<'a, f64>,
51 pub weights: ArrayView2<'a, f64>,
52 pub gate: ArrayView1<'a, f64>,
53 pub bias: Option<ArrayView1<'a, f64>>,
54}
55
56#[derive(Debug, Clone, Copy)]
58pub struct InterchangeSwapForward<'a> {
59 pub z_a: ArrayView2<'a, f64>,
60 pub z_b: ArrayView2<'a, f64>,
61 pub mask: ArrayView1<'a, bool>,
62 pub weights: ArrayView2<'a, f64>,
63 pub gate: ArrayView1<'a, f64>,
64 pub bias: Option<ArrayView1<'a, f64>>,
65}
66
67#[derive(Debug, Clone)]
69pub struct InterchangeDecodeBackward {
70 pub grad_z: Array2<f64>,
71 pub grad_weights: Array2<f64>,
72 pub grad_gate: Array1<f64>,
73 pub grad_bias: Option<Array1<f64>>,
74}
75
76#[derive(Debug, Clone)]
78pub struct InterchangeSwapBackward {
79 pub grad_z_a: Array2<f64>,
80 pub grad_z_b: Array2<f64>,
81 pub grad_weights: Array2<f64>,
82 pub grad_gate: Array1<f64>,
83 pub grad_bias: Option<Array1<f64>>,
84}
85
86fn check_shapes_forward(
87 z_rows: usize,
88 z_cols: usize,
89 weights: ArrayView2<'_, f64>,
90 gate: ArrayView1<'_, f64>,
91 bias: Option<ArrayView1<'_, f64>>,
92) -> Result<(), String> {
93 let (d, f_weights) = weights.dim();
94 if f_weights != z_cols {
95 return Err(format!(
96 "interchange_decode: weights has F={f_weights}, expected {z_cols}"
97 ));
98 }
99 if gate.len() != z_cols {
100 return Err(format!(
101 "interchange_decode: gate has length {}, expected {z_cols}",
102 gate.len()
103 ));
104 }
105 if let Some(b) = bias
106 && b.len() != d
107 {
108 return Err(format!(
109 "interchange_decode: bias has length {}, expected D={d}",
110 b.len()
111 ));
112 }
113 if z_rows == 0 || z_cols == 0 {
114 return Err("interchange_decode: latent must be non-empty".to_string());
115 }
116 if !weights.iter().all(|v| v.is_finite()) {
117 return Err("interchange_decode: weights must be finite".to_string());
118 }
119 if !gate.iter().all(|v| v.is_finite()) {
120 return Err("interchange_decode: gate must be finite".to_string());
121 }
122 if let Some(b) = bias
123 && !b.iter().all(|v| v.is_finite())
124 {
125 return Err("interchange_decode: bias must be finite".to_string());
126 }
127 Ok(())
128}
129
130pub fn interchange_decode_forward(
132 inputs: InterchangeDecodeForward<'_>,
133) -> Result<Array2<f64>, String> {
134 let (b_rows, f) = inputs.z.dim();
135 check_shapes_forward(b_rows, f, inputs.weights, inputs.gate, inputs.bias)?;
136 if !inputs.z.iter().all(|v| v.is_finite()) {
137 return Err("interchange_decode: latent must be finite".to_string());
138 }
139
140 let d = inputs.weights.nrows();
141 let mut z_gated = Array2::<f64>::zeros((b_rows, f));
142 for i in 0..b_rows {
143 for j in 0..f {
144 z_gated[[i, j]] = inputs.z[[i, j]] * inputs.gate[j];
145 }
146 }
147 let mut out = z_gated.dot(&inputs.weights.t());
149 if let Some(bias) = inputs.bias {
150 for i in 0..b_rows {
151 for k in 0..d {
152 out[[i, k]] += bias[k];
153 }
154 }
155 }
156 Ok(out)
157}
158
159pub fn interchange_swap_forward(inputs: InterchangeSwapForward<'_>) -> Result<Array2<f64>, String> {
161 if inputs.z_a.dim() != inputs.z_b.dim() {
162 return Err(format!(
163 "interchange_swap: z_a {:?} and z_b {:?} must have the same shape",
164 inputs.z_a.dim(),
165 inputs.z_b.dim()
166 ));
167 }
168 let (b_rows, f) = inputs.z_a.dim();
169 if inputs.mask.len() != f {
170 return Err(format!(
171 "interchange_swap: mask length {} must equal F={f}",
172 inputs.mask.len()
173 ));
174 }
175 if !inputs.z_a.iter().all(|v| v.is_finite()) || !inputs.z_b.iter().all(|v| v.is_finite()) {
176 return Err("interchange_swap: latents must be finite".to_string());
177 }
178 let mut z_eff = Array2::<f64>::zeros((b_rows, f));
179 for j in 0..f {
180 let take_a = inputs.mask[j];
181 if take_a {
182 for i in 0..b_rows {
183 z_eff[[i, j]] = inputs.z_a[[i, j]];
184 }
185 } else {
186 for i in 0..b_rows {
187 z_eff[[i, j]] = inputs.z_b[[i, j]];
188 }
189 }
190 }
191 interchange_decode_forward(InterchangeDecodeForward {
192 z: z_eff.view(),
193 weights: inputs.weights,
194 gate: inputs.gate,
195 bias: inputs.bias,
196 })
197}
198
199pub fn interchange_decode_backward(
201 z: ArrayView2<'_, f64>,
202 weights: ArrayView2<'_, f64>,
203 gate: ArrayView1<'_, f64>,
204 grad_out: ArrayView2<'_, f64>,
205 with_bias: bool,
206) -> Result<InterchangeDecodeBackward, String> {
207 let (b_rows, f) = z.dim();
208 let (d, f_w) = weights.dim();
209 if f_w != f {
210 return Err(format!(
211 "interchange_decode_backward: weights has F={f_w}, expected {f}"
212 ));
213 }
214 if gate.len() != f {
215 return Err(format!(
216 "interchange_decode_backward: gate length {} != F={f}",
217 gate.len()
218 ));
219 }
220 if grad_out.dim() != (b_rows, d) {
221 return Err(format!(
222 "interchange_decode_backward: grad_out shape {:?} != ({b_rows}, {d})",
223 grad_out.dim()
224 ));
225 }
226
227 let g_mat = grad_out.dot(&weights); let mut grad_z = Array2::<f64>::zeros((b_rows, f));
232 for i in 0..b_rows {
233 for j in 0..f {
234 grad_z[[i, j]] = gate[j] * g_mat[[i, j]];
235 }
236 }
237
238 let mut grad_gate = Array1::<f64>::zeros(f);
240 for j in 0..f {
241 let mut acc = 0.0;
242 for i in 0..b_rows {
243 acc += z[[i, j]] * g_mat[[i, j]];
244 }
245 grad_gate[j] = acc;
246 }
247
248 let mut grad_weights = grad_out.t().dot(&z); for j in 0..f {
252 let scale = gate[j];
253 for k in 0..d {
254 grad_weights[[k, j]] *= scale;
255 }
256 }
257
258 let grad_bias = if with_bias {
259 let mut gb = Array1::<f64>::zeros(d);
260 for i in 0..b_rows {
261 for k in 0..d {
262 gb[k] += grad_out[[i, k]];
263 }
264 }
265 Some(gb)
266 } else {
267 None
268 };
269
270 Ok(InterchangeDecodeBackward {
271 grad_z,
272 grad_weights,
273 grad_gate,
274 grad_bias,
275 })
276}
277
278pub fn interchange_swap_backward(
280 z_a: ArrayView2<'_, f64>,
281 z_b: ArrayView2<'_, f64>,
282 mask: ArrayView1<'_, bool>,
283 weights: ArrayView2<'_, f64>,
284 gate: ArrayView1<'_, f64>,
285 grad_out: ArrayView2<'_, f64>,
286 with_bias: bool,
287) -> Result<InterchangeSwapBackward, String> {
288 if z_a.dim() != z_b.dim() {
289 return Err(format!(
290 "interchange_swap_backward: z_a {:?} and z_b {:?} must have the same shape",
291 z_a.dim(),
292 z_b.dim()
293 ));
294 }
295 let (b_rows, f) = z_a.dim();
296 if mask.len() != f {
297 return Err(format!(
298 "interchange_swap_backward: mask length {} != F={f}",
299 mask.len()
300 ));
301 }
302
303 let mut z_eff = Array2::<f64>::zeros((b_rows, f));
305 for j in 0..f {
306 let take_a = mask[j];
307 if take_a {
308 for i in 0..b_rows {
309 z_eff[[i, j]] = z_a[[i, j]];
310 }
311 } else {
312 for i in 0..b_rows {
313 z_eff[[i, j]] = z_b[[i, j]];
314 }
315 }
316 }
317 let inner = interchange_decode_backward(z_eff.view(), weights, gate, grad_out, with_bias)?;
318
319 let mut grad_z_a = Array2::<f64>::zeros((b_rows, f));
321 let mut grad_z_b = Array2::<f64>::zeros((b_rows, f));
322 for j in 0..f {
323 let take_a = mask[j];
324 if take_a {
325 for i in 0..b_rows {
326 grad_z_a[[i, j]] = inner.grad_z[[i, j]];
327 }
328 } else {
329 for i in 0..b_rows {
330 grad_z_b[[i, j]] = inner.grad_z[[i, j]];
331 }
332 }
333 }
334
335 Ok(InterchangeSwapBackward {
336 grad_z_a,
337 grad_z_b,
338 grad_weights: inner.grad_weights,
339 grad_gate: inner.grad_gate,
340 grad_bias: inner.grad_bias,
341 })
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347 use ndarray::{Array1, Array2, array};
348
349 fn approx_eq(a: &Array2<f64>, b: &Array2<f64>, tol: f64) -> bool {
350 if a.dim() != b.dim() {
351 return false;
352 }
353 a.iter().zip(b.iter()).all(|(x, y)| (x - y).abs() < tol)
354 }
355
356 #[test]
357 fn forward_matches_hand_recomputation() {
358 let z = array![[1.0, -2.0, 0.5], [0.0, 3.0, -1.0]];
359 let w = array![[0.1, 0.2, 0.3], [-0.4, 0.5, 0.6]];
360 let g = array![1.0, 0.5, -1.0];
361 let bias = array![0.01, -0.02];
362 let out = interchange_decode_forward(InterchangeDecodeForward {
363 z: z.view(),
364 weights: w.view(),
365 gate: g.view(),
366 bias: Some(bias.view()),
367 })
368 .unwrap();
369 let mut expected = Array2::<f64>::zeros((2, 2));
371 for i in 0..2 {
372 for k in 0..2 {
373 let mut acc = bias[k];
374 for j in 0..3 {
375 acc += g[j] * z[[i, j]] * w[[k, j]];
376 }
377 expected[[i, k]] = acc;
378 }
379 }
380 assert!(approx_eq(&out, &expected, 1e-12));
381 }
382
383 #[test]
384 fn swap_all_true_matches_z_a_forward() {
385 let z_a = array![[1.0, -2.0], [3.0, 0.5]];
386 let z_b = array![[10.0, 20.0], [-30.0, 40.0]];
387 let w = array![[0.1, 0.2], [0.3, -0.4], [0.5, 0.6]];
388 let g = array![0.7, -0.3];
389 let mask = Array1::from(vec![true, true]);
390 let swapped = interchange_swap_forward(InterchangeSwapForward {
391 z_a: z_a.view(),
392 z_b: z_b.view(),
393 mask: mask.view(),
394 weights: w.view(),
395 gate: g.view(),
396 bias: None,
397 })
398 .unwrap();
399 let plain = interchange_decode_forward(InterchangeDecodeForward {
400 z: z_a.view(),
401 weights: w.view(),
402 gate: g.view(),
403 bias: None,
404 })
405 .unwrap();
406 assert!(approx_eq(&swapped, &plain, 1e-12));
407 }
408
409 #[test]
410 fn swap_all_false_matches_z_b_forward() {
411 let z_a = array![[1.0, -2.0], [3.0, 0.5]];
412 let z_b = array![[10.0, 20.0], [-30.0, 40.0]];
413 let w = array![[0.1, 0.2], [0.3, -0.4]];
414 let g = array![0.7, -0.3];
415 let mask = Array1::from(vec![false, false]);
416 let swapped = interchange_swap_forward(InterchangeSwapForward {
417 z_a: z_a.view(),
418 z_b: z_b.view(),
419 mask: mask.view(),
420 weights: w.view(),
421 gate: g.view(),
422 bias: None,
423 })
424 .unwrap();
425 let plain = interchange_decode_forward(InterchangeDecodeForward {
426 z: z_b.view(),
427 weights: w.view(),
428 gate: g.view(),
429 bias: None,
430 })
431 .unwrap();
432 assert!(approx_eq(&swapped, &plain, 1e-12));
433 }
434
435 #[test]
436 fn backward_matches_finite_differences() {
437 let z = array![[0.4, -0.7, 1.1], [0.2, 0.8, -0.3]];
438 let w = array![[0.1, 0.2, 0.3], [-0.4, 0.5, 0.6]];
439 let g = array![0.6, -0.2, 1.3];
440 let bias = array![0.05, -0.01];
441 let grad_out = array![[1.0, -0.5], [0.3, 0.8]];
442
443 let an = interchange_decode_backward(z.view(), w.view(), g.view(), grad_out.view(), true)
444 .unwrap();
445
446 let eps = 1e-6;
449 for i in 0..z.nrows() {
450 for j in 0..z.ncols() {
451 let mut zp = z.clone();
452 let mut zm = z.clone();
453 zp[[i, j]] += eps;
454 zm[[i, j]] -= eps;
455 let fp = interchange_decode_forward(InterchangeDecodeForward {
456 z: zp.view(),
457 weights: w.view(),
458 gate: g.view(),
459 bias: Some(bias.view()),
460 })
461 .unwrap();
462 let fm = interchange_decode_forward(InterchangeDecodeForward {
463 z: zm.view(),
464 weights: w.view(),
465 gate: g.view(),
466 bias: Some(bias.view()),
467 })
468 .unwrap();
469 let lp: f64 = fp.iter().zip(grad_out.iter()).map(|(a, b)| a * b).sum();
470 let lm: f64 = fm.iter().zip(grad_out.iter()).map(|(a, b)| a * b).sum();
471 let fd = (lp - lm) / (2.0 * eps);
472 assert!(
473 (an.grad_z[[i, j]] - fd).abs() < 1e-7,
474 "grad_z mismatch at ({i},{j}): analytic {} vs fd {}",
475 an.grad_z[[i, j]],
476 fd
477 );
478 }
479 }
480 for j in 0..g.len() {
482 let mut gp = g.clone();
483 let mut gm = g.clone();
484 gp[j] += eps;
485 gm[j] -= eps;
486 let fp = interchange_decode_forward(InterchangeDecodeForward {
487 z: z.view(),
488 weights: w.view(),
489 gate: gp.view(),
490 bias: Some(bias.view()),
491 })
492 .unwrap();
493 let fm = interchange_decode_forward(InterchangeDecodeForward {
494 z: z.view(),
495 weights: w.view(),
496 gate: gm.view(),
497 bias: Some(bias.view()),
498 })
499 .unwrap();
500 let lp: f64 = fp.iter().zip(grad_out.iter()).map(|(a, b)| a * b).sum();
501 let lm: f64 = fm.iter().zip(grad_out.iter()).map(|(a, b)| a * b).sum();
502 let fd = (lp - lm) / (2.0 * eps);
503 assert!(
504 (an.grad_gate[j] - fd).abs() < 1e-7,
505 "grad_gate mismatch at {j}: analytic {} vs fd {}",
506 an.grad_gate[j],
507 fd
508 );
509 }
510 for d in 0..w.nrows() {
512 for j in 0..w.ncols() {
513 let mut wp = w.clone();
514 let mut wm = w.clone();
515 wp[[d, j]] += eps;
516 wm[[d, j]] -= eps;
517 let fp = interchange_decode_forward(InterchangeDecodeForward {
518 z: z.view(),
519 weights: wp.view(),
520 gate: g.view(),
521 bias: Some(bias.view()),
522 })
523 .unwrap();
524 let fm = interchange_decode_forward(InterchangeDecodeForward {
525 z: z.view(),
526 weights: wm.view(),
527 gate: g.view(),
528 bias: Some(bias.view()),
529 })
530 .unwrap();
531 let lp: f64 = fp.iter().zip(grad_out.iter()).map(|(a, b)| a * b).sum();
532 let lm: f64 = fm.iter().zip(grad_out.iter()).map(|(a, b)| a * b).sum();
533 let fd = (lp - lm) / (2.0 * eps);
534 assert!(
535 (an.grad_weights[[d, j]] - fd).abs() < 1e-7,
536 "grad_W mismatch at ({d},{j}): analytic {} vs fd {}",
537 an.grad_weights[[d, j]],
538 fd
539 );
540 }
541 }
542 let bias_grad = an.grad_bias.as_ref().unwrap();
544 for d in 0..bias.len() {
545 let mut bp = bias.clone();
546 let mut bm = bias.clone();
547 bp[d] += eps;
548 bm[d] -= eps;
549 let fp = interchange_decode_forward(InterchangeDecodeForward {
550 z: z.view(),
551 weights: w.view(),
552 gate: g.view(),
553 bias: Some(bp.view()),
554 })
555 .unwrap();
556 let fm = interchange_decode_forward(InterchangeDecodeForward {
557 z: z.view(),
558 weights: w.view(),
559 gate: g.view(),
560 bias: Some(bm.view()),
561 })
562 .unwrap();
563 let lp: f64 = fp.iter().zip(grad_out.iter()).map(|(a, b)| a * b).sum();
564 let lm: f64 = fm.iter().zip(grad_out.iter()).map(|(a, b)| a * b).sum();
565 let fd = (lp - lm) / (2.0 * eps);
566 assert!(
567 (bias_grad[d] - fd).abs() < 1e-7,
568 "grad_bias mismatch at {d}: analytic {} vs fd {}",
569 bias_grad[d],
570 fd
571 );
572 }
573 }
574
575 #[test]
576 fn swap_backward_routes_grad_through_mask() {
577 let z_a = array![[1.0, 2.0, 3.0]];
578 let z_b = array![[-1.0, -2.0, -3.0]];
579 let w = array![[0.5, 0.25, -0.1]];
580 let g = array![1.0, 0.5, -1.0];
581 let mask = Array1::from(vec![true, false, true]);
582 let grad_out = array![[1.0]];
583 let bk = interchange_swap_backward(
584 z_a.view(),
585 z_b.view(),
586 mask.view(),
587 w.view(),
588 g.view(),
589 grad_out.view(),
590 false,
591 )
592 .unwrap();
593 assert!((bk.grad_z_a[[0, 0]] - 1.0 * 0.5).abs() < 1e-12);
596 assert!((bk.grad_z_a[[0, 1]] - 0.0).abs() < 1e-12);
597 assert!((bk.grad_z_a[[0, 2]] - (-1.0) * (-0.1)).abs() < 1e-12);
598 assert!((bk.grad_z_b[[0, 0]] - 0.0).abs() < 1e-12);
599 assert!((bk.grad_z_b[[0, 1]] - 0.5 * 0.25).abs() < 1e-12);
600 assert!((bk.grad_z_b[[0, 2]] - 0.0).abs() < 1e-12);
601 }
602}