rten_generate/
generator.rs

1//! Tools to run the generation loop for an auto-regressive model.
2
3use std::error::Error;
4use std::fmt;
5use std::ops::Range;
6
7use rten::{Dimension, NodeId, RunOptions, Value, ValueOrView, ValueView};
8use rten_tensor::prelude::*;
9use rten_tensor::{NdTensor, Tensor};
10
11#[cfg(feature = "text-decoder")]
12use rten_text::{Tokenizer, TokenizerError};
13
14use crate::filter::LogitsFilter;
15use crate::logits::Logits;
16use crate::metrics::Metrics;
17use crate::model::Model;
18use crate::sampler::{ArgMax, Sampler};
19
20#[cfg(feature = "text-decoder")]
21use crate::text_decoder::TextDecoder;
22
23/// Integer type used to represent token IDs.
24pub type TokenId = u32;
25
26/// Errors that occur when creating or running a [`Generator`].
27#[derive(Debug)]
28pub enum GeneratorError {
29    /// An expected model input was not found.
30    InputNotFound(String),
31
32    /// An expected model output was not found.
33    OutputNotFound(String),
34
35    /// An input or output did not have the expected shape.
36    ShapeMismatch(String),
37
38    /// An error occurred while generating the next token.
39    GenerateError(Box<dyn Error>),
40
41    /// An error occurred while decoding tokens.
42    #[cfg(feature = "text-decoder")]
43    DecodeError(TokenizerError),
44}
45
46impl fmt::Display for GeneratorError {
47    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
48        match self {
49            GeneratorError::InputNotFound(name) => write!(f, "model input not found: {}", name),
50            GeneratorError::OutputNotFound(name) => write!(f, "model output not found: {}", name),
51            GeneratorError::ShapeMismatch(err) => write!(f, "shape mismatch: {}", err),
52            GeneratorError::GenerateError(err) => write!(f, "generation error: {}", err),
53            #[cfg(feature = "text-decoder")]
54            GeneratorError::DecodeError(err) => write!(f, "decode error: {}", err),
55        }
56    }
57}
58
59impl Error for GeneratorError {}
60
61/// Wraps an error with associated context for debugging.
62#[derive(Debug)]
63struct ErrorContext {
64    error: Box<dyn Error>,
65    context: String,
66}
67
68impl Error for ErrorContext {
69    fn source(&self) -> Option<&(dyn Error + 'static)> {
70        Some(self.error.as_ref())
71    }
72}
73
74impl std::fmt::Display for ErrorContext {
75    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76        write!(f, "{}: {}", self.context, self.error)
77    }
78}
79
80enum KvCacheData {
81    /// Key-value cache with shape `[batch, seq_len, channels]`.
82    ///
83    /// In this configuration the channels for all heads are combined into the
84    /// last dimension.
85    BatchSeqChans(NdTensor<f32, 3>),
86    /// Key-value cache with shape `[batch, heads, seq_len, channels]`.
87    BatchHeadSeqChans(NdTensor<f32, 4>),
88}
89
90impl KvCacheData {
91    /// Allocate a KV cache buffer with the given batch size, number of heads
92    /// and embed size.
93    ///
94    /// The buffer initially has capacity to be extended to a sequence length
95    /// of `seq_len_capacity`.
96    fn with_capacity(
97        batch_size: usize,
98        n_heads: Option<usize>,
99        size: usize,
100        seq_len_capacity: usize,
101    ) -> KvCacheData {
102        if let Some(n_heads) = n_heads {
103            KvCacheData::BatchHeadSeqChans(NdTensor::with_capacity(
104                [batch_size, n_heads, seq_len_capacity, size],
105                2, /* seq dim */
106            ))
107        } else {
108            KvCacheData::BatchSeqChans(NdTensor::with_capacity(
109                [batch_size, seq_len_capacity, size],
110                1, /* seq dim */
111            ))
112        }
113    }
114
115    /// Return the current sequence length of the cache.
116    fn sequence_len(&self) -> usize {
117        match self {
118            KvCacheData::BatchSeqChans(data) => data.size(1),
119            KvCacheData::BatchHeadSeqChans(data) => data.size(2),
120        }
121    }
122
123    /// Return true if the KV cache has capacity for a given sequence length.
124    fn has_capacity(&self, sequence_len: usize) -> bool {
125        match self {
126            KvCacheData::BatchSeqChans(data) => {
127                data.has_capacity(1 /* seq dim */, sequence_len)
128            }
129            KvCacheData::BatchHeadSeqChans(data) => {
130                data.has_capacity(2 /* seq dim */, sequence_len)
131            }
132        }
133    }
134
135    /// Clone this cache into a new buffer with space to store sequences of
136    /// a given size.
137    fn clone_with_capacity(&self, max_sequence_len: usize) -> KvCacheData {
138        let max_sequence_len = max_sequence_len.max(self.sequence_len());
139        match self {
140            KvCacheData::BatchSeqChans(data) => {
141                let [batch, _seq, chans] = data.shape();
142                let mut new_data =
143                    NdTensor::with_capacity([batch, max_sequence_len, chans], 1 /* seq dim */);
144                new_data.append(1, data).expect("should have capacity");
145                KvCacheData::BatchSeqChans(new_data)
146            }
147            KvCacheData::BatchHeadSeqChans(data) => {
148                let [batch, n_heads, _seq, chans] = data.shape();
149                let mut new_data = NdTensor::with_capacity(
150                    [batch, n_heads, max_sequence_len, chans],
151                    2, /* seq dim */
152                );
153                new_data.append(2, data).expect("should have capacity");
154                KvCacheData::BatchHeadSeqChans(new_data)
155            }
156        }
157    }
158}
159
160/// Key-value cache for a single layer of a transformer model.
161struct KvCache {
162    /// Input ID for this cache entry.
163    input_id: NodeId,
164
165    /// Output ID for this cache entry.
166    output_id: NodeId,
167
168    /// The cached keys and values. This is set to `None` during inference, as
169    /// the model temporarily takes ownership of it.
170    cache: Option<KvCacheData>,
171}
172
173impl KvCache {
174    fn size(&self) -> Option<usize> {
175        self.cache.as_ref().map(|c| c.sequence_len())
176    }
177}
178
179/// Specifies a pattern for the name of a key-value cache input or output.
180///
181/// These inputs are expected to have the form `{prefix}{layer_number}{suffix}`,
182/// with one input and output per layer for the key cache and the value cache.
183pub struct KVCachePattern<'a> {
184    pub prefix: &'a str,
185    pub suffix: &'a str,
186}
187
188impl<'a> From<(&'a str, &'a str)> for KVCachePattern<'a> {
189    /// Construct a [`KVCachePattern`] from a `(prefix, suffix)` tuple.
190    fn from(value: (&'a str, &'a str)) -> Self {
191        let (prefix, suffix) = value;
192        KVCachePattern { prefix, suffix }
193    }
194}
195
196/// Specifies a pair of patterns for corresponding input and output key-value
197/// cache entries.
198pub struct KVCachePair<'a> {
199    /// The pattern for the model input name.
200    pub input: KVCachePattern<'a>,
201
202    /// The pattern for the model output name.
203    pub output: KVCachePattern<'a>,
204
205    /// Specifies whether this cache is used for a cross-attention ("encoder")
206    /// KV cache.
207    ///
208    /// Encoder KV-cache entries are computed only on the first run of the
209    /// model and reused in subsequent runs.
210    pub encoder: bool,
211}
212
213/// Specifies the names of model inputs and outputs.
214///
215/// The [`Default`] impl for this struct returns an instance whose names
216/// follow the configuration of Hugging Face's Optimum tool.
217///
218/// Any inputs that are not present in the model are ignored.
219pub struct ModelInputsConfig<'a> {
220    /// Model input that contains the token IDs of the prompt and output
221    /// generated so far.
222    pub input_ids: &'a str,
223
224    /// Model output that contains logits.
225    pub logits: &'a str,
226
227    /// Model input that contains an attention mask.
228    pub attention_mask: &'a str,
229
230    /// Model input that contains KV cache positions for each position.
231    ///
232    /// This input does not have a batch dimension.
233    pub cache_position: &'a str,
234
235    /// Patterns for inputs and outputs used for key-value caches.
236    pub kv_caches: Vec<KVCachePair<'a>>,
237
238    /// Model input that contains position IDs for each position.
239    ///
240    /// This input has a batch dimension.
241    pub position_ids: &'a str,
242
243    /// Boolean input that is set to false on the first run and true on
244    /// subsequent runs.
245    pub use_cache_flag: &'a str,
246}
247
248/// Contains essential configuration needed for a `Generator` to execute a
249/// model, such as the roles of different inputs and outputs.
250#[derive(Default)]
251pub struct GeneratorConfig<'a> {
252    /// Specifies names and roles of model inputs and outputs.
253    pub model_inputs: ModelInputsConfig<'a>,
254
255    /// Reserve capacity in the KV cache for a given number of tokens.
256    ///
257    /// By default the KV cache starts off with a small initial capacity and is
258    /// re-allocated if needed after a token is generated, with the capacity
259    /// being doubled each time to amortize overhead.
260    ///
261    /// If the maximum number of tokens that will be generated is known in
262    /// advance, this can be used to reduce the overhead of re-allocating the
263    /// KV cache during inference.
264    pub kv_cache_capacity: Option<usize>,
265}
266
267impl Default for ModelInputsConfig<'_> {
268    /// Return default model input names.
269    ///
270    /// These are based on [Hugging Face's
271    /// Optimum](https://huggingface.co/docs/optimum/en/index) model exporter.
272    fn default() -> Self {
273        ModelInputsConfig {
274            input_ids: "input_ids",
275            logits: "logits",
276            attention_mask: "attention_mask",
277            cache_position: "cache_position",
278            position_ids: "position_ids",
279            use_cache_flag: "use_cache_branch",
280
281            // Patterns are matched in order, so patterns with longer prefixes/
282            // suffixes are listed first to ensure we match them.
283            kv_caches: [
284                // "Merged" decoders exported by Optimum for encoder-decoder
285                // models. These have KV caches for both the self-attention and
286                // cross-attention modules.
287                KVCachePair {
288                    input: ("past_key_values.", ".decoder.key").into(),
289                    output: ("present.", ".decoder.key").into(),
290                    encoder: false,
291                },
292                KVCachePair {
293                    input: ("past_key_values.", ".decoder.value").into(),
294                    output: ("present.", ".decoder.value").into(),
295                    encoder: false,
296                },
297                KVCachePair {
298                    input: ("past_key_values.", ".encoder.key").into(),
299                    output: ("present.", ".encoder.key").into(),
300                    encoder: true,
301                },
302                KVCachePair {
303                    input: ("past_key_values.", ".encoder.value").into(),
304                    output: ("present.", ".encoder.value").into(),
305                    encoder: true,
306                },
307                // Decoder-only models exported by Optimum.
308                KVCachePair {
309                    input: ("past_key_values.", ".key").into(),
310                    output: ("present.", ".key").into(),
311                    encoder: false,
312                },
313                KVCachePair {
314                    input: ("past_key_values.", ".value").into(),
315                    output: ("present.", ".value").into(),
316                    encoder: false,
317                },
318            ]
319            .into(),
320        }
321    }
322}
323
324/// Generates a token ID sequence using a transformer decoder model.
325///
326/// This is an iterator that runs the model on each call to [`Iterator::next`]
327/// and yields a result containing the next token ID or an error.
328///
329/// The token ID sequence can be converted to text using the
330/// [`decode`](GeneratorUtils::decode) method of the [`GeneratorUtils`] trait.
331///
332/// The `GeneratorUtils` trait also provides useful wrappers for the output,
333/// such as stopping generation when an end-of-text token is reached. You can
334/// also use all of the standard iterator adapters. For example
335/// `generator.take(30)` will return an iterator that stops generation after 30
336/// tokens have been produced.
337///
338/// ## Processing pipeline
339///
340/// Each call to [`next`](Iterator::next) performs the following steps:
341///
342/// 1. Compute model inputs. This includes the token IDs (prompt on first run,
343///    most recently sampled token after that), constant inputs and
344///    position-varying inputs (eg. attention mask).
345/// 2. Run the model
346/// 3. Apply filters to model outputs ("logits")
347/// 4. Sample a token from the logits
348/// 5. Save the sampled token and KV-caches from the model for the next
349///    generation step.
350/// 6. Return the sampled token as the iterator output
351///
352/// ## Logit filters
353///
354/// The raw model outputs can be modified before sampling by configuring a
355/// [`LogitsFilter`] using [`with_logits_filter`](Generator::with_logits_filter).
356///
357/// A chain of filters can be created using [`Chain`](crate::filter::Chain).
358/// When setting up the chain, a best practice is to put the most selective
359/// filters first (ie. those which eliminate the most tokens).
360///
361///
362/// ```no_run
363/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
364/// use rten::Model;
365/// use rten_generate::Generator;
366/// use rten_generate::filter::Chain;
367/// use rten_generate::sampler::Multinomial;
368///
369/// let model = Model::load_file("model.onnx")?;
370/// let mut generator = Generator::from_model(&model)?
371///   .with_logits_filter(
372///     Chain::new()
373///       .top_k(30) // Remove all tokens except the 30 with the highest score
374///       .temperature(0.7) // Scale scores using `score / 0.7`.
375///       .top_p(0.9) // Take the top N tokens whose cumulative probability exceeds 0.9.
376///   )
377///   // Select a random token from the filtered set according to the probability
378///   // of each.
379///   .with_sampler(Multinomial::new());
380///
381/// # Ok(()) }
382/// ```
383///
384/// ## Sampling
385///
386/// The token ID is sampled from the outputs of the model (the "logits") using
387/// a [`Sampler`]. By default this is [`ArgMax`] which simply chooses
388/// the token with the highest probability. The sampler can be configured using
389/// [`with_sampler`](Self::with_sampler).
390///
391/// ## Key-value caches and generation performance
392///
393/// To enable efficient decoding, the model should have inputs and outputs for
394/// the [key-value
395/// cache](https://peterchng.com/blog/2024/06/11/what-is-the-transformer-kv-cache/).
396/// The generator will work with models that do not have cache inputs, but
397/// decoding of long output sequences will be much slower.
398pub struct Generator<'a> {
399    model: &'a dyn Model,
400
401    run_options: Option<RunOptions>,
402
403    /// Additional constant model inputs (eg. encoder outputs) passed to the
404    /// model at each step.
405    constant_inputs: Vec<(NodeId, ValueOrView<'a>)>,
406
407    /// Additional model inputs computed using constant propagation. This
408    /// effectively caches parts of the graph that don't change in each
409    /// generation step. This is `None` if the cache is out of date.
410    constant_prop_inputs: Option<Vec<(NodeId, Value)>>,
411
412    /// Additional varying model inputs computed and passed to the model at
413    /// each step. The functions receive `(batch_size, sequence_positions)` as inputs.
414    #[allow(clippy::type_complexity)]
415    varying_inputs: Vec<(NodeId, &'a dyn Fn(usize, Range<usize>) -> ValueOrView<'a>)>,
416
417    /// Input token IDs for the next run of the model.
418    input_ids: Vec<TokenId>,
419
420    /// Position ID associated with the first token in `input_ids`.
421    input_offset: usize,
422
423    /// Input node IDs
424    input_ids_input: NodeId,
425
426    /// Output node IDs
427    logits_output: NodeId,
428
429    /// Filter used to modify logits before sampling.
430    logits_filter: Option<Box<dyn LogitsFilter + 'a>>,
431
432    /// Sampler used to get the next token ID from the output logits.
433    sampler: Box<dyn Sampler + 'a>,
434
435    /// Previously sampled tokens. These are retained for conditional filtering
436    /// and sampling.
437    prev_tokens: Vec<u32>,
438
439    /// Self-attention key-value cache. This is extended on each iteration.
440    kv_cache: Vec<KvCache>,
441
442    /// Cross-attention key-value cache.
443    ///
444    /// This is used by encoder-decoder models. The cross-attention values
445    /// are computed on the first run and reused in subsequent runs.
446    encoder_kv_cache: Vec<KvCache>,
447}
448
449impl<'a> Generator<'a> {
450    /// Create a generator that iteratively produces tokens using a model.
451    ///
452    /// This function assumes default names for model inputs and outputs
453    /// based on the conventions of Hugging Face's Optimum exporter. These
454    /// can be customized using [`from_model_config`](Self::from_model_config).
455    ///
456    /// The model must have the required inputs:
457    ///
458    ///  - `input_ids` - (batch, sequence) tensor of token IDs
459    ///
460    /// The model may have the optional inputs:
461    ///
462    ///  - `attention_mask` - (batch, sequence) tensor of booleans
463    ///  - `cache_position` - (sequence) tensor of KV-cache positions. Usually the same as `position_ids`
464    ///  - `position_ids` - (batch, sequence) tensor of position indices
465    ///  - `past_key_values.N.key` - (batch, head, past_seq_len, size) key vector cache
466    ///    where `N` is the layer index
467    ///  - `past_key_values.N.value` - (batch, head, past_key_values, size) value vector cache,
468    ///    where `N` is the layer index
469    ///
470    /// **Warning:** Generation of long sequences will be much slower in models without
471    /// key-value caches.
472    ///
473    /// The model must have the outputs:
474    ///
475    ///  - `logits` - output (batch, sequence, vocab) tensor of next token probabilities
476    ///
477    /// The model may have the optional outputs:
478    ///
479    ///  - `present.N.key` - (batch, head, past_seq_len + 1, size) updated key vector cache
480    ///  - `present.N.value` - (batch, head, past_seq_len + 1, size) updated value vector cache
481    pub fn from_model(model: &'a dyn Model) -> Result<Generator<'a>, GeneratorError> {
482        let config = GeneratorConfig {
483            model_inputs: ModelInputsConfig::default(),
484            kv_cache_capacity: None,
485        };
486        Self::from_model_config(model, config)
487    }
488
489    /// Create a generator that iteratively produces tokens using a model.
490    ///
491    /// This is a variant of [`from_model`](Self::from_model) that allows
492    /// specifying custom names for model inputs.
493    pub fn from_model_config(
494        model: &'a dyn Model,
495        config: GeneratorConfig,
496    ) -> Result<Generator<'a>, GeneratorError> {
497        let model_inputs = &config.model_inputs;
498
499        let input_ids_input =
500            model
501                .find_node(model_inputs.input_ids)
502                .ok_or(GeneratorError::InputNotFound(
503                    model_inputs.input_ids.to_string(),
504                ))?;
505
506        let logits_output =
507            model
508                .find_node(model_inputs.logits)
509                .ok_or(GeneratorError::OutputNotFound(
510                    model_inputs.logits.to_string(),
511                ))?;
512
513        // Find inputs and corresponding outputs for key-value cache.
514        let batch_size = 1;
515        let mut kv_cache = Vec::new();
516        let mut encoder_kv_cache = Vec::new();
517        for &input_id in model.input_ids() {
518            let input_info = model
519                .node_info(input_id)
520                .ok_or(GeneratorError::InputNotFound(format!(
521                    "input ID {}",
522                    input_id
523                )))?;
524
525            let name = input_info.name();
526
527            let Some(kv_pattern) = model_inputs
528                .kv_caches
529                .iter()
530                .find(|pat| name.starts_with(pat.input.prefix) && name.ends_with(pat.input.suffix))
531            else {
532                continue;
533            };
534
535            let (n_heads, size) = match *input_info.shape() {
536                [_, Dimension::Fixed(n_heads), _, Dimension::Fixed(size)] => (Some(n_heads), size),
537                [_, _, Dimension::Fixed(size)] => (None, size),
538                _ => {
539                    return Err(GeneratorError::ShapeMismatch(format!(
540                        "input \"{}\" has unexpected shape. expected (batch, past_seq_len, chans) or (batch, heads, past_seq_len, chans) where `heads` and `size` are fixed",
541                        name
542                    )));
543                }
544            };
545
546            let prefix = kv_pattern.input.prefix;
547
548            let layer_index_start = prefix.len();
549            let layer_index_end = name.len() - kv_pattern.input.suffix.len();
550            let layer_index_str = &name[layer_index_start..layer_index_end];
551            let Ok(layer_index) = layer_index_str.parse::<u32>() else {
552                continue;
553            };
554
555            let output_prefix = kv_pattern.output.prefix;
556            let output_suffix = kv_pattern.output.suffix;
557
558            let output_name = format!("{}{}{}", output_prefix, layer_index, output_suffix);
559            let output_id = model
560                .find_node(&output_name)
561                .ok_or(GeneratorError::OutputNotFound(output_name))?;
562
563            // Initial sequence length capacity for KV cache buffer.
564            //
565            // For models that execute different operations on the first vs
566            // subsequent iterations (eg. Hugging Face "merged" models with
567            // past and no-past branches) the input buffer may not be used in
568            // the first iteration. Instead we need to reserve capacity once
569            // the model returns the initial KV cache.
570            //
571            // For other simpler models the input KV cache buffer is used for
572            // all iterations, in which case we would ideally reserve capacity
573            // up-front based on the max expected sequence length.
574            let max_seq_len = config.kv_cache_capacity.unwrap_or(1);
575
576            let kv_cache_entry = KvCache {
577                input_id,
578                output_id,
579                cache: Some(KvCacheData::with_capacity(
580                    batch_size,
581                    n_heads,
582                    size,
583                    max_seq_len,
584                )),
585            };
586
587            if kv_pattern.encoder {
588                encoder_kv_cache.push(kv_cache_entry);
589            } else {
590                kv_cache.push(kv_cache_entry);
591            }
592        }
593
594        let mut generator = Generator {
595            model,
596            run_options: None,
597
598            constant_inputs: Vec::new(),
599            varying_inputs: Vec::new(),
600
601            // Constant propagation is performed as a graph optimization when
602            // the model is loaded, so we only need to re-do it if additional
603            // constant inputs are added.
604            constant_prop_inputs: Some(Vec::new()),
605
606            logits_filter: None,
607            input_ids: vec![],
608            input_ids_input,
609            input_offset: 0,
610            logits_output,
611            kv_cache,
612            encoder_kv_cache,
613            prev_tokens: Vec::new(),
614            sampler: Box::new(ArgMax::new()),
615        };
616
617        let attention_mask_input = model.find_node(model_inputs.attention_mask);
618        if let Some(attention_mask_input) = attention_mask_input {
619            generator = generator
620                .with_varying_input(attention_mask_input, &|batch_size, positions| {
621                    NdTensor::full([batch_size, positions.end], 1i32).into()
622                });
623        }
624
625        let position_ids_input = model.find_node(model_inputs.position_ids);
626        if let Some(position_ids_input) = position_ids_input {
627            generator =
628                generator.with_varying_input(position_ids_input, &|batch_size, positions| {
629                    NdTensor::from_fn([batch_size, positions.len()], |[_batch, pos]| {
630                        (positions.start + pos) as i32
631                    })
632                    .into()
633                });
634        }
635
636        let cache_position_input = model.find_node(model_inputs.cache_position);
637        if let Some(cache_position_input) = cache_position_input {
638            generator =
639                generator.with_varying_input(cache_position_input, &|_batch_size, positions| {
640                    NdTensor::from_fn([positions.len()], |[pos]| (positions.start + pos) as i32)
641                        .into()
642                });
643        }
644
645        let use_cache_input = model.find_node(model_inputs.use_cache_flag);
646        if let Some(use_cache_input) = use_cache_input {
647            generator = generator.with_varying_input(use_cache_input, &|_batch_size, positions| {
648                Tensor::from(if positions.start == 0 { 0i32 } else { 1 }).into()
649            });
650        }
651
652        Ok(generator)
653    }
654
655    /// Set the initial sequence of tokens (aka. the prompt) passed to the model
656    /// when it is first run.
657    ///
658    /// To add new inputs after the initial generation, use
659    /// [`append_prompt`](Self::append_prompt) instead.
660    pub fn with_prompt(mut self, prompt: &[TokenId]) -> Self {
661        self.input_ids = prompt.to_vec();
662        self
663    }
664
665    /// Add input tokens to be included in the next iteration of the model.
666    ///
667    /// This is useful in applications such as chat where the model's input
668    /// alternates between encoded user input and model-generated output.
669    pub fn append_prompt(&mut self, prompt: &[TokenId]) {
670        self.input_ids.extend(prompt);
671    }
672
673    /// Clear the pending prompt for the next generation.
674    ///
675    /// This includes tokens added using [`with_prompt`](Self::with_prompt) and
676    /// [`append_prompt`](Self::append_prompt) as well as the single token
677    /// sampled by the most recent call to [`Iterator::next`](Self::next).
678    ///
679    /// This does not affect the state resulting from tokens that have already
680    /// been generated. In other words, it does not "rewind" the conversation.
681    pub fn clear_prompt(&mut self) {
682        self.input_ids.clear();
683    }
684
685    /// Return the prompt that will be used for the next generation.
686    pub fn prompt(&self) -> &[TokenId] {
687        &self.input_ids
688    }
689
690    /// Return the tokens that have been generated so far, including the prompt.
691    pub fn prev_tokens(&self) -> &[TokenId] {
692        &self.prev_tokens
693    }
694
695    /// Return the current decoder KV-cache length.
696    ///
697    /// Returns `None` if the model does not use a KV cache.
698    pub fn kv_cache_len(&self) -> Option<usize> {
699        self.kv_cache.first()?.size()
700    }
701
702    /// Add a constant input which is provided to the model at each iteration.
703    ///
704    /// A common use case is to pass the outputs of an encoder model to
705    /// an auto-regressive decoder.
706    pub fn with_constant_input(mut self, input_id: NodeId, value: ValueView<'a>) -> Self {
707        self.constant_prop_inputs = None;
708        self.constant_inputs.push((input_id, value.into()));
709        self
710    }
711
712    /// Add an input which varies with the sequence position.
713    ///
714    /// `value_fn` receives `(batch_size, sequence_positions)` as input and
715    /// computes the value for the input at the given positions.
716    ///
717    /// A common use case is to pass position embeddings, if they are not
718    /// computed internally by the model.
719    pub fn with_varying_input<F: Fn(usize, Range<usize>) -> ValueOrView<'a>>(
720        mut self,
721        input_id: NodeId,
722        value_fn: &'a F,
723    ) -> Self {
724        self.varying_inputs.push((input_id, value_fn));
725        self
726    }
727
728    /// Set the filter used to process model output logits before passing them
729    /// to the sampler to select a token ID.
730    ///
731    /// To combine multiple filters, use a [`Chain`](crate::filter::Chain).
732    pub fn with_logits_filter<F: LogitsFilter + 'a>(mut self, filter: F) -> Self {
733        self.logits_filter = Some(Box::new(filter));
734        self
735    }
736
737    /// Set the sampler used to sample the next token ID from the output logits.
738    ///
739    /// The default sampler picks the token with the highest probability
740    /// (aka. greedy sampling). The most common alternative is
741    /// [`Multinomial`](crate::sampler::Multinomial) which samples tokens
742    /// according to their respective probabilities.
743    pub fn with_sampler<S: Sampler + 'a>(mut self, sampler: S) -> Self {
744        self.sampler = Box::new(sampler);
745        self
746    }
747
748    /// Set execution options used when running model inference.
749    pub fn with_run_options(mut self, opts: Option<RunOptions>) -> Self {
750        self.run_options = opts;
751        self
752    }
753
754    /// Feed the current prompt into the model and update the KV cache.
755    ///
756    /// If `generate_logits` is true, the model's logits output is computed and
757    /// returned as a `(batch, sequence, vocab)` tensor.
758    fn generate_impl(
759        &mut self,
760        generate_logits: bool,
761    ) -> Result<Option<NdTensor<f32, 3>>, GeneratorError> {
762        let batch_size = 1;
763        let input_ids: NdTensor<i32, 2> = self
764            .input_ids
765            .iter()
766            .map(|id| *id as i32)
767            .collect::<Tensor<_>>()
768            .into_shape([batch_size, self.input_ids.len()]);
769
770        let input_positions = self.input_offset..self.input_offset + self.input_ids.len();
771
772        let mut model_inputs: Vec<(NodeId, ValueOrView)> =
773            vec![(self.input_ids_input, input_ids.view().into())];
774
775        // Propagate constants on the first run.
776        if self.constant_prop_inputs.is_none() {
777            let inputs = match self.model.partial_run(
778                self.constant_inputs.clone(),
779                &[self.logits_output],
780                self.run_options.clone(),
781            ) {
782                Ok(inputs) => inputs,
783                Err(err) => {
784                    return Err(wrap_error(
785                        err,
786                        "failed to partially evaluate model with constant inputs",
787                    ));
788                }
789            };
790            self.constant_prop_inputs = Some(inputs);
791        }
792
793        if let Some(constants) = self.constant_prop_inputs.as_ref() {
794            model_inputs.extend(
795                constants
796                    .iter()
797                    .map(|(node_id, output)| (*node_id, output.as_view().into())),
798            );
799        }
800
801        if !self.varying_inputs.is_empty() {
802            model_inputs.extend(self.varying_inputs.iter().map(|(node_id, value_fn)| {
803                (*node_id, value_fn(batch_size, input_positions.clone()))
804            }));
805        }
806
807        // Add key-value cache from previous run. The model takes ownership
808        // of the KV-cache tensor during the run so it can efficiently append
809        // the entry for the current step, without copying the existing buffer.
810        for entry in self.kv_cache.iter_mut() {
811            let cache = entry.cache.take();
812            match cache {
813                Some(KvCacheData::BatchSeqChans(cache)) => {
814                    model_inputs.push((entry.input_id, cache.into()));
815                }
816                Some(KvCacheData::BatchHeadSeqChans(cache)) => {
817                    model_inputs.push((entry.input_id, cache.into()));
818                }
819                None => {}
820            }
821        }
822
823        // Add cross-attention key-value cache.
824        for entry in self.encoder_kv_cache.iter() {
825            match &entry.cache {
826                Some(KvCacheData::BatchSeqChans(cache)) => {
827                    model_inputs.push((entry.input_id, cache.into()));
828                }
829                Some(KvCacheData::BatchHeadSeqChans(cache)) => {
830                    model_inputs.push((entry.input_id, cache.into()));
831                }
832                None => {}
833            }
834        }
835
836        // Run the model and collect updated KV cache and logits.
837        let mut model_outputs: Vec<NodeId> = self
838            .kv_cache
839            .iter()
840            .map(|entry| entry.output_id)
841            .chain(self.encoder_kv_cache.iter().map(|entry| entry.output_id))
842            .collect();
843
844        if generate_logits {
845            model_outputs.push(self.logits_output);
846        }
847
848        let mut outputs = self
849            .model
850            .run(model_inputs, &model_outputs, self.run_options.clone())
851            .map_err(|e| wrap_error(e, "failed to run model"))?;
852
853        // Update the self-attention key-value cache.
854        //
855        // The KV cache tensors returned from the model should be the same as
856        // the passed in tensors, but extended by one element along the sequence
857        // axis.
858        for cache_entry in self.kv_cache.iter_mut() {
859            let output = outputs.remove(0);
860
861            let err_context = "failed to save self-attention KV-cache";
862            let mut kv_cache = match output.ndim() {
863                3 => KvCacheData::BatchSeqChans(
864                    output.try_into().map_err(|e| wrap_error(e, err_context))?,
865                ),
866                4 => KvCacheData::BatchHeadSeqChans(
867                    output.try_into().map_err(|e| wrap_error(e, err_context))?,
868                ),
869                ndim => {
870                    return Err(wrap_error(
871                        format!("KV cache has {} dims, expected 3 or 4", ndim),
872                        err_context,
873                    ));
874                }
875            };
876
877            // Grow the KV cache buffer if it has reached the limit of its
878            // pre-allocated sequence length.
879            //
880            // Double the capacity each time to amortize the costs of copying
881            // the previous buffer.
882            if !kv_cache.has_capacity(kv_cache.sequence_len() + 1) {
883                kv_cache = kv_cache.clone_with_capacity(kv_cache.sequence_len() * 2);
884            }
885
886            cache_entry.cache = Some(kv_cache);
887        }
888
889        // Update the cross-attention key-value cache.
890        for cache_entry in self.encoder_kv_cache.iter_mut() {
891            let output = outputs.remove(0);
892            if output.is_empty() {
893                // Optimum-exported models only return encoder KV-cache tensors
894                // on the first run and dummy empty tensors on subsequent runs.
895                // Ignore these and continue to use the value from the first run.
896                continue;
897            }
898
899            let err_context = "failed to save cross-attention KV-cache";
900            let kv_cache = match output.ndim() {
901                3 => KvCacheData::BatchSeqChans(
902                    output.try_into().map_err(|e| wrap_error(e, err_context))?,
903                ),
904                4 => KvCacheData::BatchHeadSeqChans(
905                    output.try_into().map_err(|e| wrap_error(e, err_context))?,
906                ),
907                ndim => {
908                    return Err(wrap_error(
909                        format!("KV cache has {} dims, expected 3 or 4", ndim),
910                        err_context,
911                    ));
912                }
913            };
914            cache_entry.cache = Some(kv_cache);
915        }
916
917        // Save prompt for use in logit filters.
918        if self.prev_tokens.is_empty() {
919            self.prev_tokens.extend(self.input_ids.iter());
920        }
921
922        // Clear the prompt for the next generation.
923        if !self.kv_cache.is_empty() {
924            self.input_offset += self.input_ids.len();
925            self.input_ids.clear();
926        }
927
928        if generate_logits {
929            // Apply filtering to model outputs.
930            let logits: NdTensor<f32, 3> = outputs
931                .remove(0)
932                .try_into()
933                .map_err(|e| wrap_error(e, "failed to extract logits from model outputs"))?;
934            Ok(Some(logits))
935        } else {
936            Ok(None)
937        }
938    }
939
940    /// Run the model and update the KV cache.
941    ///
942    /// Unlike calling [`next`](Self::next) this does not generate the logits
943    /// output and sample a token.
944    pub fn process_prompt(&mut self) -> Result<(), GeneratorError> {
945        self.generate_impl(false).map(|_| ())
946    }
947
948    /// Run the model and generate the next token.
949    ///
950    /// The generated token is automatically added to the prompt for the next
951    /// generation.
952    fn generate_next_token(&mut self) -> Result<TokenId, GeneratorError> {
953        let logits = self.generate_impl(true)?.expect("should have logits");
954        let last_logits = Logits::dense(logits.slice((0, -1)).to_contiguous().to_vec());
955        let filtered_logits = if let Some(filter) = self.logits_filter.as_ref() {
956            filter.filter(last_logits, &self.prev_tokens)
957        } else {
958            last_logits
959        };
960
961        // If filtering removed all the tokens, we have nothing to sample from.
962        if filtered_logits.is_empty() {
963            return Err(GeneratorError::GenerateError(
964                "filtered logits are empty".into(),
965            ));
966        }
967
968        // Sample output token.
969        let next_id = self.sampler.sample(&filtered_logits);
970
971        // Append token to prompt for next generation.
972        self.prev_tokens.push(next_id);
973        self.input_ids.push(next_id);
974
975        Ok(next_id)
976    }
977}
978
979fn wrap_error<E>(error: E, context: &str) -> GeneratorError
980where
981    E: Into<Box<dyn Error>>,
982{
983    let error_ctx = ErrorContext {
984        error: error.into(),
985        context: context.to_string(),
986    };
987    GeneratorError::GenerateError(error_ctx.into())
988}
989
990/// Output items from a [`Generator`].
991pub type GeneratorItem = Result<TokenId, GeneratorError>;
992
993impl Iterator for Generator<'_> {
994    type Item = Result<TokenId, GeneratorError>;
995
996    /// Run the model and generate the next output token.
997    ///
998    /// The generated token is added to the prompt. This enables calling `next`
999    /// repeatedly to generate a sequence of tokens. The prompt can be extended
1000    /// or cleared using [`append_prompt`](Self::append_prompt) or
1001    /// [`clear_prompt`](Self::clear_prompt) respectively.
1002    fn next(&mut self) -> Option<Self::Item> {
1003        Some(self.generate_next_token())
1004    }
1005}
1006
1007/// Iterator utilities that wrap a [`Generator`] to perform common tasks such
1008/// as stopping generation when an end-of-text token is encountered.
1009pub trait GeneratorUtils: Iterator<Item = GeneratorItem> + Sized {
1010    /// Stop the generator when any token in `eos_tokens` is encountered.
1011    fn stop_on_tokens<A: AsRef<[u32]>>(self, eos_tokens: A) -> impl Iterator<Item = GeneratorItem> {
1012        self.take_while(move |tok| match tok {
1013            Ok(tok_id) => !eos_tokens.as_ref().contains(tok_id),
1014            _ => true,
1015        })
1016    }
1017
1018    /// Decode the tokens to text using a tokenizer.
1019    ///
1020    /// To get both the decoded text and token IDs, call
1021    /// [`with_ids`](TextDecoder::with_ids) on the result.
1022    #[cfg(feature = "text-decoder")]
1023    fn decode(self, tokenizer: &Tokenizer) -> TextDecoder<'_, Self> {
1024        TextDecoder::wrap(self, tokenizer)
1025    }
1026
1027    /// Record timing metrics.
1028    ///
1029    /// Metrics such as the number of tokens generated per second will be
1030    /// available from `metrics` after generation has finished.
1031    fn profile(self, metrics: &mut Metrics) -> impl Iterator<Item = Self::Item> {
1032        Profiler::wrap(self, metrics)
1033    }
1034}
1035
1036impl<I: Iterator<Item = GeneratorItem>> GeneratorUtils for I {}
1037
1038/// Wraps a [`Generator`] to record timing metrics into a [`Metrics`] struct.
1039struct Profiler<'a, G: Iterator> {
1040    generator: G,
1041    metrics: &'a mut Metrics,
1042}
1043
1044impl<'a, G: Iterator> Profiler<'a, G> {
1045    fn wrap(generator: G, metrics: &'a mut Metrics) -> Profiler<'a, G> {
1046        Profiler { generator, metrics }
1047    }
1048}
1049
1050impl<G: Iterator> Iterator for Profiler<'_, G> {
1051    type Item = G::Item;
1052
1053    fn next(&mut self) -> Option<Self::Item> {
1054        let start = std::time::Instant::now();
1055        let item = self.generator.next()?;
1056        self.metrics.add_step_duration(start.elapsed());
1057        Some(item)
1058    }
1059}
1060
1061#[cfg(test)]
1062mod tests {
1063    use std::cell::{Cell, RefCell};
1064    use std::collections::HashMap;
1065    use std::error::Error;
1066    use std::rc::Rc;
1067
1068    use rten::{Dimension, NodeId, RunOptions, Value, ValueOrView};
1069    use rten_tensor::NdTensor;
1070    use rten_tensor::prelude::*;
1071
1072    use super::{Generator, GeneratorUtils, Logits};
1073    use crate::filter::LogitsFilter;
1074    use crate::metrics::Metrics;
1075    use crate::model::{Model, NodeInfo};
1076
1077    struct FakeModel {
1078        nodes: Vec<NodeInfo>,
1079        input_ids: Vec<NodeId>,
1080        output_ids: Vec<NodeId>,
1081
1082        // Next inference step
1083        step: Cell<usize>,
1084
1085        // Inference outputs for each step
1086        outputs: Vec<HashMap<NodeId, Value>>,
1087
1088        // Inference inputs for each step
1089        inputs: RefCell<Vec<HashMap<NodeId, Value>>>,
1090
1091        // Run options for most recent inference
1092        run_opts: Cell<Option<RunOptions>>,
1093    }
1094
1095    impl FakeModel {
1096        /// Return a model with a given set of inputs and outputs.
1097        fn with_inputs_and_outputs(inputs: &[NodeInfo], outputs: &[NodeInfo]) -> FakeModel {
1098            let node_infos = [inputs, outputs].concat();
1099            let input_ids = (0..inputs.len())
1100                .map(|id| NodeId::from_u32(id as u32))
1101                .collect();
1102            let output_ids = (inputs.len()..(inputs.len() + outputs.len()))
1103                .map(|id| NodeId::from_u32(id as u32))
1104                .collect();
1105
1106            FakeModel {
1107                input_ids,
1108                output_ids,
1109                nodes: node_infos,
1110                step: Cell::new(0),
1111                inputs: RefCell::new(vec![]),
1112                outputs: vec![],
1113                run_opts: Cell::new(None),
1114            }
1115        }
1116
1117        /// Add inference outputs for one run of the model.
1118        fn add_outputs(&mut self, outputs: HashMap<NodeId, Value>) {
1119            self.outputs.push(outputs)
1120        }
1121
1122        /// Get an input for the `step`th run of the model.
1123        fn get_inputs(&self, step: usize, node_id: NodeId) -> Option<Value> {
1124            self.inputs
1125                .borrow()
1126                .get(step)
1127                .map(|step_inputs| step_inputs.get(&node_id))
1128                .flatten()
1129                .cloned()
1130        }
1131    }
1132
1133    impl Model for FakeModel {
1134        fn find_node(&self, name: &str) -> Option<NodeId> {
1135            self.nodes
1136                .iter()
1137                .position(|info| info.name() == name)
1138                .map(|pos| NodeId::from_u32(pos as u32))
1139        }
1140
1141        fn node_info(&self, id: NodeId) -> Option<NodeInfo> {
1142            self.nodes.get(id.as_usize()).cloned()
1143        }
1144
1145        fn input_ids(&self) -> &[NodeId] {
1146            &self.input_ids
1147        }
1148
1149        fn run(
1150            &self,
1151            inputs: Vec<(NodeId, ValueOrView)>,
1152            outputs: &[NodeId],
1153            opts: Option<RunOptions>,
1154        ) -> Result<Vec<Value>, Box<dyn Error>> {
1155            if let Some((input_id, _)) = inputs.iter().find(|(id, _)| !self.input_ids.contains(id))
1156            {
1157                return Err(format!("invalid input ID {}", input_id).into());
1158            }
1159            for &expected_input in self.input_ids.iter() {
1160                if !inputs.iter().any(|&(id, _)| id == expected_input) {
1161                    return Err(format!("missing input ID {}", expected_input).into());
1162                }
1163            }
1164
1165            if let Some(output_id) = outputs.iter().find(|id| !self.output_ids.contains(id)) {
1166                return Err(format!("invalid output ID {}", output_id).into());
1167            }
1168
1169            self.inputs.borrow_mut().push(
1170                inputs
1171                    .into_iter()
1172                    .map(|(id, input_or_output)| (id, input_or_output.to_owned()))
1173                    .collect(),
1174            );
1175
1176            let result = outputs
1177                .iter()
1178                .map(|id| {
1179                    let step_outputs = self
1180                        .outputs
1181                        .get(self.step.get())
1182                        .expect("outputs not specified for step");
1183
1184                    step_outputs
1185                        .get(id)
1186                        .cloned()
1187                        .expect("invalid output node ID")
1188                })
1189                .collect();
1190
1191            self.step.set(self.step.get() + 1);
1192            self.run_opts.set(opts);
1193
1194            Ok(result)
1195        }
1196
1197        fn partial_run(
1198            &self,
1199            _inputs: Vec<(NodeId, ValueOrView)>,
1200            _outputs: &[NodeId],
1201            _opts: Option<RunOptions>,
1202        ) -> Result<Vec<(NodeId, Value)>, Box<dyn Error>> {
1203            Ok(Vec::new())
1204        }
1205    }
1206
1207    /// Generate `[batch, sequence, n_vocab]` tensor for `logits` output.
1208    fn generate_logits(n_vocab: usize, token_ids: &[u32]) -> NdTensor<f32, 3> {
1209        let mut logits = NdTensor::zeros([1, token_ids.len(), n_vocab]);
1210        for (idx, id) in token_ids.iter().copied().enumerate() {
1211            logits[[0, idx, id as usize]] = 1.0;
1212        }
1213        logits
1214    }
1215
1216    #[derive(Copy, Clone, PartialEq)]
1217    struct TransformerParams {
1218        /// Number of layers. This determines the number of KV-cache inputs
1219        /// and outputs.
1220        n_layers: usize,
1221        n_heads: usize,
1222        n_embed: usize,
1223
1224        /// Vocabulary size. This is the size of the last dimension of the
1225        /// logits output.
1226        n_vocab: usize,
1227    }
1228
1229    impl Default for TransformerParams {
1230        fn default() -> Self {
1231            Self {
1232                n_layers: 5,
1233                n_heads: 3,
1234                n_vocab: 5,
1235                n_embed: 8,
1236            }
1237        }
1238    }
1239
1240    #[derive(Copy, Clone, PartialEq)]
1241    enum KvCacheType {
1242        /// Add KV-cache inputs and outputs for self-attention.
1243        Decoder,
1244        /// Add KV-cache inputs and outputs for self-attention and cross-
1245        /// attention.
1246        EncoderDecoder,
1247    }
1248
1249    /// Create a fake transformer model using the default names for inputs and
1250    /// outputs.
1251    fn fake_transformer_model(
1252        params: TransformerParams,
1253        kv_cache: Option<KvCacheType>,
1254        prompt_len: usize,
1255        output_token_ids: &[u32],
1256    ) -> FakeModel {
1257        let TransformerParams {
1258            n_layers,
1259            n_heads,
1260            n_vocab,
1261            n_embed,
1262        } = params;
1263
1264        // Add inputs and outputs using the standard names.
1265        let mut inputs = vec![
1266            NodeInfo::from_name_shape("input_ids", &[]),
1267            NodeInfo::from_name_shape("cache_position", &[]),
1268            NodeInfo::from_name_shape("position_ids", &[]),
1269            NodeInfo::from_name_shape("attention_mask", &[]),
1270        ];
1271        let mut outputs = vec![NodeInfo::from_name_shape("logits", &[])];
1272
1273        // Add KV-cache inputs and outputs.
1274        let mut kv_cache_output_names = Vec::new();
1275        if let Some(kv_cache_type) = kv_cache {
1276            let dims = [
1277                Dimension::Symbolic("batch".to_string()),
1278                Dimension::Fixed(n_heads as usize),
1279                Dimension::Symbolic("seq".to_string()),
1280                Dimension::Fixed(n_embed),
1281            ];
1282            let make_name_info = |name: &str| NodeInfo::from_name_shape(name, &dims);
1283
1284            for layer in 0..n_layers {
1285                let past_names: Vec<String>;
1286                let present_names: Vec<String>;
1287
1288                match kv_cache_type {
1289                    KvCacheType::Decoder => {
1290                        past_names = [
1291                            format!("past_key_values.{}.key", layer),
1292                            format!("past_key_values.{}.value", layer),
1293                        ]
1294                        .into();
1295                        present_names = [
1296                            format!("present.{}.key", layer),
1297                            format!("present.{}.value", layer),
1298                        ]
1299                        .into();
1300                    }
1301                    KvCacheType::EncoderDecoder => {
1302                        past_names = [
1303                            format!("past_key_values.{}.decoder.key", layer),
1304                            format!("past_key_values.{}.decoder.value", layer),
1305                            format!("past_key_values.{}.encoder.key", layer),
1306                            format!("past_key_values.{}.encoder.value", layer),
1307                        ]
1308                        .into();
1309
1310                        present_names = [
1311                            format!("present.{}.decoder.key", layer),
1312                            format!("present.{}.decoder.value", layer),
1313                            format!("present.{}.encoder.key", layer),
1314                            format!("present.{}.encoder.value", layer),
1315                        ]
1316                        .into();
1317                    }
1318                }
1319
1320                inputs.extend(past_names.iter().map(|name| make_name_info(&name)));
1321                outputs.extend(present_names.iter().map(|name| make_name_info(&name)));
1322                kv_cache_output_names.extend(present_names);
1323            }
1324
1325            if kv_cache_type == KvCacheType::EncoderDecoder {
1326                inputs.push(NodeInfo::from_name_shape("use_cache_branch", &[]));
1327            }
1328        }
1329
1330        let mut model = FakeModel::with_inputs_and_outputs(&inputs, &outputs);
1331        let logits_id = model.find_node("logits").unwrap();
1332
1333        for (step, output_token_id) in output_token_ids.iter().copied().enumerate() {
1334            assert!(
1335                output_token_id < n_vocab as u32,
1336                "token ID is invalid for vocab size"
1337            );
1338
1339            let logits = if kv_cache.is_some() {
1340                generate_logits(n_vocab, &[output_token_id])
1341            } else {
1342                generate_logits(n_vocab, &output_token_ids[..=step])
1343            };
1344
1345            let mut outputs = HashMap::new();
1346            outputs.insert(logits_id, Value::FloatTensor(logits.into()));
1347
1348            // Add KV cache outputs
1349            for kv_output in kv_cache_output_names.iter() {
1350                let kv_output_id = model.find_node(&kv_output).unwrap();
1351                let context_len = if step == 0 {
1352                    prompt_len
1353                } else {
1354                    prompt_len + step - 1
1355                };
1356
1357                let is_encoder = model
1358                    .node_info(kv_output_id)
1359                    .as_ref()
1360                    .map(|ni| ni.name())
1361                    .unwrap_or("")
1362                    .contains("encoder");
1363
1364                let output_n_embed = if is_encoder && step > 0 {
1365                    // Encoder KV cache outputs are only used on the first run.
1366                    // On subsequent runs return a dummy output, which should
1367                    // be ignored.
1368                    0
1369                } else {
1370                    n_embed
1371                };
1372
1373                outputs.insert(
1374                    kv_output_id,
1375                    Value::FloatTensor(
1376                        NdTensor::zeros([1, n_heads, context_len, output_n_embed]).into(),
1377                    ),
1378                );
1379            }
1380
1381            model.add_outputs(outputs);
1382        }
1383
1384        model
1385    }
1386
1387    fn test_generator_impl(kv_cache_type: Option<KvCacheType>) -> Result<(), Box<dyn Error>> {
1388        let params = TransformerParams::default();
1389        let expected_token_ids = [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 0, 0];
1390        let prompt = [1, 2, 3, 1, 2, 3];
1391        let model =
1392            fake_transformer_model(params, kv_cache_type, prompt.len(), &expected_token_ids);
1393
1394        let generator = Generator::from_model(&model)?;
1395        let generation_len = 10;
1396
1397        let output_token_ids: Vec<_> = generator
1398            .with_prompt(&prompt)
1399            .take(generation_len)
1400            .map(|id| id.expect("generation failed"))
1401            .collect();
1402
1403        // Check generator outputs
1404        assert_eq!(output_token_ids.len(), generation_len);
1405        assert_eq!(output_token_ids, &expected_token_ids[..generation_len]);
1406
1407        // Check model inputs
1408        let input_id = model.find_node("input_ids").unwrap();
1409        let position_ids = model.find_node("position_ids").unwrap();
1410        let attention_mask = model.find_node("attention_mask").unwrap();
1411        let cache_branch = model.find_node("use_cache_branch");
1412        let cache_position = model.find_node("cache_position").unwrap();
1413
1414        for step in 0..generation_len {
1415            let step_inputs = model.get_inputs(step, input_id).unwrap();
1416            let step_inputs: NdTensor<i32, 2> = step_inputs.try_into().unwrap();
1417
1418            let step_pos_ids = model.get_inputs(step, position_ids).unwrap();
1419            let step_pos_ids: NdTensor<i32, 2> = step_pos_ids.try_into().unwrap();
1420
1421            let step_cache_pos = model.get_inputs(step, cache_position).unwrap();
1422            let step_cache_pos: NdTensor<i32, 1> = step_cache_pos.try_into().unwrap();
1423
1424            let step_attn_mask = model.get_inputs(step, attention_mask).unwrap();
1425            let step_attn_mask: NdTensor<i32, 2> = step_attn_mask.try_into().unwrap();
1426
1427            let cache_branch = cache_branch.map(|cb_id| {
1428                let cb = model.get_inputs(step, cb_id).unwrap();
1429                let cb: NdTensor<i32, 0> = cb.try_into().unwrap();
1430                cb
1431            });
1432
1433            if step == 0 {
1434                assert_eq!(step_inputs.size(1), prompt.len());
1435                assert!(
1436                    step_inputs
1437                        .iter()
1438                        .map(|x| *x as u32)
1439                        .eq(prompt.iter().copied())
1440                );
1441
1442                assert_eq!(step_attn_mask.size(1), prompt.len());
1443                assert!(step_attn_mask.iter().all(|x| *x == 1));
1444
1445                assert_eq!(step_pos_ids.size(1), prompt.len());
1446                assert!(step_pos_ids.iter().map(|x| *x as usize).eq(0..prompt.len()));
1447
1448                assert_eq!(step_cache_pos.size(0), prompt.len());
1449                assert!(
1450                    step_cache_pos
1451                        .iter()
1452                        .map(|x| *x as usize)
1453                        .eq(0..prompt.len())
1454                );
1455
1456                if let Some(cache_branch) = cache_branch {
1457                    assert_eq!(cache_branch.item(), Some(&0));
1458                }
1459            } else if kv_cache_type.is_some() {
1460                assert_eq!(step_inputs.size(1), 1);
1461                assert_eq!(step_inputs[[0, 0]] as u32, expected_token_ids[step - 1]);
1462
1463                assert_eq!(step_attn_mask.size(1), prompt.len() + step);
1464                assert_eq!(step_attn_mask[[0, 0]], 1);
1465
1466                assert_eq!(step_pos_ids.size(1), 1);
1467                assert_eq!(step_pos_ids[[0, 0]], (prompt.len() + step - 1) as i32);
1468
1469                assert_eq!(step_cache_pos.size(0), 1);
1470                assert_eq!(step_cache_pos[[0]], (prompt.len() + step - 1) as i32);
1471
1472                if let Some(cache_branch) = cache_branch {
1473                    assert_eq!(cache_branch.item(), Some(&1));
1474                }
1475            } else {
1476                let expected_inputs: Vec<i32> = prompt
1477                    .iter()
1478                    .copied()
1479                    .chain(expected_token_ids)
1480                    .take(prompt.len() + step)
1481                    .map(|x| x as i32)
1482                    .collect();
1483                assert_eq!(
1484                    step_inputs,
1485                    NdTensor::from_data([1, expected_inputs.len()], expected_inputs)
1486                );
1487
1488                let expected_attn_mask = vec![1i32; prompt.len() + step];
1489                assert_eq!(
1490                    step_attn_mask,
1491                    NdTensor::from_data([1, expected_attn_mask.len()], expected_attn_mask)
1492                );
1493
1494                let expected_pos_ids: Vec<i32> =
1495                    (0..prompt.len() + step).map(|x| x as i32).collect();
1496                assert_eq!(
1497                    step_pos_ids,
1498                    NdTensor::from_data([1, expected_pos_ids.len()], expected_pos_ids.clone())
1499                );
1500                assert_eq!(
1501                    step_cache_pos,
1502                    NdTensor::from_data([expected_pos_ids.len()], expected_pos_ids)
1503                );
1504            }
1505        }
1506
1507        Ok(())
1508    }
1509
1510    #[test]
1511    fn test_generator_with_decoder_kv_cache() -> Result<(), Box<dyn Error>> {
1512        test_generator_impl(Some(KvCacheType::Decoder))
1513    }
1514
1515    #[test]
1516    fn test_generator_with_encoder_decoder_kv_cache() -> Result<(), Box<dyn Error>> {
1517        test_generator_impl(Some(KvCacheType::EncoderDecoder))
1518    }
1519
1520    #[test]
1521    fn test_generator_without_kv_cache() -> Result<(), Box<dyn Error>> {
1522        test_generator_impl(None)
1523    }
1524
1525    #[test]
1526    fn test_generator_append_prompt() -> Result<(), Box<dyn Error>> {
1527        let mut params = TransformerParams::default();
1528        params.n_vocab = 110;
1529        let output_token_ids = [0, 1, 2, 3, 4, 5, 6, 7, 8];
1530        let prompt = [99];
1531        let model = fake_transformer_model(
1532            params,
1533            Some(KvCacheType::Decoder),
1534            prompt.len(),
1535            &output_token_ids,
1536        );
1537
1538        let mut generator = Generator::from_model(&model)?.with_prompt(&prompt);
1539
1540        generator.next();
1541        generator.append_prompt(&[100]);
1542        generator.next();
1543        generator.append_prompt(&[101, 102]);
1544        generator.next();
1545
1546        let input_id = model.find_node("input_ids").unwrap();
1547
1548        // The input to the first step is just the prompt.
1549        let inputs = model.get_inputs(0, input_id).unwrap();
1550        let inputs: NdTensor<i32, 2> = inputs.try_into().unwrap();
1551        assert_eq!(inputs, NdTensor::from([[99]]));
1552
1553        // The inputs for the next steps are the output followed by the inputs
1554        // added with `append_prompt`.
1555        let inputs = model.get_inputs(1, input_id).unwrap();
1556        let inputs: NdTensor<i32, 2> = inputs.try_into().unwrap();
1557        assert_eq!(inputs, NdTensor::from([[0, 100]]));
1558
1559        let inputs = model.get_inputs(2, input_id).unwrap();
1560        let inputs: NdTensor<i32, 2> = inputs.try_into().unwrap();
1561        assert_eq!(inputs, NdTensor::from([[1, 101, 102]]));
1562
1563        Ok(())
1564    }
1565
1566    #[test]
1567    fn test_stop_on_tokens() -> Result<(), Box<dyn Error>> {
1568        let params = TransformerParams::default();
1569        let expected_token_ids = [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 0, 0];
1570        let prompt = [1, 2, 3, 1, 2, 3];
1571        let model = fake_transformer_model(
1572            params,
1573            Some(KvCacheType::Decoder),
1574            prompt.len(),
1575            &expected_token_ids,
1576        );
1577
1578        let generator = Generator::from_model(&model)?;
1579
1580        let output_token_ids: Vec<_> = generator
1581            .with_prompt(&prompt)
1582            .stop_on_tokens([4])
1583            .map(|id| id.expect("generation failed"))
1584            .collect();
1585
1586        assert_eq!(output_token_ids, &[0, 1, 2, 3]);
1587
1588        Ok(())
1589    }
1590
1591    #[test]
1592    fn test_process_prompt() -> Result<(), Box<dyn Error>> {
1593        let params = TransformerParams::default();
1594        let expected_token_ids = [0];
1595        let prompt = [1, 2, 3, 1, 2, 3];
1596        let model = fake_transformer_model(
1597            params,
1598            Some(KvCacheType::Decoder),
1599            prompt.len(),
1600            &expected_token_ids,
1601        );
1602
1603        let mut generator = Generator::from_model(&model)?.with_prompt(&prompt);
1604        assert_eq!(generator.prompt(), prompt);
1605        assert!(generator.prev_tokens().is_empty());
1606        assert_eq!(generator.kv_cache_len(), Some(0));
1607
1608        generator.process_prompt().unwrap();
1609
1610        assert!(generator.prompt().is_empty());
1611        assert_eq!(generator.prev_tokens(), prompt);
1612        assert_eq!(generator.kv_cache_len(), Some(prompt.len()));
1613
1614        Ok(())
1615    }
1616
1617    #[test]
1618    fn test_profile() -> Result<(), Box<dyn Error>> {
1619        let params = TransformerParams::default();
1620        let expected_token_ids = [0, 1, 2, 3, 4];
1621        let prompt = [1, 2, 3, 1, 2, 3];
1622        let model = fake_transformer_model(
1623            params,
1624            Some(KvCacheType::Decoder),
1625            prompt.len(),
1626            &expected_token_ids,
1627        );
1628
1629        let generator = Generator::from_model(&model)?;
1630        let mut metrics = Metrics::new();
1631
1632        let output_token_ids: Vec<_> = generator
1633            .with_prompt(&prompt)
1634            .profile(&mut metrics)
1635            .take(expected_token_ids.len())
1636            .map(|id| id.expect("generation failed"))
1637            .collect();
1638
1639        assert_eq!(output_token_ids, expected_token_ids);
1640        assert!(metrics.warmup_duration().is_some());
1641        assert_eq!(metrics.step_durations().len(), output_token_ids.len() - 1);
1642
1643        Ok(())
1644    }
1645
1646    #[test]
1647    fn test_filter() -> Result<(), Box<dyn Error>> {
1648        let mut params = TransformerParams::default();
1649        params.n_vocab = 8; // Must be >2x the max token ID in `expected_token_ids`.
1650
1651        let expected_token_ids = [0, 1, 2, 3];
1652        let prompt = [5, 6, 7];
1653        let model = fake_transformer_model(
1654            params,
1655            Some(KvCacheType::Decoder),
1656            prompt.len(),
1657            &expected_token_ids,
1658        );
1659
1660        let generator = Generator::from_model(&model)?;
1661
1662        // Filter that modifies logits to double the selected token ID.
1663        struct DoubleIndexFilter {
1664            prev_tokens: Rc<RefCell<Vec<u32>>>,
1665        }
1666        impl LogitsFilter for DoubleIndexFilter {
1667            fn filter(&self, logits: Logits, prev_tokens: &[u32]) -> Logits {
1668                self.prev_tokens.replace(prev_tokens.to_vec());
1669
1670                let max_idx = logits
1671                    .logits()
1672                    .iter()
1673                    .zip(logits.indices())
1674                    .max_by(|(x, _i), (y, _j)| x.total_cmp(y))
1675                    .map(|(_x, i)| i)
1676                    .unwrap();
1677
1678                Logits::sparse(vec![1.0], vec![max_idx * 2])
1679            }
1680        }
1681
1682        let prev_tokens = Rc::new(RefCell::new(Vec::new()));
1683        let output_token_ids: Vec<_> = generator
1684            .with_prompt(&prompt)
1685            .with_logits_filter(DoubleIndexFilter {
1686                prev_tokens: prev_tokens.clone(),
1687            })
1688            .take(expected_token_ids.len())
1689            .map(|id| id.expect("generation failed"))
1690            .collect();
1691
1692        assert_eq!(output_token_ids, [0, 2, 4, 6]);
1693        assert_eq!(prev_tokens.borrow().as_slice(), [5, 6, 7, 0, 2, 4]);
1694
1695        Ok(())
1696    }
1697
1698    #[test]
1699    fn test_empty_filter_output() {
1700        let params = TransformerParams::default();
1701        let prompt = [1];
1702        let model = fake_transformer_model(
1703            params,
1704            Some(KvCacheType::Decoder),
1705            prompt.len(),
1706            &[0, 1, 2, 3],
1707        );
1708
1709        struct RemoveAllFilter;
1710        impl LogitsFilter for RemoveAllFilter {
1711            fn filter(&self, _logits: Logits, _prev_tokens: &[u32]) -> Logits {
1712                Logits::dense(vec![])
1713            }
1714        }
1715
1716        let mut generator = Generator::from_model(&model)
1717            .unwrap()
1718            .with_logits_filter(RemoveAllFilter);
1719        let err = generator.next().unwrap().err().unwrap();
1720        assert!(err.to_string().contains("filtered logits are empty"));
1721    }
1722
1723    #[test]
1724    fn test_run_options() -> Result<(), Box<dyn Error>> {
1725        let params = TransformerParams::default();
1726        let expected_token_ids = [0, 1, 2, 3, 4];
1727        let prompt = [1, 2, 3, 1, 2, 3];
1728        let model = fake_transformer_model(
1729            params,
1730            Some(KvCacheType::Decoder),
1731            prompt.len(),
1732            &expected_token_ids,
1733        );
1734
1735        let generator = Generator::from_model(&model)?;
1736
1737        let run_opts = RunOptions::default().with_verbose(true);
1738        let output_token_ids: Vec<_> = generator
1739            .with_prompt(&prompt)
1740            .with_run_options(Some(run_opts.clone()))
1741            .take(expected_token_ids.len())
1742            .map(|id| id.expect("generation failed"))
1743            .collect();
1744
1745        assert_eq!(output_token_ids, expected_token_ids);
1746        assert_eq!(model.run_opts.take(), Some(run_opts));
1747
1748        Ok(())
1749    }
1750}