use std::path::PathBuf;
use std::time::Instant;
use onde::mistralrs::{GgufModelBuilder, RequestBuilder, Response, TextMessageRole, TokenSource};
use tokio::sync::mpsc;
#[derive(Debug, Clone)]
pub enum ChatProgress {
LoadingModel,
Ready { model_name: String },
Thinking,
Reply {
_text: String,
duration_display: String,
},
StreamDelta(String),
Error(String),
}
#[derive(Debug, Clone)]
pub struct ChatMessage {
pub role: ChatRole,
pub content: String,
}
impl ChatMessage {
pub fn user(content: impl Into<String>) -> Self {
Self {
role: ChatRole::User,
content: content.into(),
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: ChatRole::Assistant,
content: content.into(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ChatRole {
User,
Assistant,
}
#[derive(Debug)]
pub enum ChatCommand {
SendMessage(String),
Quit,
}
fn format_duration(elapsed: std::time::Duration) -> String {
let total_secs = elapsed.as_secs_f64();
let mins = (total_secs / 60.0).floor() as u64;
let secs = total_secs - (mins as f64 * 60.0);
if mins > 0 {
format!("{}m {:.1}s", mins, secs)
} else {
format!("{:.1}s", total_secs)
}
}
#[cfg(test)]
fn to_mistral_role(role: &ChatRole) -> TextMessageRole {
match role {
ChatRole::User => TextMessageRole::User,
ChatRole::Assistant => TextMessageRole::Assistant,
}
}
fn build_request(history: &[(TextMessageRole, String)], user_message: &str) -> RequestBuilder {
let mut req = RequestBuilder::new()
.set_sampler_temperature(0.7)
.set_sampler_max_len(512);
for (role, content) in history {
req = req.add_message(role.clone(), content);
}
req = req.add_message(TextMessageRole::User, user_message);
req
}
pub async fn start_chat(
gguf_path: PathBuf,
progress_tx: mpsc::UnboundedSender<ChatProgress>,
mut command_rx: mpsc::UnboundedReceiver<ChatCommand>,
) {
let _ = progress_tx.send(ChatProgress::LoadingModel);
let model_dir = match gguf_path.parent() {
Some(p) => p.to_string_lossy().to_string(),
None => {
let _ = progress_tx.send(ChatProgress::Error(format!(
"Cannot determine parent directory of: {}",
gguf_path.display()
)));
return;
}
};
let file_name = match gguf_path.file_name() {
Some(n) => n.to_string_lossy().to_string(),
None => {
let _ = progress_tx.send(ChatProgress::Error(format!(
"Cannot determine file name of: {}",
gguf_path.display()
)));
return;
}
};
let model = match GgufModelBuilder::new(&model_dir, vec![file_name.clone()])
.with_token_source(TokenSource::None)
.build()
.await
{
Ok(m) => m,
Err(e) => {
let _ = progress_tx.send(ChatProgress::Error(format!(
"Failed to load model \"{file_name}\": {e}"
)));
return;
}
};
let _ = progress_tx.send(ChatProgress::Ready {
model_name: file_name,
});
let mut history: Vec<(TextMessageRole, String)> = Vec::new();
loop {
let command = match command_rx.recv().await {
Some(cmd) => cmd,
None => break,
};
match command {
ChatCommand::Quit => break,
ChatCommand::SendMessage(user_text) => {
let _ = progress_tx.send(ChatProgress::Thinking);
let request = build_request(&history, &user_text);
let start = Instant::now();
let mut stream = match model.stream_chat_request(request).await {
Ok(s) => s,
Err(e) => {
let _ = progress_tx.send(ChatProgress::Error(format!(
"Failed to start streaming: {e}"
)));
continue;
}
};
let mut reply_buf = String::new();
while let Some(response) = stream.next().await {
match response {
Response::Chunk(chunk) => {
if let Some(choice) = chunk.choices.first()
&& let Some(ref delta) = choice.delta.content
{
reply_buf.push_str(delta);
let _ = progress_tx.send(ChatProgress::StreamDelta(delta.clone()));
}
}
Response::Done(_) => break,
Response::ModelError(msg, _) => {
let _ = progress_tx
.send(ChatProgress::Error(format!("Model error: {msg}")));
break;
}
Response::InternalError(e) => {
let _ = progress_tx
.send(ChatProgress::Error(format!("Internal error: {e}")));
break;
}
Response::ValidationError(e) => {
let _ = progress_tx
.send(ChatProgress::Error(format!("Validation error: {e}")));
break;
}
Response::CompletionModelError(msg, _) => {
let _ = progress_tx
.send(ChatProgress::Error(format!("Completion error: {msg}")));
break;
}
Response::CompletionDone(_) => break,
Response::CompletionChunk(_) => {}
Response::ImageGeneration(_) => {}
Response::Speech { .. } => {}
Response::Raw { .. } => {}
Response::Embeddings { .. } => {}
}
}
let elapsed = start.elapsed();
if !reply_buf.is_empty() {
history.push((TextMessageRole::User, user_text));
history.push((TextMessageRole::Assistant, reply_buf.trim().to_string()));
let _ = progress_tx.send(ChatProgress::Reply {
_text: reply_buf.trim().to_string(),
duration_display: format_duration(elapsed),
});
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn format_duration_under_one_minute() {
let d = std::time::Duration::from_secs_f64(4.567);
assert_eq!(format_duration(d), "4.6s");
}
#[test]
fn format_duration_exactly_one_minute() {
let d = std::time::Duration::from_secs(60);
assert_eq!(format_duration(d), "1m 0.0s");
}
#[test]
fn format_duration_over_one_minute() {
let d = std::time::Duration::from_secs_f64(125.3);
assert_eq!(format_duration(d), "2m 5.3s");
}
#[test]
fn chat_message_user_role() {
let msg = ChatMessage::user("hello");
assert_eq!(msg.role, ChatRole::User);
assert_eq!(msg.content, "hello");
}
#[test]
fn chat_message_assistant_role() {
let msg = ChatMessage::assistant("hi there");
assert_eq!(msg.role, ChatRole::Assistant);
assert_eq!(msg.content, "hi there");
}
#[test]
fn to_mistral_role_maps_correctly() {
assert!(matches!(
to_mistral_role(&ChatRole::User),
TextMessageRole::User
));
assert!(matches!(
to_mistral_role(&ChatRole::Assistant),
TextMessageRole::Assistant
));
}
}