use std::io::Write as _;
use std::time::Instant;
use clap::Parser;
use oxillama_runtime::{EngineConfig, SamplerConfig, SpeculativeConfig, SpeculativeEngine};
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
#[arg(long, short = 't')]
target: String,
#[arg(long, short = 'd')]
draft: String,
#[arg(long, short = 'p', default_value = "The meaning of life is")]
prompt: String,
#[arg(long, short = 'n', default_value_t = 200)]
max_tokens: usize,
#[arg(long, short = 'k', default_value_t = 5)]
num_speculative: usize,
#[arg(long, default_value_t = 0)]
seed: u64,
#[arg(long, default_value_t = 4)]
target_threads: usize,
#[arg(long, default_value_t = 2)]
draft_threads: usize,
#[arg(long, default_value_t = 0.8)]
temperature: f64,
#[arg(long, default_value_t = 40)]
top_k: usize,
}
fn main() -> anyhow::Result<()> {
let args = Args::parse();
let draft_sampler = SamplerConfig {
temperature: args.temperature as f32,
top_k: args.top_k,
..SamplerConfig::default()
};
let target_config = EngineConfig {
model_path: args.target.clone(),
num_threads: args.target_threads,
..EngineConfig::default()
};
let draft_config = EngineConfig {
model_path: args.draft.clone(),
num_threads: args.draft_threads,
sampler: draft_sampler,
..EngineConfig::default()
};
let mut spec_config = SpeculativeConfig::new(target_config, draft_config);
spec_config.num_speculative = args.num_speculative;
if args.seed != 0 {
spec_config.seed = Some(args.seed);
}
eprintln!("Loading target model : {}", args.target);
eprintln!("Loading draft model : {}", args.draft);
eprintln!(
"Speculation window : {} tokens/round",
args.num_speculative
);
let mut engine = SpeculativeEngine::new(spec_config)?;
eprintln!("Both models loaded.\n");
eprintln!("Prompt: {}", args.prompt);
eprintln!("---");
let stdout = std::io::stdout();
let mut locked = stdout.lock();
let mut accepted_tokens = 0usize;
let wall_start = Instant::now();
engine.generate(&args.prompt, args.max_tokens, |tok| {
let _ = locked.write_all(tok.as_bytes());
let _ = locked.flush();
accepted_tokens += 1;
})?;
let elapsed = wall_start.elapsed();
drop(locked);
let tps = if elapsed.as_secs_f64() > 0.0 {
accepted_tokens as f64 / elapsed.as_secs_f64()
} else {
0.0
};
eprintln!("\n---");
eprintln!("Accepted tokens : {accepted_tokens}");
eprintln!("Wall time : {:.2}s", elapsed.as_secs_f64());
eprintln!("Throughput : {tps:.1} accepted tok/s");
eprintln!("Draft window : {} tokens/round", args.num_speculative);
eprintln!(
"Note: acceptance rate requires per-round counters \
(not yet exposed by SpeculativeEngine)."
);
Ok(())
}