use crate::error::CliError;
use crate::output;
#[cfg(feature = "inference")]
use aprender::text::llama_tokenizer::LlamaTokenizer;
use aprender::text::bpe::Qwen2BpeTokenizer;
use aprender::serialization::apr::AprReader;
use aprender::text::chat_template::{
auto_detect_template, detect_format_from_name, ChatMessage, ChatTemplateEngine, TemplateFormat,
};
use colored::Colorize;
use std::io::{self, Write};
use std::path::Path;
use std::time::Instant;
pub(crate) struct ChatConfig {
pub temperature: f32,
pub top_p: f32,
pub max_tokens: usize,
pub system: Option<String>,
pub inspect: bool,
pub force_cpu: bool,
pub trace: bool,
pub trace_output: Option<std::path::PathBuf>,
}
impl Default for ChatConfig {
fn default() -> Self {
Self {
temperature: 0.7,
top_p: 0.9,
max_tokens: 512,
system: None,
inspect: false,
force_cpu: false, trace: false,
trace_output: None,
}
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn run(
path: &Path,
temperature: f32,
top_p: f32,
max_tokens: usize,
system: Option<&str>,
inspect: bool,
force_cpu: bool,
trace: bool,
trace_steps: Option<&[String]>,
trace_verbose: bool,
trace_output: Option<std::path::PathBuf>,
trace_level: &str,
profile: bool,
) -> Result<(), CliError> {
if !path.exists() {
return Err(CliError::FileNotFound(path.to_path_buf()));
}
if trace_steps.is_some() {
eprintln!("Warning: --trace-steps is not yet implemented for chat. Flag ignored.");
}
if trace_verbose {
eprintln!("Warning: --trace-verbose is not yet implemented for chat. Flag ignored.");
}
if profile {
eprintln!("Warning: --profile is not yet implemented for chat. Flag ignored.");
}
if trace {
eprintln!(
"{}",
"Inference tracing enabled for chat (APR-TRACE-001)".cyan()
);
eprintln!(" Trace level: {}", trace_level);
if let Some(steps) = trace_steps {
eprintln!(" Trace steps: {}", steps.join(", "));
}
if trace_verbose {
eprintln!(" Verbose mode enabled");
}
if let Some(ref path) = trace_output {
eprintln!(" Output: {}", path.display());
}
if profile {
eprintln!(" Roofline profiling enabled");
}
}
let _ = (trace_steps, trace_verbose, trace_level, profile);
let config = ChatConfig {
temperature,
top_p,
max_tokens,
system: system.map(String::from),
inspect,
force_cpu,
trace,
trace_output,
};
print_welcome_banner(path, &config);
run_repl(path, &config)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ModelFormat {
Apr,
Gguf,
SafeTensors,
Demo,
}
fn detect_format(path: &Path) -> ModelFormat {
match path.extension().and_then(|e| e.to_str()) {
Some("apr") => ModelFormat::Apr,
Some("gguf") => ModelFormat::Gguf,
Some("safetensors") => ModelFormat::SafeTensors,
_ => ModelFormat::Demo,
}
}
fn try_load_tokenizer(path: &Path, label: &str) -> Option<Qwen2BpeTokenizer> {
match Qwen2BpeTokenizer::from_file(path) {
Ok(tok) => {
println!(
"{} {} ({})",
format!("Loaded tokenizer{label}:").green(),
path.display(),
format!("{} tokens", tok.vocab_size()).dimmed()
);
Some(tok)
}
Err(e) => {
println!(
"{} {}",
format!("Warning: Failed to load tokenizer{label}:").yellow(),
e
);
None
}
}
}
fn search_hf_cache_tokenizer(hf_cache: &Path) -> Option<Qwen2BpeTokenizer> {
let entries = std::fs::read_dir(hf_cache).ok()?;
for entry in entries.flatten() {
let name = entry.file_name();
if !name.to_string_lossy().starts_with("models--Qwen") {
continue;
}
let snapshots_dir = entry.path().join("snapshots");
let snapshots = std::fs::read_dir(&snapshots_dir).ok()?;
for snapshot in snapshots.flatten() {
let tokenizer_path = snapshot.path().join("tokenizer.json");
if tokenizer_path.exists() {
if let Some(tok) = try_load_tokenizer(&tokenizer_path, " from HuggingFace cache") {
return Some(tok);
}
}
}
}
None
}
fn try_tokenizer_at(path: &Path, label: &str) -> Option<Qwen2BpeTokenizer> {
if path.exists() {
try_load_tokenizer(path, label)
} else {
None
}
}
fn find_qwen_tokenizer_sibling(model_path: &Path) -> Option<Qwen2BpeTokenizer> {
if let Some(parent) = model_path.parent() {
if let Some(stem) = model_path.file_stem().and_then(|s| s.to_str()) {
let prefixed = parent.join(format!("{stem}.tokenizer.json"));
if let Some(tok) = try_tokenizer_at(&prefixed, " from pacha cache") {
return Some(tok);
}
}
}
model_path
.parent()
.and_then(|p| try_tokenizer_at(&p.join("tokenizer.json"), ""))
}
fn find_qwen_tokenizer(model_path: &Path) -> Result<Option<Qwen2BpeTokenizer>, CliError> {
if let Some(tok) = find_qwen_tokenizer_sibling(model_path) {
return Ok(Some(tok));
}
if let Some(home) = dirs::home_dir() {
if let Some(tok) = search_hf_cache_tokenizer(&home.join(".cache/huggingface/hub")) {
return Ok(Some(tok));
}
if let Some(tok) = try_tokenizer_at(
&home.join(".apr/tokenizers/qwen2/tokenizer.json"),
" from APR cache",
) {
return Ok(Some(tok));
}
}
Err(CliError::InvalidFormat(
"No Qwen tokenizer found. Searched:\n\
1. Pacha cache ({stem}.tokenizer.json alongside model)\n\
2. Model directory (tokenizer.json)\n\
3. HuggingFace cache (~/.cache/huggingface/hub/models--Qwen--*/snapshots/*/tokenizer.json)\n\
4. APR cache (~/.apr/tokenizers/qwen2/tokenizer.json)\n\n\
To fix: Download a Qwen model with tokenizer:\n\
apr pull hf://Qwen/Qwen2.5-0.5B-Instruct-GGUF"
.to_string(),
))
}
fn normalize_repeated_punctuation(s: &str) -> String {
let mut prev_char = '\0';
let mut repeat_count = 0;
let mut result = String::with_capacity(s.len());
for c in s.chars() {
if c == prev_char && matches!(c, '!' | '?' | '.') {
repeat_count += 1;
if repeat_count < 3 {
result.push(c);
}
} else {
repeat_count = 0;
result.push(c);
}
prev_char = c;
}
result
}
fn looks_like_new_turn(text: &str) -> bool {
text.starts_with("Suggest")
|| text.starts_with("What")
|| text.starts_with("How")
|| text.starts_with("Why")
|| text.starts_with("Can")
|| text.starts_with("Human:")
|| text.contains("<|im_start|>")
}
fn clean_chat_response(raw: &str) -> String {
let mut cleaned = raw.to_string();
for marker in &[
"<|im_start|>assistant\n",
"<|im_start|>assistant",
"<|im_end|>",
"<|im_start|>",
"<|endoftext|>",
] {
cleaned = cleaned.replace(marker, "");
}
cleaned = cleaned.replace('\u{0120}', " ");
cleaned = cleaned.replace('\u{010A}', "\n");
cleaned = cleaned.replace("Ġ", " ");
cleaned = cleaned.replace("Ċ", "\n");
cleaned = normalize_repeated_punctuation(&cleaned);
while cleaned.contains(" ") {
cleaned = cleaned.replace(" ", " ");
}
let trimmed = cleaned.trim();
if let Some(first_newline) = trimmed.find('\n') {
let first_line = trimmed[..first_newline].trim();
let rest = trimmed[first_newline..].trim();
if looks_like_new_turn(rest) {
return first_line.to_string();
}
}
trimmed.to_string()
}
#[cfg(feature = "inference")]
fn detect_format_from_bytes(data: &[u8]) -> ModelFormat {
if data.len() < 8 {
return ModelFormat::Demo;
}
if &data[0..4] == b"APRN" || &data[0..4] == b"APR2" || &data[0..4] == b"APR\0" {
return ModelFormat::Apr;
}
if &data[0..4] == b"GGUF" {
return ModelFormat::Gguf;
}
let header_size = u64::from_le_bytes(data[0..8].try_into().unwrap_or([0; 8]));
if header_size > 0 && header_size < 100_000_000 {
return ModelFormat::SafeTensors;
}
ModelFormat::Demo
}
fn print_welcome_banner(path: &Path, config: &ChatConfig) {
let format = detect_format(path);
let model_name = path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("unknown");
let template_format = detect_format_from_name(model_name);
let template_name = match template_format {
TemplateFormat::ChatML => "ChatML",
TemplateFormat::Llama2 => "LLaMA2",
TemplateFormat::Mistral => "Mistral",
TemplateFormat::Phi => "Phi",
TemplateFormat::Alpaca => "Alpaca",
TemplateFormat::Custom => "Custom",
TemplateFormat::Raw => "Raw",
};
match format {
ModelFormat::Apr => {
output::section("Model Chat (APR Format)");
println!();
println!(
"{}",
"Using APR v2 format with mmap (Native Library Mandate)".cyan()
);
}
ModelFormat::Gguf => {
output::section("Model Chat (GGUF Format)");
println!();
println!(
"{}",
"Using GGUF format with realizar inference engine".cyan()
);
}
ModelFormat::SafeTensors => {
output::section("Model Chat (SafeTensors Format)");
println!();
println!(
"{}",
"Using SafeTensors with mmap (Native Library Mandate)".cyan()
);
}
ModelFormat::Demo => {
output::section("Chat Demo (Tiny Model)");
println!();
println!(
"{}",
"Note: Using tiny demo model. Pass .apr, .gguf, or .safetensors file for full model."
.yellow()
);
}
}
println!();
output::kv("Model", path.display());
output::kv("Chat Template", template_name);
output::kv("Temperature", config.temperature);
output::kv("Top-P", config.top_p);
output::kv("Max Tokens", config.max_tokens);
if let Some(system) = &config.system {
output::kv("System", system);
}
if config.inspect {
println!();
println!(
"{}",
"Inspection mode enabled - showing token probabilities".cyan()
);
}
println!();
println!("{}", "Commands:".white().bold());
println!(" /quit Exit the chat");
println!(" /clear Clear conversation history");
println!(" /system Set system prompt");
println!(" /help Show help");
println!();
println!("{}", "═".repeat(60));
println!();
}
include!("chat_session.rs");
include!("chat_generate_session.rs");
include!("chat_04.rs");