#![forbid(unsafe_code)]
#![deny(rust_2018_idioms)]
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use atomr_infer_core::batch::ExecuteBatch;
use atomr_infer_core::error::{InferenceError, InferenceResult};
use atomr_infer_core::runner::{ModelRunner, RunHandle, SessionRebuildCause};
use atomr_infer_core::runtime::{RuntimeKind, TransportKind};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MistralRsConfig {
pub model_id: String,
#[serde(default)]
pub quant: Option<String>,
#[serde(default)]
pub hf_revision: Option<String>,
#[serde(default)]
pub force_cpu: bool,
#[serde(default)]
pub max_num_seqs: Option<usize>,
}
pub struct MistralRsRunner {
#[cfg_attr(not(feature = "mistralrs"), allow(dead_code))]
config: MistralRsConfig,
#[cfg(feature = "mistralrs")]
model: tokio::sync::OnceCell<std::sync::Arc<mistralrs::Model>>,
}
impl MistralRsRunner {
pub fn new(config: MistralRsConfig) -> Self {
Self {
config,
#[cfg(feature = "mistralrs")]
model: tokio::sync::OnceCell::new(),
}
}
#[cfg(feature = "mistralrs")]
async fn ensure_model(&self) -> InferenceResult<std::sync::Arc<mistralrs::Model>> {
self.model
.get_or_try_init(|| async {
let mut builder = mistralrs::TextModelBuilder::new(&self.config.model_id);
if self.config.force_cpu {
builder = builder.with_force_cpu();
}
if let Some(rev) = &self.config.hf_revision {
builder = builder.with_hf_revision(rev.clone());
}
if let Some(max_seqs) = self.config.max_num_seqs {
builder = builder.with_max_num_seqs(max_seqs);
}
if let Some(q) = &self.config.quant {
let isq = mistralrs::parse_isq_value(q, None)
.map_err(|e| InferenceError::Internal(format!("mistralrs: bad quant '{q}': {e}")))?;
builder = builder.with_isq(isq);
}
let model = builder.build().await.map_err(|e| {
InferenceError::Internal(format!(
"mistralrs: failed to build model '{}': {e}",
self.config.model_id
))
})?;
Ok(std::sync::Arc::new(model))
})
.await
.cloned()
}
}
#[cfg(feature = "mistralrs")]
fn map_role(role: atomr_infer_core::batch::Role) -> mistralrs::TextMessageRole {
use atomr_infer_core::batch::Role;
match role {
Role::System => mistralrs::TextMessageRole::System,
Role::User => mistralrs::TextMessageRole::User,
Role::Assistant => mistralrs::TextMessageRole::Assistant,
Role::Tool => mistralrs::TextMessageRole::Tool,
_ => mistralrs::TextMessageRole::User,
}
}
#[cfg(feature = "mistralrs")]
fn message_text(message: &atomr_infer_core::batch::Message) -> String {
use atomr_infer_core::batch::{ContentPart, MessageContent};
match &message.content {
MessageContent::Text(s) => s.clone(),
MessageContent::Parts(parts) => parts
.iter()
.filter_map(|p| match p {
ContentPart::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("\n"),
_ => String::new(),
}
}
#[async_trait]
impl ModelRunner for MistralRsRunner {
#[cfg_attr(
feature = "mistralrs",
tracing::instrument(skip(self, _batch), fields(model = %self.config.model_id))
)]
async fn execute(&mut self, _batch: ExecuteBatch) -> InferenceResult<RunHandle> {
#[cfg(not(feature = "mistralrs"))]
{
Err(InferenceError::Internal(
"mistralrs feature disabled at build time — rebuild with --features mistralrs".into(),
))
}
#[cfg(feature = "mistralrs")]
{
use atomr_infer_core::tokens::{FinishReason, TokenChunk};
use futures::StreamExt;
let model = self.ensure_model().await?;
let request_id = _batch.request_id.clone();
let mut messages = mistralrs::TextMessages::new();
for m in &_batch.messages {
messages = messages.add_message(map_role(m.role), message_text(m));
}
let (tx, rx) = tokio::sync::mpsc::channel::<InferenceResult<TokenChunk>>(64);
let req_id_for_task = request_id.clone();
tokio::spawn(async move {
let mut stream = match model.stream_chat_request(messages).await {
Ok(s) => s,
Err(e) => {
let _ = tx
.send(Err(InferenceError::Internal(format!(
"mistralrs: stream_chat_request failed: {e}"
))))
.await;
return;
}
};
while let Some(resp) = stream.next().await {
match resp {
mistralrs::Response::Chunk(chunk) => {
let choice = chunk.choices.first();
let text_delta = choice.and_then(|c| c.delta.content.clone()).unwrap_or_default();
let finish_reason = choice
.and_then(|c| c.finish_reason.as_deref())
.map(map_finish_reason);
let chunk_out = TokenChunk {
request_id: req_id_for_task.clone(),
text_delta,
tool_call_delta: None,
usage: None,
finish_reason,
};
if tx.send(Ok(chunk_out)).await.is_err() {
break;
}
}
mistralrs::Response::Done(full) => {
let usage = atomr_infer_core::tokens::TokenUsage {
input_tokens: full.usage.prompt_tokens as u32,
output_tokens: full.usage.completion_tokens as u32,
..Default::default()
};
let finish_reason = full
.choices
.first()
.map(|c| map_finish_reason(c.finish_reason.as_str()));
let chunk_out = TokenChunk {
request_id: req_id_for_task.clone(),
text_delta: String::new(),
tool_call_delta: None,
usage: Some(usage),
finish_reason: finish_reason.or(Some(FinishReason::Stop)),
};
let _ = tx.send(Ok(chunk_out)).await;
break;
}
mistralrs::Response::ModelError(msg, _partial) => {
let _ = tx
.send(Err(InferenceError::Internal(format!(
"mistralrs model error: {msg}"
))))
.await;
break;
}
mistralrs::Response::InternalError(e) | mistralrs::Response::ValidationError(e) => {
let _ = tx
.send(Err(InferenceError::Internal(format!("mistralrs error: {e}"))))
.await;
break;
}
_ => {
let _ = tx
.send(Err(InferenceError::Internal(
"mistralrs: unexpected response variant".into(),
)))
.await;
break;
}
}
}
});
let stream = tokio_stream::wrappers::ReceiverStream::new(rx).boxed();
Ok(RunHandle::streaming(stream))
}
}
async fn rebuild_session(&mut self, cause: SessionRebuildCause) -> InferenceResult<()> {
#[cfg(feature = "mistralrs")]
{
if matches!(
cause,
SessionRebuildCause::CudaContextPoisoned | SessionRebuildCause::Manual
) {
self.model = tokio::sync::OnceCell::new();
}
}
let _ = cause;
Ok(())
}
fn runtime_kind(&self) -> RuntimeKind {
RuntimeKind::MistralRs
}
fn transport_kind(&self) -> TransportKind {
TransportKind::LocalGpu
}
}
#[cfg(feature = "mistralrs")]
fn map_finish_reason(s: &str) -> atomr_infer_core::tokens::FinishReason {
use atomr_infer_core::tokens::FinishReason;
match s {
"stop" => FinishReason::Stop,
"length" => FinishReason::Length,
"tool_calls" => FinishReason::ToolCalls,
"content_filter" => FinishReason::ContentFilter,
_ => FinishReason::Stop,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn config_round_trips_through_serde() {
let cfg = MistralRsConfig {
model_id: "mistralai/Mistral-7B-Instruct-v0.3".into(),
quant: Some("Q4K".into()),
hf_revision: None,
force_cpu: false,
max_num_seqs: Some(16),
};
let json = serde_json::to_string(&cfg).expect("serialize");
let back: MistralRsConfig = serde_json::from_str(&json).expect("deserialize");
assert_eq!(back.model_id, cfg.model_id);
assert_eq!(back.quant, cfg.quant);
assert_eq!(back.max_num_seqs, cfg.max_num_seqs);
}
#[test]
fn runner_reports_runtime_kind() {
let runner = MistralRsRunner::new(MistralRsConfig {
model_id: "test".into(),
quant: None,
hf_revision: None,
force_cpu: false,
max_num_seqs: None,
});
assert_eq!(runner.runtime_kind(), RuntimeKind::MistralRs);
assert_eq!(runner.transport_kind(), TransportKind::LocalGpu);
}
#[cfg(not(feature = "mistralrs"))]
#[tokio::test]
async fn execute_without_feature_returns_internal_error() {
use atomr_infer_core::batch::SamplingParams;
let mut runner = MistralRsRunner::new(MistralRsConfig {
model_id: "test".into(),
quant: None,
hf_revision: None,
force_cpu: false,
max_num_seqs: None,
});
let batch = ExecuteBatch {
request_id: "test".into(),
model: "test".into(),
messages: vec![],
sampling: SamplingParams::default(),
stream: false,
estimated_tokens: 1,
};
let result = runner.execute(batch).await;
assert!(matches!(result, Err(InferenceError::Internal(_))));
}
}