use std::sync::Arc;
use super::engine::ChatEngine;
use super::types::*;
#[uniffi::export(callback_interface)]
pub trait StreamChunkListener: Send + Sync {
fn on_chunk(&self, chunk: StreamChunk) -> bool;
}
#[derive(uniffi::Object)]
pub struct OndeChatEngine {
inner: ChatEngine,
}
#[uniffi::export]
impl OndeChatEngine {
#[uniffi::constructor]
pub fn new() -> Arc<Self> {
Arc::new(Self {
inner: ChatEngine::new(),
})
}
pub async fn load_gguf_model(
&self,
config: GgufModelConfig,
system_prompt: Option<String>,
sampling: Option<SamplingConfig>,
) -> Result<f64, InferenceError> {
let elapsed = self
.inner
.load_gguf_model(config, system_prompt, sampling)
.await?;
Ok(elapsed.as_secs_f64())
}
pub async fn load_default_model(
&self,
system_prompt: Option<String>,
sampling: Option<SamplingConfig>,
) -> Result<f64, InferenceError> {
let config = GgufModelConfig::platform_default();
let elapsed = self
.inner
.load_gguf_model(config, system_prompt, sampling)
.await?;
Ok(elapsed.as_secs_f64())
}
pub async fn load_assigned_model(
&self,
app_id: String,
app_secret: String,
system_prompt: Option<String>,
sampling: Option<SamplingConfig>,
) -> Result<f64, InferenceError> {
let elapsed = self
.inner
.load_assigned_model(
smbcloud_gresiq_sdk::Environment::Production,
&app_id,
&app_secret,
system_prompt,
sampling,
)
.await?;
Ok(elapsed.as_secs_f64())
}
pub async fn unload_model(&self) -> Option<String> {
self.inner.unload_model().await
}
pub async fn is_loaded(&self) -> bool {
self.inner.is_loaded().await
}
pub async fn info(&self) -> EngineInfo {
self.inner.info().await
}
pub async fn set_system_prompt(&self, prompt: String) {
self.inner.set_system_prompt(prompt).await;
}
pub async fn clear_system_prompt(&self) {
self.inner.clear_system_prompt().await;
}
pub async fn set_sampling(&self, sampling: SamplingConfig) {
self.inner.set_sampling(sampling).await;
}
pub async fn history(&self) -> Vec<ChatMessage> {
self.inner.history().await
}
pub async fn clear_history(&self) -> u64 {
self.inner.clear_history().await as u64
}
pub async fn push_history(&self, message: ChatMessage) {
self.inner.push_history(message).await;
}
pub async fn send_message(&self, message: String) -> Result<InferenceResult, InferenceError> {
self.inner.send_message(message).await
}
pub async fn generate(
&self,
messages: Vec<ChatMessage>,
sampling: Option<SamplingConfig>,
) -> Result<InferenceResult, InferenceError> {
self.inner.generate(messages, sampling).await
}
}
#[uniffi::export]
pub async fn stream_chat_message(
engine: Arc<OndeChatEngine>,
message: String,
listener: Box<dyn StreamChunkListener>,
) -> Result<(), InferenceError> {
let mut rx = engine.inner.stream_message(message).await?;
while let Some(chunk) = rx.recv().await {
let done = chunk.done;
let should_continue = listener.on_chunk(chunk);
if done || !should_continue {
break;
}
}
Ok(())
}
#[uniffi::export]
pub fn default_model_config() -> GgufModelConfig {
GgufModelConfig::platform_default()
}
#[uniffi::export]
pub fn qwen25_1_5b_config() -> GgufModelConfig {
GgufModelConfig::qwen25_1_5b()
}
#[uniffi::export]
pub fn qwen25_3b_config() -> GgufModelConfig {
GgufModelConfig::qwen25_3b()
}
#[uniffi::export]
pub fn default_sampling_config() -> SamplingConfig {
SamplingConfig::default()
}
#[uniffi::export]
pub fn deterministic_sampling_config() -> SamplingConfig {
SamplingConfig::deterministic()
}
#[uniffi::export]
pub fn mobile_sampling_config() -> SamplingConfig {
SamplingConfig::mobile()
}
#[uniffi::export]
pub fn system_message(content: String) -> ChatMessage {
ChatMessage::system(content)
}
#[uniffi::export]
pub fn user_message(content: String) -> ChatMessage {
ChatMessage::user(content)
}
#[uniffi::export]
pub fn assistant_message(content: String) -> ChatMessage {
ChatMessage::assistant(content)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn free_functions_return_valid_configs() {
let cfg = default_model_config();
assert!(!cfg.model_id.is_empty());
let cfg = qwen25_1_5b_config();
assert!(cfg.model_id.contains("1.5B"));
let cfg = qwen25_3b_config();
assert!(cfg.model_id.contains("3B"));
}
#[test]
fn free_functions_return_valid_sampling() {
let s = default_sampling_config();
assert_eq!(s.temperature, Some(0.7));
let s = deterministic_sampling_config();
assert_eq!(s.temperature, Some(0.0));
let s = mobile_sampling_config();
assert_eq!(s.max_tokens, Some(128));
}
#[test]
fn message_helpers() {
let m = system_message("You are helpful.".into());
assert_eq!(m.role, ChatRole::System);
assert_eq!(m.content, "You are helpful.");
let m = user_message("Hello".into());
assert_eq!(m.role, ChatRole::User);
let m = assistant_message("Hi!".into());
assert_eq!(m.role, ChatRole::Assistant);
}
#[tokio::test]
async fn onde_chat_engine_new_is_unloaded() {
let engine = OndeChatEngine::new();
assert!(!engine.is_loaded().await);
let info = engine.info().await;
assert_eq!(info.status, EngineStatus::Unloaded);
assert_eq!(info.history_length, 0);
}
#[tokio::test]
async fn onde_chat_engine_send_without_model_errors() {
let engine = OndeChatEngine::new();
let result = engine.send_message("hello".into()).await;
assert!(result.is_err());
}
#[tokio::test]
async fn onde_chat_engine_history_empty() {
let engine = OndeChatEngine::new();
assert!(engine.history().await.is_empty());
assert_eq!(engine.clear_history().await, 0);
}
#[tokio::test]
async fn onde_chat_engine_unload_when_none() {
let engine = OndeChatEngine::new();
assert!(engine.unload_model().await.is_none());
}
}