1use crate::{
43 error::{VisionError, VisionResult},
44 handle::LcgRng,
45 vit::vit_block::linear,
46};
47
48const COORD_DIM: usize = 3;
50
51#[derive(Debug, Clone)]
58struct Linear {
59 weight: Vec<f32>,
60 bias: Vec<f32>,
61 n_in: usize,
62 n_out: usize,
63}
64
65impl Linear {
66 fn new(n_in: usize, n_out: usize, scale: f32, rng: &mut LcgRng) -> Self {
68 let mut weight = vec![0.0f32; n_in * n_out];
69 rng.fill_normal(&mut weight);
70 for w in &mut weight {
71 *w *= scale;
72 }
73 Self {
74 weight,
75 bias: vec![0.0f32; n_out],
76 n_in,
77 n_out,
78 }
79 }
80
81 #[inline]
83 fn apply(&self, x: &[f32]) -> Vec<f32> {
84 linear(x, &self.weight, &self.bias, self.n_in, self.n_out)
85 }
86}
87
88#[derive(Debug, Clone)]
93struct Mlp {
94 fc1: Linear,
95 fc2: Linear,
96}
97
98impl Mlp {
99 fn new(n_in: usize, hidden: usize, n_out: usize, rng: &mut LcgRng) -> Self {
100 let s1 = (2.0 / n_in as f32).sqrt();
102 let s2 = (2.0 / hidden as f32).sqrt();
103 Self {
104 fc1: Linear::new(n_in, hidden, s1, rng),
105 fc2: Linear::new(hidden, n_out, s2, rng),
106 }
107 }
108
109 #[inline]
111 fn apply(&self, x: &[f32]) -> Vec<f32> {
112 let mut h = self.fc1.apply(x);
113 for v in &mut h {
114 *v = v.max(0.0); }
116 self.fc2.apply(&h)
117 }
118}
119
120#[derive(Debug, Clone, PartialEq)]
124pub struct PointTransformerConfig {
125 pub in_dim: usize,
127 pub dim: usize,
129 pub out_dim: usize,
131 pub pos_hidden: usize,
133 pub attn_hidden: usize,
135 pub k: usize,
137}
138
139impl PointTransformerConfig {
140 pub fn new(
147 in_dim: usize,
148 dim: usize,
149 out_dim: usize,
150 pos_hidden: usize,
151 attn_hidden: usize,
152 k: usize,
153 ) -> VisionResult<Self> {
154 if in_dim == 0 {
155 return Err(VisionError::InvalidEmbedDim(in_dim));
156 }
157 if dim == 0 {
158 return Err(VisionError::InvalidEmbedDim(dim));
159 }
160 if out_dim == 0 {
161 return Err(VisionError::InvalidEmbedDim(out_dim));
162 }
163 if k == 0 {
164 return Err(VisionError::EmptyInput("point transformer k"));
165 }
166 if pos_hidden == 0 {
167 return Err(VisionError::EmptyInput("point transformer pos_hidden"));
168 }
169 if attn_hidden == 0 {
170 return Err(VisionError::EmptyInput("point transformer attn_hidden"));
171 }
172 Ok(Self {
173 in_dim,
174 dim,
175 out_dim,
176 pos_hidden,
177 attn_hidden,
178 k,
179 })
180 }
181
182 #[must_use]
185 pub fn tiny() -> Self {
186 Self {
187 in_dim: 8,
188 dim: 8,
189 out_dim: 8,
190 pos_hidden: 8,
191 attn_hidden: 8,
192 k: 4,
193 }
194 }
195}
196
197#[derive(Debug, Clone)]
202pub struct PointAttention {
203 pub features: Vec<f32>,
205 pub neighbors: Vec<usize>,
207 pub weights: Vec<f32>,
212 pub n_points: usize,
214 pub k: usize,
216 pub dim: usize,
218}
219
220fn knn(points: &[f32], n: usize, i: usize, k: usize) -> Vec<usize> {
228 let pi = &points[i * COORD_DIM..i * COORD_DIM + COORD_DIM];
229 let mut dists: Vec<(f32, usize)> = (0..n)
230 .map(|j| {
231 let pj = &points[j * COORD_DIM..j * COORD_DIM + COORD_DIM];
232 let mut d = 0.0f32;
233 for c in 0..COORD_DIM {
234 let diff = pi[c] - pj[c];
235 d += diff * diff;
236 }
237 (d, j)
238 })
239 .collect();
240 dists.sort_by(|a, b| {
242 a.0.partial_cmp(&b.0)
243 .unwrap_or(std::cmp::Ordering::Equal)
244 .then(a.1.cmp(&b.1))
245 });
246 let kk = k.min(n);
247 dists.into_iter().take(kk).map(|(_, j)| j).collect()
248}
249
250pub struct PointTransformerLayer {
254 cfg: PointTransformerConfig,
255 phi: Linear,
257 psi: Linear,
259 alpha: Linear,
261 theta: Mlp,
263 gamma: Mlp,
265 out_proj: Linear,
267}
268
269impl PointTransformerLayer {
270 pub fn new(cfg: PointTransformerConfig, rng: &mut LcgRng) -> Self {
272 let proj_scale = 1.0 / (cfg.in_dim as f32).sqrt();
273 let phi = Linear::new(cfg.in_dim, cfg.dim, proj_scale, rng);
274 let psi = Linear::new(cfg.in_dim, cfg.dim, proj_scale, rng);
275 let alpha = Linear::new(cfg.in_dim, cfg.dim, proj_scale, rng);
276 let theta = Mlp::new(COORD_DIM, cfg.pos_hidden, cfg.dim, rng);
277 let gamma = Mlp::new(cfg.dim, cfg.attn_hidden, cfg.dim, rng);
278 let out_proj = Linear::new(cfg.dim, cfg.out_dim, 1.0 / (cfg.dim as f32).sqrt(), rng);
279 Self {
280 cfg,
281 phi,
282 psi,
283 alpha,
284 theta,
285 gamma,
286 out_proj,
287 }
288 }
289
290 #[must_use]
292 #[inline]
293 pub fn config(&self) -> &PointTransformerConfig {
294 &self.cfg
295 }
296
297 pub fn forward(
302 &self,
303 points: &[f32],
304 features: &[f32],
305 n_points: usize,
306 ) -> VisionResult<Vec<f32>> {
307 Ok(self.compute(points, features, n_points, true)?.features)
308 }
309
310 pub fn forward_detailed(
318 &self,
319 points: &[f32],
320 features: &[f32],
321 n_points: usize,
322 ) -> VisionResult<PointAttention> {
323 self.compute(points, features, n_points, true)
324 }
325
326 pub fn forward_zero_position(
332 &self,
333 points: &[f32],
334 features: &[f32],
335 n_points: usize,
336 ) -> VisionResult<PointAttention> {
337 self.compute(points, features, n_points, false)
338 }
339
340 fn compute(
343 &self,
344 points: &[f32],
345 features: &[f32],
346 n_points: usize,
347 use_delta: bool,
348 ) -> VisionResult<PointAttention> {
349 if n_points == 0 {
350 return Err(VisionError::EmptyInput("point transformer points"));
351 }
352 if points.len() != n_points * COORD_DIM {
353 return Err(VisionError::DimensionMismatch {
354 expected: n_points * COORD_DIM,
355 got: points.len(),
356 });
357 }
358 if features.len() != n_points * self.cfg.in_dim {
359 return Err(VisionError::DimensionMismatch {
360 expected: n_points * self.cfg.in_dim,
361 got: features.len(),
362 });
363 }
364
365 let d = self.cfg.dim;
366 let din = self.cfg.in_dim;
367 let k = self.cfg.k.min(n_points);
368
369 let mut phi_all = vec![0.0f32; n_points * d];
371 let mut psi_all = vec![0.0f32; n_points * d];
372 let mut alpha_all = vec![0.0f32; n_points * d];
373 for p in 0..n_points {
374 let xf = &features[p * din..(p + 1) * din];
375 phi_all[p * d..(p + 1) * d].copy_from_slice(&self.phi.apply(xf));
376 psi_all[p * d..(p + 1) * d].copy_from_slice(&self.psi.apply(xf));
377 alpha_all[p * d..(p + 1) * d].copy_from_slice(&self.alpha.apply(xf));
378 }
379
380 let mut out_features = vec![0.0f32; n_points * self.cfg.out_dim];
381 let mut all_neighbors = vec![0usize; n_points * k];
382 let mut all_weights = vec![0.0f32; n_points * k * d];
383
384 for i in 0..n_points {
385 let neighbors = knn(points, n_points, i, self.cfg.k);
386 debug_assert_eq!(neighbors.len(), k);
387 all_neighbors[i * k..(i + 1) * k].copy_from_slice(&neighbors);
388
389 let phi_i = &phi_all[i * d..(i + 1) * d];
390 let pi = &points[i * COORD_DIM..i * COORD_DIM + COORD_DIM];
391
392 let mut deltas = vec![0.0f32; k * d];
395 let mut logits = vec![0.0f32; k * d];
396 let mut values = vec![0.0f32; k * d];
397
398 for (s, &j) in neighbors.iter().enumerate() {
399 let pj = &points[j * COORD_DIM..j * COORD_DIM + COORD_DIM];
401 let rel = [pi[0] - pj[0], pi[1] - pj[1], pi[2] - pj[2]];
402 let delta = if use_delta {
403 self.theta.apply(&rel)
404 } else {
405 vec![0.0f32; d]
406 };
407
408 let psi_j = &psi_all[j * d..(j + 1) * d];
410 let alpha_j = &alpha_all[j * d..(j + 1) * d];
411 let mut relation = vec![0.0f32; d];
412 for c in 0..d {
413 relation[c] = phi_i[c] - psi_j[c] + delta[c];
414 }
415 let g = self.gamma.apply(&relation);
416
417 let row = s * d;
418 for c in 0..d {
419 logits[row + c] = g[c];
420 values[row + c] = alpha_j[c] + delta[c];
421 deltas[row + c] = delta[c];
422 }
423 }
424 let _ = &deltas; softmax_over_neighbors(&mut logits, k, d);
428 all_weights[i * k * d..(i + 1) * k * d].copy_from_slice(&logits);
429
430 let mut y_i = vec![0.0f32; d];
432 for s in 0..k {
433 let row = s * d;
434 for c in 0..d {
435 y_i[c] += logits[row + c] * values[row + c];
436 }
437 }
438
439 let proj = self.out_proj.apply(&y_i);
440 out_features[i * self.cfg.out_dim..(i + 1) * self.cfg.out_dim].copy_from_slice(&proj);
441 }
442
443 if out_features.iter().any(|v| !v.is_finite()) {
444 return Err(VisionError::NonFinite("point transformer output"));
445 }
446
447 Ok(PointAttention {
448 features: out_features,
449 neighbors: all_neighbors,
450 weights: all_weights,
451 n_points,
452 k,
453 dim: d,
454 })
455 }
456}
457
458fn softmax_over_neighbors(logits: &mut [f32], k: usize, d: usize) {
463 for c in 0..d {
464 let mut mx = f32::NEG_INFINITY;
466 for s in 0..k {
467 mx = mx.max(logits[s * d + c]);
468 }
469 let mut sum = 0.0f32;
470 for s in 0..k {
471 let e = (logits[s * d + c] - mx).exp();
472 logits[s * d + c] = e;
473 sum += e;
474 }
475 let inv = if sum > 0.0 { 1.0 / sum } else { 1.0 };
476 for s in 0..k {
477 logits[s * d + c] *= inv;
478 }
479 }
480}
481
482#[cfg(test)]
485mod tests {
486 use super::*;
487
488 fn make_cloud(n: usize, seed: u64) -> (Vec<f32>, Vec<f32>) {
490 let mut rng = LcgRng::new(seed);
491 let mut points = vec![0.0f32; n * COORD_DIM];
492 for (idx, p) in points.iter_mut().enumerate() {
494 *p = rng.next_f32() * 10.0 + idx as f32 * 0.01;
495 }
496 let mut feats = vec![0.0f32; n * 8];
497 rng.fill_normal(&mut feats);
498 (points, feats)
499 }
500
501 #[test]
504 fn config_tiny_valid() {
505 let cfg = PointTransformerConfig::tiny();
506 assert_eq!(cfg.dim, 8);
507 assert_eq!(cfg.k, 4);
508 }
509
510 #[test]
511 fn config_zero_dim_errors() {
512 assert!(matches!(
513 PointTransformerConfig::new(0, 8, 8, 8, 8, 4),
514 Err(VisionError::InvalidEmbedDim(0))
515 ));
516 assert!(matches!(
517 PointTransformerConfig::new(8, 8, 8, 8, 8, 0),
518 Err(VisionError::EmptyInput(_))
519 ));
520 }
521
522 #[test]
525 fn knn_picks_genuine_nearest() {
526 let points = vec![
529 0.0f32, 0.0, 0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 4.0, 0.0, 0.0, ];
535 let nn0 = knn(&points, 5, 0, 3);
536 assert_eq!(nn0, vec![0, 1, 2], "point 0 nearest set");
537
538 let nn2 = knn(&points, 5, 2, 3);
539 assert_eq!(nn2[0], 2, "self is nearest");
542 assert!(nn2.contains(&1) && nn2.contains(&3), "both unit neighbours");
543
544 let nn4 = knn(&points, 5, 4, 2);
545 assert_eq!(nn4, vec![4, 3], "point 4 nearest set");
546 }
547
548 #[test]
549 fn knn_clamps_k_to_n() {
550 let points = vec![0.0f32, 0.0, 0.0, 1.0, 0.0, 0.0];
551 let nn = knn(&points, 2, 0, 10);
552 assert_eq!(nn.len(), 2, "k clamped to n_points");
553 }
554
555 #[test]
558 fn forward_shapes_and_finite() {
559 let n = 16;
560 let (points, feats) = make_cloud(n, 1);
561 let mut rng = LcgRng::new(2);
562 let layer = PointTransformerLayer::new(PointTransformerConfig::tiny(), &mut rng);
563 let out = layer.forward_detailed(&points, &feats, n).expect("ok");
564 assert_eq!(out.features.len(), n * 8);
565 assert_eq!(out.neighbors.len(), n * 4);
566 assert_eq!(out.weights.len(), n * 4 * 8);
567 assert!(out.features.iter().all(|v| v.is_finite()));
568 }
569
570 #[test]
571 fn forward_wrong_feature_len_errors() {
572 let n = 8;
573 let (points, _) = make_cloud(n, 3);
574 let mut rng = LcgRng::new(4);
575 let layer = PointTransformerLayer::new(PointTransformerConfig::tiny(), &mut rng);
576 let bad = vec![0.0f32; n * 4]; let r = layer.forward(&points, &bad, n);
578 assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
579 }
580
581 #[test]
584 fn attention_weights_nonneg_and_sum_to_one_per_channel() {
585 let n = 12;
586 let (points, feats) = make_cloud(n, 5);
587 let mut rng = LcgRng::new(6);
588 let layer = PointTransformerLayer::new(PointTransformerConfig::tiny(), &mut rng);
589 let out = layer.forward_detailed(&points, &feats, n).expect("ok");
590 let k = out.k;
591 let d = out.dim;
592 for i in 0..n {
593 for c in 0..d {
594 let mut sum = 0.0f32;
595 for s in 0..k {
596 let w = out.weights[(i * k + s) * d + c];
597 assert!(w >= 0.0, "weight must be non-negative, got {w}");
598 sum += w;
599 }
600 assert!(
601 (sum - 1.0).abs() < 1e-4,
602 "point {i} channel {c} weights sum {sum} != 1"
603 );
604 }
605 }
606 }
607
608 #[test]
611 fn permutation_equivariance() {
612 let n = 16;
613 let (points, feats) = make_cloud(n, 7);
614 let mut rng = LcgRng::new(8);
615 let layer = PointTransformerLayer::new(PointTransformerConfig::tiny(), &mut rng);
616 let din = 8;
617 let dout = 8;
618
619 let base = layer.forward(&points, &feats, n).expect("ok");
620
621 let mut perm: Vec<usize> = (0..n).collect();
623 let mut prng = LcgRng::new(123);
624 prng.shuffle(&mut perm);
625
626 let mut p_points = vec![0.0f32; n * COORD_DIM];
628 let mut p_feats = vec![0.0f32; n * din];
629 for (r, &src) in perm.iter().enumerate() {
630 p_points[r * COORD_DIM..(r + 1) * COORD_DIM]
631 .copy_from_slice(&points[src * COORD_DIM..(src + 1) * COORD_DIM]);
632 p_feats[r * din..(r + 1) * din].copy_from_slice(&feats[src * din..(src + 1) * din]);
633 }
634
635 let permuted = layer.forward(&p_points, &p_feats, n).expect("ok");
636
637 for (r, &src) in perm.iter().enumerate() {
639 for c in 0..dout {
640 let a = permuted[r * dout + c];
641 let b = base[src * dout + c];
642 assert!(
643 (a - b).abs() < 1e-4,
644 "equivariance broken at row {r} ch {c}: {a} vs {b}"
645 );
646 }
647 }
648 }
649
650 #[test]
653 fn position_encoding_changes_output() {
654 let n = 14;
655 let (points, feats) = make_cloud(n, 9);
656 let mut rng = LcgRng::new(10);
657 let layer = PointTransformerLayer::new(PointTransformerConfig::tiny(), &mut rng);
658 let with_pos = layer.forward_detailed(&points, &feats, n).expect("ok");
659 let no_pos = layer.forward_zero_position(&points, &feats, n).expect("ok");
660 let diff: f32 = with_pos
661 .features
662 .iter()
663 .zip(no_pos.features.iter())
664 .map(|(a, b)| (a - b).abs())
665 .sum();
666 assert!(
667 diff > 1e-3,
668 "position encoding δ should change the output, diff={diff}"
669 );
670 }
671
672 #[test]
675 fn translation_leaves_relative_attention_unchanged() {
676 let n = 16;
677 let (points, feats) = make_cloud(n, 11);
678 let mut rng = LcgRng::new(12);
679 let layer = PointTransformerLayer::new(PointTransformerConfig::tiny(), &mut rng);
680
681 let base = layer.forward_detailed(&points, &feats, n).expect("ok");
682
683 let mut shifted = points.clone();
685 let offset = [3.5f32, -2.0, 7.25];
686 for p in 0..n {
687 for c in 0..COORD_DIM {
688 shifted[p * COORD_DIM + c] += offset[c];
689 }
690 }
691 let moved = layer.forward_detailed(&shifted, &feats, n).expect("ok");
692
693 assert_eq!(
696 base.neighbors, moved.neighbors,
697 "kNN changed under translation"
698 );
699 for (a, b) in base.weights.iter().zip(moved.weights.iter()) {
700 assert!(
701 (a - b).abs() < 1e-5,
702 "attention weights changed under translation: {a} vs {b}"
703 );
704 }
705 for (a, b) in base.features.iter().zip(moved.features.iter()) {
707 assert!((a - b).abs() < 1e-4, "output changed under translation");
708 }
709 }
710
711 #[test]
714 fn deterministic_same_seed() {
715 let n = 10;
716 let (points, feats) = make_cloud(n, 13);
717 let mut rng_a = LcgRng::new(55);
718 let mut rng_b = LcgRng::new(55);
719 let la = PointTransformerLayer::new(PointTransformerConfig::tiny(), &mut rng_a);
720 let lb = PointTransformerLayer::new(PointTransformerConfig::tiny(), &mut rng_b);
721 let oa = la.forward(&points, &feats, n).expect("ok");
722 let ob = lb.forward(&points, &feats, n).expect("ok");
723 assert_eq!(oa, ob, "same seed must produce identical output");
724 }
725}