1use crate::error::{SeqError, SeqResult};
32use crate::handle::LcgRng;
33
34#[derive(Debug, Clone)]
45pub struct PointerNetwork {
46 pub hidden_dim: usize,
48 pub attn_dim: usize,
50 pub input_dim: usize,
52 pub w1: Vec<f64>,
54 pub w2: Vec<f64>,
56 pub v: Vec<f64>,
58 pub enc_wx: Vec<f64>,
60 pub enc_wh: Vec<f64>,
62 pub enc_b: Vec<f64>,
64}
65
66#[derive(Debug, Clone)]
72pub struct PointerGrad {
73 pub w1: Vec<f64>,
75 pub w2: Vec<f64>,
77 pub v: Vec<f64>,
79}
80
81impl PointerNetwork {
82 pub fn zeros(hidden_dim: usize, attn_dim: usize, input_dim: usize) -> SeqResult<Self> {
86 if hidden_dim == 0 || attn_dim == 0 || input_dim == 0 {
87 return Err(SeqError::InvalidConfiguration(
88 "hidden_dim, attn_dim and input_dim must all be > 0".to_string(),
89 ));
90 }
91 Ok(Self {
92 hidden_dim,
93 attn_dim,
94 input_dim,
95 w1: vec![0.0; attn_dim * hidden_dim],
96 w2: vec![0.0; attn_dim * hidden_dim],
97 v: vec![0.0; attn_dim],
98 enc_wx: vec![0.0; hidden_dim * input_dim],
99 enc_wh: vec![0.0; hidden_dim * hidden_dim],
100 enc_b: vec![0.0; hidden_dim],
101 })
102 }
103
104 pub fn new(
109 hidden_dim: usize,
110 attn_dim: usize,
111 input_dim: usize,
112 scale: f64,
113 rng: &mut LcgRng,
114 ) -> SeqResult<Self> {
115 if !scale.is_finite() || scale <= 0.0 {
116 return Err(SeqError::InvalidParameter {
117 name: "scale".to_string(),
118 value: scale,
119 });
120 }
121 let mut net = Self::zeros(hidden_dim, attn_dim, input_dim)?;
122 for buf in [&mut net.w1, &mut net.w2, &mut net.enc_wx, &mut net.enc_wh] {
123 for v in buf.iter_mut() {
124 *v = rng.next_range(-scale, scale);
125 }
126 }
127 for v in net.v.iter_mut() {
128 *v = rng.next_range(-scale, scale);
129 }
130 Ok(net)
131 }
132
133 fn n_positions(&self, encoder_states: &[f64]) -> SeqResult<usize> {
135 if encoder_states.is_empty() {
136 return Err(SeqError::EmptyInput);
137 }
138 if encoder_states.len() % self.hidden_dim != 0 {
139 return Err(SeqError::DimensionMismatch {
140 a: encoder_states.len(),
141 b: self.hidden_dim,
142 });
143 }
144 Ok(encoder_states.len() / self.hidden_dim)
145 }
146
147 pub fn encode(&self, inputs: &[f64]) -> SeqResult<Vec<f64>> {
154 if inputs.is_empty() {
155 return Err(SeqError::EmptyInput);
156 }
157 if inputs.len() % self.input_dim != 0 {
158 return Err(SeqError::DimensionMismatch {
159 a: inputs.len(),
160 b: self.input_dim,
161 });
162 }
163 let n = inputs.len() / self.input_dim;
164 let hh = self.hidden_dim;
165 let d = self.input_dim;
166 let mut states = vec![0.0; n * hh];
167 let mut prev = vec![0.0; hh];
168 for t in 0..n {
169 let xt = &inputs[t * d..(t + 1) * d];
170 for h in 0..hh {
171 let mut acc = self.enc_b[h];
172 let rx = h * d;
173 for (dd, &xv) in xt.iter().enumerate() {
174 acc += self.enc_wx[rx + dd] * xv;
175 }
176 let rh = h * hh;
177 for (h2, &pv) in prev.iter().enumerate() {
178 acc += self.enc_wh[rh + h2] * pv;
179 }
180 states[t * hh + h] = acc.tanh();
181 }
182 prev.copy_from_slice(&states[t * hh..(t + 1) * hh]);
183 }
184 Ok(states)
185 }
186
187 fn project_encoder(&self, encoder_states: &[f64]) -> SeqResult<Vec<f64>> {
191 let n = self.n_positions(encoder_states)?;
192 let a = self.attn_dim;
193 let hh = self.hidden_dim;
194 let mut proj = vec![0.0; n * a];
195 for j in 0..n {
196 let ej = &encoder_states[j * hh..(j + 1) * hh];
197 for aa in 0..a {
198 let mut acc = 0.0;
199 let row = aa * hh;
200 for (h, &ev) in ej.iter().enumerate() {
201 acc += self.w1[row + h] * ev;
202 }
203 proj[j * a + aa] = acc;
204 }
205 }
206 Ok(proj)
207 }
208
209 fn project_query(&self, query: &[f64]) -> SeqResult<Vec<f64>> {
211 if query.len() != self.hidden_dim {
212 return Err(SeqError::ShapeMismatch {
213 expected: self.hidden_dim,
214 got: query.len(),
215 });
216 }
217 let a = self.attn_dim;
218 let hh = self.hidden_dim;
219 let mut q = vec![0.0; a];
220 for aa in 0..a {
221 let mut acc = 0.0;
222 let row = aa * hh;
223 for (h, &qv) in query.iter().enumerate() {
224 acc += self.w2[row + h] * qv;
225 }
226 q[aa] = acc;
227 }
228 Ok(q)
229 }
230
231 pub fn attention_logits(&self, encoder_states: &[f64], query: &[f64]) -> SeqResult<Vec<f64>> {
235 let proj = self.project_encoder(encoder_states)?;
236 let qp = self.project_query(query)?;
237 let n = self.n_positions(encoder_states)?;
238 let a = self.attn_dim;
239 let mut logits = vec![0.0; n];
240 for j in 0..n {
241 let mut acc = 0.0;
242 for aa in 0..a {
243 acc += self.v[aa] * (proj[j * a + aa] + qp[aa]).tanh();
244 }
245 logits[j] = acc;
246 }
247 Ok(logits)
248 }
249
250 fn softmax(logits: &[f64]) -> Vec<f64> {
252 let mut max = f64::NEG_INFINITY;
253 for &z in logits {
254 if z > max {
255 max = z;
256 }
257 }
258 if !max.is_finite() {
259 let n = logits.len().max(1);
261 return vec![1.0 / n as f64; logits.len()];
262 }
263 let mut probs: Vec<f64> = logits.iter().map(|&z| (z - max).exp()).collect();
264 let s: f64 = probs.iter().sum();
265 if s > 0.0 {
266 for p in probs.iter_mut() {
267 *p /= s;
268 }
269 }
270 probs
271 }
272
273 pub fn pointer_distribution(
276 &self,
277 encoder_states: &[f64],
278 query: &[f64],
279 ) -> SeqResult<Vec<f64>> {
280 let logits = self.attention_logits(encoder_states, query)?;
281 Ok(Self::softmax(&logits))
282 }
283
284 pub fn forward(&self, encoder_states: &[f64], queries: &[f64]) -> SeqResult<Vec<f64>> {
289 let n = self.n_positions(encoder_states)?;
290 let hh = self.hidden_dim;
291 if queries.is_empty() {
292 return Err(SeqError::EmptyInput);
293 }
294 if queries.len() % hh != 0 {
295 return Err(SeqError::DimensionMismatch {
296 a: queries.len(),
297 b: hh,
298 });
299 }
300 let m = queries.len() / hh;
301 let proj = self.project_encoder(encoder_states)?;
302 let a = self.attn_dim;
303 let mut out = vec![0.0; m * n];
304 for i in 0..m {
305 let qp = self.project_query(&queries[i * hh..(i + 1) * hh])?;
306 let mut logits = vec![0.0; n];
307 for j in 0..n {
308 let mut acc = 0.0;
309 for aa in 0..a {
310 acc += self.v[aa] * (proj[j * a + aa] + qp[aa]).tanh();
311 }
312 logits[j] = acc;
313 }
314 let probs = Self::softmax(&logits);
315 out[i * n..(i + 1) * n].copy_from_slice(&probs);
316 }
317 Ok(out)
318 }
319
320 pub fn decode(&self, encoder_states: &[f64], queries: &[f64]) -> SeqResult<Vec<usize>> {
324 let n = self.n_positions(encoder_states)?;
325 let probs = self.forward(encoder_states, queries)?;
326 let m = probs.len() / n;
327 let mut out = vec![0usize; m];
328 for i in 0..m {
329 let mut best = f64::NEG_INFINITY;
330 let mut argmax = 0usize;
331 for j in 0..n {
332 let p = probs[i * n + j];
333 if p > best {
334 best = p;
335 argmax = j;
336 }
337 }
338 out[i] = argmax;
339 }
340 Ok(out)
341 }
342
343 pub fn nll(
348 &self,
349 encoder_states: &[f64],
350 queries: &[f64],
351 targets: &[usize],
352 ) -> SeqResult<f64> {
353 let n = self.n_positions(encoder_states)?;
354 let probs = self.forward(encoder_states, queries)?;
355 let m = probs.len() / n;
356 if targets.len() != m {
357 return Err(SeqError::LengthMismatch {
358 a: targets.len(),
359 b: m,
360 });
361 }
362 let mut nll = 0.0;
363 for i in 0..m {
364 let tgt = targets[i];
365 if tgt >= n {
366 return Err(SeqError::IndexOutOfBounds { index: tgt, len: n });
367 }
368 let p = probs[i * n + tgt].max(1e-300);
369 nll -= p.ln();
370 }
371 Ok(nll)
372 }
373
374 pub fn backward(
381 &self,
382 encoder_states: &[f64],
383 queries: &[f64],
384 targets: &[usize],
385 ) -> SeqResult<(f64, PointerGrad)> {
386 let n = self.n_positions(encoder_states)?;
387 let hh = self.hidden_dim;
388 if queries.is_empty() || queries.len() % hh != 0 {
389 return Err(SeqError::DimensionMismatch {
390 a: queries.len(),
391 b: hh,
392 });
393 }
394 let m = queries.len() / hh;
395 if targets.len() != m {
396 return Err(SeqError::LengthMismatch {
397 a: targets.len(),
398 b: m,
399 });
400 }
401 for &t in targets {
402 if t >= n {
403 return Err(SeqError::IndexOutOfBounds { index: t, len: n });
404 }
405 }
406 let a = self.attn_dim;
407 let proj = self.project_encoder(encoder_states)?;
408
409 let mut g_w1 = vec![0.0; a * hh];
410 let mut g_w2 = vec![0.0; a * hh];
411 let mut g_v = vec![0.0; a];
412 let mut nll = 0.0;
413
414 for i in 0..m {
415 let qi = &queries[i * hh..(i + 1) * hh];
416 let qp = self.project_query(qi)?;
417 let mut s = vec![0.0; n * a];
419 let mut logits = vec![0.0; n];
420 for j in 0..n {
421 let mut acc = 0.0;
422 for aa in 0..a {
423 let pre = proj[j * a + aa] + qp[aa];
424 let th = pre.tanh();
425 s[j * a + aa] = th;
426 acc += self.v[aa] * th;
427 }
428 logits[j] = acc;
429 }
430 let probs = Self::softmax(&logits);
431 let tgt = targets[i];
432 nll -= probs[tgt].max(1e-300).ln();
433
434 for j in 0..n {
436 let d_logit = probs[j] - if j == tgt { 1.0 } else { 0.0 };
437 let ej = &encoder_states[j * hh..(j + 1) * hh];
438 for aa in 0..a {
439 g_v[aa] += d_logit * s[j * a + aa];
441 let d_pre = d_logit * self.v[aa] * (1.0 - s[j * a + aa] * s[j * a + aa]);
443 let row = aa * hh;
444 for h in 0..hh {
445 g_w1[row + h] += d_pre * ej[h];
447 g_w2[row + h] += d_pre * qi[h];
449 }
450 }
451 }
452 }
453
454 Ok((
455 nll,
456 PointerGrad {
457 w1: g_w1,
458 w2: g_w2,
459 v: g_v,
460 },
461 ))
462 }
463
464 pub fn step(
467 &mut self,
468 encoder_states: &[f64],
469 queries: &[f64],
470 targets: &[usize],
471 lr: f64,
472 ) -> SeqResult<f64> {
473 if !lr.is_finite() || lr <= 0.0 {
474 return Err(SeqError::InvalidParameter {
475 name: "lr".to_string(),
476 value: lr,
477 });
478 }
479 let (nll, grad) = self.backward(encoder_states, queries, targets)?;
480 for (w, g) in self.w1.iter_mut().zip(grad.w1.iter()) {
481 *w -= lr * g;
482 }
483 for (w, g) in self.w2.iter_mut().zip(grad.w2.iter()) {
484 *w -= lr * g;
485 }
486 for (w, g) in self.v.iter_mut().zip(grad.v.iter()) {
487 *w -= lr * g;
488 }
489 Ok(nll)
490 }
491}
492
493#[cfg(test)]
494mod tests {
495 use super::*;
496
497 fn rand_net(seed: u64) -> PointerNetwork {
498 let mut rng = LcgRng::new(seed);
499 PointerNetwork::new(3, 4, 2, 0.5, &mut rng).expect("net")
500 }
501
502 fn rand_states(net: &PointerNetwork, n: usize, seed: u64) -> Vec<f64> {
503 let mut rng = LcgRng::new(seed);
504 (0..n * net.hidden_dim)
505 .map(|_| rng.next_range(-1.0, 1.0))
506 .collect()
507 }
508
509 fn rand_queries(net: &PointerNetwork, m: usize, seed: u64) -> Vec<f64> {
510 let mut rng = LcgRng::new(seed);
511 (0..m * net.hidden_dim)
512 .map(|_| rng.next_range(-1.0, 1.0))
513 .collect()
514 }
515
516 #[test]
517 fn construct_validates_dims() {
518 assert!(PointerNetwork::zeros(0, 2, 2).is_err());
519 assert!(PointerNetwork::zeros(2, 0, 2).is_err());
520 assert!(PointerNetwork::zeros(2, 2, 0).is_err());
521 let mut rng = LcgRng::new(1);
522 assert!(PointerNetwork::new(2, 2, 2, 0.0, &mut rng).is_err());
523 assert!(PointerNetwork::new(2, 2, 2, f64::INFINITY, &mut rng).is_err());
524 }
525
526 #[test]
527 fn pointer_distribution_is_valid_simplex() {
528 let net = rand_net(2);
529 let states = rand_states(&net, 5, 3);
530 let query = rand_queries(&net, 1, 4);
531 let dist = net.pointer_distribution(&states, &query).expect("dist");
532 assert_eq!(dist.len(), 5);
533 assert!(dist.iter().all(|&p| (0.0..=1.0).contains(&p)));
534 let s: f64 = dist.iter().sum();
535 assert!((s - 1.0).abs() < 1e-12, "sum={s}");
536 }
537
538 #[test]
539 fn attention_shapes_correct() {
540 let net = rand_net(5);
541 let n = 6usize;
542 let m = 4usize;
543 let states = rand_states(&net, n, 6);
544 let queries = rand_queries(&net, m, 7);
545 let logits = net
546 .attention_logits(&states, &queries[..net.hidden_dim])
547 .expect("logits");
548 assert_eq!(logits.len(), n);
549 let probs = net.forward(&states, &queries).expect("fwd");
550 assert_eq!(probs.len(), m * n);
551 for i in 0..m {
553 let s: f64 = probs[i * n..(i + 1) * n].iter().sum();
554 assert!((s - 1.0).abs() < 1e-12, "row {i} sum={s}");
555 }
556 }
557
558 #[test]
559 fn decode_yields_in_range_indices() {
560 let net = rand_net(8);
561 let n = 7usize;
562 let states = rand_states(&net, n, 9);
563 let queries = rand_queries(&net, 5, 10);
564 let path = net.decode(&states, &queries).expect("decode");
565 assert_eq!(path.len(), 5);
566 assert!(path.iter().all(|&p| p < n));
567 }
568
569 #[test]
570 fn decode_is_deterministic() {
571 let net = rand_net(11);
572 let states = rand_states(&net, 6, 12);
573 let queries = rand_queries(&net, 4, 13);
574 let p1 = net.decode(&states, &queries).expect("d1");
575 let p2 = net.decode(&states, &queries).expect("d2");
576 assert_eq!(p1, p2);
577 let f1 = net.forward(&states, &queries).expect("f1");
578 let f2 = net.forward(&states, &queries).expect("f2");
579 assert_eq!(f1, f2);
580 }
581
582 #[test]
583 fn gradient_matches_finite_difference() {
584 let net = rand_net(14);
585 let n = 5usize;
586 let states = rand_states(&net, n, 15);
587 let queries = rand_queries(&net, 3, 16);
588 let targets = vec![2usize, 0, 4];
589 let (_, grad) = net.backward(&states, &queries, &targets).expect("bwd");
590
591 let eps = 1e-6;
592 let central = |perturb: &dyn Fn(&mut PointerNetwork, f64)| -> f64 {
593 let mut up = net.clone();
594 perturb(&mut up, eps);
595 let mut dn = net.clone();
596 perturb(&mut dn, -eps);
597 let lp = up.nll(&states, &queries, &targets).expect("nll+");
598 let lm = dn.nll(&states, &queries, &targets).expect("nll-");
599 (lp - lm) / (2.0 * eps)
600 };
601
602 for idx in 0..net.w1.len() {
603 let num = central(&|p, e| p.w1[idx] += e);
604 assert!(
605 (num - grad.w1[idx]).abs() < 1e-4,
606 "w1[{idx}] num={num} ana={}",
607 grad.w1[idx]
608 );
609 }
610 for idx in 0..net.w2.len() {
611 let num = central(&|p, e| p.w2[idx] += e);
612 assert!(
613 (num - grad.w2[idx]).abs() < 1e-4,
614 "w2[{idx}] num={num} ana={}",
615 grad.w2[idx]
616 );
617 }
618 for idx in 0..net.v.len() {
619 let num = central(&|p, e| p.v[idx] += e);
620 assert!(
621 (num - grad.v[idx]).abs() < 1e-4,
622 "v[{idx}] num={num} ana={}",
623 grad.v[idx]
624 );
625 }
626 }
627
628 #[test]
629 fn constructed_weights_point_to_argmax_by_key() {
630 let mut net = PointerNetwork::zeros(2, 2, 2).expect("net");
634 let c = 3.0;
635 net.w1[0] = c; net.v[0] = 1.0;
637 let keys = [0.2_f64, 0.9, 0.1, 0.5];
638 let n = keys.len();
639 let mut states = vec![0.0; n * net.hidden_dim];
640 for (j, &kk) in keys.iter().enumerate() {
641 states[j * net.hidden_dim] = kk;
642 }
643 let query = vec![0.0; net.hidden_dim];
644 let dist = net.pointer_distribution(&states, &query).expect("dist");
645 let argmax = dist
647 .iter()
648 .enumerate()
649 .max_by(|a, b| a.1.partial_cmp(b.1).expect("cmp"))
650 .map(|(j, _)| j)
651 .expect("nonempty");
652 assert_eq!(argmax, 1);
653 }
654
655 #[test]
656 fn training_reduces_nll_on_selection_task() {
657 let mut rng = LcgRng::new(21);
661 let mut net = PointerNetwork::new(2, 3, 2, 0.3, &mut rng).expect("net");
662 let keys = [0.1_f64, 0.4, 0.95, 0.2, 0.6];
663 let n = keys.len();
664 let mut states = vec![0.0; n * net.hidden_dim];
665 for (j, &kk) in keys.iter().enumerate() {
666 states[j * net.hidden_dim] = kk;
667 states[j * net.hidden_dim + 1] = 1.0; }
669 let queries = vec![1.0; net.hidden_dim]; let targets = vec![2usize]; let nll0 = net.nll(&states, &queries, &targets).expect("nll0");
672 for _ in 0..400 {
673 net.step(&states, &queries, &targets, 0.2).expect("step");
674 }
675 let nll1 = net.nll(&states, &queries, &targets).expect("nll1");
676 assert!(nll1 < nll0 - 1e-3, "nll0={nll0}, nll1={nll1}");
677 let path = net.decode(&states, &queries).expect("decode");
678 assert_eq!(path, targets);
679 }
680
681 #[test]
682 fn nll_validates_targets() {
683 let net = rand_net(30);
684 let n = 4usize;
685 let states = rand_states(&net, n, 31);
686 let queries = rand_queries(&net, 2, 32);
687 assert!(net.nll(&states, &queries, &[0, n]).is_err());
689 assert!(net.nll(&states, &queries, &[0]).is_err());
691 }
692
693 #[test]
694 fn input_validation_paths() {
695 let net = rand_net(40);
696 assert!(
698 net.pointer_distribution(&[], &vec![0.0; net.hidden_dim])
699 .is_err()
700 );
701 let bad = vec![0.0; net.hidden_dim * 2 + 1];
703 assert!(
704 net.attention_logits(&bad, &vec![0.0; net.hidden_dim])
705 .is_err()
706 );
707 let states = rand_states(&net, 3, 41);
709 assert!(net.pointer_distribution(&states, &[0.0, 0.0]).is_err());
710 assert!(net.forward(&states, &[]).is_err());
712 }
713
714 #[test]
715 fn encoder_runs_and_shapes_match() {
716 let net = rand_net(50);
717 let n = 4usize;
718 let inputs: Vec<f64> = {
719 let mut rng = LcgRng::new(51);
720 (0..n * net.input_dim)
721 .map(|_| rng.next_range(-1.0, 1.0))
722 .collect()
723 };
724 let states = net.encode(&inputs).expect("encode");
725 assert_eq!(states.len(), n * net.hidden_dim);
726 assert!(states.iter().all(|v| v.is_finite()));
727 let query = vec![0.5; net.hidden_dim];
729 let dist = net.pointer_distribution(&states, &query).expect("dist");
730 let s: f64 = dist.iter().sum();
731 assert!((s - 1.0).abs() < 1e-12);
732 }
733
734 #[test]
735 fn step_validates_learning_rate() {
736 let mut net = rand_net(60);
737 let states = rand_states(&net, 3, 61);
738 let queries = rand_queries(&net, 2, 62);
739 let targets = vec![0usize, 1];
740 assert!(net.step(&states, &queries, &targets, 0.0).is_err());
741 assert!(net.step(&states, &queries, &targets, -1.0).is_err());
742 }
743}