atomr_infer_runtime_mistralrs/
lib.rs1#![forbid(unsafe_code)]
24#![deny(rust_2018_idioms)]
25
26use async_trait::async_trait;
27use serde::{Deserialize, Serialize};
28
29use atomr_infer_core::batch::ExecuteBatch;
30use atomr_infer_core::error::{InferenceError, InferenceResult};
31use atomr_infer_core::runner::{ModelRunner, RunHandle, SessionRebuildCause};
32use atomr_infer_core::runtime::{RuntimeKind, TransportKind};
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct MistralRsConfig {
36 pub model_id: String,
38 #[serde(default)]
41 pub quant: Option<String>,
42 #[serde(default)]
44 pub hf_revision: Option<String>,
45 #[serde(default)]
47 pub force_cpu: bool,
48 #[serde(default)]
51 pub max_num_seqs: Option<usize>,
52}
53
54pub struct MistralRsRunner {
55 #[cfg_attr(not(feature = "mistralrs"), allow(dead_code))]
56 config: MistralRsConfig,
57 #[cfg(feature = "mistralrs")]
58 model: tokio::sync::OnceCell<std::sync::Arc<mistralrs::Model>>,
59}
60
61impl MistralRsRunner {
62 pub fn new(config: MistralRsConfig) -> Self {
63 Self {
64 config,
65 #[cfg(feature = "mistralrs")]
66 model: tokio::sync::OnceCell::new(),
67 }
68 }
69
70 #[cfg(feature = "mistralrs")]
71 async fn ensure_model(&self) -> InferenceResult<std::sync::Arc<mistralrs::Model>> {
72 self.model
73 .get_or_try_init(|| async {
74 let mut builder = mistralrs::TextModelBuilder::new(&self.config.model_id);
75 if self.config.force_cpu {
76 builder = builder.with_force_cpu();
77 }
78 if let Some(rev) = &self.config.hf_revision {
79 builder = builder.with_hf_revision(rev.clone());
80 }
81 if let Some(max_seqs) = self.config.max_num_seqs {
82 builder = builder.with_max_num_seqs(max_seqs);
83 }
84 if let Some(q) = &self.config.quant {
85 let isq = mistralrs::parse_isq_value(q, None)
86 .map_err(|e| InferenceError::Internal(format!("mistralrs: bad quant '{q}': {e}")))?;
87 builder = builder.with_isq(isq);
88 }
89 let model = builder.build().await.map_err(|e| {
90 InferenceError::Internal(format!(
91 "mistralrs: failed to build model '{}': {e}",
92 self.config.model_id
93 ))
94 })?;
95 Ok(std::sync::Arc::new(model))
96 })
97 .await
98 .cloned()
99 }
100}
101
102#[cfg(feature = "mistralrs")]
103fn map_role(role: atomr_infer_core::batch::Role) -> mistralrs::TextMessageRole {
104 use atomr_infer_core::batch::Role;
105 match role {
106 Role::System => mistralrs::TextMessageRole::System,
107 Role::User => mistralrs::TextMessageRole::User,
108 Role::Assistant => mistralrs::TextMessageRole::Assistant,
109 Role::Tool => mistralrs::TextMessageRole::Tool,
110 _ => mistralrs::TextMessageRole::User,
113 }
114}
115
116#[cfg(feature = "mistralrs")]
117fn message_text(message: &atomr_infer_core::batch::Message) -> String {
118 use atomr_infer_core::batch::{ContentPart, MessageContent};
119 match &message.content {
120 MessageContent::Text(s) => s.clone(),
121 MessageContent::Parts(parts) => parts
122 .iter()
123 .filter_map(|p| match p {
124 ContentPart::Text { text } => Some(text.as_str()),
125 _ => None,
126 })
127 .collect::<Vec<_>>()
128 .join("\n"),
129 _ => String::new(),
131 }
132}
133
134#[async_trait]
135impl ModelRunner for MistralRsRunner {
136 #[cfg_attr(
137 feature = "mistralrs",
138 tracing::instrument(skip(self, _batch), fields(model = %self.config.model_id))
139 )]
140 async fn execute(&mut self, _batch: ExecuteBatch) -> InferenceResult<RunHandle> {
141 #[cfg(not(feature = "mistralrs"))]
142 {
143 Err(InferenceError::Internal(
144 "mistralrs feature disabled at build time — rebuild with --features mistralrs".into(),
145 ))
146 }
147 #[cfg(feature = "mistralrs")]
148 {
149 use atomr_infer_core::tokens::{FinishReason, TokenChunk};
150 use futures::StreamExt;
151
152 let model = self.ensure_model().await?;
153 let request_id = _batch.request_id.clone();
154
155 let mut messages = mistralrs::TextMessages::new();
156 for m in &_batch.messages {
157 messages = messages.add_message(map_role(m.role), message_text(m));
158 }
159
160 let (tx, rx) = tokio::sync::mpsc::channel::<InferenceResult<TokenChunk>>(64);
161 let req_id_for_task = request_id.clone();
162 tokio::spawn(async move {
163 let mut stream = match model.stream_chat_request(messages).await {
164 Ok(s) => s,
165 Err(e) => {
166 let _ = tx
167 .send(Err(InferenceError::Internal(format!(
168 "mistralrs: stream_chat_request failed: {e}"
169 ))))
170 .await;
171 return;
172 }
173 };
174 while let Some(resp) = stream.next().await {
175 match resp {
176 mistralrs::Response::Chunk(chunk) => {
177 let choice = chunk.choices.first();
178 let text_delta = choice.and_then(|c| c.delta.content.clone()).unwrap_or_default();
179 let finish_reason = choice
180 .and_then(|c| c.finish_reason.as_deref())
181 .map(map_finish_reason);
182 let chunk_out = TokenChunk {
183 request_id: req_id_for_task.clone(),
184 text_delta,
185 tool_call_delta: None,
186 usage: None,
187 finish_reason,
188 };
189 if tx.send(Ok(chunk_out)).await.is_err() {
190 break;
191 }
192 }
193 mistralrs::Response::Done(full) => {
194 let usage = atomr_infer_core::tokens::TokenUsage {
195 input_tokens: full.usage.prompt_tokens as u32,
196 output_tokens: full.usage.completion_tokens as u32,
197 ..Default::default()
198 };
199 let finish_reason = full
200 .choices
201 .first()
202 .map(|c| map_finish_reason(c.finish_reason.as_str()));
203 let chunk_out = TokenChunk {
204 request_id: req_id_for_task.clone(),
205 text_delta: String::new(),
206 tool_call_delta: None,
207 usage: Some(usage),
208 finish_reason: finish_reason.or(Some(FinishReason::Stop)),
209 };
210 let _ = tx.send(Ok(chunk_out)).await;
211 break;
212 }
213 mistralrs::Response::ModelError(msg, _partial) => {
214 let _ = tx
215 .send(Err(InferenceError::Internal(format!(
216 "mistralrs model error: {msg}"
217 ))))
218 .await;
219 break;
220 }
221 mistralrs::Response::InternalError(e) | mistralrs::Response::ValidationError(e) => {
222 let _ = tx
223 .send(Err(InferenceError::Internal(format!("mistralrs error: {e}"))))
224 .await;
225 break;
226 }
227 _ => {
233 let _ = tx
234 .send(Err(InferenceError::Internal(
235 "mistralrs: unexpected response variant".into(),
236 )))
237 .await;
238 break;
239 }
240 }
241 }
242 });
243
244 let stream = tokio_stream::wrappers::ReceiverStream::new(rx).boxed();
245 Ok(RunHandle::streaming(stream))
246 }
247 }
248
249 async fn rebuild_session(&mut self, cause: SessionRebuildCause) -> InferenceResult<()> {
250 #[cfg(feature = "mistralrs")]
251 {
252 if matches!(
256 cause,
257 SessionRebuildCause::CudaContextPoisoned | SessionRebuildCause::Manual
258 ) {
259 self.model = tokio::sync::OnceCell::new();
260 }
261 }
262 let _ = cause;
263 Ok(())
264 }
265
266 fn runtime_kind(&self) -> RuntimeKind {
267 RuntimeKind::MistralRs
268 }
269 fn transport_kind(&self) -> TransportKind {
270 TransportKind::LocalGpu
271 }
272}
273
274#[cfg(feature = "mistralrs")]
275fn map_finish_reason(s: &str) -> atomr_infer_core::tokens::FinishReason {
276 use atomr_infer_core::tokens::FinishReason;
277 match s {
278 "stop" => FinishReason::Stop,
279 "length" => FinishReason::Length,
280 "tool_calls" => FinishReason::ToolCalls,
281 "content_filter" => FinishReason::ContentFilter,
282 _ => FinishReason::Stop,
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289
290 #[test]
291 fn config_round_trips_through_serde() {
292 let cfg = MistralRsConfig {
293 model_id: "mistralai/Mistral-7B-Instruct-v0.3".into(),
294 quant: Some("Q4K".into()),
295 hf_revision: None,
296 force_cpu: false,
297 max_num_seqs: Some(16),
298 };
299 let json = serde_json::to_string(&cfg).expect("serialize");
300 let back: MistralRsConfig = serde_json::from_str(&json).expect("deserialize");
301 assert_eq!(back.model_id, cfg.model_id);
302 assert_eq!(back.quant, cfg.quant);
303 assert_eq!(back.max_num_seqs, cfg.max_num_seqs);
304 }
305
306 #[test]
307 fn runner_reports_runtime_kind() {
308 let runner = MistralRsRunner::new(MistralRsConfig {
309 model_id: "test".into(),
310 quant: None,
311 hf_revision: None,
312 force_cpu: false,
313 max_num_seqs: None,
314 });
315 assert_eq!(runner.runtime_kind(), RuntimeKind::MistralRs);
316 assert_eq!(runner.transport_kind(), TransportKind::LocalGpu);
317 }
318
319 #[cfg(not(feature = "mistralrs"))]
320 #[tokio::test]
321 async fn execute_without_feature_returns_internal_error() {
322 use atomr_infer_core::batch::SamplingParams;
323
324 let mut runner = MistralRsRunner::new(MistralRsConfig {
325 model_id: "test".into(),
326 quant: None,
327 hf_revision: None,
328 force_cpu: false,
329 max_num_seqs: None,
330 });
331 let batch = ExecuteBatch {
332 request_id: "test".into(),
333 model: "test".into(),
334 messages: vec![],
335 sampling: SamplingParams::default(),
336 stream: false,
337 estimated_tokens: 1,
338 };
339 let result = runner.execute(batch).await;
340 assert!(matches!(result, Err(InferenceError::Internal(_))));
341 }
342}