1use std::error::Error;
4use std::fmt;
5use std::ops::Range;
6
7use rten::{Dimension, NodeId, RunOptions, Value, ValueOrView, ValueView};
8use rten_tensor::prelude::*;
9use rten_tensor::{NdTensor, Tensor};
10
11#[cfg(feature = "text-decoder")]
12use rten_text::{Tokenizer, TokenizerError};
13
14use crate::filter::LogitsFilter;
15use crate::logits::Logits;
16use crate::metrics::Metrics;
17use crate::model::Model;
18use crate::sampler::{ArgMax, Sampler};
19
20#[cfg(feature = "text-decoder")]
21use crate::text_decoder::TextDecoder;
22
23pub type TokenId = u32;
25
26#[derive(Debug)]
28pub enum GeneratorError {
29 InputNotFound(String),
31
32 OutputNotFound(String),
34
35 ShapeMismatch(String),
37
38 GenerateError(Box<dyn Error>),
40
41 #[cfg(feature = "text-decoder")]
43 DecodeError(TokenizerError),
44}
45
46impl fmt::Display for GeneratorError {
47 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
48 match self {
49 GeneratorError::InputNotFound(name) => write!(f, "model input not found: {}", name),
50 GeneratorError::OutputNotFound(name) => write!(f, "model output not found: {}", name),
51 GeneratorError::ShapeMismatch(err) => write!(f, "shape mismatch: {}", err),
52 GeneratorError::GenerateError(err) => write!(f, "generation error: {}", err),
53 #[cfg(feature = "text-decoder")]
54 GeneratorError::DecodeError(err) => write!(f, "decode error: {}", err),
55 }
56 }
57}
58
59impl Error for GeneratorError {}
60
61#[derive(Debug)]
63struct ErrorContext {
64 error: Box<dyn Error>,
65 context: String,
66}
67
68impl Error for ErrorContext {
69 fn source(&self) -> Option<&(dyn Error + 'static)> {
70 Some(self.error.as_ref())
71 }
72}
73
74impl std::fmt::Display for ErrorContext {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 write!(f, "{}: {}", self.context, self.error)
77 }
78}
79
80enum KvCacheData {
81 BatchSeqChans(NdTensor<f32, 3>),
86 BatchHeadSeqChans(NdTensor<f32, 4>),
88}
89
90impl KvCacheData {
91 fn with_capacity(
97 batch_size: usize,
98 n_heads: Option<usize>,
99 size: usize,
100 seq_len_capacity: usize,
101 ) -> KvCacheData {
102 if let Some(n_heads) = n_heads {
103 KvCacheData::BatchHeadSeqChans(NdTensor::with_capacity(
104 [batch_size, n_heads, seq_len_capacity, size],
105 2, ))
107 } else {
108 KvCacheData::BatchSeqChans(NdTensor::with_capacity(
109 [batch_size, seq_len_capacity, size],
110 1, ))
112 }
113 }
114
115 fn sequence_len(&self) -> usize {
117 match self {
118 KvCacheData::BatchSeqChans(data) => data.size(1),
119 KvCacheData::BatchHeadSeqChans(data) => data.size(2),
120 }
121 }
122
123 fn has_capacity(&self, sequence_len: usize) -> bool {
125 match self {
126 KvCacheData::BatchSeqChans(data) => {
127 data.has_capacity(1 , sequence_len)
128 }
129 KvCacheData::BatchHeadSeqChans(data) => {
130 data.has_capacity(2 , sequence_len)
131 }
132 }
133 }
134
135 fn clone_with_capacity(&self, max_sequence_len: usize) -> KvCacheData {
138 let max_sequence_len = max_sequence_len.max(self.sequence_len());
139 match self {
140 KvCacheData::BatchSeqChans(data) => {
141 let [batch, _seq, chans] = data.shape();
142 let mut new_data =
143 NdTensor::with_capacity([batch, max_sequence_len, chans], 1 );
144 new_data.append(1, data).expect("should have capacity");
145 KvCacheData::BatchSeqChans(new_data)
146 }
147 KvCacheData::BatchHeadSeqChans(data) => {
148 let [batch, n_heads, _seq, chans] = data.shape();
149 let mut new_data = NdTensor::with_capacity(
150 [batch, n_heads, max_sequence_len, chans],
151 2, );
153 new_data.append(2, data).expect("should have capacity");
154 KvCacheData::BatchHeadSeqChans(new_data)
155 }
156 }
157 }
158}
159
160struct KvCache {
162 input_id: NodeId,
164
165 output_id: NodeId,
167
168 cache: Option<KvCacheData>,
171}
172
173impl KvCache {
174 fn size(&self) -> Option<usize> {
175 self.cache.as_ref().map(|c| c.sequence_len())
176 }
177}
178
179pub struct KVCachePattern<'a> {
184 pub prefix: &'a str,
185 pub suffix: &'a str,
186}
187
188impl<'a> From<(&'a str, &'a str)> for KVCachePattern<'a> {
189 fn from(value: (&'a str, &'a str)) -> Self {
191 let (prefix, suffix) = value;
192 KVCachePattern { prefix, suffix }
193 }
194}
195
196pub struct KVCachePair<'a> {
199 pub input: KVCachePattern<'a>,
201
202 pub output: KVCachePattern<'a>,
204
205 pub encoder: bool,
211}
212
213pub struct ModelInputsConfig<'a> {
220 pub input_ids: &'a str,
223
224 pub logits: &'a str,
226
227 pub attention_mask: &'a str,
229
230 pub cache_position: &'a str,
234
235 pub kv_caches: Vec<KVCachePair<'a>>,
237
238 pub position_ids: &'a str,
242
243 pub use_cache_flag: &'a str,
246}
247
248#[derive(Default)]
251pub struct GeneratorConfig<'a> {
252 pub model_inputs: ModelInputsConfig<'a>,
254
255 pub kv_cache_capacity: Option<usize>,
265}
266
267impl Default for ModelInputsConfig<'_> {
268 fn default() -> Self {
273 ModelInputsConfig {
274 input_ids: "input_ids",
275 logits: "logits",
276 attention_mask: "attention_mask",
277 cache_position: "cache_position",
278 position_ids: "position_ids",
279 use_cache_flag: "use_cache_branch",
280
281 kv_caches: [
284 KVCachePair {
288 input: ("past_key_values.", ".decoder.key").into(),
289 output: ("present.", ".decoder.key").into(),
290 encoder: false,
291 },
292 KVCachePair {
293 input: ("past_key_values.", ".decoder.value").into(),
294 output: ("present.", ".decoder.value").into(),
295 encoder: false,
296 },
297 KVCachePair {
298 input: ("past_key_values.", ".encoder.key").into(),
299 output: ("present.", ".encoder.key").into(),
300 encoder: true,
301 },
302 KVCachePair {
303 input: ("past_key_values.", ".encoder.value").into(),
304 output: ("present.", ".encoder.value").into(),
305 encoder: true,
306 },
307 KVCachePair {
309 input: ("past_key_values.", ".key").into(),
310 output: ("present.", ".key").into(),
311 encoder: false,
312 },
313 KVCachePair {
314 input: ("past_key_values.", ".value").into(),
315 output: ("present.", ".value").into(),
316 encoder: false,
317 },
318 ]
319 .into(),
320 }
321 }
322}
323
324pub struct Generator<'a> {
399 model: &'a dyn Model,
400
401 run_options: Option<RunOptions>,
402
403 constant_inputs: Vec<(NodeId, ValueOrView<'a>)>,
406
407 constant_prop_inputs: Option<Vec<(NodeId, Value)>>,
411
412 #[allow(clippy::type_complexity)]
415 varying_inputs: Vec<(NodeId, &'a dyn Fn(usize, Range<usize>) -> ValueOrView<'a>)>,
416
417 input_ids: Vec<TokenId>,
419
420 input_offset: usize,
422
423 input_ids_input: NodeId,
425
426 logits_output: NodeId,
428
429 logits_filter: Option<Box<dyn LogitsFilter + 'a>>,
431
432 sampler: Box<dyn Sampler + 'a>,
434
435 prev_tokens: Vec<u32>,
438
439 kv_cache: Vec<KvCache>,
441
442 encoder_kv_cache: Vec<KvCache>,
447}
448
449impl<'a> Generator<'a> {
450 pub fn from_model(model: &'a dyn Model) -> Result<Generator<'a>, GeneratorError> {
482 let config = GeneratorConfig {
483 model_inputs: ModelInputsConfig::default(),
484 kv_cache_capacity: None,
485 };
486 Self::from_model_config(model, config)
487 }
488
489 pub fn from_model_config(
494 model: &'a dyn Model,
495 config: GeneratorConfig,
496 ) -> Result<Generator<'a>, GeneratorError> {
497 let model_inputs = &config.model_inputs;
498
499 let input_ids_input =
500 model
501 .find_node(model_inputs.input_ids)
502 .ok_or(GeneratorError::InputNotFound(
503 model_inputs.input_ids.to_string(),
504 ))?;
505
506 let logits_output =
507 model
508 .find_node(model_inputs.logits)
509 .ok_or(GeneratorError::OutputNotFound(
510 model_inputs.logits.to_string(),
511 ))?;
512
513 let batch_size = 1;
515 let mut kv_cache = Vec::new();
516 let mut encoder_kv_cache = Vec::new();
517 for &input_id in model.input_ids() {
518 let input_info = model
519 .node_info(input_id)
520 .ok_or(GeneratorError::InputNotFound(format!(
521 "input ID {}",
522 input_id
523 )))?;
524
525 let name = input_info.name();
526
527 let Some(kv_pattern) = model_inputs
528 .kv_caches
529 .iter()
530 .find(|pat| name.starts_with(pat.input.prefix) && name.ends_with(pat.input.suffix))
531 else {
532 continue;
533 };
534
535 let (n_heads, size) = match *input_info.shape() {
536 [_, Dimension::Fixed(n_heads), _, Dimension::Fixed(size)] => (Some(n_heads), size),
537 [_, _, Dimension::Fixed(size)] => (None, size),
538 _ => {
539 return Err(GeneratorError::ShapeMismatch(format!(
540 "input \"{}\" has unexpected shape. expected (batch, past_seq_len, chans) or (batch, heads, past_seq_len, chans) where `heads` and `size` are fixed",
541 name
542 )));
543 }
544 };
545
546 let prefix = kv_pattern.input.prefix;
547
548 let layer_index_start = prefix.len();
549 let layer_index_end = name.len() - kv_pattern.input.suffix.len();
550 let layer_index_str = &name[layer_index_start..layer_index_end];
551 let Ok(layer_index) = layer_index_str.parse::<u32>() else {
552 continue;
553 };
554
555 let output_prefix = kv_pattern.output.prefix;
556 let output_suffix = kv_pattern.output.suffix;
557
558 let output_name = format!("{}{}{}", output_prefix, layer_index, output_suffix);
559 let output_id = model
560 .find_node(&output_name)
561 .ok_or(GeneratorError::OutputNotFound(output_name))?;
562
563 let max_seq_len = config.kv_cache_capacity.unwrap_or(1);
575
576 let kv_cache_entry = KvCache {
577 input_id,
578 output_id,
579 cache: Some(KvCacheData::with_capacity(
580 batch_size,
581 n_heads,
582 size,
583 max_seq_len,
584 )),
585 };
586
587 if kv_pattern.encoder {
588 encoder_kv_cache.push(kv_cache_entry);
589 } else {
590 kv_cache.push(kv_cache_entry);
591 }
592 }
593
594 let mut generator = Generator {
595 model,
596 run_options: None,
597
598 constant_inputs: Vec::new(),
599 varying_inputs: Vec::new(),
600
601 constant_prop_inputs: Some(Vec::new()),
605
606 logits_filter: None,
607 input_ids: vec![],
608 input_ids_input,
609 input_offset: 0,
610 logits_output,
611 kv_cache,
612 encoder_kv_cache,
613 prev_tokens: Vec::new(),
614 sampler: Box::new(ArgMax::new()),
615 };
616
617 let attention_mask_input = model.find_node(model_inputs.attention_mask);
618 if let Some(attention_mask_input) = attention_mask_input {
619 generator = generator
620 .with_varying_input(attention_mask_input, &|batch_size, positions| {
621 NdTensor::full([batch_size, positions.end], 1i32).into()
622 });
623 }
624
625 let position_ids_input = model.find_node(model_inputs.position_ids);
626 if let Some(position_ids_input) = position_ids_input {
627 generator =
628 generator.with_varying_input(position_ids_input, &|batch_size, positions| {
629 NdTensor::from_fn([batch_size, positions.len()], |[_batch, pos]| {
630 (positions.start + pos) as i32
631 })
632 .into()
633 });
634 }
635
636 let cache_position_input = model.find_node(model_inputs.cache_position);
637 if let Some(cache_position_input) = cache_position_input {
638 generator =
639 generator.with_varying_input(cache_position_input, &|_batch_size, positions| {
640 NdTensor::from_fn([positions.len()], |[pos]| (positions.start + pos) as i32)
641 .into()
642 });
643 }
644
645 let use_cache_input = model.find_node(model_inputs.use_cache_flag);
646 if let Some(use_cache_input) = use_cache_input {
647 generator = generator.with_varying_input(use_cache_input, &|_batch_size, positions| {
648 Tensor::from(if positions.start == 0 { 0i32 } else { 1 }).into()
649 });
650 }
651
652 Ok(generator)
653 }
654
655 pub fn with_prompt(mut self, prompt: &[TokenId]) -> Self {
661 self.input_ids = prompt.to_vec();
662 self
663 }
664
665 pub fn append_prompt(&mut self, prompt: &[TokenId]) {
670 self.input_ids.extend(prompt);
671 }
672
673 pub fn clear_prompt(&mut self) {
682 self.input_ids.clear();
683 }
684
685 pub fn prompt(&self) -> &[TokenId] {
687 &self.input_ids
688 }
689
690 pub fn prev_tokens(&self) -> &[TokenId] {
692 &self.prev_tokens
693 }
694
695 pub fn kv_cache_len(&self) -> Option<usize> {
699 self.kv_cache.first()?.size()
700 }
701
702 pub fn with_constant_input(mut self, input_id: NodeId, value: ValueView<'a>) -> Self {
707 self.constant_prop_inputs = None;
708 self.constant_inputs.push((input_id, value.into()));
709 self
710 }
711
712 pub fn with_varying_input<F: Fn(usize, Range<usize>) -> ValueOrView<'a>>(
720 mut self,
721 input_id: NodeId,
722 value_fn: &'a F,
723 ) -> Self {
724 self.varying_inputs.push((input_id, value_fn));
725 self
726 }
727
728 pub fn with_logits_filter<F: LogitsFilter + 'a>(mut self, filter: F) -> Self {
733 self.logits_filter = Some(Box::new(filter));
734 self
735 }
736
737 pub fn with_sampler<S: Sampler + 'a>(mut self, sampler: S) -> Self {
744 self.sampler = Box::new(sampler);
745 self
746 }
747
748 pub fn with_run_options(mut self, opts: Option<RunOptions>) -> Self {
750 self.run_options = opts;
751 self
752 }
753
754 fn generate_impl(
759 &mut self,
760 generate_logits: bool,
761 ) -> Result<Option<NdTensor<f32, 3>>, GeneratorError> {
762 let batch_size = 1;
763 let input_ids: NdTensor<i32, 2> = self
764 .input_ids
765 .iter()
766 .map(|id| *id as i32)
767 .collect::<Tensor<_>>()
768 .into_shape([batch_size, self.input_ids.len()]);
769
770 let input_positions = self.input_offset..self.input_offset + self.input_ids.len();
771
772 let mut model_inputs: Vec<(NodeId, ValueOrView)> =
773 vec![(self.input_ids_input, input_ids.view().into())];
774
775 if self.constant_prop_inputs.is_none() {
777 let inputs = match self.model.partial_run(
778 self.constant_inputs.clone(),
779 &[self.logits_output],
780 self.run_options.clone(),
781 ) {
782 Ok(inputs) => inputs,
783 Err(err) => {
784 return Err(wrap_error(
785 err,
786 "failed to partially evaluate model with constant inputs",
787 ));
788 }
789 };
790 self.constant_prop_inputs = Some(inputs);
791 }
792
793 if let Some(constants) = self.constant_prop_inputs.as_ref() {
794 model_inputs.extend(
795 constants
796 .iter()
797 .map(|(node_id, output)| (*node_id, output.as_view().into())),
798 );
799 }
800
801 if !self.varying_inputs.is_empty() {
802 model_inputs.extend(self.varying_inputs.iter().map(|(node_id, value_fn)| {
803 (*node_id, value_fn(batch_size, input_positions.clone()))
804 }));
805 }
806
807 for entry in self.kv_cache.iter_mut() {
811 let cache = entry.cache.take();
812 match cache {
813 Some(KvCacheData::BatchSeqChans(cache)) => {
814 model_inputs.push((entry.input_id, cache.into()));
815 }
816 Some(KvCacheData::BatchHeadSeqChans(cache)) => {
817 model_inputs.push((entry.input_id, cache.into()));
818 }
819 None => {}
820 }
821 }
822
823 for entry in self.encoder_kv_cache.iter() {
825 match &entry.cache {
826 Some(KvCacheData::BatchSeqChans(cache)) => {
827 model_inputs.push((entry.input_id, cache.into()));
828 }
829 Some(KvCacheData::BatchHeadSeqChans(cache)) => {
830 model_inputs.push((entry.input_id, cache.into()));
831 }
832 None => {}
833 }
834 }
835
836 let mut model_outputs: Vec<NodeId> = self
838 .kv_cache
839 .iter()
840 .map(|entry| entry.output_id)
841 .chain(self.encoder_kv_cache.iter().map(|entry| entry.output_id))
842 .collect();
843
844 if generate_logits {
845 model_outputs.push(self.logits_output);
846 }
847
848 let mut outputs = self
849 .model
850 .run(model_inputs, &model_outputs, self.run_options.clone())
851 .map_err(|e| wrap_error(e, "failed to run model"))?;
852
853 for cache_entry in self.kv_cache.iter_mut() {
859 let output = outputs.remove(0);
860
861 let err_context = "failed to save self-attention KV-cache";
862 let mut kv_cache = match output.ndim() {
863 3 => KvCacheData::BatchSeqChans(
864 output.try_into().map_err(|e| wrap_error(e, err_context))?,
865 ),
866 4 => KvCacheData::BatchHeadSeqChans(
867 output.try_into().map_err(|e| wrap_error(e, err_context))?,
868 ),
869 ndim => {
870 return Err(wrap_error(
871 format!("KV cache has {} dims, expected 3 or 4", ndim),
872 err_context,
873 ));
874 }
875 };
876
877 if !kv_cache.has_capacity(kv_cache.sequence_len() + 1) {
883 kv_cache = kv_cache.clone_with_capacity(kv_cache.sequence_len() * 2);
884 }
885
886 cache_entry.cache = Some(kv_cache);
887 }
888
889 for cache_entry in self.encoder_kv_cache.iter_mut() {
891 let output = outputs.remove(0);
892 if output.is_empty() {
893 continue;
897 }
898
899 let err_context = "failed to save cross-attention KV-cache";
900 let kv_cache = match output.ndim() {
901 3 => KvCacheData::BatchSeqChans(
902 output.try_into().map_err(|e| wrap_error(e, err_context))?,
903 ),
904 4 => KvCacheData::BatchHeadSeqChans(
905 output.try_into().map_err(|e| wrap_error(e, err_context))?,
906 ),
907 ndim => {
908 return Err(wrap_error(
909 format!("KV cache has {} dims, expected 3 or 4", ndim),
910 err_context,
911 ));
912 }
913 };
914 cache_entry.cache = Some(kv_cache);
915 }
916
917 if self.prev_tokens.is_empty() {
919 self.prev_tokens.extend(self.input_ids.iter());
920 }
921
922 if !self.kv_cache.is_empty() {
924 self.input_offset += self.input_ids.len();
925 self.input_ids.clear();
926 }
927
928 if generate_logits {
929 let logits: NdTensor<f32, 3> = outputs
931 .remove(0)
932 .try_into()
933 .map_err(|e| wrap_error(e, "failed to extract logits from model outputs"))?;
934 Ok(Some(logits))
935 } else {
936 Ok(None)
937 }
938 }
939
940 pub fn process_prompt(&mut self) -> Result<(), GeneratorError> {
945 self.generate_impl(false).map(|_| ())
946 }
947
948 fn generate_next_token(&mut self) -> Result<TokenId, GeneratorError> {
953 let logits = self.generate_impl(true)?.expect("should have logits");
954 let last_logits = Logits::dense(logits.slice((0, -1)).to_contiguous().to_vec());
955 let filtered_logits = if let Some(filter) = self.logits_filter.as_ref() {
956 filter.filter(last_logits, &self.prev_tokens)
957 } else {
958 last_logits
959 };
960
961 if filtered_logits.is_empty() {
963 return Err(GeneratorError::GenerateError(
964 "filtered logits are empty".into(),
965 ));
966 }
967
968 let next_id = self.sampler.sample(&filtered_logits);
970
971 self.prev_tokens.push(next_id);
973 self.input_ids.push(next_id);
974
975 Ok(next_id)
976 }
977}
978
979fn wrap_error<E>(error: E, context: &str) -> GeneratorError
980where
981 E: Into<Box<dyn Error>>,
982{
983 let error_ctx = ErrorContext {
984 error: error.into(),
985 context: context.to_string(),
986 };
987 GeneratorError::GenerateError(error_ctx.into())
988}
989
990pub type GeneratorItem = Result<TokenId, GeneratorError>;
992
993impl Iterator for Generator<'_> {
994 type Item = Result<TokenId, GeneratorError>;
995
996 fn next(&mut self) -> Option<Self::Item> {
1003 Some(self.generate_next_token())
1004 }
1005}
1006
1007pub trait GeneratorUtils: Iterator<Item = GeneratorItem> + Sized {
1010 fn stop_on_tokens<A: AsRef<[u32]>>(self, eos_tokens: A) -> impl Iterator<Item = GeneratorItem> {
1012 self.take_while(move |tok| match tok {
1013 Ok(tok_id) => !eos_tokens.as_ref().contains(tok_id),
1014 _ => true,
1015 })
1016 }
1017
1018 #[cfg(feature = "text-decoder")]
1023 fn decode(self, tokenizer: &Tokenizer) -> TextDecoder<'_, Self> {
1024 TextDecoder::wrap(self, tokenizer)
1025 }
1026
1027 fn profile(self, metrics: &mut Metrics) -> impl Iterator<Item = Self::Item> {
1032 Profiler::wrap(self, metrics)
1033 }
1034}
1035
1036impl<I: Iterator<Item = GeneratorItem>> GeneratorUtils for I {}
1037
1038struct Profiler<'a, G: Iterator> {
1040 generator: G,
1041 metrics: &'a mut Metrics,
1042}
1043
1044impl<'a, G: Iterator> Profiler<'a, G> {
1045 fn wrap(generator: G, metrics: &'a mut Metrics) -> Profiler<'a, G> {
1046 Profiler { generator, metrics }
1047 }
1048}
1049
1050impl<G: Iterator> Iterator for Profiler<'_, G> {
1051 type Item = G::Item;
1052
1053 fn next(&mut self) -> Option<Self::Item> {
1054 let start = std::time::Instant::now();
1055 let item = self.generator.next()?;
1056 self.metrics.add_step_duration(start.elapsed());
1057 Some(item)
1058 }
1059}
1060
1061#[cfg(test)]
1062mod tests {
1063 use std::cell::{Cell, RefCell};
1064 use std::collections::HashMap;
1065 use std::error::Error;
1066 use std::rc::Rc;
1067
1068 use rten::{Dimension, NodeId, RunOptions, Value, ValueOrView};
1069 use rten_tensor::NdTensor;
1070 use rten_tensor::prelude::*;
1071
1072 use super::{Generator, GeneratorUtils, Logits};
1073 use crate::filter::LogitsFilter;
1074 use crate::metrics::Metrics;
1075 use crate::model::{Model, NodeInfo};
1076
1077 struct FakeModel {
1078 nodes: Vec<NodeInfo>,
1079 input_ids: Vec<NodeId>,
1080 output_ids: Vec<NodeId>,
1081
1082 step: Cell<usize>,
1084
1085 outputs: Vec<HashMap<NodeId, Value>>,
1087
1088 inputs: RefCell<Vec<HashMap<NodeId, Value>>>,
1090
1091 run_opts: Cell<Option<RunOptions>>,
1093 }
1094
1095 impl FakeModel {
1096 fn with_inputs_and_outputs(inputs: &[NodeInfo], outputs: &[NodeInfo]) -> FakeModel {
1098 let node_infos = [inputs, outputs].concat();
1099 let input_ids = (0..inputs.len())
1100 .map(|id| NodeId::from_u32(id as u32))
1101 .collect();
1102 let output_ids = (inputs.len()..(inputs.len() + outputs.len()))
1103 .map(|id| NodeId::from_u32(id as u32))
1104 .collect();
1105
1106 FakeModel {
1107 input_ids,
1108 output_ids,
1109 nodes: node_infos,
1110 step: Cell::new(0),
1111 inputs: RefCell::new(vec![]),
1112 outputs: vec![],
1113 run_opts: Cell::new(None),
1114 }
1115 }
1116
1117 fn add_outputs(&mut self, outputs: HashMap<NodeId, Value>) {
1119 self.outputs.push(outputs)
1120 }
1121
1122 fn get_inputs(&self, step: usize, node_id: NodeId) -> Option<Value> {
1124 self.inputs
1125 .borrow()
1126 .get(step)
1127 .map(|step_inputs| step_inputs.get(&node_id))
1128 .flatten()
1129 .cloned()
1130 }
1131 }
1132
1133 impl Model for FakeModel {
1134 fn find_node(&self, name: &str) -> Option<NodeId> {
1135 self.nodes
1136 .iter()
1137 .position(|info| info.name() == name)
1138 .map(|pos| NodeId::from_u32(pos as u32))
1139 }
1140
1141 fn node_info(&self, id: NodeId) -> Option<NodeInfo> {
1142 self.nodes.get(id.as_usize()).cloned()
1143 }
1144
1145 fn input_ids(&self) -> &[NodeId] {
1146 &self.input_ids
1147 }
1148
1149 fn run(
1150 &self,
1151 inputs: Vec<(NodeId, ValueOrView)>,
1152 outputs: &[NodeId],
1153 opts: Option<RunOptions>,
1154 ) -> Result<Vec<Value>, Box<dyn Error>> {
1155 if let Some((input_id, _)) = inputs.iter().find(|(id, _)| !self.input_ids.contains(id))
1156 {
1157 return Err(format!("invalid input ID {}", input_id).into());
1158 }
1159 for &expected_input in self.input_ids.iter() {
1160 if !inputs.iter().any(|&(id, _)| id == expected_input) {
1161 return Err(format!("missing input ID {}", expected_input).into());
1162 }
1163 }
1164
1165 if let Some(output_id) = outputs.iter().find(|id| !self.output_ids.contains(id)) {
1166 return Err(format!("invalid output ID {}", output_id).into());
1167 }
1168
1169 self.inputs.borrow_mut().push(
1170 inputs
1171 .into_iter()
1172 .map(|(id, input_or_output)| (id, input_or_output.to_owned()))
1173 .collect(),
1174 );
1175
1176 let result = outputs
1177 .iter()
1178 .map(|id| {
1179 let step_outputs = self
1180 .outputs
1181 .get(self.step.get())
1182 .expect("outputs not specified for step");
1183
1184 step_outputs
1185 .get(id)
1186 .cloned()
1187 .expect("invalid output node ID")
1188 })
1189 .collect();
1190
1191 self.step.set(self.step.get() + 1);
1192 self.run_opts.set(opts);
1193
1194 Ok(result)
1195 }
1196
1197 fn partial_run(
1198 &self,
1199 _inputs: Vec<(NodeId, ValueOrView)>,
1200 _outputs: &[NodeId],
1201 _opts: Option<RunOptions>,
1202 ) -> Result<Vec<(NodeId, Value)>, Box<dyn Error>> {
1203 Ok(Vec::new())
1204 }
1205 }
1206
1207 fn generate_logits(n_vocab: usize, token_ids: &[u32]) -> NdTensor<f32, 3> {
1209 let mut logits = NdTensor::zeros([1, token_ids.len(), n_vocab]);
1210 for (idx, id) in token_ids.iter().copied().enumerate() {
1211 logits[[0, idx, id as usize]] = 1.0;
1212 }
1213 logits
1214 }
1215
1216 #[derive(Copy, Clone, PartialEq)]
1217 struct TransformerParams {
1218 n_layers: usize,
1221 n_heads: usize,
1222 n_embed: usize,
1223
1224 n_vocab: usize,
1227 }
1228
1229 impl Default for TransformerParams {
1230 fn default() -> Self {
1231 Self {
1232 n_layers: 5,
1233 n_heads: 3,
1234 n_vocab: 5,
1235 n_embed: 8,
1236 }
1237 }
1238 }
1239
1240 #[derive(Copy, Clone, PartialEq)]
1241 enum KvCacheType {
1242 Decoder,
1244 EncoderDecoder,
1247 }
1248
1249 fn fake_transformer_model(
1252 params: TransformerParams,
1253 kv_cache: Option<KvCacheType>,
1254 prompt_len: usize,
1255 output_token_ids: &[u32],
1256 ) -> FakeModel {
1257 let TransformerParams {
1258 n_layers,
1259 n_heads,
1260 n_vocab,
1261 n_embed,
1262 } = params;
1263
1264 let mut inputs = vec![
1266 NodeInfo::from_name_shape("input_ids", &[]),
1267 NodeInfo::from_name_shape("cache_position", &[]),
1268 NodeInfo::from_name_shape("position_ids", &[]),
1269 NodeInfo::from_name_shape("attention_mask", &[]),
1270 ];
1271 let mut outputs = vec![NodeInfo::from_name_shape("logits", &[])];
1272
1273 let mut kv_cache_output_names = Vec::new();
1275 if let Some(kv_cache_type) = kv_cache {
1276 let dims = [
1277 Dimension::Symbolic("batch".to_string()),
1278 Dimension::Fixed(n_heads as usize),
1279 Dimension::Symbolic("seq".to_string()),
1280 Dimension::Fixed(n_embed),
1281 ];
1282 let make_name_info = |name: &str| NodeInfo::from_name_shape(name, &dims);
1283
1284 for layer in 0..n_layers {
1285 let past_names: Vec<String>;
1286 let present_names: Vec<String>;
1287
1288 match kv_cache_type {
1289 KvCacheType::Decoder => {
1290 past_names = [
1291 format!("past_key_values.{}.key", layer),
1292 format!("past_key_values.{}.value", layer),
1293 ]
1294 .into();
1295 present_names = [
1296 format!("present.{}.key", layer),
1297 format!("present.{}.value", layer),
1298 ]
1299 .into();
1300 }
1301 KvCacheType::EncoderDecoder => {
1302 past_names = [
1303 format!("past_key_values.{}.decoder.key", layer),
1304 format!("past_key_values.{}.decoder.value", layer),
1305 format!("past_key_values.{}.encoder.key", layer),
1306 format!("past_key_values.{}.encoder.value", layer),
1307 ]
1308 .into();
1309
1310 present_names = [
1311 format!("present.{}.decoder.key", layer),
1312 format!("present.{}.decoder.value", layer),
1313 format!("present.{}.encoder.key", layer),
1314 format!("present.{}.encoder.value", layer),
1315 ]
1316 .into();
1317 }
1318 }
1319
1320 inputs.extend(past_names.iter().map(|name| make_name_info(&name)));
1321 outputs.extend(present_names.iter().map(|name| make_name_info(&name)));
1322 kv_cache_output_names.extend(present_names);
1323 }
1324
1325 if kv_cache_type == KvCacheType::EncoderDecoder {
1326 inputs.push(NodeInfo::from_name_shape("use_cache_branch", &[]));
1327 }
1328 }
1329
1330 let mut model = FakeModel::with_inputs_and_outputs(&inputs, &outputs);
1331 let logits_id = model.find_node("logits").unwrap();
1332
1333 for (step, output_token_id) in output_token_ids.iter().copied().enumerate() {
1334 assert!(
1335 output_token_id < n_vocab as u32,
1336 "token ID is invalid for vocab size"
1337 );
1338
1339 let logits = if kv_cache.is_some() {
1340 generate_logits(n_vocab, &[output_token_id])
1341 } else {
1342 generate_logits(n_vocab, &output_token_ids[..=step])
1343 };
1344
1345 let mut outputs = HashMap::new();
1346 outputs.insert(logits_id, Value::FloatTensor(logits.into()));
1347
1348 for kv_output in kv_cache_output_names.iter() {
1350 let kv_output_id = model.find_node(&kv_output).unwrap();
1351 let context_len = if step == 0 {
1352 prompt_len
1353 } else {
1354 prompt_len + step - 1
1355 };
1356
1357 let is_encoder = model
1358 .node_info(kv_output_id)
1359 .as_ref()
1360 .map(|ni| ni.name())
1361 .unwrap_or("")
1362 .contains("encoder");
1363
1364 let output_n_embed = if is_encoder && step > 0 {
1365 0
1369 } else {
1370 n_embed
1371 };
1372
1373 outputs.insert(
1374 kv_output_id,
1375 Value::FloatTensor(
1376 NdTensor::zeros([1, n_heads, context_len, output_n_embed]).into(),
1377 ),
1378 );
1379 }
1380
1381 model.add_outputs(outputs);
1382 }
1383
1384 model
1385 }
1386
1387 fn test_generator_impl(kv_cache_type: Option<KvCacheType>) -> Result<(), Box<dyn Error>> {
1388 let params = TransformerParams::default();
1389 let expected_token_ids = [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 0, 0];
1390 let prompt = [1, 2, 3, 1, 2, 3];
1391 let model =
1392 fake_transformer_model(params, kv_cache_type, prompt.len(), &expected_token_ids);
1393
1394 let generator = Generator::from_model(&model)?;
1395 let generation_len = 10;
1396
1397 let output_token_ids: Vec<_> = generator
1398 .with_prompt(&prompt)
1399 .take(generation_len)
1400 .map(|id| id.expect("generation failed"))
1401 .collect();
1402
1403 assert_eq!(output_token_ids.len(), generation_len);
1405 assert_eq!(output_token_ids, &expected_token_ids[..generation_len]);
1406
1407 let input_id = model.find_node("input_ids").unwrap();
1409 let position_ids = model.find_node("position_ids").unwrap();
1410 let attention_mask = model.find_node("attention_mask").unwrap();
1411 let cache_branch = model.find_node("use_cache_branch");
1412 let cache_position = model.find_node("cache_position").unwrap();
1413
1414 for step in 0..generation_len {
1415 let step_inputs = model.get_inputs(step, input_id).unwrap();
1416 let step_inputs: NdTensor<i32, 2> = step_inputs.try_into().unwrap();
1417
1418 let step_pos_ids = model.get_inputs(step, position_ids).unwrap();
1419 let step_pos_ids: NdTensor<i32, 2> = step_pos_ids.try_into().unwrap();
1420
1421 let step_cache_pos = model.get_inputs(step, cache_position).unwrap();
1422 let step_cache_pos: NdTensor<i32, 1> = step_cache_pos.try_into().unwrap();
1423
1424 let step_attn_mask = model.get_inputs(step, attention_mask).unwrap();
1425 let step_attn_mask: NdTensor<i32, 2> = step_attn_mask.try_into().unwrap();
1426
1427 let cache_branch = cache_branch.map(|cb_id| {
1428 let cb = model.get_inputs(step, cb_id).unwrap();
1429 let cb: NdTensor<i32, 0> = cb.try_into().unwrap();
1430 cb
1431 });
1432
1433 if step == 0 {
1434 assert_eq!(step_inputs.size(1), prompt.len());
1435 assert!(
1436 step_inputs
1437 .iter()
1438 .map(|x| *x as u32)
1439 .eq(prompt.iter().copied())
1440 );
1441
1442 assert_eq!(step_attn_mask.size(1), prompt.len());
1443 assert!(step_attn_mask.iter().all(|x| *x == 1));
1444
1445 assert_eq!(step_pos_ids.size(1), prompt.len());
1446 assert!(step_pos_ids.iter().map(|x| *x as usize).eq(0..prompt.len()));
1447
1448 assert_eq!(step_cache_pos.size(0), prompt.len());
1449 assert!(
1450 step_cache_pos
1451 .iter()
1452 .map(|x| *x as usize)
1453 .eq(0..prompt.len())
1454 );
1455
1456 if let Some(cache_branch) = cache_branch {
1457 assert_eq!(cache_branch.item(), Some(&0));
1458 }
1459 } else if kv_cache_type.is_some() {
1460 assert_eq!(step_inputs.size(1), 1);
1461 assert_eq!(step_inputs[[0, 0]] as u32, expected_token_ids[step - 1]);
1462
1463 assert_eq!(step_attn_mask.size(1), prompt.len() + step);
1464 assert_eq!(step_attn_mask[[0, 0]], 1);
1465
1466 assert_eq!(step_pos_ids.size(1), 1);
1467 assert_eq!(step_pos_ids[[0, 0]], (prompt.len() + step - 1) as i32);
1468
1469 assert_eq!(step_cache_pos.size(0), 1);
1470 assert_eq!(step_cache_pos[[0]], (prompt.len() + step - 1) as i32);
1471
1472 if let Some(cache_branch) = cache_branch {
1473 assert_eq!(cache_branch.item(), Some(&1));
1474 }
1475 } else {
1476 let expected_inputs: Vec<i32> = prompt
1477 .iter()
1478 .copied()
1479 .chain(expected_token_ids)
1480 .take(prompt.len() + step)
1481 .map(|x| x as i32)
1482 .collect();
1483 assert_eq!(
1484 step_inputs,
1485 NdTensor::from_data([1, expected_inputs.len()], expected_inputs)
1486 );
1487
1488 let expected_attn_mask = vec![1i32; prompt.len() + step];
1489 assert_eq!(
1490 step_attn_mask,
1491 NdTensor::from_data([1, expected_attn_mask.len()], expected_attn_mask)
1492 );
1493
1494 let expected_pos_ids: Vec<i32> =
1495 (0..prompt.len() + step).map(|x| x as i32).collect();
1496 assert_eq!(
1497 step_pos_ids,
1498 NdTensor::from_data([1, expected_pos_ids.len()], expected_pos_ids.clone())
1499 );
1500 assert_eq!(
1501 step_cache_pos,
1502 NdTensor::from_data([expected_pos_ids.len()], expected_pos_ids)
1503 );
1504 }
1505 }
1506
1507 Ok(())
1508 }
1509
1510 #[test]
1511 fn test_generator_with_decoder_kv_cache() -> Result<(), Box<dyn Error>> {
1512 test_generator_impl(Some(KvCacheType::Decoder))
1513 }
1514
1515 #[test]
1516 fn test_generator_with_encoder_decoder_kv_cache() -> Result<(), Box<dyn Error>> {
1517 test_generator_impl(Some(KvCacheType::EncoderDecoder))
1518 }
1519
1520 #[test]
1521 fn test_generator_without_kv_cache() -> Result<(), Box<dyn Error>> {
1522 test_generator_impl(None)
1523 }
1524
1525 #[test]
1526 fn test_generator_append_prompt() -> Result<(), Box<dyn Error>> {
1527 let mut params = TransformerParams::default();
1528 params.n_vocab = 110;
1529 let output_token_ids = [0, 1, 2, 3, 4, 5, 6, 7, 8];
1530 let prompt = [99];
1531 let model = fake_transformer_model(
1532 params,
1533 Some(KvCacheType::Decoder),
1534 prompt.len(),
1535 &output_token_ids,
1536 );
1537
1538 let mut generator = Generator::from_model(&model)?.with_prompt(&prompt);
1539
1540 generator.next();
1541 generator.append_prompt(&[100]);
1542 generator.next();
1543 generator.append_prompt(&[101, 102]);
1544 generator.next();
1545
1546 let input_id = model.find_node("input_ids").unwrap();
1547
1548 let inputs = model.get_inputs(0, input_id).unwrap();
1550 let inputs: NdTensor<i32, 2> = inputs.try_into().unwrap();
1551 assert_eq!(inputs, NdTensor::from([[99]]));
1552
1553 let inputs = model.get_inputs(1, input_id).unwrap();
1556 let inputs: NdTensor<i32, 2> = inputs.try_into().unwrap();
1557 assert_eq!(inputs, NdTensor::from([[0, 100]]));
1558
1559 let inputs = model.get_inputs(2, input_id).unwrap();
1560 let inputs: NdTensor<i32, 2> = inputs.try_into().unwrap();
1561 assert_eq!(inputs, NdTensor::from([[1, 101, 102]]));
1562
1563 Ok(())
1564 }
1565
1566 #[test]
1567 fn test_stop_on_tokens() -> Result<(), Box<dyn Error>> {
1568 let params = TransformerParams::default();
1569 let expected_token_ids = [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 0, 0];
1570 let prompt = [1, 2, 3, 1, 2, 3];
1571 let model = fake_transformer_model(
1572 params,
1573 Some(KvCacheType::Decoder),
1574 prompt.len(),
1575 &expected_token_ids,
1576 );
1577
1578 let generator = Generator::from_model(&model)?;
1579
1580 let output_token_ids: Vec<_> = generator
1581 .with_prompt(&prompt)
1582 .stop_on_tokens([4])
1583 .map(|id| id.expect("generation failed"))
1584 .collect();
1585
1586 assert_eq!(output_token_ids, &[0, 1, 2, 3]);
1587
1588 Ok(())
1589 }
1590
1591 #[test]
1592 fn test_process_prompt() -> Result<(), Box<dyn Error>> {
1593 let params = TransformerParams::default();
1594 let expected_token_ids = [0];
1595 let prompt = [1, 2, 3, 1, 2, 3];
1596 let model = fake_transformer_model(
1597 params,
1598 Some(KvCacheType::Decoder),
1599 prompt.len(),
1600 &expected_token_ids,
1601 );
1602
1603 let mut generator = Generator::from_model(&model)?.with_prompt(&prompt);
1604 assert_eq!(generator.prompt(), prompt);
1605 assert!(generator.prev_tokens().is_empty());
1606 assert_eq!(generator.kv_cache_len(), Some(0));
1607
1608 generator.process_prompt().unwrap();
1609
1610 assert!(generator.prompt().is_empty());
1611 assert_eq!(generator.prev_tokens(), prompt);
1612 assert_eq!(generator.kv_cache_len(), Some(prompt.len()));
1613
1614 Ok(())
1615 }
1616
1617 #[test]
1618 fn test_profile() -> Result<(), Box<dyn Error>> {
1619 let params = TransformerParams::default();
1620 let expected_token_ids = [0, 1, 2, 3, 4];
1621 let prompt = [1, 2, 3, 1, 2, 3];
1622 let model = fake_transformer_model(
1623 params,
1624 Some(KvCacheType::Decoder),
1625 prompt.len(),
1626 &expected_token_ids,
1627 );
1628
1629 let generator = Generator::from_model(&model)?;
1630 let mut metrics = Metrics::new();
1631
1632 let output_token_ids: Vec<_> = generator
1633 .with_prompt(&prompt)
1634 .profile(&mut metrics)
1635 .take(expected_token_ids.len())
1636 .map(|id| id.expect("generation failed"))
1637 .collect();
1638
1639 assert_eq!(output_token_ids, expected_token_ids);
1640 assert!(metrics.warmup_duration().is_some());
1641 assert_eq!(metrics.step_durations().len(), output_token_ids.len() - 1);
1642
1643 Ok(())
1644 }
1645
1646 #[test]
1647 fn test_filter() -> Result<(), Box<dyn Error>> {
1648 let mut params = TransformerParams::default();
1649 params.n_vocab = 8; let expected_token_ids = [0, 1, 2, 3];
1652 let prompt = [5, 6, 7];
1653 let model = fake_transformer_model(
1654 params,
1655 Some(KvCacheType::Decoder),
1656 prompt.len(),
1657 &expected_token_ids,
1658 );
1659
1660 let generator = Generator::from_model(&model)?;
1661
1662 struct DoubleIndexFilter {
1664 prev_tokens: Rc<RefCell<Vec<u32>>>,
1665 }
1666 impl LogitsFilter for DoubleIndexFilter {
1667 fn filter(&self, logits: Logits, prev_tokens: &[u32]) -> Logits {
1668 self.prev_tokens.replace(prev_tokens.to_vec());
1669
1670 let max_idx = logits
1671 .logits()
1672 .iter()
1673 .zip(logits.indices())
1674 .max_by(|(x, _i), (y, _j)| x.total_cmp(y))
1675 .map(|(_x, i)| i)
1676 .unwrap();
1677
1678 Logits::sparse(vec![1.0], vec![max_idx * 2])
1679 }
1680 }
1681
1682 let prev_tokens = Rc::new(RefCell::new(Vec::new()));
1683 let output_token_ids: Vec<_> = generator
1684 .with_prompt(&prompt)
1685 .with_logits_filter(DoubleIndexFilter {
1686 prev_tokens: prev_tokens.clone(),
1687 })
1688 .take(expected_token_ids.len())
1689 .map(|id| id.expect("generation failed"))
1690 .collect();
1691
1692 assert_eq!(output_token_ids, [0, 2, 4, 6]);
1693 assert_eq!(prev_tokens.borrow().as_slice(), [5, 6, 7, 0, 2, 4]);
1694
1695 Ok(())
1696 }
1697
1698 #[test]
1699 fn test_empty_filter_output() {
1700 let params = TransformerParams::default();
1701 let prompt = [1];
1702 let model = fake_transformer_model(
1703 params,
1704 Some(KvCacheType::Decoder),
1705 prompt.len(),
1706 &[0, 1, 2, 3],
1707 );
1708
1709 struct RemoveAllFilter;
1710 impl LogitsFilter for RemoveAllFilter {
1711 fn filter(&self, _logits: Logits, _prev_tokens: &[u32]) -> Logits {
1712 Logits::dense(vec![])
1713 }
1714 }
1715
1716 let mut generator = Generator::from_model(&model)
1717 .unwrap()
1718 .with_logits_filter(RemoveAllFilter);
1719 let err = generator.next().unwrap().err().unwrap();
1720 assert!(err.to_string().contains("filtered logits are empty"));
1721 }
1722
1723 #[test]
1724 fn test_run_options() -> Result<(), Box<dyn Error>> {
1725 let params = TransformerParams::default();
1726 let expected_token_ids = [0, 1, 2, 3, 4];
1727 let prompt = [1, 2, 3, 1, 2, 3];
1728 let model = fake_transformer_model(
1729 params,
1730 Some(KvCacheType::Decoder),
1731 prompt.len(),
1732 &expected_token_ids,
1733 );
1734
1735 let generator = Generator::from_model(&model)?;
1736
1737 let run_opts = RunOptions::default().with_verbose(true);
1738 let output_token_ids: Vec<_> = generator
1739 .with_prompt(&prompt)
1740 .with_run_options(Some(run_opts.clone()))
1741 .take(expected_token_ids.len())
1742 .map(|id| id.expect("generation failed"))
1743 .collect();
1744
1745 assert_eq!(output_token_ids, expected_token_ids);
1746 assert_eq!(model.run_opts.take(), Some(run_opts));
1747
1748 Ok(())
1749 }
1750}