use colored::*;
use rustyline::DefaultEditor;
use rustyline::error::ReadlineError;
use std::io::{Write, stdout};
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::atomic::{AtomicBool, Ordering};
use std::thread;
use std::time::Duration;
use crate::discover::Model;
use crate::engine::InferenceEngine;
use crate::{engine::EngineConfig, template::*};
struct Spinner {
handle: Option<thread::JoinHandle<()>>,
stop_signal: Arc<AtomicBool>,
}
impl Spinner {
fn new(message: String) -> Self {
let stop_signal = Arc::new(AtomicBool::new(false));
let signal_clone = stop_signal.clone();
let handle = thread::spawn(move || {
let spinner_chars = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏'];
let mut i = 0;
print!("\x1B[?25l");
stdout().flush().unwrap();
while !signal_clone.load(Ordering::Relaxed) {
let frame = spinner_chars[i % spinner_chars.len()];
print!("\r{} {}", message.dimmed(), frame);
stdout().flush().unwrap();
thread::sleep(Duration::from_millis(80));
i += 1;
}
print!("\r{}\r", " ".repeat(message.len() + 5));
print!("\x1B[?25h");
stdout().flush().unwrap();
});
Self {
handle: Some(handle),
stop_signal,
}
}
fn stop(mut self) {
self.stop_signal.store(true, Ordering::Relaxed);
if let Some(handle) = self.handle.take() {
handle.join().unwrap();
}
}
}
struct ChatSession {
engine: Box<InferenceEngine>,
data: Vec<Message>,
system_prompt: String,
}
impl ChatSession {
fn new(engine: Box<InferenceEngine>) -> Self {
Self {
engine,
data: vec![],
system_prompt: "You are a helpful, respectful and honest AI assistant.".to_string(),
}
}
fn print_welcome_message(&self) {
println!("{}", "========================================".cyan());
println!("{}", " Welcome to Tllama!".cyan().bold());
println!("{}", "========================================".cyan());
println!("Type your message and press Enter to chat with the AI.");
println!("Type `.help` for more commands.");
println!("Type `.exit` or press Ctrl+C to quit.");
println!();
}
fn handle_command(&mut self, command: &str) -> Result<bool, Box<dyn std::error::Error>> {
let parts: Vec<&str> = command.trim().splitn(2, ' ').collect();
let cmd = parts[0];
match cmd {
".exit" | ".quit" | ".q" | ".bye" => {
println!("{}", "Goodbye!".yellow());
return Ok(false);
}
".help" => {
println!("{}", "Available Commands:".green().bold());
println!(" {:<15} {}", ".help", "Show this help message.");
println!(
" {:<15} {}",
".system [prompt]", "View or set the system prompt."
);
println!(" {:<15} {}", ".clear", "Clear the conversation history.");
println!(" {:<15} {}", ".history", "Show the conversation history.");
println!(" {:<15} {}", ".exit", "Exit the chat session.");
}
".system" => {
if let Some(new_prompt) = parts.get(1) {
self.system_prompt = new_prompt.to_string();
println!(
"{} {}",
"System prompt updated:".green(),
self.system_prompt
);
} else {
println!(
"{} {}",
"Current system prompt:".green(),
self.system_prompt
);
}
}
".clear" => {
self.data.clear();
println!("{}", "Conversation history cleared.".green());
}
".history" => {
println!("{}", "Conversation History:".green().bold());
if self.data.is_empty() {
println!(" (No history yet)");
} else {
for msg in &self.data {
let prefix = if msg.role == "user" {
"You".blue()
} else {
"AI".cyan()
};
println!("{}: {}", prefix, msg.content.as_deref().unwrap_or(""));
}
}
}
_ => {
println!("{}'{}'", "Unknown command: ".red(), command);
}
}
Ok(true)
}
fn start(&mut self) -> Result<(), Box<dyn std::error::Error>> {
let mut rl = DefaultEditor::new()?;
ctrlc::set_handler(move || {
println!("\n{}", "Received Ctrl-C. Exiting...".yellow());
print!("\x1B[?25h");
stdout().flush().unwrap();
std::process::exit(0);
})?;
self.print_welcome_message();
loop {
let readline = rl.readline(&">>> ".green().to_string());
match readline {
Ok(line) => {
let input = line.trim();
if input.is_empty() {
continue;
}
rl.add_history_entry(input)?;
if input.starts_with('.') {
if !self.handle_command(input)? {
break;
}
} else {
self.chat(input)?;
}
}
Err(ReadlineError::Interrupted) => {
println!("{}", "Received Ctrl-C. Exiting...".yellow());
break;
}
Err(ReadlineError::Eof) => {
println!("{}", "Received Ctrl-D. Exiting...".yellow());
break;
}
Err(err) => {
println!("Error: {:?}", err);
break;
}
}
}
Ok(())
}
fn chat(&mut self, user_input: &str) -> Result<(), Box<dyn std::error::Error>> {
self.data.push(Message {
role: "user".to_string(),
content: Some(user_input.to_string()),
tool_calls: None,
name: None,
});
let prompt_data = TemplateData::new()
.with_system(Some(self.system_prompt.clone()))
.with_messages(Some(self.data.clone()));
let prompt = render_chatml_template(&prompt_data)?;
let spinner = Arc::new(Mutex::new(Some(Spinner::new("".to_string()))));
let spinner_clone = spinner.clone();
let mut first_token = true;
let result = self.engine.infer(
&prompt,
None,
crate::def_callback!(|token| {
if first_token {
let mut spinner_guard = spinner_clone.lock().unwrap();
if let Some(s) = spinner_guard.take() {
s.stop();
}
stdout().flush().unwrap();
first_token = false;
}
print!("{}", token);
stdout().flush().unwrap();
}),
);
let mut spinner_guard = spinner.lock().unwrap();
if let Some(s) = spinner_guard.take() {
s.stop();
}
println!();
self.data.push(Message {
role: "assistant".to_string(),
content: Some(result?),
tool_calls: None,
name: None,
});
Ok(())
}
}
pub fn chat_session(args: crate::cli::ChatArgs) -> Result<(), Box<dyn std::error::Error>> {
#[cfg(feature = "engine-llama-cpp")]
llama_cpp_2::send_logs_to_tracing(llama_cpp_2::LogOptions::default().with_logs_enabled(false));
let model_path;
if args.model.starts_with('.') || args.model.starts_with('/') {
model_path = Model::from_path(&args.model)
} else {
model_path = crate::discover::MODEL_DISCOVERER
.lock()
.unwrap()
.find_model(&args.model)?;
}
let engine_config = EngineConfig {
n_ctx: 4096,
n_len: None,
temperature: 0.8,
top_k: 40,
top_p: 0.9,
repeat_penalty: 1.1,
};
let spinner = Spinner::new("Loading model...".to_string());
let engine_result = crate::engine::InferenceEngine::new(&engine_config, &model_path);
spinner.stop();
let engine = match engine_result {
Ok(engine) => {
println!("{} {}", "Model loaded successfully.".dimmed(), "✔".green());
Box::new(engine)
}
Err(e) => {
eprintln!("\n{} {}", "Failed to load model.".red().bold(), "✖".red());
return Err(e.into());
}
};
let mut session = ChatSession::new(engine);
session.start()
}