Skip to main content

batch_inference/
batch_inference.rs

1use litert_lm::{Backend, Engine};
2
3fn main() -> Result<(), Box<dyn std::error::Error>> {
4    // Get model path from command line argument
5    let args: Vec<String> = std::env::args().collect();
6    if args.len() < 2 {
7        eprintln!("Usage: {} <model_path>", args[0]);
8        eprintln!("Example: {} model.tflite", args[0]);
9        std::process::exit(1);
10    }
11    let model_path = &args[1];
12
13    println!("Loading model from: {}", model_path);
14
15    // Create engine
16    let engine = Engine::new(model_path, Backend::Cpu)?;
17    println!("Engine created successfully!\n");
18
19    // Test prompts
20    let prompts = vec![
21        "What is the capital of France?",
22        "Explain quantum computing in simple terms.",
23        "Write a haiku about programming.",
24        "What is 2 + 2?",
25    ];
26
27    println!("Running batch inference...\n");
28    println!("========================================");
29
30    // Process each prompt in a separate session
31    for (i, prompt) in prompts.iter().enumerate() {
32        println!("\n[{}] Prompt: {}", i + 1, prompt);
33
34        // Create a new session for each prompt
35        let session = engine.create_session()?;
36
37        match session.generate(prompt) {
38            Ok(response) => {
39                println!("Response: {}", response);
40            }
41            Err(e) => {
42                eprintln!("Error: {}", e);
43            }
44        }
45
46        println!("----------------------------------------");
47    }
48
49    println!("\nBatch inference complete!");
50
51    Ok(())
52}