Skip to main content

astrelis_render/
sampler_cache.rs

1//! Sampler cache for efficient GPU sampler reuse.
2//!
3//! Creating GPU samplers is expensive. This module provides a cache
4//! that reuses samplers with identical descriptors.
5
6use ahash::HashMap;
7use astrelis_core::profiling::profile_function;
8use std::hash::{Hash, Hasher};
9use std::sync::{Arc, RwLock};
10
11/// A hashable key for sampler descriptors.
12///
13/// wgpu::SamplerDescriptor doesn't implement Hash, so we need this wrapper.
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct SamplerKey {
16    /// Address mode for U coordinate
17    pub address_mode_u: wgpu::AddressMode,
18    /// Address mode for V coordinate
19    pub address_mode_v: wgpu::AddressMode,
20    /// Address mode for W coordinate
21    pub address_mode_w: wgpu::AddressMode,
22    /// Magnification filter
23    pub mag_filter: wgpu::FilterMode,
24    /// Minification filter
25    pub min_filter: wgpu::FilterMode,
26    /// Mipmap filter
27    pub mipmap_filter: wgpu::FilterMode,
28    /// Minimum LOD clamp
29    pub lod_min_clamp: u32, // f32 bits
30    /// Maximum LOD clamp
31    pub lod_max_clamp: u32, // f32 bits
32    /// Comparison function (if any)
33    pub compare: Option<wgpu::CompareFunction>,
34    /// Anisotropy clamp (1-16)
35    pub anisotropy_clamp: u16,
36    /// Border color (for ClampToBorder address mode)
37    pub border_color: Option<wgpu::SamplerBorderColor>,
38}
39
40impl Hash for SamplerKey {
41    fn hash<H: Hasher>(&self, state: &mut H) {
42        self.address_mode_u.hash(state);
43        self.address_mode_v.hash(state);
44        self.address_mode_w.hash(state);
45        self.mag_filter.hash(state);
46        self.min_filter.hash(state);
47        self.mipmap_filter.hash(state);
48        self.lod_min_clamp.hash(state);
49        self.lod_max_clamp.hash(state);
50        self.compare.hash(state);
51        self.anisotropy_clamp.hash(state);
52        self.border_color.hash(state);
53    }
54}
55
56impl SamplerKey {
57    /// Create a key for a repeating nearest (point) sampler.
58    pub fn nearest_repeat() -> Self {
59        Self {
60            address_mode_u: wgpu::AddressMode::Repeat,
61            address_mode_v: wgpu::AddressMode::Repeat,
62            address_mode_w: wgpu::AddressMode::Repeat,
63            mag_filter: wgpu::FilterMode::Nearest,
64            min_filter: wgpu::FilterMode::Nearest,
65            mipmap_filter: wgpu::FilterMode::Nearest,
66            lod_min_clamp: 0.0f32.to_bits(),
67            lod_max_clamp: f32::MAX.to_bits(),
68            compare: None,
69            anisotropy_clamp: 1,
70            border_color: None,
71        }
72    }
73
74    /// Create a key for a mirrored linear sampler.
75    pub fn linear_mirror() -> Self {
76        Self {
77            address_mode_u: wgpu::AddressMode::MirrorRepeat,
78            address_mode_v: wgpu::AddressMode::MirrorRepeat,
79            address_mode_w: wgpu::AddressMode::MirrorRepeat,
80            mag_filter: wgpu::FilterMode::Linear,
81            min_filter: wgpu::FilterMode::Linear,
82            mipmap_filter: wgpu::FilterMode::Linear,
83            lod_min_clamp: 0.0f32.to_bits(),
84            lod_max_clamp: f32::MAX.to_bits(),
85            compare: None,
86            anisotropy_clamp: 1,
87            border_color: None,
88        }
89    }
90
91    /// Create a key for a mirrored nearest (point) sampler.
92    pub fn nearest_mirror() -> Self {
93        Self {
94            address_mode_u: wgpu::AddressMode::MirrorRepeat,
95            address_mode_v: wgpu::AddressMode::MirrorRepeat,
96            address_mode_w: wgpu::AddressMode::MirrorRepeat,
97            mag_filter: wgpu::FilterMode::Nearest,
98            min_filter: wgpu::FilterMode::Nearest,
99            mipmap_filter: wgpu::FilterMode::Nearest,
100            lod_min_clamp: 0.0f32.to_bits(),
101            lod_max_clamp: f32::MAX.to_bits(),
102            compare: None,
103            anisotropy_clamp: 1,
104            border_color: None,
105        }
106    }
107
108    /// Create a key from a sampler descriptor.
109    pub fn from_descriptor(desc: &wgpu::SamplerDescriptor) -> Self {
110        Self {
111            address_mode_u: desc.address_mode_u,
112            address_mode_v: desc.address_mode_v,
113            address_mode_w: desc.address_mode_w,
114            mag_filter: desc.mag_filter,
115            min_filter: desc.min_filter,
116            mipmap_filter: desc.mipmap_filter,
117            lod_min_clamp: desc.lod_min_clamp.to_bits(),
118            lod_max_clamp: desc.lod_max_clamp.to_bits(),
119            compare: desc.compare,
120            anisotropy_clamp: desc.anisotropy_clamp,
121            border_color: desc.border_color,
122        }
123    }
124
125    /// Create a descriptor from this key.
126    pub fn to_descriptor<'a>(&self, label: Option<&'a str>) -> wgpu::SamplerDescriptor<'a> {
127        wgpu::SamplerDescriptor {
128            label,
129            address_mode_u: self.address_mode_u,
130            address_mode_v: self.address_mode_v,
131            address_mode_w: self.address_mode_w,
132            mag_filter: self.mag_filter,
133            min_filter: self.min_filter,
134            mipmap_filter: self.mipmap_filter,
135            lod_min_clamp: f32::from_bits(self.lod_min_clamp),
136            lod_max_clamp: f32::from_bits(self.lod_max_clamp),
137            compare: self.compare,
138            anisotropy_clamp: self.anisotropy_clamp,
139            border_color: self.border_color,
140        }
141    }
142
143    /// Create a key for a default linear sampler.
144    pub fn linear() -> Self {
145        Self {
146            address_mode_u: wgpu::AddressMode::ClampToEdge,
147            address_mode_v: wgpu::AddressMode::ClampToEdge,
148            address_mode_w: wgpu::AddressMode::ClampToEdge,
149            mag_filter: wgpu::FilterMode::Linear,
150            min_filter: wgpu::FilterMode::Linear,
151            mipmap_filter: wgpu::FilterMode::Linear,
152            lod_min_clamp: 0.0f32.to_bits(),
153            lod_max_clamp: f32::MAX.to_bits(),
154            compare: None,
155            anisotropy_clamp: 1,
156            border_color: None,
157        }
158    }
159
160    /// Create a key for a default nearest (point) sampler.
161    pub fn nearest() -> Self {
162        Self {
163            address_mode_u: wgpu::AddressMode::ClampToEdge,
164            address_mode_v: wgpu::AddressMode::ClampToEdge,
165            address_mode_w: wgpu::AddressMode::ClampToEdge,
166            mag_filter: wgpu::FilterMode::Nearest,
167            min_filter: wgpu::FilterMode::Nearest,
168            mipmap_filter: wgpu::FilterMode::Nearest,
169            lod_min_clamp: 0.0f32.to_bits(),
170            lod_max_clamp: f32::MAX.to_bits(),
171            compare: None,
172            anisotropy_clamp: 1,
173            border_color: None,
174        }
175    }
176
177    /// Create a key for a repeating linear sampler.
178    pub fn linear_repeat() -> Self {
179        Self {
180            address_mode_u: wgpu::AddressMode::Repeat,
181            address_mode_v: wgpu::AddressMode::Repeat,
182            address_mode_w: wgpu::AddressMode::Repeat,
183            mag_filter: wgpu::FilterMode::Linear,
184            min_filter: wgpu::FilterMode::Linear,
185            mipmap_filter: wgpu::FilterMode::Linear,
186            lod_min_clamp: 0.0f32.to_bits(),
187            lod_max_clamp: f32::MAX.to_bits(),
188            compare: None,
189            anisotropy_clamp: 1,
190            border_color: None,
191        }
192    }
193}
194
195/// Sampling mode for image textures.
196///
197/// This is a user-friendly enum for selecting common texture sampling configurations.
198/// It maps to underlying `SamplerKey` configurations for cache lookup.
199#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
200pub enum ImageSampling {
201    /// Smooth bilinear filtering (default). Good for photos and gradients.
202    #[default]
203    Linear,
204    /// Pixel-perfect nearest-neighbor filtering. Ideal for pixel art.
205    Nearest,
206    /// Linear filtering with UV wrapping (repeat). For tiled textures.
207    LinearRepeat,
208    /// Nearest filtering with UV wrapping. For tiled pixel art.
209    NearestRepeat,
210    /// Linear filtering with mirrored UV wrapping.
211    LinearMirror,
212    /// Nearest filtering with mirrored UV wrapping.
213    NearestMirror,
214}
215
216impl ImageSampling {
217    /// Convert to a SamplerKey for cache lookup.
218    pub fn to_sampler_key(&self) -> SamplerKey {
219        match self {
220            Self::Linear => SamplerKey::linear(),
221            Self::Nearest => SamplerKey::nearest(),
222            Self::LinearRepeat => SamplerKey::linear_repeat(),
223            Self::NearestRepeat => SamplerKey::nearest_repeat(),
224            Self::LinearMirror => SamplerKey::linear_mirror(),
225            Self::NearestMirror => SamplerKey::nearest_mirror(),
226        }
227    }
228}
229
230/// A thread-safe cache of GPU samplers.
231///
232/// This cache ensures that identical sampler configurations share the same
233/// GPU sampler object, reducing memory usage and creation overhead.
234///
235/// # Example
236///
237/// ```ignore
238/// use astrelis_render::{SamplerCache, SamplerKey};
239///
240/// let cache = SamplerCache::new();
241///
242/// // Get or create a linear sampler
243/// let sampler = cache.get_or_create(&device, SamplerKey::linear());
244///
245/// // The same sampler is returned for identical keys
246/// let sampler2 = cache.get_or_create(&device, SamplerKey::linear());
247/// // sampler and sampler2 point to the same GPU sampler
248/// ```
249pub struct SamplerCache {
250    cache: RwLock<HashMap<SamplerKey, Arc<wgpu::Sampler>>>,
251}
252
253impl Default for SamplerCache {
254    fn default() -> Self {
255        Self::new()
256    }
257}
258
259impl SamplerCache {
260    /// Create a new empty sampler cache.
261    pub fn new() -> Self {
262        Self {
263            cache: RwLock::new(HashMap::default()),
264        }
265    }
266
267    /// Get a sampler from the cache or create a new one.
268    ///
269    /// If a sampler with the given key already exists in the cache,
270    /// it is returned. Otherwise, a new sampler is created and cached.
271    ///
272    /// # Panics
273    /// Panics if the internal RwLock is poisoned (another thread panicked while holding the lock).
274    pub fn get_or_create(&self, device: &wgpu::Device, key: SamplerKey) -> Arc<wgpu::Sampler> {
275        profile_function!();
276        // Try read lock first (fast path)
277        {
278            let cache = self
279                .cache
280                .read()
281                .expect("SamplerCache lock poisoned - a thread panicked while accessing the cache");
282            if let Some(sampler) = cache.get(&key) {
283                return Arc::clone(sampler);
284            }
285        }
286
287        // Slow path: create sampler and insert
288        let mut cache = self
289            .cache
290            .write()
291            .expect("SamplerCache lock poisoned - a thread panicked while accessing the cache");
292
293        // Double-check in case another thread inserted while we waited
294        if let Some(sampler) = cache.get(&key) {
295            return Arc::clone(sampler);
296        }
297
298        // Create the sampler
299        let descriptor = key.to_descriptor(Some("Cached Sampler"));
300        let sampler = Arc::new(device.create_sampler(&descriptor));
301        cache.insert(key, Arc::clone(&sampler));
302        sampler
303    }
304
305    /// Get a sampler from the cache or create one using a descriptor.
306    ///
307    /// This is a convenience method that converts the descriptor to a key.
308    pub fn get_or_create_from_descriptor(
309        &self,
310        device: &wgpu::Device,
311        descriptor: &wgpu::SamplerDescriptor,
312    ) -> Arc<wgpu::Sampler> {
313        let key = SamplerKey::from_descriptor(descriptor);
314        self.get_or_create(device, key)
315    }
316
317    /// Get a default linear sampler.
318    pub fn linear(&self, device: &wgpu::Device) -> Arc<wgpu::Sampler> {
319        self.get_or_create(device, SamplerKey::linear())
320    }
321
322    /// Get a default nearest (point) sampler.
323    pub fn nearest(&self, device: &wgpu::Device) -> Arc<wgpu::Sampler> {
324        self.get_or_create(device, SamplerKey::nearest())
325    }
326
327    /// Get a repeating linear sampler.
328    pub fn linear_repeat(&self, device: &wgpu::Device) -> Arc<wgpu::Sampler> {
329        self.get_or_create(device, SamplerKey::linear_repeat())
330    }
331
332    /// Get a repeating nearest sampler.
333    pub fn nearest_repeat(&self, device: &wgpu::Device) -> Arc<wgpu::Sampler> {
334        self.get_or_create(device, SamplerKey::nearest_repeat())
335    }
336
337    /// Get a mirrored linear sampler.
338    pub fn linear_mirror(&self, device: &wgpu::Device) -> Arc<wgpu::Sampler> {
339        self.get_or_create(device, SamplerKey::linear_mirror())
340    }
341
342    /// Get a mirrored nearest sampler.
343    pub fn nearest_mirror(&self, device: &wgpu::Device) -> Arc<wgpu::Sampler> {
344        self.get_or_create(device, SamplerKey::nearest_mirror())
345    }
346
347    /// Get a sampler for the given sampling mode.
348    pub fn from_sampling(
349        &self,
350        device: &wgpu::Device,
351        sampling: ImageSampling,
352    ) -> Arc<wgpu::Sampler> {
353        self.get_or_create(device, sampling.to_sampler_key())
354    }
355
356    /// Get the number of cached samplers.
357    ///
358    /// # Panics
359    /// Panics if the internal RwLock is poisoned.
360    pub fn len(&self) -> usize {
361        self.cache.read().expect("SamplerCache lock poisoned").len()
362    }
363
364    /// Check if the cache is empty.
365    ///
366    /// # Panics
367    /// Panics if the internal RwLock is poisoned.
368    pub fn is_empty(&self) -> bool {
369        self.cache
370            .read()
371            .expect("SamplerCache lock poisoned")
372            .is_empty()
373    }
374
375    /// Clear the cache, releasing all cached samplers.
376    ///
377    /// This should only be called when you're sure no references to
378    /// cached samplers are still in use.
379    ///
380    /// # Panics
381    /// Panics if the internal RwLock is poisoned.
382    pub fn clear(&self) {
383        self.cache
384            .write()
385            .expect("SamplerCache lock poisoned")
386            .clear();
387    }
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393
394    #[test]
395    fn test_sampler_key_hash_equality() {
396        let key1 = SamplerKey::linear();
397        let key2 = SamplerKey::linear();
398        let key3 = SamplerKey::nearest();
399
400        assert_eq!(key1, key2);
401        assert_ne!(key1, key3);
402
403        // Test hashing
404        use std::collections::hash_map::DefaultHasher;
405        let mut hasher1 = DefaultHasher::new();
406        let mut hasher2 = DefaultHasher::new();
407        key1.hash(&mut hasher1);
408        key2.hash(&mut hasher2);
409        assert_eq!(hasher1.finish(), hasher2.finish());
410    }
411
412    #[test]
413    fn test_sampler_key_roundtrip() {
414        let key = SamplerKey::linear();
415        let desc = key.to_descriptor(Some("Test"));
416        let key2 = SamplerKey::from_descriptor(&desc);
417        assert_eq!(key, key2);
418    }
419}