Skip to main content

atomr_infer_runtime_mistralrs/
lib.rs

1//! # inference-runtime-mistralrs
2//!
3//! `mistralrs` runner for atomr-infer. Wraps `mistralrs::Model` +
4//! `mistralrs::TextModelBuilder` behind the `ModelRunner` trait so
5//! Mistral.rs participates in the same `Deployment` actor topology as
6//! the OpenAI / Anthropic / vLLM / TensorRT runners. Doc §10.3.
7//!
8//! The model is loaded lazily on the first call to
9//! `ModelRunner::execute` (mistralrs's builder downloads from
10//! HuggingFace, which can take minutes for 7B+ models — eager loading
11//! would block the runner's constructor for too long).
12//!
13//! Default-features-off the crate compiles to a typed-error stub;
14//! `cargo build --features remote-only` therefore pulls no candle /
15//! cuda dependencies via this crate.
16//!
17//! ## MSRV note
18//!
19//! mistralrs 0.8 declares MSRV 1.88. The atomr-infer workspace MSRV
20//! is 1.78 for remote-only builds; operators enabling this runner
21//! need a toolchain that satisfies mistralrs's own MSRV.
22
23#![forbid(unsafe_code)]
24#![deny(rust_2018_idioms)]
25
26use async_trait::async_trait;
27use serde::{Deserialize, Serialize};
28
29use atomr_infer_core::batch::ExecuteBatch;
30use atomr_infer_core::error::{InferenceError, InferenceResult};
31use atomr_infer_core::runner::{ModelRunner, RunHandle, SessionRebuildCause};
32use atomr_infer_core::runtime::{RuntimeKind, TransportKind};
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct MistralRsConfig {
36    /// HuggingFace repo id, e.g. `"mistralai/Mistral-7B-Instruct-v0.3"`.
37    pub model_id: String,
38    /// Optional in-situ quantisation. Passed verbatim to
39    /// `mistralrs::parse_isq_value`; e.g. `"Q4K"` / `"Q8_0"`.
40    #[serde(default)]
41    pub quant: Option<String>,
42    /// Optional HuggingFace revision (branch / tag / commit).
43    #[serde(default)]
44    pub hf_revision: Option<String>,
45    /// Force CPU execution (skip CUDA / Metal device probing).
46    #[serde(default)]
47    pub force_cpu: bool,
48    /// Maximum concurrent sequences the engine schedules. Defaults
49    /// to the mistralrs builder default (32).
50    #[serde(default)]
51    pub max_num_seqs: Option<usize>,
52}
53
54pub struct MistralRsRunner {
55    #[cfg_attr(not(feature = "mistralrs"), allow(dead_code))]
56    config: MistralRsConfig,
57    #[cfg(feature = "mistralrs")]
58    model: tokio::sync::OnceCell<std::sync::Arc<mistralrs::Model>>,
59}
60
61impl MistralRsRunner {
62    pub fn new(config: MistralRsConfig) -> Self {
63        Self {
64            config,
65            #[cfg(feature = "mistralrs")]
66            model: tokio::sync::OnceCell::new(),
67        }
68    }
69
70    #[cfg(feature = "mistralrs")]
71    async fn ensure_model(&self) -> InferenceResult<std::sync::Arc<mistralrs::Model>> {
72        self.model
73            .get_or_try_init(|| async {
74                let mut builder = mistralrs::TextModelBuilder::new(&self.config.model_id);
75                if self.config.force_cpu {
76                    builder = builder.with_force_cpu();
77                }
78                if let Some(rev) = &self.config.hf_revision {
79                    builder = builder.with_hf_revision(rev.clone());
80                }
81                if let Some(max_seqs) = self.config.max_num_seqs {
82                    builder = builder.with_max_num_seqs(max_seqs);
83                }
84                if let Some(q) = &self.config.quant {
85                    let isq = mistralrs::parse_isq_value(q, None)
86                        .map_err(|e| InferenceError::Internal(format!("mistralrs: bad quant '{q}': {e}")))?;
87                    builder = builder.with_isq(isq);
88                }
89                let model = builder.build().await.map_err(|e| {
90                    InferenceError::Internal(format!(
91                        "mistralrs: failed to build model '{}': {e}",
92                        self.config.model_id
93                    ))
94                })?;
95                Ok(std::sync::Arc::new(model))
96            })
97            .await
98            .cloned()
99    }
100}
101
102#[cfg(feature = "mistralrs")]
103fn map_role(role: atomr_infer_core::batch::Role) -> mistralrs::TextMessageRole {
104    use atomr_infer_core::batch::Role;
105    match role {
106        Role::System => mistralrs::TextMessageRole::System,
107        Role::User => mistralrs::TextMessageRole::User,
108        Role::Assistant => mistralrs::TextMessageRole::Assistant,
109        Role::Tool => mistralrs::TextMessageRole::Tool,
110        // `Role` is `#[non_exhaustive]`; default unknown roles to
111        // `User` so the request still reaches the model.
112        _ => mistralrs::TextMessageRole::User,
113    }
114}
115
116#[cfg(feature = "mistralrs")]
117fn message_text(message: &atomr_infer_core::batch::Message) -> String {
118    use atomr_infer_core::batch::{ContentPart, MessageContent};
119    match &message.content {
120        MessageContent::Text(s) => s.clone(),
121        MessageContent::Parts(parts) => parts
122            .iter()
123            .filter_map(|p| match p {
124                ContentPart::Text { text } => Some(text.as_str()),
125                _ => None,
126            })
127            .collect::<Vec<_>>()
128            .join("\n"),
129        // Forward-compat: drop unknown variants.
130        _ => String::new(),
131    }
132}
133
134#[async_trait]
135impl ModelRunner for MistralRsRunner {
136    #[cfg_attr(
137        feature = "mistralrs",
138        tracing::instrument(skip(self, _batch), fields(model = %self.config.model_id))
139    )]
140    async fn execute(&mut self, _batch: ExecuteBatch) -> InferenceResult<RunHandle> {
141        #[cfg(not(feature = "mistralrs"))]
142        {
143            Err(InferenceError::Internal(
144                "mistralrs feature disabled at build time — rebuild with --features mistralrs".into(),
145            ))
146        }
147        #[cfg(feature = "mistralrs")]
148        {
149            use atomr_infer_core::tokens::{FinishReason, TokenChunk};
150            use futures::StreamExt;
151
152            let model = self.ensure_model().await?;
153            let request_id = _batch.request_id.clone();
154
155            let mut messages = mistralrs::TextMessages::new();
156            for m in &_batch.messages {
157                messages = messages.add_message(map_role(m.role), message_text(m));
158            }
159
160            let (tx, rx) = tokio::sync::mpsc::channel::<InferenceResult<TokenChunk>>(64);
161            let req_id_for_task = request_id.clone();
162            tokio::spawn(async move {
163                let mut stream = match model.stream_chat_request(messages).await {
164                    Ok(s) => s,
165                    Err(e) => {
166                        let _ = tx
167                            .send(Err(InferenceError::Internal(format!(
168                                "mistralrs: stream_chat_request failed: {e}"
169                            ))))
170                            .await;
171                        return;
172                    }
173                };
174                while let Some(resp) = stream.next().await {
175                    match resp {
176                        mistralrs::Response::Chunk(chunk) => {
177                            let choice = chunk.choices.first();
178                            let text_delta = choice.and_then(|c| c.delta.content.clone()).unwrap_or_default();
179                            let finish_reason = choice
180                                .and_then(|c| c.finish_reason.as_deref())
181                                .map(map_finish_reason);
182                            let chunk_out = TokenChunk {
183                                request_id: req_id_for_task.clone(),
184                                text_delta,
185                                tool_call_delta: None,
186                                usage: None,
187                                finish_reason,
188                            };
189                            if tx.send(Ok(chunk_out)).await.is_err() {
190                                break;
191                            }
192                        }
193                        mistralrs::Response::Done(full) => {
194                            let usage = atomr_infer_core::tokens::TokenUsage {
195                                input_tokens: full.usage.prompt_tokens as u32,
196                                output_tokens: full.usage.completion_tokens as u32,
197                                ..Default::default()
198                            };
199                            let finish_reason = full
200                                .choices
201                                .first()
202                                .map(|c| map_finish_reason(c.finish_reason.as_str()));
203                            let chunk_out = TokenChunk {
204                                request_id: req_id_for_task.clone(),
205                                text_delta: String::new(),
206                                tool_call_delta: None,
207                                usage: Some(usage),
208                                finish_reason: finish_reason.or(Some(FinishReason::Stop)),
209                            };
210                            let _ = tx.send(Ok(chunk_out)).await;
211                            break;
212                        }
213                        mistralrs::Response::ModelError(msg, _partial) => {
214                            let _ = tx
215                                .send(Err(InferenceError::Internal(format!(
216                                    "mistralrs model error: {msg}"
217                                ))))
218                                .await;
219                            break;
220                        }
221                        mistralrs::Response::InternalError(e) | mistralrs::Response::ValidationError(e) => {
222                            let _ = tx
223                                .send(Err(InferenceError::Internal(format!("mistralrs error: {e}"))))
224                                .await;
225                            break;
226                        }
227                        // Other variants (Completion*, ImageGeneration, Speech, Raw,
228                        // Embeddings) are not produced by stream_chat_request on a
229                        // text model; if the engine ever surfaces one, drop the
230                        // stream rather than silently corrupting the token
231                        // sequence.
232                        _ => {
233                            let _ = tx
234                                .send(Err(InferenceError::Internal(
235                                    "mistralrs: unexpected response variant".into(),
236                                )))
237                                .await;
238                            break;
239                        }
240                    }
241                }
242            });
243
244            let stream = tokio_stream::wrappers::ReceiverStream::new(rx).boxed();
245            Ok(RunHandle::streaming(stream))
246        }
247    }
248
249    async fn rebuild_session(&mut self, cause: SessionRebuildCause) -> InferenceResult<()> {
250        #[cfg(feature = "mistralrs")]
251        {
252            // CUDA poisoning or a manual rebuild forces a fresh model
253            // load (and re-download if HF cache is gone). Auth /
254            // config-change causes are remote-only and ignored here.
255            if matches!(
256                cause,
257                SessionRebuildCause::CudaContextPoisoned | SessionRebuildCause::Manual
258            ) {
259                self.model = tokio::sync::OnceCell::new();
260            }
261        }
262        let _ = cause;
263        Ok(())
264    }
265
266    fn runtime_kind(&self) -> RuntimeKind {
267        RuntimeKind::MistralRs
268    }
269    fn transport_kind(&self) -> TransportKind {
270        TransportKind::LocalGpu
271    }
272}
273
274#[cfg(feature = "mistralrs")]
275fn map_finish_reason(s: &str) -> atomr_infer_core::tokens::FinishReason {
276    use atomr_infer_core::tokens::FinishReason;
277    match s {
278        "stop" => FinishReason::Stop,
279        "length" => FinishReason::Length,
280        "tool_calls" => FinishReason::ToolCalls,
281        "content_filter" => FinishReason::ContentFilter,
282        _ => FinishReason::Stop,
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    #[test]
291    fn config_round_trips_through_serde() {
292        let cfg = MistralRsConfig {
293            model_id: "mistralai/Mistral-7B-Instruct-v0.3".into(),
294            quant: Some("Q4K".into()),
295            hf_revision: None,
296            force_cpu: false,
297            max_num_seqs: Some(16),
298        };
299        let json = serde_json::to_string(&cfg).expect("serialize");
300        let back: MistralRsConfig = serde_json::from_str(&json).expect("deserialize");
301        assert_eq!(back.model_id, cfg.model_id);
302        assert_eq!(back.quant, cfg.quant);
303        assert_eq!(back.max_num_seqs, cfg.max_num_seqs);
304    }
305
306    #[test]
307    fn runner_reports_runtime_kind() {
308        let runner = MistralRsRunner::new(MistralRsConfig {
309            model_id: "test".into(),
310            quant: None,
311            hf_revision: None,
312            force_cpu: false,
313            max_num_seqs: None,
314        });
315        assert_eq!(runner.runtime_kind(), RuntimeKind::MistralRs);
316        assert_eq!(runner.transport_kind(), TransportKind::LocalGpu);
317    }
318
319    #[cfg(not(feature = "mistralrs"))]
320    #[tokio::test]
321    async fn execute_without_feature_returns_internal_error() {
322        use atomr_infer_core::batch::SamplingParams;
323
324        let mut runner = MistralRsRunner::new(MistralRsConfig {
325            model_id: "test".into(),
326            quant: None,
327            hf_revision: None,
328            force_cpu: false,
329            max_num_seqs: None,
330        });
331        let batch = ExecuteBatch {
332            request_id: "test".into(),
333            model: "test".into(),
334            messages: vec![],
335            sampling: SamplingParams::default(),
336            stream: false,
337            estimated_tokens: 1,
338        };
339        let result = runner.execute(batch).await;
340        assert!(matches!(result, Err(InferenceError::Internal(_))));
341    }
342}