use std::sync::Arc;
use super::engine::ChatEngine;
use super::types::*;
#[uniffi::export(callback_interface)]
pub trait StreamChunkListener: Send + Sync {
fn on_chunk(&self, chunk: StreamChunk) -> bool;
}
#[derive(uniffi::Object)]
pub struct OndeChatEngine {
inner: ChatEngine,
}
#[uniffi::export(async_runtime = "tokio")]
impl OndeChatEngine {
#[uniffi::constructor]
pub fn new() -> Arc<Self> {
Arc::new(Self {
inner: ChatEngine::new(),
})
}
pub async fn load_gguf_model(
&self,
config: GgufModelConfig,
system_prompt: Option<String>,
sampling: Option<SamplingConfig>,
) -> Result<f64, InferenceError> {
let elapsed = self
.inner
.load_gguf_model(config, system_prompt, sampling)
.await?;
Ok(elapsed.as_secs_f64())
}
pub async fn load_default_model(
&self,
system_prompt: Option<String>,
sampling: Option<SamplingConfig>,
) -> Result<f64, InferenceError> {
let config = GgufModelConfig::platform_default();
let elapsed = self
.inner
.load_gguf_model(config, system_prompt, sampling)
.await?;
Ok(elapsed.as_secs_f64())
}
pub async fn load_assigned_model(
&self,
app_id: String,
app_secret: String,
system_prompt: Option<String>,
sampling: Option<SamplingConfig>,
) -> Result<f64, InferenceError> {
let elapsed = self
.inner
.load_assigned_model(
smbcloud_gresiq_sdk::Environment::Production,
&app_id,
&app_secret,
system_prompt,
sampling,
)
.await?;
Ok(elapsed.as_secs_f64())
}
pub async fn unload_model(&self) -> Option<String> {
self.inner.unload_model().await
}
pub async fn is_loaded(&self) -> bool {
self.inner.is_loaded().await
}
pub async fn info(&self) -> EngineInfo {
self.inner.info().await
}
pub async fn set_system_prompt(&self, prompt: String) {
self.inner.set_system_prompt(prompt).await;
}
pub async fn clear_system_prompt(&self) {
self.inner.clear_system_prompt().await;
}
pub async fn set_sampling(&self, sampling: SamplingConfig) {
self.inner.set_sampling(sampling).await;
}
pub async fn history(&self) -> Vec<ChatMessage> {
self.inner.history().await
}
pub async fn clear_history(&self) -> u64 {
self.inner.clear_history().await as u64
}
pub async fn push_history(&self, message: ChatMessage) {
self.inner.push_history(message).await;
}
pub async fn send_message(&self, message: String) -> Result<InferenceResult, InferenceError> {
self.inner.send_message(message).await
}
pub async fn generate(
&self,
messages: Vec<ChatMessage>,
sampling: Option<SamplingConfig>,
) -> Result<InferenceResult, InferenceError> {
self.inner.generate(messages, sampling).await
}
}
#[uniffi::export(async_runtime = "tokio")]
pub async fn stream_chat_message(
engine: Arc<OndeChatEngine>,
message: String,
listener: Box<dyn StreamChunkListener>,
) -> Result<(), InferenceError> {
let mut rx = engine.inner.stream_message(message).await?;
while let Some(chunk) = rx.recv().await {
let done = chunk.done;
let should_continue = listener.on_chunk(chunk);
if done || !should_continue {
break;
}
}
Ok(())
}
#[uniffi::export]
pub fn default_model_config() -> GgufModelConfig {
GgufModelConfig::platform_default()
}
#[uniffi::export]
pub fn qwen25_0_5b_config() -> GgufModelConfig {
GgufModelConfig::qwen25_0_5b()
}
#[uniffi::export]
pub fn qwen25_1_5b_config() -> GgufModelConfig {
GgufModelConfig::qwen25_1_5b()
}
#[uniffi::export]
pub fn qwen25_3b_config() -> GgufModelConfig {
GgufModelConfig::qwen25_3b()
}
#[uniffi::export]
pub fn default_sampling_config() -> SamplingConfig {
SamplingConfig::default()
}
#[uniffi::export]
pub fn deterministic_sampling_config() -> SamplingConfig {
SamplingConfig::deterministic()
}
#[uniffi::export]
pub fn mobile_sampling_config() -> SamplingConfig {
SamplingConfig::mobile()
}
#[uniffi::export]
pub fn system_message(content: String) -> ChatMessage {
ChatMessage::system(content)
}
#[uniffi::export]
pub fn user_message(content: String) -> ChatMessage {
ChatMessage::user(content)
}
#[uniffi::export]
pub fn assistant_message(content: String) -> ChatMessage {
ChatMessage::assistant(content)
}
#[uniffi::export]
pub fn configure_cache_dir(path: String) {
let base = std::path::PathBuf::from(&path);
let hf_home = base.join("models");
let hf_hub_cache = hf_home.join("hub");
if let Err(e) = std::fs::create_dir_all(&hf_hub_cache) {
log::warn!(
"configure_cache_dir: could not create {}: {}",
hf_hub_cache.display(),
e
);
}
std::env::set_var("HF_HOME", &hf_home);
std::env::set_var("HF_HUB_CACHE", &hf_hub_cache);
log::info!(
"configure_cache_dir: HF_HOME={}, HF_HUB_CACHE={}",
hf_home.display(),
hf_hub_cache.display()
);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn free_functions_return_valid_configs() {
let cfg = default_model_config();
assert!(!cfg.model_id.is_empty());
let cfg = qwen25_1_5b_config();
assert!(cfg.model_id.contains("1.5B"));
let cfg = qwen25_3b_config();
assert!(cfg.model_id.contains("3B"));
}
#[test]
fn free_functions_return_valid_sampling() {
let s = default_sampling_config();
assert_eq!(s.temperature, Some(0.7));
let s = deterministic_sampling_config();
assert_eq!(s.temperature, Some(0.0));
let s = mobile_sampling_config();
assert_eq!(s.max_tokens, Some(128));
}
#[test]
fn message_helpers() {
let m = system_message("You are helpful.".into());
assert_eq!(m.role, ChatRole::System);
assert_eq!(m.content, "You are helpful.");
let m = user_message("Hello".into());
assert_eq!(m.role, ChatRole::User);
let m = assistant_message("Hi!".into());
assert_eq!(m.role, ChatRole::Assistant);
}
#[test]
fn configure_cache_dir_sets_env() {
let tmp = std::env::temp_dir().join("onde-test-cache-dir");
configure_cache_dir(tmp.to_string_lossy().to_string());
let hf_home = std::env::var("HF_HOME").unwrap();
assert!(hf_home.contains("models"));
let hf_hub = std::env::var("HF_HUB_CACHE").unwrap();
assert!(hf_hub.contains("hub"));
let _ = std::fs::remove_dir_all(&tmp);
}
#[tokio::test]
async fn onde_chat_engine_new_is_unloaded() {
let engine = OndeChatEngine::new();
assert!(!engine.is_loaded().await);
let info = engine.info().await;
assert_eq!(info.status, EngineStatus::Unloaded);
assert_eq!(info.history_length, 0);
}
#[tokio::test]
async fn onde_chat_engine_send_without_model_errors() {
let engine = OndeChatEngine::new();
let result = engine.send_message("hello".into()).await;
assert!(result.is_err());
}
#[tokio::test]
async fn onde_chat_engine_history_empty() {
let engine = OndeChatEngine::new();
assert!(engine.history().await.is_empty());
assert_eq!(engine.clear_history().await, 0);
}
#[tokio::test]
async fn onde_chat_engine_unload_when_none() {
let engine = OndeChatEngine::new();
assert!(engine.unload_model().await.is_none());
}
}