#[cfg(any(
target_os = "macos",
target_os = "ios",
target_os = "tvos",
target_os = "windows",
target_os = "linux",
target_os = "android"
))]
use std::sync::Arc;
#[cfg(any(
target_os = "macos",
target_os = "ios",
target_os = "tvos",
target_os = "windows",
target_os = "linux",
target_os = "android"
))]
use tokio::sync::Mutex;
use super::types::*;
#[cfg(any(
target_os = "macos",
target_os = "ios",
target_os = "tvos",
target_os = "windows",
target_os = "linux",
target_os = "android"
))]
use mistralrs::{GgufModelBuilder, Model, RequestBuilder, TextMessageRole};
#[cfg(target_os = "macos")]
use mistralrs::{IsqBits, TextModelBuilder};
#[cfg(any(
target_os = "macos",
target_os = "ios",
target_os = "tvos",
target_os = "windows",
target_os = "linux",
target_os = "android"
))]
enum LoadedModelConfig {
Gguf(GgufModelConfig),
#[cfg(target_os = "macos")]
Isq(IsqModelConfig),
}
#[cfg(any(
target_os = "macos",
target_os = "ios",
target_os = "tvos",
target_os = "windows",
target_os = "linux",
target_os = "android"
))]
impl LoadedModelConfig {
fn display_name(&self) -> &str {
match self {
LoadedModelConfig::Gguf(c) => &c.display_name,
#[cfg(target_os = "macos")]
LoadedModelConfig::Isq(c) => &c.display_name,
}
}
fn approx_memory(&self) -> &str {
match self {
LoadedModelConfig::Gguf(c) => &c.approx_memory,
#[cfg(target_os = "macos")]
LoadedModelConfig::Isq(c) => &c.approx_memory,
}
}
}
#[cfg(any(
target_os = "macos",
target_os = "ios",
target_os = "tvos",
target_os = "windows",
target_os = "linux",
target_os = "android"
))]
struct LoadedModel {
model: Arc<Model>,
config: LoadedModelConfig,
history: Vec<ChatMessage>,
system_prompt: Option<String>,
sampling: SamplingConfig,
}
#[cfg(any(
target_os = "macos",
target_os = "ios",
target_os = "tvos",
target_os = "windows",
target_os = "linux",
target_os = "android"
))]
pub struct ChatEngine {
inner: Mutex<Option<LoadedModel>>,
}
#[cfg(any(
target_os = "macos",
target_os = "ios",
target_os = "tvos",
target_os = "windows",
target_os = "linux",
target_os = "android"
))]
impl ChatEngine {
pub fn new() -> Self {
Self {
inner: Mutex::new(None),
}
}
pub async fn load_gguf_model(
&self,
config: GgufModelConfig,
system_prompt: Option<String>,
sampling: Option<SamplingConfig>,
) -> Result<std::time::Duration, InferenceError> {
use log::info;
use std::time::Instant;
info!(
"ChatEngine: loading GGUF model {} (files: {:?})",
config.model_id, config.files
);
#[cfg(any(target_os = "android", target_os = "ios", target_os = "tvos"))]
{
let hf_home = std::env::var("HF_HOME")
.map(std::path::PathBuf::from)
.map_err(|_| InferenceError::ModelBuild {
reason: "HF_HOME is not set — cannot initialise HF cache. \
On iOS/tvOS/Android, call configure_cache_dir() or \
download_model(app_data_dir:) before load_gguf_model()."
.to_string(),
})?;
let hf_hub_cache = hf_home.join("hub");
if let Err(e) = std::fs::create_dir_all(&hf_hub_cache) {
log::warn!(
"ChatEngine: could not create HF hub cache dir {}: {}",
hf_hub_cache.display(),
e
);
}
mistralrs_core::GLOBAL_HF_CACHE
.get_or_init(|| hf_hub::Cache::new(hf_hub_cache.clone()));
std::env::set_var("HF_HUB_CACHE", &hf_hub_cache);
log::debug!(
"ChatEngine: GLOBAL_HF_CACHE seeded at {}",
hf_hub_cache.display()
);
}
crate::hf_cache::clean_stale_lock_files(&config.model_id);
crate::hf_cache::repair_hf_cache_symlinks(&config.model_id);
let start = Instant::now();
let mut builder = GgufModelBuilder::new(&config.model_id, config.files.clone())
.with_token_source(super::token::hf_token_source())
.with_logging();
if let Some(ref tok_id) = config.tok_model_id {
builder = builder.with_tok_model_id(tok_id);
}
let model = builder
.build()
.await
.map_err(|e| InferenceError::ModelBuild {
reason: format!("Failed to build {} model: {}", config.display_name, e),
})?;
let elapsed = start.elapsed();
let sampling = sampling.unwrap_or_else(|| {
if cfg!(any(
target_os = "ios",
target_os = "tvos",
target_os = "android"
)) {
SamplingConfig::mobile()
} else {
SamplingConfig::default()
}
});
info!(
"ChatEngine: model {} loaded in {} (sampling: temp={:?}, max_tokens={:?})",
config.display_name,
format_duration(elapsed),
sampling.temperature,
sampling.max_tokens,
);
let mut guard = self.inner.lock().await;
*guard = Some(LoadedModel {
model: Arc::new(model),
config: LoadedModelConfig::Gguf(config),
history: Vec::new(),
system_prompt,
sampling,
});
Ok(elapsed)
}
#[cfg(target_os = "macos")]
pub async fn load_isq_model(
&self,
config: IsqModelConfig,
system_prompt: Option<String>,
sampling: Option<SamplingConfig>,
) -> Result<std::time::Duration, InferenceError> {
use log::info;
use std::time::Instant;
info!(
"ChatEngine: loading ISQ model {} (bits={})",
config.model_id, config.isq_bits
);
crate::hf_cache::clean_stale_lock_files(&config.model_id);
crate::hf_cache::repair_hf_cache_symlinks(&config.model_id);
let start = Instant::now();
let isq_bits = match config.isq_bits {
8 => IsqBits::Eight,
_ => IsqBits::Four, };
let model = TextModelBuilder::new(&config.model_id)
.with_token_source(super::token::hf_token_source())
.with_auto_isq(isq_bits)
.with_logging()
.build()
.await
.map_err(|e| InferenceError::ModelBuild {
reason: format!("Failed to build ISQ model {}: {}", config.display_name, e),
})?;
let elapsed = start.elapsed();
let sampling = sampling.unwrap_or_default();
info!(
"ChatEngine: ISQ model {} loaded in {} (sampling: temp={:?}, max_tokens={:?})",
config.display_name,
format_duration(elapsed),
sampling.temperature,
sampling.max_tokens,
);
let mut guard = self.inner.lock().await;
*guard = Some(LoadedModel {
model: Arc::new(model),
config: LoadedModelConfig::Isq(config),
history: Vec::new(),
system_prompt,
sampling,
});
Ok(elapsed)
}
pub async fn unload_model(&self) -> Option<String> {
let mut guard = self.inner.lock().await;
if let Some(loaded) = guard.take() {
let name = loaded.config.display_name().to_string();
log::info!("ChatEngine: unloading model {}", name);
Some(name)
} else {
log::debug!("ChatEngine: unload_model called but no model was loaded.");
None
}
}
pub async fn is_loaded(&self) -> bool {
self.inner.lock().await.is_some()
}
pub async fn info(&self) -> EngineInfo {
let guard = self.inner.lock().await;
match guard.as_ref() {
Some(loaded) => EngineInfo {
status: EngineStatus::Ready,
model_name: Some(loaded.config.display_name().to_string()),
approx_memory: Some(loaded.config.approx_memory().to_string()),
history_length: loaded.history.len() as u64,
},
None => EngineInfo {
status: EngineStatus::Unloaded,
model_name: None,
approx_memory: None,
history_length: 0u64,
},
}
}
pub async fn set_system_prompt(&self, prompt: impl Into<String>) {
if let Some(loaded) = self.inner.lock().await.as_mut() {
loaded.system_prompt = Some(prompt.into());
}
}
pub async fn clear_system_prompt(&self) {
if let Some(loaded) = self.inner.lock().await.as_mut() {
loaded.system_prompt = None;
}
}
pub async fn set_sampling(&self, sampling: SamplingConfig) {
if let Some(loaded) = self.inner.lock().await.as_mut() {
loaded.sampling = sampling;
}
}
pub async fn history(&self) -> Vec<ChatMessage> {
let guard = self.inner.lock().await;
match guard.as_ref() {
Some(loaded) => loaded.history.clone(),
None => Vec::new(),
}
}
pub async fn clear_history(&self) -> usize {
let mut guard = self.inner.lock().await;
match guard.as_mut() {
Some(loaded) => {
let count = loaded.history.len();
loaded.history.clear();
log::info!("ChatEngine: cleared {} history turns.", count);
count
}
None => 0,
}
}
pub async fn push_history(&self, message: ChatMessage) {
if let Some(loaded) = self.inner.lock().await.as_mut() {
loaded.history.push(message);
}
}
pub async fn send_message(
&self,
user_message: impl Into<String>,
) -> Result<InferenceResult, InferenceError> {
let user_message = user_message.into();
let (model, request) = {
let guard = self.inner.lock().await;
let loaded = guard.as_ref().ok_or(InferenceError::NoModelLoaded)?;
let request = self::build_request(loaded, &user_message);
(loaded.model.clone(), request)
};
log::info!(
"ChatEngine: inference START — message: \"{}\"",
truncate_for_log(&user_message, 100)
);
let start = std::time::Instant::now();
let response =
model
.send_chat_request(request)
.await
.map_err(|e| InferenceError::Inference {
reason: e.to_string(),
})?;
let elapsed = start.elapsed();
let reply = response.choices[0]
.message
.content
.as_ref()
.map(|c| c.trim().to_string())
.unwrap_or_else(|| "(empty response)".to_string());
let finish_reason = response.choices[0].finish_reason.clone();
log::info!(
"ChatEngine: inference END — {} — reply: \"{}\"",
format_duration(elapsed),
truncate_for_log(&reply, 100)
);
{
let mut guard = self.inner.lock().await;
if let Some(loaded) = guard.as_mut() {
loaded.history.push(ChatMessage::user(user_message));
loaded.history.push(ChatMessage::assistant(reply.clone()));
}
}
Ok(InferenceResult {
text: reply,
duration_secs: elapsed.as_secs_f64(),
duration_display: format_duration(elapsed),
finish_reason,
})
}
pub async fn generate(
&self,
messages: Vec<ChatMessage>,
sampling: Option<SamplingConfig>,
) -> Result<InferenceResult, InferenceError> {
let (model, request) = {
let guard = self.inner.lock().await;
let loaded = guard.as_ref().ok_or(InferenceError::NoModelLoaded)?;
let sampling = sampling.as_ref().unwrap_or(&loaded.sampling);
let mut req = RequestBuilder::new();
req = apply_sampling(req, sampling);
if let Some(ref sp) = loaded.system_prompt {
req = req.add_message(TextMessageRole::System, sp);
}
for msg in &messages {
req = req.add_message(chat_role_to_mistral(&msg.role), &msg.content);
}
(loaded.model.clone(), req)
};
let start = std::time::Instant::now();
let response =
model
.send_chat_request(request)
.await
.map_err(|e| InferenceError::Inference {
reason: e.to_string(),
})?;
let elapsed = start.elapsed();
let reply = response.choices[0]
.message
.content
.as_ref()
.map(|c| c.trim().to_string())
.unwrap_or_else(|| "(empty response)".to_string());
let finish_reason = response.choices[0].finish_reason.clone();
Ok(InferenceResult {
text: reply,
duration_secs: elapsed.as_secs_f64(),
duration_display: format_duration(elapsed),
finish_reason,
})
}
pub async fn stream_message(
&self,
user_message: impl Into<String>,
) -> Result<tokio::sync::mpsc::Receiver<StreamChunk>, InferenceError> {
let user_message = user_message.into();
let (model, request) = {
let guard = self.inner.lock().await;
let loaded = guard.as_ref().ok_or(InferenceError::NoModelLoaded)?;
let request = self::build_request(loaded, &user_message);
(loaded.model.clone(), request)
};
let (tx, rx) = tokio::sync::mpsc::channel::<StreamChunk>(64);
let inner_ptr = &self.inner as *const Mutex<Option<LoadedModel>>;
let inner_ref: &'static Mutex<Option<LoadedModel>> = unsafe { &*inner_ptr };
let user_msg_clone = user_message.clone();
tokio::task::spawn(async move {
let stream_result = model.stream_chat_request(request).await;
match stream_result {
Ok(mut stream) => {
let mut assembled = String::new();
let mut last_finish_reason = None;
while let Some(response) = stream.next().await {
match response {
mistralrs::Response::Chunk(chunk) => {
if let Some(choice) = chunk.choices.first() {
if let Some(ref text) = choice.delta.content {
assembled.push_str(text);
let _ = tx
.send(StreamChunk {
delta: text.clone(),
done: false,
finish_reason: None,
})
.await;
}
if let Some(ref reason) = choice.finish_reason {
last_finish_reason = Some(reason.clone());
}
}
}
mistralrs::Response::Done(_) => {
break;
}
mistralrs::Response::InternalError(e) => {
log::error!("ChatEngine stream internal error: {}", e);
break;
}
mistralrs::Response::ValidationError(e) => {
log::error!("ChatEngine stream validation error: {}", e);
break;
}
mistralrs::Response::ModelError(msg, _) => {
log::error!("ChatEngine stream model error: {}", msg);
break;
}
_ => {
break;
}
}
}
{
let mut guard = inner_ref.lock().await;
if let Some(loaded) = guard.as_mut() {
loaded.history.push(ChatMessage::user(user_msg_clone));
loaded
.history
.push(ChatMessage::assistant(assembled.trim()));
}
}
let _ = tx
.send(StreamChunk {
delta: String::new(),
done: true,
finish_reason: last_finish_reason,
})
.await;
}
Err(e) => {
log::error!("ChatEngine: stream_chat_request failed: {}", e);
let _ = tx
.send(StreamChunk {
delta: String::new(),
done: true,
finish_reason: Some(format!("error: {}", e)),
})
.await;
}
}
});
Ok(rx)
}
}
#[cfg(any(
target_os = "macos",
target_os = "ios",
target_os = "tvos",
target_os = "windows",
target_os = "linux",
target_os = "android"
))]
impl Default for ChatEngine {
fn default() -> Self {
Self::new()
}
}
#[cfg(any(
target_os = "macos",
target_os = "ios",
target_os = "tvos",
target_os = "windows",
target_os = "linux",
target_os = "android"
))]
fn build_request(loaded: &LoadedModel, user_message: &str) -> RequestBuilder {
let mut req = RequestBuilder::new();
req = apply_sampling(req, &loaded.sampling);
if let Some(ref sp) = loaded.system_prompt {
req = req.add_message(TextMessageRole::System, sp);
}
for turn in &loaded.history {
req = req.add_message(chat_role_to_mistral(&turn.role), &turn.content);
}
req = req.add_message(TextMessageRole::User, user_message);
req
}
#[cfg(any(
target_os = "macos",
target_os = "ios",
target_os = "tvos",
target_os = "windows",
target_os = "linux",
target_os = "android"
))]
fn apply_sampling(mut req: RequestBuilder, sampling: &SamplingConfig) -> RequestBuilder {
if let Some(temp) = sampling.temperature {
req = req.set_sampler_temperature(temp);
}
if let Some(top_p) = sampling.top_p {
req = req.set_sampler_topp(top_p);
}
if let Some(top_k) = sampling.top_k {
req = req.set_sampler_topk(top_k as usize);
}
if let Some(min_p) = sampling.min_p {
req = req.set_sampler_minp(min_p);
}
if let Some(max_tokens) = sampling.max_tokens {
req = req.set_sampler_max_len(max_tokens as usize);
}
if let Some(freq) = sampling.frequency_penalty {
req = req.set_sampler_frequency_penalty(freq);
}
if let Some(pres) = sampling.presence_penalty {
req = req.set_sampler_presence_penalty(pres);
}
req
}
#[cfg(any(
target_os = "macos",
target_os = "ios",
target_os = "tvos",
target_os = "windows",
target_os = "linux",
target_os = "android"
))]
fn chat_role_to_mistral(role: &ChatRole) -> TextMessageRole {
match role {
ChatRole::System => TextMessageRole::System,
ChatRole::User => TextMessageRole::User,
ChatRole::Assistant => TextMessageRole::Assistant,
}
}
fn truncate_for_log(s: &str, max_len: usize) -> String {
if s.len() > max_len {
format!("{}...", &s[..max_len])
} else {
s.to_string()
}
}
#[cfg(not(any(
target_os = "macos",
target_os = "ios",
target_os = "tvos",
target_os = "windows",
target_os = "linux",
target_os = "android"
)))]
pub struct ChatEngine;
#[cfg(not(any(
target_os = "macos",
target_os = "ios",
target_os = "tvos",
target_os = "windows",
target_os = "linux",
target_os = "android"
)))]
impl ChatEngine {
pub fn new() -> Self {
Self
}
pub async fn load_gguf_model(
&self,
_config: GgufModelConfig,
_system_prompt: Option<String>,
_sampling: Option<SamplingConfig>,
) -> Result<std::time::Duration, InferenceError> {
Err(InferenceError::Other {
reason: "LLM inference is not supported on this platform.".into(),
})
}
pub async fn unload_model(&self) -> Option<String> {
None
}
pub async fn is_loaded(&self) -> bool {
false
}
pub async fn info(&self) -> EngineInfo {
EngineInfo {
status: EngineStatus::Unloaded,
model_name: None,
approx_memory: None,
history_length: 0u64,
}
}
pub async fn set_system_prompt(&self, _prompt: impl Into<String>) {}
pub async fn clear_system_prompt(&self) {}
pub async fn set_sampling(&self, _sampling: SamplingConfig) {}
pub async fn history(&self) -> Vec<ChatMessage> {
Vec::new()
}
pub async fn clear_history(&self) -> usize {
0
}
pub async fn push_history(&self, _message: ChatMessage) {}
pub async fn send_message(
&self,
_user_message: impl Into<String>,
) -> Result<InferenceResult, InferenceError> {
Err(InferenceError::Other {
reason: "LLM inference is not supported on this platform.".into(),
})
}
pub async fn generate(
&self,
_messages: Vec<ChatMessage>,
_sampling: Option<SamplingConfig>,
) -> Result<InferenceResult, InferenceError> {
Err(InferenceError::Other {
reason: "LLM inference is not supported on this platform.".into(),
})
}
pub async fn stream_message(
&self,
_user_message: impl Into<String>,
) -> Result<tokio::sync::mpsc::Receiver<StreamChunk>, InferenceError> {
Err(InferenceError::Other {
reason: "LLM inference is not supported on this platform.".into(),
})
}
}
#[cfg(not(any(
target_os = "macos",
target_os = "ios",
target_os = "tvos",
target_os = "windows",
target_os = "linux",
target_os = "android"
)))]
impl Default for ChatEngine {
fn default() -> Self {
Self::new()
}
}
impl GgufModelConfig {
pub fn qwen25_1_5b() -> Self {
Self {
model_id: super::models::BARTOWSKI_QWEN25_1_5B_INSTRUCT_GGUF.into(),
files: vec![super::models::QWEN25_1_5B_GGUF_FILE.into()],
tok_model_id: if cfg!(target_os = "android") {
Some(super::models::QWEN25_1_5B_TOK_MODEL_ID.into())
} else {
None
},
display_name: "Qwen 2.5 1.5B".into(),
approx_memory: "~941 MB (GGUF Q4_K_M)".into(),
}
}
pub fn qwen25_3b() -> Self {
Self {
model_id: super::models::BARTOWSKI_QWEN25_3B_INSTRUCT_GGUF.into(),
files: vec![super::models::QWEN25_3B_GGUF_FILE.into()],
tok_model_id: if cfg!(target_os = "android") {
Some(super::models::QWEN25_3B_TOK_MODEL_ID.into())
} else {
None
},
display_name: "Qwen 2.5 3B".into(),
approx_memory: "~1.93 GB (GGUF Q4_K_M)".into(),
}
}
pub fn qwen25_coder_1_5b() -> Self {
Self {
model_id: super::models::BARTOWSKI_QWEN25_CODER_1_5B_INSTRUCT_GGUF.into(),
files: vec![super::models::QWEN25_CODER_1_5B_GGUF_FILE.into()],
tok_model_id: if cfg!(target_os = "android") {
Some(super::models::QWEN25_CODER_1_5B_TOK_MODEL_ID.into())
} else {
None
},
display_name: "Qwen 2.5 Coder 1.5B".into(),
approx_memory: "~941 MB (GGUF Q4_K_M)".into(),
}
}
pub fn qwen25_coder_3b() -> Self {
Self {
model_id: super::models::BARTOWSKI_QWEN25_CODER_3B_INSTRUCT_GGUF.into(),
files: vec![super::models::QWEN25_CODER_3B_GGUF_FILE.into()],
tok_model_id: if cfg!(target_os = "android") {
Some(super::models::QWEN25_CODER_3B_TOK_MODEL_ID.into())
} else {
None
},
display_name: "Qwen 2.5 Coder 3B".into(),
approx_memory: "~1.93 GB (GGUF Q4_K_M)".into(),
}
}
pub fn platform_default() -> Self {
if cfg!(any(
target_os = "ios",
target_os = "tvos",
target_os = "android"
)) {
Self::qwen25_coder_1_5b()
} else {
Self::qwen25_coder_3b()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn truncate_for_log_short() {
assert_eq!(truncate_for_log("hello", 10), "hello");
}
#[test]
fn truncate_for_log_long() {
let long = "a".repeat(200);
let result = truncate_for_log(&long, 50);
assert!(result.ends_with("..."));
assert_eq!(result.len(), 53); }
#[test]
fn gguf_model_config_qwen25_1_5b() {
let cfg = GgufModelConfig::qwen25_1_5b();
assert!(cfg.model_id.contains("1.5B"));
assert_eq!(cfg.files.len(), 1);
}
#[test]
fn gguf_model_config_qwen25_3b() {
let cfg = GgufModelConfig::qwen25_3b();
assert!(cfg.model_id.contains("3B"));
assert_eq!(cfg.files.len(), 1);
}
#[test]
fn gguf_model_config_platform_default() {
let cfg = GgufModelConfig::platform_default();
assert!(!cfg.model_id.is_empty());
assert!(!cfg.files.is_empty());
}
#[tokio::test]
async fn engine_new_is_unloaded() {
let engine = ChatEngine::new();
assert!(!engine.is_loaded().await);
let info = engine.info().await;
assert_eq!(info.status, EngineStatus::Unloaded);
assert_eq!(info.history_length, 0);
}
#[tokio::test]
async fn engine_send_without_model_errors() {
let engine = ChatEngine::new();
let result = engine.send_message("hello").await;
assert!(result.is_err());
match result.unwrap_err() {
InferenceError::NoModelLoaded => {} other => panic!("Expected NoModelLoaded, got: {:?}", other),
}
}
#[tokio::test]
async fn engine_history_empty_when_no_model() {
let engine = ChatEngine::new();
assert!(engine.history().await.is_empty());
assert_eq!(engine.clear_history().await, 0);
}
#[tokio::test]
async fn engine_unload_when_none() {
let engine = ChatEngine::new();
assert!(engine.unload_model().await.is_none());
}
}