Skip to main content

custom_logits_processor/
main.rs

1use std::sync::Arc;
2
3use anyhow::Result;
4use mistralrs::{
5    CustomLogitsProcessor, IsqType, PagedAttentionMetaBuilder, RequestBuilder, Tensor,
6    TextMessageRole, TextModelBuilder,
7};
8use rand::Rng;
9
10struct ThresholdLogitsProcessor {
11    threshold: f64,
12}
13
14impl CustomLogitsProcessor for ThresholdLogitsProcessor {
15    fn apply(&self, logits: &Tensor, _context: &[u32]) -> mistralrs::Result<Tensor> {
16        // Mask is 1 for true, 0 for false.
17        let mask = logits.ge(self.threshold)?;
18        logits.broadcast_mul(&mask.to_dtype(logits.dtype())?)
19    }
20}
21
22#[tokio::main]
23async fn main() -> Result<()> {
24    let model = TextModelBuilder::new("microsoft/Phi-3.5-mini-instruct")
25        .with_isq(IsqType::Q4K)
26        .with_logging()
27        .with_paged_attn(|| PagedAttentionMetaBuilder::default().build())?
28        .build()
29        .await?;
30
31    let mut rng = rand::rng();
32    let random_value: f64 = rng.random_range(0.0..=1.0);
33    let threshold: f64 = rng.random_range(0.0..=0.5);
34
35    let request = RequestBuilder::new()
36        .add_logits_processor(Arc::new(move |logits: &Tensor, _context: &[u32]| {
37            logits * random_value
38        }))
39        .add_logits_processor(Arc::new(ThresholdLogitsProcessor { threshold }))
40        .add_message(
41            TextMessageRole::User,
42            "Please write a mathematical equation where a few numbers are added.",
43        );
44
45    let response = model.send_chat_request(request).await?;
46
47    println!("{}", response.choices[0].message.content.as_ref().unwrap());
48
49    Ok(())
50}