use std::sync::Mutex;
use anyhow::{anyhow, Context, Result};
use mistralrs::{
IsqType, RequestBuilder, TextMessageRole, TextModelBuilder,
};
use super::{Backend, GenAnswer, GenRequest, Generator};
const DEFAULT_MODEL: &str = "Qwen/Qwen2.5-0.5B-Instruct";
pub struct MistralRsGenerator {
id: String,
model_id: String,
runtime: tokio::runtime::Runtime,
model: Mutex<Option<mistralrs::Model>>,
}
impl MistralRsGenerator {
pub fn new(model: &str) -> Result<Self> {
let model_id =
if model.is_empty() { DEFAULT_MODEL.to_string() } else { model.to_string() };
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.context("mistralrs: build tokio runtime")?;
Ok(Self {
id: format!("mistralrs:{model_id}"),
model_id,
runtime,
model: Mutex::new(None),
})
}
async fn build_model(model_id: &str) -> Result<mistralrs::Model> {
TextModelBuilder::new(model_id)
.with_isq(IsqType::Q4K)
.with_logging()
.build()
.await
.map_err(|e| anyhow!("mistralrs: build model `{model_id}`: {e}"))
}
}
impl Backend for MistralRsGenerator {
fn id(&self) -> &str {
&self.id
}
fn available(&self) -> bool {
self.runtime.handle().block_on(async { true })
}
}
impl Generator for MistralRsGenerator {
fn complete(&self, req: &GenRequest) -> Result<GenAnswer> {
let started = std::time::Instant::now();
let model_id = self.model_id.clone();
let mut guard = self.model.lock().expect("mistralrs model mutex");
if guard.is_none() {
let built = self.runtime.block_on(Self::build_model(&model_id))?;
*guard = Some(built);
}
let model = guard.as_ref().expect("model set above");
let mut builder = RequestBuilder::new()
.set_sampler_max_len(req.max_tokens)
.set_sampler_temperature(req.temperature as f64);
if let Some(sys) = &req.system {
builder = builder.add_message(TextMessageRole::System, sys.clone());
}
builder = builder.add_message(TextMessageRole::User, req.prompt.clone());
let response = self
.runtime
.block_on(model.send_chat_request(builder))
.map_err(|e| anyhow!("mistralrs: chat request: {e}"))?;
let choice = response
.choices
.first()
.ok_or_else(|| anyhow!("mistralrs: empty choices"))?;
let text = choice.message.content.clone().unwrap_or_default();
let latency_ms = started.elapsed().as_secs_f64() * 1000.0;
let tokens_in = response.usage.prompt_tokens as i64;
let tokens_out = response.usage.completion_tokens as i64;
let tokens_per_s = if latency_ms > 0.0 {
tokens_out as f64 / (latency_ms / 1000.0)
} else {
0.0
};
Ok(GenAnswer { text, tokens_in, tokens_out, tokens_per_s, latency_ms })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_model_when_empty() {
let gen = MistralRsGenerator::new("").unwrap();
assert_eq!(gen.id(), "mistralrs:Qwen/Qwen2.5-0.5B-Instruct");
}
#[test]
fn constructs_and_reports_availability_without_loading() {
let gen = MistralRsGenerator::new("Qwen/Qwen2.5-0.5B-Instruct").unwrap();
assert_eq!(gen.id(), "mistralrs:Qwen/Qwen2.5-0.5B-Instruct");
assert!(gen.available());
}
#[test]
#[ignore = "downloads + builds a real model"]
fn real_generation_round_trips() {
let gen = MistralRsGenerator::new("Qwen/Qwen2.5-0.5B-Instruct").unwrap();
let req = GenRequest::new("Reply with the single word: pong").with_max_tokens(8);
let ans = gen.complete(&req).unwrap();
assert!(ans.tokens_out >= 0);
assert!(!ans.text.is_empty());
}
}