1use candle_core::Device;
5use candle_core::{Module, Result, Tensor};
6
7use pineapple_core::im::PineappleImage;
8
9use crate::models::DinoVisionTransformer;
10use crate::models::StandardVisionTransformer;
11
12use crate::load::{
13 load_dinobloom_vit_base, load_dinov2_vit_base, load_dinov2_vit_small, load_scdino_vit_small,
14 load_subcell_vit_base,
15};
16
17use crate::preprocess::{preprocess_imagenet, preprocess_subcell};
18
19pub enum Models {
20 DinoVitSmall(DinoVisionTransformer),
21 DinoVitBase(DinoVisionTransformer),
22 DinobloomVitBase(DinoVisionTransformer),
23 ScdinoVitSmall(StandardVisionTransformer),
24 SubcellVitSmall(StandardVisionTransformer),
25}
26
27impl Models {
28 pub fn load(model_name: &str, device: &Device, verbose: bool) -> Self {
29 match model_name {
30 "dino_vit_small" => {
31 let model = load_dinov2_vit_small(device, verbose).unwrap();
32 Models::DinoVitSmall(model)
33 }
34 "dino_vit_base" => {
35 let model = load_dinov2_vit_base(device, verbose).unwrap();
36 Models::DinoVitBase(model)
37 }
38 "dinobloom_vit_base" => {
39 let model = load_dinobloom_vit_base(device, verbose).unwrap();
40 Models::DinobloomVitBase(model)
41 }
42 "scdino_vit_small" => {
43 let model = load_scdino_vit_small(device, verbose).unwrap();
44 Models::ScdinoVitSmall(model)
45 }
46 "subcell_vit_base" => {
47 let model = load_subcell_vit_base(device, verbose).unwrap();
48 Models::SubcellVitSmall(model)
49 }
50 _ => {
51 eprintln!("[pineapple::nn::models] Model name not found.");
52 std::process::exit(1);
53 }
54 }
55 }
56
57 pub fn preprocess(&self, image: &PineappleImage, device: &Device) -> Result<Tensor> {
58 match self {
59 Models::DinoVitSmall(_) => preprocess_imagenet(image, device),
60 Models::DinoVitBase(_) => preprocess_imagenet(image, device),
61 Models::DinobloomVitBase(_) => preprocess_imagenet(image, device),
62 Models::ScdinoVitSmall(_) => preprocess_imagenet(image, device),
63 Models::SubcellVitSmall(_) => preprocess_subcell(image, device),
64 }
65 }
66
67 pub fn forward(&self, input: &Tensor) -> Result<Tensor> {
68 let input = input.unsqueeze(0).unwrap();
69 match self {
70 Models::DinoVitSmall(model) => model.forward(&input),
71 Models::DinoVitBase(model) => model.forward(&input),
72 Models::DinobloomVitBase(model) => model.forward(&input),
73 Models::ScdinoVitSmall(model) => model.forward(&input),
74 Models::SubcellVitSmall(model) => model.forward(&input),
75 }
76 }
77}
78
79#[cfg(test)]
80mod test {
81 use super::*;
82 use pineapple_core::im::PineappleImage;
83
84 fn load_rgb() -> PineappleImage {
85 PineappleImage::open("../data/tests/test_rgb.png").unwrap()
86 }
87
88 fn load_grayscale() -> PineappleImage {
89 PineappleImage::open("../data/tests/test_grayscale.png").unwrap()
90 }
91
92 fn test_model(name: &str, color: &str, n_embed: usize) {
93 let image = if color == "rgb" {
94 load_rgb()
95 } else {
96 load_grayscale()
97 };
98
99 let model = Models::load(name, &Device::Cpu, true);
100 let image = model.preprocess(&image, &Device::Cpu).unwrap();
101 let logits = model.forward(&image).unwrap();
102
103 let (n_row, n_columns) = logits.shape().dims2().unwrap();
104
105 assert_eq!(n_row, 1);
106 assert_eq!(n_columns, n_embed);
107 }
108
109 #[test]
110 fn test_dinov2_small_rgb() {
111 test_model("dino_vit_small", "rgb", 384);
112 }
113
114 #[test]
115 fn test_dinov2_small_grayscale() {
116 test_model("dino_vit_small", "grayscale", 384);
117 }
118
119 #[test]
120 fn test_dinov2_base_rgb() {
121 test_model("dino_vit_base", "rgb", 768);
122 }
123
124 #[test]
125 fn test_dinov2_base_grayscale() {
126 test_model("dino_vit_base", "grayscale", 768);
127 }
128
129 #[test]
130 fn test_dinobloom_rgb() {
131 test_model("dinobloom_vit_base", "rgb", 768);
132 }
133
134 #[test]
135 fn test_dinobloom_grayscale() {
136 test_model("dinobloom_vit_base", "grayscale", 768);
137 }
138
139 #[test]
140 fn test_subcell_rgb() {
141 test_model("subcell_vit_base", "rgb", 768);
142 }
143
144 #[test]
145 fn test_subcell_grayscale() {
146 test_model("subcell_vit_base", "grayscale", 768);
147 }
148
149 #[test]
150 fn test_scdino_rgb() {
151 test_model("scdino_vit_small", "rgb", 384);
152 }
153
154 #[test]
155 fn test_scdino_grayscale() {
156 test_model("scdino_vit_small", "grayscale", 384);
157 }
158}