Skip to main content

batching/
main.rs

1use anyhow::Result;
2use mistralrs::{
3    ChatCompletionResponse, IsqType, PagedAttentionMetaBuilder, TextMessageRole, TextMessages,
4    TextModelBuilder, Usage,
5};
6
7const N_REQUESTS: usize = 10;
8
9#[tokio::main]
10async fn main() -> Result<()> {
11    let model = TextModelBuilder::new("microsoft/Phi-3.5-mini-instruct")
12        .with_isq(IsqType::Q8_0)
13        .with_logging()
14        .with_paged_attn(|| PagedAttentionMetaBuilder::default().build())?
15        .build()
16        .await?;
17
18    let messages = TextMessages::new()
19        .add_message(
20            TextMessageRole::System,
21            "You are an AI agent with a specialty in programming.",
22        )
23        .add_message(
24            TextMessageRole::User,
25            "Hello! How are you? Please write generic binary search function in Rust.",
26        );
27
28    let mut handles = Vec::new();
29    for _ in 0..N_REQUESTS {
30        handles.push(model.send_chat_request(messages.clone()));
31    }
32    let responses = futures::future::join_all(handles)
33        .await
34        .into_iter()
35        .collect::<Result<Vec<_>>>()?;
36
37    let mut max_prompt = f32::MIN;
38    let mut max_completion = f32::MIN;
39
40    for response in responses {
41        let ChatCompletionResponse {
42            usage:
43                Usage {
44                    avg_compl_tok_per_sec,
45                    avg_prompt_tok_per_sec,
46                    ..
47                },
48            ..
49        } = response;
50        dbg!(avg_compl_tok_per_sec, avg_prompt_tok_per_sec);
51        if avg_compl_tok_per_sec > max_prompt {
52            max_prompt = avg_prompt_tok_per_sec;
53        }
54        if avg_compl_tok_per_sec > max_completion {
55            max_completion = avg_compl_tok_per_sec;
56        }
57    }
58    println!("Individual sequence stats: {max_prompt} max PP T/s, {max_completion} max TG T/s");
59
60    Ok(())
61}