solo-storage 0.3.6

Solo: SQLite + SQLCipher persistence layer
Documentation
// SPDX-License-Identifier: Apache-2.0

//! BGE-M3 forward pass via candle-transformers + tokenizers (commit 1.4.b).
//!
//! Wraps `candle_transformers::models::xlm_roberta::XLMRobertaModel` (the
//! base architecture BGE-M3 is built on) and the HuggingFace
//! `tokenizers::Tokenizer` for text→tokens. Produces dense F32 embeddings
//! by CLS-pooling the last hidden state and L2-normalising — the same
//! pipeline `BGEM3FlagModel.encode` (the FlagEmbedding reference impl)
//! uses for its dense head.
//!
//! ## Model directory expectations
//!
//! Same as `BgeM3Loader::open` validates:
//!   - `config.json`        — XLM-RoBERTa config (parsed twice: once
//!                            into our `BgeM3Config` for dim, once into
//!                            candle's `xlm_roberta::Config` for model
//!                            construction).
//!   - `tokenizer.json`     — HuggingFace tokenizer.
//!   - `model.safetensors`  — preferred. `pytorch_model.bin` is also
//!                            accepted by `BgeM3Loader::open`, but
//!                            **only safetensors is supported by this
//!                            inference path**. Pytorch-bin loading
//!                            via candle's `from_pth` works in theory
//!                            but the BGE-M3 official artefact is
//!                            safetensors and we keep the surface narrow.
//!
//! ## Device
//!
//! CPU only for v0.1 (`Device::Cpu`). GPU support is a v0.2 follow-up
//! once we measure throughput needs and decide on CUDA / Metal targets.
//!
//! ## Throughput notes
//!
//! BGE-M3 on CPU: ~50-200 ms per single text on a modern laptop. Batching
//! recovers most of that — `embed_batch` pads sequences to the longest
//! in the batch and runs one forward pass. For Solo's typical workload
//! (interactive remember/recall, ~10 calls/min) latency is fine; bulk
//! re-embedding (`solo reembed`) will be the slow path.
//!
//! ## Concurrency
//!
//! `BgeM3Inference` is `Send + Sync`. The candle XLMRobertaModel holds
//! `Arc<...>` weights internally; concurrent `embed` calls share the
//! same model. The tokenizer is also clone-cheap.

use std::sync::Arc;

use async_trait::async_trait;
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::xlm_roberta::{Config as XLMRobertaConfig, XLMRobertaModel};
use solo_core::{Embedder, Embedding, EmbeddingDtype, Error, Result};
use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};

use super::bge_m3::{BGE_M3_NAME, BGE_M3_VERSION, BgeM3Manifest};

/// Loaded BGE-M3 model + tokenizer ready to serve `embed` calls.
pub struct BgeM3Inference {
    model: XLMRobertaModel,
    tokenizer: Arc<Tokenizer>,
    device: Device,
    dim: usize,
    name: String,
    version: String,
}

impl BgeM3Inference {
    /// Construct from a [`BgeM3Manifest`] (already-validated paths +
    /// parsed solo-side config). Loads weights, tokenizer, builds the
    /// model on CPU.
    pub fn load(manifest: &BgeM3Manifest) -> Result<Self> {
        let device = Device::Cpu;

        // Weights: only safetensors supported in this inference path.
        // Check before config parse so a clear "wrong weights format"
        // error wins over a less-specific config-parse error.
        let weights_path = &manifest.weights_path;
        if !weights_path
            .file_name()
            .map(|n| n.to_string_lossy().ends_with(".safetensors"))
            .unwrap_or(false)
        {
            return Err(Error::embedder(format!(
                "BGE-M3 inference requires model.safetensors; got {weights_path:?}"
            )));
        }

        // Re-parse config.json into candle's XLMRoberta Config — has
        // additional fields (attention_probs_dropout_prob, hidden_act,
        // etc.) that our solo-side BgeM3Config doesn't carry.
        let config_raw = std::fs::read_to_string(&manifest.config_path)
            .map_err(|e| Error::embedder(format!("read config.json: {e}")))?;
        let candle_cfg: XLMRobertaConfig = serde_json::from_str(&config_raw)
            .map_err(|e| Error::embedder(format!("parse XLMRoberta Config from config.json: {e}")))?;

        // SAFETY: mmap is unsafe because the underlying file could be
        // truncated/modified while held. We open it read-only and never
        // share the path with anyone who would mutate it.
        let vb = unsafe {
            VarBuilder::from_mmaped_safetensors(&[weights_path.as_path()], DType::F32, &device)
                .map_err(|e| Error::embedder(format!("VarBuilder::from_mmaped_safetensors: {e}")))?
        };

        let model = XLMRobertaModel::new(&candle_cfg, vb)
            .map_err(|e| Error::embedder(format!("XLMRobertaModel::new: {e}")))?;

        let mut tokenizer = Tokenizer::from_file(&manifest.tokenizer_path)
            .map_err(|e| Error::embedder(format!("Tokenizer::from_file: {e}")))?;

        // Configure padding (BatchLongest) + truncation to the model's
        // max_position_embeddings. BGE-M3 supports up to 8194 tokens; we
        // honour whatever the config declares.
        tokenizer
            .with_padding(Some(PaddingParams {
                strategy: PaddingStrategy::BatchLongest,
                pad_id: candle_cfg.pad_token_id,
                pad_token: "<pad>".to_string(),
                ..Default::default()
            }))
            .with_truncation(Some(TruncationParams {
                max_length: candle_cfg.max_position_embeddings.saturating_sub(2).max(1),
                ..Default::default()
            }))
            .map_err(|e| Error::embedder(format!("tokenizer truncation config: {e}")))?;

        Ok(Self {
            model,
            tokenizer: Arc::new(tokenizer),
            device,
            dim: candle_cfg.hidden_size,
            name: manifest.config.model_type.clone(),
            version: BGE_M3_VERSION.to_string(),
        })
    }

