solo-storage 0.3.5

Solo: SQLite + SQLCipher persistence layer
Documentation
// SPDX-License-Identifier: Apache-2.0

//! BGE-M3 model directory loader.
//!
//! This commit (1.4.a) ships the **file-discovery and config-parsing**
//! surface only. The actual forward pass via candle-core +
//! candle-transformers `xlm_roberta` + the `tokenizers` crate is
//! commit 1.4.b. Splitting the work this way:
//!
//!   - keeps the public API surface stable (commit 1.5 daemon main can
//!     refer to `BgeM3Loader` and `BgeM3Config` immediately);
//!   - lets the daemon use [`crate::embedder::StubEmbedder`] in the
//!     interim;
//!   - avoids forcing 1.2 GB of model weights into the test loop;
//!   - decouples the candle-transformers dep risk (mid-2024 era APIs that
//!     may have shifted) from the rest of the storage layer.
//!
//! ## Expected directory layout (HuggingFace `BAAI/bge-m3` mirror)
//!
//! ```text
//! <model_dir>/
//!     config.json          // model architecture / dim
//!     tokenizer.json       // HuggingFace tokenizer
//!     model.safetensors    // weights (preferred over pytorch_model.bin)
//! ```
//!
//! Optional: `tokenizer_config.json`, `special_tokens_map.json`. We don't
//! validate these at load time; the tokenizer crate handles them lazily
//! when present.

use std::fs;
use std::path::{Path, PathBuf};

use serde::{Deserialize, Serialize};
use solo_core::{Error, Result};

/// BGE-M3 default name + version. The version string is keyed by
/// `solo reembed` to detect when stored embeddings need regeneration; bump
/// it on any change that produces different vectors for the same input.
pub const BGE_M3_NAME: &str = "BAAI/bge-m3";
pub const BGE_M3_VERSION: &str = "v1";

/// BGE-M3 dense-output dimension (matches the public model card).
pub const BGE_M3_DIM: usize = 1024;

/// Required file basenames in the model directory.
const FILE_CONFIG: &str = "config.json";
const FILE_TOKENIZER: &str = "tokenizer.json";
const FILE_WEIGHTS_SAFETENSORS: &str = "model.safetensors";
/// Fallback weights filename produced by older HuggingFace exports.
const FILE_WEIGHTS_PYTORCH: &str = "pytorch_model.bin";

/// Subset of `config.json` we care about for the embedder. Additional
/// fields are accepted (and ignored) by serde's default behaviour. The
/// fields below are sufficient for dimensional validation against
/// `solo.config.toml.embedder.dim` and the HNSW snapshot.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BgeM3Config {
    /// e.g. "xlm-roberta" for the BGE-M3 base.
    pub model_type: String,

    /// Hidden size = output embedding dim. 1024 for BGE-M3.
    pub hidden_size: usize,

    /// Vocab size — passed through to the tokenizer + model loader.
    #[serde(default)]
    pub vocab_size: usize,

    /// Maximum input length the model was trained on. BGE-M3 supports
    /// 8192 tokens; we record but don't enforce here.
    #[serde(default)]
    pub max_position_embeddings: usize,
}

impl BgeM3Config {
    /// Parse `config.json` from disk.
    pub fn read(path: &Path) -> Result<Self> {
        let raw = fs::read_to_string(path)
            .map_err(|e| Error::embedder(format!("read {path:?}: {e}")))?;
        serde_json::from_str(&raw).map_err(Error::Serde)
    }
}

/// Resolved manifest of a BGE-M3 model directory. Built by [`BgeM3Loader::open`].
#[derive(Debug, Clone)]
pub struct BgeM3Manifest {
    pub model_dir: PathBuf,
    pub config_path: PathBuf,
    pub tokenizer_path: PathBuf,
    pub weights_path: PathBuf,
    pub config: BgeM3Config,
}

/// Loader for a BGE-M3 model directory. Validates file presence + parses
/// `config.json` on `open`. The forward-pass `Embedder` impl lands in
/// commit 1.4.b — until then this type is only useful for surfacing a
/// clear "model files present and consistent" check at daemon startup.
#[derive(Debug, Clone)]
pub struct BgeM3Loader {
    manifest: BgeM3Manifest,
}

