candle_transformers/generation/
mod.rs1use candle::{Context, DType, Error, Result, Tensor};
7use rand::{distr::Distribution, SeedableRng};
8
9#[derive(Clone, PartialEq, Debug)]
10pub enum Sampling {
11 ArgMax,
12 All { temperature: f64 },
13 TopK { k: usize, temperature: f64 },
14 TopP { p: f64, temperature: f64 },
15 TopKThenTopP { k: usize, p: f64, temperature: f64 },
16 GumbelSoftmax { temperature: f64 },
18}
19
20pub struct LogitsProcessor {
21 rng: rand::rngs::StdRng,
22 sampling: Sampling,
23}
24
25impl LogitsProcessor {
26 pub fn from_sampling(seed: u64, sampling: Sampling) -> Self {
27 let rng = rand::rngs::StdRng::seed_from_u64(seed);
28 Self { rng, sampling }
29 }
30
31 pub fn new(seed: u64, temperature: Option<f64>, top_p: Option<f64>) -> Self {
32 let temperature = temperature.and_then(|v| if v < 1e-7 { None } else { Some(v) });
33 let sampling = match temperature {
34 None => Sampling::ArgMax,
35 Some(temperature) => match top_p {
36 None => Sampling::All { temperature },
37 Some(p) => Sampling::TopP { p, temperature },
38 },
39 };
40 Self::from_sampling(seed, sampling)
41 }
42
43 fn sample_argmax(&mut self, logits: Tensor) -> Result<u32> {
44 let logits_v: Vec<f32> = logits.to_vec1()?;
45 let next_token = logits_v
46 .iter()
47 .enumerate()
48 .max_by(|(_, u), (_, v)| u.total_cmp(v))
49 .map(|(i, _)| i as u32)
50 .context("empty logits")?;
51 Ok(next_token)
52 }
53
54 fn sample_gumbel_softmax(&mut self, logits: &Tensor, temperature: f64) -> Result<u32> {
55 let sampled = candle_nn::sampling::gumbel_softmax(logits, temperature, candle::D::Minus1)?;
56 sampled.to_vec0::<u32>()
57 }
58
59 fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> {
60 let distr = rand::distr::weighted::WeightedIndex::new(prs).map_err(Error::wrap)?;
61 let next_token = distr.sample(&mut self.rng) as u32;
62 Ok(next_token)
63 }
64
65 fn sample_topp(&mut self, prs: &mut Vec<f32>, top_p: f32) -> Result<u32> {
69 let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
70
71 argsort_indices.sort_by(|&i, &j| prs[j].total_cmp(&prs[i]));
73
74 let mut cumsum = 0.;
76 for index in &argsort_indices {
77 if cumsum >= top_p {
78 prs[*index] = 0.0;
79 } else {
80 cumsum += prs[*index];
81 }
82 }
83 self.sample_multinomial(prs)
85 }
86
87 fn sample_topk(&mut self, prs: &mut Vec<f32>, top_k: usize) -> Result<u32> {
89 if top_k >= prs.len() {
90 self.sample_multinomial(prs)
91 } else {
92 let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
93 let (indices, _, _) =
94 argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i]));
95 let prs = indices.iter().map(|&i| prs[i]).collect::<Vec<_>>();
96 let index = self.sample_multinomial(&prs)?;
97 Ok(indices[index as usize] as u32)
98 }
99 }
100
101 fn sample_topk_topp(&mut self, prs: &mut Vec<f32>, top_k: usize, top_p: f32) -> Result<u32> {
104 if top_k >= prs.len() {
105 self.sample_topp(prs, top_p)
106 } else {
107 let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
108 let (indices, _, _) =
109 argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i]));
110 let mut prs = indices.iter().map(|&i| prs[i]).collect::<Vec<_>>();
111 let sum_p = prs.iter().sum::<f32>();
112 let index = if top_p <= 0.0 || top_p >= sum_p {
113 self.sample_multinomial(&prs)?
114 } else {
115 self.sample_topp(&mut prs, top_p)?
116 };
117 Ok(indices[index as usize] as u32)
118 }
119 }
120
121 pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
122 self.sample_f(logits, |_| {})
123 }
124
125 pub fn sample_f(&mut self, logits: &Tensor, f: impl FnOnce(&mut [f32])) -> Result<u32> {
126 let logits = logits.to_dtype(DType::F32)?;
127 let prs = |temperature: f64| -> Result<Vec<f32>> {
128 let logits = (&logits / temperature)?;
129 let prs = candle_nn::ops::softmax_last_dim(&logits)?;
130 let mut prs = prs.to_vec1()?;
131 f(&mut prs);
132 Ok(prs)
133 };
134
135 let next_token = match &self.sampling {
136 Sampling::ArgMax => self.sample_argmax(logits)?,
137 Sampling::GumbelSoftmax { temperature } => {
138 self.sample_gumbel_softmax(&logits, *temperature)?
139 }
140 Sampling::All { temperature } => {
141 let prs = prs(*temperature)?;
142 self.sample_multinomial(&prs)?
143 }
144 Sampling::TopP { p, temperature } => {
145 let mut prs = prs(*temperature)?;
146 if *p <= 0.0 || *p >= 1.0 {
147 self.sample_multinomial(&prs)?
149 } else {
150 self.sample_topp(&mut prs, *p as f32)?
152 }
153 }
154 Sampling::TopK { k, temperature } => {
155 let mut prs = prs(*temperature)?;
156 self.sample_topk(&mut prs, *k)?
157 }
158 Sampling::TopKThenTopP { k, p, temperature } => {
159 let mut prs = prs(*temperature)?;
160 self.sample_topk_topp(&mut prs, *k, *p as f32)?
161 }
162 };
163 Ok(next_token)
164 }
165}