Skip to main content

ferrum_models/models/
llama_family_pipeline.rs

1use ferrum_interfaces::kv_dtype::KvInt8;
2use ferrum_kernels::backend::{BackendInt8KvOps, KvLayer, MoeLlmBackend};
3use ferrum_types::{FerrumError, Result};
4
5use crate::common::{DecoderOnlyLLM, LlmRuntimeConfig};
6
7use super::llama_family::{
8    llama_family_decode_op_profile_enabled, LlamaFamilyModel, LlamaStageHiddenBridgeTiming,
9};
10
11fn elapsed_micros_u64(t0: std::time::Instant) -> u64 {
12    t0.elapsed().as_micros().min(u64::MAX as u128) as u64
13}
14
15const MIN_OVERLAPPED_DECODE_BATCH: usize = 16;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub(crate) enum PipelineHiddenDtype {
19    F32,
20}
21
22impl PipelineHiddenDtype {
23    fn as_str(self) -> &'static str {
24        match self {
25            Self::F32 => "f32",
26        }
27    }
28
29    fn elem_size_bytes(self) -> usize {
30        match self {
31            Self::F32 => size_of::<f32>(),
32        }
33    }
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub(crate) enum PipelineHiddenDevice {
38    Host,
39    BackendDevice { ordinal: Option<usize> },
40}
41
42impl PipelineHiddenDevice {
43    fn as_str(self) -> &'static str {
44        match self {
45            Self::Host => "host",
46            Self::BackendDevice { .. } => "backend_device",
47        }
48    }
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub(crate) enum PipelineHiddenLayout {
53    RowMajor,
54}
55
56impl PipelineHiddenLayout {
57    fn as_str(self) -> &'static str {
58        match self {
59            Self::RowMajor => "row_major",
60        }
61    }
62}
63
64#[allow(dead_code)]
65pub(crate) enum PipelineHiddenStorage<B: MoeLlmBackend> {
66    Host(Vec<f32>),
67    Device(B::Buffer),
68}
69
70pub(crate) struct PipelineHidden<B: MoeLlmBackend> {
71    shape: [usize; 2],
72    dtype: PipelineHiddenDtype,
73    device: PipelineHiddenDevice,
74    layout: PipelineHiddenLayout,
75    storage: PipelineHiddenStorage<B>,
76}
77
78impl<B: MoeLlmBackend> PipelineHidden<B> {
79    pub(crate) fn host(data: Vec<f32>, batch: usize, hidden_size: usize) -> Self {
80        assert_eq!(
81            data.len(),
82            batch * hidden_size,
83            "pipeline hidden host buffer length {} != batch * hidden_size {}",
84            data.len(),
85            batch * hidden_size
86        );
87        Self {
88            shape: [batch, hidden_size],
89            dtype: PipelineHiddenDtype::F32,
90            device: PipelineHiddenDevice::Host,
91            layout: PipelineHiddenLayout::RowMajor,
92            storage: PipelineHiddenStorage::Host(data),
93        }
94    }
95
96    pub(crate) fn row_count(&self) -> usize {
97        self.shape[0]
98    }
99
100    pub(crate) fn hidden_size(&self) -> usize {
101        self.shape[1]
102    }
103
104    pub(crate) fn len_bytes(&self) -> usize {
105        self.shape[0] * self.shape[1] * self.dtype.elem_size_bytes()
106    }
107
108    pub(crate) fn host_slice(&self) -> &[f32] {
109        match &self.storage {
110            PipelineHiddenStorage::Host(data) => data,
111            PipelineHiddenStorage::Device(_) => {
112                panic!("device-resident PipelineHidden cannot be read through host_slice")
113            }
114        }
115    }
116
117    fn metadata_json(&self) -> serde_json::Value {
118        serde_json::json!({
119            "shape": self.shape,
120            "dtype": self.dtype.as_str(),
121            "device": self.device.as_str(),
122            "layout": self.layout.as_str(),
123            "len_bytes": self.len_bytes(),
124        })
125    }
126}
127
128pub(crate) trait LlamaPipelineStageBatchOps<B: MoeLlmBackend> {
129    fn decode_stage_tokens_to_hidden_batch(
130        &mut self,
131        batch: &[(String, u32, u32)],
132    ) -> PipelineHidden<B>;
133
134    fn decode_stage_hidden_from_host_batch(
135        &mut self,
136        batch: &[(String, u32, u32)],
137        hidden: &PipelineHidden<B>,
138    ) -> (PipelineHidden<B>, LlamaStageHiddenBridgeTiming);
139
140    fn logits_from_hidden_batch(
141        &mut self,
142        hidden: &PipelineHidden<B>,
143        force_full_logits: bool,
144    ) -> Vec<Vec<f32>>;
145}
146
147impl<B> LlamaPipelineStageBatchOps<B> for LlamaFamilyModel<B, KvInt8>
148where
149    B: MoeLlmBackend + BackendInt8KvOps,
150{
151    fn decode_stage_tokens_to_hidden_batch(
152        &mut self,
153        batch: &[(String, u32, u32)],
154    ) -> PipelineHidden<B> {
155        let h = self.config().hidden_size;
156        let mut hidden = Vec::with_capacity(batch.len() * h);
157        for (cache_id, token, pos) in batch {
158            hidden.extend_from_slice(&self.decode_stage_token_to_hidden(cache_id, *token, *pos));
159        }
160        PipelineHidden::host(hidden, batch.len(), h)
161    }
162
163    fn decode_stage_hidden_from_host_batch(
164        &mut self,
165        batch: &[(String, u32, u32)],
166        hidden: &PipelineHidden<B>,
167    ) -> (PipelineHidden<B>, LlamaStageHiddenBridgeTiming) {
168        let h = self.config().hidden_size;
169        let hidden_slice = hidden.host_slice();
170        assert_eq!(
171            hidden_slice.len(),
172            batch.len() * h,
173            "hidden length {} != batch * hidden_size {}",
174            hidden_slice.len(),
175            batch.len() * h
176        );
177        let mut out = Vec::with_capacity(hidden_slice.len());
178        let mut bridge_timing = LlamaStageHiddenBridgeTiming::default();
179        for (row, (cache_id, _, pos)) in batch.iter().enumerate() {
180            let start = row * h;
181            let (row_hidden, row_timing) = self.decode_stage_hidden_from_host_with_timing(
182                cache_id,
183                &hidden_slice[start..start + h],
184                *pos,
185            );
186            out.extend_from_slice(&row_hidden);
187            bridge_timing = bridge_timing.add(row_timing);
188        }
189        (PipelineHidden::host(out, batch.len(), h), bridge_timing)
190    }
191
192    fn logits_from_hidden_batch(
193        &mut self,
194        hidden: &PipelineHidden<B>,
195        _force_full_logits: bool,
196    ) -> Vec<Vec<f32>> {
197        let h = self.config().hidden_size;
198        let hidden_slice = hidden.host_slice();
199        assert_eq!(
200            hidden_slice.len(),
201            hidden.row_count() * h,
202            "hidden length {} != row_count * hidden_size {}",
203            hidden_slice.len(),
204            hidden.row_count() * h
205        );
206        (0..hidden.row_count())
207            .map(|row| {
208                let start = row * h;
209                self.logits_from_hidden(&hidden_slice[start..start + h])
210            })
211            .collect()
212    }
213}
214
215#[derive(Debug, Clone, Copy, PartialEq, Eq)]
216pub enum LlamaPipelineTransport {
217    HostHiddenBridge,
218}
219
220impl LlamaPipelineTransport {
221    fn as_str(self) -> &'static str {
222        match self {
223            Self::HostHiddenBridge => "host-hidden-bridge",
224        }
225    }
226
227    fn stage_bridge(self) -> LlamaPipelineStageBridge {
228        match self {
229            Self::HostHiddenBridge => LlamaPipelineStageBridge::Host,
230        }
231    }
232}
233
234#[derive(Debug, Clone, Copy, PartialEq, Eq)]
235pub enum LlamaPipelineStageBridge {
236    Host,
237    CudaPeer,
238    CudaDeviceStaged,
239}
240
241impl LlamaPipelineStageBridge {
242    fn as_str(self) -> &'static str {
243        match self {
244            Self::Host => "host",
245            Self::CudaPeer => "cuda_peer",
246            Self::CudaDeviceStaged => "cuda_device_staged",
247        }
248    }
249}
250
251#[derive(Debug, Clone, Copy, PartialEq, Eq)]
252pub enum LlamaPipelineMode {
253    Batch,
254    Overlapped,
255}
256
257impl LlamaPipelineMode {
258    pub fn as_str(self) -> &'static str {
259        match self {
260            Self::Batch => "batch",
261            Self::Overlapped => "overlapped",
262        }
263    }
264
265    pub fn default_for_stage_count(stage_count: usize) -> Self {
266        if stage_count == 2 {
267            Self::Overlapped
268        } else {
269            Self::Batch
270        }
271    }
272
273    pub fn from_config_value(value: &str) -> Result<Self> {
274        match value.trim().to_ascii_lowercase().as_str() {
275            "batch" => Ok(Self::Batch),
276            "overlapped" => Ok(Self::Overlapped),
277            other => Err(FerrumError::config(format!(
278                "layer_split_pipeline_mode must be batch or overlapped, got {other:?}"
279            ))),
280        }
281    }
282}
283
284#[derive(Debug, Clone)]
285struct PipelineDecodeStats {
286    calls: u64,
287    overlapped_calls: u64,
288    rows: u64,
289    max_batch: u64,
290    last_batch: u64,
291    microbatch_count_max: u64,
292    microbatch_count_last: u64,
293    microbatch_size_max: u64,
294    microbatch_size_last: u64,
295    in_flight_stage_count_max: u64,
296    in_flight_stage_count_last: u64,
297    queue_depth_max: u64,
298    queue_depth_last: u64,
299    host_bridge_bytes_total: u64,
300    host_bridge_bytes_last: u64,
301    bridge_us_total: u64,
302    bridge_us_last: u64,
303    host_copy_us_total: u64,
304    host_copy_us_last: u64,
305    device_copy_us_total: u64,
306    device_copy_us_last: u64,
307    stage_us_total: Vec<u64>,
308    stage_us_last: Vec<u64>,
309    logits_us_total: u64,
310    logits_us_last: u64,
311    total_us_total: u64,
312    total_us_last: u64,
313}
314
315impl PipelineDecodeStats {
316    fn new(stage_count: usize) -> Self {
317        Self {
318            calls: 0,
319            overlapped_calls: 0,
320            rows: 0,
321            max_batch: 0,
322            last_batch: 0,
323            microbatch_count_max: 0,
324            microbatch_count_last: 0,
325            microbatch_size_max: 0,
326            microbatch_size_last: 0,
327            in_flight_stage_count_max: 0,
328            in_flight_stage_count_last: 0,
329            queue_depth_max: 0,
330            queue_depth_last: 0,
331            host_bridge_bytes_total: 0,
332            host_bridge_bytes_last: 0,
333            bridge_us_total: 0,
334            bridge_us_last: 0,
335            host_copy_us_total: 0,
336            host_copy_us_last: 0,
337            device_copy_us_total: 0,
338            device_copy_us_last: 0,
339            stage_us_total: vec![0; stage_count],
340            stage_us_last: vec![0; stage_count],
341            logits_us_total: 0,
342            logits_us_last: 0,
343            total_us_total: 0,
344            total_us_last: 0,
345        }
346    }
347
348    fn record(
349        &mut self,
350        batch: usize,
351        microbatch_count: usize,
352        microbatch_size: usize,
353        in_flight_stage_count: usize,
354        queue_depth: usize,
355        overlapped: bool,
356        host_bridge_bytes: usize,
357        bridge_timing: LlamaStageHiddenBridgeTiming,
358        stage_us: &[u64],
359        logits_us: u64,
360        total_us: u64,
361    ) {
362        if self.stage_us_total.len() != stage_us.len() {
363            self.stage_us_total.resize(stage_us.len(), 0);
364            self.stage_us_last.resize(stage_us.len(), 0);
365        }
366        self.calls = self.calls.saturating_add(1);
367        if overlapped {
368            self.overlapped_calls = self.overlapped_calls.saturating_add(1);
369        }
370        self.rows = self.rows.saturating_add(batch as u64);
371        self.max_batch = self.max_batch.max(batch as u64);
372        self.last_batch = batch as u64;
373        self.microbatch_count_max = self.microbatch_count_max.max(microbatch_count as u64);
374        self.microbatch_count_last = microbatch_count as u64;
375        self.microbatch_size_max = self.microbatch_size_max.max(microbatch_size as u64);
376        self.microbatch_size_last = microbatch_size as u64;
377        self.in_flight_stage_count_max = self
378            .in_flight_stage_count_max
379            .max(in_flight_stage_count as u64);
380        self.in_flight_stage_count_last = in_flight_stage_count as u64;
381        self.queue_depth_max = self.queue_depth_max.max(queue_depth as u64);
382        self.queue_depth_last = queue_depth as u64;
383        self.host_bridge_bytes_total = self
384            .host_bridge_bytes_total
385            .saturating_add(host_bridge_bytes as u64);
386        self.host_bridge_bytes_last = host_bridge_bytes as u64;
387        self.bridge_us_total = self.bridge_us_total.saturating_add(bridge_timing.bridge_us);
388        self.bridge_us_last = bridge_timing.bridge_us;
389        self.host_copy_us_total = self
390            .host_copy_us_total
391            .saturating_add(bridge_timing.host_copy_us);
392        self.host_copy_us_last = bridge_timing.host_copy_us;
393        self.device_copy_us_total = self
394            .device_copy_us_total
395            .saturating_add(bridge_timing.device_copy_us);
396        self.device_copy_us_last = bridge_timing.device_copy_us;
397        for (idx, value) in stage_us.iter().copied().enumerate() {
398            self.stage_us_total[idx] = self.stage_us_total[idx].saturating_add(value);
399            self.stage_us_last[idx] = value;
400        }
401        self.logits_us_total = self.logits_us_total.saturating_add(logits_us);
402        self.logits_us_last = logits_us;
403        self.total_us_total = self.total_us_total.saturating_add(total_us);
404        self.total_us_last = total_us;
405    }
406
407    fn avg_per_call(&self, value: u64) -> Option<u64> {
408        (self.calls > 0).then(|| value / self.calls)
409    }
410
411    fn json(&self) -> serde_json::Value {
412        let stage_us_avg: Vec<Option<u64>> = self
413            .stage_us_total
414            .iter()
415            .map(|value| self.avg_per_call(*value))
416            .collect();
417        serde_json::json!({
418            "calls": self.calls,
419            "overlapped_calls": self.overlapped_calls,
420            "rows": self.rows,
421            "max_batch": self.max_batch,
422            "last_batch": self.last_batch,
423            "microbatch_count_max": self.microbatch_count_max,
424            "microbatch_count_last": self.microbatch_count_last,
425            "microbatch_size_max": self.microbatch_size_max,
426            "microbatch_size_last": self.microbatch_size_last,
427            "in_flight_stage_count_max": self.in_flight_stage_count_max,
428            "in_flight_stage_count_last": self.in_flight_stage_count_last,
429            "queue_depth_max": self.queue_depth_max,
430            "queue_depth_last": self.queue_depth_last,
431            "host_bridge_bytes_total": self.host_bridge_bytes_total,
432            "host_bridge_bytes_last": self.host_bridge_bytes_last,
433            "host_bridge_bytes_avg": self.avg_per_call(self.host_bridge_bytes_total),
434            "bridge_us_total": self.bridge_us_total,
435            "bridge_us_last": self.bridge_us_last,
436            "bridge_us_avg": self.avg_per_call(self.bridge_us_total),
437            "host_copy_us_total": self.host_copy_us_total,
438            "host_copy_us_last": self.host_copy_us_last,
439            "host_copy_us_avg": self.avg_per_call(self.host_copy_us_total),
440            "device_copy_us_total": self.device_copy_us_total,
441            "device_copy_us_last": self.device_copy_us_last,
442            "device_copy_us_avg": self.avg_per_call(self.device_copy_us_total),
443            "stage_us_total": self.stage_us_total,
444            "stage_us_last": self.stage_us_last,
445            "stage_us_avg": stage_us_avg,
446            "logits_us_total": self.logits_us_total,
447            "logits_us_last": self.logits_us_last,
448            "logits_us_avg": self.avg_per_call(self.logits_us_total),
449            "total_us_total": self.total_us_total,
450            "total_us_last": self.total_us_last,
451            "total_us_avg": self.avg_per_call(self.total_us_total),
452        })
453    }
454}
455
456#[derive(Debug, Clone, Copy, PartialEq, Eq)]
457pub struct LlamaPipelineStagePlacement {
458    pub backend_device_ordinal: Option<usize>,
459}
460
461impl LlamaPipelineStagePlacement {
462    pub fn default_backend_device() -> Self {
463        Self {
464            backend_device_ordinal: None,
465        }
466    }
467
468    pub fn backend_device(ordinal: usize) -> Self {
469        Self {
470            backend_device_ordinal: Some(ordinal),
471        }
472    }
473}
474
475#[derive(Debug, Clone, PartialEq, Eq)]
476pub struct LlamaPipelinePlacement {
477    stages: Vec<LlamaPipelineStagePlacement>,
478    transport: LlamaPipelineTransport,
479    pipeline_mode: LlamaPipelineMode,
480}
481
482impl LlamaPipelinePlacement {
483    pub fn unplaced(stage_count: usize) -> Self {
484        Self {
485            stages: vec![LlamaPipelineStagePlacement::default_backend_device(); stage_count],
486            transport: LlamaPipelineTransport::HostHiddenBridge,
487            pipeline_mode: LlamaPipelineMode::default_for_stage_count(stage_count),
488        }
489    }
490
491    pub fn from_backend_device_ordinals(stage_device_ordinals: Vec<Option<usize>>) -> Self {
492        let stage_count = stage_device_ordinals.len();
493        Self {
494            stages: stage_device_ordinals
495                .into_iter()
496                .map(|backend_device_ordinal| LlamaPipelineStagePlacement {
497                    backend_device_ordinal,
498                })
499                .collect(),
500            transport: LlamaPipelineTransport::HostHiddenBridge,
501            pipeline_mode: LlamaPipelineMode::default_for_stage_count(stage_count),
502        }
503    }
504
505    pub fn with_pipeline_mode(mut self, pipeline_mode: LlamaPipelineMode) -> Self {
506        self.pipeline_mode = pipeline_mode;
507        self
508    }
509
510    pub fn len(&self) -> usize {
511        self.stages.len()
512    }
513
514    pub fn is_empty(&self) -> bool {
515        self.stages.is_empty()
516    }
517
518    pub fn stage(&self, idx: usize) -> LlamaPipelineStagePlacement {
519        self.stages[idx]
520    }
521
522    pub fn stages(&self) -> &[LlamaPipelineStagePlacement] {
523        &self.stages
524    }
525
526    pub fn transport(&self) -> LlamaPipelineTransport {
527        self.transport
528    }
529
530    pub fn stage_bridge(&self) -> LlamaPipelineStageBridge {
531        self.transport.stage_bridge()
532    }
533
534    pub fn pipeline_mode(&self) -> LlamaPipelineMode {
535        self.pipeline_mode
536    }
537
538    pub fn stage_device_ordinals(&self) -> Vec<Option<usize>> {
539        self.stages
540            .iter()
541            .map(|stage| stage.backend_device_ordinal)
542            .collect()
543    }
544
545    fn has_explicit_device_ordinals(&self) -> bool {
546        self.stages
547            .iter()
548            .any(|stage| stage.backend_device_ordinal.is_some())
549    }
550}
551
552pub struct LlamaFamilyPipelineModel<B: MoeLlmBackend, K: KvLayer<B>> {
553    stages: Vec<LlamaFamilyModel<B, K>>,
554    placement: LlamaPipelinePlacement,
555    runtime_cfg: LlmRuntimeConfig,
556    decode_stats: PipelineDecodeStats,
557}
558
559impl<B: MoeLlmBackend, K: KvLayer<B>> LlamaFamilyPipelineModel<B, K> {
560    pub fn new(stages: Vec<LlamaFamilyModel<B, K>>) -> Result<Self> {
561        let placement = LlamaPipelinePlacement::unplaced(stages.len());
562        Self::new_with_placement(stages, placement)
563    }
564
565    pub fn new_with_backend_device_ordinals(
566        stages: Vec<LlamaFamilyModel<B, K>>,
567        stage_device_ordinals: Vec<Option<usize>>,
568    ) -> Result<Self> {
569        Self::new_with_placement(
570            stages,
571            LlamaPipelinePlacement::from_backend_device_ordinals(stage_device_ordinals),
572        )
573    }
574
575    pub fn new_with_placement(
576        stages: Vec<LlamaFamilyModel<B, K>>,
577        placement: LlamaPipelinePlacement,
578    ) -> Result<Self> {
579        if stages.is_empty() {
580            return Err(FerrumError::model(
581                "LlamaFamilyPipelineModel requires at least one stage",
582            ));
583        }
584        if placement.len() != stages.len() {
585            return Err(FerrumError::model(format!(
586                "Llama pipeline stage device count {} must match stage count {}",
587                placement.len(),
588                stages.len()
589            )));
590        }
591        if placement.has_explicit_device_ordinals() && !B::supports_device_ordinal_scope() {
592            return Err(FerrumError::unsupported(
593                "Llama layer-split pipeline requested explicit backend device ordinals, \
594                 but the selected backend does not support device-scoped execution",
595            ));
596        }
597        if placement.pipeline_mode() == LlamaPipelineMode::Overlapped && placement.len() != 2 {
598            return Err(FerrumError::model(
599                "Llama layer-split overlapped pipeline mode requires exactly two stages",
600            ));
601        }
602        if stages.first().is_some_and(|stage| stage.embed.is_none()) {
603            return Err(FerrumError::model(
604                "first Llama pipeline stage must load embedding weights",
605            ));
606        }
607        if stages.last().is_some_and(|stage| stage.lm_head.is_none()) {
608            return Err(FerrumError::model(
609                "last Llama pipeline stage must load lm_head weights",
610            ));
611        }
612
613        let runtime_cfg = stages[0].runtime_cfg.clone();
614        let mut expected_start = 0usize;
615        for stage in &stages {
616            if stage.runtime_cfg.hidden_size != runtime_cfg.hidden_size
617                || stage.runtime_cfg.vocab_size != runtime_cfg.vocab_size
618                || stage.runtime_cfg.num_kv_heads != runtime_cfg.num_kv_heads
619                || stage.runtime_cfg.head_dim != runtime_cfg.head_dim
620                || stage.runtime_cfg.max_seq_len != runtime_cfg.max_seq_len
621            {
622                return Err(FerrumError::model(
623                    "Llama pipeline stages must share runtime dimensions",
624                ));
625            }
626            let range = stage.source_layer_range();
627            if range.start != expected_start {
628                return Err(FerrumError::model(format!(
629                    "Llama pipeline stage range starts at {}, expected {expected_start}",
630                    range.start
631                )));
632            }
633            expected_start = range.end;
634        }
635        if expected_start != runtime_cfg.num_layers {
636            return Err(FerrumError::model(format!(
637                "Llama pipeline stages cover {expected_start} layers but model has {}",
638                runtime_cfg.num_layers
639            )));
640        }
641
642        let stage_count = placement.len();
643        Ok(Self {
644            stages,
645            placement,
646            runtime_cfg,
647            decode_stats: PipelineDecodeStats::new(stage_count),
648        })
649    }
650
651    pub fn stages(&self) -> &[LlamaFamilyModel<B, K>] {
652        &self.stages
653    }
654
655    pub fn placement(&self) -> &LlamaPipelinePlacement {
656        &self.placement
657    }
658
659    fn pipeline_mode(&self) -> LlamaPipelineMode {
660        self.placement.pipeline_mode()
661    }
662
663    fn decode_microbatch_size(&self, batch_len: usize) -> usize {
664        match self.pipeline_mode() {
665            LlamaPipelineMode::Overlapped if batch_len < MIN_OVERLAPPED_DECODE_BATCH => {
666                batch_len.max(1)
667            }
668            LlamaPipelineMode::Overlapped => ((batch_len + 1) / 2).max(1),
669            LlamaPipelineMode::Batch => batch_len.max(1),
670        }
671    }
672
673    fn last_hidden_row<'a>(&self, hidden: &'a [f32], seq_len: usize) -> &'a [f32] {
674        let h = self.runtime_cfg.hidden_size;
675        &hidden[(seq_len - 1) * h..seq_len * h]
676    }
677}
678
679#[allow(private_bounds)]
680impl<B, K> LlamaFamilyPipelineModel<B, K>
681where
682    B: MoeLlmBackend,
683    K: KvLayer<B>,
684    LlamaFamilyModel<B, K>: DecoderOnlyLLM + LlamaPipelineStageBatchOps<B> + Send,
685{
686    fn decode_batch_sequential_internal(
687        &mut self,
688        batch: &[(String, u32, u32)],
689        force_full_logits: bool,
690    ) -> Vec<Vec<f32>> {
691        let profile = llama_family_decode_op_profile_enabled();
692        let total_t0 = std::time::Instant::now();
693        let mut stage_us: Vec<u64> = Vec::with_capacity(self.stages.len());
694        let mut host_bridge_bytes = 0usize;
695        let mut bridge_timing = LlamaStageHiddenBridgeTiming::default();
696        let stage_bridge = self.placement.stage_bridge();
697
698        let stage_t0 = std::time::Instant::now();
699        let mut hidden =
700            B::with_device_ordinal(self.placement.stage(0).backend_device_ordinal, || {
701                self.stages[0].decode_stage_tokens_to_hidden_batch(batch)
702            });
703        stage_us.push(elapsed_micros_u64(stage_t0));
704        for idx in 1..self.stages.len() {
705            let device = self.placement.stage(idx).backend_device_ordinal;
706            host_bridge_bytes = host_bridge_bytes.saturating_add(hidden.len_bytes());
707            let stage_t0 = std::time::Instant::now();
708            let (next_hidden, stage_bridge_timing) = B::with_device_ordinal(device, || {
709                self.stages[idx].decode_stage_hidden_from_host_batch(batch, &hidden)
710            });
711            hidden = next_hidden;
712            bridge_timing = bridge_timing.add(stage_bridge_timing);
713            stage_us.push(elapsed_micros_u64(stage_t0));
714        }
715        let last_idx = self.stages.len() - 1;
716        let logits_t0 = std::time::Instant::now();
717        let logits = B::with_device_ordinal(
718            self.placement.stage(last_idx).backend_device_ordinal,
719            || self.stages[last_idx].logits_from_hidden_batch(&hidden, force_full_logits),
720        );
721        let logits_us = elapsed_micros_u64(logits_t0);
722        let total_us = elapsed_micros_u64(total_t0);
723        self.decode_stats.record(
724            batch.len(),
725            1,
726            batch.len().max(1),
727            1,
728            0,
729            false,
730            host_bridge_bytes,
731            bridge_timing,
732            &stage_us,
733            logits_us,
734            total_us,
735        );
736        if profile {
737            static PIPELINE_PROFILE_CALLS: std::sync::atomic::AtomicU64 =
738                std::sync::atomic::AtomicU64::new(0);
739            let n = PIPELINE_PROFILE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
740            if n.is_multiple_of(8) {
741                eprintln!(
742                    "[pipeline-decode-profile] call#{} m={} hidden={:?} mode=batch bridge={} host_bridge_bytes={} bridge_us={} host_copy_us={} device_copy_us={} stage_us={:?} logits_us={} total_us={}",
743                    n,
744                    batch.len(),
745                    hidden.metadata_json(),
746                    stage_bridge.as_str(),
747                    host_bridge_bytes,
748                    bridge_timing.bridge_us,
749                    bridge_timing.host_copy_us,
750                    bridge_timing.device_copy_us,
751                    stage_us,
752                    logits_us,
753                    total_us,
754                );
755            }
756        }
757        logits
758    }
759
760    fn decode_batch_overlapped_two_stage(
761        &mut self,
762        batch: &[(String, u32, u32)],
763        force_full_logits: bool,
764    ) -> Vec<Vec<f32>> {
765        debug_assert_eq!(self.stages.len(), 2);
766        let microbatch_size = self.decode_microbatch_size(batch.len());
767        let chunks: Vec<Vec<(String, u32, u32)>> = batch
768            .chunks(microbatch_size)
769            .map(|chunk| chunk.to_vec())
770            .collect();
771        let chunk_count = chunks.len();
772        if chunk_count <= 1 {
773            return self.decode_batch_sequential_internal(batch, force_full_logits);
774        }
775
776        let profile = llama_family_decode_op_profile_enabled();
777        let total_t0 = std::time::Instant::now();
778        let stage0_device = self.placement.stage(0).backend_device_ordinal;
779        let stage1_device = self.placement.stage(1).backend_device_ordinal;
780        let stage_bridge = self.placement.stage_bridge();
781        let (stage0_slice, stage1_slice) = self.stages.split_at_mut(1);
782        let stage0 = &mut stage0_slice[0];
783        let stage1 = &mut stage1_slice[0];
784        let (tx, rx) = std::sync::mpsc::sync_channel::<(
785            usize,
786            Vec<(String, u32, u32)>,
787            PipelineHidden<B>,
788            u64,
789        )>(1);
790        let mut ordered_logits: Vec<Option<Vec<Vec<f32>>>> = vec![None; chunk_count];
791        let mut stage0_us_total = 0u64;
792        let mut stage1_us_total = 0u64;
793        let mut logits_us_total = 0u64;
794        let mut host_bridge_bytes = 0usize;
795        let mut bridge_timing = LlamaStageHiddenBridgeTiming::default();
796
797        std::thread::scope(|scope| {
798            let worker = scope.spawn(move || {
799                for (idx, chunk) in chunks.into_iter().enumerate() {
800                    let stage_t0 = std::time::Instant::now();
801                    let hidden = B::with_device_ordinal(stage0_device, || {
802                        stage0.decode_stage_tokens_to_hidden_batch(&chunk)
803                    });
804                    let stage_us = elapsed_micros_u64(stage_t0);
805                    if tx.send((idx, chunk, hidden, stage_us)).is_err() {
806                        break;
807                    }
808                }
809            });
810
811            for _ in 0..chunk_count {
812                let (idx, chunk, hidden, stage0_us) = rx
813                    .recv()
814                    .expect("pipeline stage0 worker ended before sending all microbatches");
815                stage0_us_total = stage0_us_total.saturating_add(stage0_us);
816                host_bridge_bytes = host_bridge_bytes.saturating_add(hidden.len_bytes());
817
818                let stage_t0 = std::time::Instant::now();
819                let (hidden, stage_bridge_timing) = B::with_device_ordinal(stage1_device, || {
820                    stage1.decode_stage_hidden_from_host_batch(&chunk, &hidden)
821                });
822                bridge_timing = bridge_timing.add(stage_bridge_timing);
823                stage1_us_total = stage1_us_total.saturating_add(elapsed_micros_u64(stage_t0));
824
825                let logits_t0 = std::time::Instant::now();
826                let logits = B::with_device_ordinal(stage1_device, || {
827                    stage1.logits_from_hidden_batch(&hidden, force_full_logits)
828                });
829                logits_us_total = logits_us_total.saturating_add(elapsed_micros_u64(logits_t0));
830                ordered_logits[idx] = Some(logits);
831            }
832
833            worker
834                .join()
835                .expect("pipeline stage0 worker panicked during overlapped decode");
836        });
837
838        let total_us = elapsed_micros_u64(total_t0);
839        let stage_us = vec![stage0_us_total, stage1_us_total];
840        self.decode_stats.record(
841            batch.len(),
842            chunk_count,
843            microbatch_size,
844            2,
845            1,
846            true,
847            host_bridge_bytes,
848            bridge_timing,
849            &stage_us,
850            logits_us_total,
851            total_us,
852        );
853        if profile {
854            static PIPELINE_PROFILE_CALLS: std::sync::atomic::AtomicU64 =
855                std::sync::atomic::AtomicU64::new(0);
856            let n = PIPELINE_PROFILE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
857            if n.is_multiple_of(8) {
858                eprintln!(
859                    "[pipeline-decode-profile] call#{} m={} mode=overlapped microbatch_size={} microbatch_count={} bridge={} host_bridge_bytes={} bridge_us={} host_copy_us={} device_copy_us={} stage_us={:?} logits_us={} total_us={}",
860                    n,
861                    batch.len(),
862                    microbatch_size,
863                    chunk_count,
864                    stage_bridge.as_str(),
865                    host_bridge_bytes,
866                    bridge_timing.bridge_us,
867                    bridge_timing.host_copy_us,
868                    bridge_timing.device_copy_us,
869                    stage_us,
870                    logits_us_total,
871                    total_us,
872                );
873            }
874        }
875
876        let mut logits = Vec::with_capacity(batch.len());
877        for chunk_logits in ordered_logits {
878            logits.extend(
879                chunk_logits.expect("pipeline overlapped decode missing logits for microbatch"),
880            );
881        }
882        logits
883    }
884}
885
886impl<B, K> DecoderOnlyLLM for LlamaFamilyPipelineModel<B, K>
887where
888    B: MoeLlmBackend,
889    K: KvLayer<B>,
890    LlamaFamilyModel<B, K>: DecoderOnlyLLM + LlamaPipelineStageBatchOps<B> + Send,
891{
892    fn config(&self) -> &LlmRuntimeConfig {
893        &self.runtime_cfg
894    }
895
896    fn cache_metrics_snapshot(&self) -> Option<serde_json::Value> {
897        let stage_bridge = self.placement.stage_bridge();
898        let pipeline_mode = self.pipeline_mode();
899        Some(serde_json::json!({
900            "position": "llama-layer-split-pipeline",
901            "stage_count": self.stages.len() as u64,
902            "stage_device_ordinals": self.placement.stage_device_ordinals(),
903            "transport": self.placement.transport().as_str(),
904            "selected_pipeline_mode": pipeline_mode.as_str(),
905            "selected_stage_bridge": stage_bridge.as_str(),
906            "pipeline_hidden": {
907                "dtype": PipelineHiddenDtype::F32.as_str(),
908                "device": PipelineHiddenDevice::Host.as_str(),
909                "layout": PipelineHiddenLayout::RowMajor.as_str(),
910            },
911            "pipeline_decode": self.decode_stats.json(),
912        }))
913    }
914
915    fn prepare(&mut self, cache_id: &str, max_tokens: usize) {
916        for (idx, stage) in self.stages.iter_mut().enumerate() {
917            let device = self.placement.stage(idx).backend_device_ordinal;
918            B::with_device_ordinal(device, || {
919                stage.ensure_scratch(max_tokens);
920                stage.ensure_kv(cache_id);
921            });
922        }
923    }
924
925    fn kv_capacity(&self) -> usize {
926        self.stages
927            .iter()
928            .map(|stage| stage.kv_capacity())
929            .min()
930            .unwrap_or(self.runtime_cfg.max_seq_len)
931    }
932
933    fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
934        assert!(!tokens.is_empty(), "pipeline prefill called with no tokens");
935        let pos_offset = self.stages[0].cache_len(cache_id);
936        let mut hidden =
937            B::with_device_ordinal(self.placement.stage(0).backend_device_ordinal, || {
938                self.stages[0].prefill_stage_tokens_to_hidden(cache_id, tokens, pos_offset)
939            });
940        for idx in 1..self.stages.len() {
941            let device = self.placement.stage(idx).backend_device_ordinal;
942            hidden = B::with_device_ordinal(device, || {
943                self.stages[idx].prefill_stage_hidden_from_host(
944                    cache_id,
945                    &hidden,
946                    tokens.len(),
947                    pos_offset,
948                )
949            });
950        }
951        let last_hidden = self.last_hidden_row(&hidden, tokens.len()).to_vec();
952        let last_idx = self.stages.len() - 1;
953        B::with_device_ordinal(
954            self.placement.stage(last_idx).backend_device_ordinal,
955            || self.stages[last_idx].logits_from_hidden(&last_hidden),
956        )
957    }
958
959    fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
960        let total_t0 = std::time::Instant::now();
961        let mut stage_us: Vec<u64> = Vec::with_capacity(self.stages.len());
962        let mut host_bridge_bytes = 0usize;
963        let mut bridge_timing = LlamaStageHiddenBridgeTiming::default();
964
965        let stage_t0 = std::time::Instant::now();
966        let mut hidden =
967            B::with_device_ordinal(self.placement.stage(0).backend_device_ordinal, || {
968                self.stages[0].decode_stage_token_to_hidden(cache_id, token, pos)
969            });
970        stage_us.push(elapsed_micros_u64(stage_t0));
971        for idx in 1..self.stages.len() {
972            let device = self.placement.stage(idx).backend_device_ordinal;
973            host_bridge_bytes = host_bridge_bytes.saturating_add(hidden.len() * size_of::<f32>());
974            let stage_t0 = std::time::Instant::now();
975            let (next_hidden, stage_bridge_timing) = B::with_device_ordinal(device, || {
976                self.stages[idx].decode_stage_hidden_from_host_with_timing(cache_id, &hidden, pos)
977            });
978            hidden = next_hidden;
979            bridge_timing = bridge_timing.add(stage_bridge_timing);
980            stage_us.push(elapsed_micros_u64(stage_t0));
981        }
982        let last_idx = self.stages.len() - 1;
983        let logits_t0 = std::time::Instant::now();
984        let logits = B::with_device_ordinal(
985            self.placement.stage(last_idx).backend_device_ordinal,
986            || self.stages[last_idx].logits_from_hidden(&hidden),
987        );
988        self.decode_stats.record(
989            1,
990            1,
991            1,
992            1,
993            0,
994            false,
995            host_bridge_bytes,
996            bridge_timing,
997            &stage_us,
998            elapsed_micros_u64(logits_t0),
999            elapsed_micros_u64(total_t0),
1000        );
1001        logits
1002    }
1003
1004    fn decode_batch(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
1005        self.decode_batch_with_full_logits(batch, false)
1006    }
1007
1008    fn decode_batch_with_full_logits(
1009        &mut self,
1010        batch: &[(String, u32, u32)],
1011        force_full_logits: bool,
1012    ) -> Vec<Vec<f32>> {
1013        if batch.is_empty() {
1014            return Vec::new();
1015        }
1016        if batch.len() == 1 && !force_full_logits {
1017            let (cache_id, token, pos) = &batch[0];
1018            return vec![self.decode(cache_id, *token, *pos)];
1019        }
1020
1021        if self.pipeline_mode() == LlamaPipelineMode::Overlapped {
1022            self.decode_batch_overlapped_two_stage(batch, force_full_logits)
1023        } else {
1024            self.decode_batch_sequential_internal(batch, force_full_logits)
1025        }
1026    }
1027
1028    fn release(&mut self, cache_id: &str) {
1029        for (idx, stage) in self.stages.iter_mut().enumerate() {
1030            B::with_device_ordinal(self.placement.stage(idx).backend_device_ordinal, || {
1031                stage.release(cache_id);
1032            });
1033        }
1034    }
1035
1036    fn truncate_kv(&mut self, cache_id: &str, new_len: usize) {
1037        for (idx, stage) in self.stages.iter_mut().enumerate() {
1038            B::with_device_ordinal(self.placement.stage(idx).backend_device_ordinal, || {
1039                stage.truncate_kv(cache_id, new_len);
1040            });
1041        }
1042    }
1043
1044    fn reset(&mut self) {
1045        for (idx, stage) in self.stages.iter_mut().enumerate() {
1046            B::with_device_ordinal(self.placement.stage(idx).backend_device_ordinal, || {
1047                stage.reset();
1048            });
1049        }
1050    }
1051}
1052
1053#[cfg(test)]
1054mod tests {
1055    use ferrum_interfaces::kv_dtype::KvFp16;
1056    use ferrum_kernels::backend::cpu::CpuBackend;
1057    use ferrum_quantization::{DenseLinear, QuantConfig, WeightLoader};
1058    use ferrum_types::{FerrumError, Result};
1059
1060    use super::*;
1061    use crate::models::llama_family::{LlamaFamilyConfig, LlamaFamilyLayerStageConfig};
1062
1063    struct ParityLoader {
1064        cfg: LlamaFamilyConfig,
1065    }
1066
1067    impl ParityLoader {
1068        fn new(cfg: LlamaFamilyConfig) -> Self {
1069            Self { cfg }
1070        }
1071
1072        fn deterministic_values(name: &str, len: usize, base: f32, scale: f32) -> Vec<f32> {
1073            let mut hash = 0x811c9dc5u32;
1074            for byte in name.bytes() {
1075                hash ^= byte as u32;
1076                hash = hash.wrapping_mul(0x01000193);
1077            }
1078            (0..len)
1079                .map(|idx| {
1080                    let mixed = hash
1081                        .wrapping_add((idx as u32).wrapping_mul(0x9e3779b9))
1082                        .rotate_left((idx % 17) as u32);
1083                    let centered = (mixed % 23) as f32 - 11.0;
1084                    base + centered * scale
1085                })
1086                .collect()
1087        }
1088
1089        fn layer_norm_values(&self, name: &str) -> Vec<f32> {
1090            Self::deterministic_values(name, self.cfg.hidden_size, 1.0, 0.005)
1091        }
1092
1093        fn linear_dims(&self, name: &str) -> Result<(usize, usize)> {
1094            let q_dim = self.cfg.num_heads * self.cfg.head_dim;
1095            let kv_dim = self.cfg.num_kv_heads * self.cfg.head_dim;
1096            if name.ends_with(".self_attn.qkv_proj") {
1097                Ok((q_dim + 2 * kv_dim, self.cfg.hidden_size))
1098            } else if name.ends_with(".self_attn.o_proj") {
1099                Ok((self.cfg.hidden_size, q_dim))
1100            } else if name.ends_with(".mlp.gate_up_proj") {
1101                Ok((2 * self.cfg.intermediate_size, self.cfg.hidden_size))
1102            } else if name.ends_with(".mlp.down_proj") {
1103                Ok((self.cfg.hidden_size, self.cfg.intermediate_size))
1104            } else if name == "lm_head" || name == "model.embed_tokens" {
1105                Ok((self.cfg.vocab_size, self.cfg.hidden_size))
1106            } else {
1107                Err(FerrumError::model(format!(
1108                    "unexpected linear requested by parity loader: {name}"
1109                )))
1110            }
1111        }
1112    }
1113
1114    impl WeightLoader<CpuBackend> for ParityLoader {
1115        fn load_tensor(&self, name: &str) -> Result<Vec<f32>> {
1116            if name == "model.embed_tokens.weight" {
1117                return Ok(Self::deterministic_values(
1118                    name,
1119                    self.cfg.vocab_size * self.cfg.hidden_size,
1120                    0.0,
1121                    0.02,
1122                ));
1123            }
1124            if name == "model.norm.weight"
1125                || name.ends_with(".input_layernorm.weight")
1126                || name.ends_with(".post_attention_layernorm.weight")
1127            {
1128                return Ok(self.layer_norm_values(name));
1129            }
1130            Err(FerrumError::model(format!(
1131                "unexpected tensor requested by parity loader: {name}"
1132            )))
1133        }
1134
1135        fn load_linear(
1136            &self,
1137            name: &str,
1138        ) -> Result<Box<dyn ferrum_quantization::Linear<CpuBackend>>> {
1139            let (out_features, in_features) = self.linear_dims(name)?;
1140            let weights = Self::deterministic_values(name, out_features * in_features, 0.0, 0.015);
1141            Ok(Box::new(DenseLinear::<CpuBackend>::from_rows(
1142                &weights,
1143                out_features,
1144                in_features,
1145            )))
1146        }
1147
1148        fn has_tensor(&self, name: &str) -> bool {
1149            name == "lm_head.weight"
1150        }
1151
1152        fn quant_config(&self) -> Option<&QuantConfig> {
1153            None
1154        }
1155    }
1156
1157    fn parity_config(num_layers: usize) -> LlamaFamilyConfig {
1158        LlamaFamilyConfig {
1159            hidden_size: 4,
1160            intermediate_size: 8,
1161            num_heads: 2,
1162            num_kv_heads: 2,
1163            head_dim: 2,
1164            num_layers,
1165            vocab_size: 7,
1166            max_seq_len: 16,
1167            rms_norm_eps: 1e-5,
1168            rope_theta: 10_000.0,
1169            rope_scaling: None,
1170            rope_interleaved: false,
1171            has_qk_norm: false,
1172            sliding_window: 0,
1173        }
1174    }
1175
1176    fn build_full_and_pipeline() -> (
1177        LlamaFamilyModel<CpuBackend, KvFp16>,
1178        LlamaFamilyPipelineModel<CpuBackend, KvFp16>,
1179    ) {
1180        build_full_and_pipeline_with_mode(LlamaPipelineMode::default_for_stage_count(2))
1181    }
1182
1183    fn build_full_and_pipeline_with_mode(
1184        pipeline_mode: LlamaPipelineMode,
1185    ) -> (
1186        LlamaFamilyModel<CpuBackend, KvFp16>,
1187        LlamaFamilyPipelineModel<CpuBackend, KvFp16>,
1188    ) {
1189        let cfg = parity_config(3);
1190        let loader = ParityLoader::new(cfg.clone());
1191        let full = LlamaFamilyModel::<CpuBackend, KvFp16>::new(cfg.clone(), &loader).unwrap();
1192        let stage0 = LlamaFamilyModel::<CpuBackend, KvFp16>::new_layer_stage(
1193            cfg.clone(),
1194            &loader,
1195            LlamaFamilyLayerStageConfig::pipeline_stage(0..1, true, false),
1196        )
1197        .unwrap();
1198        let stage1 = LlamaFamilyModel::<CpuBackend, KvFp16>::new_layer_stage(
1199            cfg,
1200            &loader,
1201            LlamaFamilyLayerStageConfig::pipeline_stage(1..3, false, true),
1202        )
1203        .unwrap();
1204        let pipeline = LlamaFamilyPipelineModel::new_with_placement(
1205            vec![stage0, stage1],
1206            LlamaPipelinePlacement::unplaced(2).with_pipeline_mode(pipeline_mode),
1207        )
1208        .unwrap();
1209        (full, pipeline)
1210    }
1211
1212    fn assert_logits_close(label: &str, expected: &[f32], actual: &[f32]) {
1213        assert_eq!(expected.len(), actual.len(), "{label} length mismatch");
1214        let max_diff = expected
1215            .iter()
1216            .zip(actual)
1217            .map(|(a, b)| (a - b).abs())
1218            .fold(0.0f32, f32::max);
1219        assert!(
1220            max_diff < 1e-5,
1221            "{label} logits diverged: max_diff={max_diff} expected={expected:?} actual={actual:?}"
1222        );
1223    }
1224
1225    #[test]
1226    fn pipeline_prefill_matches_full_model_on_multi_token_cpu_model() {
1227        let (mut full, mut pipeline) = build_full_and_pipeline();
1228
1229        let full_logits = full.prefill("full", &[0, 1, 2, 3]);
1230        let pipeline_logits = pipeline.prefill("pipe", &[0, 1, 2, 3]);
1231
1232        assert_eq!(pipeline.config().num_layers, 3);
1233        assert_eq!(pipeline.stages().len(), 2);
1234        assert_eq!(
1235            pipeline.placement().stage_device_ordinals(),
1236            vec![None, None]
1237        );
1238        assert_eq!(
1239            pipeline.placement().transport(),
1240            LlamaPipelineTransport::HostHiddenBridge
1241        );
1242        let metrics = pipeline.cache_metrics_snapshot().unwrap();
1243        assert_eq!(metrics["selected_pipeline_mode"], "overlapped");
1244        assert_eq!(metrics["selected_stage_bridge"], "host");
1245        assert_eq!(metrics["pipeline_decode"]["calls"], 0);
1246        assert_logits_close("multi-token prefill", &full_logits, &pipeline_logits);
1247    }
1248
1249    #[test]
1250    fn pipeline_decode_after_multi_token_prefill_matches_full_model() {
1251        let (mut full, mut pipeline) = build_full_and_pipeline();
1252
1253        let _ = full.prefill("full", &[0, 1, 2]);
1254        let _ = pipeline.prefill("pipe", &[0, 1, 2]);
1255        let full_logits_3 = full.decode("full", 3, 3);
1256        let pipeline_logits_3 = pipeline.decode("pipe", 3, 3);
1257        assert_logits_close("decode pos 3", &full_logits_3, &pipeline_logits_3);
1258
1259        let full_logits_4 = full.decode("full", 4, 4);
1260        let pipeline_logits_4 = pipeline.decode("pipe", 4, 4);
1261        assert_logits_close("decode pos 4", &full_logits_4, &pipeline_logits_4);
1262
1263        let metrics = pipeline.cache_metrics_snapshot().unwrap();
1264        assert_eq!(metrics["pipeline_decode"]["calls"], 2);
1265        assert_eq!(metrics["pipeline_decode"]["overlapped_calls"], 0);
1266        assert_eq!(metrics["pipeline_decode"]["rows"], 2);
1267        assert_eq!(metrics["pipeline_decode"]["max_batch"], 1);
1268        assert_eq!(
1269            metrics["pipeline_decode"]["stage_us_last"]
1270                .as_array()
1271                .unwrap()
1272                .len(),
1273            pipeline.stages().len()
1274        );
1275        assert!(
1276            metrics["pipeline_decode"]["host_bridge_bytes_total"]
1277                .as_u64()
1278                .unwrap()
1279                > 0
1280        );
1281        assert!(
1282            metrics["pipeline_decode"]["bridge_us_total"]
1283                .as_u64()
1284                .unwrap()
1285                > 0
1286        );
1287        assert!(
1288            metrics["pipeline_decode"]["host_copy_us_total"]
1289                .as_u64()
1290                .unwrap()
1291                > 0
1292        );
1293        assert!(
1294            metrics["pipeline_decode"]["host_copy_us_total"]
1295                .as_u64()
1296                .unwrap()
1297                + metrics["pipeline_decode"]["device_copy_us_total"]
1298                    .as_u64()
1299                    .unwrap()
1300                > 0
1301        );
1302    }
1303
1304    #[test]
1305    fn pipeline_decode_batch_matches_full_model_and_preserves_order() {
1306        let (mut full, mut pipeline) = build_full_and_pipeline();
1307
1308        let prefills: Vec<Vec<u32>> = (0..16)
1309            .map(|idx| {
1310                let len = 1 + (idx % 4);
1311                (0..len).map(|offset| ((idx + offset) % 7) as u32).collect()
1312            })
1313            .collect();
1314        for (idx, tokens) in prefills.iter().enumerate() {
1315            let _ = full.prefill(&format!("full_{idx}"), tokens);
1316            let _ = pipeline.prefill(&format!("pipe_{idx}"), tokens);
1317        }
1318
1319        let first_tokens: Vec<u32> = (0..prefills.len())
1320            .map(|idx| ((idx + 5) % 7) as u32)
1321            .collect();
1322        let second_tokens: Vec<u32> = (0..prefills.len())
1323            .map(|idx| ((idx + 1) % 7) as u32)
1324            .collect();
1325        let full_first: Vec<_> = prefills
1326            .iter()
1327            .enumerate()
1328            .map(|(idx, tokens)| {
1329                (
1330                    format!("full_{idx}"),
1331                    first_tokens[idx],
1332                    tokens.len() as u32,
1333                )
1334            })
1335            .collect();
1336        let pipe_first: Vec<_> = prefills
1337            .iter()
1338            .enumerate()
1339            .map(|(idx, tokens)| {
1340                (
1341                    format!("pipe_{idx}"),
1342                    first_tokens[idx],
1343                    tokens.len() as u32,
1344                )
1345            })
1346            .collect();
1347
1348        let expected = full.decode_batch(&full_first);
1349        let actual = pipeline.decode_batch(&pipe_first);
1350
1351        assert_eq!(actual.len(), prefills.len());
1352        for row in 0..prefills.len() {
1353            assert_logits_close(
1354                &format!("decode batch row {row}"),
1355                &expected[row],
1356                &actual[row],
1357            );
1358        }
1359
1360        let full_next: Vec<_> = prefills
1361            .iter()
1362            .enumerate()
1363            .map(|(idx, tokens)| {
1364                (
1365                    format!("full_{idx}"),
1366                    second_tokens[idx],
1367                    tokens.len() as u32 + 1,
1368                )
1369            })
1370            .collect();
1371        let pipe_next: Vec<_> = prefills
1372            .iter()
1373            .enumerate()
1374            .map(|(idx, tokens)| {
1375                (
1376                    format!("pipe_{idx}"),
1377                    second_tokens[idx],
1378                    tokens.len() as u32 + 1,
1379                )
1380            })
1381            .collect();
1382        let expected_next = full.decode_batch(&full_next);
1383        let actual_next = pipeline.decode_batch(&pipe_next);
1384
1385        for row in 0..prefills.len() {
1386            assert_logits_close(
1387                &format!("follow-up decode batch row {row}"),
1388                &expected_next[row],
1389                &actual_next[row],
1390            );
1391        }
1392
1393        let metrics = pipeline.cache_metrics_snapshot().unwrap();
1394        assert_eq!(metrics["pipeline_decode"]["calls"], 2);
1395        assert_eq!(metrics["pipeline_decode"]["overlapped_calls"], 2);
1396        assert_eq!(metrics["pipeline_decode"]["rows"], 32);
1397        assert_eq!(metrics["pipeline_decode"]["max_batch"], 16);
1398        assert_eq!(metrics["pipeline_decode"]["last_batch"], 16);
1399        assert_eq!(metrics["pipeline_decode"]["microbatch_count_max"], 2);
1400        assert_eq!(metrics["pipeline_decode"]["microbatch_size_max"], 8);
1401        assert_eq!(metrics["pipeline_decode"]["in_flight_stage_count_max"], 2);
1402        assert_eq!(metrics["pipeline_decode"]["queue_depth_max"], 1);
1403        assert_eq!(
1404            metrics["pipeline_decode"]["stage_us_last"]
1405                .as_array()
1406                .unwrap()
1407                .len(),
1408            pipeline.stages().len()
1409        );
1410        assert!(
1411            metrics["pipeline_decode"]["host_bridge_bytes_total"]
1412                .as_u64()
1413                .unwrap()
1414                > 0
1415        );
1416        assert!(
1417            metrics["pipeline_decode"]["bridge_us_total"]
1418                .as_u64()
1419                .unwrap()
1420                > 0
1421        );
1422        assert!(
1423            metrics["pipeline_decode"]["host_copy_us_total"]
1424                .as_u64()
1425                .unwrap()
1426                > 0
1427        );
1428        assert!(
1429            metrics["pipeline_decode"]["host_copy_us_total"]
1430                .as_u64()
1431                .unwrap()
1432                + metrics["pipeline_decode"]["device_copy_us_total"]
1433                    .as_u64()
1434                    .unwrap()
1435                > 0
1436        );
1437    }
1438
1439    #[test]
1440    fn pipeline_overlapped_mode_keeps_small_batches_whole() {
1441        let (mut full, mut pipeline) = build_full_and_pipeline();
1442
1443        let prefills: Vec<Vec<u32>> = (0..8)
1444            .map(|idx| {
1445                let len = 1 + (idx % 3);
1446                (0..len).map(|offset| ((idx + offset) % 7) as u32).collect()
1447            })
1448            .collect();
1449        for (idx, tokens) in prefills.iter().enumerate() {
1450            let _ = full.prefill(&format!("full_{idx}"), tokens);
1451            let _ = pipeline.prefill(&format!("pipe_{idx}"), tokens);
1452        }
1453
1454        let full_batch: Vec<_> = prefills
1455            .iter()
1456            .enumerate()
1457            .map(|(idx, tokens)| {
1458                (
1459                    format!("full_{idx}"),
1460                    ((idx + 5) % 7) as u32,
1461                    tokens.len() as u32,
1462                )
1463            })
1464            .collect();
1465        let pipe_batch: Vec<_> = prefills
1466            .iter()
1467            .enumerate()
1468            .map(|(idx, tokens)| {
1469                (
1470                    format!("pipe_{idx}"),
1471                    ((idx + 5) % 7) as u32,
1472                    tokens.len() as u32,
1473                )
1474            })
1475            .collect();
1476
1477        let expected = full.decode_batch(&full_batch);
1478        let actual = pipeline.decode_batch(&pipe_batch);
1479
1480        assert_eq!(actual.len(), prefills.len());
1481        for row in 0..prefills.len() {
1482            assert_logits_close(
1483                &format!("small batch row {row}"),
1484                &expected[row],
1485                &actual[row],
1486            );
1487        }
1488
1489        let metrics = pipeline.cache_metrics_snapshot().unwrap();
1490        assert_eq!(metrics["selected_pipeline_mode"], "overlapped");
1491        assert_eq!(metrics["pipeline_decode"]["calls"], 1);
1492        assert_eq!(metrics["pipeline_decode"]["overlapped_calls"], 0);
1493        assert_eq!(metrics["pipeline_decode"]["max_batch"], 8);
1494        assert_eq!(metrics["pipeline_decode"]["microbatch_count_max"], 1);
1495        assert_eq!(metrics["pipeline_decode"]["microbatch_size_max"], 8);
1496        assert_eq!(metrics["pipeline_decode"]["in_flight_stage_count_max"], 1);
1497        assert_eq!(metrics["pipeline_decode"]["queue_depth_max"], 0);
1498    }
1499
1500    #[test]
1501    fn pipeline_batch_mode_uses_stage_batch_without_overlap() {
1502        let (mut full, mut pipeline) = build_full_and_pipeline_with_mode(LlamaPipelineMode::Batch);
1503
1504        let _ = full.prefill("full_a", &[0, 1]);
1505        let _ = full.prefill("full_b", &[2, 3, 4]);
1506        let _ = pipeline.prefill("pipe_a", &[0, 1]);
1507        let _ = pipeline.prefill("pipe_b", &[2, 3, 4]);
1508
1509        let expected =
1510            full.decode_batch(&[("full_a".to_string(), 5, 2), ("full_b".to_string(), 6, 3)]);
1511        let actual =
1512            pipeline.decode_batch(&[("pipe_a".to_string(), 5, 2), ("pipe_b".to_string(), 6, 3)]);
1513
1514        assert_logits_close("batch mode row 0", &expected[0], &actual[0]);
1515        assert_logits_close("batch mode row 1", &expected[1], &actual[1]);
1516
1517        let metrics = pipeline.cache_metrics_snapshot().unwrap();
1518        assert_eq!(metrics["selected_pipeline_mode"], "batch");
1519        assert_eq!(metrics["pipeline_decode"]["calls"], 1);
1520        assert_eq!(metrics["pipeline_decode"]["overlapped_calls"], 0);
1521        assert_eq!(metrics["pipeline_decode"]["microbatch_count_max"], 1);
1522        assert_eq!(metrics["pipeline_decode"]["microbatch_size_max"], 2);
1523        assert_eq!(metrics["pipeline_decode"]["in_flight_stage_count_max"], 1);
1524        assert_eq!(metrics["pipeline_decode"]["queue_depth_max"], 0);
1525    }
1526
1527    #[test]
1528    fn pipeline_incremental_prefill_matches_full_model_position_offset() {
1529        let (mut full, mut pipeline) = build_full_and_pipeline();
1530
1531        let _ = full.prefill("full", &[0, 1]);
1532        let _ = pipeline.prefill("pipe", &[0, 1]);
1533        let full_logits = full.prefill("full", &[2, 3]);
1534        let pipeline_logits = pipeline.prefill("pipe", &[2, 3]);
1535
1536        assert_logits_close("incremental prefill", &full_logits, &pipeline_logits);
1537    }
1538
1539    #[test]
1540    fn pipeline_rejects_device_ordinals_without_backend_scope_support() {
1541        let cfg = parity_config(2);
1542        let loader = ParityLoader::new(cfg.clone());
1543        let stage0 = LlamaFamilyModel::<CpuBackend, KvFp16>::new_layer_stage(
1544            cfg.clone(),
1545            &loader,
1546            LlamaFamilyLayerStageConfig::pipeline_stage(0..1, true, false),
1547        )
1548        .unwrap();
1549        let stage1 = LlamaFamilyModel::<CpuBackend, KvFp16>::new_layer_stage(
1550            cfg,
1551            &loader,
1552            LlamaFamilyLayerStageConfig::pipeline_stage(1..2, false, true),
1553        )
1554        .unwrap();
1555
1556        let err = match LlamaFamilyPipelineModel::new_with_backend_device_ordinals(
1557            vec![stage0, stage1],
1558            vec![Some(0), Some(1)],
1559        ) {
1560            Ok(_) => panic!("pipeline unexpectedly accepted unsupported device ordinals"),
1561            Err(err) => err.to_string(),
1562        };
1563
1564        assert!(err.contains("does not support device-scoped execution"));
1565    }
1566}