custom_logits_processor/
main.rs1use std::sync::Arc;
2
3use anyhow::Result;
4use mistralrs::{
5 CustomLogitsProcessor, IsqType, PagedAttentionMetaBuilder, RequestBuilder, Tensor,
6 TextMessageRole, TextModelBuilder,
7};
8use rand::Rng;
9
10struct ThresholdLogitsProcessor {
11 threshold: f64,
12}
13
14impl CustomLogitsProcessor for ThresholdLogitsProcessor {
15 fn apply(&self, logits: &Tensor, _context: &[u32]) -> mistralrs::Result<Tensor> {
16 let mask = logits.ge(self.threshold)?;
18 logits.broadcast_mul(&mask.to_dtype(logits.dtype())?)
19 }
20}
21
22#[tokio::main]
23async fn main() -> Result<()> {
24 let model = TextModelBuilder::new("microsoft/Phi-3.5-mini-instruct")
25 .with_isq(IsqType::Q4K)
26 .with_logging()
27 .with_paged_attn(|| PagedAttentionMetaBuilder::default().build())?
28 .build()
29 .await?;
30
31 let mut rng = rand::rng();
32 let random_value: f64 = rng.random_range(0.0..=1.0);
33 let threshold: f64 = rng.random_range(0.0..=0.5);
34
35 let request = RequestBuilder::new()
36 .add_logits_processor(Arc::new(move |logits: &Tensor, _context: &[u32]| {
37 logits * random_value
38 }))
39 .add_logits_processor(Arc::new(ThresholdLogitsProcessor { threshold }))
40 .add_message(
41 TextMessageRole::User,
42 "Please write a mathematical equation where a few numbers are added.",
43 );
44
45 let response = model.send_chat_request(request).await?;
46
47 println!("{}", response.choices[0].message.content.as_ref().unwrap());
48
49 Ok(())
50}