Skip to main content

pineapple_neural/
nn.rs

1// Copyright (c) 2025, Tom Ouellette
2// Licensed under the BSD 3-Clause License
3
4use candle_core::Device;
5use candle_core::{Module, Result, Tensor};
6
7use pineapple_core::im::PineappleImage;
8
9use crate::models::DinoVisionTransformer;
10use crate::models::StandardVisionTransformer;
11
12use crate::load::{
13    load_dinobloom_vit_base, load_dinov2_vit_base, load_dinov2_vit_small, load_scdino_vit_small,
14    load_subcell_vit_base,
15};
16
17use crate::preprocess::{preprocess_imagenet, preprocess_subcell};
18
19pub enum Models {
20    DinoVitSmall(DinoVisionTransformer),
21    DinoVitBase(DinoVisionTransformer),
22    DinobloomVitBase(DinoVisionTransformer),
23    ScdinoVitSmall(StandardVisionTransformer),
24    SubcellVitSmall(StandardVisionTransformer),
25}
26
27impl Models {
28    pub fn load(model_name: &str, device: &Device, verbose: bool) -> Self {
29        match model_name {
30            "dino_vit_small" => {
31                let model = load_dinov2_vit_small(device, verbose).unwrap();
32                Models::DinoVitSmall(model)
33            }
34            "dino_vit_base" => {
35                let model = load_dinov2_vit_base(device, verbose).unwrap();
36                Models::DinoVitBase(model)
37            }
38            "dinobloom_vit_base" => {
39                let model = load_dinobloom_vit_base(device, verbose).unwrap();
40                Models::DinobloomVitBase(model)
41            }
42            "scdino_vit_small" => {
43                let model = load_scdino_vit_small(device, verbose).unwrap();
44                Models::ScdinoVitSmall(model)
45            }
46            "subcell_vit_base" => {
47                let model = load_subcell_vit_base(device, verbose).unwrap();
48                Models::SubcellVitSmall(model)
49            }
50            _ => {
51                eprintln!("[pineapple::nn::models] Model name not found.");
52                std::process::exit(1);
53            }
54        }
55    }
56
57    pub fn preprocess(&self, image: &PineappleImage, device: &Device) -> Result<Tensor> {
58        match self {
59            Models::DinoVitSmall(_) => preprocess_imagenet(image, device),
60            Models::DinoVitBase(_) => preprocess_imagenet(image, device),
61            Models::DinobloomVitBase(_) => preprocess_imagenet(image, device),
62            Models::ScdinoVitSmall(_) => preprocess_imagenet(image, device),
63            Models::SubcellVitSmall(_) => preprocess_subcell(image, device),
64        }
65    }
66
67    pub fn forward(&self, input: &Tensor) -> Result<Tensor> {
68        let input = input.unsqueeze(0).unwrap();
69        match self {
70            Models::DinoVitSmall(model) => model.forward(&input),
71            Models::DinoVitBase(model) => model.forward(&input),
72            Models::DinobloomVitBase(model) => model.forward(&input),
73            Models::ScdinoVitSmall(model) => model.forward(&input),
74            Models::SubcellVitSmall(model) => model.forward(&input),
75        }
76    }
77}
78
79#[cfg(test)]
80mod test {
81    use super::*;
82    use pineapple_core::im::PineappleImage;
83
84    fn load_rgb() -> PineappleImage {
85        PineappleImage::open("../data/tests/test_rgb.png").unwrap()
86    }
87
88    fn load_grayscale() -> PineappleImage {
89        PineappleImage::open("../data/tests/test_grayscale.png").unwrap()
90    }
91
92    fn test_model(name: &str, color: &str, n_embed: usize) {
93        let image = if color == "rgb" {
94            load_rgb()
95        } else {
96            load_grayscale()
97        };
98
99        let model = Models::load(name, &Device::Cpu, true);
100        let image = model.preprocess(&image, &Device::Cpu).unwrap();
101        let logits = model.forward(&image).unwrap();
102
103        let (n_row, n_columns) = logits.shape().dims2().unwrap();
104
105        assert_eq!(n_row, 1);
106        assert_eq!(n_columns, n_embed);
107    }
108
109    #[test]
110    fn test_dinov2_small_rgb() {
111        test_model("dino_vit_small", "rgb", 384);
112    }
113
114    #[test]
115    fn test_dinov2_small_grayscale() {
116        test_model("dino_vit_small", "grayscale", 384);
117    }
118
119    #[test]
120    fn test_dinov2_base_rgb() {
121        test_model("dino_vit_base", "rgb", 768);
122    }
123
124    #[test]
125    fn test_dinov2_base_grayscale() {
126        test_model("dino_vit_base", "grayscale", 768);
127    }
128
129    #[test]
130    fn test_dinobloom_rgb() {
131        test_model("dinobloom_vit_base", "rgb", 768);
132    }
133
134    #[test]
135    fn test_dinobloom_grayscale() {
136        test_model("dinobloom_vit_base", "grayscale", 768);
137    }
138
139    #[test]
140    fn test_subcell_rgb() {
141        test_model("subcell_vit_base", "rgb", 768);
142    }
143
144    #[test]
145    fn test_subcell_grayscale() {
146        test_model("subcell_vit_base", "grayscale", 768);
147    }
148
149    #[test]
150    fn test_scdino_rgb() {
151        test_model("scdino_vit_small", "rgb", 384);
152    }
153
154    #[test]
155    fn test_scdino_grayscale() {
156        test_model("scdino_vit_small", "grayscale", 384);
157    }
158}