memkit_gpu/
backend.rs

1//! GPU backend trait and implementations.
2
3use crate::buffer::MkBufferUsage;
4use crate::memory::MkMemoryType;
5use std::collections::HashMap;
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::sync::{Arc, RwLock};
8
9/// Trait for GPU backend implementations.
10///
11/// This trait abstracts over different GPU APIs (Vulkan, Metal, DX12).
12/// Implement this trait to add support for a new graphics backend.
13pub trait MkGpuBackend: Sized + Send + Sync {
14    /// The native buffer handle type for this backend.
15    type BufferHandle: Clone + Send + Sync;
16    
17    /// Error type for this backend.
18    type Error: std::error::Error + Send + Sync + 'static;
19    
20    /// Get backend name.
21    fn name(&self) -> &'static str;
22    
23    /// Get backend capabilities.
24    fn capabilities(&self) -> MkGpuCapabilities;
25    
26    /// Create a new buffer.
27    fn create_buffer(
28        &self,
29        size: usize,
30        usage: MkBufferUsage,
31        memory_type: MkMemoryType,
32    ) -> Result<Self::BufferHandle, Self::Error>;
33    
34    /// Destroy a buffer.
35    fn destroy_buffer(&self, handle: &Self::BufferHandle);
36    
37    /// Map buffer memory for CPU access.
38    /// Returns None if the buffer is not host-visible.
39    fn map(&self, handle: &Self::BufferHandle) -> Option<*mut u8>;
40    
41    /// Unmap buffer memory.
42    fn unmap(&self, handle: &Self::BufferHandle);
43    
44    /// Flush mapped memory range to make writes visible to GPU.
45    fn flush(&self, handle: &Self::BufferHandle, offset: usize, size: usize);
46    
47    /// Invalidate mapped memory range to make GPU writes visible to CPU.
48    fn invalidate(&self, handle: &Self::BufferHandle, offset: usize, size: usize);
49    
50    /// Copy data between buffers.
51    fn copy_buffer(
52        &self,
53        src: &Self::BufferHandle,
54        dst: &Self::BufferHandle,
55        size: usize,
56    ) -> Result<(), Self::Error>;
57    
58    /// Copy data between buffers with offsets.
59    fn copy_buffer_regions(
60        &self,
61        src: &Self::BufferHandle,
62        src_offset: usize,
63        dst: &Self::BufferHandle,
64        dst_offset: usize,
65        size: usize,
66    ) -> Result<(), Self::Error>;
67    
68    /// Wait for all operations to complete.
69    fn wait_idle(&self) -> Result<(), Self::Error>;
70}
71
72/// GPU capabilities query.
73#[derive(Debug, Clone, Default)]
74pub struct MkGpuCapabilities {
75    /// Maximum buffer size in bytes.
76    pub max_buffer_size: usize,
77    /// Maximum number of allocations.
78    pub max_allocations: usize,
79    /// Whether unified memory is supported.
80    pub unified_memory: bool,
81    /// Whether coherent memory is available.
82    pub coherent_memory: bool,
83    /// Device name.
84    pub device_name: String,
85    /// Vendor name.
86    pub vendor_name: String,
87}
88
89// ============================================================================
90// Dummy Backend - Full simulation for testing and CPU-only usage
91// ============================================================================
92
93/// Dummy backend for testing and CPU-only usage.
94///
95/// This backend simulates GPU operations entirely in CPU memory.
96/// It's useful for:
97/// - Unit testing without a GPU
98/// - Running on systems without GPU support
99/// - Prototyping before implementing a real backend
100pub struct DummyBackend {
101    next_id: AtomicU64,
102    buffers: Arc<RwLock<HashMap<u64, DummyBuffer>>>,
103    config: DummyBackendConfig,
104}
105
106/// Configuration for the dummy backend.
107#[derive(Debug, Clone)]
108pub struct DummyBackendConfig {
109    /// Maximum buffer size (default: 1 GB).
110    pub max_buffer_size: usize,
111    /// Simulate device-local memory (allocate but don't allow mapping).
112    pub simulate_device_local: bool,
113    /// Simulate transfer delays (in microseconds).
114    pub transfer_delay_us: u64,
115}
116
117impl Default for DummyBackendConfig {
118    fn default() -> Self {
119        Self {
120            max_buffer_size: 1024 * 1024 * 1024, // 1 GB
121            simulate_device_local: true,
122            transfer_delay_us: 0,
123        }
124    }
125}
126
127/// Internal buffer storage for dummy backend.
128struct DummyBuffer {
129    data: *mut u8,
130    size: usize,
131    usage: MkBufferUsage,
132    memory_type: MkMemoryType,
133    mapped: bool,
134}
135
136// Safety: DummyBuffer is protected by RwLock in DummyBackend
137unsafe impl Send for DummyBuffer {}
138unsafe impl Sync for DummyBuffer {}
139
140impl DummyBackend {
141    /// Create a new dummy backend with default config.
142    pub fn new() -> Self {
143        Self::with_config(DummyBackendConfig::default())
144    }
145
146    /// Create a new dummy backend with custom config.
147    pub fn with_config(config: DummyBackendConfig) -> Self {
148        Self {
149            next_id: AtomicU64::new(1),
150            buffers: Arc::new(RwLock::new(HashMap::new())),
151            config,
152        }
153    }
154
155    /// Get the number of allocated buffers.
156    pub fn buffer_count(&self) -> usize {
157        self.buffers.read().unwrap().len()
158    }
159
160    /// Get total allocated memory.
161    pub fn total_allocated(&self) -> usize {
162        self.buffers.read().unwrap().values().map(|b| b.size).sum()
163    }
164}
165
166impl Default for DummyBackend {
167    fn default() -> Self {
168        Self::new()
169    }
170}
171
172impl Drop for DummyBackend {
173    fn drop(&mut self) {
174        // Clean up all allocated buffers
175        let buffers = std::mem::take(&mut *self.buffers.write().unwrap());
176        for (_, buffer) in buffers {
177            if !buffer.data.is_null() {
178                let layout = std::alloc::Layout::from_size_align(buffer.size, 8).unwrap();
179                unsafe { std::alloc::dealloc(buffer.data, layout) };
180            }
181        }
182    }
183}
184
185/// Dummy buffer handle.
186#[derive(Clone, Debug)]
187pub struct DummyBufferHandle {
188    id: u64,
189    size: usize,
190    memory_type: MkMemoryType,
191}
192
193// Safety: DummyBufferHandle is just an ID, actual data is in DummyBackend
194unsafe impl Send for DummyBufferHandle {}
195unsafe impl Sync for DummyBufferHandle {}
196
197impl DummyBufferHandle {
198    /// Get the buffer size.
199    pub fn size(&self) -> usize {
200        self.size
201    }
202
203    /// Get the memory type.
204    pub fn memory_type(&self) -> MkMemoryType {
205        self.memory_type
206    }
207}
208
209/// Dummy backend error.
210#[derive(Debug, Clone)]
211pub enum DummyError {
212    /// Buffer size exceeds maximum.
213    BufferTooLarge { requested: usize, max: usize },
214    /// Memory allocation failed.
215    AllocationFailed,
216    /// Buffer not found.
217    BufferNotFound(u64),
218    /// Buffer is not mappable.
219    NotMappable,
220    /// Buffer already mapped.
221    AlreadyMapped,
222    /// Other error.
223    Other(String),
224}
225
226impl std::fmt::Display for DummyError {
227    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228        match self {
229            DummyError::BufferTooLarge { requested, max } => {
230                write!(f, "Buffer size {} exceeds maximum {}", requested, max)
231            }
232            DummyError::AllocationFailed => write!(f, "Memory allocation failed"),
233            DummyError::BufferNotFound(id) => write!(f, "Buffer {} not found", id),
234            DummyError::NotMappable => write!(f, "Buffer is not mappable (device-local)"),
235            DummyError::AlreadyMapped => write!(f, "Buffer is already mapped"),
236            DummyError::Other(msg) => write!(f, "{}", msg),
237        }
238    }
239}
240
241impl std::error::Error for DummyError {}
242
243impl MkGpuBackend for DummyBackend {
244    type BufferHandle = DummyBufferHandle;
245    type Error = DummyError;
246
247    fn name(&self) -> &'static str {
248        "Dummy (CPU Simulation)"
249    }
250
251    fn capabilities(&self) -> MkGpuCapabilities {
252        MkGpuCapabilities {
253            max_buffer_size: self.config.max_buffer_size,
254            max_allocations: usize::MAX,
255            unified_memory: true,
256            coherent_memory: true,
257            device_name: "Dummy GPU".to_string(),
258            vendor_name: "memkit".to_string(),
259        }
260    }
261
262    fn create_buffer(
263        &self,
264        size: usize,
265        usage: MkBufferUsage,
266        memory_type: MkMemoryType,
267    ) -> Result<Self::BufferHandle, Self::Error> {
268        if size > self.config.max_buffer_size {
269            return Err(DummyError::BufferTooLarge {
270                requested: size,
271                max: self.config.max_buffer_size,
272            });
273        }
274
275        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
276
277        // Allocate memory
278        let data = if size > 0 {
279            let layout = std::alloc::Layout::from_size_align(size, 8)
280                .map_err(|_| DummyError::AllocationFailed)?;
281            let ptr = unsafe { std::alloc::alloc_zeroed(layout) };
282            if ptr.is_null() {
283                return Err(DummyError::AllocationFailed);
284            }
285            ptr
286        } else {
287            std::ptr::null_mut()
288        };
289
290        let buffer = DummyBuffer {
291            data,
292            size,
293            usage,
294            memory_type,
295            mapped: false,
296        };
297
298        self.buffers.write().unwrap().insert(id, buffer);
299
300        Ok(DummyBufferHandle { id, size, memory_type })
301    }
302
303    fn destroy_buffer(&self, handle: &Self::BufferHandle) {
304        if let Some(buffer) = self.buffers.write().unwrap().remove(&handle.id) {
305            if !buffer.data.is_null() {
306                let layout = std::alloc::Layout::from_size_align(buffer.size, 8).unwrap();
307                unsafe { std::alloc::dealloc(buffer.data, layout) };
308            }
309        }
310    }
311
312    fn map(&self, handle: &Self::BufferHandle) -> Option<*mut u8> {
313        let mut buffers = self.buffers.write().unwrap();
314        let buffer = buffers.get_mut(&handle.id)?;
315
316        // Check if mappable
317        if self.config.simulate_device_local && handle.memory_type == MkMemoryType::DeviceLocal {
318            return None;
319        }
320
321        buffer.mapped = true;
322        Some(buffer.data)
323    }
324
325    fn unmap(&self, handle: &Self::BufferHandle) {
326        if let Some(buffer) = self.buffers.write().unwrap().get_mut(&handle.id) {
327            buffer.mapped = false;
328        }
329    }
330
331    fn flush(&self, _handle: &Self::BufferHandle, _offset: usize, _size: usize) {
332        // No-op for dummy - memory is coherent
333    }
334
335    fn invalidate(&self, _handle: &Self::BufferHandle, _offset: usize, _size: usize) {
336        // No-op for dummy - memory is coherent
337    }
338
339    fn copy_buffer(
340        &self,
341        src: &Self::BufferHandle,
342        dst: &Self::BufferHandle,
343        size: usize,
344    ) -> Result<(), Self::Error> {
345        self.copy_buffer_regions(src, 0, dst, 0, size)
346    }
347
348    fn copy_buffer_regions(
349        &self,
350        src: &Self::BufferHandle,
351        src_offset: usize,
352        dst: &Self::BufferHandle,
353        dst_offset: usize,
354        size: usize,
355    ) -> Result<(), Self::Error> {
356        // Simulate transfer delay
357        if self.config.transfer_delay_us > 0 {
358            std::thread::sleep(std::time::Duration::from_micros(self.config.transfer_delay_us));
359        }
360
361        let buffers = self.buffers.read().unwrap();
362        
363        let src_buf = buffers.get(&src.id)
364            .ok_or(DummyError::BufferNotFound(src.id))?;
365        let dst_buf = buffers.get(&dst.id)
366            .ok_or(DummyError::BufferNotFound(dst.id))?;
367
368        // Bounds check
369        let copy_size = size
370            .min(src_buf.size.saturating_sub(src_offset))
371            .min(dst_buf.size.saturating_sub(dst_offset));
372
373        if copy_size > 0 && !src_buf.data.is_null() && !dst_buf.data.is_null() {
374            unsafe {
375                std::ptr::copy_nonoverlapping(
376                    src_buf.data.add(src_offset),
377                    dst_buf.data.add(dst_offset),
378                    copy_size,
379                );
380            }
381        }
382
383        Ok(())
384    }
385
386    fn wait_idle(&self) -> Result<(), Self::Error> {
387        Ok(())
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394
395    #[test]
396    fn test_dummy_backend_create_destroy() {
397        let backend = DummyBackend::new();
398        assert_eq!(backend.buffer_count(), 0);
399
400        let handle = backend.create_buffer(
401            1024,
402            MkBufferUsage::VERTEX,
403            MkMemoryType::HostVisible,
404        ).unwrap();
405
406        assert_eq!(backend.buffer_count(), 1);
407        assert_eq!(backend.total_allocated(), 1024);
408
409        backend.destroy_buffer(&handle);
410        assert_eq!(backend.buffer_count(), 0);
411    }
412
413    #[test]
414    fn test_dummy_backend_map_write_read() {
415        let backend = DummyBackend::new();
416
417        let handle = backend.create_buffer(
418            1024,
419            MkBufferUsage::VERTEX,
420            MkMemoryType::HostVisible,
421        ).unwrap();
422
423        // Map and write
424        let ptr = backend.map(&handle).unwrap();
425        unsafe {
426            for i in 0..256 {
427                *ptr.add(i) = i as u8;
428            }
429        }
430        backend.unmap(&handle);
431
432        // Map and read
433        let ptr2 = backend.map(&handle).unwrap();
434        for i in 0..256 {
435            assert_eq!(unsafe { *ptr2.add(i) }, i as u8);
436        }
437
438        backend.destroy_buffer(&handle);
439    }
440
441    #[test]
442    fn test_dummy_backend_copy() {
443        let backend = DummyBackend::new();
444
445        let src = backend.create_buffer(256, MkBufferUsage::TRANSFER_SRC, MkMemoryType::HostVisible).unwrap();
446        let dst = backend.create_buffer(256, MkBufferUsage::TRANSFER_DST, MkMemoryType::HostVisible).unwrap();
447
448        // Write to source
449        let ptr = backend.map(&src).unwrap();
450        unsafe {
451            for i in 0..256 {
452                *ptr.add(i) = i as u8;
453            }
454        }
455        backend.unmap(&src);
456
457        // Copy
458        backend.copy_buffer(&src, &dst, 256).unwrap();
459
460        // Verify destination
461        let ptr2 = backend.map(&dst).unwrap();
462        for i in 0..256 {
463            assert_eq!(unsafe { *ptr2.add(i) }, i as u8);
464        }
465
466        backend.destroy_buffer(&src);
467        backend.destroy_buffer(&dst);
468    }
469
470    #[test]
471    fn test_dummy_backend_device_local_not_mappable() {
472        let backend = DummyBackend::new();
473
474        let handle = backend.create_buffer(
475            1024,
476            MkBufferUsage::VERTEX,
477            MkMemoryType::DeviceLocal,
478        ).unwrap();
479
480        // Should not be mappable
481        assert!(backend.map(&handle).is_none());
482
483        backend.destroy_buffer(&handle);
484    }
485
486    #[test]
487    fn test_dummy_backend_capabilities() {
488        let backend = DummyBackend::new();
489        let caps = backend.capabilities();
490
491        assert_eq!(caps.device_name, "Dummy GPU");
492        assert!(caps.unified_memory);
493        assert!(caps.coherent_memory);
494    }
495
496    #[test]
497    fn test_dummy_backend_buffer_too_large() {
498        let config = DummyBackendConfig {
499            max_buffer_size: 1024,
500            ..Default::default()
501        };
502        let backend = DummyBackend::with_config(config);
503
504        let result = backend.create_buffer(
505            2048,
506            MkBufferUsage::VERTEX,
507            MkMemoryType::HostVisible,
508        );
509
510        assert!(matches!(result, Err(DummyError::BufferTooLarge { .. })));
511    }
512}