1use candle_core::{Device as CandleDevice, Tensor};
5use ferrum_interfaces::{
6 kv_cache::{BlockTable, CacheHandleStats},
7 KvCacheHandle, TensorRef,
8};
9use ferrum_types::{Device, FerrumError, Result};
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::sync::Arc;
12
13static CLONE_COUNTER: AtomicU64 = AtomicU64::new(0);
15use std::time::Instant;
16
17use crate::tensor_wrapper::CandleTensorWrapper;
18
19pub fn tensor_to_tokens(tensor: &TensorRef) -> Result<Vec<u32>> {
23 if let Ok(tokens) = tensor.to_vec_u32() {
24 if tokens.is_empty() {
25 return Err(FerrumError::model("Input token tensor is empty"));
26 }
27 return Ok(tokens);
28 }
29 if let Ok(tokens_f32) = tensor.to_vec_f32() {
30 let tokens: Vec<u32> = tokens_f32.into_iter().map(|x| x as u32).collect();
31 if tokens.is_empty() {
32 return Err(FerrumError::model("Input token tensor is empty"));
33 }
34 return Ok(tokens);
35 }
36 Err(FerrumError::model(
37 "Unable to extract token IDs from input tensor",
38 ))
39}
40
41pub fn tokens_to_tensor(tokens: &[u32], device: &CandleDevice) -> Result<Tensor> {
43 let base = Tensor::new(tokens, &CandleDevice::Cpu)
44 .map_err(|e| FerrumError::model(format!("Failed to create tensor: {}", e)))?
45 .unsqueeze(0)
46 .map_err(|e| FerrumError::model(format!("Failed to unsqueeze tensor: {}", e)))?
47 .to_dtype(candle_core::DType::I64)
48 .map_err(|e| FerrumError::model(format!("Failed to cast tokens to I64: {}", e)))?;
49
50 if matches!(device, CandleDevice::Cpu) {
51 Ok(base)
52 } else {
53 base.to_device(device)
54 .map_err(|e| FerrumError::model(format!("Failed to move tensor to device: {}", e)))
55 }
56}
57
58pub fn wrap_tensor(tensor: Tensor) -> TensorRef {
60 Arc::new(CandleTensorWrapper::new(tensor))
61}
62
63#[derive(Debug, Clone)]
67pub struct GenericKvCacheHandle {
68 block_table: BlockTable,
69 num_layers: usize,
70 num_heads: usize,
71 head_dim: usize,
72 device: Device,
73 cache_id: String,
74}
75
76impl GenericKvCacheHandle {
77 pub fn new(
78 num_layers: usize,
79 num_heads: usize,
80 head_dim: usize,
81 device: CandleDevice,
82 seq_len: usize,
83 cache_id: String,
84 ) -> Self {
85 let mut block_table = BlockTable::new(16);
86 block_table.sequence_length = seq_len;
87
88 Self {
89 block_table,
90 num_layers,
91 num_heads,
92 head_dim,
93 cache_id,
94 device: match device {
95 CandleDevice::Cpu => Device::CPU,
96 CandleDevice::Cuda(_) => Device::CUDA(0),
97 #[cfg(any(target_os = "macos", target_os = "ios"))]
98 CandleDevice::Metal(_) => Device::Metal,
99 #[cfg(not(any(target_os = "macos", target_os = "ios")))]
100 CandleDevice::Metal(_) => Device::CPU,
101 },
102 }
103 }
104
105 pub fn with_sequence_length(&self, seq_len: usize) -> Self {
106 let mut handle = self.clone();
107 handle.block_table.sequence_length = seq_len;
108 handle
109 }
110
111 pub fn request_cache_id(&self) -> &str {
112 &self.cache_id
113 }
114}
115
116impl KvCacheHandle for GenericKvCacheHandle {
117 fn block_table(&self) -> &BlockTable {
118 &self.block_table
119 }
120 fn block_table_mut(&mut self) -> &mut BlockTable {
121 &mut self.block_table
122 }
123 fn as_any(&self) -> &dyn std::any::Any {
124 self
125 }
126 fn device(&self) -> Device {
127 self.device.clone()
128 }
129 fn num_layers(&self) -> usize {
130 self.num_layers
131 }
132 fn num_heads(&self) -> usize {
133 self.num_heads
134 }
135 fn head_dim(&self) -> usize {
136 self.head_dim
137 }
138 fn key_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
139 Ok(None)
140 }
141 fn value_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
142 Ok(None)
143 }
144 fn clone_handle(&self) -> Result<Arc<dyn KvCacheHandle>> {
145 let mut cloned = self.clone();
148 let n = CLONE_COUNTER.fetch_add(1, Ordering::Relaxed);
149 cloned.cache_id = format!("{}-clone-{n}", self.cache_id);
150 Ok(Arc::new(cloned))
151 }
152 fn stats(&self) -> CacheHandleStats {
153 CacheHandleStats {
154 memory_bytes: 0,
155 blocks_allocated: self.block_table.num_blocks(),
156 tokens_stored: self.block_table.sequence_length,
157 utilization: 0.0,
158 last_access: Instant::now(),
159 }
160 }
161 fn is_valid(&self) -> bool {
162 true
163 }
164 fn cache_id(&self) -> String {
165 self.cache_id.clone()
166 }
167}
168
169pub fn default_executor_status() -> ferrum_interfaces::model_executor::ExecutorStatus {
173 use ferrum_interfaces::model_executor::*;
174 ExecutorStatus {
175 state: ExecutorState::Ready,
176 is_ready: true,
177 current_batch_size: 0,
178 prefill_operations: 0,
179 decode_operations: 0,
180 avg_prefill_time_ms: 0.0,
181 avg_decode_time_ms: 0.0,
182 memory_usage: ExecutorMemoryUsage {
183 allocated_bytes: 0,
184 used_bytes: 0,
185 peak_bytes: 0,
186 utilization_percent: 0.0,
187 },
188 last_operation: Some(Instant::now()),
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195 use ferrum_interfaces::KvCacheHandle;
196
197 #[test]
198 fn tensor_to_tokens_from_u32() {
199 let tensor = ferrum_testkit::MockTensor::from_u32(&[1, 2, 3], &[3]);
200 let tokens = tensor_to_tokens(&tensor.into_ref()).unwrap();
201 assert_eq!(tokens, vec![1, 2, 3]);
202 }
203
204 #[test]
205 fn tensor_to_tokens_from_f32() {
206 let tensor = ferrum_testkit::MockTensor::from_f32(vec![10.0, 20.0, 30.0], &[3]);
207 let tokens = tensor_to_tokens(&tensor.into_ref()).unwrap();
208 assert_eq!(tokens, vec![10, 20, 30]);
209 }
210
211 #[test]
212 fn tensor_to_tokens_empty_fails() {
213 let tensor = ferrum_testkit::MockTensor::from_u32(&[], &[0]);
214 let result = tensor_to_tokens(&tensor.into_ref());
215 assert!(result.is_err());
216 }
217
218 #[test]
219 fn tokens_to_tensor_cpu() {
220 let tensor = tokens_to_tensor(&[42, 100], &CandleDevice::Cpu).unwrap();
221 assert_eq!(tensor.dims(), &[1, 2]);
222 assert_eq!(tensor.dtype(), candle_core::DType::I64);
223 }
224
225 #[test]
226 fn wrap_tensor_creates_tensor_ref() {
227 let t = Tensor::zeros((2, 3), candle_core::DType::F32, &CandleDevice::Cpu).unwrap();
228 let tr = wrap_tensor(t);
229 assert_eq!(tr.shape(), &[2, 3]);
230 }
231
232 #[test]
233 fn generic_kv_cache_handle_basic() {
234 let handle = GenericKvCacheHandle::new(
235 36, 32, 128, CandleDevice::Cpu,
239 10, "test-cache-1".to_string(),
241 );
242
243 assert_eq!(handle.num_layers(), 36);
244 assert_eq!(handle.num_heads(), 32);
245 assert_eq!(handle.head_dim(), 128);
246 assert_eq!(handle.cache_id(), "test-cache-1");
247 assert_eq!(handle.block_table().sequence_length, 10);
248 assert!(handle.is_valid());
249 }
250
251 #[test]
252 fn generic_kv_cache_handle_with_sequence_length() {
253 let handle =
254 GenericKvCacheHandle::new(4, 8, 64, CandleDevice::Cpu, 5, "cache-1".to_string());
255 let updated = handle.with_sequence_length(15);
256 assert_eq!(updated.block_table().sequence_length, 15);
257 assert_eq!(updated.request_cache_id(), "cache-1");
258 assert_eq!(handle.block_table().sequence_length, 5);
260 }
261
262 #[test]
263 fn generic_kv_cache_handle_clone_handle() {
264 let handle =
265 GenericKvCacheHandle::new(4, 8, 64, CandleDevice::Cpu, 5, "cache-2".to_string());
266 let cloned = handle.clone_handle().unwrap();
267 assert!(cloned.cache_id().starts_with("cache-2-clone-"));
269 assert_eq!(cloned.num_layers(), 4);
270 }
271
272 #[test]
273 fn default_executor_status_is_ready() {
274 let status = default_executor_status();
275 assert!(status.is_ready);
276 assert_eq!(
277 status.state,
278 ferrum_interfaces::model_executor::ExecutorState::Ready
279 );
280 }
281}