Skip to main content

perplexity/
main.rs

1use std::{fs::read_to_string, path::PathBuf, time::Instant};
2
3use anyhow::{Context, Result};
4use clap::Parser;
5use either::Either;
6use mistralrs::{
7    cross_entropy_loss, parse_isq_value, Constraint, DType, Device, MistralRs, NormalRequest,
8    Request, ResponseOk, SamplingParams, Tensor, TextModelBuilder,
9};
10use tokio::sync::mpsc::channel;
11
12/// Calculate perplexity of a model. By default, this uses the Llama 3.1 8B model.
13#[derive(Parser)]
14struct Args {
15    /// The model to run.
16    #[arg(short, long, default_value = "meta-llama/Llama-3.1-8B-Instruct")]
17    model_id: String,
18
19    /// Filename to text to run the model on. This is recommended to be the Wikitext 2 dataset:
20    /// https://huggingface.co/datasets/EricB/wikitext2
21    #[arg(short, long)]
22    file: String,
23
24    /// ISQ quantization to run with.
25    #[arg(short, long)]
26    isq: Option<String>,
27
28    /// Generate and utilize an imatrix to enhance GGUF quantizations.
29    #[arg(short, long)]
30    calibration_file: Option<PathBuf>,
31}
32
33async fn process_chunk(runner: &MistralRs, chunk: Vec<u32>) -> anyhow::Result<(Tensor, Vec<u32>)> {
34    let (tx, mut rx) = channel(1);
35
36    let request = Request::Normal(Box::new(NormalRequest {
37        messages: mistralrs::RequestMessage::CompletionTokens(chunk),
38        sampling_params: SamplingParams {
39            max_len: Some(0),
40            ..SamplingParams::deterministic()
41        },
42        response: tx,
43        return_logprobs: false,
44        is_streaming: false,
45        id: 0,
46        constraint: Constraint::None,
47        suffix: None,
48        tools: None,
49        tool_choice: None,
50        logits_processors: None,
51        return_raw_logits: true,
52        web_search_options: None,
53        model_id: None,
54        truncate_sequence: false,
55    }));
56
57    runner.get_sender(None)?.send(request).await?;
58
59    let ResponseOk::Raw {
60        logits_chunks,
61        tokens,
62    } = rx
63        .recv()
64        .await
65        .context("Channel was erroneously closed!")?
66        .as_result()?
67    else {
68        anyhow::bail!("Got unexpected response type.")
69    };
70
71    Ok((logits_chunks[0].clone(), tokens))
72}
73
74#[tokio::main]
75async fn main() -> Result<()> {
76    let args = Args::parse();
77
78    let quant = if let Some(isq) = &args.isq {
79        Some(parse_isq_value(isq, None).map_err(anyhow::Error::msg)?)
80    } else {
81        None
82    };
83
84    let prompt_chunksize = 1024;
85    let mut model_builder = TextModelBuilder::new(&args.model_id).with_logging();
86    if let Some(quant) = quant {
87        model_builder = model_builder.with_isq(quant);
88    }
89    if let Some(calibration_file) = &args.calibration_file {
90        model_builder = model_builder.with_calibration_file(calibration_file.clone());
91    }
92
93    let model = model_builder.build().await?;
94
95    let text = read_to_string(&args.file)?;
96    let tokens = model
97        .tokenize(Either::Right(text), None, false, false, None)
98        .await?;
99    let bos_token = model
100        .tokenize(Either::Right(" ".to_string()), None, true, false, None)
101        .await?[0];
102    let inner = model.inner();
103
104    println!("Using bos token id `{bos_token}`.");
105
106    let n_chunks = tokens.len().div_ceil(prompt_chunksize);
107    let mut ppl_measurements = Vec::new();
108    for (i, chunk) in tokens.chunks(prompt_chunksize).enumerate() {
109        let start = Instant::now();
110        let (logits, tokens) = {
111            let chunk = [vec![bos_token], chunk.to_vec()].concat();
112            process_chunk(inner, chunk).await?
113        };
114
115        // Upcast to float if we need to compute the loss to avoid potential precision issues
116        let logits = logits.to_device(&Device::Cpu)?.to_dtype(DType::F32)?;
117        // Shift so that tokens < n predict n
118        let shift_logits = logits.narrow(0, 0, logits.dim(0)? - 1)?.contiguous()?;
119        let shift_labels = Tensor::from_slice(&tokens[1..], (tokens.len() - 1,), &Device::Cpu)?;
120
121        let loss_fct = cross_entropy_loss(&shift_logits, &shift_labels)?;
122        let perplexity = loss_fct.exp()?.to_scalar::<f32>()?;
123        let end = Instant::now();
124
125        ppl_measurements.push(perplexity);
126        println!(
127            "Chunk {i}/{n_chunks} ({} tokens): Perplexity for `{}`, ISQ `{:?}`, {}s: {perplexity}",
128            tokens.len(),
129            args.file,
130            quant,
131            end.duration_since(start).as_secs_f32(),
132        );
133    }
134
135    let mean = ppl_measurements.iter().sum::<f32>() / ppl_measurements.len() as f32;
136    let variance = ppl_measurements
137        .iter()
138        .map(|e| (mean - e).powf(2.))
139        .sum::<f32>()
140        / ppl_measurements.len() as f32;
141    let std_dev = variance.sqrt();
142    println!();
143    println!(
144        "Final perplexity for `{}`, ISQ `{:?}`: {}±{} ppl",
145        args.file, quant, mean, std_dev
146    );
147
148    Ok(())
149}