nornir 0.4.34

Companion to cargo: dependency tracking, release gating, deploy, benchmarks, and documentation assembly. Project-agnostic.
//! `mistralrs` generative backend (`gen-mistralrs`) — in-process generation via
//! the candle-based [`mistralrs`] crate.
//!
//! `mistralrs` builds a [`Model`] from a HF model id (it fetches weights +
//! tokenizer itself) and runs chat completions in-process — no daemon, no HTTP.
//! The generator wraps that `Model` behind the [`Generator`] trait. Because
//! `mistralrs` is async, the backend drives it on a small current-thread tokio
//! runtime so [`complete`](Generator::complete) keeps the trait's sync shape.
//!
//! ## Model spec
//! `mistralrs:<hf-model-id>` (e.g. `mistralrs:Qwen/Qwen2.5-0.5B-Instruct`). An
//! empty model uses a small default. The heavy build+fetch happens lazily on the
//! first `complete`, so constructing the generator + probing `available()` is
//! cheap.
//!
//! ## `available()`
//! Reports `true` when a tokio runtime can be created (the only hard runtime
//! prerequisite the in-process engine needs to start); model weights are fetched
//! lazily on first use, so a not-yet-downloaded model still constructs and
//! generates on demand.

use std::sync::Mutex;

use anyhow::{anyhow, Context, Result};
use mistralrs::{
    IsqType, RequestBuilder, TextMessageRole, TextModelBuilder,
};

use super::{Backend, GenAnswer, GenRequest, Generator};

/// Default HF model id when the spec is `mistralrs:` (a small instruct model).
const DEFAULT_MODEL: &str = "Qwen/Qwen2.5-0.5B-Instruct";

/// The mistralrs generator. Holds the HF model id + a lazily-built [`Model`]
/// (behind a mutex so the generator is `Sync`) + the tokio runtime that drives
/// the async engine.
pub struct MistralRsGenerator {
    id: String,
    model_id: String,
    runtime: tokio::runtime::Runtime,
    model: Mutex<Option<mistralrs::Model>>,
}

impl MistralRsGenerator {
    /// Build the generator for `model` (a HF model id, or empty for the default).
    /// Does NOT fetch or build the model — that happens on first `complete`.
    pub fn new(model: &str) -> Result<Self> {
        let model_id =
            if model.is_empty() { DEFAULT_MODEL.to_string() } else { model.to_string() };
        let runtime = tokio::runtime::Builder::new_current_thread()
            .enable_all()
            .build()
            .context("mistralrs: build tokio runtime")?;
        Ok(Self {
            id: format!("mistralrs:{model_id}"),
            model_id,
            runtime,
            model: Mutex::new(None),
        })
    }

    /// Build the in-process model (fetches weights + tokenizer on first call).
    async fn build_model(model_id: &str) -> Result<mistralrs::Model> {
        TextModelBuilder::new(model_id)
            .with_isq(IsqType::Q4K)
            .with_logging()
            .build()
            .await
            .map_err(|e| anyhow!("mistralrs: build model `{model_id}`: {e}"))
    }
}

impl Backend for MistralRsGenerator {
    fn id(&self) -> &str {
        &self.id
    }

    /// Available when the async engine's runtime prerequisites are met. The
    /// model itself is fetched lazily, so this reports readiness-to-start rather
    /// than model presence; a failed model build surfaces as a `complete` error.
    fn available(&self) -> bool {
        // The current-thread runtime was built in `new`; if we got here it
        // exists. Probe that we can enter it without blocking.
        self.runtime.handle().block_on(async { true })
    }
}

impl Generator for MistralRsGenerator {
    fn complete(&self, req: &GenRequest) -> Result<GenAnswer> {
        let started = std::time::Instant::now();
        let model_id = self.model_id.clone();

        let mut guard = self.model.lock().expect("mistralrs model mutex");
        if guard.is_none() {
            let built = self.runtime.block_on(Self::build_model(&model_id))?;
            *guard = Some(built);
        }
        let model = guard.as_ref().expect("model set above");

        // Build the chat request: optional system turn + the user prompt.
        let mut builder = RequestBuilder::new()
            .set_sampler_max_len(req.max_tokens)
            .set_sampler_temperature(req.temperature as f64);
        if let Some(sys) = &req.system {
            builder = builder.add_message(TextMessageRole::System, sys.clone());
        }
        builder = builder.add_message(TextMessageRole::User, req.prompt.clone());

        let response = self
            .runtime
            .block_on(model.send_chat_request(builder))
            .map_err(|e| anyhow!("mistralrs: chat request: {e}"))?;

        let choice = response
            .choices
            .first()
            .ok_or_else(|| anyhow!("mistralrs: empty choices"))?;
        let text = choice.message.content.clone().unwrap_or_default();

        let latency_ms = started.elapsed().as_secs_f64() * 1000.0;
        let tokens_in = response.usage.prompt_tokens as i64;
        let tokens_out = response.usage.completion_tokens as i64;
        let tokens_per_s = if latency_ms > 0.0 {
            tokens_out as f64 / (latency_ms / 1000.0)
        } else {
            0.0
        };
        Ok(GenAnswer { text, tokens_in, tokens_out, tokens_per_s, latency_ms })
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn default_model_when_empty() {
        let gen = MistralRsGenerator::new("").unwrap();
        assert_eq!(gen.id(), "mistralrs:Qwen/Qwen2.5-0.5B-Instruct");
    }

    #[test]
    fn constructs_and_reports_availability_without_loading() {
        let gen = MistralRsGenerator::new("Qwen/Qwen2.5-0.5B-Instruct").unwrap();
        assert_eq!(gen.id(), "mistralrs:Qwen/Qwen2.5-0.5B-Instruct");
        // The runtime exists, so the readiness probe is true; no model is built.
        assert!(gen.available());
    }

    /// Heavy: builds the real in-process model (network + GBs) and generates.
    #[test]
    #[ignore = "downloads + builds a real model"]
    fn real_generation_round_trips() {
        let gen = MistralRsGenerator::new("Qwen/Qwen2.5-0.5B-Instruct").unwrap();
        let req = GenRequest::new("Reply with the single word: pong").with_max_tokens(8);
        let ans = gen.complete(&req).unwrap();
        assert!(ans.tokens_out >= 0);
        assert!(!ans.text.is_empty());
    }
}