1use async_trait::async_trait;
7use candle_core::Tensor;
8use ferrum_interfaces::{
9 model_executor::{
10 AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorStatus,
11 MemoryRequirements, PrefillInput, PrefillOutput,
12 },
13 KvCacheHandle, ModelExecutor, TensorRef,
14};
15use ferrum_types::{DataType, Device, FerrumError, ModelInfo, Result};
16use std::collections::HashMap;
17use std::sync::{
18 atomic::{AtomicU64, Ordering},
19 Arc,
20};
21use tracing::{debug, info};
22
23use super::common::{self, GenericKvCacheHandle};
24use crate::architectures::llama::LlamaModelWrapper;
25use crate::tensor_wrapper::CandleTensorWrapper;
26use parking_lot::Mutex;
27
28struct LlamaCacheState {
29 sequence_length: usize,
30}
31
32pub struct CandleModelExecutor {
34 model: Arc<LlamaModelWrapper>,
35 info: ModelInfo,
36 states: Mutex<HashMap<String, LlamaCacheState>>,
37 next_cache_id: AtomicU64,
38 #[cfg(feature = "cuda")]
39 cuda_runner: Mutex<Option<ferrum_cuda_kernels::cuda_decode::CudaDecodeRunner>>,
40 #[cfg(feature = "cuda")]
41 tp_group: Mutex<Option<ferrum_cuda_kernels::tp_decode::TpDecodeGroup>>,
42}
43
44impl CandleModelExecutor {
45 pub fn new(model: LlamaModelWrapper, info: ModelInfo) -> Self {
46 info!("Created CandleModelExecutor (Llama) for: {}", info.model_id);
47 Self {
48 model: Arc::new(model),
49 info,
50 states: Mutex::new(HashMap::new()),
51 next_cache_id: AtomicU64::new(1),
52 #[cfg(feature = "cuda")]
53 cuda_runner: Mutex::new(None),
54 #[cfg(feature = "cuda")]
55 tp_group: Mutex::new(None),
56 }
57 }
58
59 fn tokens_to_tensor(&self, token_ids: &[u32]) -> Result<Tensor> {
60 Tensor::new(token_ids, self.model.device())
61 .map_err(|e| FerrumError::model(format!("tensor: {e}")))?
62 .unsqueeze(0)
63 .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))
64 }
65
66 fn wrap_tensor(&self, tensor: Tensor) -> TensorRef {
67 Arc::new(CandleTensorWrapper::new(tensor))
68 }
69
70 fn tensor_to_tokens(&self, tensor: &TensorRef) -> Result<Vec<u32>> {
71 common::tensor_to_tokens(tensor)
72 }
73
74 #[cfg(feature = "cuda")]
77 fn tp_size() -> usize {
78 if let Ok(v) = std::env::var("FERRUM_TP") {
79 if let Ok(n) = v.parse::<usize>() {
80 return n;
81 }
82 }
83 candle_core::cuda_backend::cudarc::driver::CudaContext::device_count()
85 .map(|n| n as usize)
86 .unwrap_or(1)
87 }
88
89 #[cfg(not(feature = "cuda"))]
90 fn tp_size() -> usize {
91 0
92 }
93
94 #[cfg(feature = "cuda")]
96 fn ensure_tp_group(&self) -> bool {
97 if self.tp_group.lock().is_some() {
98 return true;
99 }
100 let tp = Self::tp_size();
101 if tp <= 1 {
102 return false;
103 }
104
105 info!("Initializing tensor parallel group: tp_size={tp}");
106
107 let model_dir = match self.model.model_dir.as_ref() {
108 Some(d) => d.clone(),
109 None => {
110 tracing::warn!("TP requires model_dir");
111 return false;
112 }
113 };
114
115 let loader = crate::loader::SafeTensorsLoader::new(&model_dir);
117 let vb = match loader.load_varbuilder(self.model.device(), self.model.dtype()) {
118 Ok(v) => v,
119 Err(e) => {
120 tracing::warn!("TP weight load failed: {e}");
121 return false;
122 }
123 };
124
125 let cfg = self.model.config();
126 let tp_cfg = crate::loader::tp_weight_loader::TpWeightConfig {
127 num_hidden_layers: cfg.num_hidden_layers,
128 hidden_size: cfg.hidden_size,
129 intermediate_size: cfg.intermediate_size,
130 num_attention_heads: cfg.num_attention_heads,
131 num_kv_heads: cfg.num_key_value_heads,
132 head_dim: cfg.head_dim,
133 vocab_size: cfg.vocab_size,
134 max_seq_len: cfg.max_position_embeddings,
135 rope_theta: cfg.rope_theta as f64,
136 has_qk_norm: crate::loader::SafeTensorsLoader::new(&model_dir)
138 .load_varbuilder(self.model.device(), self.model.dtype())
139 .map(|vb| {
140 vb.get(cfg.head_dim, "model.layers.0.self_attn.q_norm.weight")
141 .is_ok()
142 })
143 .unwrap_or(false),
144 tp_size: tp,
145 rank: 0,
146 };
147 info!(
148 "TP config: has_qk_norm={}, head_dim={}, nq={}, nkv={}",
149 tp_cfg.has_qk_norm, tp_cfg.head_dim, tp_cfg.num_attention_heads, tp_cfg.num_kv_heads
150 );
151
152 type RankResult = candle_core::Result<(
157 ferrum_cuda_kernels::cuda_decode::CudaDecodeRunner,
158 std::sync::Arc<candle_core::cuda_backend::cudarc::driver::CudaStream>,
159 )>;
160
161 let mut handles: Vec<std::thread::JoinHandle<RankResult>> = Vec::with_capacity(tp);
162 let dtype = self.model.dtype();
163 for rank in 0..tp {
164 let mut rank_cfg = tp_cfg.clone();
165 rank_cfg.rank = rank;
166 let model_dir = model_dir.clone();
167
168 handles.push(std::thread::spawn(move || {
169 let device = candle_core::Device::new_cuda(rank)?;
170 let loader = crate::loader::SafeTensorsLoader::new(&model_dir);
171 let vb = loader
172 .load_varbuilder(&device, dtype)
173 .map_err(|e| candle_core::Error::Msg(format!("VB rank {rank}: {e}")))?;
174
175 let (weights, dims, stream) =
176 crate::loader::tp_weight_loader::load_sharded_weights(&vb, &rank_cfg, &device)
177 .map_err(|e| candle_core::Error::Msg(format!("shard {rank}: {e}")))?;
178
179 let cuda_dev = device.as_cuda_device()?.clone();
180
181 let runner = ferrum_cuda_kernels::cuda_decode::CudaDecodeRunner::new(
182 weights,
183 dims,
184 cuda_dev,
185 stream.clone(),
186 )?;
187 Ok((runner, stream))
188 }));
189 }
190
191 let mut runners = Vec::with_capacity(tp);
192 let mut nccl_streams = Vec::with_capacity(tp);
193 for (rank, handle) in handles.into_iter().enumerate() {
194 match handle.join() {
195 Ok(Ok((runner, stream))) => {
196 runners.push(runner);
197 nccl_streams.push(stream);
198 }
199 Ok(Err(e)) => {
200 tracing::warn!("TP rank {rank} failed: {e}");
201 return false;
202 }
203 Err(_) => {
204 tracing::warn!("TP rank {rank} panicked");
205 return false;
206 }
207 }
208 }
209
210 for rank in 1..tp {
213 let (src, dst) = if rank == 1 {
214 let (first, rest) = runners.split_at_mut(1);
215 (&first[0], &mut rest[0])
216 } else {
217 let (first, rest) = runners.split_at_mut(1);
218 (&first[0], &mut rest[rank - 1])
219 };
220 if let Err(e) = dst.sync_replicated_weights_from(src) {
221 tracing::warn!("TP weight sync rank 0→{rank} failed: {e}");
222 return false;
223 }
224 info!("Synced replicated weights: rank 0 → rank {rank}");
225 }
226
227 let _main_thread_devices: Vec<_> = (0..tp)
230 .filter_map(|r| candle_core::Device::new_cuda(r).ok())
231 .collect();
232
233 let nccl_ranks = match ferrum_cuda_kernels::nccl_comm::NcclRank::init_all(nccl_streams) {
235 Ok(ranks) => ranks,
236 Err(e) => {
237 tracing::warn!("NCCL init_all failed: {e}");
238 return false;
239 }
240 };
241
242 match ferrum_cuda_kernels::tp_decode::TpDecodeGroup::new(runners, nccl_ranks) {
243 Ok(group) => {
244 info!("Tensor parallel group initialized: {tp} GPUs");
245 *self.tp_group.lock() = Some(group);
246 true
247 }
248 Err(e) => {
249 tracing::warn!("TpDecodeGroup init: {e}");
250 false
251 }
252 }
253 }
254
255 #[cfg(feature = "cuda")]
256 fn ensure_cuda_runner(&self) -> bool {
257 if self.cuda_runner.lock().is_some() {
258 return true;
259 }
260 if std::env::var("FERRUM_DISABLE_CUDA_RUNNER").map_or(false, |v| v == "1") {
261 return false;
262 }
263 match self.model.create_decode_runner() {
264 Ok(runner) => {
265 info!("CUDA decode runner initialized for Llama");
266 *self.cuda_runner.lock() = Some(runner);
267 true
268 }
269 Err(e) => {
270 tracing::warn!("CUDA runner init failed for Llama: {e}");
271 false
272 }
273 }
274 }
275
276 #[cfg(feature = "cuda")]
277 fn ensure_runner_kv_cache(&self, cache_id: &str, _seq_len: usize) -> Result<()> {
278 use candle_core::Storage;
279
280 let mut runner_guard = self.cuda_runner.lock();
281 let runner = match runner_guard.as_mut() {
282 Some(r) => r,
283 None => return Ok(()),
284 };
285 if runner.has_kv_cache(cache_id) {
286 return Ok(());
287 }
288
289 let kv_data = self
291 .model
292 .export_kv_cache(cache_id)
293 .or_else(|| {
294 cache_id
295 .rfind("-clone-")
296 .and_then(|pos| self.model.export_kv_cache(&cache_id[..pos]))
297 })
298 .ok_or_else(|| FerrumError::model(format!("No KV cache to export for: {cache_id}")))?;
299
300 if kv_data.is_empty() {
301 return Err(FerrumError::model("Empty KV cache export"));
302 }
303 let prefill_len = kv_data[0].2;
304 let max_len = kv_data[0].3;
305
306 let mut kv_slices = Vec::new();
307 for (k_tensor, v_tensor, _len, _max) in &kv_data {
308 let (k_s, _) = k_tensor.storage_and_layout();
309 let (v_s, _) = v_tensor.storage_and_layout();
310 let k_cuda = match &*k_s {
311 Storage::Cuda(cs) => cs
312 .as_cuda_slice::<half::f16>()
313 .map_err(|e| FerrumError::model(format!("KV extract: {e}")))?
314 .clone(),
315 _ => return Err(FerrumError::model("KV not on CUDA")),
316 };
317 let v_cuda = match &*v_s {
318 Storage::Cuda(cs) => cs
319 .as_cuda_slice::<half::f16>()
320 .map_err(|e| FerrumError::model(format!("KV extract: {e}")))?
321 .clone(),
322 _ => return Err(FerrumError::model("KV not on CUDA")),
323 };
324 drop(k_s);
325 drop(v_s);
326 kv_slices.push((k_cuda, v_cuda));
327 }
328
329 runner
330 .init_kv_cache(cache_id, kv_slices, prefill_len, max_len)
331 .map_err(|e| FerrumError::model(format!("KV init: {e}")))?;
332
333 if !self.states.lock().contains_key(cache_id) {
334 self.states.lock().insert(
335 cache_id.to_string(),
336 LlamaCacheState {
337 sequence_length: prefill_len,
338 },
339 );
340 }
341
342 debug!("Migrated Llama KV to CUDA runner: {cache_id}");
343 Ok(())
344 }
345
346 pub fn release_sequence(&self, cache_id: &str) {
347 self.states.lock().remove(cache_id);
348 self.model.release_cache(cache_id);
349 #[cfg(feature = "cuda")]
350 if let Some(ref mut runner) = *self.cuda_runner.lock() {
351 runner.release_kv_cache(cache_id);
352 }
353 }
354}
355
356#[async_trait]
357impl ModelExecutor for CandleModelExecutor {
358 fn info(&self) -> &ModelInfo {
359 &self.info
360 }
361
362 async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
363 debug!("Llama Prefill: seq_len={}", input.sequence_length());
364
365 let tokens = self.tensor_to_tokens(&input.input_ids)?;
366 if tokens.is_empty() {
367 return Err(FerrumError::model("Empty input"));
368 }
369
370 let cache_id = format!(
371 "llama-cache-{}",
372 self.next_cache_id.fetch_add(1, Ordering::Relaxed)
373 );
374
375 let input_tensor = self.tokens_to_tensor(&tokens)?;
376 let logits = self.model.forward_prefill(&input_tensor, &cache_id)?;
377
378 let logits = match logits.dims().len() {
379 2 => logits
380 .unsqueeze(1)
381 .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?,
382 3 => logits,
383 d => return Err(FerrumError::model(format!("Unexpected logits rank: {d}"))),
384 };
385
386 let logits_ref = self.wrap_tensor(logits);
387 let cfg = self.model.config();
388 let handle = Arc::new(GenericKvCacheHandle::new(
389 cfg.num_hidden_layers,
390 cfg.num_attention_heads,
391 cfg.head_dim,
392 self.model.device().clone(),
393 tokens.len(),
394 cache_id.clone(),
395 ));
396
397 self.states.lock().insert(
398 cache_id,
399 LlamaCacheState {
400 sequence_length: tokens.len(),
401 },
402 );
403
404 Ok(PrefillOutput::new(logits_ref, handle))
405 }
406
407 async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput> {
408 let handle = input
409 .kv_cache
410 .as_any()
411 .downcast_ref::<GenericKvCacheHandle>()
412 .ok_or_else(|| FerrumError::model("Invalid KV handle for Llama"))?;
413 let cache_id = handle.request_cache_id().to_string();
414
415 let seq_len = {
416 let mut states = self.states.lock();
417 if let Some(s) = states.get(&cache_id) {
418 s.sequence_length
419 } else {
420 let len = handle.block_table().sequence_length;
421 states.insert(
422 cache_id.clone(),
423 LlamaCacheState {
424 sequence_length: len,
425 },
426 );
427 len
428 }
429 };
430
431 let tokens = self.tensor_to_tokens(&input.input_ids)?;
432 if tokens.is_empty() {
433 return Err(FerrumError::model("Empty decode input"));
434 }
435
436 #[cfg(feature = "cuda")]
438 {
439 if Self::tp_size() > 1 && self.ensure_tp_group() {
440 {
443 let mut group = self.tp_group.lock();
444 if let Some(ref mut g) = *group {
445 if !g.has_kv_cache(&cache_id) {
446 let tp = g.world_size();
447 let kv_data = self
448 .model
449 .export_kv_cache(&cache_id)
450 .or_else(|| {
451 cache_id.rfind("-clone-").and_then(|pos| {
452 self.model.export_kv_cache(&cache_id[..pos])
453 })
454 })
455 .ok_or_else(|| {
456 FerrumError::model(format!("No KV for TP: {cache_id}"))
457 })?;
458
459 if !kv_data.is_empty() {
460 let prefill_len = kv_data[0].2;
461 let max_len = kv_data[0].3;
462 let num_kv_heads = self.model.config().num_key_value_heads;
463 let heads_per_rank = num_kv_heads / tp;
464
465 use ferrum_cuda_kernels::tp_decode::KvSource;
466
467 let mut per_rank_kv: Vec<Vec<(KvSource, KvSource)>> =
468 (0..tp).map(|_| Vec::new()).collect();
469
470 for (k_tensor, v_tensor, _len, _max) in &kv_data {
471 for rank in 0..tp {
472 let start = rank * heads_per_rank;
473 let k_shard = k_tensor
474 .narrow(1, start, heads_per_rank)
475 .and_then(|t| t.contiguous())
476 .map_err(|e| {
477 FerrumError::model(format!("KV shard: {e}"))
478 })?;
479 let v_shard = v_tensor
480 .narrow(1, start, heads_per_rank)
481 .and_then(|t| t.contiguous())
482 .map_err(|e| {
483 FerrumError::model(format!("KV shard: {e}"))
484 })?;
485
486 if rank == 0 {
487 use candle_core::Storage;
489 let (ks, _) = k_shard.storage_and_layout();
490 let (vs, _) = v_shard.storage_and_layout();
491 let kc = match &*ks {
492 Storage::Cuda(cs) => cs
493 .as_cuda_slice::<half::f16>()
494 .map_err(|e| {
495 FerrumError::model(format!("KV: {e}"))
496 })?
497 .clone(),
498 _ => return Err(FerrumError::model("KV not CUDA")),
499 };
500 let vc = match &*vs {
501 Storage::Cuda(cs) => cs
502 .as_cuda_slice::<half::f16>()
503 .map_err(|e| {
504 FerrumError::model(format!("KV: {e}"))
505 })?
506 .clone(),
507 _ => return Err(FerrumError::model("KV not CUDA")),
508 };
509 drop(ks);
510 drop(vs);
511 per_rank_kv[rank]
512 .push((KvSource::Gpu(kc), KvSource::Gpu(vc)));
513 } else {
514 let k_host = k_shard
516 .flatten_all()
517 .and_then(|t| t.to_vec1::<half::f16>())
518 .map_err(|e| {
519 FerrumError::model(format!("KV d2h: {e}"))
520 })?;
521 let v_host = v_shard
522 .flatten_all()
523 .and_then(|t| t.to_vec1::<half::f16>())
524 .map_err(|e| {
525 FerrumError::model(format!("KV d2h: {e}"))
526 })?;
527 per_rank_kv[rank].push((
528 KvSource::Host(k_host),
529 KvSource::Host(v_host),
530 ));
531 }
532 }
533 }
534 g.init_kv_cache(&cache_id, per_rank_kv, prefill_len, max_len)
535 .map_err(|e| FerrumError::model(format!("TP KV init: {e}")))?;
536 }
537 }
538 }
539 }
540
541 let logits = {
542 let mut group = self.tp_group.lock();
543 let group = group
544 .as_mut()
545 .ok_or_else(|| FerrumError::model("TP group gone"))?;
546 group
547 .decode_step(tokens[0], seq_len, &cache_id)
548 .map_err(|e| FerrumError::model(format!("tp_decode: {e}")))?
549 };
550
551 let cuda_dev = self
552 .model
553 .candle_device()
554 .as_cuda_device()
555 .map_err(|e| FerrumError::model(format!("not CUDA: {e}")))?;
556 let vocab = self.info.vocab_size;
557 let storage = candle_core::cuda_backend::CudaStorage::wrap_cuda_slice(
558 logits,
559 cuda_dev.clone(),
560 );
561 let logits_tensor = candle_core::Tensor::from_storage(
562 candle_core::Storage::Cuda(storage),
563 (1, 1, vocab),
564 candle_core::op::BackpropOp::none(),
565 false,
566 );
567 let logits_ref = self.wrap_tensor(logits_tensor);
568 let new_seq_len = seq_len + 1;
569 {
570 let mut states = self.states.lock();
571 if let Some(s) = states.get_mut(&cache_id) {
572 s.sequence_length = new_seq_len;
573 }
574 }
575 let new_handle = Arc::new(handle.with_sequence_length(new_seq_len));
576 return Ok(DecodeOutput::new(logits_ref, new_handle));
577 }
578 }
579
580 #[cfg(feature = "cuda")]
582 {
583 if std::env::var("FERRUM_DISABLE_CUDA_RUNNER").map_or(true, |v| v != "1") {
584 if self.ensure_cuda_runner() {
585 self.ensure_runner_kv_cache(&cache_id, seq_len)?;
586
587 let logits = {
588 let mut runner = self.cuda_runner.lock();
589 let runner = runner
590 .as_mut()
591 .ok_or_else(|| FerrumError::model("CUDA runner gone"))?;
592 runner
593 .decode_step(tokens[0], seq_len, &cache_id)
594 .map_err(|e| FerrumError::model(format!("decode: {e}")))?
595 };
596
597 let cuda_dev = self
598 .model
599 .candle_device()
600 .as_cuda_device()
601 .map_err(|e| FerrumError::model(format!("not CUDA: {e}")))?;
602 let vocab = self.info.vocab_size;
603 let storage = candle_core::cuda_backend::CudaStorage::wrap_cuda_slice(
604 logits,
605 cuda_dev.clone(),
606 );
607 let logits_tensor = candle_core::Tensor::from_storage(
608 candle_core::Storage::Cuda(storage),
609 (1, 1, vocab),
610 candle_core::op::BackpropOp::none(),
611 false,
612 );
613
614 let logits_ref = self.wrap_tensor(logits_tensor);
615
616 let new_seq_len = seq_len + 1;
617 {
618 let mut states = self.states.lock();
619 if let Some(s) = states.get_mut(&cache_id) {
620 s.sequence_length = new_seq_len;
621 }
622 }
623 let new_handle = Arc::new(handle.with_sequence_length(new_seq_len));
624 return Ok(DecodeOutput::new(logits_ref, new_handle));
625 }
626 }
627 }
628
629 let input_tensor = self.tokens_to_tensor(&tokens)?;
631 let logits = self
632 .model
633 .forward_decode(&input_tensor, seq_len, &cache_id)?;
634 let logits = match logits.dims().len() {
635 2 => logits
636 .unsqueeze(1)
637 .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?,
638 3 => logits,
639 _ => logits,
640 };
641 let logits_ref = self.wrap_tensor(logits);
642
643 let new_seq_len = seq_len + tokens.len();
644 {
645 let mut states = self.states.lock();
646 if let Some(s) = states.get_mut(&cache_id) {
647 s.sequence_length = new_seq_len;
648 }
649 }
650 let new_handle = Arc::new(handle.with_sequence_length(new_seq_len));
651 Ok(DecodeOutput::new(logits_ref, new_handle))
652 }
653
654 fn capabilities(&self) -> ExecutorCapabilities {
655 ExecutorCapabilities {
656 max_batch_size: 1,
657 max_sequence_length: self.info.max_sequence_length,
658 attention_mechanisms: vec![AttentionType::MultiHead, AttentionType::GroupedQuery],
659 supports_dynamic_batching: false,
660 supports_continuous_batching: true,
661 supports_speculative_decoding: false,
662 supports_tensor_parallelism: false,
663 supports_pipeline_parallelism: false,
664 supported_dtypes: vec![DataType::FP16, DataType::FP32],
665 supported_devices: vec![self.info.device.clone()],
666 memory_requirements: MemoryRequirements {
667 parameter_memory: (self.info.num_parameters * 2) as u64,
668 activation_memory_per_token: 4 * self.info.hidden_size,
669 kv_cache_memory_per_token: 2 * self.info.num_layers * self.info.hidden_size,
670 overhead_memory: 1024 * 1024 * 1024,
671 },
672 }
673 }
674
675 fn release_cache(&self, cache_id: &str) {
676 self.release_sequence(cache_id);
677 }
678
679 fn status(&self) -> ExecutorStatus {
680 common::default_executor_status()
681 }
682}