sensorlm-rs 0.1.0

SensorLM – wearable sensor foundation model in Rust (Burn + WGPU)
Documentation
//! Cross-modal retrieval: sensor → text and text → sensor.
//!
//! Given a database of pre-computed embeddings, retrieval finds the top-k
//! most similar items from the other modality.
//!
//! # Use cases
//!
//! * **Sensor → Text**: "Given this 24-hour recording, find the most similar
//!   textual descriptions."
//! * **Text → Sensor**: "Given this query text, find the most similar sensor
//!   recordings."
//!
//! Both directions use the same L2-normalised embedding space, so cosine
//! similarity equals dot product.

use burn::tensor::{backend::Backend, Tensor, Int};

use crate::model::sensorlm::SensorLMModel;
use crate::loss::recall_at_k;

// ---------------------------------------------------------------------------
// Result type
// ---------------------------------------------------------------------------

/// A single retrieval result.
#[derive(Debug, Clone)]
pub struct RetrievalResult {
    /// Index into the database being searched.
    pub index: usize,
    /// Cosine similarity score `[-1, 1]`.
    pub score: f32,
}

// ---------------------------------------------------------------------------
// Retrieval engine
// ---------------------------------------------------------------------------

/// Cross-modal retrieval engine.
///
/// Pre-computes embeddings for an indexed database of sensor recordings
/// and/or text captions and enables fast nearest-neighbour search.
pub struct RetrievalEngine<B: Backend> {
    model: SensorLMModel<B>,
    /// Optional pre-computed sensor embeddings `(N, D)`.
    sensor_embeddings: Option<Tensor<B, 2>>,
    /// Optional pre-computed text embeddings `(M, D)`.
    text_embeddings: Option<Tensor<B, 2>>,
}

impl<B: Backend> RetrievalEngine<B> {
    /// Create a new retrieval engine.
    pub fn new(model: SensorLMModel<B>) -> Self {
        Self {
            model,
            sensor_embeddings: None,
            text_embeddings: None,
        }
    }

    /// Pre-compute and store sensor embeddings from a batch of sensor data.
    ///
    /// # Arguments
    ///
    /// * `sensor_batches` – Iterator of `(B, T, C)` sensor tensors.
    pub fn index_sensor(&mut self, sensor_batches: impl IntoIterator<Item = Tensor<B, 3>>) {
        let embeddings: Vec<Tensor<B, 2>> = sensor_batches
            .into_iter()
            .map(|batch| self.model.encode_sensor(batch))
            .collect();
        self.sensor_embeddings = Some(Tensor::cat(embeddings, 0));
    }

    /// Pre-compute and store text embeddings from batches of token sequences.
    pub fn index_text(
        &mut self,
        text_batches: impl IntoIterator<Item = (Tensor<B, 2, Int>, Tensor<B, 2, Int>)>,
    ) {
        let embeddings: Vec<Tensor<B, 2>> = text_batches
            .into_iter()
            .map(|(ids, mask)| self.model.encode_text(ids, mask))
            .collect();
        self.text_embeddings = Some(Tensor::cat(embeddings, 0));
    }

    /// Retrieve the top-k most similar texts for each sensor query.
    ///
    /// # Arguments
    ///
    /// * `query_sensor` – `(Q, T, C)` query sensor tensors.
    /// * `top_k`        – Number of results to return per query.
    ///
    /// # Returns
    ///
    /// A vector of `Q` lists, each containing `top_k` [`RetrievalResult`]s
    /// sorted by descending similarity.
    ///
    /// # Panics
    ///
    /// Panics if no text embeddings have been indexed.
    pub fn sensor_to_text(
        &self,
        query_sensor: Tensor<B, 3>,
        top_k: usize,
    ) -> Vec<Vec<RetrievalResult>> {
        let db = self
            .text_embeddings
            .as_ref()
            .expect("No text embeddings indexed");

        let q_emb = self.model.encode_sensor(query_sensor); // (Q, D)
        self.top_k_search(q_emb, db.clone(), top_k)
    }

    /// Retrieve the top-k most similar sensor recordings for each text query.
    ///
    /// # Panics
    ///
    /// Panics if no sensor embeddings have been indexed.
    pub fn text_to_sensor(
        &self,
        query_ids: Tensor<B, 2, Int>,
        query_mask: Tensor<B, 2, Int>,
        top_k: usize,
    ) -> Vec<Vec<RetrievalResult>> {
        let db = self
            .sensor_embeddings
            .as_ref()
            .expect("No sensor embeddings indexed");

        let q_emb = self.model.encode_text(query_ids, query_mask); // (Q, D)
        self.top_k_search(q_emb, db.clone(), top_k)
    }

    /// Core top-k nearest-neighbour search.
    ///
    /// `queries` is `(Q, D)`, `database` is `(N, D)`.
    /// Returns `Q` lists of the top-k `(index, score)` pairs.
    fn top_k_search(
        &self,
        queries: Tensor<B, 2>,
        database: Tensor<B, 2>,
        top_k: usize,
    ) -> Vec<Vec<RetrievalResult>> {
        let q = queries.dims()[0];
        let n = database.dims()[0];

        // (Q, N) similarity matrix.
        let sim = queries.matmul(database.transpose());
        let data: Vec<f32> = sim.into_data().to_vec::<f32>().unwrap_or_default();

        (0..q)
            .map(|qi| {
                let row = &data[qi * n..(qi + 1) * n];
                let mut indexed: Vec<(usize, f32)> =
                    row.iter().copied().enumerate().collect();
                indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
                indexed
                    .into_iter()
                    .take(top_k)
                    .map(|(i, s)| RetrievalResult { index: i, score: s })
                    .collect()
            })
            .collect()
    }

    /// Evaluate Recall@k on a paired (sensor, text) evaluation set.
    ///
    /// Assumes ground truth is the diagonal: sensor `i` corresponds to text `i`.
    ///
    /// # Arguments
    ///
    /// * `sensor`  – `(N, T, C)` evaluation sensor data.
    /// * `ids`     – `(N, L)` token IDs.
    /// * `mask`    – `(N, L)` attention mask.
    /// * `k`       – Recall@k.
    pub fn evaluate_recall(
        &self,
        sensor: Tensor<B, 3>,
        ids: Tensor<B, 2, Int>,
        mask: Tensor<B, 2, Int>,
        k: usize,
    ) -> (f32, f32) {
        let z_s = self.model.encode_sensor(sensor);  // (N, D)
        let z_t = self.model.encode_text(ids, mask); // (N, D)

        let logits = z_s.matmul(z_t.clone().transpose()); // (N, N)
        let r_s2t = recall_at_k(logits.clone(), k);
        let r_t2s = recall_at_k(logits.transpose(), k);

        (r_s2t, r_t2s)
    }
}