1use std::{fs::read_to_string, path::PathBuf, time::Instant};
2
3use anyhow::{Context, Result};
4use clap::Parser;
5use either::Either;
6use mistralrs::{
7 cross_entropy_loss, parse_isq_value, Constraint, DType, Device, MistralRs, NormalRequest,
8 Request, ResponseOk, SamplingParams, Tensor, TextModelBuilder,
9};
10use tokio::sync::mpsc::channel;
11
12#[derive(Parser)]
14struct Args {
15 #[arg(short, long, default_value = "meta-llama/Llama-3.1-8B-Instruct")]
17 model_id: String,
18
19 #[arg(short, long)]
22 file: String,
23
24 #[arg(short, long)]
26 isq: Option<String>,
27
28 #[arg(short, long)]
30 calibration_file: Option<PathBuf>,
31}
32
33async fn process_chunk(runner: &MistralRs, chunk: Vec<u32>) -> anyhow::Result<(Tensor, Vec<u32>)> {
34 let (tx, mut rx) = channel(1);
35
36 let request = Request::Normal(Box::new(NormalRequest {
37 messages: mistralrs::RequestMessage::CompletionTokens(chunk),
38 sampling_params: SamplingParams {
39 max_len: Some(0),
40 ..SamplingParams::deterministic()
41 },
42 response: tx,
43 return_logprobs: false,
44 is_streaming: false,
45 id: 0,
46 constraint: Constraint::None,
47 suffix: None,
48 tools: None,
49 tool_choice: None,
50 logits_processors: None,
51 return_raw_logits: true,
52 web_search_options: None,
53 model_id: None,
54 truncate_sequence: false,
55 }));
56
57 runner.get_sender(None)?.send(request).await?;
58
59 let ResponseOk::Raw {
60 logits_chunks,
61 tokens,
62 } = rx
63 .recv()
64 .await
65 .context("Channel was erroneously closed!")?
66 .as_result()?
67 else {
68 anyhow::bail!("Got unexpected response type.")
69 };
70
71 Ok((logits_chunks[0].clone(), tokens))
72}
73
74#[tokio::main]
75async fn main() -> Result<()> {
76 let args = Args::parse();
77
78 let quant = if let Some(isq) = &args.isq {
79 Some(parse_isq_value(isq, None).map_err(anyhow::Error::msg)?)
80 } else {
81 None
82 };
83
84 let prompt_chunksize = 1024;
85 let mut model_builder = TextModelBuilder::new(&args.model_id).with_logging();
86 if let Some(quant) = quant {
87 model_builder = model_builder.with_isq(quant);
88 }
89 if let Some(calibration_file) = &args.calibration_file {
90 model_builder = model_builder.with_calibration_file(calibration_file.clone());
91 }
92
93 let model = model_builder.build().await?;
94
95 let text = read_to_string(&args.file)?;
96 let tokens = model
97 .tokenize(Either::Right(text), None, false, false, None)
98 .await?;
99 let bos_token = model
100 .tokenize(Either::Right(" ".to_string()), None, true, false, None)
101 .await?[0];
102 let inner = model.inner();
103
104 println!("Using bos token id `{bos_token}`.");
105
106 let n_chunks = tokens.len().div_ceil(prompt_chunksize);
107 let mut ppl_measurements = Vec::new();
108 for (i, chunk) in tokens.chunks(prompt_chunksize).enumerate() {
109 let start = Instant::now();
110 let (logits, tokens) = {
111 let chunk = [vec![bos_token], chunk.to_vec()].concat();
112 process_chunk(inner, chunk).await?
113 };
114
115 let logits = logits.to_device(&Device::Cpu)?.to_dtype(DType::F32)?;
117 let shift_logits = logits.narrow(0, 0, logits.dim(0)? - 1)?.contiguous()?;
119 let shift_labels = Tensor::from_slice(&tokens[1..], (tokens.len() - 1,), &Device::Cpu)?;
120
121 let loss_fct = cross_entropy_loss(&shift_logits, &shift_labels)?;
122 let perplexity = loss_fct.exp()?.to_scalar::<f32>()?;
123 let end = Instant::now();
124
125 ppl_measurements.push(perplexity);
126 println!(
127 "Chunk {i}/{n_chunks} ({} tokens): Perplexity for `{}`, ISQ `{:?}`, {}s: {perplexity}",
128 tokens.len(),
129 args.file,
130 quant,
131 end.duration_since(start).as_secs_f32(),
132 );
133 }
134
135 let mean = ppl_measurements.iter().sum::<f32>() / ppl_measurements.len() as f32;
136 let variance = ppl_measurements
137 .iter()
138 .map(|e| (mean - e).powf(2.))
139 .sum::<f32>()
140 / ppl_measurements.len() as f32;
141 let std_dev = variance.sqrt();
142 println!();
143 println!(
144 "Final perplexity for `{}`, ISQ `{:?}`: {}±{} ppl",
145 args.file, quant, mean, std_dev
146 );
147
148 Ok(())
149}