1use ferrum_interfaces::kv_dtype::KvInt8;
2use ferrum_kernels::backend::{BackendInt8KvOps, KvLayer, MoeLlmBackend};
3use ferrum_types::{FerrumError, Result};
4
5use crate::common::{DecoderOnlyLLM, LlmRuntimeConfig};
6
7use super::llama_family::{
8 llama_family_decode_op_profile_enabled, LlamaFamilyModel, LlamaStageHiddenBridgeTiming,
9};
10
11fn elapsed_micros_u64(t0: std::time::Instant) -> u64 {
12 t0.elapsed().as_micros().min(u64::MAX as u128) as u64
13}
14
15const MIN_OVERLAPPED_DECODE_BATCH: usize = 16;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub(crate) enum PipelineHiddenDtype {
19 F32,
20}
21
22impl PipelineHiddenDtype {
23 fn as_str(self) -> &'static str {
24 match self {
25 Self::F32 => "f32",
26 }
27 }
28
29 fn elem_size_bytes(self) -> usize {
30 match self {
31 Self::F32 => size_of::<f32>(),
32 }
33 }
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub(crate) enum PipelineHiddenDevice {
38 Host,
39 BackendDevice { ordinal: Option<usize> },
40}
41
42impl PipelineHiddenDevice {
43 fn as_str(self) -> &'static str {
44 match self {
45 Self::Host => "host",
46 Self::BackendDevice { .. } => "backend_device",
47 }
48 }
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub(crate) enum PipelineHiddenLayout {
53 RowMajor,
54}
55
56impl PipelineHiddenLayout {
57 fn as_str(self) -> &'static str {
58 match self {
59 Self::RowMajor => "row_major",
60 }
61 }
62}
63
64#[allow(dead_code)]
65pub(crate) enum PipelineHiddenStorage<B: MoeLlmBackend> {
66 Host(Vec<f32>),
67 Device(B::Buffer),
68}
69
70pub(crate) struct PipelineHidden<B: MoeLlmBackend> {
71 shape: [usize; 2],
72 dtype: PipelineHiddenDtype,
73 device: PipelineHiddenDevice,
74 layout: PipelineHiddenLayout,
75 storage: PipelineHiddenStorage<B>,
76}
77
78impl<B: MoeLlmBackend> PipelineHidden<B> {
79 pub(crate) fn host(data: Vec<f32>, batch: usize, hidden_size: usize) -> Self {
80 assert_eq!(
81 data.len(),
82 batch * hidden_size,
83 "pipeline hidden host buffer length {} != batch * hidden_size {}",
84 data.len(),
85 batch * hidden_size
86 );
87 Self {
88 shape: [batch, hidden_size],
89 dtype: PipelineHiddenDtype::F32,
90 device: PipelineHiddenDevice::Host,
91 layout: PipelineHiddenLayout::RowMajor,
92 storage: PipelineHiddenStorage::Host(data),
93 }
94 }
95
96 pub(crate) fn row_count(&self) -> usize {
97 self.shape[0]
98 }
99
100 pub(crate) fn hidden_size(&self) -> usize {
101 self.shape[1]
102 }
103
104 pub(crate) fn len_bytes(&self) -> usize {
105 self.shape[0] * self.shape[1] * self.dtype.elem_size_bytes()
106 }
107
108 pub(crate) fn host_slice(&self) -> &[f32] {
109 match &self.storage {
110 PipelineHiddenStorage::Host(data) => data,
111 PipelineHiddenStorage::Device(_) => {
112 panic!("device-resident PipelineHidden cannot be read through host_slice")
113 }
114 }
115 }
116
117 fn metadata_json(&self) -> serde_json::Value {
118 serde_json::json!({
119 "shape": self.shape,
120 "dtype": self.dtype.as_str(),
121 "device": self.device.as_str(),
122 "layout": self.layout.as_str(),
123 "len_bytes": self.len_bytes(),
124 })
125 }
126}
127
128pub(crate) trait LlamaPipelineStageBatchOps<B: MoeLlmBackend> {
129 fn decode_stage_tokens_to_hidden_batch(
130 &mut self,
131 batch: &[(String, u32, u32)],
132 ) -> PipelineHidden<B>;
133
134 fn decode_stage_hidden_from_host_batch(
135 &mut self,
136 batch: &[(String, u32, u32)],
137 hidden: &PipelineHidden<B>,
138 ) -> (PipelineHidden<B>, LlamaStageHiddenBridgeTiming);
139
140 fn logits_from_hidden_batch(
141 &mut self,
142 hidden: &PipelineHidden<B>,
143 force_full_logits: bool,
144 ) -> Vec<Vec<f32>>;
145}
146
147impl<B> LlamaPipelineStageBatchOps<B> for LlamaFamilyModel<B, KvInt8>
148where
149 B: MoeLlmBackend + BackendInt8KvOps,
150{
151 fn decode_stage_tokens_to_hidden_batch(
152 &mut self,
153 batch: &[(String, u32, u32)],
154 ) -> PipelineHidden<B> {
155 let h = self.config().hidden_size;
156 let mut hidden = Vec::with_capacity(batch.len() * h);
157 for (cache_id, token, pos) in batch {
158 hidden.extend_from_slice(&self.decode_stage_token_to_hidden(cache_id, *token, *pos));
159 }
160 PipelineHidden::host(hidden, batch.len(), h)
161 }
162
163 fn decode_stage_hidden_from_host_batch(
164 &mut self,
165 batch: &[(String, u32, u32)],
166 hidden: &PipelineHidden<B>,
167 ) -> (PipelineHidden<B>, LlamaStageHiddenBridgeTiming) {
168 let h = self.config().hidden_size;
169 let hidden_slice = hidden.host_slice();
170 assert_eq!(
171 hidden_slice.len(),
172 batch.len() * h,
173 "hidden length {} != batch * hidden_size {}",
174 hidden_slice.len(),
175 batch.len() * h
176 );
177 let mut out = Vec::with_capacity(hidden_slice.len());
178 let mut bridge_timing = LlamaStageHiddenBridgeTiming::default();
179 for (row, (cache_id, _, pos)) in batch.iter().enumerate() {
180 let start = row * h;
181 let (row_hidden, row_timing) = self.decode_stage_hidden_from_host_with_timing(
182 cache_id,
183 &hidden_slice[start..start + h],
184 *pos,
185 );
186 out.extend_from_slice(&row_hidden);
187 bridge_timing = bridge_timing.add(row_timing);
188 }
189 (PipelineHidden::host(out, batch.len(), h), bridge_timing)
190 }
191
192 fn logits_from_hidden_batch(
193 &mut self,
194 hidden: &PipelineHidden<B>,
195 _force_full_logits: bool,
196 ) -> Vec<Vec<f32>> {
197 let h = self.config().hidden_size;
198 let hidden_slice = hidden.host_slice();
199 assert_eq!(
200 hidden_slice.len(),
201 hidden.row_count() * h,
202 "hidden length {} != row_count * hidden_size {}",
203 hidden_slice.len(),
204 hidden.row_count() * h
205 );
206 (0..hidden.row_count())
207 .map(|row| {
208 let start = row * h;
209 self.logits_from_hidden(&hidden_slice[start..start + h])
210 })
211 .collect()
212 }
213}
214
215#[derive(Debug, Clone, Copy, PartialEq, Eq)]
216pub enum LlamaPipelineTransport {
217 HostHiddenBridge,
218}
219
220impl LlamaPipelineTransport {
221 fn as_str(self) -> &'static str {
222 match self {
223 Self::HostHiddenBridge => "host-hidden-bridge",
224 }
225 }
226
227 fn stage_bridge(self) -> LlamaPipelineStageBridge {
228 match self {
229 Self::HostHiddenBridge => LlamaPipelineStageBridge::Host,
230 }
231 }
232}
233
234#[derive(Debug, Clone, Copy, PartialEq, Eq)]
235pub enum LlamaPipelineStageBridge {
236 Host,
237 CudaPeer,
238 CudaDeviceStaged,
239}
240
241impl LlamaPipelineStageBridge {
242 fn as_str(self) -> &'static str {
243 match self {
244 Self::Host => "host",
245 Self::CudaPeer => "cuda_peer",
246 Self::CudaDeviceStaged => "cuda_device_staged",
247 }
248 }
249}
250
251#[derive(Debug, Clone, Copy, PartialEq, Eq)]
252pub enum LlamaPipelineMode {
253 Batch,
254 Overlapped,
255}
256
257impl LlamaPipelineMode {
258 pub fn as_str(self) -> &'static str {
259 match self {
260 Self::Batch => "batch",
261 Self::Overlapped => "overlapped",
262 }
263 }
264
265 pub fn default_for_stage_count(stage_count: usize) -> Self {
266 if stage_count == 2 {
267 Self::Overlapped
268 } else {
269 Self::Batch
270 }
271 }
272
273 pub fn from_config_value(value: &str) -> Result<Self> {
274 match value.trim().to_ascii_lowercase().as_str() {
275 "batch" => Ok(Self::Batch),
276 "overlapped" => Ok(Self::Overlapped),
277 other => Err(FerrumError::config(format!(
278 "layer_split_pipeline_mode must be batch or overlapped, got {other:?}"
279 ))),
280 }
281 }
282}
283
284#[derive(Debug, Clone)]
285struct PipelineDecodeStats {
286 calls: u64,
287 overlapped_calls: u64,
288 rows: u64,
289 max_batch: u64,
290 last_batch: u64,
291 microbatch_count_max: u64,
292 microbatch_count_last: u64,
293 microbatch_size_max: u64,
294 microbatch_size_last: u64,
295 in_flight_stage_count_max: u64,
296 in_flight_stage_count_last: u64,
297 queue_depth_max: u64,
298 queue_depth_last: u64,
299 host_bridge_bytes_total: u64,
300 host_bridge_bytes_last: u64,
301 bridge_us_total: u64,
302 bridge_us_last: u64,
303 host_copy_us_total: u64,
304 host_copy_us_last: u64,
305 device_copy_us_total: u64,
306 device_copy_us_last: u64,
307 stage_us_total: Vec<u64>,
308 stage_us_last: Vec<u64>,
309 logits_us_total: u64,
310 logits_us_last: u64,
311 total_us_total: u64,
312 total_us_last: u64,
313}
314
315impl PipelineDecodeStats {
316 fn new(stage_count: usize) -> Self {
317 Self {
318 calls: 0,
319 overlapped_calls: 0,
320 rows: 0,
321 max_batch: 0,
322 last_batch: 0,
323 microbatch_count_max: 0,
324 microbatch_count_last: 0,
325 microbatch_size_max: 0,
326 microbatch_size_last: 0,
327 in_flight_stage_count_max: 0,
328 in_flight_stage_count_last: 0,
329 queue_depth_max: 0,
330 queue_depth_last: 0,
331 host_bridge_bytes_total: 0,
332 host_bridge_bytes_last: 0,
333 bridge_us_total: 0,
334 bridge_us_last: 0,
335 host_copy_us_total: 0,
336 host_copy_us_last: 0,
337 device_copy_us_total: 0,
338 device_copy_us_last: 0,
339 stage_us_total: vec![0; stage_count],
340 stage_us_last: vec![0; stage_count],
341 logits_us_total: 0,
342 logits_us_last: 0,
343 total_us_total: 0,
344 total_us_last: 0,
345 }
346 }
347
348 fn record(
349 &mut self,
350 batch: usize,
351 microbatch_count: usize,
352 microbatch_size: usize,
353 in_flight_stage_count: usize,
354 queue_depth: usize,
355 overlapped: bool,
356 host_bridge_bytes: usize,
357 bridge_timing: LlamaStageHiddenBridgeTiming,
358 stage_us: &[u64],
359 logits_us: u64,
360 total_us: u64,
361 ) {
362 if self.stage_us_total.len() != stage_us.len() {
363 self.stage_us_total.resize(stage_us.len(), 0);
364 self.stage_us_last.resize(stage_us.len(), 0);
365 }
366 self.calls = self.calls.saturating_add(1);
367 if overlapped {
368 self.overlapped_calls = self.overlapped_calls.saturating_add(1);
369 }
370 self.rows = self.rows.saturating_add(batch as u64);
371 self.max_batch = self.max_batch.max(batch as u64);
372 self.last_batch = batch as u64;
373 self.microbatch_count_max = self.microbatch_count_max.max(microbatch_count as u64);
374 self.microbatch_count_last = microbatch_count as u64;
375 self.microbatch_size_max = self.microbatch_size_max.max(microbatch_size as u64);
376 self.microbatch_size_last = microbatch_size as u64;
377 self.in_flight_stage_count_max = self
378 .in_flight_stage_count_max
379 .max(in_flight_stage_count as u64);
380 self.in_flight_stage_count_last = in_flight_stage_count as u64;
381 self.queue_depth_max = self.queue_depth_max.max(queue_depth as u64);
382 self.queue_depth_last = queue_depth as u64;
383 self.host_bridge_bytes_total = self
384 .host_bridge_bytes_total
385 .saturating_add(host_bridge_bytes as u64);
386 self.host_bridge_bytes_last = host_bridge_bytes as u64;
387 self.bridge_us_total = self.bridge_us_total.saturating_add(bridge_timing.bridge_us);
388 self.bridge_us_last = bridge_timing.bridge_us;
389 self.host_copy_us_total = self
390 .host_copy_us_total
391 .saturating_add(bridge_timing.host_copy_us);
392 self.host_copy_us_last = bridge_timing.host_copy_us;
393 self.device_copy_us_total = self
394 .device_copy_us_total
395 .saturating_add(bridge_timing.device_copy_us);
396 self.device_copy_us_last = bridge_timing.device_copy_us;
397 for (idx, value) in stage_us.iter().copied().enumerate() {
398 self.stage_us_total[idx] = self.stage_us_total[idx].saturating_add(value);
399 self.stage_us_last[idx] = value;
400 }
401 self.logits_us_total = self.logits_us_total.saturating_add(logits_us);
402 self.logits_us_last = logits_us;
403 self.total_us_total = self.total_us_total.saturating_add(total_us);
404 self.total_us_last = total_us;
405 }
406
407 fn avg_per_call(&self, value: u64) -> Option<u64> {
408 (self.calls > 0).then(|| value / self.calls)
409 }
410
411 fn json(&self) -> serde_json::Value {
412 let stage_us_avg: Vec<Option<u64>> = self
413 .stage_us_total
414 .iter()
415 .map(|value| self.avg_per_call(*value))
416 .collect();
417 serde_json::json!({
418 "calls": self.calls,
419 "overlapped_calls": self.overlapped_calls,
420 "rows": self.rows,
421 "max_batch": self.max_batch,
422 "last_batch": self.last_batch,
423 "microbatch_count_max": self.microbatch_count_max,
424 "microbatch_count_last": self.microbatch_count_last,
425 "microbatch_size_max": self.microbatch_size_max,
426 "microbatch_size_last": self.microbatch_size_last,
427 "in_flight_stage_count_max": self.in_flight_stage_count_max,
428 "in_flight_stage_count_last": self.in_flight_stage_count_last,
429 "queue_depth_max": self.queue_depth_max,
430 "queue_depth_last": self.queue_depth_last,
431 "host_bridge_bytes_total": self.host_bridge_bytes_total,
432 "host_bridge_bytes_last": self.host_bridge_bytes_last,
433 "host_bridge_bytes_avg": self.avg_per_call(self.host_bridge_bytes_total),
434 "bridge_us_total": self.bridge_us_total,
435 "bridge_us_last": self.bridge_us_last,
436 "bridge_us_avg": self.avg_per_call(self.bridge_us_total),
437 "host_copy_us_total": self.host_copy_us_total,
438 "host_copy_us_last": self.host_copy_us_last,
439 "host_copy_us_avg": self.avg_per_call(self.host_copy_us_total),
440 "device_copy_us_total": self.device_copy_us_total,
441 "device_copy_us_last": self.device_copy_us_last,
442 "device_copy_us_avg": self.avg_per_call(self.device_copy_us_total),
443 "stage_us_total": self.stage_us_total,
444 "stage_us_last": self.stage_us_last,
445 "stage_us_avg": stage_us_avg,
446 "logits_us_total": self.logits_us_total,
447 "logits_us_last": self.logits_us_last,
448 "logits_us_avg": self.avg_per_call(self.logits_us_total),
449 "total_us_total": self.total_us_total,
450 "total_us_last": self.total_us_last,
451 "total_us_avg": self.avg_per_call(self.total_us_total),
452 })
453 }
454}
455
456#[derive(Debug, Clone, Copy, PartialEq, Eq)]
457pub struct LlamaPipelineStagePlacement {
458 pub backend_device_ordinal: Option<usize>,
459}
460
461impl LlamaPipelineStagePlacement {
462 pub fn default_backend_device() -> Self {
463 Self {
464 backend_device_ordinal: None,
465 }
466 }
467
468 pub fn backend_device(ordinal: usize) -> Self {
469 Self {
470 backend_device_ordinal: Some(ordinal),
471 }
472 }
473}
474
475#[derive(Debug, Clone, PartialEq, Eq)]
476pub struct LlamaPipelinePlacement {
477 stages: Vec<LlamaPipelineStagePlacement>,
478 transport: LlamaPipelineTransport,
479 pipeline_mode: LlamaPipelineMode,
480}
481
482impl LlamaPipelinePlacement {
483 pub fn unplaced(stage_count: usize) -> Self {
484 Self {
485 stages: vec![LlamaPipelineStagePlacement::default_backend_device(); stage_count],
486 transport: LlamaPipelineTransport::HostHiddenBridge,
487 pipeline_mode: LlamaPipelineMode::default_for_stage_count(stage_count),
488 }
489 }
490
491 pub fn from_backend_device_ordinals(stage_device_ordinals: Vec<Option<usize>>) -> Self {
492 let stage_count = stage_device_ordinals.len();
493 Self {
494 stages: stage_device_ordinals
495 .into_iter()
496 .map(|backend_device_ordinal| LlamaPipelineStagePlacement {
497 backend_device_ordinal,
498 })
499 .collect(),
500 transport: LlamaPipelineTransport::HostHiddenBridge,
501 pipeline_mode: LlamaPipelineMode::default_for_stage_count(stage_count),
502 }
503 }
504
505 pub fn with_pipeline_mode(mut self, pipeline_mode: LlamaPipelineMode) -> Self {
506 self.pipeline_mode = pipeline_mode;
507 self
508 }
509
510 pub fn len(&self) -> usize {
511 self.stages.len()
512 }
513
514 pub fn is_empty(&self) -> bool {
515 self.stages.is_empty()
516 }
517
518 pub fn stage(&self, idx: usize) -> LlamaPipelineStagePlacement {
519 self.stages[idx]
520 }
521
522 pub fn stages(&self) -> &[LlamaPipelineStagePlacement] {
523 &self.stages
524 }
525
526 pub fn transport(&self) -> LlamaPipelineTransport {
527 self.transport
528 }
529
530 pub fn stage_bridge(&self) -> LlamaPipelineStageBridge {
531 self.transport.stage_bridge()
532 }
533
534 pub fn pipeline_mode(&self) -> LlamaPipelineMode {
535 self.pipeline_mode
536 }
537
538 pub fn stage_device_ordinals(&self) -> Vec<Option<usize>> {
539 self.stages
540 .iter()
541 .map(|stage| stage.backend_device_ordinal)
542 .collect()
543 }
544
545 fn has_explicit_device_ordinals(&self) -> bool {
546 self.stages
547 .iter()
548 .any(|stage| stage.backend_device_ordinal.is_some())
549 }
550}
551
552pub struct LlamaFamilyPipelineModel<B: MoeLlmBackend, K: KvLayer<B>> {
553 stages: Vec<LlamaFamilyModel<B, K>>,
554 placement: LlamaPipelinePlacement,
555 runtime_cfg: LlmRuntimeConfig,
556 decode_stats: PipelineDecodeStats,
557}
558
559impl<B: MoeLlmBackend, K: KvLayer<B>> LlamaFamilyPipelineModel<B, K> {
560 pub fn new(stages: Vec<LlamaFamilyModel<B, K>>) -> Result<Self> {
561 let placement = LlamaPipelinePlacement::unplaced(stages.len());
562 Self::new_with_placement(stages, placement)
563 }
564
565 pub fn new_with_backend_device_ordinals(
566 stages: Vec<LlamaFamilyModel<B, K>>,
567 stage_device_ordinals: Vec<Option<usize>>,
568 ) -> Result<Self> {
569 Self::new_with_placement(
570 stages,
571 LlamaPipelinePlacement::from_backend_device_ordinals(stage_device_ordinals),
572 )
573 }
574
575 pub fn new_with_placement(
576 stages: Vec<LlamaFamilyModel<B, K>>,
577 placement: LlamaPipelinePlacement,
578 ) -> Result<Self> {
579 if stages.is_empty() {
580 return Err(FerrumError::model(
581 "LlamaFamilyPipelineModel requires at least one stage",
582 ));
583 }
584 if placement.len() != stages.len() {
585 return Err(FerrumError::model(format!(
586 "Llama pipeline stage device count {} must match stage count {}",
587 placement.len(),
588 stages.len()
589 )));
590 }
591 if placement.has_explicit_device_ordinals() && !B::supports_device_ordinal_scope() {
592 return Err(FerrumError::unsupported(
593 "Llama layer-split pipeline requested explicit backend device ordinals, \
594 but the selected backend does not support device-scoped execution",
595 ));
596 }
597 if placement.pipeline_mode() == LlamaPipelineMode::Overlapped && placement.len() != 2 {
598 return Err(FerrumError::model(
599 "Llama layer-split overlapped pipeline mode requires exactly two stages",
600 ));
601 }
602 if stages.first().is_some_and(|stage| stage.embed.is_none()) {
603 return Err(FerrumError::model(
604 "first Llama pipeline stage must load embedding weights",
605 ));
606 }
607 if stages.last().is_some_and(|stage| stage.lm_head.is_none()) {
608 return Err(FerrumError::model(
609 "last Llama pipeline stage must load lm_head weights",
610 ));
611 }
612
613 let runtime_cfg = stages[0].runtime_cfg.clone();
614 let mut expected_start = 0usize;
615 for stage in &stages {
616 if stage.runtime_cfg.hidden_size != runtime_cfg.hidden_size
617 || stage.runtime_cfg.vocab_size != runtime_cfg.vocab_size
618 || stage.runtime_cfg.num_kv_heads != runtime_cfg.num_kv_heads
619 || stage.runtime_cfg.head_dim != runtime_cfg.head_dim
620 || stage.runtime_cfg.max_seq_len != runtime_cfg.max_seq_len
621 {
622 return Err(FerrumError::model(
623 "Llama pipeline stages must share runtime dimensions",
624 ));
625 }
626 let range = stage.source_layer_range();
627 if range.start != expected_start {
628 return Err(FerrumError::model(format!(
629 "Llama pipeline stage range starts at {}, expected {expected_start}",
630 range.start
631 )));
632 }
633 expected_start = range.end;
634 }
635 if expected_start != runtime_cfg.num_layers {
636 return Err(FerrumError::model(format!(
637 "Llama pipeline stages cover {expected_start} layers but model has {}",
638 runtime_cfg.num_layers
639 )));
640 }
641
642 let stage_count = placement.len();
643 Ok(Self {
644 stages,
645 placement,
646 runtime_cfg,
647 decode_stats: PipelineDecodeStats::new(stage_count),
648 })
649 }
650
651 pub fn stages(&self) -> &[LlamaFamilyModel<B, K>] {
652 &self.stages
653 }
654
655 pub fn placement(&self) -> &LlamaPipelinePlacement {
656 &self.placement
657 }
658
659 fn pipeline_mode(&self) -> LlamaPipelineMode {
660 self.placement.pipeline_mode()
661 }
662
663 fn decode_microbatch_size(&self, batch_len: usize) -> usize {
664 match self.pipeline_mode() {
665 LlamaPipelineMode::Overlapped if batch_len < MIN_OVERLAPPED_DECODE_BATCH => {
666 batch_len.max(1)
667 }
668 LlamaPipelineMode::Overlapped => ((batch_len + 1) / 2).max(1),
669 LlamaPipelineMode::Batch => batch_len.max(1),
670 }
671 }
672
673 fn last_hidden_row<'a>(&self, hidden: &'a [f32], seq_len: usize) -> &'a [f32] {
674 let h = self.runtime_cfg.hidden_size;
675 &hidden[(seq_len - 1) * h..seq_len * h]
676 }
677}
678
679#[allow(private_bounds)]
680impl<B, K> LlamaFamilyPipelineModel<B, K>
681where
682 B: MoeLlmBackend,
683 K: KvLayer<B>,
684 LlamaFamilyModel<B, K>: DecoderOnlyLLM + LlamaPipelineStageBatchOps<B> + Send,
685{
686 fn decode_batch_sequential_internal(
687 &mut self,
688 batch: &[(String, u32, u32)],
689 force_full_logits: bool,
690 ) -> Vec<Vec<f32>> {
691 let profile = llama_family_decode_op_profile_enabled();
692 let total_t0 = std::time::Instant::now();
693 let mut stage_us: Vec<u64> = Vec::with_capacity(self.stages.len());
694 let mut host_bridge_bytes = 0usize;
695 let mut bridge_timing = LlamaStageHiddenBridgeTiming::default();
696 let stage_bridge = self.placement.stage_bridge();
697
698 let stage_t0 = std::time::Instant::now();
699 let mut hidden =
700 B::with_device_ordinal(self.placement.stage(0).backend_device_ordinal, || {
701 self.stages[0].decode_stage_tokens_to_hidden_batch(batch)
702 });
703 stage_us.push(elapsed_micros_u64(stage_t0));
704 for idx in 1..self.stages.len() {
705 let device = self.placement.stage(idx).backend_device_ordinal;
706 host_bridge_bytes = host_bridge_bytes.saturating_add(hidden.len_bytes());
707 let stage_t0 = std::time::Instant::now();
708 let (next_hidden, stage_bridge_timing) = B::with_device_ordinal(device, || {
709 self.stages[idx].decode_stage_hidden_from_host_batch(batch, &hidden)
710 });
711 hidden = next_hidden;
712 bridge_timing = bridge_timing.add(stage_bridge_timing);
713 stage_us.push(elapsed_micros_u64(stage_t0));
714 }
715 let last_idx = self.stages.len() - 1;
716 let logits_t0 = std::time::Instant::now();
717 let logits = B::with_device_ordinal(
718 self.placement.stage(last_idx).backend_device_ordinal,
719 || self.stages[last_idx].logits_from_hidden_batch(&hidden, force_full_logits),
720 );
721 let logits_us = elapsed_micros_u64(logits_t0);
722 let total_us = elapsed_micros_u64(total_t0);
723 self.decode_stats.record(
724 batch.len(),
725 1,
726 batch.len().max(1),
727 1,
728 0,
729 false,
730 host_bridge_bytes,
731 bridge_timing,
732 &stage_us,
733 logits_us,
734 total_us,
735 );
736 if profile {
737 static PIPELINE_PROFILE_CALLS: std::sync::atomic::AtomicU64 =
738 std::sync::atomic::AtomicU64::new(0);
739 let n = PIPELINE_PROFILE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
740 if n.is_multiple_of(8) {
741 eprintln!(
742 "[pipeline-decode-profile] call#{} m={} hidden={:?} mode=batch bridge={} host_bridge_bytes={} bridge_us={} host_copy_us={} device_copy_us={} stage_us={:?} logits_us={} total_us={}",
743 n,
744 batch.len(),
745 hidden.metadata_json(),
746 stage_bridge.as_str(),
747 host_bridge_bytes,
748 bridge_timing.bridge_us,
749 bridge_timing.host_copy_us,
750 bridge_timing.device_copy_us,
751 stage_us,
752 logits_us,
753 total_us,
754 );
755 }
756 }
757 logits
758 }
759
760 fn decode_batch_overlapped_two_stage(
761 &mut self,
762 batch: &[(String, u32, u32)],
763 force_full_logits: bool,
764 ) -> Vec<Vec<f32>> {
765 debug_assert_eq!(self.stages.len(), 2);
766 let microbatch_size = self.decode_microbatch_size(batch.len());
767 let chunks: Vec<Vec<(String, u32, u32)>> = batch
768 .chunks(microbatch_size)
769 .map(|chunk| chunk.to_vec())
770 .collect();
771 let chunk_count = chunks.len();
772 if chunk_count <= 1 {
773 return self.decode_batch_sequential_internal(batch, force_full_logits);
774 }
775
776 let profile = llama_family_decode_op_profile_enabled();
777 let total_t0 = std::time::Instant::now();
778 let stage0_device = self.placement.stage(0).backend_device_ordinal;
779 let stage1_device = self.placement.stage(1).backend_device_ordinal;
780 let stage_bridge = self.placement.stage_bridge();
781 let (stage0_slice, stage1_slice) = self.stages.split_at_mut(1);
782 let stage0 = &mut stage0_slice[0];
783 let stage1 = &mut stage1_slice[0];
784 let (tx, rx) = std::sync::mpsc::sync_channel::<(
785 usize,
786 Vec<(String, u32, u32)>,
787 PipelineHidden<B>,
788 u64,
789 )>(1);
790 let mut ordered_logits: Vec<Option<Vec<Vec<f32>>>> = vec![None; chunk_count];
791 let mut stage0_us_total = 0u64;
792 let mut stage1_us_total = 0u64;
793 let mut logits_us_total = 0u64;
794 let mut host_bridge_bytes = 0usize;
795 let mut bridge_timing = LlamaStageHiddenBridgeTiming::default();
796
797 std::thread::scope(|scope| {
798 let worker = scope.spawn(move || {
799 for (idx, chunk) in chunks.into_iter().enumerate() {
800 let stage_t0 = std::time::Instant::now();
801 let hidden = B::with_device_ordinal(stage0_device, || {
802 stage0.decode_stage_tokens_to_hidden_batch(&chunk)
803 });
804 let stage_us = elapsed_micros_u64(stage_t0);
805 if tx.send((idx, chunk, hidden, stage_us)).is_err() {
806 break;
807 }
808 }
809 });
810
811 for _ in 0..chunk_count {
812 let (idx, chunk, hidden, stage0_us) = rx
813 .recv()
814 .expect("pipeline stage0 worker ended before sending all microbatches");
815 stage0_us_total = stage0_us_total.saturating_add(stage0_us);
816 host_bridge_bytes = host_bridge_bytes.saturating_add(hidden.len_bytes());
817
818 let stage_t0 = std::time::Instant::now();
819 let (hidden, stage_bridge_timing) = B::with_device_ordinal(stage1_device, || {
820 stage1.decode_stage_hidden_from_host_batch(&chunk, &hidden)
821 });
822 bridge_timing = bridge_timing.add(stage_bridge_timing);
823 stage1_us_total = stage1_us_total.saturating_add(elapsed_micros_u64(stage_t0));
824
825 let logits_t0 = std::time::Instant::now();
826 let logits = B::with_device_ordinal(stage1_device, || {
827 stage1.logits_from_hidden_batch(&hidden, force_full_logits)
828 });
829 logits_us_total = logits_us_total.saturating_add(elapsed_micros_u64(logits_t0));
830 ordered_logits[idx] = Some(logits);
831 }
832
833 worker
834 .join()
835 .expect("pipeline stage0 worker panicked during overlapped decode");
836 });
837
838 let total_us = elapsed_micros_u64(total_t0);
839 let stage_us = vec![stage0_us_total, stage1_us_total];
840 self.decode_stats.record(
841 batch.len(),
842 chunk_count,
843 microbatch_size,
844 2,
845 1,
846 true,
847 host_bridge_bytes,
848 bridge_timing,
849 &stage_us,
850 logits_us_total,
851 total_us,
852 );
853 if profile {
854 static PIPELINE_PROFILE_CALLS: std::sync::atomic::AtomicU64 =
855 std::sync::atomic::AtomicU64::new(0);
856 let n = PIPELINE_PROFILE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
857 if n.is_multiple_of(8) {
858 eprintln!(
859 "[pipeline-decode-profile] call#{} m={} mode=overlapped microbatch_size={} microbatch_count={} bridge={} host_bridge_bytes={} bridge_us={} host_copy_us={} device_copy_us={} stage_us={:?} logits_us={} total_us={}",
860 n,
861 batch.len(),
862 microbatch_size,
863 chunk_count,
864 stage_bridge.as_str(),
865 host_bridge_bytes,
866 bridge_timing.bridge_us,
867 bridge_timing.host_copy_us,
868 bridge_timing.device_copy_us,
869 stage_us,
870 logits_us_total,
871 total_us,
872 );
873 }
874 }
875
876 let mut logits = Vec::with_capacity(batch.len());
877 for chunk_logits in ordered_logits {
878 logits.extend(
879 chunk_logits.expect("pipeline overlapped decode missing logits for microbatch"),
880 );
881 }
882 logits
883 }
884}
885
886impl<B, K> DecoderOnlyLLM for LlamaFamilyPipelineModel<B, K>
887where
888 B: MoeLlmBackend,
889 K: KvLayer<B>,
890 LlamaFamilyModel<B, K>: DecoderOnlyLLM + LlamaPipelineStageBatchOps<B> + Send,
891{
892 fn config(&self) -> &LlmRuntimeConfig {
893 &self.runtime_cfg
894 }
895
896 fn cache_metrics_snapshot(&self) -> Option<serde_json::Value> {
897 let stage_bridge = self.placement.stage_bridge();
898 let pipeline_mode = self.pipeline_mode();
899 Some(serde_json::json!({
900 "position": "llama-layer-split-pipeline",
901 "stage_count": self.stages.len() as u64,
902 "stage_device_ordinals": self.placement.stage_device_ordinals(),
903 "transport": self.placement.transport().as_str(),
904 "selected_pipeline_mode": pipeline_mode.as_str(),
905 "selected_stage_bridge": stage_bridge.as_str(),
906 "pipeline_hidden": {
907 "dtype": PipelineHiddenDtype::F32.as_str(),
908 "device": PipelineHiddenDevice::Host.as_str(),
909 "layout": PipelineHiddenLayout::RowMajor.as_str(),
910 },
911 "pipeline_decode": self.decode_stats.json(),
912 }))
913 }
914
915 fn prepare(&mut self, cache_id: &str, max_tokens: usize) {
916 for (idx, stage) in self.stages.iter_mut().enumerate() {
917 let device = self.placement.stage(idx).backend_device_ordinal;
918 B::with_device_ordinal(device, || {
919 stage.ensure_scratch(max_tokens);
920 stage.ensure_kv(cache_id);
921 });
922 }
923 }
924
925 fn kv_capacity(&self) -> usize {
926 self.stages
927 .iter()
928 .map(|stage| stage.kv_capacity())
929 .min()
930 .unwrap_or(self.runtime_cfg.max_seq_len)
931 }
932
933 fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
934 assert!(!tokens.is_empty(), "pipeline prefill called with no tokens");
935 let pos_offset = self.stages[0].cache_len(cache_id);
936 let mut hidden =
937 B::with_device_ordinal(self.placement.stage(0).backend_device_ordinal, || {
938 self.stages[0].prefill_stage_tokens_to_hidden(cache_id, tokens, pos_offset)
939 });
940 for idx in 1..self.stages.len() {
941 let device = self.placement.stage(idx).backend_device_ordinal;
942 hidden = B::with_device_ordinal(device, || {
943 self.stages[idx].prefill_stage_hidden_from_host(
944 cache_id,
945 &hidden,
946 tokens.len(),
947 pos_offset,
948 )
949 });
950 }
951 let last_hidden = self.last_hidden_row(&hidden, tokens.len()).to_vec();
952 let last_idx = self.stages.len() - 1;
953 B::with_device_ordinal(
954 self.placement.stage(last_idx).backend_device_ordinal,
955 || self.stages[last_idx].logits_from_hidden(&last_hidden),
956 )
957 }
958
959 fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
960 let total_t0 = std::time::Instant::now();
961 let mut stage_us: Vec<u64> = Vec::with_capacity(self.stages.len());
962 let mut host_bridge_bytes = 0usize;
963 let mut bridge_timing = LlamaStageHiddenBridgeTiming::default();
964
965 let stage_t0 = std::time::Instant::now();
966 let mut hidden =
967 B::with_device_ordinal(self.placement.stage(0).backend_device_ordinal, || {
968 self.stages[0].decode_stage_token_to_hidden(cache_id, token, pos)
969 });
970 stage_us.push(elapsed_micros_u64(stage_t0));
971 for idx in 1..self.stages.len() {
972 let device = self.placement.stage(idx).backend_device_ordinal;
973 host_bridge_bytes = host_bridge_bytes.saturating_add(hidden.len() * size_of::<f32>());
974 let stage_t0 = std::time::Instant::now();
975 let (next_hidden, stage_bridge_timing) = B::with_device_ordinal(device, || {
976 self.stages[idx].decode_stage_hidden_from_host_with_timing(cache_id, &hidden, pos)
977 });
978 hidden = next_hidden;
979 bridge_timing = bridge_timing.add(stage_bridge_timing);
980 stage_us.push(elapsed_micros_u64(stage_t0));
981 }
982 let last_idx = self.stages.len() - 1;
983 let logits_t0 = std::time::Instant::now();
984 let logits = B::with_device_ordinal(
985 self.placement.stage(last_idx).backend_device_ordinal,
986 || self.stages[last_idx].logits_from_hidden(&hidden),
987 );
988 self.decode_stats.record(
989 1,
990 1,
991 1,
992 1,
993 0,
994 false,
995 host_bridge_bytes,
996 bridge_timing,
997 &stage_us,
998 elapsed_micros_u64(logits_t0),
999 elapsed_micros_u64(total_t0),
1000 );
1001 logits
1002 }
1003
1004 fn decode_batch(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
1005 self.decode_batch_with_full_logits(batch, false)
1006 }
1007
1008 fn decode_batch_with_full_logits(
1009 &mut self,
1010 batch: &[(String, u32, u32)],
1011 force_full_logits: bool,
1012 ) -> Vec<Vec<f32>> {
1013 if batch.is_empty() {
1014 return Vec::new();
1015 }
1016 if batch.len() == 1 && !force_full_logits {
1017 let (cache_id, token, pos) = &batch[0];
1018 return vec![self.decode(cache_id, *token, *pos)];
1019 }
1020
1021 if self.pipeline_mode() == LlamaPipelineMode::Overlapped {
1022 self.decode_batch_overlapped_two_stage(batch, force_full_logits)
1023 } else {
1024 self.decode_batch_sequential_internal(batch, force_full_logits)
1025 }
1026 }
1027
1028 fn release(&mut self, cache_id: &str) {
1029 for (idx, stage) in self.stages.iter_mut().enumerate() {
1030 B::with_device_ordinal(self.placement.stage(idx).backend_device_ordinal, || {
1031 stage.release(cache_id);
1032 });
1033 }
1034 }
1035
1036 fn truncate_kv(&mut self, cache_id: &str, new_len: usize) {
1037 for (idx, stage) in self.stages.iter_mut().enumerate() {
1038 B::with_device_ordinal(self.placement.stage(idx).backend_device_ordinal, || {
1039 stage.truncate_kv(cache_id, new_len);
1040 });
1041 }
1042 }
1043
1044 fn reset(&mut self) {
1045 for (idx, stage) in self.stages.iter_mut().enumerate() {
1046 B::with_device_ordinal(self.placement.stage(idx).backend_device_ordinal, || {
1047 stage.reset();
1048 });
1049 }
1050 }
1051}
1052
1053#[cfg(test)]
1054mod tests {
1055 use ferrum_interfaces::kv_dtype::KvFp16;
1056 use ferrum_kernels::backend::cpu::CpuBackend;
1057 use ferrum_quantization::{DenseLinear, QuantConfig, WeightLoader};
1058 use ferrum_types::{FerrumError, Result};
1059
1060 use super::*;
1061 use crate::models::llama_family::{LlamaFamilyConfig, LlamaFamilyLayerStageConfig};
1062
1063 struct ParityLoader {
1064 cfg: LlamaFamilyConfig,
1065 }
1066
1067 impl ParityLoader {
1068 fn new(cfg: LlamaFamilyConfig) -> Self {
1069 Self { cfg }
1070 }
1071
1072 fn deterministic_values(name: &str, len: usize, base: f32, scale: f32) -> Vec<f32> {
1073 let mut hash = 0x811c9dc5u32;
1074 for byte in name.bytes() {
1075 hash ^= byte as u32;
1076 hash = hash.wrapping_mul(0x01000193);
1077 }
1078 (0..len)
1079 .map(|idx| {
1080 let mixed = hash
1081 .wrapping_add((idx as u32).wrapping_mul(0x9e3779b9))
1082 .rotate_left((idx % 17) as u32);
1083 let centered = (mixed % 23) as f32 - 11.0;
1084 base + centered * scale
1085 })
1086 .collect()
1087 }
1088
1089 fn layer_norm_values(&self, name: &str) -> Vec<f32> {
1090 Self::deterministic_values(name, self.cfg.hidden_size, 1.0, 0.005)
1091 }
1092
1093 fn linear_dims(&self, name: &str) -> Result<(usize, usize)> {
1094 let q_dim = self.cfg.num_heads * self.cfg.head_dim;
1095 let kv_dim = self.cfg.num_kv_heads * self.cfg.head_dim;
1096 if name.ends_with(".self_attn.qkv_proj") {
1097 Ok((q_dim + 2 * kv_dim, self.cfg.hidden_size))
1098 } else if name.ends_with(".self_attn.o_proj") {
1099 Ok((self.cfg.hidden_size, q_dim))
1100 } else if name.ends_with(".mlp.gate_up_proj") {
1101 Ok((2 * self.cfg.intermediate_size, self.cfg.hidden_size))
1102 } else if name.ends_with(".mlp.down_proj") {
1103 Ok((self.cfg.hidden_size, self.cfg.intermediate_size))
1104 } else if name == "lm_head" || name == "model.embed_tokens" {
1105 Ok((self.cfg.vocab_size, self.cfg.hidden_size))
1106 } else {
1107 Err(FerrumError::model(format!(
1108 "unexpected linear requested by parity loader: {name}"
1109 )))
1110 }
1111 }
1112 }
1113
1114 impl WeightLoader<CpuBackend> for ParityLoader {
1115 fn load_tensor(&self, name: &str) -> Result<Vec<f32>> {
1116 if name == "model.embed_tokens.weight" {
1117 return Ok(Self::deterministic_values(
1118 name,
1119 self.cfg.vocab_size * self.cfg.hidden_size,
1120 0.0,
1121 0.02,
1122 ));
1123 }
1124 if name == "model.norm.weight"
1125 || name.ends_with(".input_layernorm.weight")
1126 || name.ends_with(".post_attention_layernorm.weight")
1127 {
1128 return Ok(self.layer_norm_values(name));
1129 }
1130 Err(FerrumError::model(format!(
1131 "unexpected tensor requested by parity loader: {name}"
1132 )))
1133 }
1134
1135 fn load_linear(
1136 &self,
1137 name: &str,
1138 ) -> Result<Box<dyn ferrum_quantization::Linear<CpuBackend>>> {
1139 let (out_features, in_features) = self.linear_dims(name)?;
1140 let weights = Self::deterministic_values(name, out_features * in_features, 0.0, 0.015);
1141 Ok(Box::new(DenseLinear::<CpuBackend>::from_rows(
1142 &weights,
1143 out_features,
1144 in_features,
1145 )))
1146 }
1147
1148 fn has_tensor(&self, name: &str) -> bool {
1149 name == "lm_head.weight"
1150 }
1151
1152 fn quant_config(&self) -> Option<&QuantConfig> {
1153 None
1154 }
1155 }
1156
1157 fn parity_config(num_layers: usize) -> LlamaFamilyConfig {
1158 LlamaFamilyConfig {
1159 hidden_size: 4,
1160 intermediate_size: 8,
1161 num_heads: 2,
1162 num_kv_heads: 2,
1163 head_dim: 2,
1164 num_layers,
1165 vocab_size: 7,
1166 max_seq_len: 16,
1167 rms_norm_eps: 1e-5,
1168 rope_theta: 10_000.0,
1169 rope_scaling: None,
1170 rope_interleaved: false,
1171 has_qk_norm: false,
1172 sliding_window: 0,
1173 }
1174 }
1175
1176 fn build_full_and_pipeline() -> (
1177 LlamaFamilyModel<CpuBackend, KvFp16>,
1178 LlamaFamilyPipelineModel<CpuBackend, KvFp16>,
1179 ) {
1180 build_full_and_pipeline_with_mode(LlamaPipelineMode::default_for_stage_count(2))
1181 }
1182
1183 fn build_full_and_pipeline_with_mode(
1184 pipeline_mode: LlamaPipelineMode,
1185 ) -> (
1186 LlamaFamilyModel<CpuBackend, KvFp16>,
1187 LlamaFamilyPipelineModel<CpuBackend, KvFp16>,
1188 ) {
1189 let cfg = parity_config(3);
1190 let loader = ParityLoader::new(cfg.clone());
1191 let full = LlamaFamilyModel::<CpuBackend, KvFp16>::new(cfg.clone(), &loader).unwrap();
1192 let stage0 = LlamaFamilyModel::<CpuBackend, KvFp16>::new_layer_stage(
1193 cfg.clone(),
1194 &loader,
1195 LlamaFamilyLayerStageConfig::pipeline_stage(0..1, true, false),
1196 )
1197 .unwrap();
1198 let stage1 = LlamaFamilyModel::<CpuBackend, KvFp16>::new_layer_stage(
1199 cfg,
1200 &loader,
1201 LlamaFamilyLayerStageConfig::pipeline_stage(1..3, false, true),
1202 )
1203 .unwrap();
1204 let pipeline = LlamaFamilyPipelineModel::new_with_placement(
1205 vec![stage0, stage1],
1206 LlamaPipelinePlacement::unplaced(2).with_pipeline_mode(pipeline_mode),
1207 )
1208 .unwrap();
1209 (full, pipeline)
1210 }
1211
1212 fn assert_logits_close(label: &str, expected: &[f32], actual: &[f32]) {
1213 assert_eq!(expected.len(), actual.len(), "{label} length mismatch");
1214 let max_diff = expected
1215 .iter()
1216 .zip(actual)
1217 .map(|(a, b)| (a - b).abs())
1218 .fold(0.0f32, f32::max);
1219 assert!(
1220 max_diff < 1e-5,
1221 "{label} logits diverged: max_diff={max_diff} expected={expected:?} actual={actual:?}"
1222 );
1223 }
1224
1225 #[test]
1226 fn pipeline_prefill_matches_full_model_on_multi_token_cpu_model() {
1227 let (mut full, mut pipeline) = build_full_and_pipeline();
1228
1229 let full_logits = full.prefill("full", &[0, 1, 2, 3]);
1230 let pipeline_logits = pipeline.prefill("pipe", &[0, 1, 2, 3]);
1231
1232 assert_eq!(pipeline.config().num_layers, 3);
1233 assert_eq!(pipeline.stages().len(), 2);
1234 assert_eq!(
1235 pipeline.placement().stage_device_ordinals(),
1236 vec![None, None]
1237 );
1238 assert_eq!(
1239 pipeline.placement().transport(),
1240 LlamaPipelineTransport::HostHiddenBridge
1241 );
1242 let metrics = pipeline.cache_metrics_snapshot().unwrap();
1243 assert_eq!(metrics["selected_pipeline_mode"], "overlapped");
1244 assert_eq!(metrics["selected_stage_bridge"], "host");
1245 assert_eq!(metrics["pipeline_decode"]["calls"], 0);
1246 assert_logits_close("multi-token prefill", &full_logits, &pipeline_logits);
1247 }
1248
1249 #[test]
1250 fn pipeline_decode_after_multi_token_prefill_matches_full_model() {
1251 let (mut full, mut pipeline) = build_full_and_pipeline();
1252
1253 let _ = full.prefill("full", &[0, 1, 2]);
1254 let _ = pipeline.prefill("pipe", &[0, 1, 2]);
1255 let full_logits_3 = full.decode("full", 3, 3);
1256 let pipeline_logits_3 = pipeline.decode("pipe", 3, 3);
1257 assert_logits_close("decode pos 3", &full_logits_3, &pipeline_logits_3);
1258
1259 let full_logits_4 = full.decode("full", 4, 4);
1260 let pipeline_logits_4 = pipeline.decode("pipe", 4, 4);
1261 assert_logits_close("decode pos 4", &full_logits_4, &pipeline_logits_4);
1262
1263 let metrics = pipeline.cache_metrics_snapshot().unwrap();
1264 assert_eq!(metrics["pipeline_decode"]["calls"], 2);
1265 assert_eq!(metrics["pipeline_decode"]["overlapped_calls"], 0);
1266 assert_eq!(metrics["pipeline_decode"]["rows"], 2);
1267 assert_eq!(metrics["pipeline_decode"]["max_batch"], 1);
1268 assert_eq!(
1269 metrics["pipeline_decode"]["stage_us_last"]
1270 .as_array()
1271 .unwrap()
1272 .len(),
1273 pipeline.stages().len()
1274 );
1275 assert!(
1276 metrics["pipeline_decode"]["host_bridge_bytes_total"]
1277 .as_u64()
1278 .unwrap()
1279 > 0
1280 );
1281 assert!(
1282 metrics["pipeline_decode"]["bridge_us_total"]
1283 .as_u64()
1284 .unwrap()
1285 > 0
1286 );
1287 assert!(
1288 metrics["pipeline_decode"]["host_copy_us_total"]
1289 .as_u64()
1290 .unwrap()
1291 > 0
1292 );
1293 assert!(
1294 metrics["pipeline_decode"]["host_copy_us_total"]
1295 .as_u64()
1296 .unwrap()
1297 + metrics["pipeline_decode"]["device_copy_us_total"]
1298 .as_u64()
1299 .unwrap()
1300 > 0
1301 );
1302 }
1303
1304 #[test]
1305 fn pipeline_decode_batch_matches_full_model_and_preserves_order() {
1306 let (mut full, mut pipeline) = build_full_and_pipeline();
1307
1308 let prefills: Vec<Vec<u32>> = (0..16)
1309 .map(|idx| {
1310 let len = 1 + (idx % 4);
1311 (0..len).map(|offset| ((idx + offset) % 7) as u32).collect()
1312 })
1313 .collect();
1314 for (idx, tokens) in prefills.iter().enumerate() {
1315 let _ = full.prefill(&format!("full_{idx}"), tokens);
1316 let _ = pipeline.prefill(&format!("pipe_{idx}"), tokens);
1317 }
1318
1319 let first_tokens: Vec<u32> = (0..prefills.len())
1320 .map(|idx| ((idx + 5) % 7) as u32)
1321 .collect();
1322 let second_tokens: Vec<u32> = (0..prefills.len())
1323 .map(|idx| ((idx + 1) % 7) as u32)
1324 .collect();
1325 let full_first: Vec<_> = prefills
1326 .iter()
1327 .enumerate()
1328 .map(|(idx, tokens)| {
1329 (
1330 format!("full_{idx}"),
1331 first_tokens[idx],
1332 tokens.len() as u32,
1333 )
1334 })
1335 .collect();
1336 let pipe_first: Vec<_> = prefills
1337 .iter()
1338 .enumerate()
1339 .map(|(idx, tokens)| {
1340 (
1341 format!("pipe_{idx}"),
1342 first_tokens[idx],
1343 tokens.len() as u32,
1344 )
1345 })
1346 .collect();
1347
1348 let expected = full.decode_batch(&full_first);
1349 let actual = pipeline.decode_batch(&pipe_first);
1350
1351 assert_eq!(actual.len(), prefills.len());
1352 for row in 0..prefills.len() {
1353 assert_logits_close(
1354 &format!("decode batch row {row}"),
1355 &expected[row],
1356 &actual[row],
1357 );
1358 }
1359
1360 let full_next: Vec<_> = prefills
1361 .iter()
1362 .enumerate()
1363 .map(|(idx, tokens)| {
1364 (
1365 format!("full_{idx}"),
1366 second_tokens[idx],
1367 tokens.len() as u32 + 1,
1368 )
1369 })
1370 .collect();
1371 let pipe_next: Vec<_> = prefills
1372 .iter()
1373 .enumerate()
1374 .map(|(idx, tokens)| {
1375 (
1376 format!("pipe_{idx}"),
1377 second_tokens[idx],
1378 tokens.len() as u32 + 1,
1379 )
1380 })
1381 .collect();
1382 let expected_next = full.decode_batch(&full_next);
1383 let actual_next = pipeline.decode_batch(&pipe_next);
1384
1385 for row in 0..prefills.len() {
1386 assert_logits_close(
1387 &format!("follow-up decode batch row {row}"),
1388 &expected_next[row],
1389 &actual_next[row],
1390 );
1391 }
1392
1393 let metrics = pipeline.cache_metrics_snapshot().unwrap();
1394 assert_eq!(metrics["pipeline_decode"]["calls"], 2);
1395 assert_eq!(metrics["pipeline_decode"]["overlapped_calls"], 2);
1396 assert_eq!(metrics["pipeline_decode"]["rows"], 32);
1397 assert_eq!(metrics["pipeline_decode"]["max_batch"], 16);
1398 assert_eq!(metrics["pipeline_decode"]["last_batch"], 16);
1399 assert_eq!(metrics["pipeline_decode"]["microbatch_count_max"], 2);
1400 assert_eq!(metrics["pipeline_decode"]["microbatch_size_max"], 8);
1401 assert_eq!(metrics["pipeline_decode"]["in_flight_stage_count_max"], 2);
1402 assert_eq!(metrics["pipeline_decode"]["queue_depth_max"], 1);
1403 assert_eq!(
1404 metrics["pipeline_decode"]["stage_us_last"]
1405 .as_array()
1406 .unwrap()
1407 .len(),
1408 pipeline.stages().len()
1409 );
1410 assert!(
1411 metrics["pipeline_decode"]["host_bridge_bytes_total"]
1412 .as_u64()
1413 .unwrap()
1414 > 0
1415 );
1416 assert!(
1417 metrics["pipeline_decode"]["bridge_us_total"]
1418 .as_u64()
1419 .unwrap()
1420 > 0
1421 );
1422 assert!(
1423 metrics["pipeline_decode"]["host_copy_us_total"]
1424 .as_u64()
1425 .unwrap()
1426 > 0
1427 );
1428 assert!(
1429 metrics["pipeline_decode"]["host_copy_us_total"]
1430 .as_u64()
1431 .unwrap()
1432 + metrics["pipeline_decode"]["device_copy_us_total"]
1433 .as_u64()
1434 .unwrap()
1435 > 0
1436 );
1437 }
1438
1439 #[test]
1440 fn pipeline_overlapped_mode_keeps_small_batches_whole() {
1441 let (mut full, mut pipeline) = build_full_and_pipeline();
1442
1443 let prefills: Vec<Vec<u32>> = (0..8)
1444 .map(|idx| {
1445 let len = 1 + (idx % 3);
1446 (0..len).map(|offset| ((idx + offset) % 7) as u32).collect()
1447 })
1448 .collect();
1449 for (idx, tokens) in prefills.iter().enumerate() {
1450 let _ = full.prefill(&format!("full_{idx}"), tokens);
1451 let _ = pipeline.prefill(&format!("pipe_{idx}"), tokens);
1452 }
1453
1454 let full_batch: Vec<_> = prefills
1455 .iter()
1456 .enumerate()
1457 .map(|(idx, tokens)| {
1458 (
1459 format!("full_{idx}"),
1460 ((idx + 5) % 7) as u32,
1461 tokens.len() as u32,
1462 )
1463 })
1464 .collect();
1465 let pipe_batch: Vec<_> = prefills
1466 .iter()
1467 .enumerate()
1468 .map(|(idx, tokens)| {
1469 (
1470 format!("pipe_{idx}"),
1471 ((idx + 5) % 7) as u32,
1472 tokens.len() as u32,
1473 )
1474 })
1475 .collect();
1476
1477 let expected = full.decode_batch(&full_batch);
1478 let actual = pipeline.decode_batch(&pipe_batch);
1479
1480 assert_eq!(actual.len(), prefills.len());
1481 for row in 0..prefills.len() {
1482 assert_logits_close(
1483 &format!("small batch row {row}"),
1484 &expected[row],
1485 &actual[row],
1486 );
1487 }
1488
1489 let metrics = pipeline.cache_metrics_snapshot().unwrap();
1490 assert_eq!(metrics["selected_pipeline_mode"], "overlapped");
1491 assert_eq!(metrics["pipeline_decode"]["calls"], 1);
1492 assert_eq!(metrics["pipeline_decode"]["overlapped_calls"], 0);
1493 assert_eq!(metrics["pipeline_decode"]["max_batch"], 8);
1494 assert_eq!(metrics["pipeline_decode"]["microbatch_count_max"], 1);
1495 assert_eq!(metrics["pipeline_decode"]["microbatch_size_max"], 8);
1496 assert_eq!(metrics["pipeline_decode"]["in_flight_stage_count_max"], 1);
1497 assert_eq!(metrics["pipeline_decode"]["queue_depth_max"], 0);
1498 }
1499
1500 #[test]
1501 fn pipeline_batch_mode_uses_stage_batch_without_overlap() {
1502 let (mut full, mut pipeline) = build_full_and_pipeline_with_mode(LlamaPipelineMode::Batch);
1503
1504 let _ = full.prefill("full_a", &[0, 1]);
1505 let _ = full.prefill("full_b", &[2, 3, 4]);
1506 let _ = pipeline.prefill("pipe_a", &[0, 1]);
1507 let _ = pipeline.prefill("pipe_b", &[2, 3, 4]);
1508
1509 let expected =
1510 full.decode_batch(&[("full_a".to_string(), 5, 2), ("full_b".to_string(), 6, 3)]);
1511 let actual =
1512 pipeline.decode_batch(&[("pipe_a".to_string(), 5, 2), ("pipe_b".to_string(), 6, 3)]);
1513
1514 assert_logits_close("batch mode row 0", &expected[0], &actual[0]);
1515 assert_logits_close("batch mode row 1", &expected[1], &actual[1]);
1516
1517 let metrics = pipeline.cache_metrics_snapshot().unwrap();
1518 assert_eq!(metrics["selected_pipeline_mode"], "batch");
1519 assert_eq!(metrics["pipeline_decode"]["calls"], 1);
1520 assert_eq!(metrics["pipeline_decode"]["overlapped_calls"], 0);
1521 assert_eq!(metrics["pipeline_decode"]["microbatch_count_max"], 1);
1522 assert_eq!(metrics["pipeline_decode"]["microbatch_size_max"], 2);
1523 assert_eq!(metrics["pipeline_decode"]["in_flight_stage_count_max"], 1);
1524 assert_eq!(metrics["pipeline_decode"]["queue_depth_max"], 0);
1525 }
1526
1527 #[test]
1528 fn pipeline_incremental_prefill_matches_full_model_position_offset() {
1529 let (mut full, mut pipeline) = build_full_and_pipeline();
1530
1531 let _ = full.prefill("full", &[0, 1]);
1532 let _ = pipeline.prefill("pipe", &[0, 1]);
1533 let full_logits = full.prefill("full", &[2, 3]);
1534 let pipeline_logits = pipeline.prefill("pipe", &[2, 3]);
1535
1536 assert_logits_close("incremental prefill", &full_logits, &pipeline_logits);
1537 }
1538
1539 #[test]
1540 fn pipeline_rejects_device_ordinals_without_backend_scope_support() {
1541 let cfg = parity_config(2);
1542 let loader = ParityLoader::new(cfg.clone());
1543 let stage0 = LlamaFamilyModel::<CpuBackend, KvFp16>::new_layer_stage(
1544 cfg.clone(),
1545 &loader,
1546 LlamaFamilyLayerStageConfig::pipeline_stage(0..1, true, false),
1547 )
1548 .unwrap();
1549 let stage1 = LlamaFamilyModel::<CpuBackend, KvFp16>::new_layer_stage(
1550 cfg,
1551 &loader,
1552 LlamaFamilyLayerStageConfig::pipeline_stage(1..2, false, true),
1553 )
1554 .unwrap();
1555
1556 let err = match LlamaFamilyPipelineModel::new_with_backend_device_ordinals(
1557 vec![stage0, stage1],
1558 vec![Some(0), Some(1)],
1559 ) {
1560 Ok(_) => panic!("pipeline unexpectedly accepted unsupported device ordinals"),
1561 Err(err) => err.to_string(),
1562 };
1563
1564 assert!(err.contains("does not support device-scoped execution"));
1565 }
1566}