opentslm 0.1.0

Rust implementation of OpenTSLM using Burn, WGPU, and llama.cpp
//! Thin safe wrapper around the `llama-cpp-4` crate.
//!
//! [`LlamaCppBackend`] provides the three operations needed by the OpenTSLM
//! training and inference pipeline:
//!
//! 1. **Tokenisation / de-tokenisation** — [`tokenize`](LlamaCppBackend::tokenize) /
//!    [`detokenize`](LlamaCppBackend::detokenize)
//! 2. **Forward pass for training** — [`answer_logits`](LlamaCppBackend::answer_logits):
//!    runs a single causal forward pass over `[prompt | answer]` and returns
//!    the base logit vectors at answer positions.  These are treated as
//!    *constants* in the Burn autodiff graph; gradients flow only through
//!    the additive logit bias produced by the encoder.
//! 3. **Autoregressive generation** — [`generate`](LlamaCppBackend::generate):
//!    greedy argmax decoding with an optional additive `logit_bias` applied
//!    at every step to condition generation on the time series.
//!
//! # Design notes
//!
//! `LlamaContext<'_>` carries a lifetime tied to `&LlamaModel`, so it cannot
//! be stored alongside the model in the same struct.  A fresh context is
//! created on demand inside each method and dropped before returning.  The
//! KV-cache allocation cost is acceptable because each call spawns at most
//! one context and contexts are small relative to the multi-GB model weights.
//!
//! The LLM is **fully frozen**.  llama.cpp provides:
//! - native GGUF quantisation (Q4_K_M, Q5_K_M, Q8_0, …)
//! - accelerated inference on the platform GPU (see below)
//! - a rich tokeniser that matches the model's built-in vocabulary exactly
//!
//! # GPU backend selection
//!
//! The correct GPU backend is compiled automatically — no flags needed for
//! the defaults:
//!
//! | Platform | Default backend | How it is enabled |
//! |----------|----------------|-------------------|
//! | macOS (Apple Silicon / Intel) | Metal | always compiled by llama.cpp on Apple targets |
//! | Linux / Windows | Vulkan | `[target]` section in `Cargo.toml` |
//!
//! ## Optional overrides
//!
//! ```text
//! # Vulkan on any platform (macOS requires MoltenVK — see below):
//! cargo build --release --features vulkan
//!
//! # CUDA on Linux/Windows (CUDA toolkit required):
//! cargo build --release --features cuda
//!
//! # Both Vulkan and CUDA simultaneously:
//! cargo build --release --features vulkan --features cuda
//! ```
//!
//! ### Vulkan on macOS (MoltenVK)
//!
//! Vulkan on macOS is provided by MoltenVK, which ships inside the
//! [LunarG Vulkan SDK](https://vulkan.lunarg.com/sdk/home#mac).
//! Install the SDK and source `setup-env.sh` before building:
//! ```text
//! source $HOME/VulkanSDK/<ver>/setup-env.sh
//! cargo build --release --features vulkan
//! ```
//! `build.rs` checks for `VULKAN_SDK` / `VK_ICD_FILENAMES` / `VK_LAYER_PATH`
//! and prints a clear error with install instructions when the SDK is absent.
//! Note: because `llama-cpp-sys-4`'s CMake invocation runs in parallel with
//! `build.rs`, a CMake "Could NOT find Vulkan" error will also appear above
//! the diagnostic — both errors describe the same missing SDK.
//!
//! ## Layer offload
//!
//! The number of transformer layers sent to the GPU is set at runtime by
//! [`N_GPU_LAYERS`].  Use a large value (e.g. `999`) to offload all layers,
//! or `0` for CPU-only inference.
//!
//! [`N_GPU_LAYERS`]: crate::config::N_GPU_LAYERS

use std::{num::NonZeroU32, path::Path};

use anyhow::{Context, Result};
use llama_cpp_4::{
    context::params::LlamaContextParams,
    llama_backend::LlamaBackend,
    llama_batch::LlamaBatch,
    model::{params::LlamaModelParams, AddBos, LlamaModel, Special},
    token::LlamaToken,
};
use tracing::info;

// ── Backend ───────────────────────────────────────────────────────────────────