impl BgeM3Loader {
    /// Discover and validate the model files in `model_dir`. Returns Err
    /// with a clear message on missing files or unparseable config.
    ///
    /// Weights resolution order: `model.safetensors` (preferred) →
    /// `pytorch_model.bin` (legacy). We don't support sharded weights at
    /// this layer; if the user has a sharded export they'll need to
    /// pre-merge or wait for commit 1.4.b's loader to grow that path.
    pub fn open(model_dir: impl Into<PathBuf>) -> Result<Self> {
        let model_dir = model_dir.into();
        if !model_dir.is_dir() {
            return Err(Error::embedder(format!(
                "BGE-M3 model dir not found or not a directory: {model_dir:?}"
            )));
        }

        let config_path = model_dir.join(FILE_CONFIG);
        let tokenizer_path = model_dir.join(FILE_TOKENIZER);

        for (name, p) in [
            (FILE_CONFIG, &config_path),
            (FILE_TOKENIZER, &tokenizer_path),
        ] {
            if !p.is_file() {
                return Err(Error::embedder(format!(
                    "BGE-M3 model dir is missing required file `{name}`: \
                     expected at {p:?}"
                )));
            }
        }

        let safetensors = model_dir.join(FILE_WEIGHTS_SAFETENSORS);
        let pytorch = model_dir.join(FILE_WEIGHTS_PYTORCH);
        let weights_path = if safetensors.is_file() {
            safetensors
        } else if pytorch.is_file() {
            pytorch
        } else {
            return Err(Error::embedder(format!(
                "BGE-M3 model dir has neither `{FILE_WEIGHTS_SAFETENSORS}` nor \
                 `{FILE_WEIGHTS_PYTORCH}`: expected one in {model_dir:?}"
            )));
        };

        let config = BgeM3Config::read(&config_path)?;

        // Dimensional sanity: BGE-M3 published dim is 1024. We don't *force*
        // 1024 in case the user is experimenting with a fine-tuned variant
        // that genuinely changed hidden_size, but we do refuse a clearly
        // broken zero / suspiciously-small config.
        if config.hidden_size == 0 {
            return Err(Error::embedder(format!(
                "BGE-M3 config.json reports hidden_size=0 — corrupt or wrong file? at {config_path:?}"
            )));
        }

        Ok(Self {
            manifest: BgeM3Manifest {
                model_dir,
                config_path,
                tokenizer_path,
                weights_path,
                config,
            },
        })
    }

    pub fn manifest(&self) -> &BgeM3Manifest {
        &self.manifest
    }

    pub fn dim(&self) -> usize {
        self.manifest.config.hidden_size
    }

