1pub fn compute_cu_seqlens_q(
21 items: &[(String, Vec<u32>, usize, bool)],
22) -> (Vec<usize>, Vec<u32>, usize) {
23 let q_lens: Vec<usize> = items.iter().map(|it| it.1.len()).collect();
24 let mut cu_seqlens_q: Vec<u32> = Vec::with_capacity(items.len() + 1);
25 cu_seqlens_q.push(0);
26 for &l in &q_lens {
27 let prev = *cu_seqlens_q.last().unwrap();
28 cu_seqlens_q.push(prev + l as u32);
29 }
30 let m_total = *cu_seqlens_q.last().unwrap() as usize;
31 (q_lens, cu_seqlens_q, m_total)
32}
33
34pub fn compute_pos_offsets(items: &[(String, Vec<u32>, usize, bool)]) -> Vec<u32> {
39 items.iter().map(|it| it.2 as u32).collect()
40}
41
42pub fn compute_max_kv_len(items: &[(String, Vec<u32>, usize, bool)]) -> usize {
46 items.iter().map(|it| it.2 + it.1.len()).max().unwrap_or(0)
47}
48
49pub fn concat_q_tokens(items: &[(String, Vec<u32>, usize, bool)]) -> Vec<u32> {
53 items.iter().flat_map(|it| it.1.iter().copied()).collect()
54}
55
56pub fn stack_block_tables<F: Fn(&str) -> Vec<u32>>(
66 items: &[(String, Vec<u32>, usize, bool)],
67 max_blocks_per_seq: usize,
68 lookup: F,
69) -> Vec<u32> {
70 let mut stacked: Vec<u32> = vec![0u32; items.len() * max_blocks_per_seq];
71 for (i, (cid, _, _, _)) in items.iter().enumerate() {
72 let blocks = lookup(cid);
73 let n_to_copy = blocks.len().min(max_blocks_per_seq);
74 stacked[i * max_blocks_per_seq..i * max_blocks_per_seq + n_to_copy]
75 .copy_from_slice(&blocks[..n_to_copy]);
76 }
77 stacked
78}
79
80pub fn compute_final_indices(
85 items: &[(String, Vec<u32>, usize, bool)],
86 cu_seqlens_q: &[u32],
87) -> Vec<(usize, usize)> {
88 items
89 .iter()
90 .enumerate()
91 .filter(|(_, it)| it.3)
92 .map(|(orig_idx, it)| {
93 let last_token_local = it.1.len() - 1;
94 let global = (cu_seqlens_q[orig_idx] as usize) + last_token_local;
95 (orig_idx, global)
96 })
97 .collect()
98}
99
100pub const fn unified_graph_key(m_total: usize, num_seqs: usize) -> u64 {
110 (1u64 << 63) | ((m_total as u64) << 32) | (num_seqs as u64)
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116
117 fn item(
118 cid: &str,
119 q_len: usize,
120 pos: usize,
121 final_chunk: bool,
122 ) -> (String, Vec<u32>, usize, bool) {
123 (cid.to_string(), vec![0u32; q_len], pos, final_chunk)
124 }
125
126 #[test]
127 fn cu_seqlens_q_mixed_lengths() {
128 let items = vec![
129 item("a", 5, 0, true),
130 item("b", 1, 100, true),
131 item("c", 3, 10, false),
132 ];
133 let (q_lens, cu, m_total) = compute_cu_seqlens_q(&items);
134 assert_eq!(q_lens, vec![5, 1, 3]);
135 assert_eq!(cu, vec![0, 5, 6, 9]);
136 assert_eq!(m_total, 9);
137 }
138
139 #[test]
140 fn pos_offsets_and_max_kv_len() {
141 let items = vec![
142 item("a", 5, 0, true),
143 item("b", 1, 100, true),
144 item("c", 3, 10, false),
145 ];
146 assert_eq!(compute_pos_offsets(&items), vec![0u32, 100, 10]);
147 assert_eq!(compute_max_kv_len(&items), 101); }
149
150 #[test]
151 fn final_indices_only_final_chunks() {
152 let items = vec![
153 item("a", 5, 0, true), item("b", 1, 100, true), item("c", 3, 10, false), ];
157 let (_, cu, _) = compute_cu_seqlens_q(&items);
158 let fi = compute_final_indices(&items, &cu);
159 assert_eq!(fi, vec![(0, 4), (1, 5)]);
160 }
161
162 #[test]
163 fn graph_key_high_bit_set() {
164 let k = unified_graph_key(32, 4);
165 assert!(k & (1u64 << 63) != 0, "high bit must be set");
166 let legacy = ((32u64) << 32) | 4u64;
168 assert_ne!(k, legacy);
169 }
170
171 #[test]
172 fn stack_block_tables_pads_and_truncates() {
173 let items = vec![item("a", 1, 0, true), item("b", 1, 0, true)];
174 let stacked = stack_block_tables(&items, 3, |cid| match cid {
176 "a" => vec![10u32, 11u32],
177 "b" => vec![20u32, 21u32, 22u32, 23u32, 24u32],
178 _ => unreachable!(),
179 });
180 assert_eq!(stacked, vec![10, 11, 0, 20, 21, 22]);
183 }
184
185 #[test]
186 fn empty_items() {
187 let items: Vec<(String, Vec<u32>, usize, bool)> = Vec::new();
188 let (q_lens, cu, m_total) = compute_cu_seqlens_q(&items);
189 assert!(q_lens.is_empty());
190 assert_eq!(cu, vec![0]);
191 assert_eq!(m_total, 0);
192 }
193}