1use std::collections::hash_map::DefaultHasher;
10use std::hash::{Hash, Hasher};
11
12use crate::dynamic::sym;
13use crate::shape::DimBinding;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17pub enum ModelPhase {
18 Prefill,
19 Decode,
20 Encoder,
21 Inference,
22}
23
24#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct ModelVariant {
27 pub batch: usize,
28 pub seq: usize,
29 pub past_seq: Option<usize>,
30 pub phase: ModelPhase,
31 pub extra: Vec<(u32, usize)>,
33}
34
35impl ModelVariant {
36 pub fn prefill(batch: usize, seq: usize) -> Self {
37 Self {
38 batch,
39 seq,
40 past_seq: None,
41 phase: ModelPhase::Prefill,
42 extra: Vec::new(),
43 }
44 }
45
46 pub fn decode(batch: usize, past_seq: usize, new_tokens: usize) -> Self {
48 Self {
49 batch,
50 seq: new_tokens,
51 past_seq: Some(past_seq),
52 phase: ModelPhase::Decode,
53 extra: Vec::new(),
54 }
55 }
56
57 pub fn encoder(batch: usize, seq: usize) -> Self {
58 Self {
59 batch,
60 seq,
61 past_seq: None,
62 phase: ModelPhase::Encoder,
63 extra: Vec::new(),
64 }
65 }
66
67 pub fn with_extra(mut self, symbol: u32, size: usize) -> Self {
68 self.extra.push((symbol, size));
69 self
70 }
71
72 pub fn cache_key(&self) -> u64 {
74 let mut h = DefaultHasher::new();
75 self.phase.hash(&mut h);
76 self.batch.hash(&mut h);
77 self.seq.hash(&mut h);
78 self.past_seq.hash(&mut h);
79 for (sym, size) in &self.extra {
80 sym.hash(&mut h);
81 size.hash(&mut h);
82 }
83 h.finish()
84 }
85
86 pub fn dim_binding(&self) -> DimBinding {
88 let mut b = match (self.phase, self.past_seq) {
89 (ModelPhase::Decode, Some(past)) => DimBinding::batch_past_seq(self.batch, past),
90 _ => DimBinding::batch_seq(self.batch, self.seq),
91 };
92 if self.phase == ModelPhase::Decode {
93 b.set(sym::SEQ, self.seq);
94 }
95 for (sym, size) in &self.extra {
96 b.set(*sym, *size);
97 }
98 b
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105
106 #[test]
107 fn prefill_binding_sets_batch_seq() {
108 let v = ModelVariant::prefill(2, 128);
109 let b = v.dim_binding();
110 assert_eq!(b.get(sym::BATCH), Some(2));
111 assert_eq!(b.get(sym::SEQ), Some(128));
112 }
113
114 #[test]
115 fn decode_binding_sets_past_and_new_seq() {
116 let v = ModelVariant::decode(1, 64, 1);
117 let b = v.dim_binding();
118 assert_eq!(b.get(sym::BATCH), Some(1));
119 assert_eq!(b.get(sym::PAST_SEQ), Some(64));
120 assert_eq!(b.get(sym::SEQ), Some(1));
121 }
122
123 #[test]
124 fn cache_key_differs_by_phase() {
125 let a = ModelVariant::prefill(1, 8).cache_key();
126 let b = ModelVariant::decode(1, 7, 1).cache_key();
127 assert_ne!(a, b);
128 }
129}