Skip to main content

oxicuda_vision/pointcloud/
point_transformer.rs

1//! Point Transformer — vector self-attention over kNN neighbourhoods.
2//!
3//! Reference: Zhao, Jiang, Jia, Torr & Koltun, *"Point Transformer"*
4//! (ICCV 2021).
5//!
6//! Unlike the scalar dot-product attention used by language / image
7//! transformers, the Point Transformer layer uses **vector attention** with a
8//! *subtraction relation*. For a query point `x_i` and a neighbour `x_j` drawn
9//! from `x_i`'s k-nearest-neighbour set `N(i)`, the per-channel attention vector
10//! is
11//!
12//! ```text
13//! γ( φ(x_i) − ψ(x_j) + δ_ij )          (a vector in R^d, not a scalar)
14//! ```
15//!
16//! where `φ`, `ψ` are linear query / key projections, `δ_ij = θ(p_i − p_j)` is a
17//! learned **position encoding** of the *relative* coordinate offset, and `γ` is
18//! a small MLP. The weights are normalised with a **per-channel softmax over the
19//! neighbourhood** and aggregate the value vectors:
20//!
21//! ```text
22//! y_i = Σ_{j∈N(i)} ρ( γ( φ(x_i) − ψ(x_j) + δ_ij ) ) ⊙ ( α(x_j) + δ_ij )
23//! ```
24//!
25//! with `ρ = softmax_j`, `α` the value projection and `⊙` the Hadamard product.
26//! A final linear projection maps `y_i` to the output dimension.
27//!
28//! ## Tensor layout
29//! - `points`:   flat `[n_points × 3]` row-major XYZ coordinates.
30//! - `features`: flat `[n_points × in_dim]` row-major.
31//! - outputs:    flat `[n_points × out_dim]` row-major.
32//!
33//! ## Key properties (exercised by the tests)
34//! - **Permutation equivariance**: permuting the input points permutes the
35//!   outputs identically — the defining property of a point-cloud network.
36//! - **Translation invariance of the relations**: because `δ` only sees the
37//!   *relative* offset `p_i − p_j`, translating every point by a constant leaves
38//!   the attention weights unchanged.
39//! - **Per-channel softmax**: the vector-attention weights are non-negative and
40//!   sum to 1 over the neighbourhood, independently for every channel.
41
42use crate::{
43    error::{VisionError, VisionResult},
44    handle::LcgRng,
45    vit::vit_block::linear,
46};
47
48/// Spatial coordinate dimension (XYZ).
49const COORD_DIM: usize = 3;
50
51// ─── Linear ────────────────────────────────────────────────────────────────────
52
53/// A dense linear projection `y = x W^T + b` with `[n_out × n_in]` weights.
54///
55/// Reuses the crate-wide [`crate::vit::vit_block::linear`] kernel so the matmul
56/// is not duplicated.
57#[derive(Debug, Clone)]
58struct Linear {
59    weight: Vec<f32>,
60    bias: Vec<f32>,
61    n_in: usize,
62    n_out: usize,
63}
64
65impl Linear {
66    /// Random init with `N(0, scale)` weights and zero bias.
67    fn new(n_in: usize, n_out: usize, scale: f32, rng: &mut LcgRng) -> Self {
68        let mut weight = vec![0.0f32; n_in * n_out];
69        rng.fill_normal(&mut weight);
70        for w in &mut weight {
71            *w *= scale;
72        }
73        Self {
74            weight,
75            bias: vec![0.0f32; n_out],
76            n_in,
77            n_out,
78        }
79    }
80
81    /// Apply to a single `[n_in]` vector → `[n_out]`.
82    #[inline]
83    fn apply(&self, x: &[f32]) -> Vec<f32> {
84        linear(x, &self.weight, &self.bias, self.n_in, self.n_out)
85    }
86}
87
88// ─── Mlp (2-layer, ReLU) ────────────────────────────────────────────────────────
89
90/// Two-layer perceptron `Linear → ReLU → Linear` used for both the position
91/// encoder `θ` (`δ`) and the attention mapping `γ`.
92#[derive(Debug, Clone)]
93struct Mlp {
94    fc1: Linear,
95    fc2: Linear,
96}
97
98impl Mlp {
99    fn new(n_in: usize, hidden: usize, n_out: usize, rng: &mut LcgRng) -> Self {
100        // Kaiming-style scale for the ReLU non-linearity.
101        let s1 = (2.0 / n_in as f32).sqrt();
102        let s2 = (2.0 / hidden as f32).sqrt();
103        Self {
104            fc1: Linear::new(n_in, hidden, s1, rng),
105            fc2: Linear::new(hidden, n_out, s2, rng),
106        }
107    }
108
109    /// Apply to a single vector.
110    #[inline]
111    fn apply(&self, x: &[f32]) -> Vec<f32> {
112        let mut h = self.fc1.apply(x);
113        for v in &mut h {
114            *v = v.max(0.0); // ReLU
115        }
116        self.fc2.apply(&h)
117    }
118}
119
120// ─── Config ──────────────────────────────────────────────────────────────────
121
122/// Configuration for a [`PointTransformerLayer`].
123#[derive(Debug, Clone, PartialEq)]
124pub struct PointTransformerConfig {
125    /// Input feature dimension `d_in`.
126    pub in_dim: usize,
127    /// Attention (query / key / value) dimension `d`.
128    pub dim: usize,
129    /// Output feature dimension `d_out`.
130    pub out_dim: usize,
131    /// Hidden width of the position-encoding MLP `θ`.
132    pub pos_hidden: usize,
133    /// Hidden width of the attention MLP `γ`.
134    pub attn_hidden: usize,
135    /// Number of neighbours `k` (includes the point itself).
136    pub k: usize,
137}
138
139impl PointTransformerConfig {
140    /// Create and validate a configuration.
141    ///
142    /// # Errors
143    /// - [`VisionError::InvalidEmbedDim`] if `in_dim`, `dim` or `out_dim` is 0.
144    /// - [`VisionError::EmptyInput`] if `k == 0`, `pos_hidden == 0` or
145    ///   `attn_hidden == 0`.
146    pub fn new(
147        in_dim: usize,
148        dim: usize,
149        out_dim: usize,
150        pos_hidden: usize,
151        attn_hidden: usize,
152        k: usize,
153    ) -> VisionResult<Self> {
154        if in_dim == 0 {
155            return Err(VisionError::InvalidEmbedDim(in_dim));
156        }
157        if dim == 0 {
158            return Err(VisionError::InvalidEmbedDim(dim));
159        }
160        if out_dim == 0 {
161            return Err(VisionError::InvalidEmbedDim(out_dim));
162        }
163        if k == 0 {
164            return Err(VisionError::EmptyInput("point transformer k"));
165        }
166        if pos_hidden == 0 {
167            return Err(VisionError::EmptyInput("point transformer pos_hidden"));
168        }
169        if attn_hidden == 0 {
170            return Err(VisionError::EmptyInput("point transformer attn_hidden"));
171        }
172        Ok(Self {
173            in_dim,
174            dim,
175            out_dim,
176            pos_hidden,
177            attn_hidden,
178            k,
179        })
180    }
181
182    /// A tiny configuration for unit tests:
183    /// `in_dim=8, dim=8, out_dim=8, pos_hidden=8, attn_hidden=8, k=4`.
184    #[must_use]
185    pub fn tiny() -> Self {
186        Self {
187            in_dim: 8,
188            dim: 8,
189            out_dim: 8,
190            pos_hidden: 8,
191            attn_hidden: 8,
192            k: 4,
193        }
194    }
195}
196
197// ─── Detailed output ──────────────────────────────────────────────────────────
198
199/// Detailed per-point attention output, exposing the neighbourhood and the
200/// per-channel vector-attention weights (useful for inspection / tests).
201#[derive(Debug, Clone)]
202pub struct PointAttention {
203    /// Aggregated output features: flat `[n_points × out_dim]`.
204    pub features: Vec<f32>,
205    /// Neighbour indices: flat `[n_points × k]`, sorted nearest-first.
206    pub neighbors: Vec<usize>,
207    /// Per-channel softmax weights: flat `[n_points × k × dim]`.
208    ///
209    /// For point `i`, neighbour slot `s`, channel `c`:
210    /// `weights[(i * k + s) * dim + c]`.
211    pub weights: Vec<f32>,
212    /// Number of points.
213    pub n_points: usize,
214    /// Neighbours per point.
215    pub k: usize,
216    /// Attention dimension `d`.
217    pub dim: usize,
218}
219
220// ─── k-nearest-neighbours ──────────────────────────────────────────────────────
221
222/// Indices of the `k` nearest points to point `i` (including `i` itself),
223/// sorted by ascending squared Euclidean distance with the point index as a
224/// deterministic tie-break.
225///
226/// `points` is flat `[n × 3]`.
227fn knn(points: &[f32], n: usize, i: usize, k: usize) -> Vec<usize> {
228    let pi = &points[i * COORD_DIM..i * COORD_DIM + COORD_DIM];
229    let mut dists: Vec<(f32, usize)> = (0..n)
230        .map(|j| {
231            let pj = &points[j * COORD_DIM..j * COORD_DIM + COORD_DIM];
232            let mut d = 0.0f32;
233            for c in 0..COORD_DIM {
234                let diff = pi[c] - pj[c];
235                d += diff * diff;
236            }
237            (d, j)
238        })
239        .collect();
240    // Ascending by distance, ties broken by lower index → fully deterministic.
241    dists.sort_by(|a, b| {
242        a.0.partial_cmp(&b.0)
243            .unwrap_or(std::cmp::Ordering::Equal)
244            .then(a.1.cmp(&b.1))
245    });
246    let kk = k.min(n);
247    dists.into_iter().take(kk).map(|(_, j)| j).collect()
248}
249
250// ─── PointTransformerLayer ──────────────────────────────────────────────────────
251
252/// A single Point Transformer vector-attention layer.
253pub struct PointTransformerLayer {
254    cfg: PointTransformerConfig,
255    /// Query projection `φ`: `in_dim → dim`.
256    phi: Linear,
257    /// Key projection `ψ`: `in_dim → dim`.
258    psi: Linear,
259    /// Value projection `α`: `in_dim → dim`.
260    alpha: Linear,
261    /// Position-encoding MLP `θ`: `3 → dim` (`δ`).
262    theta: Mlp,
263    /// Attention MLP `γ`: `dim → dim`.
264    gamma: Mlp,
265    /// Output projection: `dim → out_dim`.
266    out_proj: Linear,
267}
268
269impl PointTransformerLayer {
270    /// Construct a layer with randomly-initialised weights.
271    pub fn new(cfg: PointTransformerConfig, rng: &mut LcgRng) -> Self {
272        let proj_scale = 1.0 / (cfg.in_dim as f32).sqrt();
273        let phi = Linear::new(cfg.in_dim, cfg.dim, proj_scale, rng);
274        let psi = Linear::new(cfg.in_dim, cfg.dim, proj_scale, rng);
275        let alpha = Linear::new(cfg.in_dim, cfg.dim, proj_scale, rng);
276        let theta = Mlp::new(COORD_DIM, cfg.pos_hidden, cfg.dim, rng);
277        let gamma = Mlp::new(cfg.dim, cfg.attn_hidden, cfg.dim, rng);
278        let out_proj = Linear::new(cfg.dim, cfg.out_dim, 1.0 / (cfg.dim as f32).sqrt(), rng);
279        Self {
280            cfg,
281            phi,
282            psi,
283            alpha,
284            theta,
285            gamma,
286            out_proj,
287        }
288    }
289
290    /// Read-only configuration access.
291    #[must_use]
292    #[inline]
293    pub fn config(&self) -> &PointTransformerConfig {
294        &self.cfg
295    }
296
297    /// Forward pass returning only the output features `[n_points × out_dim]`.
298    ///
299    /// # Errors
300    /// See [`PointTransformerLayer::forward_detailed`].
301    pub fn forward(
302        &self,
303        points: &[f32],
304        features: &[f32],
305        n_points: usize,
306    ) -> VisionResult<Vec<f32>> {
307        Ok(self.compute(points, features, n_points, true)?.features)
308    }
309
310    /// Forward pass exposing neighbourhoods and per-channel attention weights.
311    ///
312    /// # Errors
313    /// - [`VisionError::EmptyInput`] if `n_points == 0`.
314    /// - [`VisionError::DimensionMismatch`] if `points.len() != n_points * 3` or
315    ///   `features.len() != n_points * in_dim`.
316    /// - [`VisionError::NonFinite`] if any output is non-finite.
317    pub fn forward_detailed(
318        &self,
319        points: &[f32],
320        features: &[f32],
321        n_points: usize,
322    ) -> VisionResult<PointAttention> {
323        self.compute(points, features, n_points, true)
324    }
325
326    /// Forward pass with the position encoding `δ` forced to zero — used to
327    /// verify that `δ` genuinely influences the output.
328    ///
329    /// # Errors
330    /// Same as [`PointTransformerLayer::forward_detailed`].
331    pub fn forward_zero_position(
332        &self,
333        points: &[f32],
334        features: &[f32],
335        n_points: usize,
336    ) -> VisionResult<PointAttention> {
337        self.compute(points, features, n_points, false)
338    }
339
340    /// Core computation. When `use_delta` is false the position encoding `δ` is
341    /// dropped from both the attention relation and the value aggregation.
342    fn compute(
343        &self,
344        points: &[f32],
345        features: &[f32],
346        n_points: usize,
347        use_delta: bool,
348    ) -> VisionResult<PointAttention> {
349        if n_points == 0 {
350            return Err(VisionError::EmptyInput("point transformer points"));
351        }
352        if points.len() != n_points * COORD_DIM {
353            return Err(VisionError::DimensionMismatch {
354                expected: n_points * COORD_DIM,
355                got: points.len(),
356            });
357        }
358        if features.len() != n_points * self.cfg.in_dim {
359            return Err(VisionError::DimensionMismatch {
360                expected: n_points * self.cfg.in_dim,
361                got: features.len(),
362            });
363        }
364
365        let d = self.cfg.dim;
366        let din = self.cfg.in_dim;
367        let k = self.cfg.k.min(n_points);
368
369        // Pre-compute φ / ψ / α for all points (each independent of neighbours).
370        let mut phi_all = vec![0.0f32; n_points * d];
371        let mut psi_all = vec![0.0f32; n_points * d];
372        let mut alpha_all = vec![0.0f32; n_points * d];
373        for p in 0..n_points {
374            let xf = &features[p * din..(p + 1) * din];
375            phi_all[p * d..(p + 1) * d].copy_from_slice(&self.phi.apply(xf));
376            psi_all[p * d..(p + 1) * d].copy_from_slice(&self.psi.apply(xf));
377            alpha_all[p * d..(p + 1) * d].copy_from_slice(&self.alpha.apply(xf));
378        }
379
380        let mut out_features = vec![0.0f32; n_points * self.cfg.out_dim];
381        let mut all_neighbors = vec![0usize; n_points * k];
382        let mut all_weights = vec![0.0f32; n_points * k * d];
383
384        for i in 0..n_points {
385            let neighbors = knn(points, n_points, i, self.cfg.k);
386            debug_assert_eq!(neighbors.len(), k);
387            all_neighbors[i * k..(i + 1) * k].copy_from_slice(&neighbors);
388
389            let phi_i = &phi_all[i * d..(i + 1) * d];
390            let pi = &points[i * COORD_DIM..i * COORD_DIM + COORD_DIM];
391
392            // Per-neighbour: position encoding δ, attention logits γ(relation),
393            // and the value vector (α(x_j) + δ).
394            let mut deltas = vec![0.0f32; k * d];
395            let mut logits = vec![0.0f32; k * d];
396            let mut values = vec![0.0f32; k * d];
397
398            for (s, &j) in neighbors.iter().enumerate() {
399                // δ_ij = θ(p_i − p_j)  (relative offset only).
400                let pj = &points[j * COORD_DIM..j * COORD_DIM + COORD_DIM];
401                let rel = [pi[0] - pj[0], pi[1] - pj[1], pi[2] - pj[2]];
402                let delta = if use_delta {
403                    self.theta.apply(&rel)
404                } else {
405                    vec![0.0f32; d]
406                };
407
408                // relation = φ(x_i) − ψ(x_j) + δ_ij
409                let psi_j = &psi_all[j * d..(j + 1) * d];
410                let alpha_j = &alpha_all[j * d..(j + 1) * d];
411                let mut relation = vec![0.0f32; d];
412                for c in 0..d {
413                    relation[c] = phi_i[c] - psi_j[c] + delta[c];
414                }
415                let g = self.gamma.apply(&relation);
416
417                let row = s * d;
418                for c in 0..d {
419                    logits[row + c] = g[c];
420                    values[row + c] = alpha_j[c] + delta[c];
421                    deltas[row + c] = delta[c];
422                }
423            }
424            let _ = &deltas; // retained for clarity of the value construction above
425
426            // Per-channel softmax over the k neighbours.
427            softmax_over_neighbors(&mut logits, k, d);
428            all_weights[i * k * d..(i + 1) * k * d].copy_from_slice(&logits);
429
430            // Aggregate: y_i[c] = Σ_s weight[s, c] · value[s, c].
431            let mut y_i = vec![0.0f32; d];
432            for s in 0..k {
433                let row = s * d;
434                for c in 0..d {
435                    y_i[c] += logits[row + c] * values[row + c];
436                }
437            }
438
439            let proj = self.out_proj.apply(&y_i);
440            out_features[i * self.cfg.out_dim..(i + 1) * self.cfg.out_dim].copy_from_slice(&proj);
441        }
442
443        if out_features.iter().any(|v| !v.is_finite()) {
444            return Err(VisionError::NonFinite("point transformer output"));
445        }
446
447        Ok(PointAttention {
448            features: out_features,
449            neighbors: all_neighbors,
450            weights: all_weights,
451            n_points,
452            k,
453            dim: d,
454        })
455    }
456}
457
458/// In-place per-channel softmax over the neighbour axis.
459///
460/// `logits` is `[k × d]` row-major. For each channel `c`, the `k` neighbour
461/// values are normalised with a numerically-stable softmax (max subtraction).
462fn softmax_over_neighbors(logits: &mut [f32], k: usize, d: usize) {
463    for c in 0..d {
464        // Stable max over neighbours for this channel.
465        let mut mx = f32::NEG_INFINITY;
466        for s in 0..k {
467            mx = mx.max(logits[s * d + c]);
468        }
469        let mut sum = 0.0f32;
470        for s in 0..k {
471            let e = (logits[s * d + c] - mx).exp();
472            logits[s * d + c] = e;
473            sum += e;
474        }
475        let inv = if sum > 0.0 { 1.0 / sum } else { 1.0 };
476        for s in 0..k {
477            logits[s * d + c] *= inv;
478        }
479    }
480}
481
482// ─── Tests ───────────────────────────────────────────────────────────────────
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487
488    /// Deterministic pseudo-random point cloud with distinct coordinates.
489    fn make_cloud(n: usize, seed: u64) -> (Vec<f32>, Vec<f32>) {
490        let mut rng = LcgRng::new(seed);
491        let mut points = vec![0.0f32; n * COORD_DIM];
492        // Spread points out so kNN has no distance ties.
493        for (idx, p) in points.iter_mut().enumerate() {
494            *p = rng.next_f32() * 10.0 + idx as f32 * 0.01;
495        }
496        let mut feats = vec![0.0f32; n * 8];
497        rng.fill_normal(&mut feats);
498        (points, feats)
499    }
500
501    // ── Config ─────────────────────────────────────────────────────────────────
502
503    #[test]
504    fn config_tiny_valid() {
505        let cfg = PointTransformerConfig::tiny();
506        assert_eq!(cfg.dim, 8);
507        assert_eq!(cfg.k, 4);
508    }
509
510    #[test]
511    fn config_zero_dim_errors() {
512        assert!(matches!(
513            PointTransformerConfig::new(0, 8, 8, 8, 8, 4),
514            Err(VisionError::InvalidEmbedDim(0))
515        ));
516        assert!(matches!(
517            PointTransformerConfig::new(8, 8, 8, 8, 8, 0),
518            Err(VisionError::EmptyInput(_))
519        ));
520    }
521
522    // ── kNN ────────────────────────────────────────────────────────────────────
523
524    #[test]
525    fn knn_picks_genuine_nearest() {
526        // Points on the X axis at 0,1,2,3,4. Nearest of point 0 (incl. self):
527        // {0,1,2}; nearest of point 2: {2,1 or 3,...}. Hand-checked.
528        let points = vec![
529            0.0f32, 0.0, 0.0, // 0
530            1.0, 0.0, 0.0, // 1
531            2.0, 0.0, 0.0, // 2
532            3.0, 0.0, 0.0, // 3
533            4.0, 0.0, 0.0, // 4
534        ];
535        let nn0 = knn(&points, 5, 0, 3);
536        assert_eq!(nn0, vec![0, 1, 2], "point 0 nearest set");
537
538        let nn2 = knn(&points, 5, 2, 3);
539        // distances from 2: self 0, then 1 and 3 both at distance 1 (tie),
540        // index tie-break picks 1 before 3.
541        assert_eq!(nn2[0], 2, "self is nearest");
542        assert!(nn2.contains(&1) && nn2.contains(&3), "both unit neighbours");
543
544        let nn4 = knn(&points, 5, 4, 2);
545        assert_eq!(nn4, vec![4, 3], "point 4 nearest set");
546    }
547
548    #[test]
549    fn knn_clamps_k_to_n() {
550        let points = vec![0.0f32, 0.0, 0.0, 1.0, 0.0, 0.0];
551        let nn = knn(&points, 2, 0, 10);
552        assert_eq!(nn.len(), 2, "k clamped to n_points");
553    }
554
555    // ── Shapes & finiteness ──────────────────────────────────────────────────
556
557    #[test]
558    fn forward_shapes_and_finite() {
559        let n = 16;
560        let (points, feats) = make_cloud(n, 1);
561        let mut rng = LcgRng::new(2);
562        let layer = PointTransformerLayer::new(PointTransformerConfig::tiny(), &mut rng);
563        let out = layer.forward_detailed(&points, &feats, n).expect("ok");
564        assert_eq!(out.features.len(), n * 8);
565        assert_eq!(out.neighbors.len(), n * 4);
566        assert_eq!(out.weights.len(), n * 4 * 8);
567        assert!(out.features.iter().all(|v| v.is_finite()));
568    }
569
570    #[test]
571    fn forward_wrong_feature_len_errors() {
572        let n = 8;
573        let (points, _) = make_cloud(n, 3);
574        let mut rng = LcgRng::new(4);
575        let layer = PointTransformerLayer::new(PointTransformerConfig::tiny(), &mut rng);
576        let bad = vec![0.0f32; n * 4]; // in_dim is 8
577        let r = layer.forward(&points, &bad, n);
578        assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
579    }
580
581    // ── Per-channel softmax weights ──────────────────────────────────────────
582
583    #[test]
584    fn attention_weights_nonneg_and_sum_to_one_per_channel() {
585        let n = 12;
586        let (points, feats) = make_cloud(n, 5);
587        let mut rng = LcgRng::new(6);
588        let layer = PointTransformerLayer::new(PointTransformerConfig::tiny(), &mut rng);
589        let out = layer.forward_detailed(&points, &feats, n).expect("ok");
590        let k = out.k;
591        let d = out.dim;
592        for i in 0..n {
593            for c in 0..d {
594                let mut sum = 0.0f32;
595                for s in 0..k {
596                    let w = out.weights[(i * k + s) * d + c];
597                    assert!(w >= 0.0, "weight must be non-negative, got {w}");
598                    sum += w;
599                }
600                assert!(
601                    (sum - 1.0).abs() < 1e-4,
602                    "point {i} channel {c} weights sum {sum} != 1"
603                );
604            }
605        }
606    }
607
608    // ── Permutation equivariance ─────────────────────────────────────────────
609
610    #[test]
611    fn permutation_equivariance() {
612        let n = 16;
613        let (points, feats) = make_cloud(n, 7);
614        let mut rng = LcgRng::new(8);
615        let layer = PointTransformerLayer::new(PointTransformerConfig::tiny(), &mut rng);
616        let din = 8;
617        let dout = 8;
618
619        let base = layer.forward(&points, &feats, n).expect("ok");
620
621        // A non-trivial permutation of the point indices.
622        let mut perm: Vec<usize> = (0..n).collect();
623        let mut prng = LcgRng::new(123);
624        prng.shuffle(&mut perm);
625
626        // Build permuted point cloud: row r of the new arrays = row perm[r] of old.
627        let mut p_points = vec![0.0f32; n * COORD_DIM];
628        let mut p_feats = vec![0.0f32; n * din];
629        for (r, &src) in perm.iter().enumerate() {
630            p_points[r * COORD_DIM..(r + 1) * COORD_DIM]
631                .copy_from_slice(&points[src * COORD_DIM..(src + 1) * COORD_DIM]);
632            p_feats[r * din..(r + 1) * din].copy_from_slice(&feats[src * din..(src + 1) * din]);
633        }
634
635        let permuted = layer.forward(&p_points, &p_feats, n).expect("ok");
636
637        // out_permuted[r] must equal out_base[perm[r]].
638        for (r, &src) in perm.iter().enumerate() {
639            for c in 0..dout {
640                let a = permuted[r * dout + c];
641                let b = base[src * dout + c];
642                assert!(
643                    (a - b).abs() < 1e-4,
644                    "equivariance broken at row {r} ch {c}: {a} vs {b}"
645                );
646            }
647        }
648    }
649
650    // ── Position encoding δ matters ──────────────────────────────────────────
651
652    #[test]
653    fn position_encoding_changes_output() {
654        let n = 14;
655        let (points, feats) = make_cloud(n, 9);
656        let mut rng = LcgRng::new(10);
657        let layer = PointTransformerLayer::new(PointTransformerConfig::tiny(), &mut rng);
658        let with_pos = layer.forward_detailed(&points, &feats, n).expect("ok");
659        let no_pos = layer.forward_zero_position(&points, &feats, n).expect("ok");
660        let diff: f32 = with_pos
661            .features
662            .iter()
663            .zip(no_pos.features.iter())
664            .map(|(a, b)| (a - b).abs())
665            .sum();
666        assert!(
667            diff > 1e-3,
668            "position encoding δ should change the output, diff={diff}"
669        );
670    }
671
672    // ── Translation invariance of the attention weights ──────────────────────
673
674    #[test]
675    fn translation_leaves_relative_attention_unchanged() {
676        let n = 16;
677        let (points, feats) = make_cloud(n, 11);
678        let mut rng = LcgRng::new(12);
679        let layer = PointTransformerLayer::new(PointTransformerConfig::tiny(), &mut rng);
680
681        let base = layer.forward_detailed(&points, &feats, n).expect("ok");
682
683        // Translate every point by a constant offset.
684        let mut shifted = points.clone();
685        let offset = [3.5f32, -2.0, 7.25];
686        for p in 0..n {
687            for c in 0..COORD_DIM {
688                shifted[p * COORD_DIM + c] += offset[c];
689            }
690        }
691        let moved = layer.forward_detailed(&shifted, &feats, n).expect("ok");
692
693        // Relative offsets (p_i − p_j) are unchanged → δ, γ and the softmax
694        // weights must be identical, and so must the neighbourhoods.
695        assert_eq!(
696            base.neighbors, moved.neighbors,
697            "kNN changed under translation"
698        );
699        for (a, b) in base.weights.iter().zip(moved.weights.iter()) {
700            assert!(
701                (a - b).abs() < 1e-5,
702                "attention weights changed under translation: {a} vs {b}"
703            );
704        }
705        // Outputs are translation-invariant too (features unchanged, δ unchanged).
706        for (a, b) in base.features.iter().zip(moved.features.iter()) {
707            assert!((a - b).abs() < 1e-4, "output changed under translation");
708        }
709    }
710
711    // ── Determinism ───────────────────────────────────────────────────────────
712
713    #[test]
714    fn deterministic_same_seed() {
715        let n = 10;
716        let (points, feats) = make_cloud(n, 13);
717        let mut rng_a = LcgRng::new(55);
718        let mut rng_b = LcgRng::new(55);
719        let la = PointTransformerLayer::new(PointTransformerConfig::tiny(), &mut rng_a);
720        let lb = PointTransformerLayer::new(PointTransformerConfig::tiny(), &mut rng_b);
721        let oa = la.forward(&points, &feats, n).expect("ok");
722        let ob = lb.forward(&points, &feats, n).expect("ok");
723        assert_eq!(oa, ob, "same seed must produce identical output");
724    }
725}