use candle_core::{DType, Device, Result};
use candle_nn::{Module, VarBuilder};
use serde::{Deserialize, Serialize};
use std::io::{BufRead, BufReader};
use crate::{models::{load_image_from_buffer, mobilenetv2::Mobilenetv2}, utils::auto_crop_image_content};
#[derive(Serialize, Deserialize)]
pub struct Topk {
pub label: String,
pub score: f32,
pub class_idx:usize,
}
pub struct Inference {
model: Mobilenetv2,
}
impl Inference {
pub fn load_model(weights: &[u8]) -> Result<Self> {
let dev = &Device::Cpu;
let vb = VarBuilder::from_buffered_safetensors(weights.to_vec(), DType::F32, dev)?;
let model = Mobilenetv2::new(vb, 4037)?;
Ok(Self { model })
}
pub fn get_labels(&self) -> Result<Vec<String>> {
let label_text = include_str!("../label.txt");
let reader = BufReader::new(label_text.as_bytes());
let mut labels = Vec::new();
for line in reader.lines() {
let line = line?;
let line = line.trim();
if let Some(label) = line.split('\t').next() {
labels.push(label.to_string());
}
}
Ok(labels)
}
pub fn predict(&self, image: Vec<u8>,topk:Option<usize>) -> anyhow::Result<Vec<Topk>> {
let image = auto_crop_image_content(&image)?;
let image = load_image_from_buffer(&image, &Device::Cpu)?;
let image = image.unsqueeze(0)?;
let output = self.model.forward(&image)?;
let output = candle_nn::ops::softmax(&output, 1)?;
let mut predictions = output
.flatten_all()?
.to_vec1::<f32>()?
.into_iter()
.enumerate()
.collect::<Vec<_>>();
predictions.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
let labels = self.get_labels()?;
let topk = topk.unwrap_or(5);
let topk_data = predictions.iter().take(topk).collect::<Vec<_>>();
let mut top5_data = Vec::with_capacity(topk);
for (i, (class_idx, prob)) in topk_data.iter().enumerate() {
println!(
"{}. Class {}: {:.2}%",
i + 1,
labels[*class_idx],
prob * 100.0
);
top5_data.push(Topk {
label: labels[*class_idx].clone(),
score: *prob,
class_idx: *class_idx,
})
}
Ok(top5_data)
}
}