/// Frozen GGUF LLM backend.
///
/// Owns the llama.cpp process-global [`LlamaBackend`] and the loaded
/// [`LlamaModel`].  This struct is intentionally **not** a Burn [`Module`]
/// — it is kept outside Burn's autodiff graph so that its parameters are
/// never updated during training.
///
/// [`Module`]: burn::module::Module
pub struct LlamaCppBackend {
    /// Process-global llama.cpp backend (initialised once).
    pub llama_backend: LlamaBackend,
    /// Loaded GGUF model (weights, tokeniser, config).
    pub model:         LlamaModel,
    /// Vocabulary size; used to size the logit-bias tensor.
    pub n_vocab:       usize,
    /// Embedding dimensionality of the LLM (unused in SP variant but exposed
    /// for informational purposes and potential future Flamingo-style use).
    pub n_embd:        usize,
    /// Minimum context window size (tokens) when creating inference contexts.
    /// The actual context created is `max(ctx_size, sequence_length + 1)`.
    pub ctx_size:      usize,
}

impl LlamaCppBackend {
    // ── Constructor ───────────────────────────────────────────────────────

    /// Load a GGUF model file.
    ///
    /// `n_gpu_layers` — number of transformer layers to offload to the GPU
    /// (Metal / CUDA).  Use a large value such as `999` to offload all layers.
    ///
    /// `ctx_size` — default KV-cache window size in tokens.
    pub fn load(gguf_path: &Path, n_gpu_layers: u32, ctx_size: usize) -> Result<Self> {
        info!("Loading GGUF model from {:?}", gguf_path);

        let mut llama_backend = LlamaBackend::init()
            .map_err(|e| anyhow::anyhow!("LlamaBackend::init failed: {e}"))?;

        // Suppress llama.cpp / ggml's verbose C-level stderr chatter
        // (ggml_metal_init, llama_kv_cache layer lines, sched_reserve, etc.).
        // All meaningful status is already reported through Rust's `tracing`.
        llama_backend.void_logs();

        let model_params = LlamaModelParams::default()
            .with_n_gpu_layers(n_gpu_layers);

        let model = LlamaModel::load_from_file(&llama_backend, gguf_path, &model_params)
            .with_context(|| format!("Failed to load model from {gguf_path:?}"))?;

        let n_vocab = usize::try_from(model.n_vocab())
            .context("n_vocab overflows usize")?;
        let n_embd = usize::try_from(model.n_embd())
            .context("n_embd overflows usize")?;

        info!(
            "  vocab={n_vocab}  embd={n_embd}  ctx_train={}",
            model.n_ctx_train()
        );

        Ok(Self { llama_backend, model, n_vocab, n_embd, ctx_size })
    }

    // ── Tokenisation ──────────────────────────────────────────────────────

    /// Tokenise `text`.  BOS is prepended when `add_bos = true`.
    pub fn tokenize(&self, text: &str, add_bos: bool) -> Result<Vec<LlamaToken>> {
        let bos = if add_bos { AddBos::Always } else { AddBos::Never };
        self.model
            .str_to_token(text, bos)
            .map_err(|e| anyhow::anyhow!("tokenise failed: {e}"))
    }

    /// Decode a sequence of tokens back to a UTF-8 string.
    pub fn detokenize(&self, tokens: &[LlamaToken]) -> String {
        self.model
            .tokens_to_str(tokens, Special::Tokenize)
            .unwrap_or_default()
    }

    /// Return the model's end-of-sequence token.
    pub fn eos_token(&self) -> LlamaToken { self.model.token_eos() }
    /// Return the model's beginning-of-sequence token.
    pub fn bos_token(&self) -> LlamaToken { self.model.token_bos() }

    // ── Forward pass for training ─────────────────────────────────────────

