use crate::ir::MessageRole;
use crate::runtime_adapter::llm::{
ChatMessage, GenerationConfig, GenerationOutput, LlmBackend, LlmConfig, LlmResult,
PartialToken, StreamingCallback,
};
#[cfg(feature = "llm-mistral")]
use crate::runtime_adapter::llm_telemetry::compute_streaming_fields;
use crate::runtime_adapter::AdapterError;
use crate::tracing as xybrid_trace;
#[cfg(feature = "llm-mistral")]
use mistralrs::{
GgufModelBuilder, Model, PagedAttentionMetaBuilder, RequestBuilder, Response, TextMessageRole,
};
#[cfg(feature = "llm-mistral")]
fn nonzero(v: f32) -> Option<f32> {
if v > 0.0 {
Some(v)
} else {
None
}
}
#[cfg(feature = "llm-mistral")]
struct StreamState {
text: String,
finish_reason: String,
tokens_reported: Option<usize>,
prompt_tokens_reported: Option<usize>,
chunk_ts: Vec<std::time::Instant>,
decode_tps_reported: Option<f32>,
prefill_tps_reported: Option<f32>,
saw_terminal: bool,
}
#[cfg(feature = "llm-mistral")]
impl StreamState {
fn new() -> Self {
Self {
text: String::new(),
finish_reason: String::from("unknown"),
tokens_reported: None,
prompt_tokens_reported: None,
chunk_ts: Vec::new(),
decode_tps_reported: None,
prefill_tps_reported: None,
saw_terminal: false,
}
}
}
#[cfg(feature = "llm-mistral")]
fn handle_response(
response: Response,
state: &mut StreamState,
ts: std::time::Instant,
) -> Result<bool, AdapterError> {
match response {
Response::Chunk(chunk) => {
let delta = chunk
.choices
.first()
.and_then(|c| c.delta.content.as_deref());
if let Some(body) = delta.filter(|s| !s.is_empty()) {
state.chunk_ts.push(ts);
state.text.push_str(body);
}
let chunk_fr = chunk.choices.first().and_then(|c| c.finish_reason.as_ref());
if chunk.usage.is_some() || chunk_fr.is_some() {
state.saw_terminal = true;
if let Some(u) = chunk.usage.as_ref() {
state.tokens_reported = Some(u.completion_tokens);
state.prompt_tokens_reported = Some(u.prompt_tokens);
state.decode_tps_reported = nonzero(u.avg_compl_tok_per_sec);
state.prefill_tps_reported = nonzero(u.avg_prompt_tok_per_sec);
}
if let Some(fr) = chunk_fr {
state.finish_reason = fr.clone();
}
}
Ok(false)
}
Response::Done(final_resp) => {
state.saw_terminal = true;
state.tokens_reported = Some(final_resp.usage.completion_tokens);
state.prompt_tokens_reported = Some(final_resp.usage.prompt_tokens);
state.decode_tps_reported = nonzero(final_resp.usage.avg_compl_tok_per_sec);
state.prefill_tps_reported = nonzero(final_resp.usage.avg_prompt_tok_per_sec);
if let Some(choice) = final_resp.choices.first() {
state.finish_reason = choice.finish_reason.clone();
}
Ok(true)
}
Response::ModelError(msg, partial) => {
let preview = partial
.choices
.first()
.and_then(|c| c.message.content.as_deref())
.map(|s| format!(" (partial: {:?})", &s[..s.len().min(200)]))
.unwrap_or_default();
Err(AdapterError::InferenceFailed(format!(
"model: {}{}",
msg, preview
)))
}
Response::InternalError(e) => {
Err(AdapterError::InferenceFailed(format!("internal: {}", e)))
}
Response::ValidationError(e) => {
Err(AdapterError::InvalidInput(format!("validation: {}", e)))
}
_ => Err(AdapterError::InferenceFailed(
"unexpected stream response variant for chat".to_string(),
)),
}
}
#[cfg(feature = "llm-mistral")]
fn emit_new_text_if_any(
state: &StreamState,
before_len: usize,
token_index: &mut usize,
on_token: &mut StreamingCallback<'_>,
) -> Result<(), AdapterError> {
if state.text.len() > before_len {
let token = state.text[before_len..].to_string();
let partial = PartialToken::new(token, *token_index, state.text.clone());
on_token(partial).map_err(AdapterError::from_streaming_callback_error)?;
*token_index = token_index.saturating_add(1);
}
Ok(())
}
#[cfg(feature = "llm-mistral")]
fn emit_final_token_if_needed(
state: &StreamState,
token_index: usize,
on_token: &mut StreamingCallback<'_>,
) -> Result<(), AdapterError> {
if token_index > 0 {
let final_partial = PartialToken::new(String::new(), token_index, state.text.clone())
.with_finish_reason(state.finish_reason.clone());
on_token(final_partial).map_err(AdapterError::from_streaming_callback_error)?;
}
Ok(())
}
#[cfg(feature = "llm-mistral")]
pub struct MistralBackend {
model: Option<Model>,
config: Option<LlmConfig>,
effective_context_length: Option<usize>,
runtime: tokio::runtime::Runtime,
}
#[cfg(feature = "llm-mistral")]
impl MistralBackend {
pub fn new() -> LlmResult<Self> {
let runtime = tokio::runtime::Runtime::new().map_err(|e| {
AdapterError::RuntimeError(format!("Failed to create tokio runtime: {}", e))
})?;
Ok(Self {
model: None,
config: None,
effective_context_length: None,
runtime,
})
}
fn convert_role(role: &MessageRole) -> TextMessageRole {
match role {
MessageRole::System => TextMessageRole::System,
MessageRole::Assistant => TextMessageRole::Assistant,
MessageRole::User => TextMessageRole::User,
}
}
fn build_request(messages: &[ChatMessage], config: &GenerationConfig) -> RequestBuilder {
let mut request = RequestBuilder::new();
for msg in messages {
request = request.add_message(Self::convert_role(&msg.role), &msg.content);
}
if config.temperature <= 0.0 {
request
.set_deterministic_sampler()
.set_sampler_max_len(config.max_tokens)
} else {
request
.set_sampler_temperature(config.temperature as f64)
.set_sampler_topp(config.top_p as f64)
.set_sampler_topk(config.top_k)
.set_sampler_max_len(config.max_tokens)
}
}
}
#[cfg(feature = "llm-mistral")]
impl Default for MistralBackend {
fn default() -> Self {
Self::new().expect("Failed to create MistralBackend")
}
}
#[cfg(feature = "llm-mistral")]
impl LlmBackend for MistralBackend {
fn name(&self) -> &str {
"mistral"
}
fn wire_label(&self) -> Option<&'static str> {
Some("mistralrs")
}
fn supported_formats(&self) -> Vec<&'static str> {
vec!["gguf"]
}
fn load(&mut self, config: &LlmConfig) -> LlmResult<()> {
use std::path::Path;
let model_path = Path::new(&config.model_path);
if !model_path.exists() {
return Err(AdapterError::ModelNotFound(config.model_path.clone()));
}
const DEFAULT_CONTEXT_LENGTH: usize = 4096;
if config.context_length != DEFAULT_CONTEXT_LENGTH {
log::warn!(
"LlmConfig.context_length={} is not wired to the mistralrs GGUF backend; \
requested value ignored. Context length is derived from the GGUF file.",
config.context_length
);
}
if config.gpu_layers != 0 {
log::warn!(
"LlmConfig.gpu_layers={} is not wired to mistralrs; build with the \
appropriate feature flag (llm-mistral-metal/llm-mistral-cuda) instead.",
config.gpu_layers
);
}
self.effective_context_length = None;
let (model_dir, model_file) = if model_path.is_file() {
let dir = model_path
.parent()
.ok_or_else(|| AdapterError::InvalidInput("Invalid model path".to_string()))?;
let file = model_path
.file_name()
.and_then(|s| s.to_str())
.ok_or_else(|| AdapterError::InvalidInput("Invalid model filename".to_string()))?;
(dir.to_string_lossy().to_string(), file.to_string())
} else {
let gguf_files: Vec<_> = std::fs::read_dir(model_path)
.map_err(AdapterError::IOError)?
.filter_map(|e| e.ok())
.filter(|e| {
e.path()
.extension()
.map(|ext| ext == "gguf")
.unwrap_or(false)
})
.collect();
if gguf_files.is_empty() {
return Err(AdapterError::ModelNotFound(format!(
"No .gguf files found in {}",
config.model_path
)));
}
let file = gguf_files[0].file_name().to_string_lossy().to_string();
(config.model_path.clone(), file)
};
let model = self.runtime.block_on(async {
let mut builder = GgufModelBuilder::new(&model_dir, vec![model_file]);
if let Some(ref template) = config.chat_template {
builder = builder.with_chat_template(template);
}
if config.logging {
builder = builder.with_logging();
}
if config.paged_attention {
let paged_cfg = PagedAttentionMetaBuilder::default().build().map_err(|e| {
AdapterError::RuntimeError(format!("Paged attention setup failed: {}", e))
})?;
builder = builder.with_paged_attn(paged_cfg);
}
builder
.build()
.await
.map_err(|e| AdapterError::RuntimeError(format!("Model loading failed: {}", e)))
})?;
self.model = Some(model);
self.config = Some(config.clone());
Ok(())
}
fn is_loaded(&self) -> bool {
self.model.is_some()
}
fn unload(&mut self) -> LlmResult<()> {
self.model = None;
self.config = None;
Ok(())
}
fn generate(
&self,
messages: &[ChatMessage],
config: &GenerationConfig,
) -> LlmResult<GenerationOutput> {
let model = self.model.as_ref().ok_or_else(|| {
AdapterError::ModelNotLoaded("No model loaded. Call load() first.".to_string())
})?;
let request = Self::build_request(messages, config);
let start = std::time::Instant::now();
let state = self.runtime.block_on(async move {
let mut stream = model
.stream_chat_request(request)
.await
.map_err(|e| AdapterError::InferenceFailed(format!("stream init: {}", e)))?;
let mut state = StreamState::new();
while let Some(response) = stream.next().await {
let done = handle_response(response, &mut state, std::time::Instant::now())?;
if done {
break;
}
}
if !state.saw_terminal {
return Err(AdapterError::InferenceFailed(
"stream closed before terminal chunk (no usage or finish_reason)".to_string(),
));
}
Ok::<_, AdapterError>(state)
})?;
let StreamState {
text,
finish_reason,
tokens_reported,
prompt_tokens_reported,
chunk_ts,
decode_tps_reported,
prefill_tps_reported,
..
} = state;
let tokens_generated = tokens_reported.unwrap_or(chunk_ts.len());
let prompt_token_count = prompt_tokens_reported.unwrap_or(0);
let fields =
compute_streaming_fields(start, &chunk_ts, prompt_token_count, tokens_generated);
xybrid_trace::add_metadata("tokens_in", prompt_token_count.to_string());
Ok(GenerationOutput {
text,
tokens_generated,
generation_time_ms: fields.generation_time_ms,
tokens_per_second: fields.tokens_per_second,
finish_reason,
ttft_ms: fields.ttft_ms,
mean_itl_ms: fields.mean_itl_ms,
p95_itl_ms: fields.p95_itl_ms,
emitted_chunks: fields.emitted_chunks,
inter_chunk_ms: fields.inter_chunk_ms,
decode_tps: decode_tps_reported.or(fields.decode_tps),
prefill_tps: prefill_tps_reported.or(fields.prefill_tps),
})
}
fn generate_raw(&self, prompt: &str, config: &GenerationConfig) -> LlmResult<GenerationOutput> {
let messages = vec![ChatMessage::user(prompt)];
self.generate(&messages, config)
}
fn generate_streaming(
&self,
messages: &[ChatMessage],
config: &GenerationConfig,
on_token: StreamingCallback<'_>,
) -> LlmResult<GenerationOutput> {
let model = self.model.as_ref().ok_or_else(|| {
AdapterError::ModelNotLoaded("No model loaded. Call load() first.".to_string())
})?;
let request = Self::build_request(messages, config);
let start = std::time::Instant::now();
let mut on_token = on_token;
let state = self.runtime.block_on(async {
let mut stream = model
.stream_chat_request(request)
.await
.map_err(|e| AdapterError::InferenceFailed(format!("stream init: {}", e)))?;
let mut state = StreamState::new();
let mut token_index = 0usize;
while let Some(response) = stream.next().await {
let before_len = state.text.len();
let done = handle_response(response, &mut state, std::time::Instant::now())?;
emit_new_text_if_any(&state, before_len, &mut token_index, &mut on_token)?;
if done {
break;
}
}
if !state.saw_terminal {
return Err(AdapterError::InferenceFailed(
"stream closed before terminal chunk (no usage or finish_reason)".to_string(),
));
}
emit_final_token_if_needed(&state, token_index, &mut on_token)?;
Ok::<_, AdapterError>(state)
})?;
let StreamState {
text,
finish_reason,
tokens_reported,
prompt_tokens_reported,
chunk_ts,
decode_tps_reported,
prefill_tps_reported,
..
} = state;
let tokens_generated = tokens_reported.unwrap_or(chunk_ts.len());
let prompt_token_count = prompt_tokens_reported.unwrap_or(0);
let fields =
compute_streaming_fields(start, &chunk_ts, prompt_token_count, tokens_generated);
xybrid_trace::add_metadata("tokens_in", prompt_token_count.to_string());
Ok(GenerationOutput {
text,
tokens_generated,
generation_time_ms: fields.generation_time_ms,
tokens_per_second: fields.tokens_per_second,
finish_reason,
ttft_ms: fields.ttft_ms,
mean_itl_ms: fields.mean_itl_ms,
p95_itl_ms: fields.p95_itl_ms,
emitted_chunks: fields.emitted_chunks,
inter_chunk_ms: fields.inter_chunk_ms,
decode_tps: decode_tps_reported.or(fields.decode_tps),
prefill_tps: prefill_tps_reported.or(fields.prefill_tps),
})
}
fn supports_streaming(&self) -> bool {
true
}
fn memory_usage(&self) -> Option<u64> {
None
}
fn context_length(&self) -> Option<usize> {
self.effective_context_length
}
}
#[cfg(not(feature = "llm-mistral"))]
pub struct MistralBackend;
#[cfg(not(feature = "llm-mistral"))]
impl MistralBackend {
pub fn new() -> LlmResult<Self> {
Err(AdapterError::RuntimeError(
"llm-mistral feature not enabled. Build with --features llm-mistral".to_string(),
))
}
}
#[cfg(not(feature = "llm-mistral"))]
impl LlmBackend for MistralBackend {
fn name(&self) -> &str {
"mistral"
}
fn supported_formats(&self) -> Vec<&'static str> {
vec!["gguf"]
}
fn load(&mut self, _config: &LlmConfig) -> LlmResult<()> {
Err(AdapterError::RuntimeError(
"llm-mistral feature not enabled".to_string(),
))
}
fn is_loaded(&self) -> bool {
false
}
fn unload(&mut self) -> LlmResult<()> {
Ok(())
}
fn generate(
&self,
_messages: &[ChatMessage],
_config: &GenerationConfig,
) -> LlmResult<GenerationOutput> {
Err(AdapterError::RuntimeError(
"llm-mistral feature not enabled".to_string(),
))
}
fn generate_raw(
&self,
_prompt: &str,
_config: &GenerationConfig,
) -> LlmResult<GenerationOutput> {
Err(AdapterError::RuntimeError(
"llm-mistral feature not enabled".to_string(),
))
}
}
#[cfg(test)]
mod tests {
#[cfg(feature = "llm-mistral")]
mod mock_stream {
use super::super::{
emit_final_token_if_needed, emit_new_text_if_any, handle_response, nonzero, StreamState,
};
use crate::abort::{AbortReason, CloudFallbackAbort};
use crate::runtime_adapter::llm::{PartialToken, StreamingCallback};
use crate::runtime_adapter::AdapterError;
use mistralrs::{ChatCompletionChunkResponse, ChunkChoice, Delta, Response, Usage};
use std::time::{Duration, Instant};
fn usage(completion_tokens: usize, avg_prompt: f32, avg_compl: f32) -> Usage {
Usage {
completion_tokens,
prompt_tokens: 16,
total_tokens: 16 + completion_tokens,
avg_tok_per_sec: (avg_prompt + avg_compl) / 2.0,
avg_prompt_tok_per_sec: avg_prompt,
avg_compl_tok_per_sec: avg_compl,
total_time_sec: 1.0,
total_prompt_time_sec: 0.1,
total_completion_time_sec: 0.9,
}
}
fn delta_chunk(content: &str) -> Response {
Response::Chunk(ChatCompletionChunkResponse {
id: "test".into(),
choices: vec![ChunkChoice {
finish_reason: None,
index: 0,
delta: Delta {
content: Some(content.into()),
role: "assistant".into(),
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
}],
created: 0,
model: "test-model".into(),
system_fingerprint: String::new(),
object: "chat.completion.chunk".into(),
usage: None,
})
}
fn terminal_chunk(
completion_tokens: usize,
avg_prompt: f32,
avg_compl: f32,
finish_reason: &str,
) -> Response {
Response::Chunk(ChatCompletionChunkResponse {
id: "test".into(),
choices: vec![ChunkChoice {
finish_reason: Some(finish_reason.into()),
index: 0,
delta: Delta {
content: None,
role: "assistant".into(),
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
}],
created: 0,
model: "test-model".into(),
system_fingerprint: String::new(),
object: "chat.completion.chunk".into(),
usage: Some(usage(completion_tokens, avg_prompt, avg_compl)),
})
}
#[test]
fn happy_path_populates_all_fields() {
let start = Instant::now();
let mut state = StreamState::new();
assert!(!handle_response(
delta_chunk("Hel"),
&mut state,
start + Duration::from_millis(40)
)
.unwrap());
assert!(!handle_response(
delta_chunk("lo"),
&mut state,
start + Duration::from_millis(55)
)
.unwrap());
assert!(!handle_response(
delta_chunk(" world"),
&mut state,
start + Duration::from_millis(70)
)
.unwrap());
assert!(!handle_response(
terminal_chunk(3, 250.0, 120.0, "stop"),
&mut state,
start + Duration::from_millis(85),
)
.unwrap());
assert_eq!(state.text, "Hello world");
assert_eq!(state.chunk_ts.len(), 3);
assert!(state.saw_terminal);
assert_eq!(state.tokens_reported, Some(3));
assert_eq!(state.decode_tps_reported, Some(120.0));
assert_eq!(state.prefill_tps_reported, Some(250.0));
assert_eq!(state.finish_reason, "stop");
let ttft = state.chunk_ts[0].duration_since(start).as_millis() as u64;
assert_eq!(ttft, 40);
}
#[test]
fn terminal_only_chunk_does_not_inflate_chunk_ts() {
let mut state = StreamState::new();
handle_response(
terminal_chunk(5, 180.0, 90.0, "stop"),
&mut state,
Instant::now(),
)
.unwrap();
assert!(state.saw_terminal);
assert_eq!(state.chunk_ts.len(), 0);
assert_eq!(state.tokens_reported, Some(5));
}
#[test]
fn zero_tps_values_collapse_to_none() {
let mut state = StreamState::new();
handle_response(
terminal_chunk(1, 0.0, 0.0, "stop"),
&mut state,
Instant::now(),
)
.unwrap();
assert!(state.saw_terminal);
assert_eq!(state.decode_tps_reported, None);
assert_eq!(state.prefill_tps_reported, None);
}
#[test]
fn stream_without_terminal_leaves_saw_terminal_false() {
let mut state = StreamState::new();
handle_response(delta_chunk("Partial"), &mut state, Instant::now()).unwrap();
handle_response(delta_chunk(" response"), &mut state, Instant::now()).unwrap();
assert!(!state.saw_terminal);
assert_eq!(state.text, "Partial response");
assert_eq!(state.tokens_reported, None);
assert_eq!(state.finish_reason, "unknown");
}
#[test]
fn done_response_populates_and_stops() {
use mistralrs::{ChatCompletionResponse, Choice, ResponseMessage};
let mut state = StreamState::new();
let final_resp = ChatCompletionResponse {
id: "test".into(),
choices: vec![Choice {
finish_reason: "length".into(),
index: 0,
message: ResponseMessage {
content: Some("done body".into()),
role: "assistant".into(),
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
}],
created: 0,
model: "test-model".into(),
system_fingerprint: String::new(),
object: "chat.completion".into(),
usage: usage(42, 180.0, 95.0),
};
let stop =
handle_response(Response::Done(final_resp), &mut state, Instant::now()).unwrap();
assert!(
stop,
"Response::Done must return true so the outer loop breaks"
);
assert!(state.saw_terminal);
assert_eq!(state.tokens_reported, Some(42));
assert_eq!(state.decode_tps_reported, Some(95.0));
assert_eq!(state.prefill_tps_reported, Some(180.0));
assert_eq!(state.finish_reason, "length");
}
#[test]
fn error_variants_produce_adapter_error() {
let mut state = StreamState::new();
let err = handle_response(
Response::ValidationError("bad request".to_string().into()),
&mut state,
Instant::now(),
)
.unwrap_err();
assert!(matches!(err, AdapterError::InvalidInput(_)));
}
#[test]
fn model_error_preserves_partial_text_in_message() {
use mistralrs::{ChatCompletionResponse, Choice, ResponseMessage};
let partial = ChatCompletionResponse {
id: "test".into(),
choices: vec![Choice {
finish_reason: "error".into(),
index: 0,
message: ResponseMessage {
content: Some("Hello wor".into()),
role: "assistant".into(),
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
}],
created: 0,
model: "test-model".into(),
system_fingerprint: String::new(),
object: "chat.completion".into(),
usage: usage(2, 0.0, 0.0),
};
let mut state = StreamState::new();
let err = handle_response(
Response::ModelError("kernel died".into(), partial),
&mut state,
Instant::now(),
)
.unwrap_err();
let msg = match err {
AdapterError::InferenceFailed(m) => m,
other => panic!("expected InferenceFailed, got {:?}", other),
};
assert!(msg.contains("kernel died"), "got: {msg}");
assert!(msg.contains("Hello wor"), "partial preview missing: {msg}");
}
#[test]
fn nonzero_filter() {
assert_eq!(nonzero(0.0), None);
assert_eq!(nonzero(-1.0), None);
assert_eq!(nonzero(42.5), Some(42.5));
}
#[test]
fn backend_reports_true_streaming_for_sdk_cancellation_gate() {
use crate::runtime_adapter::llm::LlmBackend;
let backend = super::super::MistralBackend::new().unwrap();
assert!(
backend.supports_streaming(),
"mistral must stay on the true streaming path so SDK abort checks can interrupt generation"
);
}
#[test]
fn callback_cloud_fallback_abort_is_preserved_as_typed_adapter_error() {
let mut state = StreamState::new();
let before_len = state.text.len();
handle_response(delta_chunk("stop here"), &mut state, Instant::now()).unwrap();
let mut token_index = 0usize;
let mut callback: StreamingCallback<'_> = Box::new(|_token: PartialToken| {
Err(Box::new(CloudFallbackAbort::new(AbortReason::StressMemory)))
});
let started = Instant::now();
let err = emit_new_text_if_any(&state, before_len, &mut token_index, &mut callback)
.expect_err("callback abort must stop the mistral stream");
let elapsed = started.elapsed();
assert_eq!(
err.cloud_fallback_abort_reason(),
Some(AbortReason::StressMemory),
"mistral callback errors must preserve the typed CloudFallbackAbort marker"
);
assert!(
elapsed <= Duration::from_millis(50),
"mistral callback abort exceeded M-series cancellation budget: {:?}",
elapsed
);
assert_eq!(
token_index, 0,
"token index must not advance after a failed callback"
);
}
#[test]
fn final_token_callback_error_is_preserved() {
let mut state = StreamState::new();
state.text = "done".to_string();
state.finish_reason = "stop".to_string();
let mut callback: StreamingCallback<'_> = Box::new(|_token: PartialToken| {
Err(Box::new(CloudFallbackAbort::new(
AbortReason::StressThermal,
)))
});
let err = emit_final_token_if_needed(&state, 1, &mut callback)
.expect_err("final callback abort must stop the mistral stream");
assert_eq!(
err.cloud_fallback_abort_reason(),
Some(AbortReason::StressThermal),
"final-token callback errors must preserve the typed CloudFallbackAbort marker"
);
}
}
}