1use super::RenderNodeCpu;
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
5pub enum ScaleAlgorithm {
6 #[default]
8 Bilinear,
9 Bicubic,
11 Lanczos,
13}
14
15#[cfg(feature = "wgpu")]
18struct ScalePipeline {
19 render_pipeline: wgpu::RenderPipeline,
20 bind_group_layout: wgpu::BindGroupLayout,
21 sampler: wgpu::Sampler,
22}
23
24pub struct ScaleNode {
36 pub width: u32,
38 pub height: u32,
40 pub algorithm: ScaleAlgorithm,
42 #[cfg(feature = "wgpu")]
43 pipeline: std::sync::OnceLock<ScalePipeline>,
44}
45
46impl ScaleNode {
47 #[must_use]
48 pub fn new(width: u32, height: u32, algorithm: ScaleAlgorithm) -> Self {
49 Self {
50 width,
51 height,
52 algorithm,
53 #[cfg(feature = "wgpu")]
54 pipeline: std::sync::OnceLock::new(),
55 }
56 }
57}
58
59impl Default for ScaleNode {
60 fn default() -> Self {
61 Self::new(0, 0, ScaleAlgorithm::Bilinear)
62 }
63}
64
65impl RenderNodeCpu for ScaleNode {
68 fn process_cpu(&self, _rgba: &mut [u8], _w: u32, _h: u32) {
69 }
72}
73
74#[cfg(feature = "wgpu")]
77impl ScaleNode {
78 fn get_or_create_pipeline(&self, ctx: &crate::context::RenderContext) -> &ScalePipeline {
79 self.pipeline.get_or_init(|| {
80 let device = &ctx.device;
81
82 let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
83 label: Some("Scale shader"),
84 source: wgpu::ShaderSource::Wgsl(include_str!("../shaders/scale.wgsl").into()),
85 });
86
87 let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
88 label: Some("Scale BGL"),
89 entries: &[
90 wgpu::BindGroupLayoutEntry {
91 binding: 0,
92 visibility: wgpu::ShaderStages::FRAGMENT,
93 ty: wgpu::BindingType::Texture {
94 sample_type: wgpu::TextureSampleType::Float { filterable: true },
95 view_dimension: wgpu::TextureViewDimension::D2,
96 multisampled: false,
97 },
98 count: None,
99 },
100 wgpu::BindGroupLayoutEntry {
101 binding: 1,
102 visibility: wgpu::ShaderStages::FRAGMENT,
103 ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Filtering),
104 count: None,
105 },
106 ],
107 });
108
109 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
110 label: Some("Scale layout"),
111 bind_group_layouts: &[Some(&bgl)],
112 immediate_size: 0,
113 });
114
115 let filter = match self.algorithm {
118 ScaleAlgorithm::Bilinear | ScaleAlgorithm::Bicubic | ScaleAlgorithm::Lanczos => {
119 wgpu::FilterMode::Linear
120 }
121 };
122
123 let sampler = device.create_sampler(&wgpu::SamplerDescriptor {
124 label: Some("Scale sampler"),
125 address_mode_u: wgpu::AddressMode::ClampToEdge,
126 address_mode_v: wgpu::AddressMode::ClampToEdge,
127 mag_filter: filter,
128 min_filter: filter,
129 ..Default::default()
130 });
131
132 let render_pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
133 label: Some("Scale pipeline"),
134 layout: Some(&pipeline_layout),
135 vertex: wgpu::VertexState {
136 module: &shader,
137 entry_point: Some("vs_main"),
138 buffers: &[],
139 compilation_options: wgpu::PipelineCompilationOptions::default(),
140 },
141 fragment: Some(wgpu::FragmentState {
142 module: &shader,
143 entry_point: Some("fs_main"),
144 targets: &[Some(wgpu::ColorTargetState {
145 format: wgpu::TextureFormat::Rgba8Unorm,
146 blend: None,
147 write_mask: wgpu::ColorWrites::ALL,
148 })],
149 compilation_options: wgpu::PipelineCompilationOptions::default(),
150 }),
151 primitive: wgpu::PrimitiveState::default(),
152 depth_stencil: None,
153 multisample: wgpu::MultisampleState::default(),
154 multiview_mask: None,
155 cache: None,
156 });
157
158 ScalePipeline {
159 render_pipeline,
160 bind_group_layout: bgl,
161 sampler,
162 }
163 })
164 }
165}
166
167#[cfg(feature = "wgpu")]
168impl super::RenderNode for ScaleNode {
169 fn process(
170 &self,
171 inputs: &[&wgpu::Texture],
172 outputs: &[&wgpu::Texture],
173 ctx: &crate::context::RenderContext,
174 ) {
175 let Some(input) = inputs.first() else {
176 log::warn!("ScaleNode::process called with no inputs");
177 return;
178 };
179 let Some(output) = outputs.first() else {
180 log::warn!("ScaleNode::process called with no outputs");
181 return;
182 };
183
184 let pd = self.get_or_create_pipeline(ctx);
185
186 let input_view = input.create_view(&wgpu::TextureViewDescriptor::default());
187 let output_view = output.create_view(&wgpu::TextureViewDescriptor::default());
188
189 let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
190 label: Some("Scale BG"),
191 layout: &pd.bind_group_layout,
192 entries: &[
193 wgpu::BindGroupEntry {
194 binding: 0,
195 resource: wgpu::BindingResource::TextureView(&input_view),
196 },
197 wgpu::BindGroupEntry {
198 binding: 1,
199 resource: wgpu::BindingResource::Sampler(&pd.sampler),
200 },
201 ],
202 });
203
204 let mut encoder = ctx
205 .device
206 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
207 label: Some("Scale pass"),
208 });
209 {
210 let mut pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
211 label: Some("Scale pass"),
212 color_attachments: &[Some(wgpu::RenderPassColorAttachment {
213 view: &output_view,
214 resolve_target: None,
215 depth_slice: None,
216 ops: wgpu::Operations {
217 load: wgpu::LoadOp::Clear(wgpu::Color::TRANSPARENT),
218 store: wgpu::StoreOp::Store,
219 },
220 })],
221 depth_stencil_attachment: None,
222 timestamp_writes: None,
223 occlusion_query_set: None,
224 multiview_mask: None,
225 });
226 pass.set_pipeline(&pd.render_pipeline);
227 pass.set_bind_group(0, &bind_group, &[]);
228 pass.draw(0..6, 0..1);
229 }
230 ctx.queue.submit(std::iter::once(encoder.finish()));
231 }
232}
233
234#[cfg(test)]
237mod tests {
238 use super::*;
239
240 #[test]
241 fn scale_node_cpu_path_is_passthrough() {
242 let node = ScaleNode::new(100, 100, ScaleAlgorithm::Bilinear);
243 let original = vec![10u8, 20, 30, 255];
244 let mut rgba = original.clone();
245 node.process_cpu(&mut rgba, 1, 1);
246 assert_eq!(rgba, original, "ScaleNode CPU path must be a no-op");
247 }
248
249 #[test]
250 fn scale_algorithm_default_should_be_bilinear() {
251 assert_eq!(ScaleAlgorithm::default(), ScaleAlgorithm::Bilinear);
252 }
253}