Skip to main content

ferrum_models/executor/
common.rs

1//! Common executor utilities — extracted from duplicated code across
2//! Qwen3, Qwen2, and Llama executors.
3
4use candle_core::{Device as CandleDevice, Tensor};
5use ferrum_interfaces::{
6    kv_cache::{BlockTable, CacheHandleStats},
7    KvCacheHandle, TensorRef,
8};
9use ferrum_types::{Device, FerrumError, Result};
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::sync::Arc;
12
13/// Global counter for generating unique clone cache IDs.
14static CLONE_COUNTER: AtomicU64 = AtomicU64::new(0);
15use std::time::Instant;
16
17use crate::tensor_wrapper::CandleTensorWrapper;
18
19// ======================== Tensor Conversion ========================
20
21/// Extract token IDs from a TensorRef.
22pub fn tensor_to_tokens(tensor: &TensorRef) -> Result<Vec<u32>> {
23    if let Ok(tokens) = tensor.to_vec_u32() {
24        if tokens.is_empty() {
25            return Err(FerrumError::model("Input token tensor is empty"));
26        }
27        return Ok(tokens);
28    }
29    if let Ok(tokens_f32) = tensor.to_vec_f32() {
30        let tokens: Vec<u32> = tokens_f32.into_iter().map(|x| x as u32).collect();
31        if tokens.is_empty() {
32            return Err(FerrumError::model("Input token tensor is empty"));
33        }
34        return Ok(tokens);
35    }
36    Err(FerrumError::model(
37        "Unable to extract token IDs from input tensor",
38    ))
39}
40
41/// Convert token IDs to a candle Tensor on the target device.
42pub fn tokens_to_tensor(tokens: &[u32], device: &CandleDevice) -> Result<Tensor> {
43    let base = Tensor::new(tokens, &CandleDevice::Cpu)
44        .map_err(|e| FerrumError::model(format!("Failed to create tensor: {}", e)))?
45        .unsqueeze(0)
46        .map_err(|e| FerrumError::model(format!("Failed to unsqueeze tensor: {}", e)))?
47        .to_dtype(candle_core::DType::I64)
48        .map_err(|e| FerrumError::model(format!("Failed to cast tokens to I64: {}", e)))?;
49
50    if matches!(device, CandleDevice::Cpu) {
51        Ok(base)
52    } else {
53        base.to_device(device)
54            .map_err(|e| FerrumError::model(format!("Failed to move tensor to device: {}", e)))
55    }
56}
57
58/// Wrap a candle Tensor as a TensorRef.
59pub fn wrap_tensor(tensor: Tensor) -> TensorRef {
60    Arc::new(CandleTensorWrapper::new(tensor))
61}
62
63// ======================== Generic KV Cache Handle ========================
64
65/// Generic KV cache handle usable by any model architecture.
66#[derive(Debug, Clone)]
67pub struct GenericKvCacheHandle {
68    block_table: BlockTable,
69    num_layers: usize,
70    num_heads: usize,
71    head_dim: usize,
72    device: Device,
73    cache_id: String,
74}
75
76impl GenericKvCacheHandle {
77    pub fn new(
78        num_layers: usize,
79        num_heads: usize,
80        head_dim: usize,
81        device: CandleDevice,
82        seq_len: usize,
83        cache_id: String,
84    ) -> Self {
85        let mut block_table = BlockTable::new(16);
86        block_table.sequence_length = seq_len;
87
88        Self {
89            block_table,
90            num_layers,
91            num_heads,
92            head_dim,
93            cache_id,
94            device: match device {
95                CandleDevice::Cpu => Device::CPU,
96                CandleDevice::Cuda(_) => Device::CUDA(0),
97                #[cfg(any(target_os = "macos", target_os = "ios"))]
98                CandleDevice::Metal(_) => Device::Metal,
99                #[cfg(not(any(target_os = "macos", target_os = "ios")))]
100                CandleDevice::Metal(_) => Device::CPU,
101            },
102        }
103    }
104
105    pub fn with_sequence_length(&self, seq_len: usize) -> Self {
106        let mut handle = self.clone();
107        handle.block_table.sequence_length = seq_len;
108        handle
109    }
110
111    pub fn request_cache_id(&self) -> &str {
112        &self.cache_id
113    }
114}
115
116impl KvCacheHandle for GenericKvCacheHandle {
117    fn block_table(&self) -> &BlockTable {
118        &self.block_table
119    }
120    fn block_table_mut(&mut self) -> &mut BlockTable {
121        &mut self.block_table
122    }
123    fn as_any(&self) -> &dyn std::any::Any {
124        self
125    }
126    fn device(&self) -> Device {
127        self.device.clone()
128    }
129    fn num_layers(&self) -> usize {
130        self.num_layers
131    }
132    fn num_heads(&self) -> usize {
133        self.num_heads
134    }
135    fn head_dim(&self) -> usize {
136        self.head_dim
137    }
138    fn key_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
139        Ok(None)
140    }
141    fn value_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
142        Ok(None)
143    }
144    fn clone_handle(&self) -> Result<Arc<dyn KvCacheHandle>> {
145        // Each clone gets a unique cache_id so CUDA runner's kv_states
146        // HashMap can distinguish between sequences with same prompt.
147        let mut cloned = self.clone();
148        let n = CLONE_COUNTER.fetch_add(1, Ordering::Relaxed);
149        cloned.cache_id = format!("{}-clone-{n}", self.cache_id);
150        Ok(Arc::new(cloned))
151    }
152    fn stats(&self) -> CacheHandleStats {
153        CacheHandleStats {
154            memory_bytes: 0,
155            blocks_allocated: self.block_table.num_blocks(),
156            tokens_stored: self.block_table.sequence_length,
157            utilization: 0.0,
158            last_access: Instant::now(),
159        }
160    }
161    fn is_valid(&self) -> bool {
162        true
163    }
164    fn cache_id(&self) -> String {
165        self.cache_id.clone()
166    }
167}
168
169// ======================== Default Executor Status ========================
170
171/// Default executor status (all executors return the same thing).
172pub fn default_executor_status() -> ferrum_interfaces::model_executor::ExecutorStatus {
173    use ferrum_interfaces::model_executor::*;
174    ExecutorStatus {
175        state: ExecutorState::Ready,
176        is_ready: true,
177        current_batch_size: 0,
178        prefill_operations: 0,
179        decode_operations: 0,
180        avg_prefill_time_ms: 0.0,
181        avg_decode_time_ms: 0.0,
182        memory_usage: ExecutorMemoryUsage {
183            allocated_bytes: 0,
184            used_bytes: 0,
185            peak_bytes: 0,
186            utilization_percent: 0.0,
187        },
188        last_operation: Some(Instant::now()),
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195    use ferrum_interfaces::KvCacheHandle;
196
197    #[test]
198    fn tensor_to_tokens_from_u32() {
199        let tensor = ferrum_testkit::MockTensor::from_u32(&[1, 2, 3], &[3]);
200        let tokens = tensor_to_tokens(&tensor.into_ref()).unwrap();
201        assert_eq!(tokens, vec![1, 2, 3]);
202    }
203
204    #[test]
205    fn tensor_to_tokens_from_f32() {
206        let tensor = ferrum_testkit::MockTensor::from_f32(vec![10.0, 20.0, 30.0], &[3]);
207        let tokens = tensor_to_tokens(&tensor.into_ref()).unwrap();
208        assert_eq!(tokens, vec![10, 20, 30]);
209    }
210
211    #[test]
212    fn tensor_to_tokens_empty_fails() {
213        let tensor = ferrum_testkit::MockTensor::from_u32(&[], &[0]);
214        let result = tensor_to_tokens(&tensor.into_ref());
215        assert!(result.is_err());
216    }
217
218    #[test]
219    fn tokens_to_tensor_cpu() {
220        let tensor = tokens_to_tensor(&[42, 100], &CandleDevice::Cpu).unwrap();
221        assert_eq!(tensor.dims(), &[1, 2]);
222        assert_eq!(tensor.dtype(), candle_core::DType::I64);
223    }
224
225    #[test]
226    fn wrap_tensor_creates_tensor_ref() {
227        let t = Tensor::zeros((2, 3), candle_core::DType::F32, &CandleDevice::Cpu).unwrap();
228        let tr = wrap_tensor(t);
229        assert_eq!(tr.shape(), &[2, 3]);
230    }
231
232    #[test]
233    fn generic_kv_cache_handle_basic() {
234        let handle = GenericKvCacheHandle::new(
235            36,  // num_layers
236            32,  // num_heads
237            128, // head_dim
238            CandleDevice::Cpu,
239            10, // seq_len
240            "test-cache-1".to_string(),
241        );
242
243        assert_eq!(handle.num_layers(), 36);
244        assert_eq!(handle.num_heads(), 32);
245        assert_eq!(handle.head_dim(), 128);
246        assert_eq!(handle.cache_id(), "test-cache-1");
247        assert_eq!(handle.block_table().sequence_length, 10);
248        assert!(handle.is_valid());
249    }
250
251    #[test]
252    fn generic_kv_cache_handle_with_sequence_length() {
253        let handle =
254            GenericKvCacheHandle::new(4, 8, 64, CandleDevice::Cpu, 5, "cache-1".to_string());
255        let updated = handle.with_sequence_length(15);
256        assert_eq!(updated.block_table().sequence_length, 15);
257        assert_eq!(updated.request_cache_id(), "cache-1");
258        // Original unchanged
259        assert_eq!(handle.block_table().sequence_length, 5);
260    }
261
262    #[test]
263    fn generic_kv_cache_handle_clone_handle() {
264        let handle =
265            GenericKvCacheHandle::new(4, 8, 64, CandleDevice::Cpu, 5, "cache-2".to_string());
266        let cloned = handle.clone_handle().unwrap();
267        // Clone gets unique ID: "cache-2-clone-N"
268        assert!(cloned.cache_id().starts_with("cache-2-clone-"));
269        assert_eq!(cloned.num_layers(), 4);
270    }
271
272    #[test]
273    fn default_executor_status_is_ready() {
274        let status = default_executor_status();
275        assert!(status.is_ready);
276        assert_eq!(
277            status.state,
278            ferrum_interfaces::model_executor::ExecutorState::Ready
279        );
280    }
281}