1use std::collections::hash_map::DefaultHasher;
22use std::hash::{Hash, Hasher};
23
24use crate::dynamic::sym;
25use crate::shape::DimBinding;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29pub enum ModelPhase {
30 Prefill,
31 Decode,
32 Encoder,
33 Inference,
34}
35
36#[derive(Debug, Clone, PartialEq, Eq)]
38pub struct ModelVariant {
39 pub batch: usize,
40 pub seq: usize,
41 pub past_seq: Option<usize>,
42 pub phase: ModelPhase,
43 pub extra: Vec<(u32, usize)>,
45}
46
47impl ModelVariant {
48 pub fn prefill(batch: usize, seq: usize) -> Self {
49 Self {
50 batch,
51 seq,
52 past_seq: None,
53 phase: ModelPhase::Prefill,
54 extra: Vec::new(),
55 }
56 }
57
58 pub fn decode(batch: usize, past_seq: usize, new_tokens: usize) -> Self {
60 Self {
61 batch,
62 seq: new_tokens,
63 past_seq: Some(past_seq),
64 phase: ModelPhase::Decode,
65 extra: Vec::new(),
66 }
67 }
68
69 pub fn encoder(batch: usize, seq: usize) -> Self {
70 Self {
71 batch,
72 seq,
73 past_seq: None,
74 phase: ModelPhase::Encoder,
75 extra: Vec::new(),
76 }
77 }
78
79 pub fn with_extra(mut self, symbol: u32, size: usize) -> Self {
80 self.extra.push((symbol, size));
81 self
82 }
83
84 pub fn cache_key(&self) -> u64 {
86 let mut h = DefaultHasher::new();
87 self.phase.hash(&mut h);
88 self.batch.hash(&mut h);
89 self.seq.hash(&mut h);
90 self.past_seq.hash(&mut h);
91 for (sym, size) in &self.extra {
92 sym.hash(&mut h);
93 size.hash(&mut h);
94 }
95 h.finish()
96 }
97
98 pub fn dim_binding(&self) -> DimBinding {
100 let mut b = match (self.phase, self.past_seq) {
101 (ModelPhase::Decode, Some(past)) => DimBinding::batch_past_seq(self.batch, past),
102 _ => DimBinding::batch_seq(self.batch, self.seq),
103 };
104 if self.phase == ModelPhase::Decode {
105 b.set(sym::SEQ, self.seq);
106 }
107 for (sym, size) in &self.extra {
108 b.set(*sym, *size);
109 }
110 b
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 #[test]
119 fn prefill_binding_sets_batch_seq() {
120 let v = ModelVariant::prefill(2, 128);
121 let b = v.dim_binding();
122 assert_eq!(b.get(sym::BATCH), Some(2));
123 assert_eq!(b.get(sym::SEQ), Some(128));
124 }
125
126 #[test]
127 fn decode_binding_sets_past_and_new_seq() {
128 let v = ModelVariant::decode(1, 64, 1);
129 let b = v.dim_binding();
130 assert_eq!(b.get(sym::BATCH), Some(1));
131 assert_eq!(b.get(sym::PAST_SEQ), Some(64));
132 assert_eq!(b.get(sym::SEQ), Some(1));
133 }
134
135 #[test]
136 fn cache_key_differs_by_phase() {
137 let a = ModelVariant::prefill(1, 8).cache_key();
138 let b = ModelVariant::decode(1, 7, 1).cache_key();
139 assert_ne!(a, b);
140 }
141}