1pub mod augment;
39pub mod blocks;
40pub mod clip;
41pub mod convnext;
42pub mod detection;
43pub mod error;
44pub mod fpn;
45pub mod handle;
46pub mod imgproc;
47pub mod losses;
48pub mod optimize;
49pub mod patch_embed;
50pub mod pointcloud;
51pub mod ptx_kernels;
52pub mod segmentation;
53pub mod ssl;
54pub mod text;
55pub mod vit;
56
57pub use error::{VisionError, VisionResult};
58pub use handle::{LcgRng, SmVersion, VisionHandle};
59
60pub mod prelude {
63 pub use crate::augment::{AugOp, MixOutput, Pipeline, cutmix, mixup};
64 pub use crate::clip::{
65 ClipVisionConfig, ClipVisionEncoder, ProjectionHead, contrastive::info_nce_loss,
66 };
67 pub use crate::convnext::block::{ConvNextBlock, ConvNextConfig};
68 pub use crate::detection::{
69 AnchorConfig, AnchorGenerator, BBox, DetrConfig, DetrDecoder, IouBox, IouLossKind,
70 MaskHead, MaskHeadConfig, OwlVit, OwlVitConfig, OwlVitOutput, RtmDet, RtmDetConfig,
71 RtmDetOutput, bipartite_match, ciou_loss, decode_level, diou_loss, giou_loss, iou,
72 iou_loss, iou_loss_pairs, nms, roi_align, simota_cost, soft_nms,
73 };
74 pub use crate::error::{VisionError, VisionResult};
75 pub use crate::fpn::{FeatureMap, Fpn, FpnConfig};
76 pub use crate::handle::{LcgRng, SmVersion, VisionHandle};
77 pub use crate::imgproc::connected_components::{
78 ComponentLabels, Connectivity, connected_components,
79 };
80 pub use crate::imgproc::edges::{SobelOutput, canny, sobel_gradients};
81 pub use crate::imgproc::hough::{
82 HoughAccumulator, HoughConfig, HoughLine, hough_accumulate, hough_lines,
83 };
84 pub use crate::imgproc::morphology::{
85 StructuringElement, close, dilate, erode, morphological_gradient, open,
86 };
87 pub use crate::losses::dice::{dice_loss, dice_loss_default, dice_loss_squared};
88 pub use crate::losses::focal::{Reduction, binary_focal_loss, multiclass_focal_loss};
89 pub use crate::losses::quality::{ms_ssim, mse, psnr, ssim, ssim_default};
90 pub use crate::patch_embed::{
91 LearnablePosEmbed, PatchEmbed, PatchEmbedConfig, add_pos_embed, pos_2d_sincos, prepend_cls,
92 };
93 pub use crate::pointcloud::{PointAttention, PointTransformerConfig, PointTransformerLayer};
94 pub use crate::ptx_kernels::{
95 adaptive_avg_pool_ptx, bilinear_interp_ptx, contrastive_loss_ptx, focal_loss_ptx,
96 image_normalize_ptx, patch_embed_ptx, roi_align_ptx,
97 };
98 pub use crate::segmentation::{
99 MaskPrediction, Sam, SamConfig, TwoWayAttentionBlock, TwoWayTransformer,
100 };
101 pub use crate::ssl::{
102 BackboneOutput, CenteringBuffer, DinoBackbone, DinoHead, cross_entropy, dino_loss,
103 ibot_loss, student_softmax, teacher_softmax,
104 };
105 pub use crate::text::{ClipTextConfig, ClipTextEncoder};
106 pub use crate::vit::swin::{SwinBlock, SwinConfig, SwinWeights};
107 pub use crate::vit::vit_patch::{VitPatchConfig, VitPatchEmbed};
108 pub use crate::vit::{ViTConfig, ViTEncoder, ViTModel};
109}
110
111#[cfg(test)]
114mod tests {
115 use crate::{
116 augment::{AugOp, Pipeline},
117 clip::contrastive::info_nce_loss,
118 clip::{ClipVisionConfig, ClipVisionEncoder, ProjectionHead},
119 detection::{DetrConfig, DetrDecoder, bipartite_match, roi_align},
120 error::VisionError,
121 fpn::{FeatureMap, Fpn, FpnConfig, LateralConv1x1},
122 handle::{LcgRng, SmVersion, VisionHandle},
123 patch_embed::{
124 LearnablePosEmbed, PatchEmbed, PatchEmbedConfig, add_pos_embed, pos_2d_sincos,
125 prepend_cls,
126 },
127 ptx_kernels::{
128 adaptive_avg_pool_ptx, bilinear_interp_ptx, contrastive_loss_ptx, focal_loss_ptx,
129 image_normalize_ptx, patch_embed_ptx, roi_align_ptx,
130 },
131 vit::{ViTConfig, ViTModel},
132 };
133
134 #[test]
137 #[allow(clippy::type_complexity)]
138 fn e2e_ptx_kernels_all_sm_versions() {
139 const SM_VERSIONS: &[u32] = &[75, 80, 86, 90, 100, 120];
140 let kernel_generators: &[(&str, fn(u32) -> String)] = &[
141 ("patch_embed_ptx", patch_embed_ptx),
142 ("bilinear_interp_ptx", bilinear_interp_ptx),
143 ("contrastive_loss_ptx", contrastive_loss_ptx),
144 ("roi_align_ptx", roi_align_ptx),
145 ("image_normalize_ptx", image_normalize_ptx),
146 ("adaptive_avg_pool_ptx", adaptive_avg_pool_ptx),
147 ("focal_loss_ptx", focal_loss_ptx),
148 ];
149 for &(name, kernel_fn) in kernel_generators {
150 for &sm in SM_VERSIONS {
151 let ptx = kernel_fn(sm);
152 let expected_target = format!(".target sm_{sm}");
153 assert!(
154 ptx.contains(&expected_target),
155 "kernel {name} sm={sm}: missing '{expected_target}' in PTX"
156 );
157 assert!(
158 ptx.contains(".version"),
159 "kernel {name} sm={sm}: missing .version directive"
160 );
161 }
162 }
163 }
164
165 #[test]
168 fn e2e_handle_default() {
169 let h = VisionHandle::default_handle();
170 assert_eq!(h.device(), 0);
171 assert_eq!(h.sm_version(), SmVersion(80));
172 }
173
174 #[test]
175 fn e2e_lcg_rng_reproducibility() {
176 let mut r1 = LcgRng::new(42);
177 let mut r2 = LcgRng::new(42);
178 for _ in 0..200 {
179 assert_eq!(r1.next_u32(), r2.next_u32());
180 }
181 }
182
183 #[test]
186 fn e2e_patch_embed_shape() {
187 let cfg = PatchEmbedConfig::new(32, 4, 3, 16).expect("valid config");
189 let mut rng = LcgRng::new(1);
190 let pe = PatchEmbed::new(cfg.clone(), &mut rng);
191 let image = vec![0.5f32; 3 * 32 * 32];
192 let tokens = pe.forward(&image).expect("forward ok");
193 assert_eq!(tokens.len(), cfg.n_patches() * cfg.embed_dim);
194 assert_eq!(cfg.n_patches(), 64);
195 }
196
197 #[test]
198 fn e2e_patch_embed_cls_prepend() {
199 let cfg = PatchEmbedConfig::new(16, 4, 3, 8).expect("valid config");
200 let mut rng = LcgRng::new(2);
201 let pe = PatchEmbed::new(cfg.clone(), &mut rng);
202 let image = vec![0.0f32; 3 * 16 * 16];
203 let tokens = pe.forward(&image).expect("forward ok");
204 let with_cls =
205 prepend_cls(&tokens, &pe.weights.cls_token, cfg.embed_dim).expect("prepend ok");
206 assert_eq!(with_cls.len(), (cfg.n_patches() + 1) * cfg.embed_dim);
207 }
208
209 #[test]
210 fn e2e_pos_embed_2d_sincos_periodicity() {
211 let pe = pos_2d_sincos(4, 1, 4).expect("ok");
213 let diff = (pe[4] - 1.0_f32.sin()).abs();
215 assert!(diff < 1e-5, "periodicity check failed: diff={diff}");
216 }
217
218 #[test]
221 fn e2e_vit_block_forward_finite() {
222 use crate::vit::{ViTBlock, ViTBlockConfig};
223 let cfg = ViTBlockConfig::new(32, 4, 4).expect("valid");
224 let mut rng = LcgRng::new(3);
225 let block = ViTBlock::new(cfg, &mut rng);
226 let n_tokens = 8;
227 let mut tokens = vec![0.0f32; n_tokens * 32];
228 rng.fill_normal(&mut tokens);
229 let out = block.forward(&tokens, n_tokens).expect("forward ok");
230 assert!(
231 out.iter().all(|v| v.is_finite()),
232 "non-finite ViT block output"
233 );
234 assert_eq!(out.len(), n_tokens * 32);
235 }
236
237 #[test]
238 fn e2e_vit_model_classify_tiny() {
239 let cfg = ViTConfig::tiny();
240 let mut rng = LcgRng::new(4);
241 let model = ViTModel::new(cfg, &mut rng).expect("model ok");
242 let image = vec![0.5f32; 3 * 32 * 32];
243 let logits = model.forward(&image).expect("forward ok");
244 assert_eq!(logits.len(), 10, "expected 10 logits from tiny config");
245 assert!(logits.iter().all(|v| v.is_finite()), "non-finite logits");
246 }
247
248 #[test]
251 fn e2e_clip_vision_encoder_pool_shape() {
252 let vit_cfg = ViTConfig::tiny();
253 let embed_dim = vit_cfg.embed_dim;
254 let cfg = ClipVisionConfig::new(vit_cfg);
255 let mut rng = LcgRng::new(5);
256 let enc = ClipVisionEncoder::new(cfg, &mut rng).expect("encoder ok");
257 let image = vec![0.1f32; 3 * 32 * 32];
258 let emb = enc.forward_single(&image).expect("forward ok");
259 assert_eq!(emb.len(), embed_dim, "CLS pool output must be [embed_dim]");
260 assert!(emb.iter().all(|v| v.is_finite()), "non-finite embedding");
261 }
262
263 #[test]
264 fn e2e_clip_proj_l2_unit_norm() {
265 let embed_dim = 32;
266 let proj_dim = 16;
267 let mut rng = LcgRng::new(6);
268 let head = ProjectionHead::new(embed_dim, proj_dim, &mut rng).expect("ok");
269 let mut x = vec![0.0f32; embed_dim];
270 rng.fill_normal(&mut x);
271 let z = head.project(&x).expect("project ok");
272 let norm: f32 = z.iter().map(|&v| v * v).sum::<f32>().sqrt();
273 assert!(
274 (norm - 1.0).abs() < 1e-5,
275 "projected embedding not unit-norm; ‖z‖={norm}"
276 );
277 }
278
279 #[test]
280 fn e2e_clip_info_nce_symmetric() {
281 let embed_dim = 16;
282 let batch = 4;
283 let mut rng = LcgRng::new(7);
284 let mut img_e = vec![0.0f32; batch * embed_dim];
285 let mut txt_e = vec![0.0f32; batch * embed_dim];
286 rng.fill_normal(&mut img_e);
287 rng.fill_normal(&mut txt_e);
288
289 let (loss_it, _) = info_nce_loss(&img_e, &txt_e, embed_dim, 0.1).expect("ok");
290 let (loss_ti, _) = info_nce_loss(&txt_e, &img_e, embed_dim, 0.1).expect("ok");
291
292 assert!(loss_it.is_finite(), "image→text loss is not finite");
293 assert!(loss_ti.is_finite(), "text→image loss is not finite");
294 assert!(
295 (loss_it - loss_ti).abs() < 1e-4,
296 "symmetric loss mismatch: {loss_it} vs {loss_ti}"
297 );
298 }
299
300 #[test]
303 fn e2e_augment_random_crop_dims() {
304 let img = vec![0.5f32; 3 * 64 * 64];
305 let mut rng = LcgRng::new(8);
306 let op = AugOp::RandomCrop { crop_size: 48 };
307 let (out, new_h, new_w) = op.apply(&img, 3, 64, 64, &mut rng).expect("ok");
308 assert_eq!((new_h, new_w), (48, 48));
309 assert_eq!(out.len(), 3 * 48 * 48);
310 }
311
312 #[test]
313 fn e2e_augment_normalize_imagenet() {
314 use crate::augment::normalize::{IMAGENET_MEAN, IMAGENET_STD, normalize_chw};
315 let h = 8;
317 let w = 8;
318 let hw = h * w;
319 let mut img = vec![0.0f32; 3 * hw];
320 for c in 0..3 {
321 for p in 0..hw {
322 img[c * hw + p] = IMAGENET_MEAN[c];
323 }
324 }
325 let out = normalize_chw(&img, 3, h, w, &IMAGENET_MEAN, &IMAGENET_STD).expect("ok");
326 let max_abs = out.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
328 assert!(
329 max_abs < 1e-5,
330 "normalized constant-mean image should be ~0; max={max_abs}"
331 );
332 }
333
334 #[test]
335 fn e2e_augment_pipeline_chain() {
336 let img = vec![0.5f32; 3 * 64 * 64];
337 let mut rng = LcgRng::new(9);
338 let pipeline = Pipeline::new()
339 .push(AugOp::Resize { target: 48 })
340 .push(AugOp::RandomCrop { crop_size: 32 })
341 .push(AugOp::HorizontalFlip { prob: 0.5 });
342 let (out, new_h, new_w) = pipeline.apply(&img, 3, 64, 64, &mut rng).expect("ok");
343 assert_eq!((new_h, new_w), (32, 32));
344 assert!(
345 out.iter().all(|v| v.is_finite()),
346 "pipeline output must be finite"
347 );
348 }
349
350 #[test]
353 fn e2e_fpn_top_down_shape_consistency() {
354 let mut rng = LcgRng::new(10);
355 let in_channels = vec![128usize, 64, 32];
357 let out_channels = 16;
358 let cfg = FpnConfig::new(in_channels.clone(), out_channels).expect("config ok");
359 let fpn = Fpn::new(cfg, &mut rng).expect("fpn ok");
360
361 let features = vec![
362 FeatureMap::new(vec![0.1f32; 128 * 4 * 4], 128, 4, 4).expect("ok"),
363 FeatureMap::new(vec![0.1f32; 64 * 8 * 8], 64, 8, 8).expect("ok"),
364 FeatureMap::new(vec![0.1f32; 32 * 16 * 16], 32, 16, 16).expect("ok"),
365 ];
366 let pyramid = fpn.forward(features).expect("fpn forward ok");
367
368 assert_eq!(pyramid.len(), 3);
369 for fm in &pyramid {
370 assert_eq!(
371 fm.channels, out_channels,
372 "all FPN levels must have out_channels"
373 );
374 }
375 assert!(
376 pyramid
377 .iter()
378 .all(|fm| fm.data.iter().all(|v| v.is_finite()))
379 );
380 }
381
382 #[test]
385 fn e2e_roi_align_unit_box_identity() {
386 let c = 1;
388 let h = 4;
389 let w = 4;
390 let feat = vec![1.0f32; c * h * w];
391 let rois = vec![0.0f32, 0.0, 4.0, 4.0];
393 let out = roi_align(&feat, c, h, w, &rois, 1, 1, 1, 2).expect("ok");
394 assert_eq!(out.len(), 1);
395 assert!(
397 (out[0] - 1.0).abs() < 1e-5,
398 "unit box over constant map should return 1.0; got {}",
399 out[0]
400 );
401 }
402
403 #[test]
404 fn e2e_detr_decoder_query_shape() {
405 let cfg = DetrConfig::tiny();
406 let mut rng = LcgRng::new(11);
407 let decoder = DetrDecoder::new(cfg.clone(), &mut rng).expect("ok");
408 let n_queries = cfg.n_queries;
409 let embed_dim = cfg.embed_dim;
410 let n_enc = 16;
411
412 let queries = vec![0.1f32; n_queries * embed_dim];
413 let enc_feats = vec![0.2f32; n_enc * embed_dim];
414 let out = decoder
415 .forward(&queries, &enc_feats, n_enc)
416 .expect("forward ok");
417
418 assert_eq!(
419 out.len(),
420 n_queries * embed_dim,
421 "decoder must preserve query shape"
422 );
423 assert!(
424 out.iter().all(|v| v.is_finite()),
425 "decoder output contains non-finite"
426 );
427 }
428
429 #[test]
430 fn e2e_set_match_self_assignment() {
431 let n = 4;
434 let mut cost = vec![1.0f32; n * n];
435 for i in 0..n {
436 cost[i * n + i] = 0.0;
437 }
438 let matching = bipartite_match(&cost, n, n).expect("ok");
439 assert_eq!(matching.len(), n);
440 let mut assigned: Vec<(usize, usize)> = matching.clone();
442 assigned.sort_unstable();
443 for i in 0..n {
444 assert!(
445 assigned.contains(&(i, i)),
446 "identity cost matrix: query {i} should match target {i}"
447 );
448 }
449 }
450
451 #[test]
452 fn e2e_focal_loss_positive_only() {
453 let p: f32 = 0.99;
458 let alpha: f32 = 1.0;
459 let gamma: f32 = 0.0;
460 let fl = -alpha * (1.0 - p).powf(gamma) * p.ln();
461 let standard_bce = -p.ln();
462 assert!(
463 (fl - standard_bce).abs() < 1e-5,
464 "at gamma=0 focal loss == BCE; got fl={fl}, bce={standard_bce}"
465 );
466 }
467
468 #[test]
471 fn e2e_learnable_pos_embed_and_add() {
472 let n = 17; let d = 32;
474 let mut rng = LcgRng::new(12);
475 let lpe = LearnablePosEmbed::new(n, d, &mut rng).expect("ok");
476 let mut tokens = vec![0.0f32; n * d];
477 add_pos_embed(&mut tokens, &lpe.table, d).expect("add ok");
478 for (t, p) in tokens.iter().zip(lpe.table.iter()) {
480 assert!((t - p).abs() < 1e-6, "add_pos_embed mismatch");
481 }
482 }
483
484 #[test]
487 fn e2e_lateral_conv_output_shape() {
488 let mut rng = LcgRng::new(13);
489 let lat = LateralConv1x1::new(64, 16, &mut rng).expect("ok");
490 let feat = vec![0.5f32; 64 * 8 * 8];
491 let out = lat.forward(&feat, 8, 8).expect("ok");
492 assert_eq!(out.len(), 16 * 8 * 8);
493 assert!(out.iter().all(|v| v.is_finite()));
494 }
495
496 #[test]
499 fn e2e_clip_nce_nonpositive_temp_errors() {
500 let img = vec![1.0f32; 4 * 16];
501 let txt = vec![1.0f32; 4 * 16];
502 let r = info_nce_loss(&img, &txt, 16, 0.0);
503 assert!(matches!(r, Err(VisionError::NonPositiveTemperature(_))));
504 }
505}