Skip to main content

rgpui_wgpu/
wgpu_atlas.rs

1use anyhow::{Context as _, Result};
2use etagere::{BucketedAtlasAllocator, size2};
3use parking_lot::Mutex;
4use rgpui::collections::FxHashMap;
5use rgpui::{
6    AtlasKey, AtlasTextureId, AtlasTextureKind, AtlasTextureList, AtlasTile, Bounds, DevicePixels,
7    PlatformAtlas, Point, Size,
8};
9use std::{borrow::Cow, ops, sync::Arc};
10
11use crate::WgpuContext;
12
13fn device_size_to_etagere(size: Size<DevicePixels>) -> etagere::Size {
14    size2(size.width.0, size.height.0)
15}
16
17fn etagere_point_to_device(point: etagere::Point) -> Point<DevicePixels> {
18    Point {
19        x: DevicePixels(point.x),
20        y: DevicePixels(point.y),
21    }
22}
23
24pub struct WgpuAtlas(Mutex<WgpuAtlasState>);
25
26struct PendingUpload {
27    id: AtlasTextureId,
28    bounds: Bounds<DevicePixels>,
29    data: Vec<u8>,
30}
31
32struct WgpuAtlasState {
33    device: Arc<wgpu::Device>,
34    queue: Arc<wgpu::Queue>,
35    max_texture_size: u32,
36    color_texture_format: wgpu::TextureFormat,
37    storage: WgpuAtlasStorage,
38    tiles_by_key: FxHashMap<AtlasKey, AtlasTile>,
39    pending_uploads: Vec<PendingUpload>,
40}
41
42pub struct WgpuTextureInfo {
43    pub view: wgpu::TextureView,
44}
45
46impl WgpuAtlas {
47    pub fn new(
48        device: Arc<wgpu::Device>,
49        queue: Arc<wgpu::Queue>,
50        color_texture_format: wgpu::TextureFormat,
51    ) -> Self {
52        let max_texture_size = device.limits().max_texture_dimension_2d;
53        WgpuAtlas(Mutex::new(WgpuAtlasState {
54            device,
55            queue,
56            max_texture_size,
57            color_texture_format,
58            storage: WgpuAtlasStorage::default(),
59            tiles_by_key: Default::default(),
60            pending_uploads: Vec::new(),
61        }))
62    }
63
64    pub fn from_context(context: &WgpuContext) -> Self {
65        Self::new(
66            context.device.clone(),
67            context.queue.clone(),
68            context.color_texture_format(),
69        )
70    }
71
72    pub fn before_frame(&self) {
73        let mut lock = self.0.lock();
74        lock.flush_uploads();
75    }
76
77    pub fn get_texture_info(&self, id: AtlasTextureId) -> WgpuTextureInfo {
78        let lock = self.0.lock();
79        let texture = &lock.storage[id];
80        WgpuTextureInfo {
81            view: texture.view.clone(),
82        }
83    }
84
85    /// Clears all cached textures and tiles, forcing them to be recreated.
86    /// Use this for incremental recovery when the device is still valid.
87    pub fn clear(&self) {
88        let mut lock = self.0.lock();
89        lock.storage = WgpuAtlasStorage::default();
90        lock.tiles_by_key.clear();
91        lock.pending_uploads.clear();
92    }
93
94    /// Handles device lost by clearing all textures and cached tiles.
95    /// The atlas will lazily recreate textures as needed on subsequent frames.
96    pub fn handle_device_lost(&self, context: &WgpuContext) {
97        let mut lock = self.0.lock();
98        lock.device = context.device.clone();
99        lock.queue = context.queue.clone();
100        lock.color_texture_format = context.color_texture_format();
101        lock.storage = WgpuAtlasStorage::default();
102        lock.tiles_by_key.clear();
103        lock.pending_uploads.clear();
104    }
105}
106
107impl PlatformAtlas for WgpuAtlas {
108    fn get_or_insert_with<'a>(
109        &self,
110        key: &AtlasKey,
111        build: &mut dyn FnMut() -> Result<Option<(Size<DevicePixels>, Cow<'a, [u8]>)>>,
112    ) -> Result<Option<AtlasTile>> {
113        let mut lock = self.0.lock();
114        if let Some(tile) = lock.tiles_by_key.get(key) {
115            Ok(Some(*tile))
116        } else {
117            profiling::scope!("new tile");
118            let Some((size, bytes)) = build()? else {
119                return Ok(None);
120            };
121            let tile = lock
122                .allocate(size, key.texture_kind())
123                .context("failed to allocate")?;
124            lock.upload_texture(tile.texture_id, tile.bounds, &bytes);
125            lock.tiles_by_key.insert(key.clone(), tile);
126            Ok(Some(tile))
127        }
128    }
129
130    fn remove(&self, key: &AtlasKey) {
131        let mut lock = self.0.lock();
132
133        let Some(tile) = lock.tiles_by_key.remove(key) else {
134            return;
135        };
136        let id = tile.texture_id;
137
138        let Some(texture_slot) = lock.storage[id.kind].textures.get_mut(id.index as usize) else {
139            return;
140        };
141
142        if let Some(mut texture) = texture_slot.take() {
143            texture.allocator.deallocate(tile.tile_id.into());
144            texture.decrement_ref_count();
145            if texture.is_unreferenced() {
146                lock.pending_uploads
147                    .retain(|upload| upload.id != texture.id);
148                lock.storage[id.kind]
149                    .free_list
150                    .push(texture.id.index as usize);
151            } else {
152                *texture_slot = Some(texture);
153            }
154        }
155    }
156}
157
158impl WgpuAtlasState {
159    fn allocate(
160        &mut self,
161        size: Size<DevicePixels>,
162        texture_kind: AtlasTextureKind,
163    ) -> Option<AtlasTile> {
164        {
165            let textures = &mut self.storage[texture_kind];
166
167            if let Some(tile) = textures
168                .iter_mut()
169                .rev()
170                .find_map(|texture| texture.allocate(size))
171            {
172                return Some(tile);
173            }
174        }
175
176        let texture = self.push_texture(size, texture_kind);
177        texture.allocate(size)
178    }
179
180    fn push_texture(
181        &mut self,
182        min_size: Size<DevicePixels>,
183        kind: AtlasTextureKind,
184    ) -> &mut WgpuAtlasTexture {
185        const DEFAULT_ATLAS_SIZE: Size<DevicePixels> = Size {
186            width: DevicePixels(1024),
187            height: DevicePixels(1024),
188        };
189        let max_texture_size = self.max_texture_size as i32;
190        let max_atlas_size = Size {
191            width: DevicePixels(max_texture_size),
192            height: DevicePixels(max_texture_size),
193        };
194
195        let size = min_size.min(&max_atlas_size).max(&DEFAULT_ATLAS_SIZE);
196        let format = match kind {
197            AtlasTextureKind::Monochrome => wgpu::TextureFormat::R8Unorm,
198            AtlasTextureKind::Subpixel | AtlasTextureKind::Polychrome => self.color_texture_format,
199        };
200
201        let texture = self.device.create_texture(&wgpu::TextureDescriptor {
202            label: Some("atlas"),
203            size: wgpu::Extent3d {
204                width: size.width.0 as u32,
205                height: size.height.0 as u32,
206                depth_or_array_layers: 1,
207            },
208            mip_level_count: 1,
209            sample_count: 1,
210            dimension: wgpu::TextureDimension::D2,
211            format,
212            usage: wgpu::TextureUsages::TEXTURE_BINDING | wgpu::TextureUsages::COPY_DST,
213            view_formats: &[],
214        });
215
216        let view = texture.create_view(&wgpu::TextureViewDescriptor::default());
217
218        let texture_list = &mut self.storage[kind];
219        let index = texture_list.free_list.pop();
220
221        let atlas_texture = WgpuAtlasTexture {
222            id: AtlasTextureId {
223                index: index.unwrap_or(texture_list.textures.len()) as u32,
224                kind,
225            },
226            allocator: BucketedAtlasAllocator::new(device_size_to_etagere(size)),
227            format,
228            texture,
229            view,
230            live_atlas_keys: 0,
231        };
232
233        if let Some(ix) = index {
234            texture_list.textures[ix] = Some(atlas_texture);
235            texture_list
236                .textures
237                .get_mut(ix)
238                .and_then(|t| t.as_mut())
239                .expect("texture must exist")
240        } else {
241            texture_list.textures.push(Some(atlas_texture));
242            texture_list
243                .textures
244                .last_mut()
245                .and_then(|t| t.as_mut())
246                .expect("texture must exist")
247        }
248    }
249
250    fn upload_texture(&mut self, id: AtlasTextureId, bounds: Bounds<DevicePixels>, bytes: &[u8]) {
251        let data = self
252            .storage
253            .get(id)
254            .map(|texture| swizzle_upload_data(bytes, texture.format))
255            .unwrap_or_else(|| bytes.to_vec());
256
257        self.pending_uploads
258            .push(PendingUpload { id, bounds, data });
259    }
260
261    fn flush_uploads(&mut self) {
262        for upload in self.pending_uploads.drain(..) {
263            let Some(texture) = self.storage.get(upload.id) else {
264                continue;
265            };
266            let bytes_per_pixel = texture.bytes_per_pixel();
267
268            self.queue.write_texture(
269                wgpu::TexelCopyTextureInfo {
270                    texture: &texture.texture,
271                    mip_level: 0,
272                    origin: wgpu::Origin3d {
273                        x: upload.bounds.origin.x.0 as u32,
274                        y: upload.bounds.origin.y.0 as u32,
275                        z: 0,
276                    },
277                    aspect: wgpu::TextureAspect::All,
278                },
279                &upload.data,
280                wgpu::TexelCopyBufferLayout {
281                    offset: 0,
282                    bytes_per_row: Some(upload.bounds.size.width.0 as u32 * bytes_per_pixel as u32),
283                    rows_per_image: None,
284                },
285                wgpu::Extent3d {
286                    width: upload.bounds.size.width.0 as u32,
287                    height: upload.bounds.size.height.0 as u32,
288                    depth_or_array_layers: 1,
289                },
290            );
291        }
292    }
293}
294
295#[derive(Default)]
296struct WgpuAtlasStorage {
297    monochrome_textures: AtlasTextureList<WgpuAtlasTexture>,
298    subpixel_textures: AtlasTextureList<WgpuAtlasTexture>,
299    polychrome_textures: AtlasTextureList<WgpuAtlasTexture>,
300}
301
302impl ops::Index<AtlasTextureKind> for WgpuAtlasStorage {
303    type Output = AtlasTextureList<WgpuAtlasTexture>;
304    fn index(&self, kind: AtlasTextureKind) -> &Self::Output {
305        match kind {
306            AtlasTextureKind::Monochrome => &self.monochrome_textures,
307            AtlasTextureKind::Subpixel => &self.subpixel_textures,
308            AtlasTextureKind::Polychrome => &self.polychrome_textures,
309        }
310    }
311}
312
313impl ops::IndexMut<AtlasTextureKind> for WgpuAtlasStorage {
314    fn index_mut(&mut self, kind: AtlasTextureKind) -> &mut Self::Output {
315        match kind {
316            AtlasTextureKind::Monochrome => &mut self.monochrome_textures,
317            AtlasTextureKind::Subpixel => &mut self.subpixel_textures,
318            AtlasTextureKind::Polychrome => &mut self.polychrome_textures,
319        }
320    }
321}
322
323impl WgpuAtlasStorage {
324    fn get(&self, id: AtlasTextureId) -> Option<&WgpuAtlasTexture> {
325        self[id.kind]
326            .textures
327            .get(id.index as usize)
328            .and_then(|t| t.as_ref())
329    }
330}
331
332impl ops::Index<AtlasTextureId> for WgpuAtlasStorage {
333    type Output = WgpuAtlasTexture;
334    fn index(&self, id: AtlasTextureId) -> &Self::Output {
335        let textures = match id.kind {
336            AtlasTextureKind::Monochrome => &self.monochrome_textures,
337            AtlasTextureKind::Subpixel => &self.subpixel_textures,
338            AtlasTextureKind::Polychrome => &self.polychrome_textures,
339        };
340        textures[id.index as usize]
341            .as_ref()
342            .expect("texture must exist")
343    }
344}
345
346struct WgpuAtlasTexture {
347    id: AtlasTextureId,
348    allocator: BucketedAtlasAllocator,
349    texture: wgpu::Texture,
350    view: wgpu::TextureView,
351    format: wgpu::TextureFormat,
352    live_atlas_keys: u32,
353}
354
355impl WgpuAtlasTexture {
356    fn allocate(&mut self, size: Size<DevicePixels>) -> Option<AtlasTile> {
357        let allocation = self.allocator.allocate(device_size_to_etagere(size))?;
358        let tile = AtlasTile {
359            texture_id: self.id,
360            tile_id: allocation.id.into(),
361            padding: 0,
362            bounds: Bounds {
363                origin: etagere_point_to_device(allocation.rectangle.min),
364                size,
365            },
366        };
367        self.live_atlas_keys += 1;
368        Some(tile)
369    }
370
371    fn bytes_per_pixel(&self) -> u8 {
372        match self.format {
373            wgpu::TextureFormat::R8Unorm => 1,
374            wgpu::TextureFormat::Bgra8Unorm | wgpu::TextureFormat::Rgba8Unorm => 4,
375            _ => 4,
376        }
377    }
378
379    fn decrement_ref_count(&mut self) {
380        self.live_atlas_keys -= 1;
381    }
382
383    fn is_unreferenced(&self) -> bool {
384        self.live_atlas_keys == 0
385    }
386}
387
388fn swizzle_upload_data(bytes: &[u8], format: wgpu::TextureFormat) -> Vec<u8> {
389    match format {
390        wgpu::TextureFormat::Rgba8Unorm => {
391            let mut data = bytes.to_vec();
392            for pixel in data.chunks_exact_mut(4) {
393                pixel.swap(0, 2);
394            }
395            data
396        }
397        _ => bytes.to_vec(),
398    }
399}
400
401#[cfg(all(test, not(target_family = "wasm")))]
402mod tests {
403    use super::*;
404    use rgpui::block_on;
405    use rgpui::{ImageId, RenderImageParams};
406    use std::sync::Arc;
407
408    fn test_device_and_queue() -> anyhow::Result<(Arc<wgpu::Device>, Arc<wgpu::Queue>)> {
409        block_on(async {
410            let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
411                backends: wgpu::Backends::all(),
412                flags: wgpu::InstanceFlags::default(),
413                backend_options: wgpu::BackendOptions::default(),
414                memory_budget_thresholds: wgpu::MemoryBudgetThresholds::default(),
415                display: None,
416            });
417            let adapter = instance
418                .request_adapter(&wgpu::RequestAdapterOptions {
419                    power_preference: wgpu::PowerPreference::LowPower,
420                    compatible_surface: None,
421                    force_fallback_adapter: false,
422                })
423                .await
424                .map_err(|error| anyhow::anyhow!("failed to request adapter: {error}"))?;
425            let (device, queue) = adapter
426                .request_device(&wgpu::DeviceDescriptor {
427                    label: Some("wgpu_atlas_test_device"),
428                    required_features: wgpu::Features::empty(),
429                    required_limits: wgpu::Limits::downlevel_defaults()
430                        .using_resolution(adapter.limits())
431                        .using_alignment(adapter.limits()),
432                    memory_hints: wgpu::MemoryHints::MemoryUsage,
433                    trace: wgpu::Trace::Off,
434                    experimental_features: wgpu::ExperimentalFeatures::disabled(),
435                })
436                .await
437                .map_err(|error| anyhow::anyhow!("failed to request device: {error}"))?;
438            Ok((Arc::new(device), Arc::new(queue)))
439        })
440    }
441
442    #[test]
443    fn before_frame_skips_uploads_for_removed_texture() -> anyhow::Result<()> {
444        let (device, queue) = test_device_and_queue()?;
445
446        let atlas = WgpuAtlas::new(device, queue, wgpu::TextureFormat::Bgra8Unorm);
447        let key = AtlasKey::Image(RenderImageParams {
448            image_id: ImageId(1),
449            frame_index: 0,
450        });
451        let size = Size {
452            width: DevicePixels(1),
453            height: DevicePixels(1),
454        };
455        let mut build = || Ok(Some((size, Cow::Owned(vec![0, 0, 0, 255]))));
456
457        // Regression test: before the fix, this panicked in flush_uploads
458        atlas
459            .get_or_insert_with(&key, &mut build)?
460            .expect("tile should be created");
461        atlas.remove(&key);
462        atlas.before_frame();
463        Ok(())
464    }
465
466    #[test]
467    fn remove_deallocates_tile_space_for_reuse() -> anyhow::Result<()> {
468        let (device, queue) = test_device_and_queue()?;
469        let atlas = WgpuAtlas::new(device, queue, wgpu::TextureFormat::Bgra8Unorm);
470
471        let small = Size {
472            width: DevicePixels(64),
473            height: DevicePixels(64),
474        };
475        let big = Size {
476            width: DevicePixels(700),
477            height: DevicePixels(700),
478        };
479
480        let make_key = |image_id: usize| {
481            AtlasKey::Image(RenderImageParams {
482                image_id: ImageId(image_id),
483                frame_index: 0,
484            })
485        };
486        let insert = |key: &AtlasKey, size: Size<DevicePixels>| {
487            let byte_count = (size.width.0 as usize) * (size.height.0 as usize) * 4;
488            atlas
489                .get_or_insert_with(key, &mut || {
490                    Ok(Some((size, Cow::Owned(vec![0u8; byte_count]))))
491                })
492                .expect("allocation should succeed")
493                .expect("callback returns Some")
494        };
495
496        let keeper_key = make_key(1);
497        let big_key_a = make_key(2);
498        let big_key_b = make_key(3);
499
500        let keeper_tile = insert(&keeper_key, small);
501        let tile_a = insert(&big_key_a, big);
502        assert_eq!(keeper_tile.texture_id, tile_a.texture_id);
503
504        atlas.remove(&big_key_a);
505        let tile_b = insert(&big_key_b, big);
506        assert_eq!(tile_b.texture_id, keeper_tile.texture_id);
507        Ok(())
508    }
509
510    #[test]
511    fn swizzle_upload_data_preserves_bgra_uploads() {
512        let input = vec![0x10, 0x20, 0x30, 0x40];
513        assert_eq!(
514            swizzle_upload_data(&input, wgpu::TextureFormat::Bgra8Unorm),
515            input
516        );
517    }
518
519    #[test]
520    fn swizzle_upload_data_converts_bgra_to_rgba() {
521        let input = vec![0x10, 0x20, 0x30, 0x40, 0xAA, 0xBB, 0xCC, 0xDD];
522        assert_eq!(
523            swizzle_upload_data(&input, wgpu::TextureFormat::Rgba8Unorm),
524            vec![0x30, 0x20, 0x10, 0x40, 0xCC, 0xBB, 0xAA, 0xDD]
525        );
526    }
527}