Skip to main content

oxicuda_vision/
lib.rs

1//! `oxicuda-vision` — Vision Transformer & CLIP primitives for OxiCUDA.
2//!
3//! Pure-Rust CPU reference implementation providing:
4//! - **`patch_embed`**: strided Conv2D patch embedder, sinusoidal & learnable
5//!   positional encodings.
6//! - **`vit`**: ViT block (pre-norm MHSA + MLP), encoder stack, full ViT model,
7//!   and the Swin Transformer windowed / shifted-window block.
8//! - **`convnext`**: ConvNeXt modern-CNN block (depthwise conv + channel
9//!   LayerNorm + inverted-bottleneck + layer scale).
10//! - **`clip`**: CLIP vision encoder, projection head, InfoNCE contrastive loss.
11//! - **`augment`**: geometric, photometric, and normalisation image augmentations,
12//!   plus MixUp / CutMix batch mixing regularisers.
13//! - **`imgproc`**: classical image processing — Sobel gradients and the Canny
14//!   edge detector, binary/grayscale morphology, union-find connected-component
15//!   labelling, and the Hough line transform.
16//! - **`fpn`**: Feature Pyramid Network (lateral 1×1 convolutions + top-down pathway).
17//! - **`detection`**: RoI Align, DETR decoder, bipartite set matching,
18//!   IoU / GIoU / DIoU / CIoU box-regression losses, the RTMDet detector
19//!   (CSPNeXt backbone + PAFPN neck + decoupled head + SimOTA-lite cost), and
20//!   the OWL-ViT open-vocabulary detector (per-patch image-text matching).
21//! - **`segmentation`**: the Segment Anything Model (SAM) — ViT image encoder,
22//!   prompt encoder, and a two-way transformer mask decoder.
23//! - **`ssl`**: the DINOv2 self-supervised distillation recipe — ViT backbone
24//!   (`[CLS]` + patch tokens), weight-normalised prototype head, centred /
25//!   sharpened teacher-student DINO loss, EMA teacher, centering buffer, and the
26//!   iBOT masked-patch term.
27//! - **`text`**: the CLIP Transformer text encoder — token + positional
28//!   embeddings, causal self-attention blocks, EOS pooling, and joint-space
29//!   projection.
30//! - **`pointcloud`**: the Point Transformer vector self-attention layer over
31//!   kNN neighbourhoods.
32//! - **`losses`**: focal loss (sigmoid & softmax) and soft Dice segmentation loss.
33//! - **`ptx_kernels`**: 7 GPU PTX kernel string generators (SM 7.5–12.0).
34//!
35//! No CUDA SDK dependency; all forward passes run on CPU `f32` tensors
36//! using flat row-major `Vec<f32>` layouts.
37
38pub mod augment;
39pub mod blocks;
40pub mod clip;
41pub mod convnext;
42pub mod detection;
43pub mod error;
44pub mod fpn;
45pub mod handle;
46pub mod imgproc;
47pub mod losses;
48pub mod optimize;
49pub mod patch_embed;
50pub mod pointcloud;
51pub mod ptx_kernels;
52pub mod segmentation;
53pub mod ssl;
54pub mod text;
55pub mod vit;
56
57pub use error::{VisionError, VisionResult};
58pub use handle::{LcgRng, SmVersion, VisionHandle};
59
60// ─── Prelude ─────────────────────────────────────────────────────────────────
61
62pub mod prelude {
63    pub use crate::augment::{AugOp, MixOutput, Pipeline, cutmix, mixup};
64    pub use crate::clip::{
65        ClipVisionConfig, ClipVisionEncoder, ProjectionHead, contrastive::info_nce_loss,
66    };
67    pub use crate::convnext::block::{ConvNextBlock, ConvNextConfig};
68    pub use crate::detection::{
69        AnchorConfig, AnchorGenerator, BBox, DetrConfig, DetrDecoder, IouBox, IouLossKind,
70        MaskHead, MaskHeadConfig, OwlVit, OwlVitConfig, OwlVitOutput, RtmDet, RtmDetConfig,
71        RtmDetOutput, bipartite_match, ciou_loss, decode_level, diou_loss, giou_loss, iou,
72        iou_loss, iou_loss_pairs, nms, roi_align, simota_cost, soft_nms,
73    };
74    pub use crate::error::{VisionError, VisionResult};
75    pub use crate::fpn::{FeatureMap, Fpn, FpnConfig};
76    pub use crate::handle::{LcgRng, SmVersion, VisionHandle};
77    pub use crate::imgproc::connected_components::{
78        ComponentLabels, Connectivity, connected_components,
79    };
80    pub use crate::imgproc::edges::{SobelOutput, canny, sobel_gradients};
81    pub use crate::imgproc::hough::{
82        HoughAccumulator, HoughConfig, HoughLine, hough_accumulate, hough_lines,
83    };
84    pub use crate::imgproc::morphology::{
85        StructuringElement, close, dilate, erode, morphological_gradient, open,
86    };
87    pub use crate::losses::dice::{dice_loss, dice_loss_default, dice_loss_squared};
88    pub use crate::losses::focal::{Reduction, binary_focal_loss, multiclass_focal_loss};
89    pub use crate::losses::quality::{ms_ssim, mse, psnr, ssim, ssim_default};
90    pub use crate::patch_embed::{
91        LearnablePosEmbed, PatchEmbed, PatchEmbedConfig, add_pos_embed, pos_2d_sincos, prepend_cls,
92    };
93    pub use crate::pointcloud::{PointAttention, PointTransformerConfig, PointTransformerLayer};
94    pub use crate::ptx_kernels::{
95        adaptive_avg_pool_ptx, bilinear_interp_ptx, contrastive_loss_ptx, focal_loss_ptx,
96        image_normalize_ptx, patch_embed_ptx, roi_align_ptx,
97    };
98    pub use crate::segmentation::{
99        MaskPrediction, Sam, SamConfig, TwoWayAttentionBlock, TwoWayTransformer,
100    };
101    pub use crate::ssl::{
102        BackboneOutput, CenteringBuffer, DinoBackbone, DinoHead, cross_entropy, dino_loss,
103        ibot_loss, student_softmax, teacher_softmax,
104    };
105    pub use crate::text::{ClipTextConfig, ClipTextEncoder};
106    pub use crate::vit::swin::{SwinBlock, SwinConfig, SwinWeights};
107    pub use crate::vit::vit_patch::{VitPatchConfig, VitPatchEmbed};
108    pub use crate::vit::{ViTConfig, ViTEncoder, ViTModel};
109}
110
111// ─── End-to-end integration tests ────────────────────────────────────────────
112
113#[cfg(test)]
114mod tests {
115    use crate::{
116        augment::{AugOp, Pipeline},
117        clip::contrastive::info_nce_loss,
118        clip::{ClipVisionConfig, ClipVisionEncoder, ProjectionHead},
119        detection::{DetrConfig, DetrDecoder, bipartite_match, roi_align},
120        error::VisionError,
121        fpn::{FeatureMap, Fpn, FpnConfig, LateralConv1x1},
122        handle::{LcgRng, SmVersion, VisionHandle},
123        patch_embed::{
124            LearnablePosEmbed, PatchEmbed, PatchEmbedConfig, add_pos_embed, pos_2d_sincos,
125            prepend_cls,
126        },
127        ptx_kernels::{
128            adaptive_avg_pool_ptx, bilinear_interp_ptx, contrastive_loss_ptx, focal_loss_ptx,
129            image_normalize_ptx, patch_embed_ptx, roi_align_ptx,
130        },
131        vit::{ViTConfig, ViTModel},
132    };
133
134    // ── PTX kernel end-to-end ─────────────────────────────────────────────────
135
136    #[test]
137    #[allow(clippy::type_complexity)]
138    fn e2e_ptx_kernels_all_sm_versions() {
139        const SM_VERSIONS: &[u32] = &[75, 80, 86, 90, 100, 120];
140        let kernel_generators: &[(&str, fn(u32) -> String)] = &[
141            ("patch_embed_ptx", patch_embed_ptx),
142            ("bilinear_interp_ptx", bilinear_interp_ptx),
143            ("contrastive_loss_ptx", contrastive_loss_ptx),
144            ("roi_align_ptx", roi_align_ptx),
145            ("image_normalize_ptx", image_normalize_ptx),
146            ("adaptive_avg_pool_ptx", adaptive_avg_pool_ptx),
147            ("focal_loss_ptx", focal_loss_ptx),
148        ];
149        for &(name, kernel_fn) in kernel_generators {
150            for &sm in SM_VERSIONS {
151                let ptx = kernel_fn(sm);
152                let expected_target = format!(".target sm_{sm}");
153                assert!(
154                    ptx.contains(&expected_target),
155                    "kernel {name} sm={sm}: missing '{expected_target}' in PTX"
156                );
157                assert!(
158                    ptx.contains(".version"),
159                    "kernel {name} sm={sm}: missing .version directive"
160                );
161            }
162        }
163    }
164
165    // ── Handle ────────────────────────────────────────────────────────────────
166
167    #[test]
168    fn e2e_handle_default() {
169        let h = VisionHandle::default_handle();
170        assert_eq!(h.device(), 0);
171        assert_eq!(h.sm_version(), SmVersion(80));
172    }
173
174    #[test]
175    fn e2e_lcg_rng_reproducibility() {
176        let mut r1 = LcgRng::new(42);
177        let mut r2 = LcgRng::new(42);
178        for _ in 0..200 {
179            assert_eq!(r1.next_u32(), r2.next_u32());
180        }
181    }
182
183    // ── Patch embedding ───────────────────────────────────────────────────────
184
185    #[test]
186    fn e2e_patch_embed_shape() {
187        // 3×32×32 image, patch_size=4 → (32/4)²=64 patches, embed_dim=16
188        let cfg = PatchEmbedConfig::new(32, 4, 3, 16).expect("valid config");
189        let mut rng = LcgRng::new(1);
190        let pe = PatchEmbed::new(cfg.clone(), &mut rng);
191        let image = vec![0.5f32; 3 * 32 * 32];
192        let tokens = pe.forward(&image).expect("forward ok");
193        assert_eq!(tokens.len(), cfg.n_patches() * cfg.embed_dim);
194        assert_eq!(cfg.n_patches(), 64);
195    }
196
197    #[test]
198    fn e2e_patch_embed_cls_prepend() {
199        let cfg = PatchEmbedConfig::new(16, 4, 3, 8).expect("valid config");
200        let mut rng = LcgRng::new(2);
201        let pe = PatchEmbed::new(cfg.clone(), &mut rng);
202        let image = vec![0.0f32; 3 * 16 * 16];
203        let tokens = pe.forward(&image).expect("forward ok");
204        let with_cls =
205            prepend_cls(&tokens, &pe.weights.cls_token, cfg.embed_dim).expect("prepend ok");
206        assert_eq!(with_cls.len(), (cfg.n_patches() + 1) * cfg.embed_dim);
207    }
208
209    #[test]
210    fn e2e_pos_embed_2d_sincos_periodicity() {
211        // With grid 4×1 and dim=4, the first sine band (k=0, freq=1) encodes row.
212        let pe = pos_2d_sincos(4, 1, 4).expect("ok");
213        // Position h=1 row index 0: sin(1*1) = sin(1)
214        let diff = (pe[4] - 1.0_f32.sin()).abs();
215        assert!(diff < 1e-5, "periodicity check failed: diff={diff}");
216    }
217
218    // ── ViT ───────────────────────────────────────────────────────────────────
219
220    #[test]
221    fn e2e_vit_block_forward_finite() {
222        use crate::vit::{ViTBlock, ViTBlockConfig};
223        let cfg = ViTBlockConfig::new(32, 4, 4).expect("valid");
224        let mut rng = LcgRng::new(3);
225        let block = ViTBlock::new(cfg, &mut rng);
226        let n_tokens = 8;
227        let mut tokens = vec![0.0f32; n_tokens * 32];
228        rng.fill_normal(&mut tokens);
229        let out = block.forward(&tokens, n_tokens).expect("forward ok");
230        assert!(
231            out.iter().all(|v| v.is_finite()),
232            "non-finite ViT block output"
233        );
234        assert_eq!(out.len(), n_tokens * 32);
235    }
236
237    #[test]
238    fn e2e_vit_model_classify_tiny() {
239        let cfg = ViTConfig::tiny();
240        let mut rng = LcgRng::new(4);
241        let model = ViTModel::new(cfg, &mut rng).expect("model ok");
242        let image = vec![0.5f32; 3 * 32 * 32];
243        let logits = model.forward(&image).expect("forward ok");
244        assert_eq!(logits.len(), 10, "expected 10 logits from tiny config");
245        assert!(logits.iter().all(|v| v.is_finite()), "non-finite logits");
246    }
247
248    // ── CLIP ──────────────────────────────────────────────────────────────────
249
250    #[test]
251    fn e2e_clip_vision_encoder_pool_shape() {
252        let vit_cfg = ViTConfig::tiny();
253        let embed_dim = vit_cfg.embed_dim;
254        let cfg = ClipVisionConfig::new(vit_cfg);
255        let mut rng = LcgRng::new(5);
256        let enc = ClipVisionEncoder::new(cfg, &mut rng).expect("encoder ok");
257        let image = vec![0.1f32; 3 * 32 * 32];
258        let emb = enc.forward_single(&image).expect("forward ok");
259        assert_eq!(emb.len(), embed_dim, "CLS pool output must be [embed_dim]");
260        assert!(emb.iter().all(|v| v.is_finite()), "non-finite embedding");
261    }
262
263    #[test]
264    fn e2e_clip_proj_l2_unit_norm() {
265        let embed_dim = 32;
266        let proj_dim = 16;
267        let mut rng = LcgRng::new(6);
268        let head = ProjectionHead::new(embed_dim, proj_dim, &mut rng).expect("ok");
269        let mut x = vec![0.0f32; embed_dim];
270        rng.fill_normal(&mut x);
271        let z = head.project(&x).expect("project ok");
272        let norm: f32 = z.iter().map(|&v| v * v).sum::<f32>().sqrt();
273        assert!(
274            (norm - 1.0).abs() < 1e-5,
275            "projected embedding not unit-norm; ‖z‖={norm}"
276        );
277    }
278
279    #[test]
280    fn e2e_clip_info_nce_symmetric() {
281        let embed_dim = 16;
282        let batch = 4;
283        let mut rng = LcgRng::new(7);
284        let mut img_e = vec![0.0f32; batch * embed_dim];
285        let mut txt_e = vec![0.0f32; batch * embed_dim];
286        rng.fill_normal(&mut img_e);
287        rng.fill_normal(&mut txt_e);
288
289        let (loss_it, _) = info_nce_loss(&img_e, &txt_e, embed_dim, 0.1).expect("ok");
290        let (loss_ti, _) = info_nce_loss(&txt_e, &img_e, embed_dim, 0.1).expect("ok");
291
292        assert!(loss_it.is_finite(), "image→text loss is not finite");
293        assert!(loss_ti.is_finite(), "text→image loss is not finite");
294        assert!(
295            (loss_it - loss_ti).abs() < 1e-4,
296            "symmetric loss mismatch: {loss_it} vs {loss_ti}"
297        );
298    }
299
300    // ── Augmentation ──────────────────────────────────────────────────────────
301
302    #[test]
303    fn e2e_augment_random_crop_dims() {
304        let img = vec![0.5f32; 3 * 64 * 64];
305        let mut rng = LcgRng::new(8);
306        let op = AugOp::RandomCrop { crop_size: 48 };
307        let (out, new_h, new_w) = op.apply(&img, 3, 64, 64, &mut rng).expect("ok");
308        assert_eq!((new_h, new_w), (48, 48));
309        assert_eq!(out.len(), 3 * 48 * 48);
310    }
311
312    #[test]
313    fn e2e_augment_normalize_imagenet() {
314        use crate::augment::normalize::{IMAGENET_MEAN, IMAGENET_STD, normalize_chw};
315        // Build image whose per-channel mean matches imagenet mean
316        let h = 8;
317        let w = 8;
318        let hw = h * w;
319        let mut img = vec![0.0f32; 3 * hw];
320        for c in 0..3 {
321            for p in 0..hw {
322                img[c * hw + p] = IMAGENET_MEAN[c];
323            }
324        }
325        let out = normalize_chw(&img, 3, h, w, &IMAGENET_MEAN, &IMAGENET_STD).expect("ok");
326        // After normalizing: all pixels ≈ 0 (mean removed)
327        let max_abs = out.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
328        assert!(
329            max_abs < 1e-5,
330            "normalized constant-mean image should be ~0; max={max_abs}"
331        );
332    }
333
334    #[test]
335    fn e2e_augment_pipeline_chain() {
336        let img = vec![0.5f32; 3 * 64 * 64];
337        let mut rng = LcgRng::new(9);
338        let pipeline = Pipeline::new()
339            .push(AugOp::Resize { target: 48 })
340            .push(AugOp::RandomCrop { crop_size: 32 })
341            .push(AugOp::HorizontalFlip { prob: 0.5 });
342        let (out, new_h, new_w) = pipeline.apply(&img, 3, 64, 64, &mut rng).expect("ok");
343        assert_eq!((new_h, new_w), (32, 32));
344        assert!(
345            out.iter().all(|v| v.is_finite()),
346            "pipeline output must be finite"
347        );
348    }
349
350    // ── FPN ───────────────────────────────────────────────────────────────────
351
352    #[test]
353    fn e2e_fpn_top_down_shape_consistency() {
354        let mut rng = LcgRng::new(10);
355        // 3 levels: [128, 64, 32] channels, sizes [4×4, 8×8, 16×16]
356        let in_channels = vec![128usize, 64, 32];
357        let out_channels = 16;
358        let cfg = FpnConfig::new(in_channels.clone(), out_channels).expect("config ok");
359        let fpn = Fpn::new(cfg, &mut rng).expect("fpn ok");
360
361        let features = vec![
362            FeatureMap::new(vec![0.1f32; 128 * 4 * 4], 128, 4, 4).expect("ok"),
363            FeatureMap::new(vec![0.1f32; 64 * 8 * 8], 64, 8, 8).expect("ok"),
364            FeatureMap::new(vec![0.1f32; 32 * 16 * 16], 32, 16, 16).expect("ok"),
365        ];
366        let pyramid = fpn.forward(features).expect("fpn forward ok");
367
368        assert_eq!(pyramid.len(), 3);
369        for fm in &pyramid {
370            assert_eq!(
371                fm.channels, out_channels,
372                "all FPN levels must have out_channels"
373            );
374        }
375        assert!(
376            pyramid
377                .iter()
378                .all(|fm| fm.data.iter().all(|v| v.is_finite()))
379        );
380    }
381
382    // ── Detection ─────────────────────────────────────────────────────────────
383
384    #[test]
385    fn e2e_roi_align_unit_box_identity() {
386        // Feature map: 1 channel, 4×4, all 1.0
387        let c = 1;
388        let h = 4;
389        let w = 4;
390        let feat = vec![1.0f32; c * h * w];
391        // RoI covering the entire feature map: [x1=0, y1=0, x2=4, y2=4]
392        let rois = vec![0.0f32, 0.0, 4.0, 4.0];
393        let out = roi_align(&feat, c, h, w, &rois, 1, 1, 1, 2).expect("ok");
394        assert_eq!(out.len(), 1);
395        // Bilinear samples of a constant-1 map → mean = 1
396        assert!(
397            (out[0] - 1.0).abs() < 1e-5,
398            "unit box over constant map should return 1.0; got {}",
399            out[0]
400        );
401    }
402
403    #[test]
404    fn e2e_detr_decoder_query_shape() {
405        let cfg = DetrConfig::tiny();
406        let mut rng = LcgRng::new(11);
407        let decoder = DetrDecoder::new(cfg.clone(), &mut rng).expect("ok");
408        let n_queries = cfg.n_queries;
409        let embed_dim = cfg.embed_dim;
410        let n_enc = 16;
411
412        let queries = vec![0.1f32; n_queries * embed_dim];
413        let enc_feats = vec![0.2f32; n_enc * embed_dim];
414        let out = decoder
415            .forward(&queries, &enc_feats, n_enc)
416            .expect("forward ok");
417
418        assert_eq!(
419            out.len(),
420            n_queries * embed_dim,
421            "decoder must preserve query shape"
422        );
423        assert!(
424            out.iter().all(|v| v.is_finite()),
425            "decoder output contains non-finite"
426        );
427    }
428
429    #[test]
430    fn e2e_set_match_self_assignment() {
431        // Cost matrix: diagonal = 0, off-diagonal = 1
432        // Greedy should find the diagonal matching
433        let n = 4;
434        let mut cost = vec![1.0f32; n * n];
435        for i in 0..n {
436            cost[i * n + i] = 0.0;
437        }
438        let matching = bipartite_match(&cost, n, n).expect("ok");
439        assert_eq!(matching.len(), n);
440        // Every diagonal pair should be matched
441        let mut assigned: Vec<(usize, usize)> = matching.clone();
442        assigned.sort_unstable();
443        for i in 0..n {
444            assert!(
445                assigned.contains(&(i, i)),
446                "identity cost matrix: query {i} should match target {i}"
447            );
448        }
449    }
450
451    #[test]
452    fn e2e_focal_loss_positive_only() {
453        // For gamma=0, alpha=1: focal_loss = -log(p), which is standard BCE
454        // The PTX kernel embeds alpha=0.25, gamma=2 — we just verify it generates valid PTX.
455        // Verify positive case: for p ≈ 1 the focal loss → 0.
456        // We use the CPU formula directly here as a sanity check.
457        let p: f32 = 0.99;
458        let alpha: f32 = 1.0;
459        let gamma: f32 = 0.0;
460        let fl = -alpha * (1.0 - p).powf(gamma) * p.ln();
461        let standard_bce = -p.ln();
462        assert!(
463            (fl - standard_bce).abs() < 1e-5,
464            "at gamma=0 focal loss == BCE; got fl={fl}, bce={standard_bce}"
465        );
466    }
467
468    // ── Learnable pos embed ───────────────────────────────────────────────────
469
470    #[test]
471    fn e2e_learnable_pos_embed_and_add() {
472        let n = 17; // 16 patches + CLS
473        let d = 32;
474        let mut rng = LcgRng::new(12);
475        let lpe = LearnablePosEmbed::new(n, d, &mut rng).expect("ok");
476        let mut tokens = vec![0.0f32; n * d];
477        add_pos_embed(&mut tokens, &lpe.table, d).expect("add ok");
478        // tokens should now equal the pos embedding
479        for (t, p) in tokens.iter().zip(lpe.table.iter()) {
480            assert!((t - p).abs() < 1e-6, "add_pos_embed mismatch");
481        }
482    }
483
484    // ── Lateral conv ─────────────────────────────────────────────────────────
485
486    #[test]
487    fn e2e_lateral_conv_output_shape() {
488        let mut rng = LcgRng::new(13);
489        let lat = LateralConv1x1::new(64, 16, &mut rng).expect("ok");
490        let feat = vec![0.5f32; 64 * 8 * 8];
491        let out = lat.forward(&feat, 8, 8).expect("ok");
492        assert_eq!(out.len(), 16 * 8 * 8);
493        assert!(out.iter().all(|v| v.is_finite()));
494    }
495
496    // ── Non-positive temperature rejects ─────────────────────────────────────
497
498    #[test]
499    fn e2e_clip_nce_nonpositive_temp_errors() {
500        let img = vec![1.0f32; 4 * 16];
501        let txt = vec![1.0f32; 4 * 16];
502        let r = info_nce_loss(&img, &txt, 16, 0.0);
503        assert!(matches!(r, Err(VisionError::NonPositiveTemperature(_))));
504    }
505}