llm_base/
inference_session.rs

1use std::fmt::Display;
2
3use partial_sort::PartialSort;
4use rand::{distributions::WeightedIndex, prelude::Distribution};
5use thiserror::Error;
6
7use crate::{
8    mulf, InferenceError, InferenceParameters, Model, OutputRequest, TokenId, TokenUtf8Buffer,
9};
10
11// The size of a scratch buffer used for inference. This is used for temporary
12// storage of intermediate results during inference.
13//
14// The specific value was copied from `llama.cpp`.
15const SCRATCH_SIZE: usize = 512 * 1024 * 1024;
16
17/// An inference session represents the state of the text generation. This holds
18/// the full context window, as well as several additional parameters used
19/// during sampling.
20///
21/// # Safety
22/// This implements `Send` as it can be sent to another thread. However, it does
23/// not implement `Sync` - it *cannot* be used from multiple threads at the same time.
24///
25/// Consider spawning multiple inference sessions for the same model if you need
26/// to use it from multiple threads.
27pub struct InferenceSession {
28    // Must be kept alive for the model
29    pub(crate) _session_ctx: ggml::Context,
30
31    // Original size of the memory used to create this context.
32    pub(crate) memory_size: usize,
33
34    // Configuration for the session.
35    pub(crate) config: InferenceSessionConfig,
36
37    /// Memory K
38    #[doc(hidden)]
39    pub memory_k: ggml::Tensor,
40
41    /// Memory M
42    #[doc(hidden)]
43    pub memory_v: ggml::Tensor,
44
45    /// How many tokens have been fed into the model's working memory so far.
46    #[doc(hidden)]
47    pub n_past: usize,
48
49    /// How much memory is required per token for the temporary context used
50    /// during inference.
51    #[doc(hidden)]
52    pub mem_per_token: usize,
53
54    /// All tokens generated by this inference session
55    pub(crate) tokens: Vec<TokenId>,
56
57    /// The logits that were last predicted by the network. Zeroed out otherwise.
58    #[doc(hidden)]
59    pub last_logits: Vec<f32>,
60
61    /// Scratch buffers used during inference.
62    ///
63    /// The number of scratch buffers was copied from `llama.cpp`.
64    /// There is no specific reason for this number, but one is insufficient.
65    #[doc(hidden)]
66    pub scratch: [ggml::Buffer; 2],
67}
68unsafe impl Send for InferenceSession {}
69impl InferenceSession {
70    /// Feed a prompt to the model for this session.
71    pub fn feed_prompt<E: std::error::Error + 'static>(
72        &mut self,
73        model: &dyn Model,
74        params: &InferenceParameters,
75        prompt: &str,
76        output_request: &mut OutputRequest,
77        mut callback: impl FnMut(&[u8]) -> Result<(), E>,
78    ) -> Result<(), InferenceError> {
79        let beginning_of_sentence = self.n_past == 0;
80
81        let vocab = model.vocabulary();
82        let prompt_tokens: Vec<TokenId> = vocab
83            .tokenize(prompt, beginning_of_sentence)?
84            .iter()
85            .map(|(_, tok)| *tok)
86            .collect();
87
88        if self.n_past + prompt_tokens.len() >= model.n_context_tokens() {
89            return Err(InferenceError::ContextFull);
90        }
91
92        for batch in prompt_tokens.chunks(params.n_batch) {
93            model.evaluate(self, params, batch, output_request);
94            for &tk in batch {
95                let should_call_callback = Some(tk) != model.bot_token_id();
96
97                if should_call_callback {
98                    // NOTE: No string ever tokenizes to the end of sentence. So we
99                    // can just return the id here.
100                    if let Err(e) = callback(vocab.token(tk as usize)) {
101                        return Err(InferenceError::UserCallback(Box::new(e)));
102                    }
103                }
104
105                // Update the tokens for this session
106                self.tokens.push(tk);
107            }
108        }
109
110        Ok(())
111    }
112
113    /// Infer the next token for this session.
114    pub fn infer_next_token<'v>(
115        &mut self,
116        model: &'v dyn Model,
117        params: &InferenceParameters,
118        output_request: &mut OutputRequest,
119        rng: &mut impl rand::Rng,
120    ) -> Result<&'v [u8], InferenceError> {
121        if self.n_past + 1 >= model.n_context_tokens() {
122            return Err(InferenceError::ContextFull);
123        }
124
125        // First, sample the next token, using the stored last_logits;
126        let next_token = self.sample_top_p_top_k(params, rng);
127
128        // Update the tokens for this session
129        self.tokens.push(next_token);
130
131        // Then, evaluate the network again to compute the new last_logits
132        model.evaluate(self, params, &[next_token], output_request);
133
134        // Return the next token
135        if next_token as TokenId == model.eot_token_id() {
136            Err(InferenceError::EndOfText)
137        } else {
138            Ok(model.vocabulary().token(next_token as usize))
139        }
140    }
141
142    /// Generate text by using the provided [Model] to evaluate the `prompt`.
143    ///
144    /// The `callback` is called with each new token until an end-of-text (EOT)
145    /// token is encountered or the maximum number of tokens have been
146    /// generated (specified by [InferenceRequest::maximum_token_count]).
147    ///
148    /// This is a wrapper around [Self::feed_prompt] and [Self::infer_next_token].
149    pub fn infer<E: std::error::Error + 'static>(
150        &mut self,
151        model: &dyn Model,
152        rng: &mut impl rand::Rng,
153        request: &InferenceRequest,
154        output_request: &mut OutputRequest,
155        mut callback: impl FnMut(&str) -> Result<(), E>,
156    ) -> Result<InferenceStats, InferenceError> {
157        let maximum_token_count = request.maximum_token_count.unwrap_or(usize::MAX);
158        if request.play_back_previous_tokens {
159            // "Play back" the existing tokens, so that loading from an inference snapshot works
160            // as expected.
161            let mut token_utf8_buf = TokenUtf8Buffer::new();
162            for token_id in &self.tokens {
163                // Buffer the token until it's valid UTF-8, then call the callback.
164                if let Some(tokens) =
165                    token_utf8_buf.push(model.vocabulary().token(*token_id as usize))
166                {
167                    if let Err(e) = callback(&tokens) {
168                        return Err(InferenceError::UserCallback(Box::new(e)));
169                    }
170                }
171            }
172        }
173
174        let mut stats = InferenceStats::default();
175        let start_at = std::time::SystemTime::now();
176
177        let parameters = request.parameters.unwrap_or(model.inference_parameters());
178
179        // Feed the initial prompt through the transformer, to update its
180        // context window with new data.
181        self.feed_prompt(
182            model,
183            parameters,
184            request.prompt,
185            output_request,
186            TokenUtf8Buffer::adapt_callback(&mut callback),
187        )?;
188        stats.feed_prompt_duration = start_at.elapsed().unwrap();
189        stats.prompt_tokens = self.n_past;
190
191        // After the prompt is consumed, sample tokens by repeatedly calling
192        // `infer_next_token`. We generate tokens until the model returns an
193        // EndOfText token, or we run out of space in the context window,
194        // or we reach the specified limit.
195        let mut tokens_processed = 0;
196        let mut token_utf8_buf = TokenUtf8Buffer::new();
197        while tokens_processed < maximum_token_count {
198            let token = match self.infer_next_token(model, parameters, &mut Default::default(), rng)
199            {
200                Ok(token) => token,
201                Err(InferenceError::EndOfText) => break,
202                Err(e) => return Err(e),
203            };
204
205            // Buffer the token until it's valid UTF-8, then call the callback.
206            if let Some(tokens) = token_utf8_buf.push(token) {
207                if let Err(e) = callback(&tokens) {
208                    return Err(InferenceError::UserCallback(Box::new(e)));
209                }
210            }
211
212            tokens_processed += 1;
213        }
214        stats.predict_duration = start_at.elapsed().unwrap();
215        stats.predict_tokens = self.n_past;
216
217        Ok(stats)
218    }
219
220    /// Sample a token using Top-P/Top-K sampling and the last logits from this session.
221    pub fn sample_top_p_top_k(
222        &self,
223        params: &InferenceParameters,
224        rng: &mut impl rand::Rng,
225    ) -> TokenId {
226        let logits = &self.last_logits;
227        let n_logits = logits.len();
228        let mut logits_id = Vec::<(f32, TokenId)>::with_capacity(n_logits);
229
230        {
231            let scale = 1.0 / params.temperature;
232            for (i, &logit) in logits.iter().enumerate() {
233                let tid = i as TokenId;
234
235                let val = if let Some(logit_override) = params.bias_tokens.get(tid) {
236                    logit_override
237                } else if self.tokens[self
238                    .tokens
239                    .len()
240                    .saturating_sub(params.repetition_penalty_last_n)..]
241                    .contains(&(i as TokenId))
242                {
243                    // repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
244                    // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
245
246                    // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
247                    if logits[i] < 0.0 {
248                        logit * scale * params.repeat_penalty
249                    } else {
250                        logit * scale / params.repeat_penalty
251                    }
252                } else {
253                    logit * scale
254                };
255                logits_id.push((val, tid));
256            }
257        }
258
259        // find the top K tokens
260        {
261            logits_id.partial_sort(params.top_k, |a, b| {
262                // Sort descending
263                b.0.total_cmp(&a.0)
264            });
265            logits_id.truncate(params.top_k);
266        }
267
268        let maxl = logits_id
269            .iter()
270            .map(|x| x.0)
271            .max_by(f32::total_cmp)
272            .unwrap();
273
274        // compute probs for the top K tokens
275        let mut probs: Vec<f32> = logits_id
276            .iter()
277            .copied()
278            .map(|(k, _)| (k - maxl).exp())
279            .collect();
280        let sum: f32 = probs.iter().copied().sum();
281
282        // Normalize the probs
283        for p in probs.iter_mut() {
284            *p /= sum;
285        }
286
287        // Top p sampling
288        if params.top_p < 1.0 {
289            let mut cumsum = 0.0;
290            for i in 0..probs.len() {
291                cumsum += probs[i];
292                if cumsum >= params.top_p {
293                    probs.truncate(i + 1);
294                    logits_id.truncate(i + 1);
295                    break;
296                }
297            }
298
299            cumsum = 1.0 / cumsum;
300            for p in probs.iter_mut() {
301                *p *= cumsum;
302            }
303        }
304
305        let dist = WeightedIndex::new(&probs).expect("WeightedIndex error");
306        let idx = dist.sample(rng);
307
308        logits_id[idx].1
309    }
310
311    /// Obtains a serializable snapshot of the current inference status. This
312    /// can be used to cache the state of the model and store them into a file.
313    ///
314    /// # Safety
315    ///
316    /// This function provides raw access to the underlying memory owned by the
317    /// ggml context. While the provided `InferenceSnapshotRef` object is alive,
318    /// no other methods for this model object should be called.
319    pub unsafe fn get_snapshot(&mut self) -> InferenceSnapshotRef<'_> {
320        let memory_k = unsafe {
321            std::slice::from_raw_parts(self.memory_k.data() as *mut u8, self.memory_k.nbytes())
322        };
323        let memory_v = unsafe {
324            std::slice::from_raw_parts(self.memory_v.data() as *mut u8, self.memory_v.nbytes())
325        };
326
327        InferenceSnapshotRef {
328            npast: self.n_past,
329            config: self.config,
330            tokens: self.tokens.clone(),
331            logits: self.last_logits.clone(),
332            memory_k,
333            memory_v,
334        }
335    }
336
337    /// Creates an [InferenceSession] from a snapshot.
338    pub fn from_snapshot(
339        snapshot: InferenceSnapshot,
340        model: &dyn Model,
341    ) -> Result<Self, SnapshotError> {
342        let mut session = model.start_session(snapshot.config);
343
344        if session.memory_k.nbytes() != snapshot.memory_k.len()
345            || session.memory_v.nbytes() != snapshot.memory_v.len()
346        {
347            return Err(SnapshotError::MemorySizeMismatch {
348                self_size: session.memory_k.nbytes() + session.memory_v.nbytes(),
349                input_size: snapshot.memory_k.len() + snapshot.memory_v.len(),
350            });
351        }
352
353        // SAFETY: We have exclusive access to Session, which means no one else
354        // should be touching the context's memory. We can write to it because
355        // we already checked the size.
356        unsafe {
357            session.memory_k.write_data(&snapshot.memory_k);
358            session.memory_v.write_data(&snapshot.memory_v);
359        }
360
361        session.n_past = snapshot.npast;
362        session.tokens = snapshot.tokens;
363        session.last_logits = snapshot.last_logits;
364
365        Ok(session)
366    }
367}
368impl InferenceSession {
369    /// Create a new InferenceSession
370    pub fn new(
371        config: InferenceSessionConfig,
372        n_ctx: usize,
373        n_layer: usize,
374        n_embd: usize,
375        n_vocab: usize,
376    ) -> InferenceSession {
377        let ctx_size = {
378            let mut ctx_size = 0;
379            ctx_size += mulf!(
380                n_ctx,
381                n_layer,
382                n_embd,
383                ggml::type_sizef(config.memory_k_type.into())
384            ); // memory_k
385            ctx_size += mulf!(
386                n_ctx,
387                n_layer,
388                n_embd,
389                ggml::type_sizef(config.memory_v_type.into())
390            ); // memory_v
391            ctx_size += (5 + 10 * n_layer) * 256; // object overhead
392            ctx_size
393        };
394
395        let session_ctx = ggml::Context::init(ctx_size, true);
396
397        // Initialize key + value memory tensors
398        let n_mem = n_layer * n_ctx;
399        let n_elements = n_embd * n_mem;
400        let memory_k = session_ctx.new_tensor_1d(config.memory_k_type.into(), n_elements);
401        let memory_v = session_ctx.new_tensor_1d(config.memory_v_type.into(), n_elements);
402
403        InferenceSession {
404            _session_ctx: session_ctx,
405            memory_size: ctx_size,
406            config,
407            memory_k,
408            memory_v,
409            n_past: 0,
410            mem_per_token: 0,
411            tokens: vec![],
412            last_logits: vec![0.0; n_vocab],
413            scratch: scratch_buffers(),
414        }
415    }
416}
417impl Clone for InferenceSession {
418    fn clone(&self) -> Self {
419        let context = ggml::Context::init(self.memory_size, true);
420        let memory_k = context.new_tensor_1d(self.memory_k.get_type(), self.memory_k.nelements());
421        let memory_v = context.new_tensor_1d(self.memory_v.get_type(), self.memory_v.nelements());
422
423        Self {
424            _session_ctx: context,
425            memory_size: self.memory_size,
426            config: self.config,
427            memory_k,
428            memory_v,
429            n_past: self.n_past,
430            mem_per_token: self.mem_per_token,
431            tokens: self.tokens.clone(),
432            last_logits: self.last_logits.clone(),
433            scratch: scratch_buffers(),
434        }
435    }
436}
437
438#[derive(Error, Debug)]
439/// Errors encountered during the snapshot process.
440pub enum SnapshotError {
441    /// Arbitrary I/O error.
442    #[error("I/O error while reading or writing snapshot")]
443    IO(#[from] std::io::Error),
444    /// Mismatch between the snapshotted memory and the in-memory memory.
445    #[error("could not read snapshot due to size mismatch (self={self_size}, input={input_size})")]
446    MemorySizeMismatch {
447        /// The size of the session memory in memory.
448        self_size: usize,
449        /// The size of the session memory in snapshot.
450        input_size: usize,
451    },
452}
453
454#[derive(serde::Serialize, Clone, PartialEq)]
455/// A serializable snapshot of the inference process.
456/// Can be created by calling [InferenceSession::get_snapshot].
457///
458/// If serializing, ensure that your serializer is binary-efficient.
459/// This type contains a large array of bytes; traditional textual serializers
460/// are likely to serialize this as an array of numbers at extreme cost.
461// Keep in sync with [InferenceSession] and [InferenceSnapshot].
462pub struct InferenceSnapshotRef<'a> {
463    /// How many tokens have been stored in the memory so far.
464    pub npast: usize,
465    /// Parameters associated with the saved inference session.
466    pub config: InferenceSessionConfig,
467    /// All tokens generated by this inference session.
468    pub tokens: Vec<TokenId>,
469    /// The vector of logits that was produced after the last inference.
470    pub logits: Vec<f32>,
471    /// The contents of the 'key' memory tensor.
472    #[serde(with = "serde_bytes")]
473    pub memory_k: &'a [u8],
474    /// The contents of the 'value' memory tensor.
475    #[serde(with = "serde_bytes")]
476    pub memory_v: &'a [u8],
477}
478impl InferenceSnapshotRef<'_> {
479    /// Creates an owned [InferenceSnapshot] from this [InferenceSnapshotRef].
480    ///
481    /// The [ToOwned] trait is not used due to its blanket implementation for all [Clone] types.
482    pub fn to_owned(&self) -> InferenceSnapshot {
483        InferenceSnapshot {
484            npast: self.npast,
485            config: self.config,
486            tokens: self.tokens.clone(),
487            last_logits: self.logits.clone(),
488            memory_k: self.memory_k.to_vec(),
489            memory_v: self.memory_v.to_vec(),
490        }
491    }
492}
493
494/// A serializable snapshot of the inference process. Can be restored by calling
495/// [InferenceSession::from_snapshot].
496#[derive(serde::Deserialize, Clone, PartialEq)]
497// Keep in sync with [InferenceSession] and [InferenceSnapshotRef].
498pub struct InferenceSnapshot {
499    /// How many tokens have been stored in the memory so far.
500    pub npast: usize,
501    /// Parameters associated with the saved inference session.
502    pub config: InferenceSessionConfig,
503    /// All tokens generated by this inference session.
504    pub tokens: Vec<TokenId>,
505    /// The vector of logits that was produced after the last inference.
506    pub last_logits: Vec<f32>,
507    /// The contents of the 'key' memory tensor.
508    #[serde(with = "serde_bytes")]
509    pub memory_k: Vec<u8>,
510    /// The contents of the 'value' memory tensor.
511    #[serde(with = "serde_bytes")]
512    pub memory_v: Vec<u8>,
513}
514
515#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
516/// Configuration for an inference session.
517///
518/// This is specified at the time of creation of an [InferenceSession],
519/// and cannot be changed after the session has been created.
520pub struct InferenceSessionConfig {
521    /// The type of the memory K tensor.
522    pub memory_k_type: ModelKVMemoryType,
523    /// The type of the memory V tensor.
524    pub memory_v_type: ModelKVMemoryType,
525}
526impl Default for InferenceSessionConfig {
527    fn default() -> Self {
528        Self {
529            memory_k_type: ModelKVMemoryType::Float32,
530            memory_v_type: ModelKVMemoryType::Float32,
531        }
532    }
533}
534
535#[derive(Debug, PartialEq, Default, Clone, Copy)]
536/// Settings specific to [InferenceSession::infer].
537pub struct InferenceRequest<'a> {
538    /// The prompt to feed to the model.
539    pub prompt: &'a str,
540    /// The parameters to use during this inference attempt.
541    /// If not specified, this will default to the parameters
542    /// specified in the model.
543    pub parameters: Option<&'a InferenceParameters>,
544    /// Whether or not to call the callback with the previous tokens
545    /// that were encountered in this session.
546    ///
547    /// You likely want to turn this on if you're using a session
548    /// that has been rehydrated from a snapshot.
549    pub play_back_previous_tokens: bool,
550    /// The maximum number of tokens to generate.
551    pub maximum_token_count: Option<usize>,
552}
553
554/// Statistics about the inference process.
555#[derive(Debug, Clone, Copy)]
556pub struct InferenceStats {
557    /// How long it took to feed the prompt.
558    pub feed_prompt_duration: std::time::Duration,
559    /// How many tokens the prompt was.
560    pub prompt_tokens: usize,
561    /// How long it took to predict new tokens.
562    pub predict_duration: std::time::Duration,
563    /// The number of predicted tokens.
564    pub predict_tokens: usize,
565}
566impl Default for InferenceStats {
567    fn default() -> Self {
568        Self {
569            feed_prompt_duration: std::time::Duration::from_secs(0),
570            prompt_tokens: 0,
571            predict_duration: std::time::Duration::from_secs(0),
572            predict_tokens: 0,
573        }
574    }
575}
576impl Display for InferenceStats {
577    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
578        write!(
579            f,
580            "feed_prompt_duration: {}ms\nprompt_tokens: {}\npredict_duration: {}ms\npredict_tokens: {}\nper_token_duration: {:.3}ms",
581            self.feed_prompt_duration.as_millis(),
582            self.prompt_tokens,
583            self.predict_duration.as_millis(),
584            self.predict_tokens,
585            (self.predict_duration.as_millis() as f64) / (self.predict_tokens as f64),
586        )
587    }
588}
589
590/// Allowed types for the model memory K/V tensors.
591#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
592pub enum ModelKVMemoryType {
593    /// 16-bit float.
594    Float16,
595    /// 32-bit float.
596    Float32,
597}
598impl From<ModelKVMemoryType> for ggml::Type {
599    fn from(value: ModelKVMemoryType) -> Self {
600        match value {
601            ModelKVMemoryType::Float16 => ggml::Type::F16,
602            ModelKVMemoryType::Float32 => ggml::Type::F32,
603        }
604    }
605}
606
607fn scratch_buffers() -> [ggml::Buffer; 2] {
608    [
609        ggml::Buffer::new(SCRATCH_SIZE),
610        ggml::Buffer::new(SCRATCH_SIZE),
611    ]
612}