opentslm 0.1.0

Rust implementation of OpenTSLM using Burn, WGPU, and llama.cpp
//! Core data structures shared across all dataset loaders.
//!
//! A [`Sample`] is the Rust equivalent of the per-sample `dict` passed
//! throughout the Python codebase:
//!
//! ```python
//! {
//!     "pre_prompt":       str,
//!     "time_series_text": List[str],
//!     "time_series":      List[List[float]],   # one 1-D series per entry
//!     "post_prompt":      str,
//!     "answer":           str,
//!     "label":            Optional[str],
//! }
//! ```
//!
//! [`extend_to_patch_multiple`] mirrors
//! `extend_time_series_to_match_patch_size_and_aggregate` from
//! `src/opentslm/time_series_datasets/util.py`.

use crate::config::PATCH_SIZE;

/// A single training or evaluation sample.
///
/// Mirrors the per-sample `dict` format used throughout the Python codebase.
/// Each [`Sample`] contains one multi-modal input: a structured text prompt,
/// zero or more raw time series (with associated text descriptions), and the
/// expected ground-truth answer.
///
/// The `time_series` and `time_series_text` vecs are parallel — element `i`
/// of `time_series_text` is the label/description for `time_series[i]`.
#[derive(Debug, Clone)]
pub struct Sample {
    /// Text shown to the LLM before the time-series block (task description,
    /// question, clinical context, etc.).
    pub pre_prompt: String,
    /// One short text description per time series, e.g.
    /// `"x-axis accelerometer, mean 0.12, std 0.88:"`.
    pub time_series_text: Vec<String>,
    /// Raw 1-D time series values (f32), one per channel.  These are already
    /// normalised but **not yet patch-padded**; call [`extend_to_patch_multiple`]
    /// or [`collate`] before passing them to the encoder.
    pub time_series: Vec<Vec<f32>>,
    /// Text shown to the LLM after the time-series block (answer format
    /// instructions, candidate labels, etc.).
    pub post_prompt: String,
    /// Ground-truth answer text used to compute training loss.
    pub answer: String,
    /// Optional categorical label used to compute accuracy / macro recall
    /// during evaluation.  `None` for free-form captioning stages.
    pub label: Option<String>,
}

/// A batch of [`Sample`]s whose time series have been padded to a uniform,
/// patch-aligned length.
///
/// Created via [`from_samples`](TokenizedBatch::from_samples) or the
/// convenience wrapper [`collate`].
#[derive(Debug, Clone)]
pub struct TokenizedBatch {
    /// The samples comprising this batch, with all `time_series` vecs already
    /// padded to the same length (a multiple of [`PATCH_SIZE`]).
    pub samples: Vec<Sample>,
}

impl TokenizedBatch {
    /// Build a [`TokenizedBatch`] by patch-padding every series in `samples`.
    ///
    /// Within each sample the series are padded to a common length that is
    /// the smallest multiple of `patch_size` ≥ the longest series in that
    /// sample.  Use [`collate`] to apply the global [`PATCH_SIZE`].
    pub fn from_samples(mut samples: Vec<Sample>, patch_size: usize) -> Self {
        for s in &mut samples {
            extend_to_patch_multiple(&mut s.time_series, patch_size);
        }
        Self { samples }
    }
}

/// Pad every 1-D series in `series_list` so they all share the same length,
/// which is the smallest multiple of `patch_size` ≥ the longest series.
///
/// Shorter series are extended with trailing zeros; longer-than-padded series
/// are silently truncated (should not happen in practice).  This mirrors
/// `extend_time_series_to_match_patch_size_and_aggregate` in the Python codebase.
pub fn extend_to_patch_multiple(series_list: &mut Vec<Vec<f32>>, patch_size: usize) {
    if series_list.is_empty() {
        return;
    }
    let max_len = series_list.iter().map(|v| v.len()).max().unwrap_or(0);
    let padded_len = ((max_len + patch_size - 1) / patch_size) * patch_size;

    for series in series_list.iter_mut() {
        if series.len() < padded_len {
            series.resize(padded_len, 0.0);
        } else if series.len() > padded_len {
            series.truncate(padded_len);
        }
    }
}

/// Convenience wrapper: collate `samples` into a [`TokenizedBatch`] using the
/// global [`PATCH_SIZE`] from [`crate::config`].
pub fn collate(samples: Vec<Sample>) -> TokenizedBatch {
    TokenizedBatch::from_samples(samples, PATCH_SIZE)
}

/// Normalise a slice of floats to zero mean and unit variance.
///
/// Returns `(normalised_values, mean, std)`.  The standard deviation is
/// clamped to at least `1e-6` to avoid division by zero on constant series.
pub fn normalize(v: &[f32]) -> (Vec<f32>, f32, f32) {
    let n = v.len() as f32;
    let mean = v.iter().sum::<f32>() / n;
    let var = v.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / n;
    let std = var.sqrt().max(1e-6);
    let normed = v.iter().map(|x| (x - mean) / std).collect();
    (normed, mean, std)
}