Skip to main content

easy_wgpu/
texture.rs

1use gloss_img::DynImage;
2use image::imageops::FilterType;
3// use image::GenericImage;
4use image::{EncodableLayout, GenericImageView, ImageBuffer};
5use log::{debug, warn};
6use pollster::FutureExt;
7use std::borrow::Cow;
8use wgpu::{util::DeviceExt, CommandEncoderDescriptor, TextureFormat}; //enabled create_texture_with_data
9
10// use gloss_utils::gloss_image;
11use gloss_utils::numerical;
12
13use crate::{buffer::Buffer, mipmap::RenderMipmapGenerator};
14
15#[cfg(feature = "burn-torch")]
16use crate::error::CudaInteropError;
17#[cfg(feature = "burn-torch")]
18use cust_raw;
19
20#[cfg(feature = "burn-torch")]
21use std::sync::Arc;
22#[cfg(feature = "burn-torch")]
23use tch::Tensor;
24#[cfg(feature = "burn-torch")]
25use wgpu_cuda_interop::{vulkan_wgpu_interop::WgpuBufferCudaMem, AllocSize};
26
27//aditional parameters for texture creation that usually you can leave as
28// default
29#[derive(Clone, Copy)]
30pub struct TexParams {
31    pub sample_count: u32,
32    pub mip_level_count: u32,
33    pub scale_factor: u32,
34}
35impl Default for TexParams {
36    fn default() -> Self {
37        Self {
38            sample_count: 1,
39            mip_level_count: 1,
40            scale_factor: 1,
41        }
42    }
43}
44impl TexParams {
45    pub fn from_desc(desc: &wgpu::TextureDescriptor) -> Self {
46        Self {
47            sample_count: desc.sample_count,
48            mip_level_count: desc.mip_level_count,
49            scale_factor: 1,
50        }
51    }
52    pub fn apply(&self, desc: &mut wgpu::TextureDescriptor) {
53        desc.sample_count = self.sample_count;
54        desc.mip_level_count = self.mip_level_count;
55    }
56}
57
58#[derive(Clone)]
59pub struct Texture {
60    pub texture: wgpu::Texture,
61    pub view: wgpu::TextureView,
62    pub sampler: wgpu::Sampler, //TODO should be optional or rather we should create a nearest and linear sampler as a global per frame uniform
63    // pub width: u32,
64    // pub height: u32,
65    // pub bind_group: Option<wgpu::BindGroup>, //cannot lazily create because it depends on the binding locations
66    pub tex_params: TexParams,
67
68    //optional stuff for vulkan-cuda interop
69    #[cfg(feature = "burn-torch")]
70    pub staging_buffer_backed_by_cuda_mem: Option<Arc<WgpuBufferCudaMem>>,
71}
72
73impl Texture {
74    pub fn new(
75        device: &wgpu::Device,
76        width: u32,
77        height: u32,
78        format: wgpu::TextureFormat,
79        usage: wgpu::TextureUsages,
80        tex_params: TexParams,
81    ) -> Self {
82        debug!("New texture");
83        // let format = wgpu::TextureFormat::Rgba8UnormSrgb;
84        let mut texture_desc = wgpu::TextureDescriptor {
85            size: wgpu::Extent3d {
86                width,
87                height,
88                depth_or_array_layers: 1,
89            },
90            mip_level_count: 1,
91            sample_count: 1,
92            dimension: wgpu::TextureDimension::D2,
93            format,
94            // usage: wgpu::TextureUsages::COPY_SRC | wgpu::TextureUsages::RENDER_ATTACHMENT,
95            usage,
96            label: None,
97            view_formats: if cfg!(target_arch = "wasm32") {
98                &[]
99            } else {
100                &[format.add_srgb_suffix(), format.remove_srgb_suffix()]
101            },
102        };
103        tex_params.apply(&mut texture_desc);
104
105        let texture = device.create_texture(&texture_desc);
106        let view = texture.create_view(&wgpu::TextureViewDescriptor::default());
107        let sampler = device.create_sampler(&wgpu::SamplerDescriptor {
108            address_mode_u: wgpu::AddressMode::ClampToEdge,
109            address_mode_v: wgpu::AddressMode::ClampToEdge,
110            address_mode_w: wgpu::AddressMode::ClampToEdge,
111            mag_filter: wgpu::FilterMode::Linear,
112            min_filter: wgpu::FilterMode::Linear,
113            mipmap_filter: wgpu::FilterMode::Linear,
114            ..Default::default()
115        });
116
117        Self {
118            texture,
119            view,
120            sampler,
121            tex_params,
122            // width,
123            // height,
124            // bind_group: None,
125            #[cfg(feature = "burn-torch")]
126            staging_buffer_backed_by_cuda_mem: None,
127        }
128    }
129
130    /// # Panics
131    /// Will panic if bytes cannot be decoded into a image representation
132    pub fn from_bytes(device: &wgpu::Device, queue: &wgpu::Queue, bytes: &[u8], label: &str) -> Self {
133        let img = image::load_from_memory(bytes).unwrap();
134        Self::from_image(device, queue, &img, Some(label))
135    }
136
137    pub fn from_image(device: &wgpu::Device, queue: &wgpu::Queue, img: &image::DynamicImage, label: Option<&str>) -> Self {
138        let rgba = img.to_rgba8();
139        let dimensions = img.dimensions();
140
141        let size = wgpu::Extent3d {
142            width: dimensions.0,
143            height: dimensions.1,
144            depth_or_array_layers: 1,
145        };
146        let format = wgpu::TextureFormat::Rgba8UnormSrgb;
147        let desc = wgpu::TextureDescriptor {
148            label,
149            size,
150            mip_level_count: 1,
151            sample_count: 1,
152            dimension: wgpu::TextureDimension::D2,
153            format,
154            usage: wgpu::TextureUsages::TEXTURE_BINDING | wgpu::TextureUsages::COPY_DST,
155            view_formats: if cfg!(target_arch = "wasm32") {
156                &[]
157            } else {
158                &[format.add_srgb_suffix(), format.remove_srgb_suffix()]
159            },
160        };
161        let tex_params = TexParams::from_desc(&desc);
162        let texture = device.create_texture(&desc);
163
164        queue.write_texture(
165            wgpu::TexelCopyTextureInfo {
166                aspect: wgpu::TextureAspect::All,
167                texture: &texture,
168                mip_level: 0,
169                origin: wgpu::Origin3d::ZERO,
170            },
171            &rgba,
172            wgpu::TexelCopyBufferLayout {
173                offset: 0,
174                bytes_per_row: Some(4 * dimensions.0),
175                rows_per_image: Some(dimensions.1),
176            },
177            size,
178        );
179
180        let view = texture.create_view(&wgpu::TextureViewDescriptor::default());
181        let sampler = device.create_sampler(&wgpu::SamplerDescriptor {
182            address_mode_u: wgpu::AddressMode::ClampToEdge,
183            address_mode_v: wgpu::AddressMode::ClampToEdge,
184            address_mode_w: wgpu::AddressMode::ClampToEdge,
185            mag_filter: wgpu::FilterMode::Linear,
186            min_filter: wgpu::FilterMode::Nearest,
187            mipmap_filter: wgpu::FilterMode::Nearest,
188            ..Default::default()
189        });
190
191        Self {
192            texture,
193            view,
194            sampler,
195            tex_params, /* width: dimensions.0,
196                         * height: dimensions.1,
197                         * bind_group: None, */
198            #[cfg(feature = "burn-torch")]
199            staging_buffer_backed_by_cuda_mem: None,
200        }
201    }
202
203    /// reads image from format and into this texture
204    /// if `is_srgb` is set then the reading will perform a conversion from
205    /// gamma space to linear space when sampling the texture in a shader
206    /// When writing to the texture, the opposite conversion takes place.
207    /// # Panics
208    /// Will panic if the path cannot be found
209    pub fn from_path(path: &str, device: &wgpu::Device, queue: &wgpu::Queue, is_srgb: bool) -> Self {
210        //read to cpu
211        let img = image::ImageReader::open(path).unwrap().decode().unwrap();
212        Self::from_img(
213            &img.try_into().unwrap(),
214            device,
215            queue,
216            is_srgb,
217            true,
218            false, //TODO what do we set as default here?
219            None,
220            None,
221        )
222        .block_on()
223        .unwrap()
224    }
225
226    /// # Panics
227    /// Will panic if textures that have more than 1 byte per channel or more
228    /// than 4 channels.
229    #[allow(clippy::missing_errors_doc)]
230    #[allow(clippy::too_many_lines)]
231    #[allow(clippy::too_many_arguments)]
232    pub async fn from_img(
233        img: &DynImage,
234        device: &wgpu::Device,
235        queue: &wgpu::Queue,
236        is_srgb: bool,
237        generate_mipmaps: bool,
238        mipmap_generation_cpu: bool,
239        staging_buffer: Option<&Buffer>,
240        mipmaper: Option<&RenderMipmapGenerator>,
241    ) -> Result<Self, Box<dyn std::error::Error>> {
242        let dimensions = img.dimensions();
243        let nr_channels = img.color().channel_count();
244        let bytes_per_channel = img.color().bytes_per_pixel() / nr_channels;
245        assert!(bytes_per_channel == 1, "We are only supporting textures which have 1 byte per channel.");
246        //convert 3 channels to 4 channels and keep 2 channels as 2 channels
247        let img_vec;
248        let img_buf = match nr_channels {
249            1 | 2 | 4 => img.as_bytes(),
250            3 => {
251                img_vec = img.to_rgba8().into_vec();
252                img_vec.as_bytes()
253            }
254            _ => panic!("Format with more than 4 channels not supported"),
255        };
256
257        let tex_format = Self::format_from_img(img, is_srgb);
258
259        let size = wgpu::Extent3d {
260            width: dimensions.0,
261            height: dimensions.1,
262            depth_or_array_layers: 1,
263        };
264        let mut nr_mip_maps = 1;
265        let mut usages = wgpu::TextureUsages::TEXTURE_BINDING | wgpu::TextureUsages::COPY_DST;
266        if generate_mipmaps {
267            nr_mip_maps = size.max_mips(wgpu::TextureDimension::D2);
268        }
269        if mipmaper.is_some() && generate_mipmaps {
270            usages |= RenderMipmapGenerator::required_usage();
271        }
272
273        let desc = wgpu::TextureDescriptor {
274            label: None,
275            size,
276            mip_level_count: nr_mip_maps,
277            sample_count: 1,
278            dimension: wgpu::TextureDimension::D2,
279            format: tex_format,
280            usage: usages,
281            view_formats: if cfg!(target_arch = "wasm32") {
282                &[]
283            } else {
284                &[tex_format.add_srgb_suffix(), tex_format.remove_srgb_suffix()]
285            },
286        };
287        let tex_params = TexParams::from_desc(&desc);
288
289        let texture = device.create_texture(&desc); //create with all mips but upload only 1 mip
290
291        Self::upload_single_mip(&texture, device, queue, &desc, img_buf, staging_buffer, 0).await?;
292
293        //mipmaps
294        if generate_mipmaps {
295            Self::generate_mipmaps(
296                img,
297                &texture,
298                device,
299                queue,
300                &desc,
301                nr_mip_maps,
302                mipmap_generation_cpu,
303                staging_buffer,
304                mipmaper,
305            )
306            .await?;
307        }
308
309        // let view = texture.create_view(&wgpu::TextureViewDescriptor::default());
310        let view = texture.create_view(&wgpu::TextureViewDescriptor {
311            mip_level_count: Some(nr_mip_maps),
312            ..Default::default()
313        });
314        let sampler = device.create_sampler(&wgpu::SamplerDescriptor {
315            address_mode_u: wgpu::AddressMode::ClampToEdge,
316            address_mode_v: wgpu::AddressMode::ClampToEdge,
317            address_mode_w: wgpu::AddressMode::ClampToEdge,
318            mag_filter: wgpu::FilterMode::Linear,
319            min_filter: wgpu::FilterMode::Nearest,
320            mipmap_filter: wgpu::FilterMode::Nearest,
321            ..Default::default()
322        });
323
324        Ok(Self {
325            texture,
326            view,
327            sampler,
328            tex_params, /* width: dimensions.0,
329                         * height: dimensions.1,
330                         * bind_group: None, */
331            #[cfg(feature = "burn-torch")]
332            staging_buffer_backed_by_cuda_mem: None,
333        })
334    }
335
336    /// # Panics
337    /// Will panic if the image has more than 1 byte per channel
338    #[allow(clippy::missing_errors_doc)]
339    #[allow(clippy::too_many_arguments)]
340    pub async fn update_from_img(
341        &mut self,
342        img: &DynImage,
343        device: &wgpu::Device,
344        queue: &wgpu::Queue,
345        is_srgb: bool,
346        generate_mipmaps: bool,
347        mipmap_generation_cpu: bool,
348        staging_buffer: Option<&Buffer>,
349        mipmaper: Option<&RenderMipmapGenerator>,
350    ) -> Result<(), Box<dyn std::error::Error>> {
351        // let dimensions = img.dimensions();
352        let nr_channels = img.color().channel_count();
353        let bytes_per_channel = img.color().bytes_per_pixel() / nr_channels;
354        assert!(bytes_per_channel == 1, "We are only supporting textures which have 1 byte per channel.");
355
356        // TODO refactor this into its own func because there is a lot of duplication
357        // with the from_img function convert 3 channels to 4 channels and keep
358        // 2 channels as 2 channels
359        let img_vec;
360        let img_buf = match nr_channels {
361            1 | 2 | 4 => img.as_bytes(),
362            3 => {
363                img_vec = img.to_rgba8().into_vec();
364                img_vec.as_bytes()
365            }
366            _ => panic!("Format with more than 4 channels not supported"),
367        };
368
369        let size = Self::extent_from_img(img);
370        let tex_format = Self::format_from_img(img, is_srgb);
371        let mut nr_mip_maps = 1;
372        let mut usages = wgpu::TextureUsages::TEXTURE_BINDING | wgpu::TextureUsages::COPY_DST;
373        if generate_mipmaps {
374            nr_mip_maps = size.max_mips(wgpu::TextureDimension::D2);
375        }
376        if mipmaper.is_some() && generate_mipmaps {
377            usages |= RenderMipmapGenerator::required_usage();
378        }
379
380        let desc = wgpu::TextureDescriptor {
381            label: None,
382            size,
383            mip_level_count: nr_mip_maps,
384            sample_count: 1,
385            dimension: wgpu::TextureDimension::D2,
386            format: tex_format,
387            usage: usages,
388            view_formats: if cfg!(target_arch = "wasm32") {
389                &[]
390            } else {
391                &[tex_format.add_srgb_suffix(), tex_format.remove_srgb_suffix()]
392            },
393        };
394
395        Self::upload_single_mip(&self.texture, device, queue, &desc, img_buf, staging_buffer, 0).await?;
396
397        //mipmaps
398        if generate_mipmaps {
399            Self::generate_mipmaps(
400                img,
401                &self.texture,
402                device,
403                queue,
404                &desc,
405                nr_mip_maps,
406                mipmap_generation_cpu,
407                staging_buffer,
408                mipmaper,
409            )
410            .await?;
411        }
412
413        // let view = texture.create_view(&wgpu::TextureViewDescriptor::default());
414        let view = self.texture.create_view(&wgpu::TextureViewDescriptor {
415            mip_level_count: Some(nr_mip_maps),
416            ..Default::default()
417        });
418
419        //update
420        self.view = view;
421
422        Ok(())
423    }
424
425    pub fn nr_channels(&self) -> u32 {
426        self.texture.format().components().into()
427    }
428
429    #[allow(clippy::too_many_arguments)]
430    #[allow(clippy::missing_errors_doc)]
431    pub async fn generate_mipmaps(
432        img: &DynImage,
433        texture: &wgpu::Texture,
434        device: &wgpu::Device,
435        queue: &wgpu::Queue,
436        desc: &wgpu::TextureDescriptor<'_>,
437        nr_mip_maps: u32,
438        mipmap_generation_cpu: bool,
439        staging_buffer: Option<&Buffer>,
440        mipmaper: Option<&RenderMipmapGenerator>,
441    ) -> Result<(), Box<dyn std::error::Error>> {
442        let nr_channels = img.color().channel_count();
443        if mipmap_generation_cpu {
444            //CPU generation
445            //similar to https://github.com/DGriffin91/bevy_mod_mipmap_generator/blob/main/src/lib.rs
446            let mut img_mip = DynImage::new(1, 1, image::ColorType::L8);
447            for mip_lvl in 1..nr_mip_maps {
448                let mip_size = desc.mip_level_size(mip_lvl).unwrap();
449                let prev_img_mip = if mip_lvl == 1 { img } else { &img_mip };
450                img_mip = prev_img_mip.resize_exact(mip_size.width, mip_size.height, FilterType::Triangle);
451                debug!("mip lvl {mip_lvl} has size {mip_size:?}");
452
453                let img_mip_vec;
454                let img_mip_buf = match nr_channels {
455                    1 | 2 | 4 => img_mip.as_bytes(),
456                    3 => {
457                        img_mip_vec = img_mip.to_rgba8().into_vec();
458                        img_mip_vec.as_bytes()
459                    }
460                    _ => panic!("Format with more than 4 channels not supported"),
461                };
462
463                Self::upload_single_mip(texture, device, queue, desc, img_mip_buf, staging_buffer, mip_lvl).await?;
464            }
465        } else {
466            //GPU mipmaps generation
467            if let Some(mipmaper) = mipmaper {
468                let mut encoder = device.create_command_encoder(&CommandEncoderDescriptor::default());
469                mipmaper.generate(device, &mut encoder, texture, desc).unwrap();
470                queue.submit(std::iter::once(encoder.finish()));
471            } else {
472                warn!("Couldn't generate mipmaps since the mipmapper was not provided");
473            }
474        }
475
476        Ok(())
477    }
478
479    pub fn extent_from_img(img: &DynImage) -> wgpu::Extent3d {
480        let dimensions = img.dimensions();
481        wgpu::Extent3d {
482            width: dimensions.0,
483            height: dimensions.1,
484            depth_or_array_layers: 1,
485        }
486    }
487
488    /// # Panics
489    /// Will panic if the image has more than 1 byte per channel
490    pub fn format_from_img(img: &DynImage, is_srgb: bool) -> wgpu::TextureFormat {
491        let nr_channels = img.color().channel_count();
492        let bytes_per_channel = img.color().bytes_per_pixel() / nr_channels;
493        assert!(bytes_per_channel == 1, "We are only supporting textures which have 1 byte per channel.");
494
495        //get a format for the texture
496        let mut tex_format = match nr_channels {
497            1 => wgpu::TextureFormat::R8Unorm,
498            2 => wgpu::TextureFormat::Rg8Unorm,
499            3 | 4 => wgpu::TextureFormat::Rgba8Unorm,
500            _ => panic!("Format with more than 4 channels not supported"),
501        };
502        if is_srgb {
503            tex_format = tex_format.add_srgb_suffix();
504        }
505
506        tex_format
507    }
508
509    /// Basically the same as `device.create_texture_with_data` but without the
510    /// creation part and the data is assumed to contain only one mip # Panics
511    /// Will panic if the data does not fit in the defined mipmaps described in
512    /// textureDescriptor
513    /// This is async for handling of textures on web environments
514    #[allow(clippy::missing_errors_doc)]
515    pub async fn upload_single_mip(
516        texture: &wgpu::Texture,
517        device: &wgpu::Device,
518        queue: &wgpu::Queue,
519        desc: &wgpu::TextureDescriptor<'_>,
520        data: &[u8],
521        staging_buffer: Option<&Buffer>,
522        mip: u32,
523    ) -> Result<(), Box<dyn std::error::Error>> {
524        let mut mip_size = desc.mip_level_size(mip).unwrap();
525        // copying layers separately
526        if desc.dimension != wgpu::TextureDimension::D3 {
527            mip_size.depth_or_array_layers = 1;
528        }
529
530        // Will return None only if it's a combined depth-stencil format
531        // If so, default to 4, validation will fail later anyway since the depth or
532        // stencil aspect needs to be written to individually
533        let block_size = desc.format.block_copy_size(None).unwrap_or(4);
534        let (block_width, block_height) = desc.format.block_dimensions();
535
536        // When uploading mips of compressed textures and the mip is supposed to be
537        // a size that isn't a multiple of the block size, the mip needs to be uploaded
538        // as its "physical size" which is the size rounded up to the nearest block
539        // size.
540        let mip_physical = mip_size.physical_size(desc.format);
541
542        // All these calculations are performed on the physical size as that's the
543        // data that exists in the buffer.
544        let width_blocks = mip_physical.width / block_width;
545        let height_blocks = mip_physical.height / block_height;
546
547        let bytes_per_row = width_blocks * block_size;
548        // let data_size = bytes_per_row * height_blocks *
549        // mip_size.depth_or_array_layers;
550
551        // let end_offset = binary_offset + data_size as usize;
552
553        if let Some(staging_buffer) = staging_buffer {
554            warn!("Using slow CPU->GPU transfer for texture upload. Might use less memory that staging buffer using by wgpu but it will be slower.");
555
556            //get some metadata
557            let bytes_per_row_unpadded = texture.format().block_copy_size(None).unwrap() * mip_size.width;
558            let bytes_per_row_padded = numerical::align(bytes_per_row_unpadded, wgpu::COPY_BYTES_PER_ROW_ALIGNMENT);
559
560            //map buffer and copy into it
561            // https://docs.rs/wgpu/latest/wgpu/struct.Buffer.html#mapping-buffers
562            //the mapping range has to be aligned to COPY_BUFFER_ALIGNMENT(4 bytes)
563            let slice_size = numerical::align(u32::try_from(data.len()).unwrap(), u32::try_from(wgpu::COPY_BUFFER_ALIGNMENT).unwrap());
564            {
565                let buffer_slice = staging_buffer.buffer.slice(0..u64::from(slice_size));
566                // NOTE: We have to create the mapping THEN device.poll() before await
567                // the future. Otherwise the application will freeze.
568                let (tx, rx) = futures::channel::oneshot::channel();
569                buffer_slice.map_async(wgpu::MapMode::Write, move |result| {
570                    tx.send(result).unwrap();
571                });
572                let _ = device.poll(wgpu::PollType::Wait);
573                rx.await.unwrap()?;
574                let mut buf_data = buffer_slice.get_mapped_range_mut();
575
576                //copy into it
577                buf_data.get_mut(0..data.len()).unwrap().clone_from_slice(data);
578            }
579
580            //finish
581            staging_buffer.buffer.unmap();
582
583            //copy from buffer to texture
584            let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
585            encoder.copy_buffer_to_texture(
586                wgpu::TexelCopyBufferInfo {
587                    buffer: &staging_buffer.buffer,
588                    layout: wgpu::TexelCopyBufferLayout {
589                        offset: 0,
590                        bytes_per_row: Some(bytes_per_row_padded),
591                        rows_per_image: Some(mip_size.height),
592                    },
593                },
594                wgpu::TexelCopyTextureInfo {
595                    aspect: wgpu::TextureAspect::All,
596                    texture,
597                    mip_level: mip,
598                    origin: wgpu::Origin3d::ZERO,
599                },
600                wgpu::Extent3d {
601                    width: mip_size.width,
602                    height: mip_size.height,
603                    depth_or_array_layers: 1,
604                },
605            );
606            queue.submit(Some(encoder.finish()));
607
608            //wait to finish because we might be reusing the staging buffer for
609            // something else later TODO maybe this is not needed
610            // since the mapping will block either way if the buffer is still in
611            // use device.poll(wgpu::PollType::Wait);
612        } else {
613            //Use wgpu write_texture which schedules internally the transfer to happen
614            // later
615            queue.write_texture(
616                wgpu::TexelCopyTextureInfo {
617                    texture,
618                    mip_level: mip,
619                    origin: wgpu::Origin3d { x: 0, y: 0, z: 0 },
620                    aspect: wgpu::TextureAspect::All,
621                },
622                data,
623                wgpu::TexelCopyBufferLayout {
624                    offset: 0,
625                    bytes_per_row: Some(bytes_per_row),
626                    rows_per_image: Some(height_blocks),
627                },
628                mip_physical,
629            );
630        }
631
632        Ok(())
633    }
634
635    /// Basically the same as `device.create_texture_with_data` but without the
636    /// creation part Assumes the data contains info for all mips
637    /// # Panics
638    /// Will panic if the data does not fit in the defined mipmaps described in
639    /// textureDescriptor
640    pub fn upload_all_mips(
641        texture: &wgpu::Texture,
642        device: &wgpu::Device,
643        queue: &wgpu::Queue,
644        desc: &wgpu::TextureDescriptor,
645        data: &[u8],
646        staging_buffer: Option<&Buffer>,
647    ) {
648        // Will return None only if it's a combined depth-stencil format
649        // If so, default to 4, validation will fail later anyway since the depth or
650        // stencil aspect needs to be written to individually
651        let block_size = desc.format.block_copy_size(None).unwrap_or(4);
652        let (block_width, block_height) = desc.format.block_dimensions();
653        let layer_iterations = desc.array_layer_count();
654
655        let (min_mip, max_mip) = (0, desc.mip_level_count);
656
657        let mut binary_offset = 0;
658        for layer in 0..layer_iterations {
659            for mip in min_mip..max_mip {
660                let mut mip_size = desc.mip_level_size(mip).unwrap();
661                // copying layers separately
662                if desc.dimension != wgpu::TextureDimension::D3 {
663                    mip_size.depth_or_array_layers = 1;
664                }
665
666                // When uploading mips of compressed textures and the mip is supposed to be
667                // a size that isn't a multiple of the block size, the mip needs to be uploaded
668                // as its "physical size" which is the size rounded up to the nearest block
669                // size.
670                let mip_physical = mip_size.physical_size(desc.format);
671
672                // All these calculations are performed on the physical size as that's the
673                // data that exists in the buffer.
674                let width_blocks = mip_physical.width / block_width;
675                let height_blocks = mip_physical.height / block_height;
676
677                let bytes_per_row = width_blocks * block_size;
678                let data_size = bytes_per_row * height_blocks * mip_size.depth_or_array_layers;
679
680                let end_offset = binary_offset + data_size as usize;
681
682                if let Some(staging_buffer) = staging_buffer {
683                    warn!("Using slow CPU->GPU transfer for texture upload. Might use less memory that staging buffer using by wgpu but it will be slower.");
684
685                    //get some metadata
686                    let bytes_per_row_unpadded = texture.format().block_copy_size(None).unwrap() * mip_size.width;
687                    let bytes_per_row_padded = numerical::align(bytes_per_row_unpadded, wgpu::COPY_BYTES_PER_ROW_ALIGNMENT);
688
689                    //map buffer and copy into it
690                    // https://docs.rs/wgpu/latest/wgpu/struct.Buffer.html#mapping-buffers
691                    let data_to_copy = &data[binary_offset..end_offset];
692                    //the mapping range has to be aligned to COPY_BUFFER_ALIGNMENT(4 bytes)
693                    let slice_size = numerical::align(
694                        u32::try_from(data_to_copy.len()).unwrap(),
695                        u32::try_from(wgpu::COPY_BUFFER_ALIGNMENT).unwrap(),
696                    );
697                    {
698                        let buffer_slice = staging_buffer.buffer.slice(0..u64::from(slice_size));
699                        // NOTE: We have to create the mapping THEN device.poll() before await
700                        // the future. Otherwise the application will freeze.
701                        let (tx, rx) = futures::channel::oneshot::channel();
702                        buffer_slice.map_async(wgpu::MapMode::Write, move |result| {
703                            tx.send(result).unwrap();
704                        });
705                        let _ = device.poll(wgpu::PollType::Wait);
706                        rx.block_on().unwrap().unwrap();
707                        let mut buf_data = buffer_slice.get_mapped_range_mut();
708
709                        //copy into it
710                        buf_data.get_mut(0..data_to_copy.len()).unwrap().clone_from_slice(data_to_copy);
711                    }
712
713                    //finish
714                    staging_buffer.buffer.unmap();
715
716                    //copy from buffer to texture
717                    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
718                    encoder.copy_buffer_to_texture(
719                        wgpu::TexelCopyBufferInfo {
720                            buffer: &staging_buffer.buffer,
721                            layout: wgpu::TexelCopyBufferLayout {
722                                offset: 0,
723                                bytes_per_row: Some(bytes_per_row_padded),
724                                rows_per_image: Some(mip_size.height),
725                            },
726                        },
727                        wgpu::TexelCopyTextureInfo {
728                            aspect: wgpu::TextureAspect::All,
729                            texture,
730                            mip_level: mip,
731                            origin: wgpu::Origin3d::ZERO,
732                        },
733                        wgpu::Extent3d {
734                            width: mip_size.width,
735                            height: mip_size.height,
736                            depth_or_array_layers: 1,
737                        },
738                    );
739                    queue.submit(Some(encoder.finish()));
740
741                    //wait to finish because we might be reusing the staging
742                    // buffer for something else later
743                    // TODO maybe this is not needed since the mapping will
744                    // block either way if the buffer is still in use
745                    // device.poll(wgpu::PollType::Wait);
746                } else {
747                    //Use wgpu write_texture which schedules internally the transfer to happen
748                    // later
749                    queue.write_texture(
750                        wgpu::TexelCopyTextureInfo {
751                            texture,
752                            mip_level: mip,
753                            origin: wgpu::Origin3d { x: 0, y: 0, z: layer },
754                            aspect: wgpu::TextureAspect::All,
755                        },
756                        &data[binary_offset..end_offset],
757                        wgpu::TexelCopyBufferLayout {
758                            offset: 0,
759                            bytes_per_row: Some(bytes_per_row),
760                            rows_per_image: Some(height_blocks),
761                        },
762                        mip_physical,
763                    );
764                }
765
766                binary_offset = end_offset;
767            }
768        }
769    }
770
771    pub fn upload_from_cpu_with_staging_buffer(
772        texture: &wgpu::Texture,
773        device: &wgpu::Device,
774        queue: &wgpu::Queue,
775        desc: &wgpu::TextureDescriptor,
776        data: &[u8],
777        staging_buffer: &Buffer,
778        mip_lvl: u32,
779    ) {
780        let mip_size = desc.mip_level_size(mip_lvl).unwrap();
781
782        //map buffer and copy into it
783        // https://docs.rs/wgpu/latest/wgpu/struct.Buffer.html#mapping-buffers
784        {
785            let buffer_slice = staging_buffer.buffer.slice(0..data.len() as u64);
786            // NOTE: We have to create the mapping THEN device.poll() before await
787            // the future. Otherwise the application will freeze.
788            let (tx, rx) = futures::channel::oneshot::channel();
789            buffer_slice.map_async(wgpu::MapMode::Write, move |result| {
790                tx.send(result).unwrap();
791            });
792            let _ = device.poll(wgpu::PollType::Wait);
793            rx.block_on().unwrap().unwrap();
794            let mut buf_data = buffer_slice.get_mapped_range_mut();
795
796            //copy into it
797            buf_data.clone_from_slice(data);
798        }
799
800        //finish
801        staging_buffer.buffer.unmap();
802
803        //get some metadata
804        let bytes_per_row_unpadded = texture.format().block_copy_size(None).unwrap() * mip_size.width;
805        let bytes_per_row_padded = numerical::align(bytes_per_row_unpadded, wgpu::COPY_BYTES_PER_ROW_ALIGNMENT);
806
807        //copy from buffer to texture
808        let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
809        encoder.copy_buffer_to_texture(
810            wgpu::TexelCopyBufferInfo {
811                buffer: &staging_buffer.buffer,
812                layout: wgpu::TexelCopyBufferLayout {
813                    offset: 0,
814                    bytes_per_row: Some(bytes_per_row_padded),
815                    rows_per_image: Some(mip_size.height),
816                },
817            },
818            wgpu::TexelCopyTextureInfo {
819                aspect: wgpu::TextureAspect::All,
820                texture,
821                mip_level: mip_lvl,
822                origin: wgpu::Origin3d::ZERO,
823            },
824            wgpu::Extent3d {
825                width: mip_size.width,
826                height: mip_size.height,
827                depth_or_array_layers: 1,
828            },
829        );
830        queue.submit(Some(encoder.finish()));
831
832        //wait to finish because we might be reusing the staging buffer for something
833        // else later
834        let _ = device.poll(wgpu::PollType::Wait);
835    }
836
837    /// This functions downloads the texture to the cpu and returns a `DynImage`
838    /// The aspect of the texture to download is specified by the aspect parameter
839    /// This is required since we use texture format `Depth32FloatStencil8`
840    /// which has both a depth and a stencil component
841    pub async fn download_to_cpu(&self, device: &wgpu::Device, queue: &wgpu::Queue, aspect: wgpu::TextureAspect) -> DynImage {
842        // create buffer
843        let bytes_per_row_unpadded = self.texture.format().block_copy_size(None).unwrap_or(4) * self.width();
844        let bytes_per_row_padded = numerical::align(bytes_per_row_unpadded, wgpu::COPY_BYTES_PER_ROW_ALIGNMENT);
845        let output_buffer_size = u64::from(bytes_per_row_padded * self.height());
846        let output_buffer_desc = wgpu::BufferDescriptor {
847            size: output_buffer_size,
848            usage: wgpu::BufferUsages::COPY_DST
849        // this tells wpgu that we want to read this buffer from the cpu
850        | wgpu::BufferUsages::MAP_READ,
851            label: None,
852            mapped_at_creation: false,
853        };
854
855        let output_buffer = device.create_buffer(&output_buffer_desc);
856
857        //copy from texture to buffer
858        let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
859        encoder.copy_texture_to_buffer(
860            wgpu::TexelCopyTextureInfo {
861                aspect,
862                texture: &self.texture,
863                mip_level: 0,
864                origin: wgpu::Origin3d::ZERO,
865            },
866            wgpu::TexelCopyBufferInfo {
867                buffer: &output_buffer,
868                layout: wgpu::TexelCopyBufferLayout {
869                    offset: 0,
870                    bytes_per_row: Some(bytes_per_row_padded),
871                    rows_per_image: Some(self.height()),
872                },
873            },
874            wgpu::Extent3d {
875                width: self.width(),
876                height: self.height(),
877                depth_or_array_layers: 1,
878            },
879        );
880        queue.submit(Some(encoder.finish()));
881
882        // map and get to cpu
883        // We need to scope the mapping variables so that we can unmap the buffer
884
885        // let mut buffer = DynImage::new(self.width(), self.height(),
886        // self.texture.format()); let mut buffer = match self.texture.format()
887        // {     TextureFormat::Rgba8Unorm => DynImage::new_rgba8(self.width(),
888        // self.height()),     TextureFormat::Depth32Float =>
889        // DynImage::new_luma32f(self.width(), self.height()),     _ => panic!("
890        // Texture format not implemented!"), };
891
892        let img: Option<DynImage> = {
893            let buffer_slice = output_buffer.slice(..);
894
895            // NOTE: We have to create the mapping THEN device.poll() before await
896            // the future. Otherwise the application will freeze.
897            //TODO maybe change the future_intrusive to futures. Future_intrusive seems to
898            // give some issues on wasm
899            let (tx, rx) = futures_intrusive::channel::shared::oneshot_channel();
900            buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
901                tx.send(result).unwrap();
902            });
903            let _ = device.poll(wgpu::PollType::Wait);
904            rx.receive().await.unwrap().unwrap();
905
906            let data = buffer_slice.get_mapped_range();
907
908            //TODO remove padding and copy into image
909            // https://github.com/rerun-io/rerun/blob/93146b6d04f8f494258901c8b892eee0bb31b1a8/crates/re_renderer/src/texture_info.rs#L57
910            let data_unpadded = Texture::remove_padding(data.as_bytes(), bytes_per_row_unpadded, bytes_per_row_padded, self.height());
911
912            // let copy_from = data_unpadded.as_bytes();
913            // buffer.copy_from_bytes(self.width(), self.height(), copy_from);
914            let w = self.width();
915            let h = self.height();
916            match self.texture.format() {
917                TextureFormat::Rgba8Unorm => ImageBuffer::from_raw(w, h, data_unpadded.to_vec()).map(DynImage::ImageRgba8),
918                TextureFormat::Bgra8Unorm => {
919                    let bgra_data = data_unpadded.to_vec();
920                    // Convert BGRA to RGBA by swapping channels
921                    let mut rgba_data = bgra_data.clone();
922                    for chunk in rgba_data.chunks_exact_mut(4) {
923                        chunk.swap(0, 2); // Swap B and R
924                    }
925                    ImageBuffer::from_raw(w, h, rgba_data).map(DynImage::ImageRgba8)
926                }
927                TextureFormat::Rgba32Float => ImageBuffer::from_raw(w, h, numerical::u8_to_f32_vec(&data_unpadded)).map(DynImage::ImageRgba32F),
928                TextureFormat::Depth32Float | TextureFormat::Depth32FloatStencil8 => {
929                    ImageBuffer::from_raw(w, h, numerical::u8_to_f32_vec(&data_unpadded)).map(DynImage::ImageLuma32F)
930                }
931                x => panic!("Texture format not implemented! {x:?}"),
932            }
933        };
934        output_buffer.unmap();
935        img.unwrap()
936    }
937
938    pub async fn download_pixel_to_cpu(&self, device: &wgpu::Device, queue: &wgpu::Queue, aspect: wgpu::TextureAspect, x: u32, y: u32) -> DynImage {
939        // Create a single pixel buffer to read back the selected pixel
940        let output_buffer_desc = wgpu::BufferDescriptor {
941            label: Some("ID Readback Buffer"),
942            size: 4, // 4 bytes for a single u32
943            usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
944            mapped_at_creation: false,
945        };
946        let output_buffer = device.create_buffer(&output_buffer_desc);
947
948        let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
949            label: Some("ID Readback Encoder"),
950        });
951
952        // let scaled_x = x / self.tex_params.scale_factor;
953        // let scaled_y = y / self.tex_params.scale_factor;
954
955        encoder.copy_texture_to_buffer(
956            wgpu::TexelCopyTextureInfo {
957                aspect,
958                texture: &self.texture,
959                mip_level: 0,
960                origin: wgpu::Origin3d { x, y, z: 0 },
961            },
962            wgpu::TexelCopyBufferInfo {
963                buffer: &output_buffer,
964                layout: wgpu::TexelCopyBufferLayout {
965                    offset: 0,
966                    bytes_per_row: None,
967                    rows_per_image: None,
968                },
969            },
970            wgpu::Extent3d {
971                width: 1,
972                height: 1,
973                depth_or_array_layers: 1,
974            },
975        );
976
977        queue.submit(Some(encoder.finish()));
978
979        let pixel: Option<DynImage> = {
980            let buffer_slice = output_buffer.slice(..);
981
982            let (tx, rx) = futures_intrusive::channel::shared::oneshot_channel();
983            buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
984                tx.send(result).unwrap();
985            });
986            let _ = device.poll(wgpu::PollType::Wait);
987            rx.receive().await.unwrap().unwrap();
988
989            let data = buffer_slice.get_mapped_range();
990            match self.texture.format() {
991                TextureFormat::Rgba8Unorm => {
992                    // This would be a ingle byte image
993                    let single_pixel_bytes = *data.to_vec().first().unwrap();
994                    ImageBuffer::from_raw(1, 1, [single_pixel_bytes].to_vec()).map(DynImage::ImageLuma8)
995                }
996                x => panic!("Texture format not implemented! {x:?}"),
997            }
998        };
999        output_buffer.unmap();
1000        pixel.unwrap()
1001    }
1002
1003    pub fn remove_padding(buffer: &[u8], bytes_per_row_unpadded: u32, bytes_per_row_padded: u32, nr_rows: u32) -> Cow<'_, [u8]> {
1004        // re_tracing::profile_function!();
1005
1006        // assert_eq!(buffer.len() as wgpu::BufferAddress, self.buffer_size_padded);
1007
1008        if bytes_per_row_padded == bytes_per_row_unpadded {
1009            return Cow::Borrowed(buffer);
1010        }
1011
1012        let mut unpadded_buffer = Vec::with_capacity((bytes_per_row_unpadded * nr_rows) as _);
1013
1014        for row in 0..nr_rows {
1015            let offset = (bytes_per_row_padded * row) as usize;
1016            unpadded_buffer.extend_from_slice(&buffer[offset..(offset + bytes_per_row_unpadded as usize)]);
1017        }
1018
1019        unpadded_buffer.into()
1020    }
1021
1022    pub fn create_bind_group_layout(device: &wgpu::Device, binding_tex: u32, binding_sampler: u32) -> wgpu::BindGroupLayout {
1023        device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
1024            entries: &[
1025                wgpu::BindGroupLayoutEntry {
1026                    binding: binding_tex, //matches with the @binding in the shader
1027                    visibility: wgpu::ShaderStages::FRAGMENT,
1028                    ty: wgpu::BindingType::Texture {
1029                        multisampled: false,
1030                        view_dimension: wgpu::TextureViewDimension::D2,
1031                        sample_type: wgpu::TextureSampleType::Float { filterable: true },
1032                    },
1033                    count: None,
1034                },
1035                wgpu::BindGroupLayoutEntry {
1036                    binding: binding_sampler, //matches with the @binding in the shader
1037                    visibility: wgpu::ShaderStages::FRAGMENT,
1038                    ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Filtering),
1039                    count: None,
1040                },
1041            ],
1042            label: Some("texture_bind_group_layout"),
1043        })
1044    }
1045    #[must_use]
1046    pub fn depth_linearize(&self, device: &wgpu::Device, queue: &wgpu::Queue, near: f32, far: f32) -> DynImage {
1047        //panics if depth map retrieval is attempted with MSAA sample count set to > 1
1048        assert!(
1049            !(self.texture.sample_count() > 1 && self.texture.format() == TextureFormat::Depth32Float),
1050            "InvalidSampleCount: Depth maps not supported for MSAA sample count {} (Use a config to set msaa_nr_samples as 1)",
1051            self.texture.sample_count()
1052        );
1053
1054        // This download specifically happens for a depth map so we only want to download the depth component
1055        let aspect = wgpu::TextureAspect::DepthOnly;
1056        let dynamic_img = pollster::block_on(self.download_to_cpu(device, queue, aspect));
1057        let w = dynamic_img.width();
1058        let h = dynamic_img.height();
1059        let c = dynamic_img.channels();
1060        assert!(c == 1, "Depth maps should have only 1 channel");
1061
1062        let linearized_img = match dynamic_img {
1063            DynImage::ImageLuma32F(v) => {
1064                let img_vec_ndc = v.to_vec();
1065                let img_vec: Vec<f32> = img_vec_ndc.iter().map(|&x| numerical::linearize_depth_reverse_z(x, near, far)).collect();
1066                DynImage::ImageLuma32F(ImageBuffer::from_raw(w, h, img_vec).unwrap())
1067            }
1068            _ => panic!("Texture format not implemented for remap (Only for depths)!"),
1069        };
1070        linearized_img
1071    }
1072
1073    pub fn create_bind_group(&self, device: &wgpu::Device, binding_tex: u32, binding_sampler: u32) -> wgpu::BindGroup {
1074        //create bind group
1075        //recreate the bind group
1076        let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
1077            layout: &Self::create_bind_group_layout(device, binding_tex, binding_sampler),
1078            entries: &[
1079                wgpu::BindGroupEntry {
1080                    binding: binding_tex,
1081                    resource: wgpu::BindingResource::TextureView(&self.view),
1082                },
1083                wgpu::BindGroupEntry {
1084                    binding: binding_sampler,
1085                    resource: wgpu::BindingResource::Sampler(&self.sampler),
1086                },
1087            ],
1088            label: Some("bind_group"),
1089        });
1090        bind_group
1091    }
1092
1093    pub fn resize(&mut self, device: &wgpu::Device, width: u32, height: u32) {
1094        //essentially creates a whole new texture with the same format and usage
1095        let format = self.texture.format();
1096        let usage = self.texture.usage();
1097        let mut new = Self::new(device, width, height, format, usage, self.tex_params);
1098        std::mem::swap(self, &mut new);
1099    }
1100
1101    //make a default 4x4 texture that can be used as a dummy texture
1102    pub fn create_default_texture(device: &wgpu::Device, queue: &wgpu::Queue) -> Self {
1103        // //read to cpu
1104        // let img = ImageReader::open(path).unwrap().decode().unwrap();
1105        // let rgba = img.to_rgba8();
1106
1107        //we make a 4x4 texture because some gbus don't allow 1x1 or 2x2 so 4x4 seems
1108        // to be the minimum allowed
1109        let width = 4;
1110        let height = 4;
1111
1112        let mut img_data: Vec<u8> = Vec::new();
1113        for _ in 0..height {
1114            for _ in 0..width {
1115                //assume 4 channels
1116                img_data.push(255);
1117                img_data.push(0);
1118                img_data.push(0);
1119                img_data.push(0);
1120            }
1121        }
1122
1123        // let rgba = img.to_rgba8();
1124        // let dimensions = img.dimensions();
1125
1126        let size = wgpu::Extent3d {
1127            width,
1128            height,
1129            depth_or_array_layers: 1,
1130        };
1131        // let format = wgpu::TextureFormat::Rgba8UnormSrgb;
1132        let format = wgpu::TextureFormat::Rgba8UnormSrgb;
1133        let desc = wgpu::TextureDescriptor {
1134            label: None,
1135            size,
1136            mip_level_count: 1,
1137            sample_count: 1,
1138            dimension: wgpu::TextureDimension::D2,
1139            format,
1140            usage: wgpu::TextureUsages::TEXTURE_BINDING | wgpu::TextureUsages::COPY_DST,
1141            view_formats: if cfg!(target_arch = "wasm32") {
1142                &[]
1143            } else {
1144                &[format.add_srgb_suffix(), format.remove_srgb_suffix()]
1145            },
1146        };
1147        let tex_params = TexParams::from_desc(&desc);
1148        let texture = device.create_texture_with_data(queue, &desc, wgpu::util::TextureDataOrder::LayerMajor, img_data.as_slice());
1149
1150        let view = texture.create_view(&wgpu::TextureViewDescriptor::default());
1151        let sampler = device.create_sampler(&wgpu::SamplerDescriptor {
1152            address_mode_u: wgpu::AddressMode::ClampToEdge,
1153            address_mode_v: wgpu::AddressMode::ClampToEdge,
1154            address_mode_w: wgpu::AddressMode::ClampToEdge,
1155            mag_filter: wgpu::FilterMode::Linear,
1156            min_filter: wgpu::FilterMode::Nearest,
1157            mipmap_filter: wgpu::FilterMode::Nearest,
1158            ..Default::default()
1159        });
1160
1161        Self {
1162            texture,
1163            view,
1164            sampler,
1165            tex_params, /* width,
1166                         * height, */
1167            #[cfg(feature = "burn-torch")]
1168            staging_buffer_backed_by_cuda_mem: None,
1169        }
1170    }
1171
1172    pub fn create_default_cubemap(device: &wgpu::Device, queue: &wgpu::Queue) -> Self {
1173        // //read to cpu
1174        // let img = ImageReader::open(path).unwrap().decode().unwrap();
1175        // let rgba = img.to_rgba8();
1176
1177        //we make a 4x4 texture because some gbus don't allow 1x1 or 2x2 so 4x4 seems
1178        // to be the minimum allowed
1179        let width = 4;
1180        let height = 4;
1181
1182        let mut img_data: Vec<u8> = Vec::new();
1183        for _ in 0..6 {
1184            for _ in 0..height {
1185                for _ in 0..width {
1186                    //assume 4 channels
1187                    img_data.push(255);
1188                    img_data.push(0);
1189                    img_data.push(0);
1190                    img_data.push(0);
1191                }
1192            }
1193        }
1194
1195        let size = wgpu::Extent3d {
1196            width,
1197            height,
1198            depth_or_array_layers: 6,
1199        };
1200        // let format = wgpu::TextureFormat::Rgba8UnormSrgb;
1201        let format = wgpu::TextureFormat::Rgba8UnormSrgb;
1202        let desc = wgpu::TextureDescriptor {
1203            label: None,
1204            size,
1205            mip_level_count: 1,
1206            sample_count: 1,
1207            dimension: wgpu::TextureDimension::D2,
1208            format,
1209            usage: wgpu::TextureUsages::TEXTURE_BINDING | wgpu::TextureUsages::COPY_DST,
1210            view_formats: if cfg!(target_arch = "wasm32") {
1211                &[]
1212            } else {
1213                &[format.add_srgb_suffix(), format.remove_srgb_suffix()]
1214            },
1215        };
1216        let tex_params = TexParams::from_desc(&desc);
1217        let texture = device.create_texture_with_data(queue, &desc, wgpu::util::TextureDataOrder::LayerMajor, img_data.as_slice());
1218
1219        let view = texture.create_view(&wgpu::TextureViewDescriptor {
1220            dimension: Some(wgpu::TextureViewDimension::Cube),
1221            ..Default::default()
1222        });
1223        let sampler = device.create_sampler(&wgpu::SamplerDescriptor {
1224            address_mode_u: wgpu::AddressMode::ClampToEdge,
1225            address_mode_v: wgpu::AddressMode::ClampToEdge,
1226            address_mode_w: wgpu::AddressMode::ClampToEdge,
1227            mag_filter: wgpu::FilterMode::Linear,
1228            min_filter: wgpu::FilterMode::Linear,
1229            mipmap_filter: wgpu::FilterMode::Linear,
1230            ..Default::default()
1231        });
1232
1233        Self {
1234            texture,
1235            view,
1236            sampler,
1237            tex_params, /* width,
1238                         * height, */
1239            #[cfg(feature = "burn-torch")]
1240            staging_buffer_backed_by_cuda_mem: None,
1241        }
1242    }
1243
1244    pub fn width(&self) -> u32 {
1245        self.texture.width()
1246    }
1247    pub fn height(&self) -> u32 {
1248        self.texture.height()
1249    }
1250    pub fn extent(&self) -> wgpu::Extent3d {
1251        wgpu::Extent3d {
1252            width: self.width(),
1253            height: self.height(),
1254            depth_or_array_layers: 1,
1255        }
1256    }
1257    // pub fn clone(&self) -> Self {
1258    //     Self {
1259    //         texture: self.texture,
1260    //         view: (),
1261    //         sampler: (),
1262    //         width: (),
1263    //         height: (),
1264    //     }
1265    // }
1266
1267    #[cfg(feature = "burn-torch")]
1268    pub fn from_tensor(
1269        &mut self,
1270        tensor: &Tensor,
1271        device: &wgpu::Device,
1272        queue: &wgpu::Queue,
1273        adapter: &wgpu::Adapter,
1274    ) -> Result<(), CudaInteropError> {
1275        if tensor.dim() != 4 {
1276            return Err(CudaInteropError::InvalidTensorDim(tensor.dim() as usize));
1277        }
1278        if tensor.size()[0] != 1 {
1279            return Err(CudaInteropError::InvalidBatchSize(tensor.size()[0] as usize));
1280        }
1281        if tensor.kind() != tch::Kind::Uint8 {
1282            return Err(CudaInteropError::InvalidTensorType(tensor.kind()));
1283        }
1284
1285        let mut tensor_hwc: Tensor = tensor.permute([0, 2, 3, 1]).squeeze().contiguous();
1286        let nr_channels = tensor_hwc.size()[2] as usize;
1287        if nr_channels > 4 {
1288            return Err(CudaInteropError::InvalidChannelSize(nr_channels));
1289        }
1290
1291        //since there is no such thing as rgb textures (just R, RG and RGBA). If it has 3 channels we pad with a channel full of zeros
1292        if nr_channels == 3 {
1293            let zero_channel = Tensor::empty_like(&tensor_hwc.slice(2, 0, 1, 1));
1294            tensor_hwc = Tensor::cat(&[tensor_hwc, zero_channel], 2);
1295        }
1296        let nr_channels = tensor_hwc.size()[2] as usize;
1297
1298        //calculate the size of the image
1299        let height = tensor_hwc.size()[0] as usize;
1300        let width = tensor_hwc.size()[1] as usize;
1301        let bytes_per_channel = 1; //we assume is one byte per channel since we only allow uint8
1302        let img_size = AllocSize {
1303            height: height,
1304            width: width,
1305            stride: width * nr_channels * bytes_per_channel,
1306        };
1307
1308        //recreate the staging buffer if necessary
1309        if self.staging_buffer_backed_by_cuda_mem.is_none()
1310            || self.staging_buffer_backed_by_cuda_mem.as_ref().unwrap().cuda_mem.alloc_size != img_size
1311        {
1312            debug!("staging_buffer_backed_by_cuda_mem creating because it is none or the size is different");
1313            let wgpu_cuda = wgpu_cuda_interop::interop::create_wgpu_cuda_buffer(device, adapter, img_size, wgpu::BufferUsages::COPY_SRC);
1314            self.staging_buffer_backed_by_cuda_mem = Some(Arc::new(wgpu_cuda));
1315        }
1316
1317        //remake the texture size if necessary
1318        if self.texture.height() != height as u32 || self.texture.width() != width as u32 || self.nr_channels() != nr_channels as u32 {
1319            // let old_format= self.texture.format();
1320            // println!("old format {:?}", old_format);
1321            let new_format = match nr_channels {
1322                1 => wgpu::TextureFormat::R8Unorm,
1323                2 => wgpu::TextureFormat::Rg8Unorm,
1324                4 => wgpu::TextureFormat::Rgba8UnormSrgb,
1325                _ => panic!("Unsupported number of channels"),
1326            };
1327            let new_tex = Texture::new(device, width as u32, height as u32, new_format, self.texture.usage(), self.tex_params);
1328            //replace
1329            // println!("replacing the view!");
1330            self.texture = new_tex.texture;
1331            self.view = new_tex.view;
1332        }
1333
1334        //copy from cuda memory to the staging buffer and then to the texture
1335        let source_ptr = tensor_hwc.data_ptr() as cust_raw::CUdeviceptr;
1336        if let Some(staging_buffer) = self.staging_buffer_backed_by_cuda_mem.as_ref() {
1337            wgpu_cuda_interop::interop::cuda_img_to_wgpu(source_ptr, img_size, staging_buffer, &self.texture, device, queue);
1338        }
1339
1340        Ok(())
1341    }
1342}