    /// Run a single forward pass over `[prompt | answer]` and return the LLM's
    /// logit vectors at the positions that predict each answer token.
    ///
    /// Returns `Vec<Vec<f32>>` of length `answer_tokens.len()`, where element
    /// `j` is the `n_vocab`-long logit vector that predicts `answer_tokens[j]`.
    ///
    /// A context is created and destroyed inside this call — the model's KV
    /// cache does not persist between calls.
    pub fn answer_logits(
        &self,
        prompt_tokens: &[LlamaToken],
        answer_tokens: &[LlamaToken],
    ) -> Result<Vec<Vec<f32>>> {
        if answer_tokens.is_empty() {
            return Ok(vec![]);
        }

        let p_len = prompt_tokens.len();
        let a_len = answer_tokens.len();
        let total  = p_len + a_len;

        // Context must fit the entire sequence.
        let n_ctx = (total + 1).max(self.ctx_size);
        let ctx_params = LlamaContextParams::default()
            .with_n_ctx(NonZeroU32::new(n_ctx as u32));

        let mut ctx = self.model
            .new_context(&self.llama_backend, ctx_params)
            .map_err(|e| anyhow::anyhow!("context creation failed: {e}"))?;

        let mut batch = LlamaBatch::new(total, 1);

        // Prompt tokens — only the LAST prompt token needs logits (it predicts a0).
        for (i, &token) in prompt_tokens.iter().enumerate() {
            batch.add(token, i as i32, &[0], i == p_len - 1)?;
        }

        // Answer tokens — tokens a0..a(M-2) need logits (they predict a1..a(M-1)).
        // The final answer token a(M-1) does NOT need logits (nothing to predict after it).
        for (i, &token) in answer_tokens.iter().enumerate() {
            let need = i < a_len - 1;
            batch.add(token, (p_len + i) as i32, &[0], need)?;
        }

        ctx.decode(&mut batch)
            .map_err(|e| anyhow::anyhow!("decode failed: {e}"))?;

        // Collect:
        //   position p_len-1            → predicts a0
        //   position p_len + 0          → predicts a1
        //        //   position p_len + (a_len-2)  → predicts a(M-1)
        let mut result = Vec::with_capacity(a_len);

        // Logit at position (p_len-1) predicts answer_tokens[0].
        result.push(ctx.get_logits_ith((p_len as i32) - 1).to_vec());

        // Logit at position (p_len + i) predicts answer_tokens[i+1].
        for i in 0..(a_len - 1) {
            result.push(ctx.get_logits_ith((p_len + i) as i32).to_vec());
        }

        Ok(result)
    }

    // ── Greedy generation ─────────────────────────────────────────────────

    /// Autoregressively generate up to `max_new_tokens` tokens given
    /// `prompt_tokens`.
    ///
    /// `logit_bias` — optional additive bias (`n_vocab` floats) applied to the
    /// raw LLM logits at every generation step.  Pass the encoder's output to
    /// condition generation on the time series.
    pub fn generate(
        &self,
        prompt_tokens: &[LlamaToken],
        max_new_tokens: usize,
        logit_bias: Option<&[f32]>,
    ) -> Result<Vec<LlamaToken>> {
        if prompt_tokens.is_empty() {
            return Ok(vec![]);
        }

        let capacity = prompt_tokens.len() + max_new_tokens + 1;
        let n_ctx    = capacity.max(self.ctx_size);
        let ctx_params = LlamaContextParams::default()
            .with_n_ctx(NonZeroU32::new(n_ctx as u32));

        let mut ctx = self.model
            .new_context(&self.llama_backend, ctx_params)
            .map_err(|e| anyhow::anyhow!("context creation failed: {e}"))?;

        let mut batch = LlamaBatch::new(capacity, 1);

        // ── Process the initial prompt ────────────────────────────────────
        let p_last = (prompt_tokens.len() as i32) - 1;
        for (i, &token) in prompt_tokens.iter().enumerate() {
            batch.add(token, i as i32, &[0], i as i32 == p_last)?;
        }
        ctx.decode(&mut batch)
            .map_err(|e| anyhow::anyhow!("decode (prompt) failed: {e}"))?;

        // ── Autoregressive loop ───────────────────────────────────────────
        let mut generated   = Vec::with_capacity(max_new_tokens);
        let mut n_cur: i32  = prompt_tokens.len() as i32;
        let mut logit_idx   = p_last; // batch-index with logits after first decode

        for _ in 0..max_new_tokens {
            let logits = ctx.get_logits_ith(logit_idx);
            let next = greedy_sample(logits, logit_bias);

            if self.model.is_eog_token(next) {
                break;
            }

            generated.push(next);

            // Append the new token to the sequence.
            batch.clear();
            batch.add(next, n_cur, &[0], true)?;
            ctx.decode(&mut batch)
                .map_err(|e| anyhow::anyhow!("decode (step {n_cur}) failed: {e}"))?;

            logit_idx = 0; // only one token in the batch from now on
            n_cur    += 1;
        }

        Ok(generated)
    }
}

// ── Helpers ───────────────────────────────────────────────────────────────────

/// Greedy argmax over `logits` with an optional additive per-token `bias`.
///
/// When `bias` is `Some`, the effective score for token `i` is
/// `logits[i] + bias[i]`.  `bias` must be the same length as `logits`
/// (i.e. `n_vocab`).
fn greedy_sample(logits: &[f32], bias: Option<&[f32]>) -> LlamaToken {
    let idx = match bias {
        None => logits
            .iter()
            .enumerate()
            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
            .map(|(i, _)| i)
            .unwrap_or(0),
        Some(b) => logits
            .iter()
            .zip(b.iter())
            .map(|(l, bias)| l + bias)
            .enumerate()
            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
            .map(|(i, _)| i)
            .unwrap_or(0),
    };
    LlamaToken::new(idx as i32)
}