    fn forward_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
        if texts.is_empty() {
            return Ok(Vec::new());
        }

        // Tokenize batch (with padding to the longest in batch).
        let encodings = self
            .tokenizer
            .encode_batch(texts.to_vec(), true)
            .map_err(|e| Error::embedder(format!("tokenizer.encode_batch: {e}")))?;

        let batch = encodings.len();
        let seq_len = encodings.first().map(|e| e.len()).unwrap_or(0);
        if seq_len == 0 {
            return Err(Error::embedder("tokenizer produced empty sequences"));
        }

        // Build [batch, seq] u32 tensors for input_ids + attention_mask
        // + token_type_ids (zeros — XLM-RoBERTa uses a single segment).
        let mut input_ids: Vec<u32> = Vec::with_capacity(batch * seq_len);
        let mut attention_mask: Vec<u32> = Vec::with_capacity(batch * seq_len);
        for enc in &encodings {
            input_ids.extend_from_slice(enc.get_ids());
            attention_mask.extend_from_slice(enc.get_attention_mask());
        }
        let token_type_ids = vec![0u32; batch * seq_len];

        let input_ids = Tensor::from_vec(input_ids, (batch, seq_len), &self.device)
            .map_err(|e| Error::embedder(format!("input_ids tensor: {e}")))?;
        let attention_mask = Tensor::from_vec(attention_mask, (batch, seq_len), &self.device)
            .map_err(|e| Error::embedder(format!("attention_mask tensor: {e}")))?;
        let token_type_ids = Tensor::from_vec(token_type_ids, (batch, seq_len), &self.device)
            .map_err(|e| Error::embedder(format!("token_type_ids tensor: {e}")))?;

        // Forward pass — returns [batch, seq, hidden].
        let hidden = self
            .model
            .forward(&input_ids, &attention_mask, &token_type_ids, None, None, None)
            .map_err(|e| Error::embedder(format!("XLMRobertaModel::forward: {e}")))?;

        // CLS pooling: take the first token of each sequence → [batch, hidden].
        let cls = hidden
            .narrow(1, 0, 1)
            .and_then(|t| t.squeeze(1))
            .map_err(|e| Error::embedder(format!("CLS pool: {e}")))?;

        // L2-normalize each row.
        let norm = cls
            .sqr()
            .and_then(|t| t.sum_keepdim(1))
            .and_then(|t| t.sqrt())
            .map_err(|e| Error::embedder(format!("L2 norm: {e}")))?;
        let normalized = cls
            .broadcast_div(&norm)
            .map_err(|e| Error::embedder(format!("normalize divide: {e}")))?;

        // Convert to Vec<Vec<f32>>.
        let normalized = normalized
            .to_dtype(DType::F32)
            .map_err(|e| Error::embedder(format!("to_dtype f32: {e}")))?;
        let rows = normalized
            .to_vec2::<f32>()
            .map_err(|e| Error::embedder(format!("to_vec2: {e}")))?;
        Ok(rows)
    }
}

#[async_trait]
impl Embedder for BgeM3Inference {
    fn name(&self) -> &str {
        &self.name
    }
    fn version(&self) -> &str {
        &self.version
    }
    fn dim(&self) -> usize {
        self.dim
    }
    fn dtype(&self) -> EmbeddingDtype {
        EmbeddingDtype::F32
    }

    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Embedding>> {
        if texts.is_empty() {
            return Ok(Vec::new());
        }
        // Move forward pass off the async runtime — candle's CPU forward
        // is pure-CPU work that can take 50-500 ms per batch.
        let texts: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
        let dim = self.dim;
        // We need self in the spawn_blocking; clone the Arc-backed parts
        // via a wrapper. Since BgeM3Inference isn't trivially clone-cheap
        // (XLMRobertaModel doesn't impl Clone), do it on the current
        // thread but in spawn_blocking so we don't starve the runtime.
        //
        // SAFETY: we hold &self across the await; tokio::task::block_in_place
        // is the standard way to run blocking work on a multi-thread runtime
        // without moving self.
        let rows = tokio::task::block_in_place(|| {
            let refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
            self.forward_batch(&refs)
        })?;

        let mut out: Vec<Embedding> = Vec::with_capacity(rows.len());
        for row in rows {
            if row.len() != dim {
                return Err(Error::embedder(format!(
                    "BGE-M3 produced {} dims, expected {}",
                    row.len(),
                    dim
                )));
            }
            let mut bytes = Vec::with_capacity(dim * 4);
            for v in row {
                bytes.extend_from_slice(&v.to_le_bytes());
            }
            out.push(Embedding {
                dtype: EmbeddingDtype::F32,
                dim,
                data: bytes,
            });
        }
        Ok(out)
    }
}

