#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use anyhow::{Error as E, Result, bail};
use clap::{Parser, ValueEnum};
use candle::{DType, Tensor};
use candle_core as candle;
use candle_nn::VarBuilder;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use hf_hub::{Repo, RepoType, api::sync::Api};
use serde::{Deserialize, Serialize};
use std::io::{self, Write};
use std::{fs, path::Path};
use candle_transformers::models::llama as model;
use model::{Llama, LlamaConfig};
const EOS_TOKEN: &str = "</s>";
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
struct Persona {
pub name: Option<String>,
pub description: String,
pub system_prompt: String,
pub updated_at: Option<String>,
pub sources: Vec<String>,
}
impl Persona {
fn to_system_prefix(&self) -> String {
let name = self.name.clone().unwrap_or_else(|| "Persona".to_string());
format!(
"You are {}. {}\nGuidelines: {}\nStay in character.\n\n",
name, self.description, self.system_prompt
)
}
}
#[derive(Debug, Clone)]
struct VersionPaths {
root: std::path::PathBuf,
persona_json: std::path::PathBuf,
memory_txt: std::path::PathBuf,
}
fn ensure_dir(p: &Path) -> Result<()> {
if !p.exists() {
fs::create_dir_all(p)?;
}
Ok(())
}
fn resolve_version_paths(base: &Path, name: &str) -> VersionPaths {
let root = base.join(name);
VersionPaths {
persona_json: root.join("persona.json"),
memory_txt: root.join("memory.txt"),
root,
}
}
fn load_persona_from_file(p: &Path) -> Option<Persona> {
if p.exists() {
if let Ok(s) = fs::read_to_string(p) {
serde_json::from_str(&s).ok()
} else {
None
}
} else {
None
}
}
fn read_memory_excerpt(path: &Path, max_chars: usize) -> Result<Option<String>> {
if !path.exists() {
return Ok(None);
}
let content = fs::read_to_string(path)?;
let len = content.chars().count();
if len <= max_chars {
return Ok(Some(content));
}
let excerpt: String = content
.chars()
.rev()
.take(max_chars)
.collect::<String>()
.chars()
.rev()
.collect();
Ok(Some(excerpt))
}
fn append_memory(path: &Path, user: &str, assistant: &str) -> Result<()> {
let mut s = String::new();
s.push_str("User: ");
s.push_str(user);
s.push_str("\n");
s.push_str("Assistant: ");
s.push_str(assistant);
s.push_str("\n---\n");
if let Some(parent) = path.parent() {
ensure_dir(parent)?;
}
let mut f = fs::OpenOptions::new()
.create(true)
.append(true)
.open(path)?;
use std::io::Write as IoWrite;
f.write_all(s.as_bytes())?;
Ok(())
}
fn read_feed_texts(feed_path: &str) -> Result<(String, Vec<String>)> {
let path = Path::new(feed_path);
let mut texts = Vec::new();
let mut sources = Vec::new();
if path.is_file() {
eprintln!("Reading feed file: {}", path.to_string_lossy());
io::stderr().flush().ok();
let content = fs::read_to_string(path)?;
texts.push(content);
sources.push(path.to_string_lossy().to_string());
} else if path.is_dir() {
let mut paths = Vec::new();
for entry in fs::read_dir(path)? {
let entry = entry?;
let p = entry.path();
if p.extension()
.and_then(|s| s.to_str())
.map(|s| s.eq_ignore_ascii_case("txt"))
.unwrap_or(false)
{
paths.push(p);
}
}
eprintln!(
"Found {} .txt feed file(s) in {}",
paths.len(),
path.to_string_lossy()
);
io::stderr().flush().ok();
for (i, p) in paths.iter().enumerate() {
eprintln!(
"Reading feed file ({}/{}): {}",
i + 1,
paths.len(),
p.to_string_lossy()
);
io::stderr().flush().ok();
let content = fs::read_to_string(&p)?;
texts.push(content);
sources.push(p.to_string_lossy().to_string());
}
} else {
bail!("feed_path not found: {}", feed_path);
}
let mut combined = String::new();
let mut total = 0usize;
for t in texts {
if total > 200_000 {
break;
}
let take = t.chars().take(50_000).collect::<String>();
total += take.len();
combined.push_str("\n---\n");
combined.push_str(&take);
}
eprintln!(
"Combined {} characters from {} source(s).",
total,
sources.len()
);
io::stderr().flush().ok();
Ok((combined, sources))
}
fn build_persona_update_prompt(existing: Option<&Persona>, feed_excerpt: &str) -> String {
let mut prompt = String::from(
"You are an assistant that builds a concise persona profile from given text feeds.\n\n",
);
if let Some(p) = existing {
prompt.push_str("Current persona JSON:\n");
prompt.push_str(&serde_json::to_string_pretty(p).unwrap_or_default());
prompt.push_str("\n\nUpdate the persona with the new feed while keeping consistency.\n");
} else {
prompt.push_str("Create a new persona based on the feed.\n");
}
prompt.push_str("Return ONLY valid compact JSON with fields: name (optional), description (string), system_prompt (string), sources (array of strings). Do not include any other text.\n\nFEED:\n");
prompt.push_str(feed_excerpt);
prompt.push_str("\n\nJSON:\n");
prompt
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Which {
V1,
V2,
V3,
V31,
V3Instruct,
V31Instruct,
V32_1b,
V32_1bInstruct,
V32_3b,
V32_3bInstruct,
#[value(name = "solar-10.7b")]
Solar10_7B,
#[value(name = "tiny-llama-1.1b-chat")]
TinyLlama1_1BChat,
#[value(name = "SmoLM2-1.7B")]
SmolLM2_1B,
#[value(name = "SmoLM2-1.7B-Instruct")]
SmolLM2_1BInstruct,
#[value(name = "SmoLM2-360M")]
SmolLM2_360M,
#[value(name = "SmoLM2-360M-Instruct")]
SmolLM2_360MInstruct,
#[value(name = "SmoLM2-135M")]
SmolLM2_135M,
#[value(name = "SmoLM2-135M-Instruct")]
SmolLM2_135MInstruct,
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Engine {
#[value(name = "llamacpp")]
LlamaCpp,
#[value(name = "candle")]
Candle,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None, disable_version_flag = true)]
struct Args {
#[arg(long)]
cpu: bool,
#[arg(long)]
persona_file: Option<String>,
#[arg(long)]
feed_path: Option<String>,
#[arg(long)]
update_persona: bool,
#[arg(long)]
chat: bool,
#[arg(long)]
dry_run: bool,
#[arg(long)]
auto_train: bool,
#[arg(long, default_value = "./versions")]
version_dir: Option<String>,
#[arg(long)]
version: Option<String>,
#[arg(long)]
save_as_version: Option<String>,
#[arg(long, default_value_t = 8)]
history_max_turns: usize,
#[arg(long, default_value_t = 0.8)]
temperature: f64,
#[arg(long)]
top_p: Option<f64>,
#[arg(long)]
top_k: Option<usize>,
#[arg(long, default_value_t = 299792458)]
seed: u64,
#[arg(short = 'n', long, default_value_t = 256)]
sample_len: usize,
#[arg(long, default_value_t = 16)]
min_tokens: usize,
#[arg(long)]
no_kv_cache: bool,
#[arg(long)]
prompt: Option<String>,
#[arg(long)]
dtype: Option<String>,
#[arg(long)]
tracing: bool,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long, default_value = "SmoLM2-1.7B-Instruct")]
which: Which,
#[arg(long)]
use_flash_attn: bool,
#[arg(long)]
features: Option<String>,
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
#[arg(long, default_value_t = 128)]
repeat_last_n: usize,
#[arg(long, value_enum, default_value = "llamacpp")]
engine: Engine,
#[arg(long, default_value = "http://127.0.0.1:8080")]
llama_url: String,
}
fn main() -> Result<()> {
use tokenizers::Tokenizer;
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
if let Some(ref feats) = args.features {
eprintln!(
"Note: --features '{}' is accepted for compatibility but is a no-op at runtime. Use cargo features to enable CUDA/MPS/etc., or use --cpu to force CPU.",
feats
);
}
#[cfg(feature = "cuda")]
{
eprintln!(
"[build] CUDA feature enabled. To use NVIDIA GPU, compile with `--features cuda` and select `--engine candle` (default engine is llama.cpp). Use `--cpu` to force CPU."
);
}
#[cfg(not(feature = "cuda"))]
{
eprintln!(
"[build] CUDA feature not enabled. For GPU support, rebuild with `--features cuda` and select `--engine candle`."
);
}
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
if args.dry_run && !args.chat && !args.update_persona {
let version_paths: Option<VersionPaths> = match (&args.version_dir, &args.version) {
(Some(dir), Some(name)) => Some(resolve_version_paths(Path::new(dir), name)),
_ => None,
};
let existing_persona: Option<Persona> = if let Some(ref pf) = args.persona_file {
if Path::new(pf).exists() {
match fs::read_to_string(pf) {
Ok(s) => serde_json::from_str(&s).ok(),
Err(_) => None,
}
} else {
None
}
} else {
None
};
let active_persona: Option<Persona> = if let Some(vp) = &version_paths {
load_persona_from_file(&vp.persona_json).or_else(|| existing_persona.clone())
} else {
existing_persona.clone()
};
let base_prompt = args
.prompt
.as_ref()
.map_or(DEFAULT_PROMPT.to_string(), |p| p.clone());
let mut final_prompt = String::new();
if let Some(p) = active_persona.as_ref() {
final_prompt.push_str(&p.to_system_prefix());
}
if let Some(vp) = &version_paths {
if let Ok(Some(mem)) = read_memory_excerpt(&vp.memory_txt, 20_000) {
final_prompt.push_str("Long-term memory (use to stay consistent):\n");
final_prompt.push_str(&mem);
final_prompt.push_str("\n\n");
}
}
final_prompt.push_str(&base_prompt);
if !final_prompt.trim_end().ends_with("Assistant:") {
final_prompt.push_str("\nAssistant: ");
}
let model_id = args.model_id.clone().unwrap_or_else(|| {
let str = match args.which {
Which::V1 => "Narsil/amall-7b",
Which::V2 => "meta-llama/Llama-2-7b-hf",
Which::V3 => "meta-llama/Meta-Llama-3-8B",
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct",
Which::V31 => "meta-llama/Llama-3.1-8B",
Which::V31Instruct => "meta-llama/Llama-3.1-8B-Instruct",
Which::V32_1b => "meta-llama/Llama-3.2-1B",
Which::V32_1bInstruct => "meta-llama/Llama-3.2-1B-Instruct",
Which::V32_3b => "meta-llama/Llama-3.2-3B",
Which::V32_3bInstruct => "meta-llama/Llama-3.2-3B-Instruct",
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0",
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
Which::SmolLM2_135M => "HuggingFaceTB/SmolLM2-135M",
Which::SmolLM2_135MInstruct => "HuggingFaceTB/SmolLM2-135M-Instruct",
Which::SmolLM2_360M => "HuggingFaceTB/SmolLM2-360M",
Which::SmolLM2_360MInstruct => "HuggingFaceTB/SmolLM2-360M-Instruct",
Which::SmolLM2_1B => "HuggingFaceTB/SmolLM2-1.7B",
Which::SmolLM2_1BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
};
str.to_string()
});
let backend = match args.engine {
Engine::LlamaCpp => format!("llama.cpp @ {}", args.llama_url),
Engine::Candle => "candle".to_string(),
};
let model_line = if matches!(args.engine, Engine::LlamaCpp) {
args.model_id
.clone()
.unwrap_or_else(|| "server-default".to_string())
} else {
model_id.clone()
};
println!(
"[dry-run] Backend: {}\n[dry-run] Model: {}\n[dry-run] Sample length: {}\n\n[dry-run] Final prompt that would be sent:\n{}",
backend, model_line, args.sample_len, final_prompt
);
io::stdout().flush().ok();
return Ok(());
}
if matches!(args.engine, Engine::LlamaCpp) {
return run_with_llamacpp(&args);
}
let device = candle_examples::device(args.cpu)?;
let dtype = match args.dtype.as_deref() {
Some("f16") => DType::F16,
Some("bf16") => DType::BF16,
Some("f32") => DType::F32,
Some(dtype) => bail!("Unsupported dtype {dtype}"),
None => DType::F16,
};
let (llama, tokenizer_filename, mut cache, config) = {
let api = Api::new()?;
let model_id = args.model_id.unwrap_or_else(|| {
let str = match args.which {
Which::V1 => "Narsil/amall-7b",
Which::V2 => "meta-llama/Llama-2-7b-hf",
Which::V3 => "meta-llama/Meta-Llama-3-8B",
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct",
Which::V31 => "meta-llama/Llama-3.1-8B",
Which::V31Instruct => "meta-llama/Llama-3.1-8B-Instruct",
Which::V32_1b => "meta-llama/Llama-3.2-1B",
Which::V32_1bInstruct => "meta-llama/Llama-3.2-1B-Instruct",
Which::V32_3b => "meta-llama/Llama-3.2-3B",
Which::V32_3bInstruct => "meta-llama/Llama-3.2-3B-Instruct",
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0",
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
Which::SmolLM2_135M => "HuggingFaceTB/SmolLM2-135M",
Which::SmolLM2_135MInstruct => "HuggingFaceTB/SmolLM2-135M-Instruct",
Which::SmolLM2_360M => "HuggingFaceTB/SmolLM2-360M",
Which::SmolLM2_360MInstruct => "HuggingFaceTB/SmolLM2-360M-Instruct",
Which::SmolLM2_1B => "HuggingFaceTB/SmolLM2-1.7B",
Which::SmolLM2_1BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
};
str.to_string()
});
println!("loading the model weights from {model_id}");
let revision = args.revision.unwrap_or("main".to_string());
eprintln!(
"Preparing repository {model_id}@{revision} (may download from Hugging Face cache)..."
);
io::stderr().flush().ok();
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
eprintln!("Fetching tokenizer.json (may download if not cached)...");
io::stderr().flush().ok();
let tokenizer_filename = api.get("tokenizer.json")?;
eprintln!("Fetching config.json...");
io::stderr().flush().ok();
let config_filename = api.get("config.json")?;
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let config = config.into_config(args.use_flash_attn);
eprintln!("Resolving model weight files...");
io::stderr().flush().ok();
let filenames = match args.which {
Which::V1
| Which::V2
| Which::V3
| Which::V3Instruct
| Which::V31
| Which::V31Instruct
| Which::V32_3b
| Which::V32_3bInstruct
| Which::Solar10_7B => {
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
}
Which::SmolLM2_360M
| Which::SmolLM2_360MInstruct
| Which::SmolLM2_135M
| Which::SmolLM2_135MInstruct
| Which::SmolLM2_1B
| Which::SmolLM2_1BInstruct
| Which::V32_1b
| Which::V32_1bInstruct
| Which::TinyLlama1_1BChat => {
vec![api.get("model.safetensors")?]
}
};
eprintln!(
"Initializing KV cache (use_kv_cache={})...",
!args.no_kv_cache
);
io::stderr().flush().ok();
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
eprintln!("Memory-mapping {} weight file(s)...", filenames.len());
io::stderr().flush().ok();
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
eprintln!("Loading model graph (this can take a while on first run)...");
io::stderr().flush().ok();
(Llama::load(vb, &config)?, tokenizer_filename, cache, config)
};
let tokenizer = Tokenizer::from_file(&tokenizer_filename).map_err(E::msg)?;
let eos_token_id = config.eos_token_id.clone().or_else(|| {
tokenizer
.token_to_id(EOS_TOKEN)
.map(model::LlamaEosToks::Single)
});
let existing_persona: Option<Persona> = if let Some(ref pf) = args.persona_file {
if Path::new(pf).exists() {
match fs::read_to_string(pf) {
Ok(s) => serde_json::from_str(&s).ok(),
Err(_) => None,
}
} else {
None
}
} else {
None
};
let version_paths: Option<VersionPaths> = match (&args.version_dir, &args.version) {
(Some(dir), Some(name)) => {
let base = Path::new(dir);
let vp = resolve_version_paths(base, name);
if let Err(e) = ensure_dir(&vp.root) {
eprintln!("warning: could not create version dir: {e}");
}
Some(vp)
}
_ => None,
};
let active_persona: Option<Persona> = if let Some(vp) = &version_paths {
load_persona_from_file(&vp.persona_json).or_else(|| existing_persona.clone())
} else {
existing_persona.clone()
};
if let Some(name) = &args.save_as_version {
let version_dir = args
.version_dir
.as_ref()
.ok_or_else(|| E::msg("--version-dir is required with --save-as-version"))?;
let vp = resolve_version_paths(Path::new(version_dir), name);
ensure_dir(&vp.root)?;
let persona_to_save: Persona = active_persona.clone().unwrap_or_default();
let json = serde_json::to_string_pretty(&persona_to_save)?;
fs::write(&vp.persona_json, json)?;
if !vp.memory_txt.exists() {
fs::write(&vp.memory_txt, "")?;
}
println!(
"Saved version '{}' into {}",
name,
vp.root.to_string_lossy()
);
return Ok(());
}
if args.update_persona {
let feed_path = args
.feed_path
.as_ref()
.ok_or_else(|| E::msg("--feed-path is required when --update-persona is set"))?;
let persona_file = args
.persona_file
.as_ref()
.ok_or_else(|| E::msg("--persona-file is required when --update-persona is set"))?;
let (feed_excerpt, sources) = read_feed_texts(feed_path)?;
let prompt_text = build_persona_update_prompt(active_persona.as_ref(), &feed_excerpt);
let mut tokens = tokenizer
.encode(prompt_text.as_str(), true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
let mut logits_processor = {
let temperature = args.temperature;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (args.top_k, args.top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(args.seed, sampling)
};
let mut index_pos = 0;
let mut out = String::new();
let start = std::time::Instant::now();
let mut last_report = start;
let mut generated: usize = 0;
eprintln!(
"[persona update] Generating persona JSON (max {} tokens)...",
args.sample_len
);
io::stderr().flush().ok();
for index in 0..args.sample_len {
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
(1, index_pos)
} else {
(tokens.len(), 0)
};
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = llama.forward(&input, context_index, &mut cache)?;
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&tokens[start_at..],
)?
};
index_pos += ctxt.len();
let logits = match &eos_token_id {
Some(model::LlamaEosToks::Single(eos_id)) if generated < args.min_tokens => {
let mut data = logits.to_vec1::<f32>()?;
let i = *eos_id as usize;
if i < data.len() {
data[i] = f32::NEG_INFINITY;
}
Tensor::new(&data[..], &device)?
}
Some(model::LlamaEosToks::Multiple(ids)) if generated < args.min_tokens => {
let mut data = logits.to_vec1::<f32>()?;
for id in ids {
let i = *id as usize;
if i < data.len() {
data[i] = f32::NEG_INFINITY;
}
}
Tensor::new(&data[..], &device)?
}
_ => logits,
};
let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
generated += 1;
let now = std::time::Instant::now();
if now.duration_since(last_report).as_millis() >= 750 {
let elapsed = now.duration_since(start).as_secs_f64();
let rate = if elapsed > 0.0 {
generated as f64 / elapsed
} else {
0.0
};
let rem = args.sample_len.saturating_sub(generated);
let eta = if rate > 0.0 {
(rem as f64 / rate) as u64
} else {
0
};
let mm = eta / 60;
let ss = eta % 60;
eprint!(
"\r[persona update] {} / {} tokens ({:.1} tok/s, ETA {:02}:{:02})",
generated, args.sample_len, rate, mm, ss
);
io::stderr().flush().ok();
last_report = now;
}
match eos_token_id {
Some(model::LlamaEosToks::Single(eos_tok_id))
if next_token == eos_tok_id && generated >= args.min_tokens =>
{
break;
}
Some(model::LlamaEosToks::Multiple(ref eos_ids))
if eos_ids.contains(&next_token) && generated >= args.min_tokens =>
{
break;
}
_ => (),
}
if let Some(t) = tokenizer.next_token(next_token)? {
out.push_str(&t);
}
}
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
out.push_str(&rest);
}
eprintln!("\n[persona update] Done. Generated {} token(s).", generated);
io::stderr().flush().ok();
let mut persona: Persona = match serde_json::from_str(out.trim()) {
Ok(p) => p,
Err(_) => Persona {
name: None,
description: out.trim().to_string(),
system_prompt: out.trim().to_string(),
updated_at: None,
sources: vec![],
},
};
persona.sources = sources;
if persona.updated_at.is_none() {
if let Ok(dur) = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
persona.updated_at = Some(format!("{}", dur.as_secs()));
}
}
let json = serde_json::to_string_pretty(&persona)?;
fs::write(persona_file, json)?;
println!("Persona updated and saved to {}", persona_file);
return Ok(());
}
if args.chat {
if args.auto_train && version_paths.is_none() {
eprintln!(
"warning: --auto-train set but no --version-dir/--version provided; memory will not be saved."
);
}
let mut history: Vec<(String, String)> = Vec::new();
let stdin = io::stdin();
println!("Chat started. Type 'exit' or Ctrl-D to quit.\n");
loop {
print!("You> ");
std::io::stdout().flush()?;
let mut user = String::new();
if stdin.read_line(&mut user).is_err() {
break;
}
let user = user.trim().to_string();
if user.is_empty() || user.eq_ignore_ascii_case("exit") {
break;
}
let mut prefix = String::new();
if let Some(p) = active_persona.as_ref() {
prefix.push_str(&p.to_system_prefix());
}
if let Some(vp) = &version_paths {
if let Ok(Some(mem)) = read_memory_excerpt(&vp.memory_txt, 20_000) {
prefix.push_str("Long-term memory (use to stay consistent):\n");
prefix.push_str(&mem);
prefix.push_str("\n\n");
}
}
let mut convo = String::new();
let start = history.len().saturating_sub(args.history_max_turns);
for (u, a) in history.iter().skip(start) {
convo.push_str("User: ");
convo.push_str(u);
convo.push_str("\n");
convo.push_str("Assistant: ");
convo.push_str(a);
convo.push_str("\n");
}
let final_prompt = format!("{}{}User: {}\nAssistant: ", prefix, convo, user);
let mut cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let mut tokens = Tokenizer::from_file(tokenizer_filename.clone())
.map_err(E::msg)?
.encode(final_prompt.as_str(), true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut tok_stream = {
let t = Tokenizer::from_file(tokenizer_filename.clone()).map_err(E::msg)?;
candle_examples::token_output_stream::TokenOutputStream::new(t)
};
let mut logits_processor = {
let temperature = args.temperature;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (args.top_k, args.top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(args.seed, sampling)
};
let mut out = String::new();
let mut index_pos = 0;
let start = std::time::Instant::now();
let mut last_report = start;
let mut generated: usize = 0;
eprintln!(
"[chat] Generating response (max {} tokens)...",
args.sample_len
);
io::stderr().flush().ok();
print!("Assistant> ");
std::io::stdout().flush()?;
for index in 0..args.sample_len {
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
(1, index_pos)
} else {
(tokens.len(), 0)
};
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = llama.forward(&input, context_index, &mut cache)?;
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&tokens[start_at..],
)?
};
index_pos += ctxt.len();
let logits = match &eos_token_id {
Some(model::LlamaEosToks::Single(eos_id)) if generated < args.min_tokens => {
let mut data = logits.to_vec1::<f32>()?;
let i = *eos_id as usize;
if i < data.len() {
data[i] = f32::NEG_INFINITY;
}
Tensor::new(&data[..], &device)?
}
Some(model::LlamaEosToks::Multiple(ids)) if generated < args.min_tokens => {
let mut data = logits.to_vec1::<f32>()?;
for id in ids {
let i = *id as usize;
if i < data.len() {
data[i] = f32::NEG_INFINITY;
}
}
Tensor::new(&data[..], &device)?
}
_ => logits,
};
let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
generated += 1;
let now = std::time::Instant::now();
if now.duration_since(last_report).as_millis() >= 750 {
let elapsed = now.duration_since(start).as_secs_f64();
let rate = if elapsed > 0.0 {
generated as f64 / elapsed
} else {
0.0
};
let rem = args.sample_len.saturating_sub(generated);
let eta = if rate > 0.0 {
(rem as f64 / rate) as u64
} else {
0
};
let mm = eta / 60;
let ss = eta % 60;
eprintln!(
"[chat] {} / {} tokens ({:.1} tok/s, ETA {:02}:{:02})",
generated, args.sample_len, rate, mm, ss
);
io::stderr().flush().ok();
last_report = now;
}
match eos_token_id {
Some(model::LlamaEosToks::Single(eos_tok_id))
if next_token == eos_tok_id && generated >= args.min_tokens =>
{
break;
}
Some(model::LlamaEosToks::Multiple(ref eos_ids))
if eos_ids.contains(&next_token) && generated >= args.min_tokens =>
{
break;
}
_ => (),
}
if let Some(t) = tok_stream.next_token(next_token)? {
out.push_str(&t);
print!("{t}");
std::io::stdout().flush()?;
}
}
if let Some(rest) = tok_stream.decode_rest().map_err(E::msg)? {
out.push_str(&rest);
print!("{rest}");
}
eprintln!("\n[chat] Done. Generated {} token(s).", generated);
io::stderr().flush().ok();
println!("");
if args.auto_train {
if let Some(vp) = &version_paths {
append_memory(&vp.memory_txt, &user, out.trim())?;
}
}
history.push((user, out.trim().to_string()));
}
return Ok(());
}
let base_prompt = args
.prompt
.as_ref()
.map_or(DEFAULT_PROMPT.to_string(), |p| p.clone());
let mut final_prompt = String::new();
if let Some(p) = active_persona.as_ref() {
final_prompt.push_str(&p.to_system_prefix());
}
if let Some(vp) = &version_paths {
if let Ok(Some(mem)) = read_memory_excerpt(&vp.memory_txt, 20_000) {
final_prompt.push_str("Long-term memory (use to stay consistent):\n");
final_prompt.push_str(&mem);
final_prompt.push_str("\n\n");
}
}
final_prompt.push_str(&base_prompt);
if !final_prompt.trim_end().ends_with("Assistant:") {
final_prompt.push_str("\nAssistant: ");
}
let mut tokens = Tokenizer::from_file(tokenizer_filename.clone())
.map_err(E::msg)?
.encode(final_prompt.as_str(), true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(
Tokenizer::from_file(tokenizer_filename.clone()).map_err(E::msg)?,
);
println!("starting the inference loop");
print!("{final_prompt}");
std::io::stdout().flush()?;
let mut logits_processor = {
let temperature = args.temperature;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (args.top_k, args.top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(args.seed, sampling)
};
let mut start_gen = std::time::Instant::now();
let mut index_pos = 0;
let mut token_generated = 0;
let mut last_report = std::time::Instant::now();
eprintln!("[gen] Generating (max {} tokens)...", args.sample_len);
io::stderr().flush().ok();
for index in 0..args.sample_len {
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
(1, index_pos)
} else {
(tokens.len(), 0)
};
if index == 1 {
start_gen = std::time::Instant::now()
}
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = llama.forward(&input, context_index, &mut cache)?;
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&tokens[start_at..],
)?
};
index_pos += ctxt.len();
let logits = match &eos_token_id {
Some(model::LlamaEosToks::Single(eos_id)) if token_generated < args.min_tokens => {
let mut data = logits.to_vec1::<f32>()?;
let i = *eos_id as usize;
if i < data.len() {
data[i] = f32::NEG_INFINITY;
}
Tensor::new(&data[..], &device)?
}
Some(model::LlamaEosToks::Multiple(ids)) if token_generated < args.min_tokens => {
let mut data = logits.to_vec1::<f32>()?;
for id in ids {
let i = *id as usize;
if i < data.len() {
data[i] = f32::NEG_INFINITY;
}
}
Tensor::new(&data[..], &device)?
}
_ => logits,
};
let next_token = logits_processor.sample(&logits)?;
token_generated += 1;
tokens.push(next_token);
let now = std::time::Instant::now();
if now.duration_since(last_report).as_millis() >= 750 {
let elapsed = now.duration_since(start_gen).as_secs_f64();
let rate = if elapsed > 0.0 {
token_generated as f64 / elapsed
} else {
0.0
};
let rem = args.sample_len.saturating_sub(token_generated);
let eta = if rate > 0.0 {
(rem as f64 / rate) as u64
} else {
0
};
let mm = eta / 60;
let ss = eta % 60;
eprintln!(
"[gen] {} / {} tokens ({:.1} tok/s, ETA {:02}:{:02})",
token_generated, args.sample_len, rate, mm, ss
);
io::stderr().flush().ok();
last_report = now;
}
match eos_token_id {
Some(model::LlamaEosToks::Single(eos_tok_id))
if next_token == eos_tok_id && token_generated >= args.min_tokens =>
{
break;
}
Some(model::LlamaEosToks::Multiple(ref eos_ids))
if eos_ids.contains(&next_token) && token_generated >= args.min_tokens =>
{
break;
}
_ => (),
}
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
eprintln!("\n[gen] Done. Generated {} token(s).", token_generated);
io::stderr().flush().ok();
let dt = start_gen.elapsed();
println!(
"\n\n{} tokens generated ({} token/s)\n",
token_generated,
(token_generated - 1) as f64 / dt.as_secs_f64(),
);
Ok(())
}
fn extract_text_from_llamacpp_response(v: &serde_json::Value) -> Option<String> {
if let Some(choices) = v.get("choices").and_then(|c| c.as_array()) {
if let Some(first) = choices.first() {
if let Some(text) = first.get("text").and_then(|t| t.as_str()) {
return Some(text.to_string());
}
if let Some(msg) = first
.get("message")
.and_then(|m| m.get("content"))
.and_then(|t| t.as_str())
{
return Some(msg.to_string());
}
}
}
if let Some(s) = v.get("content").and_then(|t| t.as_str()) {
return Some(s.to_string());
}
if let Some(s) = v.get("completion").and_then(|t| t.as_str()) {
return Some(s.to_string());
}
if let Some(s) = v.get("text").and_then(|t| t.as_str()) {
return Some(s.to_string());
}
None
}
fn llamacpp_complete(
llama_url: &str,
model: Option<&str>,
prompt: &str,
args: &Args,
) -> Result<String> {
let client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(120))
.build()
.map_err(E::msg)?;
let base = llama_url.trim_end_matches('/');
let stops: Vec<&str> = vec!["</s>", "\nUser:", "User:"];
let mut body = serde_json::Map::new();
if let Some(m) = model {
body.insert(
"model".to_string(),
serde_json::Value::String(m.to_string()),
);
}
body.insert(
"prompt".to_string(),
serde_json::Value::String(prompt.to_string()),
);
body.insert(
"max_tokens".to_string(),
serde_json::Value::Number((args.sample_len as u64).into()),
);
body.insert(
"temperature".to_string(),
serde_json::Value::Number(
serde_json::Number::from_f64(args.temperature)
.unwrap_or_else(|| serde_json::Number::from_f64(0.8).unwrap()),
),
);
if let Some(tp) = args.top_p {
body.insert(
"top_p".to_string(),
serde_json::Value::Number(serde_json::Number::from_f64(tp).unwrap()),
);
}
if let Some(tk) = args.top_k {
body.insert(
"top_k".to_string(),
serde_json::Value::Number((tk as u64).into()),
);
}
body.insert(
"stop".to_string(),
serde_json::Value::Array(
stops
.iter()
.map(|s| serde_json::Value::String(s.to_string()))
.collect(),
),
);
body.insert("stream".to_string(), serde_json::Value::Bool(false));
body.insert(
"repeat_penalty".to_string(),
serde_json::Value::Number(
serde_json::Number::from_f64(args.repeat_penalty as f64).unwrap(),
),
);
let url1 = format!("{}/v1/completions", base);
if let Ok(resp) = client
.post(&url1)
.json(&serde_json::Value::Object(body.clone()))
.send()
{
if resp.status().is_success() {
if let Ok(val) = resp.json::<serde_json::Value>() {
if let Some(text) = extract_text_from_llamacpp_response(&val) {
return Ok(text);
}
}
}
}
let mut body2 = serde_json::Map::new();
if let Some(m) = model {
body2.insert(
"model".to_string(),
serde_json::Value::String(m.to_string()),
);
}
body2.insert(
"prompt".to_string(),
serde_json::Value::String(prompt.to_string()),
);
body2.insert(
"n_predict".to_string(),
serde_json::Value::Number((args.sample_len as u64).into()),
);
body2.insert(
"temperature".to_string(),
serde_json::Value::Number(
serde_json::Number::from_f64(args.temperature)
.unwrap_or_else(|| serde_json::Number::from_f64(0.8).unwrap()),
),
);
if let Some(tp) = args.top_p {
body2.insert(
"top_p".to_string(),
serde_json::Value::Number(serde_json::Number::from_f64(tp).unwrap()),
);
}
if let Some(tk) = args.top_k {
body2.insert(
"top_k".to_string(),
serde_json::Value::Number((tk as u64).into()),
);
}
body2.insert(
"repeat_penalty".to_string(),
serde_json::Value::Number(
serde_json::Number::from_f64(args.repeat_penalty as f64).unwrap(),
),
);
body2.insert("cache_prompt".to_string(), serde_json::Value::Bool(true));
body2.insert("stream".to_string(), serde_json::Value::Bool(false));
body2.insert(
"stop".to_string(),
serde_json::Value::Array(
stops
.iter()
.map(|s| serde_json::Value::String(s.to_string()))
.collect(),
),
);
let url2 = format!("{}/completion", base);
let resp2 = match client
.post(&url2)
.json(&serde_json::Value::Object(body2))
.send()
{
Ok(r) => r,
Err(e) => {
let extra = if e.is_connect() {
"Connection refused or unreachable host. Is the llama.cpp server running?"
} else if e.is_timeout() {
"Request timed out. The server may be busy or the model is slow; consider increasing timeout."
} else {
"Request failed."
};
bail!(format!(
"Could not contact llama.cpp at {} (endpoint: {}). {}\nHints:\n- Start the server, e.g.: ./server -m /path/to/model.gguf --host 127.0.0.1 --port 8080\n- Or point to a different URL via LLAMA_CPP_URL or --llama-url\n- For a quick check without a server: use --dry-run\n- Alternatively, switch to the Candle backend: --engine candle",
llama_url, url2, extra
));
}
};
if !resp2.status().is_success() {
bail!(format!(
"llama.cpp server error {} on {}. If using multiple models, ensure --model-id matches a loaded GGUF name.",
resp2.status(),
url2
));
}
let val = resp2.json::<serde_json::Value>()?;
if let Some(text) = extract_text_from_llamacpp_response(&val) {
return Ok(text);
}
bail!(format!("Could not parse llama.cpp response: {}", val));
}
fn run_with_llamacpp(args: &Args) -> Result<()> {
let llama_url = std::env::var("LLAMA_CPP_URL").unwrap_or_else(|_| args.llama_url.clone());
let existing_persona: Option<Persona> = if let Some(ref pf) = args.persona_file {
if Path::new(pf).exists() {
match fs::read_to_string(pf) {
Ok(s) => serde_json::from_str(&s).ok(),
Err(_) => None,
}
} else {
None
}
} else {
None
};
let version_paths: Option<VersionPaths> = match (&args.version_dir, &args.version) {
(Some(dir), Some(name)) => {
let base = Path::new(dir);
let vp = resolve_version_paths(base, name);
if let Err(e) = ensure_dir(&vp.root) {
eprintln!("warning: could not create version dir: {e}");
}
Some(vp)
}
_ => None,
};
let active_persona: Option<Persona> = if let Some(vp) = &version_paths {
load_persona_from_file(&vp.persona_json).or_else(|| existing_persona.clone())
} else {
existing_persona.clone()
};
if let Some(name) = &args.save_as_version {
let version_dir = args
.version_dir
.as_ref()
.ok_or_else(|| E::msg("--version-dir is required with --save-as-version"))?;
let vp = resolve_version_paths(Path::new(version_dir), name);
ensure_dir(&vp.root)?;
let persona_to_save: Persona = active_persona.clone().unwrap_or_default();
let json = serde_json::to_string_pretty(&persona_to_save)?;
fs::write(&vp.persona_json, json)?;
if !vp.memory_txt.exists() {
fs::write(&vp.memory_txt, "")?;
}
println!(
"Saved version '{}' into {}",
name,
vp.root.to_string_lossy()
);
return Ok(());
}
if args.update_persona {
let feed_path = args
.feed_path
.as_ref()
.ok_or_else(|| E::msg("--feed-path is required when --update-persona is set"))?;
let persona_file = args
.persona_file
.as_ref()
.ok_or_else(|| E::msg("--persona-file is required when --update-persona is set"))?;
let (feed_excerpt, sources) = read_feed_texts(feed_path)?;
let prompt_text = build_persona_update_prompt(active_persona.as_ref(), &feed_excerpt);
eprintln!(
"[persona update] Requesting completion from llama.cpp at {} ...",
llama_url
);
io::stderr().flush().ok();
let out = llamacpp_complete(&llama_url, args.model_id.as_deref(), &prompt_text, args)?;
let mut persona: Persona = match serde_json::from_str(out.trim()) {
Ok(p) => p,
Err(_) => Persona {
name: None,
description: out.trim().to_string(),
system_prompt: out.trim().to_string(),
updated_at: None,
sources: vec![],
},
};
let mut unique_sources = sources;
unique_sources.sort();
unique_sources.dedup();
persona.sources = unique_sources;
if persona.updated_at.is_none() {
if let Ok(dur) = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
persona.updated_at = Some(format!("{}", dur.as_secs()));
}
}
let json = serde_json::to_string_pretty(&persona)?;
fs::write(persona_file, json)?;
println!("Persona updated and saved to {}", persona_file);
return Ok(());
}
if args.chat {
if args.auto_train && version_paths.is_none() {
eprintln!(
"warning: --auto-train set but no --version-dir/--version provided; memory will not be saved."
);
}
let mut history: Vec<(String, String)> = Vec::new();
let stdin = io::stdin();
println!("Chat started. Type 'exit' or Ctrl-D to quit.\n");
loop {
print!("You> ");
std::io::stdout().flush()?;
let mut user = String::new();
if stdin.read_line(&mut user).is_err() {
break;
}
let user = user.trim().to_string();
if user.is_empty() || user.eq_ignore_ascii_case("exit") {
break;
}
let mut prefix = String::new();
if let Some(p) = active_persona.as_ref() {
prefix.push_str(&p.to_system_prefix());
}
if let Some(vp) = &version_paths {
if let Ok(Some(mem)) = read_memory_excerpt(&vp.memory_txt, 20_000) {
prefix.push_str("Long-term memory (use to stay consistent):\n");
prefix.push_str(&mem);
prefix.push_str("\n\n");
}
}
let mut convo = String::new();
let start = history.len().saturating_sub(args.history_max_turns);
for (u, a) in history.iter().skip(start) {
convo.push_str("User: ");
convo.push_str(u);
convo.push_str("\n");
convo.push_str("Assistant: ");
convo.push_str(a);
convo.push_str("\n");
}
let final_prompt = format!("{}{}User: {}\nAssistant: ", prefix, convo, user);
print!("Assistant> ");
std::io::stdout().flush()?;
let out = llamacpp_complete(&llama_url, args.model_id.as_deref(), &final_prompt, args)?;
println!("{}", out.trim());
if args.auto_train {
if let Some(vp) = &version_paths {
append_memory(&vp.memory_txt, &user, out.trim())?;
}
}
history.push((user, out.trim().to_string()));
}
return Ok(());
}
let base_prompt = args
.prompt
.as_ref()
.map_or(DEFAULT_PROMPT.to_string(), |p| p.clone());
let mut final_prompt = String::new();
if let Some(p) = active_persona.as_ref() {
final_prompt.push_str(&p.to_system_prefix());
}
if let Some(vp) = &version_paths {
if let Ok(Some(mem)) = read_memory_excerpt(&vp.memory_txt, 20_000) {
final_prompt.push_str("Long-term memory (use to stay consistent):\n");
final_prompt.push_str(&mem);
final_prompt.push_str("\n\n");
}
}
final_prompt.push_str(&base_prompt);
if !final_prompt.trim_end().ends_with("Assistant:") {
final_prompt.push_str("\nAssistant: ");
}
eprintln!(
"[gen] Requesting completion from llama.cpp at {} ...",
llama_url
);
io::stderr().flush().ok();
let out = llamacpp_complete(&llama_url, args.model_id.as_deref(), &final_prompt, args)?;
println!("{}{}", final_prompt, out);
Ok(())
}