use anyhow::{Context, Result};
use clap::Parser;
use multiscreen_rs::prelude::*;
use sentencepiece_rs::SentencePieceProcessor;
use std::io::{self, Write};
use std::path::{Path, PathBuf};
struct SpTokenizer {
proc: SentencePieceProcessor,
}
impl SpTokenizer {
fn load(path: &Path) -> Result<Self> {
Ok(Self {
proc: SentencePieceProcessor::open(path)
.with_context(|| format!("failed to load {}", path.display()))?,
})
}
fn encode(&self, text: &str) -> Vec<u32> {
self.proc
.encode_to_ids(text)
.unwrap_or_default()
.into_iter()
.map(|id| id as u32)
.collect()
}
fn eos_id(&self) -> Option<u32> {
self.proc.eos_id().map(|id| id as u32)
}
fn id_to_piece(&self, id: u32) -> String {
self.proc
.id_to_piece(id as usize)
.map(|s| s.to_owned())
.unwrap_or_default()
}
}
fn piece_to_text(piece: &str) -> &str {
piece.strip_prefix('\u{2581}').unwrap_or(piece)
}
fn find_tokenizer(run_dir: &Path) -> Result<PathBuf> {
let p = run_dir.join("tokenizer.model");
if p.exists() {
return Ok(p);
}
let p = PathBuf::from("examples/data/tokenizer.model");
if p.exists() {
return Ok(p);
}
anyhow::bail!(
"tokenizer.model not found in {} or examples/data/",
run_dir.display()
)
}
fn sample_token(
logits: &[f32],
temperature: f32,
top_k: usize,
repetition_penalty: f32,
recent_tokens: &[u32],
) -> u32 {
use rand::Rng;
let mut rng = rand::thread_rng();
if temperature <= 0.0 {
let mut best = (0usize, f32::NEG_INFINITY);
for (i, &v) in logits.iter().enumerate() {
let mut score = v;
if recent_tokens.contains(&(i as u32)) && repetition_penalty > 1.0 {
score /= repetition_penalty;
}
if score > best.1 {
best = (i, score);
}
}
return best.0 as u32;
}
let mut scores: Vec<f32> = logits.to_vec();
for &tok in recent_tokens {
let idx = tok as usize;
if idx < scores.len() && repetition_penalty > 1.0 {
if scores[idx] > 0.0 {
scores[idx] /= repetition_penalty;
} else {
scores[idx] *= repetition_penalty;
}
}
}
let scores: Vec<f32> = scores.iter().map(|s| s / temperature).collect();
let mut indexed: Vec<(usize, f32)> = scores.iter().copied().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
indexed.truncate(top_k);
let max_val = indexed[0].1;
let exps: Vec<(usize, f32)> = indexed
.iter()
.map(|&(idx, s)| (idx, (s - max_val).exp()))
.collect();
let sum: f32 = exps.iter().map(|(_, e)| *e).sum();
let probs: Vec<(usize, f32)> = exps.iter().map(|&(idx, e)| (idx, e / sum)).collect();
let r: f32 = rng.r#gen();
let mut cumulative = 0.0f32;
for &(idx, p) in &probs {
cumulative += p;
if r <= cumulative {
return idx as u32;
}
}
probs.last().map(|&(idx, _)| idx as u32).unwrap_or(0)
}
#[derive(Parser)]
#[command(
name = "chat_with_tokenizer",
about = "Chat with a trained Multiscreen model using streaming output"
)]
struct Args {
#[arg(long, default_value = "runs/my-model")]
run_dir: PathBuf,
#[arg(long)]
checkpoint: Option<PathBuf>,
#[arg(long)]
prompt: Option<String>,
#[arg(long, default_value_t = 128)]
max_new_tokens: usize,
#[arg(long, default_value_t = 0.8)]
temperature: f32,
#[arg(long, default_value_t = 40)]
top_k: usize,
#[arg(long, default_value_t = 1.2)]
repetition_penalty: f32,
#[arg(long)]
system_prompt: Option<String>,
}
fn main() -> Result<()> {
let args = Args::parse();
let ckpt = args
.checkpoint
.unwrap_or_else(|| args.run_dir.join("checkpoints/latest.mpk"));
anyhow::ensure!(ckpt.exists(), "checkpoint not found: {}", ckpt.display());
let tok_path = find_tokenizer(&args.run_dir)?;
let sp = SpTokenizer::load(&tok_path)?;
eprintln!("tokenizer: {}", tok_path.display());
let model = ChatModel::load(&ckpt)?;
eprintln!(
"loaded model: {} params",
model.config().estimated_parameter_count()
);
let seq_len = model.config().seq_len;
let vocab_size = model.config().vocab_size;
eprintln!("seq_len={seq_len}, vocab_size={vocab_size}");
let eos_id = sp.eos_id();
let pad_token_id: u32 = 0;
let system_text = args.system_prompt.as_deref().unwrap_or(
"You are หมิว, a shy but sharp ม.6 student who secretly likes ธันวา. \
Speak Thai mixed with English naturally. Be caring indirectly, thoughtful, and concise.",
);
let system_line = format!("system: {system_text}");
eprintln!(
"sampling: temperature={:.2}, top_k={}, repetition_penalty={:.2}",
args.temperature, args.top_k, args.repetition_penalty
);
eprintln!();
let generate = |prompt_text: &str,
max_new_tokens: usize,
temperature: f32,
top_k: usize,
repetition_penalty: f32|
-> Result<String> {
let prompt_ids = sp.encode(prompt_text);
if prompt_ids.is_empty() {
anyhow::bail!("prompt tokenized to empty sequence");
}
let mut output_ids = prompt_ids.clone();
let mut full_text = String::new();
let mut recent_tokens: Vec<u32> = Vec::new();
const RECENT_WINDOW: usize = 16;
for _i in 0..max_new_tokens {
let input = context_window(&output_ids, seq_len, pad_token_id);
let logits_tensor = model.predict_logits(&input)?;
let last_logits = logits_tensor
.slice([0..1, seq_len - 1..seq_len, 0..vocab_size])
.reshape([vocab_size]);
let logits_data = last_logits.into_data();
let logits_vec: Vec<f32> = logits_data.to_vec().unwrap_or_default();
let next_token = sample_token(
&logits_vec,
temperature,
top_k,
repetition_penalty,
&recent_tokens,
);
if Some(next_token) == eos_id {
break;
}
output_ids.push(next_token);
recent_tokens.push(next_token);
if recent_tokens.len() > RECENT_WINDOW {
recent_tokens.remove(0);
}
let piece = sp.id_to_piece(next_token);
let text = piece_to_text(&piece);
full_text.push_str(text);
print!("{text}");
io::stdout().flush().ok();
}
println!(); Ok(full_text)
};
if let Some(prompt) = &args.prompt {
let full_prompt = if prompt.contains(':')
&& (prompt.starts_with("system:") || prompt.starts_with("user:"))
{
prompt.clone()
} else {
format!("{system_line}\nuser: {prompt}\nassistant:")
};
let _ = generate(
&full_prompt,
args.max_new_tokens,
args.temperature,
args.top_k,
args.repetition_penalty,
)?;
return Ok(());
}
eprintln!("interactive mode — type a message, press Enter. Empty line to exit.");
let stdin = io::stdin();
loop {
print!("you> ");
io::stdout().flush()?;
let mut input = String::new();
if stdin.read_line(&mut input)? == 0 {
break;
}
let input = input.trim();
if input.is_empty() {
break;
}
let prompt = format!("{system_line}\nuser: {input}\nassistant:");
print!("ai> ");
io::stdout().flush()?;
let _ = generate(
&prompt,
args.max_new_tokens,
args.temperature,
args.top_k,
args.repetition_penalty,
)?;
}
Ok(())
}
fn context_window(tokens: &[u32], seq_len: usize, pad_token_id: u32) -> Vec<u32> {
if tokens.len() >= seq_len {
tokens[(tokens.len() - seq_len)..].to_vec()
} else {
let pad_count = seq_len - tokens.len();
let mut padded = vec![pad_token_id; pad_count];
padded.extend_from_slice(tokens);
padded
}
}