Skip to main content

astrelis_render/
sampler_cache.rs

1//! Sampler cache for efficient GPU sampler reuse.
2//!
3//! Creating GPU samplers is expensive. This module provides a cache
4//! that reuses samplers with identical descriptors.
5
6use astrelis_core::profiling::profile_function;
7use ahash::HashMap;
8use std::hash::{Hash, Hasher};
9use std::sync::{Arc, RwLock};
10
11/// A hashable key for sampler descriptors.
12///
13/// wgpu::SamplerDescriptor doesn't implement Hash, so we need this wrapper.
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct SamplerKey {
16    /// Address mode for U coordinate
17    pub address_mode_u: wgpu::AddressMode,
18    /// Address mode for V coordinate
19    pub address_mode_v: wgpu::AddressMode,
20    /// Address mode for W coordinate
21    pub address_mode_w: wgpu::AddressMode,
22    /// Magnification filter
23    pub mag_filter: wgpu::FilterMode,
24    /// Minification filter
25    pub min_filter: wgpu::FilterMode,
26    /// Mipmap filter
27    pub mipmap_filter: wgpu::FilterMode,
28    /// Minimum LOD clamp
29    pub lod_min_clamp: u32, // f32 bits
30    /// Maximum LOD clamp
31    pub lod_max_clamp: u32, // f32 bits
32    /// Comparison function (if any)
33    pub compare: Option<wgpu::CompareFunction>,
34    /// Anisotropy clamp (1-16)
35    pub anisotropy_clamp: u16,
36    /// Border color (for ClampToBorder address mode)
37    pub border_color: Option<wgpu::SamplerBorderColor>,
38}
39
40impl Hash for SamplerKey {
41    fn hash<H: Hasher>(&self, state: &mut H) {
42        self.address_mode_u.hash(state);
43        self.address_mode_v.hash(state);
44        self.address_mode_w.hash(state);
45        self.mag_filter.hash(state);
46        self.min_filter.hash(state);
47        self.mipmap_filter.hash(state);
48        self.lod_min_clamp.hash(state);
49        self.lod_max_clamp.hash(state);
50        self.compare.hash(state);
51        self.anisotropy_clamp.hash(state);
52        self.border_color.hash(state);
53    }
54}
55
56impl SamplerKey {
57    /// Create a key for a repeating nearest (point) sampler.
58    pub fn nearest_repeat() -> Self {
59        Self {
60            address_mode_u: wgpu::AddressMode::Repeat,
61            address_mode_v: wgpu::AddressMode::Repeat,
62            address_mode_w: wgpu::AddressMode::Repeat,
63            mag_filter: wgpu::FilterMode::Nearest,
64            min_filter: wgpu::FilterMode::Nearest,
65            mipmap_filter: wgpu::FilterMode::Nearest,
66            lod_min_clamp: 0.0f32.to_bits(),
67            lod_max_clamp: f32::MAX.to_bits(),
68            compare: None,
69            anisotropy_clamp: 1,
70            border_color: None,
71        }
72    }
73
74    /// Create a key for a mirrored linear sampler.
75    pub fn linear_mirror() -> Self {
76        Self {
77            address_mode_u: wgpu::AddressMode::MirrorRepeat,
78            address_mode_v: wgpu::AddressMode::MirrorRepeat,
79            address_mode_w: wgpu::AddressMode::MirrorRepeat,
80            mag_filter: wgpu::FilterMode::Linear,
81            min_filter: wgpu::FilterMode::Linear,
82            mipmap_filter: wgpu::FilterMode::Linear,
83            lod_min_clamp: 0.0f32.to_bits(),
84            lod_max_clamp: f32::MAX.to_bits(),
85            compare: None,
86            anisotropy_clamp: 1,
87            border_color: None,
88        }
89    }
90
91    /// Create a key for a mirrored nearest (point) sampler.
92    pub fn nearest_mirror() -> Self {
93        Self {
94            address_mode_u: wgpu::AddressMode::MirrorRepeat,
95            address_mode_v: wgpu::AddressMode::MirrorRepeat,
96            address_mode_w: wgpu::AddressMode::MirrorRepeat,
97            mag_filter: wgpu::FilterMode::Nearest,
98            min_filter: wgpu::FilterMode::Nearest,
99            mipmap_filter: wgpu::FilterMode::Nearest,
100            lod_min_clamp: 0.0f32.to_bits(),
101            lod_max_clamp: f32::MAX.to_bits(),
102            compare: None,
103            anisotropy_clamp: 1,
104            border_color: None,
105        }
106    }
107
108    /// Create a key from a sampler descriptor.
109    pub fn from_descriptor(desc: &wgpu::SamplerDescriptor) -> Self {
110        Self {
111            address_mode_u: desc.address_mode_u,
112            address_mode_v: desc.address_mode_v,
113            address_mode_w: desc.address_mode_w,
114            mag_filter: desc.mag_filter,
115            min_filter: desc.min_filter,
116            mipmap_filter: desc.mipmap_filter,
117            lod_min_clamp: desc.lod_min_clamp.to_bits(),
118            lod_max_clamp: desc.lod_max_clamp.to_bits(),
119            compare: desc.compare,
120            anisotropy_clamp: desc.anisotropy_clamp,
121            border_color: desc.border_color,
122        }
123    }
124
125    /// Create a descriptor from this key.
126    pub fn to_descriptor<'a>(&self, label: Option<&'a str>) -> wgpu::SamplerDescriptor<'a> {
127        wgpu::SamplerDescriptor {
128            label,
129            address_mode_u: self.address_mode_u,
130            address_mode_v: self.address_mode_v,
131            address_mode_w: self.address_mode_w,
132            mag_filter: self.mag_filter,
133            min_filter: self.min_filter,
134            mipmap_filter: self.mipmap_filter,
135            lod_min_clamp: f32::from_bits(self.lod_min_clamp),
136            lod_max_clamp: f32::from_bits(self.lod_max_clamp),
137            compare: self.compare,
138            anisotropy_clamp: self.anisotropy_clamp,
139            border_color: self.border_color,
140        }
141    }
142
143    /// Create a key for a default linear sampler.
144    pub fn linear() -> Self {
145        Self {
146            address_mode_u: wgpu::AddressMode::ClampToEdge,
147            address_mode_v: wgpu::AddressMode::ClampToEdge,
148            address_mode_w: wgpu::AddressMode::ClampToEdge,
149            mag_filter: wgpu::FilterMode::Linear,
150            min_filter: wgpu::FilterMode::Linear,
151            mipmap_filter: wgpu::FilterMode::Linear,
152            lod_min_clamp: 0.0f32.to_bits(),
153            lod_max_clamp: f32::MAX.to_bits(),
154            compare: None,
155            anisotropy_clamp: 1,
156            border_color: None,
157        }
158    }
159
160    /// Create a key for a default nearest (point) sampler.
161    pub fn nearest() -> Self {
162        Self {
163            address_mode_u: wgpu::AddressMode::ClampToEdge,
164            address_mode_v: wgpu::AddressMode::ClampToEdge,
165            address_mode_w: wgpu::AddressMode::ClampToEdge,
166            mag_filter: wgpu::FilterMode::Nearest,
167            min_filter: wgpu::FilterMode::Nearest,
168            mipmap_filter: wgpu::FilterMode::Nearest,
169            lod_min_clamp: 0.0f32.to_bits(),
170            lod_max_clamp: f32::MAX.to_bits(),
171            compare: None,
172            anisotropy_clamp: 1,
173            border_color: None,
174        }
175    }
176
177    /// Create a key for a repeating linear sampler.
178    pub fn linear_repeat() -> Self {
179        Self {
180            address_mode_u: wgpu::AddressMode::Repeat,
181            address_mode_v: wgpu::AddressMode::Repeat,
182            address_mode_w: wgpu::AddressMode::Repeat,
183            mag_filter: wgpu::FilterMode::Linear,
184            min_filter: wgpu::FilterMode::Linear,
185            mipmap_filter: wgpu::FilterMode::Linear,
186            lod_min_clamp: 0.0f32.to_bits(),
187            lod_max_clamp: f32::MAX.to_bits(),
188            compare: None,
189            anisotropy_clamp: 1,
190            border_color: None,
191        }
192    }
193}
194
195/// Sampling mode for image textures.
196///
197/// This is a user-friendly enum for selecting common texture sampling configurations.
198/// It maps to underlying `SamplerKey` configurations for cache lookup.
199#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
200pub enum ImageSampling {
201    /// Smooth bilinear filtering (default). Good for photos and gradients.
202    #[default]
203    Linear,
204    /// Pixel-perfect nearest-neighbor filtering. Ideal for pixel art.
205    Nearest,
206    /// Linear filtering with UV wrapping (repeat). For tiled textures.
207    LinearRepeat,
208    /// Nearest filtering with UV wrapping. For tiled pixel art.
209    NearestRepeat,
210    /// Linear filtering with mirrored UV wrapping.
211    LinearMirror,
212    /// Nearest filtering with mirrored UV wrapping.
213    NearestMirror,
214}
215
216impl ImageSampling {
217    /// Convert to a SamplerKey for cache lookup.
218    pub fn to_sampler_key(&self) -> SamplerKey {
219        match self {
220            Self::Linear => SamplerKey::linear(),
221            Self::Nearest => SamplerKey::nearest(),
222            Self::LinearRepeat => SamplerKey::linear_repeat(),
223            Self::NearestRepeat => SamplerKey::nearest_repeat(),
224            Self::LinearMirror => SamplerKey::linear_mirror(),
225            Self::NearestMirror => SamplerKey::nearest_mirror(),
226        }
227    }
228}
229
230/// A thread-safe cache of GPU samplers.
231///
232/// This cache ensures that identical sampler configurations share the same
233/// GPU sampler object, reducing memory usage and creation overhead.
234///
235/// # Example
236///
237/// ```ignore
238/// use astrelis_render::{SamplerCache, SamplerKey};
239///
240/// let cache = SamplerCache::new();
241///
242/// // Get or create a linear sampler
243/// let sampler = cache.get_or_create(&device, SamplerKey::linear());
244///
245/// // The same sampler is returned for identical keys
246/// let sampler2 = cache.get_or_create(&device, SamplerKey::linear());
247/// // sampler and sampler2 point to the same GPU sampler
248/// ```
249pub struct SamplerCache {
250    cache: RwLock<HashMap<SamplerKey, Arc<wgpu::Sampler>>>,
251}
252
253impl Default for SamplerCache {
254    fn default() -> Self {
255        Self::new()
256    }
257}
258
259impl SamplerCache {
260    /// Create a new empty sampler cache.
261    pub fn new() -> Self {
262        Self {
263            cache: RwLock::new(HashMap::default()),
264        }
265    }
266
267    /// Get a sampler from the cache or create a new one.
268    ///
269    /// If a sampler with the given key already exists in the cache,
270    /// it is returned. Otherwise, a new sampler is created and cached.
271    ///
272    /// # Panics
273    /// Panics if the internal RwLock is poisoned (another thread panicked while holding the lock).
274    pub fn get_or_create(&self, device: &wgpu::Device, key: SamplerKey) -> Arc<wgpu::Sampler> {
275        profile_function!();
276        // Try read lock first (fast path)
277        {
278            let cache = self.cache.read()
279                .expect("SamplerCache lock poisoned - a thread panicked while accessing the cache");
280            if let Some(sampler) = cache.get(&key) {
281                return Arc::clone(sampler);
282            }
283        }
284
285        // Slow path: create sampler and insert
286        let mut cache = self.cache.write()
287            .expect("SamplerCache lock poisoned - a thread panicked while accessing the cache");
288
289        // Double-check in case another thread inserted while we waited
290        if let Some(sampler) = cache.get(&key) {
291            return Arc::clone(sampler);
292        }
293
294        // Create the sampler
295        let descriptor = key.to_descriptor(Some("Cached Sampler"));
296        let sampler = Arc::new(device.create_sampler(&descriptor));
297        cache.insert(key, Arc::clone(&sampler));
298        sampler
299    }
300
301    /// Get a sampler from the cache or create one using a descriptor.
302    ///
303    /// This is a convenience method that converts the descriptor to a key.
304    pub fn get_or_create_from_descriptor(
305        &self,
306        device: &wgpu::Device,
307        descriptor: &wgpu::SamplerDescriptor,
308    ) -> Arc<wgpu::Sampler> {
309        let key = SamplerKey::from_descriptor(descriptor);
310        self.get_or_create(device, key)
311    }
312
313    /// Get a default linear sampler.
314    pub fn linear(&self, device: &wgpu::Device) -> Arc<wgpu::Sampler> {
315        self.get_or_create(device, SamplerKey::linear())
316    }
317
318    /// Get a default nearest (point) sampler.
319    pub fn nearest(&self, device: &wgpu::Device) -> Arc<wgpu::Sampler> {
320        self.get_or_create(device, SamplerKey::nearest())
321    }
322
323    /// Get a repeating linear sampler.
324    pub fn linear_repeat(&self, device: &wgpu::Device) -> Arc<wgpu::Sampler> {
325        self.get_or_create(device, SamplerKey::linear_repeat())
326    }
327
328    /// Get a repeating nearest sampler.
329    pub fn nearest_repeat(&self, device: &wgpu::Device) -> Arc<wgpu::Sampler> {
330        self.get_or_create(device, SamplerKey::nearest_repeat())
331    }
332
333    /// Get a mirrored linear sampler.
334    pub fn linear_mirror(&self, device: &wgpu::Device) -> Arc<wgpu::Sampler> {
335        self.get_or_create(device, SamplerKey::linear_mirror())
336    }
337
338    /// Get a mirrored nearest sampler.
339    pub fn nearest_mirror(&self, device: &wgpu::Device) -> Arc<wgpu::Sampler> {
340        self.get_or_create(device, SamplerKey::nearest_mirror())
341    }
342
343    /// Get a sampler for the given sampling mode.
344    pub fn from_sampling(&self, device: &wgpu::Device, sampling: ImageSampling) -> Arc<wgpu::Sampler> {
345        self.get_or_create(device, sampling.to_sampler_key())
346    }
347
348    /// Get the number of cached samplers.
349    ///
350    /// # Panics
351    /// Panics if the internal RwLock is poisoned.
352    pub fn len(&self) -> usize {
353        self.cache.read()
354            .expect("SamplerCache lock poisoned")
355            .len()
356    }
357
358    /// Check if the cache is empty.
359    ///
360    /// # Panics
361    /// Panics if the internal RwLock is poisoned.
362    pub fn is_empty(&self) -> bool {
363        self.cache.read()
364            .expect("SamplerCache lock poisoned")
365            .is_empty()
366    }
367
368    /// Clear the cache, releasing all cached samplers.
369    ///
370    /// This should only be called when you're sure no references to
371    /// cached samplers are still in use.
372    ///
373    /// # Panics
374    /// Panics if the internal RwLock is poisoned.
375    pub fn clear(&self) {
376        self.cache.write()
377            .expect("SamplerCache lock poisoned")
378            .clear();
379    }
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385
386    #[test]
387    fn test_sampler_key_hash_equality() {
388        let key1 = SamplerKey::linear();
389        let key2 = SamplerKey::linear();
390        let key3 = SamplerKey::nearest();
391
392        assert_eq!(key1, key2);
393        assert_ne!(key1, key3);
394
395        // Test hashing
396        use std::collections::hash_map::DefaultHasher;
397        let mut hasher1 = DefaultHasher::new();
398        let mut hasher2 = DefaultHasher::new();
399        key1.hash(&mut hasher1);
400        key2.hash(&mut hasher2);
401        assert_eq!(hasher1.finish(), hasher2.finish());
402    }
403
404    #[test]
405    fn test_sampler_key_roundtrip() {
406        let key = SamplerKey::linear();
407        let desc = key.to_descriptor(Some("Test"));
408        let key2 = SamplerKey::from_descriptor(&desc);
409        assert_eq!(key, key2);
410    }
411}