use crate::core::backend::{
ChatMessage, ChatOptions, ChatResponse, ChatRole, LoadConfig, ModelInfo, NativeModelKind,
VisionImage,
};
use crate::core::storage::extract_quantization;
use crate::provider::native::error::NativeError;
use crate::provider::native::traits::InferenceBackend;
use futures::Stream;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
#[cfg(feature = "native-inference")]
use crate::util::constants::INFER_TIMEOUT;
#[cfg(feature = "native-inference")]
use parking_lot::Mutex;
#[cfg(feature = "native-inference")]
use std::path::Path;
#[cfg(feature = "native-inference")]
use tracing::{debug, info, warn};
#[cfg(feature = "native-inference")]
#[allow(unused_imports)] use futures::StreamExt as _;
#[cfg(feature = "native-inference")]
use mistralrs::{
GgufModelBuilder, MemoryGpuConfig, Model, ModelCategory, PagedAttentionMetaBuilder,
RequestBuilder, TextMessageRole, TextMessages, VisionMessages, VisionModelBuilder,
};
#[cfg(feature = "native-inference")]
use tokio::sync::RwLock;
#[cfg(feature = "native-inference")]
struct TrackedTask {
#[allow(dead_code)] handle: JoinHandle<()>,
#[allow(dead_code)] token: CancellationToken,
}
#[allow(dead_code)] pub struct NativeRuntime {
#[cfg(feature = "native-inference")]
model: Option<Arc<RwLock<Model>>>,
model_info: Option<ModelInfo>,
model_path: Option<PathBuf>,
config: Option<LoadConfig>,
is_vision: bool,
cancellation_token: CancellationToken,
#[cfg(feature = "native-inference")]
tasks: Arc<Mutex<Vec<TrackedTask>>>,
}
impl std::fmt::Debug for NativeRuntime {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut dbg = f.debug_struct("NativeRuntime");
dbg.field("model_info", &self.model_info)
.field("model_path", &self.model_path)
.field("config", &self.config)
.field("is_loaded", &self.is_loaded())
.field("is_vision", &self.is_vision)
.field("is_cancelled", &self.cancellation_token.is_cancelled());
#[cfg(feature = "native-inference")]
{
let task_count = self.tasks.lock().len();
dbg.field("active_tasks", &task_count);
}
dbg.finish()
}
}
impl Clone for NativeRuntime {
fn clone(&self) -> Self {
Self {
#[cfg(feature = "native-inference")]
model: self.model.clone(),
model_info: self.model_info.clone(),
model_path: self.model_path.clone(),
config: self.config.clone(),
is_vision: self.is_vision,
cancellation_token: self.cancellation_token.clone(),
#[cfg(feature = "native-inference")]
tasks: Arc::clone(&self.tasks),
}
}
}
impl NativeRuntime {
#[must_use]
pub fn new() -> Self {
Self {
#[cfg(feature = "native-inference")]
model: None,
model_info: None,
model_path: None,
config: None,
is_vision: false,
cancellation_token: CancellationToken::new(),
#[cfg(feature = "native-inference")]
tasks: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn cancel_all(&self) {
tracing::debug!("Cancelling all native inference tasks");
self.cancellation_token.cancel();
}
#[must_use]
pub fn is_cancelled(&self) -> bool {
self.cancellation_token.is_cancelled()
}
#[must_use]
pub fn child_token(&self) -> CancellationToken {
self.cancellation_token.child_token()
}
#[cfg(feature = "native-inference")]
fn cleanup_completed_tasks(&self) {
let mut tasks = self.tasks.lock();
tasks.retain(|task| !task.handle.is_finished());
}
#[cfg(feature = "native-inference")]
#[must_use]
pub fn active_task_count(&self) -> usize {
self.cleanup_completed_tasks();
self.tasks.lock().len()
}
#[cfg(feature = "native-inference")]
pub async fn shutdown(&mut self, timeout: std::time::Duration) -> Result<(), NativeError> {
use tokio::time::timeout as tokio_timeout;
tracing::info!("Shutting down native runtime");
self.cancel_all();
let tasks: Vec<_> = {
let mut guard = self.tasks.lock();
std::mem::take(&mut *guard)
};
if !tasks.is_empty() {
tracing::debug!(task_count = tasks.len(), "Waiting for tasks to complete");
let wait_future = async {
for task in tasks {
let _ = task.handle.await;
}
};
if tokio_timeout(timeout, wait_future).await.is_err() {
tracing::warn!("Timeout waiting for tasks to complete during shutdown");
}
}
self.unload().await?;
tracing::info!("Native runtime shutdown complete");
Ok(())
}
#[must_use]
pub fn model_path(&self) -> Option<&PathBuf> {
self.model_path.as_ref()
}
#[must_use]
pub fn config(&self) -> Option<&LoadConfig> {
self.config.as_ref()
}
}
impl Default for NativeRuntime {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "native-inference")]
fn extract_quantization_from_path(path: &Path) -> Option<String> {
let filename = path.file_name()?.to_string_lossy();
extract_quantization(&filename)
}
#[cfg(feature = "native-inference")]
fn apply_sampling_params(mut request: RequestBuilder, options: &ChatOptions) -> RequestBuilder {
if let Some(temp) = options.temperature {
request = request.set_sampler_temperature(f64::from(temp));
}
if let Some(max_tokens) = options.max_tokens {
request = request.set_sampler_max_len(max_tokens as usize);
}
if let Some(top_p) = options.top_p {
request = request.set_sampler_topp(f64::from(top_p));
}
if let Some(top_k) = options.top_k {
request = request.set_sampler_topk(top_k as usize);
}
request
}
#[cfg(feature = "native-inference")]
fn parse_chat_completion(
response: &mistralrs::ChatCompletionResponse,
) -> Result<ChatResponse, NativeError> {
let content = response
.choices
.first()
.and_then(|c| c.message.content.clone())
.ok_or_else(|| {
NativeError::InferenceFailed("Model returned empty response (no choices)".to_string())
})?;
debug!(
prompt_tokens = response.usage.prompt_tokens,
completion_tokens = response.usage.completion_tokens,
avg_prompt_tok_per_sec = ?response.usage.avg_prompt_tok_per_sec,
avg_compl_tok_per_sec = ?response.usage.avg_compl_tok_per_sec,
"Inference completed"
);
Ok(ChatResponse {
message: ChatMessage {
role: ChatRole::Assistant,
content,
},
done: true,
total_duration: None,
prompt_eval_count: Some(response.usage.prompt_tokens as u64),
eval_count: Some(response.usage.completion_tokens as u64),
})
}
#[cfg(feature = "native-inference")]
fn decode_vision_images(images: &[VisionImage]) -> Result<Vec<image::DynamicImage>, NativeError> {
images
.iter()
.enumerate()
.map(|(i, img)| {
image::load_from_memory(&img.bytes).map_err(|e| {
NativeError::InferenceFailed(format!(
"Failed to decode image {} ({}): {}",
i, img.media_type, e
))
})
})
.collect()
}
#[cfg(feature = "native-inference")]
fn detect_vision_capability(model: &Model) -> bool {
match model.config() {
Ok(config) => matches!(config.category, ModelCategory::Vision { .. }),
Err(e) => {
warn!(
"Could not read model config to detect vision capability: {}",
e
);
false
}
}
}
#[cfg(feature = "native-inference")]
fn spawn_stream_task(
runtime: &NativeRuntime,
model_arc: Arc<RwLock<Model>>,
request: RequestBuilder,
) -> Result<impl Stream<Item = Result<String, NativeError>> + Send, NativeError> {
use crate::util::constants::STREAM_CHUNK_TIMEOUT;
use async_stream::stream;
use mistralrs::Response;
use tokio::sync::mpsc;
if runtime.cancellation_token.is_cancelled() {
return Err(NativeError::Cancelled);
}
let (tx, mut rx) = mpsc::channel::<Result<String, NativeError>>(32);
let task_token = runtime.child_token();
let task_token_clone = task_token.clone();
let handle = tokio::spawn(async move {
let model = model_arc.read().await;
match model.stream_chat_request(request).await {
Ok(mut stream) => {
loop {
if task_token_clone.is_cancelled() {
debug!("Streaming task cancelled");
let _ = tx.send(Err(NativeError::Cancelled)).await;
break;
}
let chunk_result = tokio::select! {
biased;
_ = task_token_clone.cancelled() => {
debug!("Streaming task cancelled during chunk wait");
let _ = tx.send(Err(NativeError::Cancelled)).await;
break;
}
result = tokio::time::timeout(
STREAM_CHUNK_TIMEOUT,
stream.next(),
) => {
result
}
};
let chunk = match chunk_result {
Ok(Some(c)) => c,
Ok(None) => break, Err(_) => {
let _ = tx
.send(Err(NativeError::InferenceTimeout {
timeout_secs: STREAM_CHUNK_TIMEOUT.as_secs(),
}))
.await;
break;
}
};
match chunk {
Response::Chunk(chunk_response) => {
if let Some(choice) = chunk_response.choices.first() {
if let Some(text) = &choice.delta.content {
if tx.send(Ok(text.clone())).await.is_err() {
break;
}
}
}
}
Response::Done(_) => {
debug!("Streaming completed");
break;
}
Response::ModelError(msg, _) => {
let _ = tx
.send(Err(NativeError::InferenceFailed(format!(
"Model error: {}",
msg
))))
.await;
break;
}
Response::ValidationError(err) => {
let _ = tx
.send(Err(NativeError::InferenceFailed(format!(
"Validation error: {:?}",
err
))))
.await;
break;
}
Response::InternalError(err) => {
let _ = tx
.send(Err(NativeError::InferenceFailed(format!(
"Internal error: {:?}",
err
))))
.await;
break;
}
_ => {
}
}
}
}
Err(e) => {
let _ = tx
.send(Err(NativeError::InferenceFailed(format!(
"Failed to start streaming: {}",
e
))))
.await;
}
}
});
{
let mut tasks = runtime.tasks.lock();
tasks.push(TrackedTask {
handle,
token: task_token,
});
}
runtime.cleanup_completed_tasks();
Ok(stream! {
while let Some(result) = rx.recv().await {
yield result;
}
})
}
#[cfg(feature = "native-inference")]
impl InferenceBackend for NativeRuntime {
async fn load(&mut self, model_path: PathBuf, config: LoadConfig) -> Result<(), NativeError> {
if self.model.is_some() {
self.unload().await?;
}
let model_kind = config.model_kind.clone();
match &model_kind {
NativeModelKind::TextGguf => {
info!(?model_path, "Loading GGUF model");
if !model_path.exists() {
return Err(NativeError::ModelNotFound {
repo: "local".to_string(),
filename: model_path.to_string_lossy().to_string(),
});
}
let parent = model_path
.parent()
.map(|p| p.to_string_lossy().to_string())
.unwrap_or_else(|| ".".to_string());
let filename = model_path
.file_name()
.map(|f| f.to_string_lossy().to_string())
.ok_or_else(|| {
NativeError::InvalidConfig("Invalid model path: no filename".to_string())
})?;
debug!(gpu_layers = config.gpu_layers, %parent, %filename, "Building GGUF model");
let context_size = config.context_size.unwrap_or(2048);
let model = GgufModelBuilder::new(parent, vec![filename])
.with_logging()
.with_paged_attn(|| {
PagedAttentionMetaBuilder::default()
.with_block_size(32)
.with_gpu_memory(MemoryGpuConfig::ContextSize(context_size as usize))
.build()
})
.map_err(|e| {
NativeError::InvalidConfig(format!("PagedAttention config error: {e}"))
})?
.build()
.await
.map_err(|e| {
NativeError::InvalidConfig(format!("Failed to build model: {e}"))
})?;
let is_vision = detect_vision_capability(&model);
let info = ModelInfo {
name: model_path
.file_stem()
.map(|s| s.to_string_lossy().to_string())
.unwrap_or_else(|| "unknown".to_string()),
size: tokio::fs::metadata(&model_path)
.await
.map(|m| m.len())
.unwrap_or(0),
quantization: extract_quantization_from_path(&model_path),
parameters: None,
digest: None,
};
self.model = Some(Arc::new(RwLock::new(model)));
self.model_info = Some(info);
self.model_path = Some(model_path);
self.is_vision = is_vision;
self.config = Some(config);
info!(is_vision, "GGUF model loaded successfully");
}
NativeModelKind::VisionHf { model_id, isq } => {
info!(%model_id, ?isq, "Loading HuggingFace vision model");
let context_size = config.context_size.unwrap_or(4096);
let mut builder = VisionModelBuilder::new(model_id).with_logging();
if let Some(isq_str) = isq {
let isq_type = mistralrs::parse_isq_value(isq_str, None).map_err(|e| {
NativeError::InvalidConfig(format!("Invalid ISQ type '{}': {}", isq_str, e))
})?;
debug!(?isq_type, "Applying ISQ quantization");
builder = builder.with_isq(isq_type);
}
builder = builder
.with_paged_attn(|| {
PagedAttentionMetaBuilder::default()
.with_block_size(32)
.with_gpu_memory(MemoryGpuConfig::ContextSize(context_size as usize))
.build()
})
.map_err(|e| {
NativeError::InvalidConfig(format!("PagedAttention config error: {e}"))
})?;
let model = builder.build().await.map_err(|e| {
NativeError::InvalidConfig(format!(
"Failed to build vision model '{}': {}",
model_id, e
))
})?;
let is_vision = detect_vision_capability(&model);
if !is_vision {
warn!(
%model_id,
"Model loaded via VisionHf path but does not report Vision category"
);
}
let info = ModelInfo {
name: model_id.clone(),
size: 0, quantization: isq.clone(),
parameters: None,
digest: None,
};
self.model = Some(Arc::new(RwLock::new(model)));
self.model_info = Some(info);
self.model_path = None;
self.is_vision = is_vision;
self.config = Some(config);
info!(is_vision, %model_id, "Vision model loaded successfully");
}
}
Ok(())
}
async fn unload(&mut self) -> Result<(), NativeError> {
if self.model.is_some() {
info!("Unloading model");
self.model = None;
self.model_info = None;
self.model_path = None;
self.config = None;
self.is_vision = false;
}
Ok(())
}
fn is_loaded(&self) -> bool {
self.model.is_some()
}
fn model_info(&self) -> Option<&ModelInfo> {
self.model_info.as_ref()
}
fn supports_vision(&self) -> bool {
self.is_vision
}
async fn infer(&self, prompt: &str, options: ChatOptions) -> Result<ChatResponse, NativeError> {
let model = self.model.as_ref().ok_or(NativeError::ModelNotLoaded)?;
let model = model.read().await;
let messages = TextMessages::new().add_message(TextMessageRole::User, prompt);
debug!(
temperature = options.temperature,
max_tokens = options.max_tokens,
"Running text inference"
);
let request = apply_sampling_params(RequestBuilder::from(messages), &options);
let response = tokio::time::timeout(INFER_TIMEOUT, model.send_chat_request(request))
.await
.map_err(|_| NativeError::InferenceTimeout {
timeout_secs: INFER_TIMEOUT.as_secs(),
})?
.map_err(|e| NativeError::InferenceFailed(format!("Inference failed: {e}")))?;
parse_chat_completion(&response)
}
async fn infer_stream(
&self,
prompt: &str,
options: ChatOptions,
) -> Result<impl Stream<Item = Result<String, NativeError>> + Send, NativeError> {
if self.cancellation_token.is_cancelled() {
return Err(NativeError::Cancelled);
}
let model = self.model.as_ref().ok_or(NativeError::ModelNotLoaded)?;
let model_arc = Arc::clone(model);
let messages = TextMessages::new().add_message(TextMessageRole::User, prompt);
let request = apply_sampling_params(RequestBuilder::from(messages), &options);
spawn_stream_task(self, model_arc, request)
}
async fn infer_vision(
&self,
prompt: &str,
images: Vec<VisionImage>,
options: ChatOptions,
) -> Result<ChatResponse, NativeError> {
let model_lock = self.model.as_ref().ok_or(NativeError::ModelNotLoaded)?;
if !self.is_vision {
return Err(NativeError::InvalidConfig(
"Loaded model does not support vision. Load a vision model via \
NativeModelKind::VisionHf"
.to_string(),
));
}
if images.is_empty() {
return Err(NativeError::InvalidConfig(
"infer_vision requires at least one image".to_string(),
));
}
let model = model_lock.read().await;
debug!(
image_count = images.len(),
temperature = options.temperature,
max_tokens = options.max_tokens,
"Running vision inference"
);
let dynamic_images = decode_vision_images(&images)?;
let vision_messages = VisionMessages::new()
.add_image_message(TextMessageRole::User, prompt, dynamic_images, &model)
.map_err(|e| {
NativeError::InferenceFailed(format!("Failed to build vision message: {}", e))
})?;
let request = apply_sampling_params(RequestBuilder::from(vision_messages), &options);
let response = tokio::time::timeout(INFER_TIMEOUT, model.send_chat_request(request))
.await
.map_err(|_| NativeError::InferenceTimeout {
timeout_secs: INFER_TIMEOUT.as_secs(),
})?
.map_err(|e| NativeError::InferenceFailed(format!("Vision inference failed: {e}")))?;
parse_chat_completion(&response)
}
async fn infer_vision_stream(
&self,
prompt: &str,
images: Vec<VisionImage>,
options: ChatOptions,
) -> Result<impl Stream<Item = Result<String, NativeError>> + Send, NativeError> {
let model_lock = self.model.as_ref().ok_or(NativeError::ModelNotLoaded)?;
if !self.is_vision {
return Err(NativeError::InvalidConfig(
"Loaded model does not support vision. Load a vision model via \
NativeModelKind::VisionHf"
.to_string(),
));
}
if images.is_empty() {
return Err(NativeError::InvalidConfig(
"infer_vision_stream requires at least one image".to_string(),
));
}
debug!(
image_count = images.len(),
temperature = options.temperature,
max_tokens = options.max_tokens,
"Starting vision inference stream"
);
let dynamic_images = decode_vision_images(&images)?;
let request = {
let model = model_lock.read().await;
let vision_messages = VisionMessages::new()
.add_image_message(TextMessageRole::User, prompt, dynamic_images, &model)
.map_err(|e| {
NativeError::InferenceFailed(format!("Failed to build vision message: {}", e))
})?;
apply_sampling_params(RequestBuilder::from(vision_messages), &options)
};
let model_arc = Arc::clone(model_lock);
spawn_stream_task(self, model_arc, request)
}
}
#[cfg(not(feature = "native-inference"))]
impl InferenceBackend for NativeRuntime {
async fn load(&mut self, _model_path: PathBuf, _config: LoadConfig) -> Result<(), NativeError> {
Err(NativeError::InvalidConfig(
"Inference feature not enabled. Rebuild with --features native-inference".to_string(),
))
}
async fn unload(&mut self) -> Result<(), NativeError> {
Ok(())
}
fn is_loaded(&self) -> bool {
false
}
fn model_info(&self) -> Option<&ModelInfo> {
None
}
fn supports_vision(&self) -> bool {
false
}
async fn infer(
&self,
_prompt: &str,
_options: ChatOptions,
) -> Result<ChatResponse, NativeError> {
Err(NativeError::InvalidConfig(
"Inference feature not enabled. Rebuild with --features native-inference".to_string(),
))
}
async fn infer_stream(
&self,
_prompt: &str,
_options: ChatOptions,
) -> Result<impl Stream<Item = Result<String, NativeError>> + Send, NativeError> {
Err::<futures::stream::Empty<Result<String, NativeError>>, _>(NativeError::InvalidConfig(
"Inference feature not enabled. Rebuild with --features native-inference".to_string(),
))
}
async fn infer_vision(
&self,
_prompt: &str,
_images: Vec<VisionImage>,
_options: ChatOptions,
) -> Result<ChatResponse, NativeError> {
Err(NativeError::InvalidConfig(
"Inference feature not enabled. Rebuild with --features native-inference".to_string(),
))
}
async fn infer_vision_stream(
&self,
_prompt: &str,
_images: Vec<VisionImage>,
_options: ChatOptions,
) -> Result<impl Stream<Item = Result<String, NativeError>> + Send, NativeError> {
Err::<futures::stream::Empty<Result<String, NativeError>>, _>(NativeError::InvalidConfig(
"Inference feature not enabled. Rebuild with --features native-inference".to_string(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_runtime_creation() {
let runtime = NativeRuntime::new();
assert!(!runtime.is_loaded());
assert!(runtime.model_info().is_none());
assert!(runtime.model_path().is_none());
assert!(!runtime.is_cancelled());
assert!(!runtime.supports_vision());
}
#[test]
fn test_runtime_default() {
let runtime = NativeRuntime::default();
assert!(!runtime.is_loaded());
assert!(!runtime.is_cancelled());
assert!(!runtime.supports_vision());
}
#[test]
fn test_cancel_all() {
let runtime = NativeRuntime::new();
assert!(!runtime.is_cancelled());
runtime.cancel_all();
assert!(runtime.is_cancelled());
}
#[test]
fn test_child_token_cancelled_with_parent() {
let runtime = NativeRuntime::new();
let child = runtime.child_token();
assert!(!child.is_cancelled());
runtime.cancel_all();
assert!(child.is_cancelled());
}
#[test]
fn test_child_token_independent() {
let runtime = NativeRuntime::new();
let child1 = runtime.child_token();
let child2 = runtime.child_token();
child1.cancel();
assert!(child1.is_cancelled());
assert!(!child2.is_cancelled());
assert!(!runtime.is_cancelled());
}
#[test]
fn test_clone_shares_cancellation_token() {
let runtime1 = NativeRuntime::new();
let runtime2 = runtime1.clone();
assert!(!runtime1.is_cancelled());
assert!(!runtime2.is_cancelled());
runtime1.cancel_all();
assert!(runtime1.is_cancelled());
assert!(runtime2.is_cancelled());
}
#[test]
fn test_debug_includes_cancellation_state() {
let runtime = NativeRuntime::new();
let debug_str = format!("{:?}", runtime);
assert!(debug_str.contains("is_cancelled"));
assert!(debug_str.contains("is_vision"));
}
#[test]
fn test_vision_image_construction() {
let img = VisionImage::new(vec![0xFF, 0xD8], "image/jpeg");
assert_eq!(img.bytes, vec![0xFF, 0xD8]);
assert_eq!(img.media_type, "image/jpeg");
}
#[test]
fn test_native_model_kind_default() {
let kind = NativeModelKind::default();
assert!(matches!(kind, NativeModelKind::TextGguf));
assert!(!kind.is_vision());
}
#[test]
fn test_native_model_kind_vision() {
let kind = NativeModelKind::VisionHf {
model_id: "test/model".to_string(),
isq: Some("Q4K".to_string()),
};
assert!(kind.is_vision());
}
#[test]
fn test_load_config_default_has_text_gguf() {
let config = LoadConfig::default();
assert!(matches!(config.model_kind, NativeModelKind::TextGguf));
}
#[tokio::test]
#[cfg(not(feature = "native-inference"))]
async fn test_load_without_feature() {
let mut runtime = NativeRuntime::new();
let result = runtime
.load(PathBuf::from("test.gguf"), LoadConfig::default())
.await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Inference feature not enabled"));
}
#[tokio::test]
#[cfg(not(feature = "native-inference"))]
async fn test_infer_without_feature() {
let runtime = NativeRuntime::new();
let result = runtime.infer("test", ChatOptions::default()).await;
assert!(result.is_err());
}
#[tokio::test]
#[cfg(not(feature = "native-inference"))]
async fn test_infer_vision_without_feature() {
let runtime = NativeRuntime::new();
let images = vec![VisionImage::new(vec![0xFF], "image/png")];
let result = runtime
.infer_vision("describe", images, ChatOptions::default())
.await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Inference feature not enabled"));
}
#[cfg(feature = "native-inference")]
mod native_inference_tests {
use super::*;
#[tokio::test]
async fn test_active_task_count_starts_at_zero() {
let runtime = NativeRuntime::new();
assert_eq!(runtime.active_task_count(), 0);
}
#[tokio::test]
async fn test_infer_stream_returns_cancelled_when_already_cancelled() {
let runtime = NativeRuntime::new();
runtime.cancel_all();
let result = runtime
.infer_stream("test prompt", ChatOptions::default())
.await;
match result {
Err(NativeError::Cancelled) => {} Err(other) => panic!("Expected NativeError::Cancelled, got {:?}", other),
Ok(_) => panic!("Expected Err(Cancelled), got Ok(stream)"),
}
}
#[tokio::test]
async fn test_shutdown_cancels_token() {
let mut runtime = NativeRuntime::new();
assert!(!runtime.is_cancelled());
let result = runtime
.shutdown(std::time::Duration::from_millis(100))
.await;
assert!(result.is_ok());
assert!(runtime.is_cancelled());
}
#[tokio::test]
async fn test_infer_vision_requires_loaded_model() {
let runtime = NativeRuntime::new();
let images = vec![VisionImage::new(vec![0xFF], "image/png")];
let result = runtime
.infer_vision("describe", images, ChatOptions::default())
.await;
assert!(matches!(result, Err(NativeError::ModelNotLoaded)));
}
#[tokio::test]
async fn test_infer_vision_stream_requires_loaded_model() {
let runtime = NativeRuntime::new();
let images = vec![VisionImage::new(vec![0xFF], "image/png")];
let result = runtime
.infer_vision_stream("describe", images, ChatOptions::default())
.await;
match result {
Err(NativeError::ModelNotLoaded) => {} Err(other) => panic!("Expected ModelNotLoaded, got {:?}", other),
Ok(_) => panic!("Expected Err, got Ok"),
}
}
}
}