use std::{
path::{Path, PathBuf},
sync::Arc,
time::Duration,
};
use mistralrs::{
Constraint, IsqType, MultimodalMessages, MultimodalModelBuilder, RequestBuilder, TextMessageRole,
};
use tracing::{debug, info, instrument};
use crate::error::{Error, LoadError};
use llmtask::Task;
pub const DEFAULT_INFERENCE_TIMEOUT: Duration = Duration::from_secs(300);
#[derive(Debug, Clone)]
pub struct EngineOptions {
model_path: PathBuf,
quantization: IsqType,
max_tokens: usize,
request: RequestOptions,
inference_timeout: Duration,
}
impl EngineOptions {
pub fn new(model_path: impl Into<PathBuf>) -> Self {
Self {
model_path: model_path.into(),
quantization: IsqType::Q4K,
max_tokens: 1024,
request: RequestOptions::deterministic(),
inference_timeout: DEFAULT_INFERENCE_TIMEOUT,
}
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn model_path(&self) -> &Path {
&self.model_path
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn with_model_path(mut self, val: impl Into<PathBuf>) -> Self {
self.model_path = val.into();
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn set_model_path(&mut self, val: impl Into<PathBuf>) -> &mut Self {
self.model_path = val.into();
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn quantization(&self) -> IsqType {
self.quantization
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn with_quantization(mut self, val: IsqType) -> Self {
self.quantization = val;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_quantization(&mut self, val: IsqType) -> &mut Self {
self.quantization = val;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn max_tokens(&self) -> usize {
self.max_tokens
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn with_max_tokens(mut self, val: usize) -> Self {
self.max_tokens = val;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_max_tokens(&mut self, val: usize) -> &mut Self {
self.max_tokens = val;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn request(&self) -> &RequestOptions {
&self.request
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn with_request(mut self, val: RequestOptions) -> Self {
self.request = val;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn set_request(&mut self, val: RequestOptions) -> &mut Self {
self.request = val;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn inference_timeout(&self) -> Duration {
self.inference_timeout
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn with_inference_timeout(mut self, val: Duration) -> Self {
self.inference_timeout = val;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_inference_timeout(&mut self, val: Duration) -> &mut Self {
self.inference_timeout = val;
self
}
}
#[derive(Debug, Clone)]
pub struct RequestOptions {
temperature: f64,
top_p: f64,
top_k: usize,
presence_penalty: f32,
}
impl RequestOptions {
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn new() -> Self {
Self {
temperature: 0.7,
top_p: 0.8,
top_k: 20,
presence_penalty: 1.5,
}
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn deterministic() -> Self {
Self {
temperature: 0.0,
top_p: 1.0,
top_k: 1,
presence_penalty: 1.5,
}
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn temperature(&self) -> f64 {
self.temperature
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn with_temperature(mut self, val: f64) -> Self {
self.temperature = val;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_temperature(&mut self, val: f64) -> &mut Self {
self.temperature = val;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn top_p(&self) -> f64 {
self.top_p
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn with_top_p(mut self, val: f64) -> Self {
self.top_p = val;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_top_p(&mut self, val: f64) -> &mut Self {
self.top_p = val;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn top_k(&self) -> usize {
self.top_k
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn with_top_k(mut self, val: usize) -> Self {
self.top_k = val;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_top_k(&mut self, val: usize) -> &mut Self {
self.top_k = val;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn presence_penalty(&self) -> f32 {
self.presence_penalty
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn with_presence_penalty(mut self, val: f32) -> Self {
self.presence_penalty = val;
self
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn set_presence_penalty(&mut self, val: f32) -> &mut Self {
self.presence_penalty = val;
self
}
pub const fn validate(&self) -> Result<(), Error> {
if !self.temperature.is_finite() || self.temperature < 0.0 {
return Err(Error::InvalidRequest(
"temperature must be finite and >= 0.0",
));
}
if !self.top_p.is_finite() || self.top_p <= 0.0 || self.top_p > 1.0 {
return Err(Error::InvalidRequest(
"top_p must be finite and in (0.0, 1.0]",
));
}
if self.top_k == 0 {
return Err(Error::InvalidRequest("top_k must be >= 1"));
}
if !self.presence_penalty.is_finite() {
return Err(Error::InvalidRequest("presence_penalty must be finite"));
}
Ok(())
}
}
impl Default for RequestOptions {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
pub struct Engine {
model: Arc<mistralrs::Model>,
options: EngineOptions,
}
impl Engine {
#[instrument(name = "qwen3_vl::load", skip(opts), fields(model_path = %opts.model_path().display(), quantization = ?opts.quantization()))]
pub async fn load(opts: EngineOptions) -> Result<Self, LoadError> {
if !opts.model_path().exists() {
return Err(LoadError::NotFound(opts.model_path().to_path_buf()));
}
let started = std::time::Instant::now();
info!("loading Qwen3-VL model");
let model_id = opts.model_path().to_string_lossy().into_owned();
let model = MultimodalModelBuilder::new(model_id)
.with_isq(opts.quantization())
.build()
.await
.map_err(|e| LoadError::Build(e.to_string()))?;
info!(
elapsed_ms = started.elapsed().as_millis() as u64,
"model loaded"
);
Ok(Self {
model: Arc::new(model),
options: opts,
})
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn model_path(&self) -> &Path {
self.options.model_path()
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn quantization(&self) -> IsqType {
self.options.quantization()
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn max_tokens(&self) -> usize {
self.options.max_tokens()
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn request(&self) -> &RequestOptions {
self.options.request()
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn inference_timeout(&self) -> Duration {
self.options.inference_timeout()
}
#[instrument(name = "qwen3_vl::warmup", skip(self))]
pub async fn warmup(&self) -> Result<(), Error> {
use image::{DynamicImage, RgbImage};
let blank = DynamicImage::ImageRgb8(RgbImage::new(1, 1));
self.warmup_with_image(blank).await
}
#[instrument(name = "qwen3_vl::warmup_with_image", skip(self, image))]
pub async fn warmup_with_image(&self, image: image::DynamicImage) -> Result<(), Error> {
let started = std::time::Instant::now();
let messages = MultimodalMessages::new().add_image_message(
TextMessageRole::User,
"Reply with: ok",
vec![image],
);
let request = RequestBuilder::from(messages)
.set_sampler_max_len(4)
.enable_thinking(false);
let timeout = self.options.inference_timeout();
let _ = tokio::time::timeout(timeout, self.model.send_chat_request(request))
.await
.map_err(|_| Error::InferenceTimeout(timeout))?
.map_err(|e| Error::Inference(e.to_string()))?;
debug!(
elapsed_ms = started.elapsed().as_millis() as u64,
"warmup complete"
);
Ok(())
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub async fn run<T: Task>(
&self,
task: &T,
images: Vec<image::DynamicImage>,
) -> Result<T::Output, Error>
where
T::ParseError: Send + Sync + 'static,
{
self.run_with(task, images, self.options.request()).await
}
#[instrument(
name = "qwen3_vl::run_with",
skip(self, task, images, opts),
fields(
task_kind = std::any::type_name::<T>(),
image_count = images.len(),
max_tokens = self.options.max_tokens(),
temperature = opts.temperature(),
),
)]
pub async fn run_with<T: Task>(
&self,
task: &T,
images: Vec<image::DynamicImage>,
opts: &RequestOptions,
) -> Result<T::Output, Error>
where
T::ParseError: Send + Sync + 'static,
{
if images.is_empty() {
return Err(Error::NoImages);
}
opts.validate()?;
let grammar = task.grammar();
let schema = grammar
.as_json_schema()
.ok_or_else(|| {
Error::UnsupportedGrammar(llmtask::UnsupportedGrammar::new(
grammar.kind(),
"json_schema",
))
})?
.clone();
let messages =
MultimodalMessages::new().add_image_message(TextMessageRole::User, task.prompt(), images);
let request = RequestBuilder::from(messages)
.set_sampler_max_len(self.options.max_tokens().max(1))
.enable_thinking(false)
.set_constraint(Constraint::JsonSchema(schema))
.set_sampler_temperature(opts.temperature())
.set_sampler_topp(opts.top_p())
.set_sampler_topk(opts.top_k())
.set_sampler_presence_penalty(opts.presence_penalty());
let started = std::time::Instant::now();
let timeout = self.options.inference_timeout();
let response = tokio::time::timeout(timeout, self.model.send_chat_request(request))
.await
.map_err(|_| Error::InferenceTimeout(timeout))?
.map_err(|e| Error::Inference(e.to_string()))?;
debug!(
elapsed_ms = started.elapsed().as_millis() as u64,
"inference complete"
);
let choice = response.choices.first().ok_or(Error::Empty)?;
if choice.finish_reason != "stop" {
let raw_len = choice
.message
.content
.as_ref()
.map(|s| s.len())
.unwrap_or(0);
return Err(Error::Truncated {
finish_reason: choice.finish_reason.clone(),
raw_len,
});
}
let text = choice
.message
.content
.clone()
.filter(|s| !s.trim().is_empty())
.ok_or(Error::Empty)?;
#[cfg(feature = "trace-output")]
tracing::trace!(raw = %text, "model output");
task.parse(&text).map_err(|e| Error::Parse(Box::new(e)))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn engine_options_defaults_to_deterministic_request() {
let opts = EngineOptions::new("/tmp/model");
assert_eq!(opts.model_path(), Path::new("/tmp/model"));
assert!(matches!(opts.quantization(), IsqType::Q4K));
assert_eq!(opts.max_tokens(), 1024);
let req = opts.request();
assert_eq!(req.temperature(), 0.0);
assert_eq!(req.top_p(), 1.0);
assert_eq!(req.top_k(), 1);
assert_eq!(
req.presence_penalty(),
1.5,
"deterministic preset must keep presence_penalty 1.5 — greedy \
without it falls into token loops"
);
}
#[test]
fn engine_options_with_chains() {
let opts = EngineOptions::new("/tmp/a")
.with_model_path("/tmp/b")
.with_quantization(IsqType::Q8_0)
.with_max_tokens(1024)
.with_request(RequestOptions::new());
assert_eq!(opts.model_path(), Path::new("/tmp/b"));
assert!(matches!(opts.quantization(), IsqType::Q8_0));
assert_eq!(opts.max_tokens(), 1024);
assert_eq!(opts.request().temperature(), 0.7);
}
#[test]
fn engine_options_set_chains() {
let mut opts = EngineOptions::new("/tmp/a");
opts
.set_model_path("/tmp/b")
.set_quantization(IsqType::Q8_0)
.set_max_tokens(1024)
.set_request(RequestOptions::new());
assert_eq!(opts.model_path(), Path::new("/tmp/b"));
assert!(matches!(opts.quantization(), IsqType::Q8_0));
assert_eq!(opts.max_tokens(), 1024);
assert_eq!(opts.request().temperature(), 0.7);
}
#[test]
fn request_options_defaults_match_model_card() {
let opts = RequestOptions::new();
assert_eq!(opts.temperature(), 0.7);
assert_eq!(opts.top_p(), 0.8);
assert_eq!(opts.top_k(), 20);
assert_eq!(opts.presence_penalty(), 1.5);
}
#[test]
fn request_options_default_eq_new() {
let new_opts = RequestOptions::new();
let default_opts = RequestOptions::default();
assert_eq!(new_opts.temperature(), default_opts.temperature());
assert_eq!(new_opts.top_p(), default_opts.top_p());
assert_eq!(new_opts.top_k(), default_opts.top_k());
assert_eq!(new_opts.presence_penalty(), default_opts.presence_penalty());
}
#[test]
fn request_options_with_chains() {
let opts = RequestOptions::new()
.with_temperature(0.3)
.with_top_p(0.95)
.with_top_k(50)
.with_presence_penalty(0.0);
assert_eq!(opts.temperature(), 0.3);
assert_eq!(opts.top_p(), 0.95);
assert_eq!(opts.top_k(), 50);
assert_eq!(opts.presence_penalty(), 0.0);
}
#[test]
fn request_options_set_chains() {
let mut opts = RequestOptions::new();
opts
.set_temperature(0.3)
.set_top_p(0.95)
.set_top_k(50)
.set_presence_penalty(0.0);
assert_eq!(opts.temperature(), 0.3);
assert_eq!(opts.top_p(), 0.95);
assert_eq!(opts.top_k(), 50);
assert_eq!(opts.presence_penalty(), 0.0);
}
#[test]
fn request_options_deterministic_preset() {
let opts = RequestOptions::deterministic();
assert_eq!(opts.temperature(), 0.0);
assert_eq!(opts.top_p(), 1.0);
assert_eq!(opts.top_k(), 1);
assert_eq!(opts.presence_penalty(), 1.5);
}
#[test]
fn request_options_validate_accepts_presets() {
assert!(RequestOptions::new().validate().is_ok());
assert!(RequestOptions::deterministic().validate().is_ok());
}
#[test]
fn request_options_validate_rejects_negative_temperature() {
let opts = RequestOptions::new().with_temperature(-0.1);
assert!(matches!(opts.validate(), Err(Error::InvalidRequest(_))));
}
#[test]
fn request_options_validate_rejects_non_finite_temperature() {
assert!(matches!(
RequestOptions::new().with_temperature(f64::NAN).validate(),
Err(Error::InvalidRequest(_))
));
assert!(matches!(
RequestOptions::new()
.with_temperature(f64::INFINITY)
.validate(),
Err(Error::InvalidRequest(_))
));
}
#[test]
fn request_options_validate_rejects_top_p_out_of_range() {
assert!(matches!(
RequestOptions::new().with_top_p(0.0).validate(),
Err(Error::InvalidRequest(_))
));
assert!(matches!(
RequestOptions::new().with_top_p(1.5).validate(),
Err(Error::InvalidRequest(_))
));
assert!(matches!(
RequestOptions::new().with_top_p(-0.1).validate(),
Err(Error::InvalidRequest(_))
));
assert!(matches!(
RequestOptions::new().with_top_p(f64::NAN).validate(),
Err(Error::InvalidRequest(_))
));
}
#[test]
fn request_options_validate_accepts_top_p_one() {
assert!(RequestOptions::new().with_top_p(1.0).validate().is_ok());
}
#[test]
fn request_options_validate_rejects_top_k_zero() {
let opts = RequestOptions::new().with_top_k(0);
assert!(matches!(opts.validate(), Err(Error::InvalidRequest(_))));
}
#[test]
fn request_options_validate_rejects_non_finite_presence_penalty() {
assert!(matches!(
RequestOptions::new()
.with_presence_penalty(f32::NAN)
.validate(),
Err(Error::InvalidRequest(_))
));
assert!(matches!(
RequestOptions::new()
.with_presence_penalty(f32::INFINITY)
.validate(),
Err(Error::InvalidRequest(_))
));
}
#[test]
fn request_options_validate_accepts_negative_presence_penalty() {
assert!(
RequestOptions::new()
.with_presence_penalty(-1.0)
.validate()
.is_ok()
);
}
#[test]
fn engine_options_default_inference_timeout() {
let opts = EngineOptions::new("/nonexistent");
assert_eq!(opts.inference_timeout(), DEFAULT_INFERENCE_TIMEOUT);
assert_eq!(opts.inference_timeout(), Duration::from_secs(300));
}
#[test]
fn engine_options_with_inference_timeout() {
let opts = EngineOptions::new("/nonexistent").with_inference_timeout(Duration::from_secs(10));
assert_eq!(opts.inference_timeout(), Duration::from_secs(10));
}
#[test]
fn engine_options_default_max_tokens_bumped_to_1024() {
let opts = EngineOptions::new("/nonexistent");
assert_eq!(opts.max_tokens(), 1024);
}
}