1use crate::error::{CoreError, CoreResult};
31use candle_core::{Device, Tensor};
32use std::collections::HashMap;
33
34pub struct TensorTransfer;
36
37impl TensorTransfer {
38 pub fn to_device(tensor: &Tensor, device: &Device) -> CoreResult<Tensor> {
47 tensor
48 .to_device(device)
49 .map_err(|e| CoreError::DeviceError(format!("Failed to transfer tensor: {}", e)))
50 }
51
52 pub fn to_cpu(tensor: &Tensor) -> CoreResult<Tensor> {
54 Self::to_device(tensor, &Device::Cpu)
55 }
56
57 pub fn to_gpu(tensor: &Tensor) -> CoreResult<Tensor> {
59 let device = crate::device::get_best_device();
60 if matches!(device, Device::Cpu) {
61 return Err(CoreError::DeviceError(
62 "No GPU device available".to_string(),
63 ));
64 }
65 Self::to_device(tensor, &device)
66 }
67
68 pub fn is_on_gpu(tensor: &Tensor) -> bool {
70 !matches!(tensor.device(), Device::Cpu)
71 }
72
73 pub fn is_on_cpu(tensor: &Tensor) -> bool {
75 matches!(tensor.device(), Device::Cpu)
76 }
77
78 pub fn get_device(tensor: &Tensor) -> Device {
80 tensor.device().clone()
81 }
82}
83
84pub struct TransferBatch;
86
87impl TransferBatch {
88 pub fn transfer_all(tensors: &[Tensor], device: &Device) -> CoreResult<Vec<Tensor>> {
93 tensors
94 .iter()
95 .map(|t| TensorTransfer::to_device(t, device))
96 .collect()
97 }
98
99 pub fn to_cpu_all(tensors: &[Tensor]) -> CoreResult<Vec<Tensor>> {
101 Self::transfer_all(tensors, &Device::Cpu)
102 }
103
104 pub fn to_gpu_all(tensors: &[Tensor]) -> CoreResult<Vec<Tensor>> {
106 let device = crate::device::get_best_device();
107 if matches!(device, Device::Cpu) {
108 return Err(CoreError::DeviceError(
109 "No GPU device available".to_string(),
110 ));
111 }
112 Self::transfer_all(tensors, &device)
113 }
114}
115
116#[derive(Debug, Clone)]
118pub struct MemoryStats {
119 pub total_allocated: usize,
121 pub tensor_count: usize,
123 pub memory_by_name: HashMap<String, usize>,
125}
126
127impl MemoryStats {
128 pub fn new() -> Self {
130 Self {
131 total_allocated: 0,
132 tensor_count: 0,
133 memory_by_name: HashMap::new(),
134 }
135 }
136
137 pub fn track_tensor(&mut self, name: String, tensor: &Tensor) {
139 let size = Self::tensor_size(tensor);
140 self.total_allocated += size;
141 self.tensor_count += 1;
142 self.memory_by_name.insert(name, size);
143 }
144
145 pub fn untrack_tensor(&mut self, name: &str) {
147 if let Some(size) = self.memory_by_name.remove(name) {
148 self.total_allocated = self.total_allocated.saturating_sub(size);
149 self.tensor_count = self.tensor_count.saturating_sub(1);
150 }
151 }
152
153 pub fn total_bytes(&self) -> usize {
155 self.total_allocated
156 }
157
158 pub fn total_mb(&self) -> f64 {
160 self.total_allocated as f64 / (1024.0 * 1024.0)
161 }
162
163 pub fn total_gb(&self) -> f64 {
165 self.total_allocated as f64 / (1024.0 * 1024.0 * 1024.0)
166 }
167
168 fn tensor_size(tensor: &Tensor) -> usize {
170 let elem_count: usize = tensor.dims().iter().product();
171 let dtype_size = match tensor.dtype() {
172 candle_core::DType::U8 => 1,
173 candle_core::DType::U32 => 4,
174 candle_core::DType::I64 => 8,
175 candle_core::DType::F16 => 2,
176 candle_core::DType::BF16 => 2,
177 candle_core::DType::F32 => 4,
178 candle_core::DType::F64 => 8,
179 };
180 elem_count * dtype_size
181 }
182
183 pub fn clear(&mut self) {
185 self.total_allocated = 0;
186 self.tensor_count = 0;
187 self.memory_by_name.clear();
188 }
189}
190
191impl Default for MemoryStats {
192 fn default() -> Self {
193 Self::new()
194 }
195}
196
197pub struct GPUMemoryPool {
199 device: Device,
200 stats: MemoryStats,
201}
202
203impl GPUMemoryPool {
204 pub fn new(device: Device) -> Self {
206 Self {
207 device,
208 stats: MemoryStats::new(),
209 }
210 }
211
212 pub fn allocate(
214 &mut self,
215 name: String,
216 shape: &[usize],
217 dtype: candle_core::DType,
218 ) -> CoreResult<Tensor> {
219 let tensor = Tensor::zeros(shape, dtype, &self.device)
220 .map_err(|e| CoreError::DeviceError(format!("Failed to allocate tensor: {}", e)))?;
221
222 self.stats.track_tensor(name, &tensor);
223 Ok(tensor)
224 }
225
226 pub fn release(&mut self, name: &str) {
228 self.stats.untrack_tensor(name);
229 }
230
231 pub fn stats(&self) -> &MemoryStats {
233 &self.stats
234 }
235
236 pub fn device(&self) -> &Device {
238 &self.device
239 }
240}
241
242pub struct TensorPrefetch;
244
245impl TensorPrefetch {
246 pub fn prefetch(tensor: &Tensor, device: &Device) -> CoreResult<Tensor> {
251 TensorTransfer::to_device(tensor, device)
254 }
255
256 pub fn prefetch_batch(tensors: &[Tensor], device: &Device) -> CoreResult<Vec<Tensor>> {
258 TransferBatch::transfer_all(tensors, device)
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265 use candle_core::DType;
266
267 #[test]
268 fn test_tensor_transfer_to_cpu() {
269 let tensor = Tensor::zeros((10, 10), DType::F32, &Device::Cpu).unwrap();
270 let cpu_tensor = TensorTransfer::to_cpu(&tensor).unwrap();
271
272 assert!(TensorTransfer::is_on_cpu(&cpu_tensor));
273 assert!(!TensorTransfer::is_on_gpu(&cpu_tensor));
274 }
275
276 #[test]
277 fn test_batch_transfer() {
278 let tensors = vec![
279 Tensor::zeros((5, 5), DType::F32, &Device::Cpu).unwrap(),
280 Tensor::zeros((10, 10), DType::F32, &Device::Cpu).unwrap(),
281 ];
282
283 let cpu_tensors = TransferBatch::to_cpu_all(&tensors).unwrap();
284 assert_eq!(cpu_tensors.len(), 2);
285
286 for tensor in &cpu_tensors {
287 assert!(TensorTransfer::is_on_cpu(tensor));
288 }
289 }
290
291 #[test]
292 fn test_memory_stats() {
293 let mut stats = MemoryStats::new();
294
295 let tensor1 = Tensor::zeros((10, 10), DType::F32, &Device::Cpu).unwrap();
296 let tensor2 = Tensor::zeros((20, 20), DType::F32, &Device::Cpu).unwrap();
297
298 stats.track_tensor("tensor1".to_string(), &tensor1);
299 stats.track_tensor("tensor2".to_string(), &tensor2);
300
301 assert_eq!(stats.tensor_count, 2);
302 assert_eq!(stats.total_bytes(), 2000);
304
305 stats.untrack_tensor("tensor1");
306 assert_eq!(stats.tensor_count, 1);
307 assert_eq!(stats.total_bytes(), 1600);
308
309 stats.clear();
310 assert_eq!(stats.tensor_count, 0);
311 assert_eq!(stats.total_bytes(), 0);
312 }
313
314 #[test]
315 fn test_memory_stats_mb_gb() {
316 let mut stats = MemoryStats::new();
317
318 let tensor = Tensor::zeros((1000, 1000), DType::F32, &Device::Cpu).unwrap();
320 stats.track_tensor("large_tensor".to_string(), &tensor);
321
322 let expected_mb = 4_000_000.0 / (1024.0 * 1024.0);
325 assert!((stats.total_mb() - expected_mb).abs() < 0.01);
326
327 let expected_gb = 4_000_000.0 / (1024.0 * 1024.0 * 1024.0);
328 assert!((stats.total_gb() - expected_gb).abs() < 0.0001);
329 }
330
331 #[test]
332 fn test_gpu_memory_pool() {
333 let mut pool = GPUMemoryPool::new(Device::Cpu);
334
335 let tensor = pool
336 .allocate("test_tensor".to_string(), &[100, 100], DType::F32)
337 .unwrap();
338
339 assert_eq!(tensor.dims(), &[100, 100]);
340 assert_eq!(pool.stats().tensor_count, 1);
341 assert_eq!(pool.stats().total_bytes(), 100 * 100 * 4);
342
343 pool.release("test_tensor");
344 assert_eq!(pool.stats().tensor_count, 0);
345 }
346
347 #[test]
348 fn test_get_device() {
349 let tensor = Tensor::zeros((10, 10), DType::F32, &Device::Cpu).unwrap();
350 let device = TensorTransfer::get_device(&tensor);
351 assert!(matches!(device, Device::Cpu));
352 }
353
354 #[test]
355 fn test_tensor_size_calculation() {
356 let tensor_f32 = Tensor::zeros((10, 20), DType::F32, &Device::Cpu).unwrap();
357 assert_eq!(MemoryStats::tensor_size(&tensor_f32), 10 * 20 * 4);
358
359 let tensor_f16 = Tensor::zeros((10, 20), DType::F16, &Device::Cpu).unwrap();
360 assert_eq!(MemoryStats::tensor_size(&tensor_f16), 10 * 20 * 2);
361
362 let tensor_i64 = Tensor::zeros((5, 5), DType::I64, &Device::Cpu).unwrap();
363 assert_eq!(MemoryStats::tensor_size(&tensor_i64), 5 * 5 * 8);
364 }
365}