use burn::tensor::{backend::Backend, Tensor, Int};
use crate::model::sensorlm::SensorLMModel;
#[derive(Debug, Clone)]
pub struct ClassifierConfig {
pub class_names: Vec<String>,
pub prompt_template: String,
}
impl ClassifierConfig {
pub fn prompt_for(&self, label: &str) -> String {
self.prompt_template.replace("{label}", label)
}
}
impl Default for ClassifierConfig {
fn default() -> Self {
Self {
class_names: vec![
"walking".to_string(),
"running".to_string(),
"cycling".to_string(),
"sleeping".to_string(),
"sedentary".to_string(),
],
prompt_template: "The person is {label}.".to_string(),
}
}
}
pub struct ZeroShotClassifier<B: Backend> {
model: SensorLMModel<B>,
class_embeddings: Tensor<B, 2>,
class_names: Vec<String>,
}
impl<B: Backend> ZeroShotClassifier<B> {
pub fn new<F>(model: SensorLMModel<B>, cfg: &ClassifierConfig, tokenize: F) -> Self
where
F: Fn(&str) -> (Tensor<B, 2, Int>, Tensor<B, 2, Int>),
{
let embeddings: Vec<Tensor<B, 2>> = cfg
.class_names
.iter()
.map(|name| {
let prompt = cfg.prompt_for(name);
let (ids, mask) = tokenize(&prompt);
model.encode_text(ids, mask) })
.collect();
let class_embeddings = Tensor::cat(embeddings, 0);
Self {
model,
class_embeddings,
class_names: cfg.class_names.clone(),
}
}
pub fn predict(
&self,
sensor: Tensor<B, 3>,
) -> Vec<(usize, String, f32)> {
let b = sensor.dims()[0];
let z_sensor = self.model.encode_sensor(sensor);
let sim = z_sensor.matmul(self.class_embeddings.clone().transpose());
let data: Vec<f32> = sim.into_data().to_vec::<f32>().unwrap_or_default();
let k = self.class_names.len();
(0..b)
.map(|i| {
let row = &data[i * k..(i + 1) * k];
let (best_idx, &best_score) = row
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap();
(best_idx, self.class_names[best_idx].clone(), best_score)
})
.collect()
}
pub fn predict_topk(
&self,
sensor: Tensor<B, 3>,
k: usize,
) -> Vec<Vec<(usize, String, f32)>> {
let b = sensor.dims()[0];
let z_sensor = self.model.encode_sensor(sensor);
let sim = z_sensor.matmul(self.class_embeddings.clone().transpose());
let data: Vec<f32> = sim.into_data().to_vec::<f32>().unwrap_or_default();
let num_classes = self.class_names.len();
(0..b)
.map(|i| {
let row = &data[i * num_classes..(i + 1) * num_classes];
let mut indexed: Vec<(usize, f32)> =
row.iter().copied().enumerate().collect();
indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
indexed
.into_iter()
.take(k)
.map(|(idx, score)| (idx, self.class_names[idx].clone(), score))
.collect()
})
.collect()
}
}