sensorlm-rs 0.1.0

SensorLM – wearable sensor foundation model in Rust (Burn + WGPU)
Documentation
//! Zero-shot classification for wearable sensor data.
//!
//! Zero-shot recognition works by encoding a set of candidate class-name
//! prompts with the text encoder and computing the cosine similarity to each
//! sensor embedding.  The class with the highest similarity is predicted.
//!
//! # Example
//!
//! ```rust,no_run
//! use sensorlm::inference::zero_shot::{ZeroShotClassifier, ClassifierConfig};
//!
//! let cfg = ClassifierConfig {
//!     class_names: vec!["walking".into(), "running".into(), "sleeping".into()],
//!     prompt_template: "The person is {label}.".into(),
//! };
//! // let clf = ZeroShotClassifier::new(model, tokenizer, cfg, device);
//! // let predictions = clf.predict(sensor_batch);
//! ```

use burn::tensor::{backend::Backend, Tensor, Int};

use crate::model::sensorlm::SensorLMModel;

// ---------------------------------------------------------------------------
// Configuration
// ---------------------------------------------------------------------------

/// Configuration for zero-shot classification.
#[derive(Debug, Clone)]
pub struct ClassifierConfig {
    /// Human-readable class labels.
    pub class_names: Vec<String>,
    /// Prompt template.  The substring `{label}` is replaced with each class
    /// name before tokenisation.
    pub prompt_template: String,
}

impl ClassifierConfig {
    /// Build the filled-in prompt for one class.
    pub fn prompt_for(&self, label: &str) -> String {
        self.prompt_template.replace("{label}", label)
    }
}

impl Default for ClassifierConfig {
    fn default() -> Self {
        Self {
            class_names: vec![
                "walking".to_string(),
                "running".to_string(),
                "cycling".to_string(),
                "sleeping".to_string(),
                "sedentary".to_string(),
            ],
            prompt_template: "The person is {label}.".to_string(),
        }
    }
}

// ---------------------------------------------------------------------------
// Classifier
// ---------------------------------------------------------------------------

/// Zero-shot classifier backed by a SensorLM model.
pub struct ZeroShotClassifier<B: Backend> {
    model: SensorLMModel<B>,
    /// Pre-computed text embeddings for all class prompts, shape `(K, D)`.
    class_embeddings: Tensor<B, 2>,
    /// Class names in the same order as `class_embeddings`.
    class_names: Vec<String>,
}

impl<B: Backend> ZeroShotClassifier<B> {
    /// Construct the classifier and pre-compute class embeddings.
    ///
    /// # Arguments
    ///
    /// * `model`     – A trained SensorLM model.
    /// * `cfg`       – Classifier configuration.
    /// * `tokenize`  – A closure that converts a prompt string into
    ///   `(token_ids, attention_mask)` tensors of shape `(1, L)`.
    pub fn new<F>(model: SensorLMModel<B>, cfg: &ClassifierConfig, tokenize: F) -> Self
    where
        F: Fn(&str) -> (Tensor<B, 2, Int>, Tensor<B, 2, Int>),
    {
        let embeddings: Vec<Tensor<B, 2>> = cfg
            .class_names
            .iter()
            .map(|name| {
                let prompt = cfg.prompt_for(name);
                let (ids, mask) = tokenize(&prompt);
                model.encode_text(ids, mask) // (1, D)
            })
            .collect();

        // Stack into (K, D).
        let class_embeddings = Tensor::cat(embeddings, 0);

        Self {
            model,
            class_embeddings,
            class_names: cfg.class_names.clone(),
        }
    }

    /// Predict the class for each sensor sample in the batch.
    ///
    /// # Arguments
    ///
    /// * `sensor` – `(B, T, C)` normalised sensor data.
    ///
    /// # Returns
    ///
    /// A vector of `B` `(class_index, class_name, similarity_score)` tuples.
    pub fn predict(
        &self,
        sensor: Tensor<B, 3>,
    ) -> Vec<(usize, String, f32)> {
        let b = sensor.dims()[0];
        let z_sensor = self.model.encode_sensor(sensor); // (B, D)

        // Similarity matrix: (B, K) = z_sensor @ class_embeddings.T
        let sim = z_sensor.matmul(self.class_embeddings.clone().transpose()); // (B, K)

        let data: Vec<f32> = sim.into_data().to_vec::<f32>().unwrap_or_default();
        let k = self.class_names.len();

        (0..b)
            .map(|i| {
                let row = &data[i * k..(i + 1) * k];
                let (best_idx, &best_score) = row
                    .iter()
                    .enumerate()
                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
                    .unwrap();
                (best_idx, self.class_names[best_idx].clone(), best_score)
            })
            .collect()
    }

    /// Predict and return the top-k predictions per sample.
    pub fn predict_topk(
        &self,
        sensor: Tensor<B, 3>,
        k: usize,
    ) -> Vec<Vec<(usize, String, f32)>> {
        let b = sensor.dims()[0];
        let z_sensor = self.model.encode_sensor(sensor);
        let sim = z_sensor.matmul(self.class_embeddings.clone().transpose());
        let data: Vec<f32> = sim.into_data().to_vec::<f32>().unwrap_or_default();
        let num_classes = self.class_names.len();

        (0..b)
            .map(|i| {
                let row = &data[i * num_classes..(i + 1) * num_classes];
                let mut indexed: Vec<(usize, f32)> =
                    row.iter().copied().enumerate().collect();
                indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
                indexed
                    .into_iter()
                    .take(k)
                    .map(|(idx, score)| (idx, self.class_names[idx].clone(), score))
                    .collect()
            })
            .collect()
    }
}