daedalus_gpu/
convert.rs

1use crate::{GpuContextHandle, GpuError, GpuImageHandle, upload_rgba8_texture};
2use image::{DynamicImage, GenericImageView, GrayImage, RgbImage, RgbaImage};
3use std::any::Any;
4use std::sync::Arc;
5
6/// Opt-in bridge to allow CPU types to participate in GPU segments.
7/// Users implement this for their own types to describe how to upload/download.
8pub trait GpuSendable {
9    type GpuRepr;
10
11    /// Upload CPU data to a GPU representation.
12    fn upload(self, _ctx: &GpuContextHandle) -> Result<Self::GpuRepr, GpuError>
13    where
14        Self: Sized,
15    {
16        Err(GpuError::Unsupported)
17    }
18
19    /// Download a GPU representation back to CPU data.
20    fn download(_gpu: &Self::GpuRepr, _ctx: &GpuContextHandle) -> Result<Self, GpuError>
21    where
22        Self: Sized,
23    {
24        Err(GpuError::Unsupported)
25    }
26}
27
28/// Generic payload that can carry either CPU data or a GPU representation.
29#[derive(Debug, Clone)]
30pub enum Payload<T: GpuSendable> {
31    Cpu(T),
32    Gpu(T::GpuRepr),
33}
34
35impl<T: GpuSendable> Payload<T> {
36    pub fn is_gpu(&self) -> bool {
37        matches!(self, Payload::Gpu(_))
38    }
39
40    pub fn as_cpu(&self) -> Option<&T> {
41        match self {
42            Payload::Cpu(t) => Some(t),
43            _ => None,
44        }
45    }
46
47    pub fn as_gpu(&self) -> Option<&T::GpuRepr> {
48        match self {
49            Payload::Gpu(g) => Some(g),
50            _ => None,
51        }
52    }
53}
54
55impl Payload<DynamicImage> {
56    /// Return image dimensions without forcing the caller to pattern match.
57    pub fn dimensions(&self) -> (u32, u32) {
58        match self {
59            Payload::Cpu(img) => img.dimensions(),
60            Payload::Gpu(handle) => (handle.width, handle.height),
61        }
62    }
63
64    /// Ensure this payload is resident on GPU; uploads if necessary and returns the handle and dimensions.
65    pub fn into_gpu(self, ctx: &GpuContextHandle) -> Result<(GpuImageHandle, u32, u32), GpuError> {
66        match self {
67            Payload::Gpu(handle) => Ok((handle.clone(), handle.width, handle.height)),
68            Payload::Cpu(img) => {
69                let rgba = img.to_rgba8();
70                let (w, h) = rgba.dimensions();
71                let handle = upload_rgba8_texture(ctx, w, h, rgba.as_raw())?;
72                Ok((handle, w, h))
73            }
74        }
75    }
76}
77
78#[derive(Clone)]
79enum ErasedPayloadInner {
80    Any(Arc<dyn Any + Send + Sync>),
81}
82
83/// Type-erased payload wrapper so runtimes can carry GPU-capable data without monomorphizing.
84#[derive(Clone)]
85pub struct ErasedPayload {
86    is_gpu: bool,
87    inner: ErasedPayloadInner,
88    cpu_type_name: &'static str,
89    upload: fn(&ErasedPayloadInner, &GpuContextHandle) -> Result<ErasedPayload, GpuError>,
90    download: fn(&ErasedPayloadInner, &GpuContextHandle) -> Result<ErasedPayload, GpuError>,
91}
92
93
94impl ErasedPayload {
95    fn cross_dylib_ref<'a, T: 'static>(any: &'a dyn Any, expected: &str) -> Option<&'a T> {
96        let actual = std::any::type_name::<T>();
97        if expected != actual && !expected.ends_with(actual) && !actual.ends_with(expected) {
98            return None;
99        }
100        let size_ok = std::mem::size_of_val(any) == std::mem::size_of::<T>();
101        let align_ok = std::mem::align_of_val(any) == std::mem::align_of::<T>();
102        if !(size_ok && align_ok) {
103            if std::env::var_os("DAEDALUS_TRACE_PAYLOAD_CROSS_DYLIB").is_some() {
104                eprintln!(
105                    "daedalus-gpu: cross_dylib_ref size/align mismatch expected={} actual={} size_any={} size_t={} align_any={} align_t={}",
106                    expected,
107                    actual,
108                    std::mem::size_of_val(any),
109                    std::mem::size_of::<T>(),
110                    std::mem::align_of_val(any),
111                    std::mem::align_of::<T>(),
112                );
113            }
114            return None;
115        }
116        let (data_ptr, _): (*const (), *const ()) = unsafe { std::mem::transmute(any) };
117        Some(unsafe { &*(data_ptr as *const T) })
118    }
119
120    pub fn from_cpu<T>(val: T) -> Self
121    where
122        T: GpuSendable + Clone + Send + Sync + 'static,
123        T::GpuRepr: Clone + Send + Sync + 'static,
124    {
125        fn upload<T>(inner: &ErasedPayloadInner, ctx: &GpuContextHandle) -> Result<ErasedPayload, GpuError>
126        where
127            T: GpuSendable + Clone + Send + Sync + 'static,
128            T::GpuRepr: Clone + Send + Sync + 'static,
129        {
130            let ErasedPayloadInner::Any(inner) = inner;
131            let cpu = inner
132                .downcast_ref::<T>()
133                .ok_or(GpuError::Unsupported)?
134                .clone();
135            let handle = cpu.upload(ctx)?;
136            Ok(ErasedPayload::from_gpu::<T>(handle))
137        }
138
139        fn download<T>(inner: &ErasedPayloadInner, _ctx: &GpuContextHandle) -> Result<ErasedPayload, GpuError>
140        where
141            T: GpuSendable + Clone + Send + Sync + 'static,
142            T::GpuRepr: Clone + Send + Sync + 'static,
143        {
144            let ErasedPayloadInner::Any(inner) = inner;
145            let cpu = inner
146                .downcast_ref::<T>()
147                .ok_or(GpuError::Unsupported)?
148                .clone();
149            Ok(ErasedPayload::from_cpu::<T>(cpu))
150        }
151
152        Self {
153            is_gpu: false,
154            inner: ErasedPayloadInner::Any(Arc::new(val)),
155            cpu_type_name: std::any::type_name::<T>(),
156            upload: upload::<T>,
157            download: download::<T>,
158        }
159    }
160
161    pub fn from_gpu<T>(val: T::GpuRepr) -> Self
162    where
163        T: GpuSendable + Clone + Send + Sync + 'static,
164        T::GpuRepr: Clone + Send + Sync + 'static,
165    {
166        fn upload<T>(inner: &ErasedPayloadInner, _ctx: &GpuContextHandle) -> Result<ErasedPayload, GpuError>
167        where
168            T: GpuSendable + Clone + Send + Sync + 'static,
169            T::GpuRepr: Clone + Send + Sync + 'static,
170        {
171            let ErasedPayloadInner::Any(inner) = inner;
172            let g = inner
173                .downcast_ref::<T::GpuRepr>()
174                .ok_or(GpuError::Unsupported)?
175                .clone();
176            Ok(ErasedPayload::from_gpu::<T>(g))
177        }
178
179        fn download<T>(inner: &ErasedPayloadInner, ctx: &GpuContextHandle) -> Result<ErasedPayload, GpuError>
180        where
181            T: GpuSendable + Clone + Send + Sync + 'static,
182            T::GpuRepr: Clone + Send + Sync + 'static,
183        {
184            let ErasedPayloadInner::Any(inner) = inner;
185            let g = inner
186                .downcast_ref::<T::GpuRepr>()
187                .ok_or(GpuError::Unsupported)?;
188            let cpu = T::download(g, ctx)?;
189            Ok(ErasedPayload::from_cpu::<T>(cpu))
190        }
191
192        Self {
193            is_gpu: true,
194            inner: ErasedPayloadInner::Any(Arc::new(val)),
195            cpu_type_name: std::any::type_name::<T>(),
196            upload: upload::<T>,
197            download: download::<T>,
198        }
199    }
200
201    pub fn is_gpu(&self) -> bool {
202        self.is_gpu
203    }
204
205    pub fn upload(&self, ctx: &GpuContextHandle) -> Result<ErasedPayload, GpuError> {
206        (self.upload)(&self.inner, ctx)
207    }
208
209    pub fn download(&self, ctx: &GpuContextHandle) -> Result<ErasedPayload, GpuError> {
210        (self.download)(&self.inner, ctx)
211    }
212
213    pub fn as_cpu<T>(&self) -> Option<&T>
214    where
215        T: GpuSendable + 'static,
216    {
217        if self.is_gpu {
218            None
219        } else {
220            match &self.inner {
221                ErasedPayloadInner::Any(inner) => {
222                    inner
223                        .downcast_ref::<T>()
224                        .or_else(|| Self::cross_dylib_ref::<T>(inner.as_ref(), self.cpu_type_name))
225                }
226            }
227        }
228    }
229
230    pub fn as_gpu<T>(&self) -> Option<&T::GpuRepr>
231    where
232        T: GpuSendable + 'static,
233        T::GpuRepr: 'static,
234    {
235        if self.is_gpu {
236            match &self.inner {
237                ErasedPayloadInner::Any(inner) => inner.downcast_ref::<T::GpuRepr>(),
238            }
239        } else {
240            None
241        }
242    }
243
244    pub fn try_downcast_cpu_any<T>(&self) -> Option<T>
245    where
246        T: Clone + Send + Sync + 'static,
247    {
248        if self.is_gpu {
249            return None;
250        }
251        let ErasedPayloadInner::Any(inner) = &self.inner;
252        if let Some(v) = inner.downcast_ref::<T>() {
253            return Some(v.clone());
254        }
255        Self::cross_dylib_ref::<T>(inner.as_ref(), self.cpu_type_name).cloned()
256    }
257
258    pub fn clone_cpu<T>(&self) -> Option<T>
259    where
260        T: GpuSendable + Clone + 'static,
261    {
262        if self.is_gpu {
263            return None;
264        }
265        let ErasedPayloadInner::Any(inner) = &self.inner;
266        if let Some(v) = inner.downcast_ref::<T>().cloned() {
267            return Some(v);
268        }
269        if self.cpu_type_name != std::any::type_name::<T>() {
270            return None;
271        }
272        Self::cross_dylib_ref::<T>(inner.as_ref(), self.cpu_type_name).cloned()
273    }
274
275    pub fn take_cpu<T>(self) -> Result<T, Self>
276    where
277        T: GpuSendable + Clone + Send + Sync + 'static,
278    {
279        if self.is_gpu {
280            return Err(self);
281        }
282
283        let ErasedPayload {
284            is_gpu,
285            inner,
286            cpu_type_name,
287            upload,
288            download,
289        } = self;
290        let restore = |inner| ErasedPayload {
291            is_gpu,
292            inner,
293            cpu_type_name,
294            upload,
295            download,
296        };
297
298        match inner {
299            ErasedPayloadInner::Any(inner) => match Arc::downcast::<T>(inner) {
300                Ok(arc) => match Arc::try_unwrap(arc) {
301                    Ok(v) => Ok(v),
302                    Err(arc) => Err(restore(ErasedPayloadInner::Any(arc))),
303                },
304                Err(arc) => {
305                    if let Some(v) = Self::cross_dylib_ref::<T>(arc.as_ref(), cpu_type_name) {
306                        return Ok(v.clone());
307                    }
308                    Err(restore(ErasedPayloadInner::Any(arc)))
309                }
310            },
311        }
312    }
313
314    pub fn take_cpu_any<T>(self) -> Result<T, Self>
315    where
316        T: Clone + Send + Sync + 'static,
317    {
318        if self.is_gpu {
319            return Err(self);
320        }
321
322        let ErasedPayload {
323            is_gpu,
324            inner,
325            cpu_type_name,
326            upload,
327            download,
328        } = self;
329        let restore = |inner| ErasedPayload {
330            is_gpu,
331            inner,
332            cpu_type_name,
333            upload,
334            download,
335        };
336
337        match inner {
338            ErasedPayloadInner::Any(inner) => match Arc::downcast::<T>(inner) {
339                Ok(arc) => match Arc::try_unwrap(arc) {
340                    Ok(v) => Ok(v),
341                    Err(arc) => Ok((*arc).clone()),
342                },
343                Err(arc) => {
344                    if let Some(v) = Self::cross_dylib_ref::<T>(arc.as_ref(), cpu_type_name) {
345                        return Ok(v.clone());
346                    }
347                    Err(restore(ErasedPayloadInner::Any(arc)))
348                }
349            },
350        }
351    }
352
353    pub fn clone_gpu<T>(&self) -> Option<T::GpuRepr>
354    where
355        T: GpuSendable + 'static,
356        T::GpuRepr: Clone + 'static,
357    {
358        if !self.is_gpu {
359            return None;
360        }
361        let ErasedPayloadInner::Any(inner) = &self.inner;
362        if let Some(v) = inner.downcast_ref::<T::GpuRepr>().cloned() {
363            return Some(v);
364        }
365        if self.cpu_type_name != std::any::type_name::<T>() {
366            return None;
367        }
368        None
369    }
370}
371
372impl std::fmt::Debug for ErasedPayload {
373    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374        f.debug_struct("ErasedPayload")
375            .field("is_gpu", &self.is_gpu)
376            .field("cpu_type_name", &self.cpu_type_name)
377            .finish()
378    }
379}
380
381impl GpuSendable for DynamicImage {
382    type GpuRepr = GpuImageHandle;
383
384    fn upload(self, ctx: &GpuContextHandle) -> Result<Self::GpuRepr, GpuError> {
385        let rgba = self.to_rgba8();
386        let (width, height) = rgba.dimensions();
387        upload_rgba8_texture(ctx, width, height, rgba.as_raw())
388    }
389
390    fn download(gpu: &Self::GpuRepr, ctx: &GpuContextHandle) -> Result<Self, GpuError> {
391        let bytes = ctx.read_texture(gpu)?;
392        let buf = image::ImageBuffer::<image::Rgba<u8>, _>::from_raw(gpu.width, gpu.height, bytes)
393            .ok_or(GpuError::AllocationFailed)?;
394        Ok(DynamicImage::ImageRgba8(buf))
395    }
396}
397
398impl GpuSendable for RgbaImage {
399    type GpuRepr = GpuImageHandle;
400
401    fn upload(self, ctx: &GpuContextHandle) -> Result<Self::GpuRepr, GpuError> {
402        let (width, height) = self.dimensions();
403        upload_rgba8_texture(ctx, width, height, self.as_raw())
404    }
405
406    fn download(gpu: &Self::GpuRepr, ctx: &GpuContextHandle) -> Result<Self, GpuError> {
407        let bytes = ctx.read_texture(gpu)?;
408        image::ImageBuffer::<image::Rgba<u8>, _>::from_raw(gpu.width, gpu.height, bytes)
409            .ok_or(GpuError::AllocationFailed)
410    }
411}
412
413impl GpuSendable for RgbImage {
414    type GpuRepr = GpuImageHandle;
415
416    fn upload(self, ctx: &GpuContextHandle) -> Result<Self::GpuRepr, GpuError> {
417        let (width, height) = self.dimensions();
418        let rgba = image::ImageBuffer::from_fn(width, height, |x, y| {
419            let p = self.get_pixel(x, y);
420            image::Rgba([p[0], p[1], p[2], 255])
421        });
422        upload_rgba8_texture(ctx, width, height, rgba.as_raw())
423    }
424
425    fn download(gpu: &Self::GpuRepr, ctx: &GpuContextHandle) -> Result<Self, GpuError> {
426        let bytes = ctx.read_texture(gpu)?;
427        let rgba = image::ImageBuffer::<image::Rgba<u8>, _>::from_raw(gpu.width, gpu.height, bytes)
428            .ok_or(GpuError::AllocationFailed)?;
429        Ok(image::ImageBuffer::from_fn(
430            gpu.width,
431            gpu.height,
432            |x, y| {
433                let p = rgba.get_pixel(x, y);
434                image::Rgb([p[0], p[1], p[2]])
435            },
436        ))
437    }
438}
439
440impl GpuSendable for GrayImage {
441    type GpuRepr = GpuImageHandle;
442
443    fn upload(self, ctx: &GpuContextHandle) -> Result<Self::GpuRepr, GpuError> {
444        let (width, height) = self.dimensions();
445        // Prefer an R8 upload when supported (reduces memory and readback bandwidth).
446        if ctx
447            .capabilities()
448            .supported_formats
449            .iter()
450            .any(|f| matches!(f, crate::GpuFormat::R8Unorm))
451        {
452            return crate::upload_r8_texture(ctx, width, height, self.as_raw());
453        }
454
455        // Fallback: expand to RGBA8.
456        let mut rgba = Vec::with_capacity(
457            (width as usize)
458                .saturating_mul(height as usize)
459                .saturating_mul(4),
460        );
461        for &v in self.as_raw() {
462            rgba.extend_from_slice(&[v, v, v, 255]);
463        }
464        upload_rgba8_texture(ctx, width, height, &rgba)
465    }
466
467    fn download(gpu: &Self::GpuRepr, ctx: &GpuContextHandle) -> Result<Self, GpuError> {
468        let bytes = ctx.read_texture(gpu)?;
469        match gpu.format {
470            crate::GpuFormat::R8Unorm => image::ImageBuffer::from_raw(gpu.width, gpu.height, bytes)
471                .ok_or(GpuError::AllocationFailed),
472            _ => {
473                let mut gray =
474                    Vec::with_capacity((gpu.width as usize).saturating_mul(gpu.height as usize));
475                for rgba in bytes.chunks_exact(4) {
476                    gray.push(rgba[0]);
477                }
478                image::ImageBuffer::from_raw(gpu.width, gpu.height, gray)
479                    .ok_or(GpuError::AllocationFailed)
480            }
481        }
482    }
483}
484
485impl Payload<DynamicImage> {
486    /// Get RGBA8 bytes + dimensions, downloading from GPU if needed.
487    pub fn to_rgba_bytes(
488        &self,
489        ctx: Option<&GpuContextHandle>,
490    ) -> Result<(Vec<u8>, u32, u32), GpuError> {
491        match self {
492            Payload::Cpu(cpu) => {
493                let rgba = cpu.to_rgba8();
494                let (w, h) = rgba.dimensions();
495                Ok((rgba.into_raw(), w, h))
496            }
497            Payload::Gpu(handle) => {
498                let ctx = ctx.ok_or(GpuError::Unsupported)?;
499                let bytes = ctx.read_texture(handle)?;
500                Ok((bytes, handle.width, handle.height))
501            }
502        }
503    }
504
505    /// Construct a payload from RGBA8 bytes, uploading if a GPU context is available.
506    pub fn from_rgba_bytes(
507        ctx: Option<&GpuContextHandle>,
508        bytes: Vec<u8>,
509        w: u32,
510        h: u32,
511    ) -> Result<Self, GpuError> {
512        if let Some(ctx) = ctx {
513            upload_rgba8_texture(ctx, w, h, &bytes).map(Payload::Gpu)
514        } else {
515            image::ImageBuffer::<image::Rgba<u8>, _>::from_raw(w, h, bytes)
516                .map(DynamicImage::ImageRgba8)
517                .map(Payload::Cpu)
518                .ok_or(GpuError::AllocationFailed)
519        }
520    }
521}