kizzasi_core/
gpu_utils.rs

1//! GPU memory management and tensor transfer utilities
2//!
3//! Provides efficient utilities for:
4//! - Moving tensors between CPU and GPU
5//! - Memory pooling for GPU tensors
6//! - Batch transfer operations
7//! - Memory usage tracking
8//!
9//! # Examples
10//!
11//! ```rust
12//! use kizzasi_core::gpu_utils::{TensorTransfer, TransferBatch};
13//! use kizzasi_core::device::DeviceConfig;
14//! use candle_core::{Tensor, Device};
15//!
16//! # fn example() -> Result<(), Box<dyn std::error::Error>> {
17//! // Transfer single tensor to GPU
18//! let cpu_tensor = Tensor::zeros((100, 100), candle_core::DType::F32, &Device::Cpu)?;
19//! let gpu_device = DeviceConfig::default().create_device()?;
20//!
21//! let gpu_tensor = TensorTransfer::to_device(&cpu_tensor, &gpu_device)?;
22//!
23//! // Batch transfer multiple tensors
24//! let tensors = vec![cpu_tensor.clone(), cpu_tensor.clone()];
25//! let gpu_tensors = TransferBatch::transfer_all(&tensors, &gpu_device)?;
26//! # Ok(())
27//! # }
28//! ```
29
30use crate::error::{CoreError, CoreResult};
31use candle_core::{Device, Tensor};
32use std::collections::HashMap;
33
34/// Tensor transfer utilities for CPU/GPU operations
35pub struct TensorTransfer;
36
37impl TensorTransfer {
38    /// Transfer a tensor to a specific device
39    ///
40    /// # Arguments
41    /// * `tensor` - The tensor to transfer
42    /// * `device` - Target device
43    ///
44    /// # Returns
45    /// Tensor on the target device
46    pub fn to_device(tensor: &Tensor, device: &Device) -> CoreResult<Tensor> {
47        tensor
48            .to_device(device)
49            .map_err(|e| CoreError::DeviceError(format!("Failed to transfer tensor: {}", e)))
50    }
51
52    /// Transfer a tensor to CPU
53    pub fn to_cpu(tensor: &Tensor) -> CoreResult<Tensor> {
54        Self::to_device(tensor, &Device::Cpu)
55    }
56
57    /// Transfer a tensor to GPU (auto-detect best GPU)
58    pub fn to_gpu(tensor: &Tensor) -> CoreResult<Tensor> {
59        let device = crate::device::get_best_device();
60        if matches!(device, Device::Cpu) {
61            return Err(CoreError::DeviceError(
62                "No GPU device available".to_string(),
63            ));
64        }
65        Self::to_device(tensor, &device)
66    }
67
68    /// Check if a tensor is on GPU
69    pub fn is_on_gpu(tensor: &Tensor) -> bool {
70        !matches!(tensor.device(), Device::Cpu)
71    }
72
73    /// Check if a tensor is on CPU
74    pub fn is_on_cpu(tensor: &Tensor) -> bool {
75        matches!(tensor.device(), Device::Cpu)
76    }
77
78    /// Get device of a tensor
79    pub fn get_device(tensor: &Tensor) -> Device {
80        tensor.device().clone()
81    }
82}
83
84/// Batch tensor transfer operations
85pub struct TransferBatch;
86
87impl TransferBatch {
88    /// Transfer multiple tensors to a device in batch
89    ///
90    /// This is more efficient than transferring tensors one by one
91    /// as it can leverage async transfers on some backends.
92    pub fn transfer_all(tensors: &[Tensor], device: &Device) -> CoreResult<Vec<Tensor>> {
93        tensors
94            .iter()
95            .map(|t| TensorTransfer::to_device(t, device))
96            .collect()
97    }
98
99    /// Transfer all tensors to CPU
100    pub fn to_cpu_all(tensors: &[Tensor]) -> CoreResult<Vec<Tensor>> {
101        Self::transfer_all(tensors, &Device::Cpu)
102    }
103
104    /// Transfer all tensors to GPU
105    pub fn to_gpu_all(tensors: &[Tensor]) -> CoreResult<Vec<Tensor>> {
106        let device = crate::device::get_best_device();
107        if matches!(device, Device::Cpu) {
108            return Err(CoreError::DeviceError(
109                "No GPU device available".to_string(),
110            ));
111        }
112        Self::transfer_all(tensors, &device)
113    }
114}
115
116/// Memory usage tracking for GPU tensors
117#[derive(Debug, Clone)]
118pub struct MemoryStats {
119    /// Total memory allocated (bytes)
120    pub total_allocated: usize,
121    /// Number of tensors tracked
122    pub tensor_count: usize,
123    /// Memory by tensor name
124    pub memory_by_name: HashMap<String, usize>,
125}
126
127impl MemoryStats {
128    /// Create a new memory stats tracker
129    pub fn new() -> Self {
130        Self {
131            total_allocated: 0,
132            tensor_count: 0,
133            memory_by_name: HashMap::new(),
134        }
135    }
136
137    /// Track a tensor's memory usage
138    pub fn track_tensor(&mut self, name: String, tensor: &Tensor) {
139        let size = Self::tensor_size(tensor);
140        self.total_allocated += size;
141        self.tensor_count += 1;
142        self.memory_by_name.insert(name, size);
143    }
144
145    /// Untrack a tensor
146    pub fn untrack_tensor(&mut self, name: &str) {
147        if let Some(size) = self.memory_by_name.remove(name) {
148            self.total_allocated = self.total_allocated.saturating_sub(size);
149            self.tensor_count = self.tensor_count.saturating_sub(1);
150        }
151    }
152
153    /// Get total allocated memory in bytes
154    pub fn total_bytes(&self) -> usize {
155        self.total_allocated
156    }
157
158    /// Get total allocated memory in MB
159    pub fn total_mb(&self) -> f64 {
160        self.total_allocated as f64 / (1024.0 * 1024.0)
161    }
162
163    /// Get total allocated memory in GB
164    pub fn total_gb(&self) -> f64 {
165        self.total_allocated as f64 / (1024.0 * 1024.0 * 1024.0)
166    }
167
168    /// Calculate tensor size in bytes
169    fn tensor_size(tensor: &Tensor) -> usize {
170        let elem_count: usize = tensor.dims().iter().product();
171        let dtype_size = match tensor.dtype() {
172            candle_core::DType::U8 => 1,
173            candle_core::DType::U32 => 4,
174            candle_core::DType::I64 => 8,
175            candle_core::DType::F16 => 2,
176            candle_core::DType::BF16 => 2,
177            candle_core::DType::F32 => 4,
178            candle_core::DType::F64 => 8,
179        };
180        elem_count * dtype_size
181    }
182
183    /// Clear all tracked tensors
184    pub fn clear(&mut self) {
185        self.total_allocated = 0;
186        self.tensor_count = 0;
187        self.memory_by_name.clear();
188    }
189}
190
191impl Default for MemoryStats {
192    fn default() -> Self {
193        Self::new()
194    }
195}
196
197/// GPU memory pool for efficient tensor allocation
198pub struct GPUMemoryPool {
199    device: Device,
200    stats: MemoryStats,
201}
202
203impl GPUMemoryPool {
204    /// Create a new GPU memory pool
205    pub fn new(device: Device) -> Self {
206        Self {
207            device,
208            stats: MemoryStats::new(),
209        }
210    }
211
212    /// Allocate a tensor on GPU
213    pub fn allocate(
214        &mut self,
215        name: String,
216        shape: &[usize],
217        dtype: candle_core::DType,
218    ) -> CoreResult<Tensor> {
219        let tensor = Tensor::zeros(shape, dtype, &self.device)
220            .map_err(|e| CoreError::DeviceError(format!("Failed to allocate tensor: {}", e)))?;
221
222        self.stats.track_tensor(name, &tensor);
223        Ok(tensor)
224    }
225
226    /// Release a tensor from the pool
227    pub fn release(&mut self, name: &str) {
228        self.stats.untrack_tensor(name);
229    }
230
231    /// Get memory statistics
232    pub fn stats(&self) -> &MemoryStats {
233        &self.stats
234    }
235
236    /// Get device
237    pub fn device(&self) -> &Device {
238        &self.device
239    }
240}
241
242/// Prefetching utilities for optimizing data transfer
243pub struct TensorPrefetch;
244
245impl TensorPrefetch {
246    /// Prefetch tensor to device (async hint for backends that support it)
247    ///
248    /// Note: This is a hint to the backend. Actual async behavior depends on
249    /// the device backend (CUDA/Metal).
250    pub fn prefetch(tensor: &Tensor, device: &Device) -> CoreResult<Tensor> {
251        // For now, this is synchronous. Future implementations could leverage
252        // async streams on CUDA or Metal command buffers
253        TensorTransfer::to_device(tensor, device)
254    }
255
256    /// Prefetch multiple tensors
257    pub fn prefetch_batch(tensors: &[Tensor], device: &Device) -> CoreResult<Vec<Tensor>> {
258        TransferBatch::transfer_all(tensors, device)
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265    use candle_core::DType;
266
267    #[test]
268    fn test_tensor_transfer_to_cpu() {
269        let tensor = Tensor::zeros((10, 10), DType::F32, &Device::Cpu).unwrap();
270        let cpu_tensor = TensorTransfer::to_cpu(&tensor).unwrap();
271
272        assert!(TensorTransfer::is_on_cpu(&cpu_tensor));
273        assert!(!TensorTransfer::is_on_gpu(&cpu_tensor));
274    }
275
276    #[test]
277    fn test_batch_transfer() {
278        let tensors = vec![
279            Tensor::zeros((5, 5), DType::F32, &Device::Cpu).unwrap(),
280            Tensor::zeros((10, 10), DType::F32, &Device::Cpu).unwrap(),
281        ];
282
283        let cpu_tensors = TransferBatch::to_cpu_all(&tensors).unwrap();
284        assert_eq!(cpu_tensors.len(), 2);
285
286        for tensor in &cpu_tensors {
287            assert!(TensorTransfer::is_on_cpu(tensor));
288        }
289    }
290
291    #[test]
292    fn test_memory_stats() {
293        let mut stats = MemoryStats::new();
294
295        let tensor1 = Tensor::zeros((10, 10), DType::F32, &Device::Cpu).unwrap();
296        let tensor2 = Tensor::zeros((20, 20), DType::F32, &Device::Cpu).unwrap();
297
298        stats.track_tensor("tensor1".to_string(), &tensor1);
299        stats.track_tensor("tensor2".to_string(), &tensor2);
300
301        assert_eq!(stats.tensor_count, 2);
302        // 10*10*4 + 20*20*4 = 400 + 1600 = 2000 bytes
303        assert_eq!(stats.total_bytes(), 2000);
304
305        stats.untrack_tensor("tensor1");
306        assert_eq!(stats.tensor_count, 1);
307        assert_eq!(stats.total_bytes(), 1600);
308
309        stats.clear();
310        assert_eq!(stats.tensor_count, 0);
311        assert_eq!(stats.total_bytes(), 0);
312    }
313
314    #[test]
315    fn test_memory_stats_mb_gb() {
316        let mut stats = MemoryStats::new();
317
318        // Create a large tensor: 1000 * 1000 * f32 (4 bytes) = 4,000,000 bytes
319        let tensor = Tensor::zeros((1000, 1000), DType::F32, &Device::Cpu).unwrap();
320        stats.track_tensor("large_tensor".to_string(), &tensor);
321
322        // 4,000,000 bytes = 3.814 MB (1024^2) or 4.0 MB (1000^2)
323        // Using 1024-based calculation: 4,000,000 / (1024 * 1024) ≈ 3.814 MB
324        let expected_mb = 4_000_000.0 / (1024.0 * 1024.0);
325        assert!((stats.total_mb() - expected_mb).abs() < 0.01);
326
327        let expected_gb = 4_000_000.0 / (1024.0 * 1024.0 * 1024.0);
328        assert!((stats.total_gb() - expected_gb).abs() < 0.0001);
329    }
330
331    #[test]
332    fn test_gpu_memory_pool() {
333        let mut pool = GPUMemoryPool::new(Device::Cpu);
334
335        let tensor = pool
336            .allocate("test_tensor".to_string(), &[100, 100], DType::F32)
337            .unwrap();
338
339        assert_eq!(tensor.dims(), &[100, 100]);
340        assert_eq!(pool.stats().tensor_count, 1);
341        assert_eq!(pool.stats().total_bytes(), 100 * 100 * 4);
342
343        pool.release("test_tensor");
344        assert_eq!(pool.stats().tensor_count, 0);
345    }
346
347    #[test]
348    fn test_get_device() {
349        let tensor = Tensor::zeros((10, 10), DType::F32, &Device::Cpu).unwrap();
350        let device = TensorTransfer::get_device(&tensor);
351        assert!(matches!(device, Device::Cpu));
352    }
353
354    #[test]
355    fn test_tensor_size_calculation() {
356        let tensor_f32 = Tensor::zeros((10, 20), DType::F32, &Device::Cpu).unwrap();
357        assert_eq!(MemoryStats::tensor_size(&tensor_f32), 10 * 20 * 4);
358
359        let tensor_f16 = Tensor::zeros((10, 20), DType::F16, &Device::Cpu).unwrap();
360        assert_eq!(MemoryStats::tensor_size(&tensor_f16), 10 * 20 * 2);
361
362        let tensor_i64 = Tensor::zeros((5, 5), DType::I64, &Device::Cpu).unwrap();
363        assert_eq!(MemoryStats::tensor_size(&tensor_i64), 5 * 5 * 8);
364    }
365}