/// Convenience: name + version for callers that want the canonical
/// BGE-M3 identity without instantiating the model.
pub fn canonical_identity() -> (&'static str, &'static str) {
    (BGE_M3_NAME, BGE_M3_VERSION)
}

#[cfg(test)]
mod tests {
    use super::*;

    /// Loading without real weights fails with a clear error. Real-weight
    /// smoke tests require the user to predownload BGE-M3 and point a
    /// `SOLO_BGE_M3_DIR` env var at it; that's out of band for CI.
    #[test]
    fn load_with_garbage_safetensors_fails_clearly() {
        use crate::embedder::bge_m3::{BgeM3Config, BgeM3Loader, BgeM3Manifest};
        // Can't construct a BgeM3Manifest with a garbage safetensors via
        // BgeM3Loader::open's normal path because the loader doesn't
        // validate weight contents — but the inference's load DOES. Build
        // the manifest manually.
        let tmp = tempfile::TempDir::new().unwrap();
        let dir = tmp.path();
        let cfg_path = dir.join("config.json");
        let tok_path = dir.join("tokenizer.json");
        let weights_path = dir.join("model.safetensors");
        std::fs::write(
            &cfg_path,
            r#"{
              "model_type": "xlm-roberta",
              "hidden_size": 1024,
              "layer_norm_eps": 1e-5,
              "attention_probs_dropout_prob": 0.1,
              "hidden_dropout_prob": 0.1,
              "num_attention_heads": 16,
              "position_embedding_type": "absolute",
              "intermediate_size": 4096,
              "hidden_act": "gelu",
              "num_hidden_layers": 24,
              "vocab_size": 250002,
              "max_position_embeddings": 8194,
              "type_vocab_size": 1,
              "pad_token_id": 1
            }"#,
        )
        .unwrap();
        std::fs::write(&tok_path, r#"{"version":"1.0","model":{}}"#).unwrap();
        std::fs::write(&weights_path, b"NOT A SAFETENSORS FILE").unwrap();

        let manifest = BgeM3Manifest {
            model_dir: dir.to_path_buf(),
            config_path: cfg_path,
            tokenizer_path: tok_path,
            weights_path,
            config: BgeM3Config {
                model_type: "xlm-roberta".into(),
                hidden_size: 1024,
                vocab_size: 250002,
                max_position_embeddings: 8194,
            },
        };
        let _ = BgeM3Loader::open(dir).expect("loader sanity");
        let err = match BgeM3Inference::load(&manifest) {
            Ok(_) => panic!("expected Err"),
            Err(e) => e,
        };
        let msg = err.to_string();
        // The error originates inside candle's safetensors parser. We
        // accept any error mentioning the safetensors / VarBuilder path
        // — exact wording depends on candle internals.
        assert!(
            msg.contains("VarBuilder")
                || msg.contains("safetensors")
                || msg.contains("Tokenizer")
                || msg.contains("from_mmaped"),
            "unexpected error: {msg}"
        );
    }

    #[test]
    fn rejects_pytorch_bin_weights() {
        use crate::embedder::bge_m3::{BgeM3Config, BgeM3Manifest};
        let tmp = tempfile::TempDir::new().unwrap();
        let dir = tmp.path();
        let cfg_path = dir.join("config.json");
        let tok_path = dir.join("tokenizer.json");
        let pth_path = dir.join("pytorch_model.bin");
        std::fs::write(&cfg_path, "{}").unwrap();
        std::fs::write(&tok_path, "{}").unwrap();
        std::fs::write(&pth_path, b"fake").unwrap();
        let manifest = BgeM3Manifest {
            model_dir: dir.to_path_buf(),
            config_path: cfg_path,
            tokenizer_path: tok_path,
            weights_path: pth_path,
            config: BgeM3Config {
                model_type: "xlm-roberta".into(),
                hidden_size: 1024,
                vocab_size: 0,
                max_position_embeddings: 0,
            },
        };
        let err = match BgeM3Inference::load(&manifest) {
            Ok(_) => panic!("expected Err"),
            Err(e) => e,
        };
        assert!(
            err.to_string().contains("requires model.safetensors"),
            "got: {err}"
        );
    }
}