    /// Construct the live BGE-M3 embedder. Loads the model + tokenizer
    /// from the manifest's paths; returns a `Box<dyn Embedder>` ready to
    /// serve `embed` calls. Cost: 1-3 seconds (1.2 GB safetensors via
    /// mmap; tokenizer JSON parse).
    ///
    /// As of commit 1.4.b: only `model.safetensors` is supported; legacy
    /// `pytorch_model.bin` is rejected with a clear error.
    pub fn into_embedder(self) -> Result<Box<dyn solo_core::Embedder>> {
        let inference = crate::embedder::bge_m3_inference::BgeM3Inference::load(&self.manifest)?;
        Ok(Box::new(inference))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::fs;

    fn write(path: &Path, body: &[u8]) {
        fs::write(path, body).unwrap();
    }

    fn fake_config_json(hidden_size: usize) -> String {
        format!(
            r#"{{
              "model_type": "xlm-roberta",
              "hidden_size": {hidden_size},
              "vocab_size": 250002,
              "max_position_embeddings": 8194
            }}"#
        )
    }

    fn fake_tokenizer_json() -> &'static str {
        // Minimal JSON; real tokenizer.json is huge but we only validate
        // existence at this layer, not contents.
        r#"{"version": "1.0", "model": {}}"#
    }

    #[test]
    fn open_validates_complete_safetensors_dir() {
        let tmp = tempfile::TempDir::new().unwrap();
        let dir = tmp.path();
        write(&dir.join("config.json"), fake_config_json(1024).as_bytes());
        write(&dir.join("tokenizer.json"), fake_tokenizer_json().as_bytes());
        write(&dir.join("model.safetensors"), b"<fake weights>");

        let loader = BgeM3Loader::open(dir).unwrap();
        assert_eq!(loader.dim(), 1024);
        assert_eq!(loader.manifest().config.model_type, "xlm-roberta");
        assert!(
            loader
                .manifest()
                .weights_path
                .ends_with("model.safetensors"),
            "expected safetensors path, got {:?}",
            loader.manifest().weights_path
        );
    }

    #[test]
    fn open_falls_back_to_pytorch_bin() {
        let tmp = tempfile::TempDir::new().unwrap();
        let dir = tmp.path();
        write(&dir.join("config.json"), fake_config_json(1024).as_bytes());
        write(&dir.join("tokenizer.json"), fake_tokenizer_json().as_bytes());
        write(&dir.join("pytorch_model.bin"), b"<fake>");

        let loader = BgeM3Loader::open(dir).unwrap();
        assert!(
            loader
                .manifest()
                .weights_path
                .ends_with("pytorch_model.bin")
        );
    }

    #[test]
    fn open_rejects_missing_config() {
        let tmp = tempfile::TempDir::new().unwrap();
        let dir = tmp.path();
        write(&dir.join("tokenizer.json"), fake_tokenizer_json().as_bytes());
        write(&dir.join("model.safetensors"), b"<fake>");
        let err = BgeM3Loader::open(dir).unwrap_err();
        assert!(err.to_string().contains("config.json"), "got: {err}");
    }

    #[test]
    fn open_rejects_missing_tokenizer() {
        let tmp = tempfile::TempDir::new().unwrap();
        let dir = tmp.path();
        write(&dir.join("config.json"), fake_config_json(1024).as_bytes());
        write(&dir.join("model.safetensors"), b"<fake>");
        let err = BgeM3Loader::open(dir).unwrap_err();
        assert!(err.to_string().contains("tokenizer.json"), "got: {err}");
    }

    #[test]
    fn open_rejects_missing_weights() {
        let tmp = tempfile::TempDir::new().unwrap();
        let dir = tmp.path();
        write(&dir.join("config.json"), fake_config_json(1024).as_bytes());
        write(&dir.join("tokenizer.json"), fake_tokenizer_json().as_bytes());
        let err = BgeM3Loader::open(dir).unwrap_err();
        assert!(
            err.to_string().contains("model.safetensors"),
            "got: {err}"
        );
    }

    #[test]
    fn open_rejects_zero_hidden_size() {
        let tmp = tempfile::TempDir::new().unwrap();
        let dir = tmp.path();
        write(&dir.join("config.json"), fake_config_json(0).as_bytes());
        write(&dir.join("tokenizer.json"), fake_tokenizer_json().as_bytes());
        write(&dir.join("model.safetensors"), b"<fake>");
        let err = BgeM3Loader::open(dir).unwrap_err();
        assert!(err.to_string().contains("hidden_size=0"), "got: {err}");
    }

    #[test]
    fn open_rejects_nonexistent_dir() {
        let err = BgeM3Loader::open("/definitely/not/here/solo/bge-m3").unwrap_err();
        assert!(err.to_string().contains("not found"), "got: {err}");
    }

    #[test]
    fn open_rejects_dir_that_is_a_file() {
        let tmp = tempfile::NamedTempFile::new().unwrap();
        let err = BgeM3Loader::open(tmp.path()).unwrap_err();
        assert!(err.to_string().contains("not a directory"), "got: {err}");
    }

    #[test]
    fn into_embedder_with_garbage_weights_fails_clearly() {
        // Fake weights (not real safetensors). into_embedder loads via
        // candle's mmap path; expect a clear error referencing
        // safetensors / VarBuilder. Real-weight smoke tests require
        // BAAI/bge-m3 predownloaded — out-of-band for CI.
        let tmp = tempfile::TempDir::new().unwrap();
        let dir = tmp.path();
        write(&dir.join("config.json"), full_xlm_config_json(1024).as_bytes());
        write(&dir.join("tokenizer.json"), fake_tokenizer_json().as_bytes());
        write(&dir.join("model.safetensors"), b"<fake>");

        let loader = BgeM3Loader::open(dir).unwrap();
        let err = match loader.into_embedder() {
            Ok(_) => panic!("expected Err but got Ok"),
            Err(e) => e,
        };
        let msg = err.to_string();
        assert!(
            msg.contains("safetensors")
                || msg.contains("VarBuilder")
                || msg.contains("Tokenizer")
                || msg.contains("from_mmaped"),
            "expected safetensors/VarBuilder/Tokenizer error, got: {msg}"
        );
    }

    /// XLM-RoBERTa Config has more fields than our solo-side
    /// BgeM3Config. This fixture matches BAAI/bge-m3's published config.
    fn full_xlm_config_json(hidden_size: usize) -> String {
        format!(
            r#"{{
              "model_type": "xlm-roberta",
              "hidden_size": {hidden_size},
              "layer_norm_eps": 1e-5,
              "attention_probs_dropout_prob": 0.1,
              "hidden_dropout_prob": 0.1,
              "num_attention_heads": 16,
              "position_embedding_type": "absolute",
              "intermediate_size": 4096,
              "hidden_act": "gelu",
              "num_hidden_layers": 24,
              "vocab_size": 250002,
              "max_position_embeddings": 8194,
              "type_vocab_size": 1,
              "pad_token_id": 1
            }}"#
        )
    }
}