use std::io::Write as _;
use std::sync::Arc;
use clap::Parser;
use oxillama_arch::lora::LoadedLora;
use oxillama_runtime::{EngineConfig, InferenceEngine, SamplerConfig};
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
#[arg(long, short = 'm')]
model: String,
#[arg(long)]
adapter_a: Option<String>,
#[arg(long)]
adapter_b: Option<String>,
#[arg(long, short = 'p', default_value = "Translate: hello world")]
prompt: String,
#[arg(long, short = 'n', default_value_t = 64)]
max_tokens: usize,
#[arg(long, default_value_t = 1.0)]
lora_scale: f32,
#[arg(long, default_value_t = 4)]
threads: usize,
}
fn generate_and_print(
engine: &mut InferenceEngine,
prompt: &str,
max_tokens: usize,
label: &str,
) -> anyhow::Result<()> {
eprintln!("\n[{label}] Prompt: {prompt}");
let stdout = std::io::stdout();
let mut locked = stdout.lock();
engine.generate_with_config(prompt, max_tokens, SamplerConfig::default(), |tok| {
let _ = locked.write_all(tok.as_bytes());
let _ = locked.flush();
})?;
drop(locked);
eprintln!(); Ok(())
}
fn main() -> anyhow::Result<()> {
let args = Args::parse();
eprintln!("Loading base model: {}", args.model);
let config = EngineConfig {
model_path: args.model.clone(),
num_threads: args.threads,
..EngineConfig::default()
};
let mut engine = InferenceEngine::new(config);
engine.load_model()?;
eprintln!(" Base model loaded. is_loaded={}", engine.is_loaded());
generate_and_print(&mut engine, &args.prompt, args.max_tokens, "base (no LoRA)")?;
if let Some(ref path_a) = args.adapter_a {
eprintln!("\nLoading LoRA adapter A: {path_a}");
let lora_a = LoadedLora::load(path_a)?;
let lora_a_arc = Arc::new(lora_a);
eprintln!(
" Adapter A: rank={} alpha={}",
lora_a_arc.rank, lora_a_arc.alpha
);
engine.push_lora(Arc::clone(&lora_a_arc), args.lora_scale);
eprintln!(" Stack depth after push: {}", engine.lora_stack().len());
engine.apply_lora_stack()?;
eprintln!(" Adapter A applied to model weights.");
generate_and_print(&mut engine, &args.prompt, args.max_tokens, "adapter A")?;
if let Some(ref path_b) = args.adapter_b {
eprintln!("\nHot-swapping to LoRA adapter B: {path_b}");
let _popped = engine.pop_lora();
eprintln!(
" Adapter A popped. Stack depth: {}",
engine.lora_stack().len()
);
let lora_b = LoadedLora::load(path_b)?;
let lora_b_arc = Arc::new(lora_b);
eprintln!(
" Adapter B: rank={} alpha={}",
lora_b_arc.rank, lora_b_arc.alpha
);
engine.push_lora(lora_b_arc, args.lora_scale);
engine.apply_lora_stack()?;
eprintln!(" Adapter B applied to model weights.");
generate_and_print(&mut engine, &args.prompt, args.max_tokens, "adapter B")?;
}
engine.clear_loras();
eprintln!(
"\nAll adapters cleared. Stack depth: {}",
engine.lora_stack().len()
);
engine.reset();
generate_and_print(
&mut engine,
&args.prompt,
args.max_tokens,
"base (adapters cleared)",
)?;
} else {
eprintln!("\n(No adapter paths provided — skipping LoRA demo.)");
eprintln!("Use --adapter-a <path.gguf> to enable the hot-swap demonstration.");
}
eprintln!("\nDone.");
Ok(())
}