use mullama::{Context, ContextParams, Model, MullamaError, SamplerParams};
use std::io::{self, Write};
use std::sync::Arc;
fn main() -> Result<(), MullamaError> {
let args: Vec<String> = std::env::args().collect();
if args.len() < 2 {
eprintln!("Usage: {} <model_path> [prompt]", args[0]);
eprintln!(
"Example: {} path/to/model.gguf \"The future of AI is\"",
args[0]
);
return Ok(());
}
let model_path = &args[1];
let prompt = args
.get(2)
.map(|s| s.as_str())
.unwrap_or("The future of artificial intelligence is");
println!("Mullama Simple Generation Example");
println!("Model: {}", model_path);
println!("Prompt: \"{}\"", prompt);
println!("{}", "=".repeat(50));
println!("Loading model...");
let model = Arc::new(Model::load(model_path)?);
println!("Model loaded successfully!");
println!(" Vocabulary size: {}", model.vocab_size());
println!(" Context size: {}", model.n_ctx_train());
let mut ctx_params = ContextParams::default();
ctx_params.n_ctx = 2048;
ctx_params.n_batch = 512;
let mut context = Context::new(model.clone(), ctx_params)?;
let mut sampler_params = SamplerParams::default();
sampler_params.temperature = 0.7;
sampler_params.top_k = 40;
sampler_params.top_p = 0.9;
let mut sampler = sampler_params.build_chain(model.clone())?;
println!("Tokenizing prompt...");
let tokens = model.tokenize(prompt, true, false)?;
println!("Prompt tokenized into {} tokens", tokens.len());
println!("Processing prompt...");
for _token in tokens {
}
println!("Generating text...\n");
print!("Output: {}", prompt);
io::stdout().flush()?;
let max_tokens = 100;
let mut generated_tokens = 0;
while generated_tokens < max_tokens {
let next_token = sampler.sample(&mut context, 0);
if next_token == 0 {
println!("\n\nGeneration completed (end token reached)");
break;
}
let text = model.token_to_str(next_token, 0, false)?;
print!("{}", text);
io::stdout().flush()?;
generated_tokens += 1;
}
if generated_tokens >= max_tokens {
println!("\n\nGeneration completed (max tokens reached)");
}
println!("Generated {} tokens", generated_tokens);
Ok(())
}