Skip to main content

oxicuda_seq/decoders/
pointer_network.rs

1//! Pointer Network.
2//!
3//! Reference: Vinyals, O., Fortunato, M. & Jaitly, N. (2015). *Pointer Networks*.
4//! NeurIPS 28 (arXiv 1506.03134). <https://arxiv.org/abs/1506.03134>.
5//!
6//! # Model
7//!
8//! A Pointer Network is a sequence-to-sequence model whose attention mechanism
9//! **points to positions in the input** rather than emitting a token from a fixed
10//! output vocabulary. This makes the output vocabulary equal to the (variable)
11//! input length `n`, which is exactly what combinatorial tasks such as sorting,
12//! convex hull and the travelling-salesman problem require.
13//!
14//! Given encoder hidden states `e_1 … e_n` and a decoder query `d_i`, the
15//! content-based attention score for pointing at input position `j` is
16//!
17//! ```text
18//! u^i_j = vᵀ tanh(W1 e_j + W2 d_i)
19//! ```
20//!
21//! and the **pointer distribution** over input positions is
22//! `p^i = softmax(u^i)`. Greedy decoding emits `argmax_j p^i_j` at each step.
23//!
24//! Here the encoder states are provided directly (or produced by a minimal
25//! Elman/`tanh` RNN encoder, [`PointerNetwork::encode`]) and the decoder queries
26//! are likewise provided per step, so the module is a faithful CPU reference for
27//! the pointer attention head and its training objective (teacher-forced NLL with
28//! a finite-difference-verified gradient) without committing to any one recurrent
29//! cell. Production code never panics: all fallible paths return [`SeqError`].
30
31use crate::error::{SeqError, SeqResult};
32use crate::handle::LcgRng;
33
34/// A Pointer Network attention head with optional Elman encoder.
35///
36/// Parameter layout (all row-major, `f64`):
37///
38/// * `w1[a * hidden_dim + h]` — encoder projection (`attn_dim × hidden_dim`)
39/// * `w2[a * hidden_dim + h]` — decoder-query projection (`attn_dim × hidden_dim`)
40/// * `v[a]` — attention combination vector (`attn_dim`)
41/// * `enc_wx[h * input_dim + d]` — encoder input→hidden (`hidden_dim × input_dim`)
42/// * `enc_wh[h * hidden_dim + h2]` — encoder hidden→hidden (`hidden_dim × hidden_dim`)
43/// * `enc_b[h]` — encoder hidden bias (`hidden_dim`)
44#[derive(Debug, Clone)]
45pub struct PointerNetwork {
46    /// Dimensionality of encoder/decoder hidden states.
47    pub hidden_dim: usize,
48    /// Attention (alignment) inner dimension.
49    pub attn_dim: usize,
50    /// Input feature dimensionality for the optional Elman encoder.
51    pub input_dim: usize,
52    /// Encoder-state projection `W1` (`attn_dim × hidden_dim`).
53    pub w1: Vec<f64>,
54    /// Decoder-query projection `W2` (`attn_dim × hidden_dim`).
55    pub w2: Vec<f64>,
56    /// Attention combination vector `v` (`attn_dim`).
57    pub v: Vec<f64>,
58    /// Elman encoder input→hidden weight (`hidden_dim × input_dim`).
59    pub enc_wx: Vec<f64>,
60    /// Elman encoder hidden→hidden weight (`hidden_dim × hidden_dim`).
61    pub enc_wh: Vec<f64>,
62    /// Elman encoder hidden bias (`hidden_dim`).
63    pub enc_b: Vec<f64>,
64}
65
66/// Gradients of the teacher-forced NLL with respect to the attention parameters.
67///
68/// Only the attention head (`w1`, `w2`, `v`) is differentiated here; the encoder
69/// parameters are exercised by [`PointerNetwork::encode`] but treated as fixed
70/// feature extractors for the gradient check.
71#[derive(Debug, Clone)]
72pub struct PointerGrad {
73    /// Gradient w.r.t. `w1`.
74    pub w1: Vec<f64>,
75    /// Gradient w.r.t. `w2`.
76    pub w2: Vec<f64>,
77    /// Gradient w.r.t. `v`.
78    pub v: Vec<f64>,
79}
80
81impl PointerNetwork {
82    /// Construct a zero-initialised pointer network.
83    ///
84    /// All dimensions must be positive; otherwise [`SeqError::InvalidConfiguration`].
85    pub fn zeros(hidden_dim: usize, attn_dim: usize, input_dim: usize) -> SeqResult<Self> {
86        if hidden_dim == 0 || attn_dim == 0 || input_dim == 0 {
87            return Err(SeqError::InvalidConfiguration(
88                "hidden_dim, attn_dim and input_dim must all be > 0".to_string(),
89            ));
90        }
91        Ok(Self {
92            hidden_dim,
93            attn_dim,
94            input_dim,
95            w1: vec![0.0; attn_dim * hidden_dim],
96            w2: vec![0.0; attn_dim * hidden_dim],
97            v: vec![0.0; attn_dim],
98            enc_wx: vec![0.0; hidden_dim * input_dim],
99            enc_wh: vec![0.0; hidden_dim * hidden_dim],
100            enc_b: vec![0.0; hidden_dim],
101        })
102    }
103
104    /// Construct a pointer network with small random weights from a seeded LCG.
105    ///
106    /// All weight matrices and `v` are sampled `~ U(-scale, scale)`; biases start
107    /// at zero. `scale` must be finite and positive.
108    pub fn new(
109        hidden_dim: usize,
110        attn_dim: usize,
111        input_dim: usize,
112        scale: f64,
113        rng: &mut LcgRng,
114    ) -> SeqResult<Self> {
115        if !scale.is_finite() || scale <= 0.0 {
116            return Err(SeqError::InvalidParameter {
117                name: "scale".to_string(),
118                value: scale,
119            });
120        }
121        let mut net = Self::zeros(hidden_dim, attn_dim, input_dim)?;
122        for buf in [&mut net.w1, &mut net.w2, &mut net.enc_wx, &mut net.enc_wh] {
123            for v in buf.iter_mut() {
124                *v = rng.next_range(-scale, scale);
125            }
126        }
127        for v in net.v.iter_mut() {
128            *v = rng.next_range(-scale, scale);
129        }
130        Ok(net)
131    }
132
133    /// Number of input positions implied by an encoder-state buffer.
134    fn n_positions(&self, encoder_states: &[f64]) -> SeqResult<usize> {
135        if encoder_states.is_empty() {
136            return Err(SeqError::EmptyInput);
137        }
138        if encoder_states.len() % self.hidden_dim != 0 {
139            return Err(SeqError::DimensionMismatch {
140                a: encoder_states.len(),
141                b: self.hidden_dim,
142            });
143        }
144        Ok(encoder_states.len() / self.hidden_dim)
145    }
146
147    /// Encode an input feature sequence with a minimal Elman (`tanh`) RNN.
148    ///
149    /// `inputs` is `n × input_dim` row-major; the returned buffer is
150    /// `n × hidden_dim` row-major encoder states `e_1 … e_n`. Provided as a
151    /// convenience; callers may instead pass their own embeddings to the
152    /// attention methods directly.
153    pub fn encode(&self, inputs: &[f64]) -> SeqResult<Vec<f64>> {
154        if inputs.is_empty() {
155            return Err(SeqError::EmptyInput);
156        }
157        if inputs.len() % self.input_dim != 0 {
158            return Err(SeqError::DimensionMismatch {
159                a: inputs.len(),
160                b: self.input_dim,
161            });
162        }
163        let n = inputs.len() / self.input_dim;
164        let hh = self.hidden_dim;
165        let d = self.input_dim;
166        let mut states = vec![0.0; n * hh];
167        let mut prev = vec![0.0; hh];
168        for t in 0..n {
169            let xt = &inputs[t * d..(t + 1) * d];
170            for h in 0..hh {
171                let mut acc = self.enc_b[h];
172                let rx = h * d;
173                for (dd, &xv) in xt.iter().enumerate() {
174                    acc += self.enc_wx[rx + dd] * xv;
175                }
176                let rh = h * hh;
177                for (h2, &pv) in prev.iter().enumerate() {
178                    acc += self.enc_wh[rh + h2] * pv;
179                }
180                states[t * hh + h] = acc.tanh();
181            }
182            prev.copy_from_slice(&states[t * hh..(t + 1) * hh]);
183        }
184        Ok(states)
185    }
186
187    /// Pre-compute the encoder projections `W1 e_j` for all positions `j`.
188    ///
189    /// Returns an `n × attn_dim` row-major buffer reused across decoder steps.
190    fn project_encoder(&self, encoder_states: &[f64]) -> SeqResult<Vec<f64>> {
191        let n = self.n_positions(encoder_states)?;
192        let a = self.attn_dim;
193        let hh = self.hidden_dim;
194        let mut proj = vec![0.0; n * a];
195        for j in 0..n {
196            let ej = &encoder_states[j * hh..(j + 1) * hh];
197            for aa in 0..a {
198                let mut acc = 0.0;
199                let row = aa * hh;
200                for (h, &ev) in ej.iter().enumerate() {
201                    acc += self.w1[row + h] * ev;
202                }
203                proj[j * a + aa] = acc;
204            }
205        }
206        Ok(proj)
207    }
208
209    /// Project a decoder query `d_i` to `W2 d_i` (length `attn_dim`).
210    fn project_query(&self, query: &[f64]) -> SeqResult<Vec<f64>> {
211        if query.len() != self.hidden_dim {
212            return Err(SeqError::ShapeMismatch {
213                expected: self.hidden_dim,
214                got: query.len(),
215            });
216        }
217        let a = self.attn_dim;
218        let hh = self.hidden_dim;
219        let mut q = vec![0.0; a];
220        for aa in 0..a {
221            let mut acc = 0.0;
222            let row = aa * hh;
223            for (h, &qv) in query.iter().enumerate() {
224                acc += self.w2[row + h] * qv;
225            }
226            q[aa] = acc;
227        }
228        Ok(q)
229    }
230
231    /// Raw attention logits `u^i_j = vᵀ tanh(W1 e_j + W2 d_i)` for one query.
232    ///
233    /// `encoder_states` is `n × hidden_dim`; `query` is length `hidden_dim`.
234    pub fn attention_logits(&self, encoder_states: &[f64], query: &[f64]) -> SeqResult<Vec<f64>> {
235        let proj = self.project_encoder(encoder_states)?;
236        let qp = self.project_query(query)?;
237        let n = self.n_positions(encoder_states)?;
238        let a = self.attn_dim;
239        let mut logits = vec![0.0; n];
240        for j in 0..n {
241            let mut acc = 0.0;
242            for aa in 0..a {
243                acc += self.v[aa] * (proj[j * a + aa] + qp[aa]).tanh();
244            }
245            logits[j] = acc;
246        }
247        Ok(logits)
248    }
249
250    /// Numerically-stable softmax of `logits` in place into a fresh vector.
251    fn softmax(logits: &[f64]) -> Vec<f64> {
252        let mut max = f64::NEG_INFINITY;
253        for &z in logits {
254            if z > max {
255                max = z;
256            }
257        }
258        if !max.is_finite() {
259            // All −inf or empty: fall back to uniform over the support.
260            let n = logits.len().max(1);
261            return vec![1.0 / n as f64; logits.len()];
262        }
263        let mut probs: Vec<f64> = logits.iter().map(|&z| (z - max).exp()).collect();
264        let s: f64 = probs.iter().sum();
265        if s > 0.0 {
266            for p in probs.iter_mut() {
267                *p /= s;
268            }
269        }
270        probs
271    }
272
273    /// Pointer distribution `softmax(u^i)` over the `n` input positions for one
274    /// decoder query.
275    pub fn pointer_distribution(
276        &self,
277        encoder_states: &[f64],
278        query: &[f64],
279    ) -> SeqResult<Vec<f64>> {
280        let logits = self.attention_logits(encoder_states, query)?;
281        Ok(Self::softmax(&logits))
282    }
283
284    /// Forward pass over a sequence of decoder queries.
285    ///
286    /// `queries` is `m × hidden_dim` row-major. Returns an `m × n` row-major
287    /// matrix of pointer distributions (row `i` is the distribution for query `i`).
288    pub fn forward(&self, encoder_states: &[f64], queries: &[f64]) -> SeqResult<Vec<f64>> {
289        let n = self.n_positions(encoder_states)?;
290        let hh = self.hidden_dim;
291        if queries.is_empty() {
292            return Err(SeqError::EmptyInput);
293        }
294        if queries.len() % hh != 0 {
295            return Err(SeqError::DimensionMismatch {
296                a: queries.len(),
297                b: hh,
298            });
299        }
300        let m = queries.len() / hh;
301        let proj = self.project_encoder(encoder_states)?;
302        let a = self.attn_dim;
303        let mut out = vec![0.0; m * n];
304        for i in 0..m {
305            let qp = self.project_query(&queries[i * hh..(i + 1) * hh])?;
306            let mut logits = vec![0.0; n];
307            for j in 0..n {
308                let mut acc = 0.0;
309                for aa in 0..a {
310                    acc += self.v[aa] * (proj[j * a + aa] + qp[aa]).tanh();
311                }
312                logits[j] = acc;
313            }
314            let probs = Self::softmax(&logits);
315            out[i * n..(i + 1) * n].copy_from_slice(&probs);
316        }
317        Ok(out)
318    }
319
320    /// Greedy decode: emit `argmax_j p^i_j` for each decoder query.
321    ///
322    /// Returns a sequence of `m` pointed input indices, each in `0..n`.
323    pub fn decode(&self, encoder_states: &[f64], queries: &[f64]) -> SeqResult<Vec<usize>> {
324        let n = self.n_positions(encoder_states)?;
325        let probs = self.forward(encoder_states, queries)?;
326        let m = probs.len() / n;
327        let mut out = vec![0usize; m];
328        for i in 0..m {
329            let mut best = f64::NEG_INFINITY;
330            let mut argmax = 0usize;
331            for j in 0..n {
332                let p = probs[i * n + j];
333                if p > best {
334                    best = p;
335                    argmax = j;
336                }
337            }
338            out[i] = argmax;
339        }
340        Ok(out)
341    }
342
343    /// Teacher-forced negative log-likelihood of a target index sequence.
344    ///
345    /// `targets[i]` is the gold input position pointed to at decoder step `i`;
346    /// `NLL = − Σ_i log p^i_{targets[i]}`. Targets must be in `0..n`.
347    pub fn nll(
348        &self,
349        encoder_states: &[f64],
350        queries: &[f64],
351        targets: &[usize],
352    ) -> SeqResult<f64> {
353        let n = self.n_positions(encoder_states)?;
354        let probs = self.forward(encoder_states, queries)?;
355        let m = probs.len() / n;
356        if targets.len() != m {
357            return Err(SeqError::LengthMismatch {
358                a: targets.len(),
359                b: m,
360            });
361        }
362        let mut nll = 0.0;
363        for i in 0..m {
364            let tgt = targets[i];
365            if tgt >= n {
366                return Err(SeqError::IndexOutOfBounds { index: tgt, len: n });
367            }
368            let p = probs[i * n + tgt].max(1e-300);
369            nll -= p.ln();
370        }
371        Ok(nll)
372    }
373
374    /// Gradient of the teacher-forced NLL w.r.t. the attention parameters
375    /// (`w1`, `w2`, `v`).
376    ///
377    /// Uses the softmax-cross-entropy identity `∂NLL/∂u^i_j = p^i_j − 1[j = tgt_i]`
378    /// and back-propagates through `s_{j,a} = tanh(W1 e_j + W2 d_i)_a`. Returns the
379    /// NLL and a [`PointerGrad`].
380    pub fn backward(
381        &self,
382        encoder_states: &[f64],
383        queries: &[f64],
384        targets: &[usize],
385    ) -> SeqResult<(f64, PointerGrad)> {
386        let n = self.n_positions(encoder_states)?;
387        let hh = self.hidden_dim;
388        if queries.is_empty() || queries.len() % hh != 0 {
389            return Err(SeqError::DimensionMismatch {
390                a: queries.len(),
391                b: hh,
392            });
393        }
394        let m = queries.len() / hh;
395        if targets.len() != m {
396            return Err(SeqError::LengthMismatch {
397                a: targets.len(),
398                b: m,
399            });
400        }
401        for &t in targets {
402            if t >= n {
403                return Err(SeqError::IndexOutOfBounds { index: t, len: n });
404            }
405        }
406        let a = self.attn_dim;
407        let proj = self.project_encoder(encoder_states)?;
408
409        let mut g_w1 = vec![0.0; a * hh];
410        let mut g_w2 = vec![0.0; a * hh];
411        let mut g_v = vec![0.0; a];
412        let mut nll = 0.0;
413
414        for i in 0..m {
415            let qi = &queries[i * hh..(i + 1) * hh];
416            let qp = self.project_query(qi)?;
417            // Per-position pre-activation s and its tanh, plus logits.
418            let mut s = vec![0.0; n * a];
419            let mut logits = vec![0.0; n];
420            for j in 0..n {
421                let mut acc = 0.0;
422                for aa in 0..a {
423                    let pre = proj[j * a + aa] + qp[aa];
424                    let th = pre.tanh();
425                    s[j * a + aa] = th;
426                    acc += self.v[aa] * th;
427                }
428                logits[j] = acc;
429            }
430            let probs = Self::softmax(&logits);
431            let tgt = targets[i];
432            nll -= probs[tgt].max(1e-300).ln();
433
434            // d_logit[j] = p_j − 1[j == tgt]
435            for j in 0..n {
436                let d_logit = probs[j] - if j == tgt { 1.0 } else { 0.0 };
437                let ej = &encoder_states[j * hh..(j + 1) * hh];
438                for aa in 0..a {
439                    // v gradient: u depends on v_a via s_{j,a}.
440                    g_v[aa] += d_logit * s[j * a + aa];
441                    // Through tanh into the pre-activation.
442                    let d_pre = d_logit * self.v[aa] * (1.0 - s[j * a + aa] * s[j * a + aa]);
443                    let row = aa * hh;
444                    for h in 0..hh {
445                        // W1 e_j contribution.
446                        g_w1[row + h] += d_pre * ej[h];
447                        // W2 d_i contribution.
448                        g_w2[row + h] += d_pre * qi[h];
449                    }
450                }
451            }
452        }
453
454        Ok((
455            nll,
456            PointerGrad {
457                w1: g_w1,
458                w2: g_w2,
459                v: g_v,
460            },
461        ))
462    }
463
464    /// Apply one gradient-descent step on the attention parameters with learning
465    /// rate `lr`. Returns the NLL *before* the update.
466    pub fn step(
467        &mut self,
468        encoder_states: &[f64],
469        queries: &[f64],
470        targets: &[usize],
471        lr: f64,
472    ) -> SeqResult<f64> {
473        if !lr.is_finite() || lr <= 0.0 {
474            return Err(SeqError::InvalidParameter {
475                name: "lr".to_string(),
476                value: lr,
477            });
478        }
479        let (nll, grad) = self.backward(encoder_states, queries, targets)?;
480        for (w, g) in self.w1.iter_mut().zip(grad.w1.iter()) {
481            *w -= lr * g;
482        }
483        for (w, g) in self.w2.iter_mut().zip(grad.w2.iter()) {
484            *w -= lr * g;
485        }
486        for (w, g) in self.v.iter_mut().zip(grad.v.iter()) {
487            *w -= lr * g;
488        }
489        Ok(nll)
490    }
491}
492
493#[cfg(test)]
494mod tests {
495    use super::*;
496
497    fn rand_net(seed: u64) -> PointerNetwork {
498        let mut rng = LcgRng::new(seed);
499        PointerNetwork::new(3, 4, 2, 0.5, &mut rng).expect("net")
500    }
501
502    fn rand_states(net: &PointerNetwork, n: usize, seed: u64) -> Vec<f64> {
503        let mut rng = LcgRng::new(seed);
504        (0..n * net.hidden_dim)
505            .map(|_| rng.next_range(-1.0, 1.0))
506            .collect()
507    }
508
509    fn rand_queries(net: &PointerNetwork, m: usize, seed: u64) -> Vec<f64> {
510        let mut rng = LcgRng::new(seed);
511        (0..m * net.hidden_dim)
512            .map(|_| rng.next_range(-1.0, 1.0))
513            .collect()
514    }
515
516    #[test]
517    fn construct_validates_dims() {
518        assert!(PointerNetwork::zeros(0, 2, 2).is_err());
519        assert!(PointerNetwork::zeros(2, 0, 2).is_err());
520        assert!(PointerNetwork::zeros(2, 2, 0).is_err());
521        let mut rng = LcgRng::new(1);
522        assert!(PointerNetwork::new(2, 2, 2, 0.0, &mut rng).is_err());
523        assert!(PointerNetwork::new(2, 2, 2, f64::INFINITY, &mut rng).is_err());
524    }
525
526    #[test]
527    fn pointer_distribution_is_valid_simplex() {
528        let net = rand_net(2);
529        let states = rand_states(&net, 5, 3);
530        let query = rand_queries(&net, 1, 4);
531        let dist = net.pointer_distribution(&states, &query).expect("dist");
532        assert_eq!(dist.len(), 5);
533        assert!(dist.iter().all(|&p| (0.0..=1.0).contains(&p)));
534        let s: f64 = dist.iter().sum();
535        assert!((s - 1.0).abs() < 1e-12, "sum={s}");
536    }
537
538    #[test]
539    fn attention_shapes_correct() {
540        let net = rand_net(5);
541        let n = 6usize;
542        let m = 4usize;
543        let states = rand_states(&net, n, 6);
544        let queries = rand_queries(&net, m, 7);
545        let logits = net
546            .attention_logits(&states, &queries[..net.hidden_dim])
547            .expect("logits");
548        assert_eq!(logits.len(), n);
549        let probs = net.forward(&states, &queries).expect("fwd");
550        assert_eq!(probs.len(), m * n);
551        // Each row is a simplex.
552        for i in 0..m {
553            let s: f64 = probs[i * n..(i + 1) * n].iter().sum();
554            assert!((s - 1.0).abs() < 1e-12, "row {i} sum={s}");
555        }
556    }
557
558    #[test]
559    fn decode_yields_in_range_indices() {
560        let net = rand_net(8);
561        let n = 7usize;
562        let states = rand_states(&net, n, 9);
563        let queries = rand_queries(&net, 5, 10);
564        let path = net.decode(&states, &queries).expect("decode");
565        assert_eq!(path.len(), 5);
566        assert!(path.iter().all(|&p| p < n));
567    }
568
569    #[test]
570    fn decode_is_deterministic() {
571        let net = rand_net(11);
572        let states = rand_states(&net, 6, 12);
573        let queries = rand_queries(&net, 4, 13);
574        let p1 = net.decode(&states, &queries).expect("d1");
575        let p2 = net.decode(&states, &queries).expect("d2");
576        assert_eq!(p1, p2);
577        let f1 = net.forward(&states, &queries).expect("f1");
578        let f2 = net.forward(&states, &queries).expect("f2");
579        assert_eq!(f1, f2);
580    }
581
582    #[test]
583    fn gradient_matches_finite_difference() {
584        let net = rand_net(14);
585        let n = 5usize;
586        let states = rand_states(&net, n, 15);
587        let queries = rand_queries(&net, 3, 16);
588        let targets = vec![2usize, 0, 4];
589        let (_, grad) = net.backward(&states, &queries, &targets).expect("bwd");
590
591        let eps = 1e-6;
592        let central = |perturb: &dyn Fn(&mut PointerNetwork, f64)| -> f64 {
593            let mut up = net.clone();
594            perturb(&mut up, eps);
595            let mut dn = net.clone();
596            perturb(&mut dn, -eps);
597            let lp = up.nll(&states, &queries, &targets).expect("nll+");
598            let lm = dn.nll(&states, &queries, &targets).expect("nll-");
599            (lp - lm) / (2.0 * eps)
600        };
601
602        for idx in 0..net.w1.len() {
603            let num = central(&|p, e| p.w1[idx] += e);
604            assert!(
605                (num - grad.w1[idx]).abs() < 1e-4,
606                "w1[{idx}] num={num} ana={}",
607                grad.w1[idx]
608            );
609        }
610        for idx in 0..net.w2.len() {
611            let num = central(&|p, e| p.w2[idx] += e);
612            assert!(
613                (num - grad.w2[idx]).abs() < 1e-4,
614                "w2[{idx}] num={num} ana={}",
615                grad.w2[idx]
616            );
617        }
618        for idx in 0..net.v.len() {
619            let num = central(&|p, e| p.v[idx] += e);
620            assert!(
621                (num - grad.v[idx]).abs() < 1e-4,
622                "v[{idx}] num={num} ana={}",
623                grad.v[idx]
624            );
625        }
626    }
627
628    #[test]
629    fn constructed_weights_point_to_argmax_by_key() {
630        // Selection task: each encoder state e_j = [key_j, 0]. With v = [1, 0],
631        // W1 = [[c, 0], [0, 0]], W2 = 0, the logit u_j = tanh(c · key_j), which is
632        // monotone in key_j, so argmax points at the position with the largest key.
633        let mut net = PointerNetwork::zeros(2, 2, 2).expect("net");
634        let c = 3.0;
635        net.w1[0] = c; // attn row 0 reads hidden dim 0 (the key)
636        net.v[0] = 1.0;
637        let keys = [0.2_f64, 0.9, 0.1, 0.5];
638        let n = keys.len();
639        let mut states = vec![0.0; n * net.hidden_dim];
640        for (j, &kk) in keys.iter().enumerate() {
641            states[j * net.hidden_dim] = kk;
642        }
643        let query = vec![0.0; net.hidden_dim];
644        let dist = net.pointer_distribution(&states, &query).expect("dist");
645        // argmax key is position 1.
646        let argmax = dist
647            .iter()
648            .enumerate()
649            .max_by(|a, b| a.1.partial_cmp(b.1).expect("cmp"))
650            .map(|(j, _)| j)
651            .expect("nonempty");
652        assert_eq!(argmax, 1);
653    }
654
655    #[test]
656    fn training_reduces_nll_on_selection_task() {
657        // Teach the net to point at the highest-key position. The decoder query is
658        // a constant; the encoder states carry distinct keys in hidden dim 0.
659        // A 2-dim hidden state makes the keys easy to set explicitly.
660        let mut rng = LcgRng::new(21);
661        let mut net = PointerNetwork::new(2, 3, 2, 0.3, &mut rng).expect("net");
662        let keys = [0.1_f64, 0.4, 0.95, 0.2, 0.6];
663        let n = keys.len();
664        let mut states = vec![0.0; n * net.hidden_dim];
665        for (j, &kk) in keys.iter().enumerate() {
666            states[j * net.hidden_dim] = kk;
667            states[j * net.hidden_dim + 1] = 1.0; // bias-like constant feature
668        }
669        let queries = vec![1.0; net.hidden_dim]; // single decoder step
670        let targets = vec![2usize]; // argmax key is position 2
671        let nll0 = net.nll(&states, &queries, &targets).expect("nll0");
672        for _ in 0..400 {
673            net.step(&states, &queries, &targets, 0.2).expect("step");
674        }
675        let nll1 = net.nll(&states, &queries, &targets).expect("nll1");
676        assert!(nll1 < nll0 - 1e-3, "nll0={nll0}, nll1={nll1}");
677        let path = net.decode(&states, &queries).expect("decode");
678        assert_eq!(path, targets);
679    }
680
681    #[test]
682    fn nll_validates_targets() {
683        let net = rand_net(30);
684        let n = 4usize;
685        let states = rand_states(&net, n, 31);
686        let queries = rand_queries(&net, 2, 32);
687        // Target out of range.
688        assert!(net.nll(&states, &queries, &[0, n]).is_err());
689        // Wrong number of targets.
690        assert!(net.nll(&states, &queries, &[0]).is_err());
691    }
692
693    #[test]
694    fn input_validation_paths() {
695        let net = rand_net(40);
696        // Empty encoder states.
697        assert!(
698            net.pointer_distribution(&[], &vec![0.0; net.hidden_dim])
699                .is_err()
700        );
701        // Ragged encoder states.
702        let bad = vec![0.0; net.hidden_dim * 2 + 1];
703        assert!(
704            net.attention_logits(&bad, &vec![0.0; net.hidden_dim])
705                .is_err()
706        );
707        // Query of wrong length.
708        let states = rand_states(&net, 3, 41);
709        assert!(net.pointer_distribution(&states, &[0.0, 0.0]).is_err());
710        // Empty queries.
711        assert!(net.forward(&states, &[]).is_err());
712    }
713
714    #[test]
715    fn encoder_runs_and_shapes_match() {
716        let net = rand_net(50);
717        let n = 4usize;
718        let inputs: Vec<f64> = {
719            let mut rng = LcgRng::new(51);
720            (0..n * net.input_dim)
721                .map(|_| rng.next_range(-1.0, 1.0))
722                .collect()
723        };
724        let states = net.encode(&inputs).expect("encode");
725        assert_eq!(states.len(), n * net.hidden_dim);
726        assert!(states.iter().all(|v| v.is_finite()));
727        // States feed straight into the attention head.
728        let query = vec![0.5; net.hidden_dim];
729        let dist = net.pointer_distribution(&states, &query).expect("dist");
730        let s: f64 = dist.iter().sum();
731        assert!((s - 1.0).abs() < 1e-12);
732    }
733
734    #[test]
735    fn step_validates_learning_rate() {
736        let mut net = rand_net(60);
737        let states = rand_states(&net, 3, 61);
738        let queries = rand_queries(&net, 2, 62);
739        let targets = vec![0usize, 1];
740        assert!(net.step(&states, &queries, &targets, 0.0).is_err());
741        assert!(net.step(&states, &queries, &targets, -1.0).is_err());
742    }
743}