use candle_core::{DType, Device, Error, Result, Tensor};
pub mod mobilenetv2;
pub mod sequential;
fn load_image64_raw(raw: Vec<u8>, device: &Device) -> Result<Tensor> {
let mean_array = [0.95f32, 0.95, 0.95]; let std_array = [0.2f32, 0.2, 0.2]; let data = Tensor::from_vec(raw, (96, 96, 3), device)?.permute((2, 0, 1))?;
let mean = Tensor::new(&mean_array, device)?.reshape((3, 1, 1))?;
let std = Tensor::new(&std_array, device)?.reshape((3, 1, 1))?;
(data.to_dtype(DType::F32)? / 255.0)?
.broadcast_sub(&mean)?
.broadcast_div(&std)
}
pub fn load_image_from_buffer(buffer: &[u8], device: &Device) -> Result<Tensor> {
let img = image::load_from_memory(buffer)
.map_err(Error::wrap)?
.resize_to_fill(96, 96, image::imageops::FilterType::Triangle);
let img = img.to_rgb8();
load_image64_raw(img.into_raw(), device)
}
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::Path;
#[allow(dead_code)]
fn get_labels<P: AsRef<Path>>(path: P) -> std::io::Result<Vec<String>> {
let file = File::open(path)?;
let reader = BufReader::new(file);
let mut labels = Vec::new();
for line in reader.lines() {
let line = line?;
let line = line.trim();
if let Some(label) = line.split('\t').next() {
labels.push(label.to_string());
}
}
Ok(labels)
}
#[cfg(test)]
mod test {
use std::path::Path;
use anyhow::Ok;
use candle_nn::{ops::softmax, Module, VarBuilder};
use super::*;
#[test]
fn test_inference()->anyhow::Result<()> {
let model_path =
Path::new(env!("CARGO_MANIFEST_DIR")).join("ochw_mobilenetv2_fp16.safetensors");
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(
&[model_path.as_path()],
candle_core::DType::F32,
&candle_core::Device::Cpu,
)
.unwrap()
};
let nclasses = 4037;
let model = mobilenetv2::Mobilenetv2::new(vb, nclasses).unwrap();
let image_data = include_bytes!("../../../testdata/zhi.png");
let device = &Device::Cpu;
let image = load_image_from_buffer(image_data, device).unwrap();
let image = image.unsqueeze(0).unwrap();
let output = model.forward(&image).unwrap();
let output = softmax(&output, 1).unwrap();
println!("{output}");
let mut predictions = output
.flatten_all()?
.to_vec1::<f32>()?
.into_iter()
.enumerate()
.collect::<Vec<_>>();
predictions.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
let label_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("testdata/label.txt");
let labels = get_labels(label_path).unwrap();
let top5 = predictions.iter().take(5).collect::<Vec<_>>();
for (i, (class_idx, prob)) in top5.iter().enumerate() {
println!("{}. Class {}: {:.2}%", i+1, labels[*class_idx], prob * 100.0);
}
Ok(())
}
}