1use crate::buffer::{GpuBuffer, GpuRasterBuffer};
7use crate::context::GpuContext;
8use crate::error::{GpuError, GpuResult};
9use crate::kernels::{
10 convolution::gaussian_blur,
11 raster::{ElementWiseOp, RasterKernel, ScalarKernel, ScalarOp, UnaryKernel, UnaryOp},
12 resampling::{ResamplingMethod, resize},
13 statistics::{
14 HistogramKernel, HistogramParams, ReductionKernel, ReductionOp, Statistics,
15 compute_statistics,
16 },
17};
18use crate::shaders::{
19 ComputePipelineBuilder, WgslShader, create_compute_bind_group_layout, storage_buffer_layout,
20 uniform_buffer_layout,
21};
22use bytemuck::{Pod, Zeroable};
23use std::marker::PhantomData;
24use tracing::debug;
25use wgpu::{
26 BindGroupDescriptor, BindGroupEntry, BufferUsages, CommandEncoderDescriptor,
27 ComputePassDescriptor, ComputePipeline as WgpuComputePipeline,
28};
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
36pub enum GpuDataType {
37 U8,
39 U16,
41 U32,
43 I8,
45 I16,
47 I32,
49 F32,
51 F64Emulated,
53}
54
55impl GpuDataType {
56 pub fn size_bytes(&self) -> usize {
58 match self {
59 Self::U8 | Self::I8 => 1,
60 Self::U16 | Self::I16 => 2,
61 Self::U32 | Self::I32 | Self::F32 => 4,
62 Self::F64Emulated => 8,
63 }
64 }
65
66 pub fn min_value(&self) -> f64 {
68 match self {
69 Self::U8 => 0.0,
70 Self::U16 => 0.0,
71 Self::U32 => 0.0,
72 Self::I8 => -128.0,
73 Self::I16 => -32768.0,
74 Self::I32 => -2147483648.0,
75 Self::F32 => f32::MIN as f64,
76 Self::F64Emulated => f64::MIN,
77 }
78 }
79
80 pub fn max_value(&self) -> f64 {
82 match self {
83 Self::U8 => 255.0,
84 Self::U16 => 65535.0,
85 Self::U32 => 4294967295.0,
86 Self::I8 => 127.0,
87 Self::I16 => 32767.0,
88 Self::I32 => 2147483647.0,
89 Self::F32 => f32::MAX as f64,
90 Self::F64Emulated => f64::MAX,
91 }
92 }
93
94 pub fn is_signed(&self) -> bool {
96 matches!(
97 self,
98 Self::I8 | Self::I16 | Self::I32 | Self::F32 | Self::F64Emulated
99 )
100 }
101
102 pub fn is_float(&self) -> bool {
104 matches!(self, Self::F32 | Self::F64Emulated)
105 }
106
107 fn wgsl_storage_type(&self) -> &'static str {
109 match self {
110 Self::U8 | Self::I8 | Self::U16 | Self::I16 | Self::U32 | Self::I32 => "u32",
111 Self::F32 => "f32",
112 Self::F64Emulated => "vec2<f32>",
113 }
114 }
115}
116
117#[derive(Debug, Clone, Copy, Pod, Zeroable)]
119#[repr(C)]
120pub struct ConversionParams {
121 pub scale: f32,
123 pub offset: f32,
125 pub out_min: f32,
127 pub out_max: f32,
129 pub nodata_in: f32,
131 pub nodata_out: f32,
133 pub use_nodata: u32,
135 _padding: u32,
137}
138
139impl Default for ConversionParams {
140 fn default() -> Self {
141 Self {
142 scale: 1.0,
143 offset: 0.0,
144 out_min: f32::MIN,
145 out_max: f32::MAX,
146 nodata_in: 0.0,
147 nodata_out: 0.0,
148 use_nodata: 0,
149 _padding: 0,
150 }
151 }
152}
153
154impl ConversionParams {
155 pub fn new(scale: f32, offset: f32) -> Self {
157 Self {
158 scale,
159 offset,
160 ..Default::default()
161 }
162 }
163
164 pub fn for_type_conversion(src: GpuDataType, dst: GpuDataType) -> Self {
166 let src_range = src.max_value() - src.min_value();
168 let dst_range = dst.max_value() - dst.min_value();
169
170 let scale = if src_range > 0.0 && dst_range > 0.0 {
171 (dst_range / src_range) as f32
172 } else {
173 1.0
174 };
175
176 let offset = if src.min_value() != dst.min_value() {
177 (dst.min_value() - src.min_value() * scale as f64) as f32
178 } else {
179 0.0
180 };
181
182 Self {
183 scale,
184 offset,
185 out_min: dst.min_value() as f32,
186 out_max: dst.max_value() as f32,
187 ..Default::default()
188 }
189 }
190
191 pub fn with_clamp(mut self, min: f32, max: f32) -> Self {
193 self.out_min = min;
194 self.out_max = max;
195 self
196 }
197
198 pub fn with_nodata(mut self, input_nodata: f32, output_nodata: f32) -> Self {
200 self.nodata_in = input_nodata;
201 self.nodata_out = output_nodata;
202 self.use_nodata = 1;
203 self
204 }
205
206 pub fn u8_to_normalized() -> Self {
208 Self {
209 scale: 1.0 / 255.0,
210 offset: 0.0,
211 out_min: 0.0,
212 out_max: 1.0,
213 ..Default::default()
214 }
215 }
216
217 pub fn normalized_to_u8() -> Self {
219 Self {
220 scale: 255.0,
221 offset: 0.0,
222 out_min: 0.0,
223 out_max: 255.0,
224 ..Default::default()
225 }
226 }
227
228 pub fn u16_to_normalized() -> Self {
230 Self {
231 scale: 1.0 / 65535.0,
232 offset: 0.0,
233 out_min: 0.0,
234 out_max: 1.0,
235 ..Default::default()
236 }
237 }
238}
239
240pub struct DataTypeConversionKernel {
242 context: GpuContext,
243 pipeline: WgpuComputePipeline,
244 bind_group_layout: wgpu::BindGroupLayout,
245 workgroup_size: u32,
246}
247
248impl DataTypeConversionKernel {
249 pub fn new(context: &GpuContext, src_type: GpuDataType) -> GpuResult<Self> {
258 debug!(
259 "Creating data type conversion kernel for {:?} -> f32",
260 src_type
261 );
262
263 let shader_source = Self::conversion_shader(src_type);
264 let mut shader = WgslShader::new(shader_source, "convert_type");
265 let shader_module = shader.compile(context.device())?;
266
267 let bind_group_layout = create_compute_bind_group_layout(
268 context.device(),
269 &[
270 storage_buffer_layout(0, true), uniform_buffer_layout(1), storage_buffer_layout(2, false), ],
274 Some("DataTypeConversionKernel BindGroupLayout"),
275 )?;
276
277 let pipeline = ComputePipelineBuilder::new(context.device(), shader_module, "convert_type")
278 .bind_group_layout(&bind_group_layout)
279 .label(format!(
280 "DataTypeConversion Pipeline: {:?} -> f32",
281 src_type
282 ))
283 .build()?;
284
285 Ok(Self {
286 context: context.clone(),
287 pipeline,
288 bind_group_layout,
289 workgroup_size: 256,
290 })
291 }
292
293 fn conversion_shader(src_type: GpuDataType) -> String {
295 let (input_type, unpack_code) = match src_type {
296 GpuDataType::U8 => (
297 "u32",
298 r#"
299 // Unpack 4 u8 values from one u32
300 let packed = input[idx / 4u];
301 let byte_idx = idx % 4u;
302 var value: f32;
303 switch (byte_idx) {
304 case 0u: { value = f32(packed & 0xFFu); }
305 case 1u: { value = f32((packed >> 8u) & 0xFFu); }
306 case 2u: { value = f32((packed >> 16u) & 0xFFu); }
307 case 3u: { value = f32((packed >> 24u) & 0xFFu); }
308 default: { value = 0.0; }
309 }"#,
310 ),
311 GpuDataType::I8 => (
312 "u32",
313 r#"
314 // Unpack 4 i8 values from one u32
315 let packed = input[idx / 4u];
316 let byte_idx = idx % 4u;
317 var raw: u32;
318 switch (byte_idx) {
319 case 0u: { raw = packed & 0xFFu; }
320 case 1u: { raw = (packed >> 8u) & 0xFFu; }
321 case 2u: { raw = (packed >> 16u) & 0xFFu; }
322 case 3u: { raw = (packed >> 24u) & 0xFFu; }
323 default: { raw = 0u; }
324 }
325 // Sign extend from 8 bits
326 var value: f32;
327 if (raw >= 128u) {
328 value = f32(i32(raw) - 256);
329 } else {
330 value = f32(raw);
331 }"#,
332 ),
333 GpuDataType::U16 => (
334 "u32",
335 r#"
336 // Unpack 2 u16 values from one u32
337 let packed = input[idx / 2u];
338 let half_idx = idx % 2u;
339 var value: f32;
340 if (half_idx == 0u) {
341 value = f32(packed & 0xFFFFu);
342 } else {
343 value = f32((packed >> 16u) & 0xFFFFu);
344 }"#,
345 ),
346 GpuDataType::I16 => (
347 "u32",
348 r#"
349 // Unpack 2 i16 values from one u32
350 let packed = input[idx / 2u];
351 let half_idx = idx % 2u;
352 var raw: u32;
353 if (half_idx == 0u) {
354 raw = packed & 0xFFFFu;
355 } else {
356 raw = (packed >> 16u) & 0xFFFFu;
357 }
358 // Sign extend from 16 bits
359 var value: f32;
360 if (raw >= 32768u) {
361 value = f32(i32(raw) - 65536);
362 } else {
363 value = f32(raw);
364 }"#,
365 ),
366 GpuDataType::U32 => (
367 "u32",
368 r#"
369 let value = f32(input[idx]);"#,
370 ),
371 GpuDataType::I32 => (
372 "u32",
373 r#"
374 let value = f32(bitcast<i32>(input[idx]));"#,
375 ),
376 GpuDataType::F32 => (
377 "f32",
378 r#"
379 let value = input[idx];"#,
380 ),
381 GpuDataType::F64Emulated => (
382 "vec2<f32>",
383 r#"
384 // Emulate f64 using two f32s (high and low parts)
385 let packed = input[idx];
386 // This is a simplified conversion - full f64 support would need more complex handling
387 let value = packed.x + packed.y;"#,
388 ),
389 };
390
391 format!(
392 r#"
393struct ConversionParams {{
394 scale: f32,
395 offset: f32,
396 out_min: f32,
397 out_max: f32,
398 nodata_in: f32,
399 nodata_out: f32,
400 use_nodata: u32,
401 _padding: u32,
402}}
403
404@group(0) @binding(0) var<storage, read> input: array<{input_type}>;
405@group(0) @binding(1) var<uniform> params: ConversionParams;
406@group(0) @binding(2) var<storage, read_write> output: array<f32>;
407
408@compute @workgroup_size(256)
409fn convert_type(@builtin(global_invocation_id) global_id: vec3<u32>) {{
410 let idx = global_id.x;
411 let output_len = arrayLength(&output);
412
413 if (idx >= output_len) {{
414 return;
415 }}
416
417{unpack_code}
418
419 // Check for nodata
420 if (params.use_nodata != 0u && abs(value - params.nodata_in) < 1e-6) {{
421 output[idx] = params.nodata_out;
422 return;
423 }}
424
425 // Apply scale and offset
426 var result = value * params.scale + params.offset;
427
428 // Clamp to output range
429 result = clamp(result, params.out_min, params.out_max);
430
431 output[idx] = result;
432}}
433"#,
434 input_type = input_type,
435 unpack_code = unpack_code
436 )
437 }
438
439 pub fn execute<T: Pod>(
445 &self,
446 input: &GpuBuffer<T>,
447 output: &mut GpuBuffer<f32>,
448 params: &ConversionParams,
449 ) -> GpuResult<()> {
450 let params_buffer = GpuBuffer::from_data(
452 &self.context,
453 &[*params],
454 BufferUsages::UNIFORM | BufferUsages::COPY_DST,
455 )?;
456
457 let bind_group = self
458 .context
459 .device()
460 .create_bind_group(&BindGroupDescriptor {
461 label: Some("DataTypeConversionKernel BindGroup"),
462 layout: &self.bind_group_layout,
463 entries: &[
464 BindGroupEntry {
465 binding: 0,
466 resource: input.buffer().as_entire_binding(),
467 },
468 BindGroupEntry {
469 binding: 1,
470 resource: params_buffer.buffer().as_entire_binding(),
471 },
472 BindGroupEntry {
473 binding: 2,
474 resource: output.buffer().as_entire_binding(),
475 },
476 ],
477 });
478
479 let mut encoder = self
480 .context
481 .device()
482 .create_command_encoder(&CommandEncoderDescriptor {
483 label: Some("DataTypeConversionKernel Encoder"),
484 });
485
486 {
487 let mut compute_pass = encoder.begin_compute_pass(&ComputePassDescriptor {
488 label: Some("DataTypeConversionKernel Pass"),
489 timestamp_writes: None,
490 });
491
492 compute_pass.set_pipeline(&self.pipeline);
493 compute_pass.set_bind_group(0, &bind_group, &[]);
494
495 let num_workgroups =
496 (output.len() as u32 + self.workgroup_size - 1) / self.workgroup_size;
497 compute_pass.dispatch_workgroups(num_workgroups, 1, 1);
498 }
499
500 self.context.queue().submit(Some(encoder.finish()));
501
502 debug!(
503 "Executed type conversion kernel on {} elements",
504 output.len()
505 );
506 Ok(())
507 }
508}
509
510pub struct F32ToTypeKernel {
512 context: GpuContext,
513 pipeline: WgpuComputePipeline,
514 bind_group_layout: wgpu::BindGroupLayout,
515 workgroup_size: u32,
516 dst_type: GpuDataType,
517}
518
519impl F32ToTypeKernel {
520 pub fn new(context: &GpuContext, dst_type: GpuDataType) -> GpuResult<Self> {
526 debug!(
527 "Creating data type conversion kernel for f32 -> {:?}",
528 dst_type
529 );
530
531 let shader_source = Self::conversion_shader(dst_type);
532 let mut shader = WgslShader::new(shader_source, "convert_to_type");
533 let shader_module = shader.compile(context.device())?;
534
535 let bind_group_layout = create_compute_bind_group_layout(
536 context.device(),
537 &[
538 storage_buffer_layout(0, true), uniform_buffer_layout(1), storage_buffer_layout(2, false), ],
542 Some("F32ToTypeKernel BindGroupLayout"),
543 )?;
544
545 let pipeline =
546 ComputePipelineBuilder::new(context.device(), shader_module, "convert_to_type")
547 .bind_group_layout(&bind_group_layout)
548 .label(format!("F32ToType Pipeline: f32 -> {:?}", dst_type))
549 .build()?;
550
551 Ok(Self {
552 context: context.clone(),
553 pipeline,
554 bind_group_layout,
555 workgroup_size: 256,
556 dst_type,
557 })
558 }
559
560 fn conversion_shader(dst_type: GpuDataType) -> String {
562 let (output_type, pack_code) = match dst_type {
563 GpuDataType::U8 => (
564 "u32",
565 r#"
566 // Pack 4 u8 values into one u32
567 let base_idx = idx * 4u;
568 var packed = 0u;
569
570 for (var i = 0u; i < 4u; i = i + 1u) {
571 let src_idx = base_idx + i;
572 if (src_idx < arrayLength(&input)) {
573 var value = input[src_idx];
574
575 // Check nodata
576 if (params.use_nodata != 0u && abs(value - params.nodata_in) < 1e-6) {
577 value = params.nodata_out;
578 }
579
580 // Apply scale and offset, then clamp
581 value = clamp(value * params.scale + params.offset, params.out_min, params.out_max);
582 let byte_val = u32(value) & 0xFFu;
583 packed = packed | (byte_val << (i * 8u));
584 }
585 }
586
587 output[idx] = packed;"#,
588 ),
589 GpuDataType::U16 => (
590 "u32",
591 r#"
592 // Pack 2 u16 values into one u32
593 let base_idx = idx * 2u;
594 var packed = 0u;
595
596 for (var i = 0u; i < 2u; i = i + 1u) {
597 let src_idx = base_idx + i;
598 if (src_idx < arrayLength(&input)) {
599 var value = input[src_idx];
600
601 if (params.use_nodata != 0u && abs(value - params.nodata_in) < 1e-6) {
602 value = params.nodata_out;
603 }
604
605 value = clamp(value * params.scale + params.offset, params.out_min, params.out_max);
606 let half_val = u32(value) & 0xFFFFu;
607 packed = packed | (half_val << (i * 16u));
608 }
609 }
610
611 output[idx] = packed;"#,
612 ),
613 GpuDataType::U32 => (
614 "u32",
615 r#"
616 var value = input[idx];
617
618 if (params.use_nodata != 0u && abs(value - params.nodata_in) < 1e-6) {
619 value = params.nodata_out;
620 }
621
622 value = clamp(value * params.scale + params.offset, params.out_min, params.out_max);
623 output[idx] = u32(value);"#,
624 ),
625 GpuDataType::I8 => (
626 "u32",
627 r#"
628 // Pack 4 i8 values into one u32
629 let base_idx = idx * 4u;
630 var packed = 0u;
631
632 for (var i = 0u; i < 4u; i = i + 1u) {
633 let src_idx = base_idx + i;
634 if (src_idx < arrayLength(&input)) {
635 var value = input[src_idx];
636
637 if (params.use_nodata != 0u && abs(value - params.nodata_in) < 1e-6) {
638 value = params.nodata_out;
639 }
640
641 value = clamp(value * params.scale + params.offset, params.out_min, params.out_max);
642 var byte_val: u32;
643 if (value < 0.0) {
644 byte_val = u32(i32(value) + 256) & 0xFFu;
645 } else {
646 byte_val = u32(value) & 0xFFu;
647 }
648 packed = packed | (byte_val << (i * 8u));
649 }
650 }
651
652 output[idx] = packed;"#,
653 ),
654 GpuDataType::I16 => (
655 "u32",
656 r#"
657 // Pack 2 i16 values into one u32
658 let base_idx = idx * 2u;
659 var packed = 0u;
660
661 for (var i = 0u; i < 2u; i = i + 1u) {
662 let src_idx = base_idx + i;
663 if (src_idx < arrayLength(&input)) {
664 var value = input[src_idx];
665
666 if (params.use_nodata != 0u && abs(value - params.nodata_in) < 1e-6) {
667 value = params.nodata_out;
668 }
669
670 value = clamp(value * params.scale + params.offset, params.out_min, params.out_max);
671 var half_val: u32;
672 if (value < 0.0) {
673 half_val = u32(i32(value) + 65536) & 0xFFFFu;
674 } else {
675 half_val = u32(value) & 0xFFFFu;
676 }
677 packed = packed | (half_val << (i * 16u));
678 }
679 }
680
681 output[idx] = packed;"#,
682 ),
683 GpuDataType::I32 => (
684 "u32",
685 r#"
686 var value = input[idx];
687
688 if (params.use_nodata != 0u && abs(value - params.nodata_in) < 1e-6) {
689 value = params.nodata_out;
690 }
691
692 value = clamp(value * params.scale + params.offset, params.out_min, params.out_max);
693 output[idx] = bitcast<u32>(i32(value));"#,
694 ),
695 GpuDataType::F32 => (
696 "f32",
697 r#"
698 var value = input[idx];
699
700 if (params.use_nodata != 0u && abs(value - params.nodata_in) < 1e-6) {
701 value = params.nodata_out;
702 }
703
704 output[idx] = clamp(value * params.scale + params.offset, params.out_min, params.out_max);"#,
705 ),
706 GpuDataType::F64Emulated => (
707 "vec2<f32>",
708 r#"
709 var value = input[idx];
710
711 if (params.use_nodata != 0u && abs(value - params.nodata_in) < 1e-6) {
712 value = params.nodata_out;
713 }
714
715 value = clamp(value * params.scale + params.offset, params.out_min, params.out_max);
716 // Split into high and low parts for f64 emulation
717 output[idx] = vec2<f32>(value, 0.0);"#,
718 ),
719 };
720
721 format!(
722 r#"
723struct ConversionParams {{
724 scale: f32,
725 offset: f32,
726 out_min: f32,
727 out_max: f32,
728 nodata_in: f32,
729 nodata_out: f32,
730 use_nodata: u32,
731 _padding: u32,
732}}
733
734@group(0) @binding(0) var<storage, read> input: array<f32>;
735@group(0) @binding(1) var<uniform> params: ConversionParams;
736@group(0) @binding(2) var<storage, read_write> output: array<{output_type}>;
737
738@compute @workgroup_size(256)
739fn convert_to_type(@builtin(global_invocation_id) global_id: vec3<u32>) {{
740 let idx = global_id.x;
741 let output_len = arrayLength(&output);
742
743 if (idx >= output_len) {{
744 return;
745 }}
746
747{pack_code}
748}}
749"#,
750 output_type = output_type,
751 pack_code = pack_code
752 )
753 }
754
755 pub fn execute<T: Pod>(
761 &self,
762 input: &GpuBuffer<f32>,
763 output: &mut GpuBuffer<T>,
764 params: &ConversionParams,
765 ) -> GpuResult<()> {
766 let params_buffer = GpuBuffer::from_data(
767 &self.context,
768 &[*params],
769 BufferUsages::UNIFORM | BufferUsages::COPY_DST,
770 )?;
771
772 let bind_group = self
773 .context
774 .device()
775 .create_bind_group(&BindGroupDescriptor {
776 label: Some("F32ToTypeKernel BindGroup"),
777 layout: &self.bind_group_layout,
778 entries: &[
779 BindGroupEntry {
780 binding: 0,
781 resource: input.buffer().as_entire_binding(),
782 },
783 BindGroupEntry {
784 binding: 1,
785 resource: params_buffer.buffer().as_entire_binding(),
786 },
787 BindGroupEntry {
788 binding: 2,
789 resource: output.buffer().as_entire_binding(),
790 },
791 ],
792 });
793
794 let mut encoder = self
795 .context
796 .device()
797 .create_command_encoder(&CommandEncoderDescriptor {
798 label: Some("F32ToTypeKernel Encoder"),
799 });
800
801 {
802 let mut compute_pass = encoder.begin_compute_pass(&ComputePassDescriptor {
803 label: Some("F32ToTypeKernel Pass"),
804 timestamp_writes: None,
805 });
806
807 compute_pass.set_pipeline(&self.pipeline);
808 compute_pass.set_bind_group(0, &bind_group, &[]);
809
810 let num_workgroups =
811 (output.len() as u32 + self.workgroup_size - 1) / self.workgroup_size;
812 compute_pass.dispatch_workgroups(num_workgroups, 1, 1);
813 }
814
815 self.context.queue().submit(Some(encoder.finish()));
816
817 debug!(
818 "Executed f32 -> {:?} conversion on {} elements",
819 self.dst_type,
820 input.len()
821 );
822 Ok(())
823 }
824}
825
826pub struct BatchTypeConverter {
831 context: GpuContext,
832 tile_size: usize,
833}
834
835impl BatchTypeConverter {
836 pub fn new(context: &GpuContext) -> Self {
838 Self {
839 context: context.clone(),
840 tile_size: 1024 * 1024, }
842 }
843
844 pub fn with_tile_size(mut self, size: usize) -> Self {
846 self.tile_size = size;
847 self
848 }
849
850 pub fn convert_to_f32<T: Pod>(
858 &self,
859 input: &GpuBuffer<T>,
860 src_type: GpuDataType,
861 params: &ConversionParams,
862 ) -> GpuResult<GpuBuffer<f32>> {
863 let kernel = DataTypeConversionKernel::new(&self.context, src_type)?;
864
865 let output_len = match src_type {
867 GpuDataType::U8 | GpuDataType::I8 => input.len() * 4,
868 GpuDataType::U16 | GpuDataType::I16 => input.len() * 2,
869 _ => input.len(),
870 };
871
872 let mut output = GpuBuffer::new(
873 &self.context,
874 output_len,
875 BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
876 )?;
877
878 kernel.execute(input, &mut output, params)?;
879
880 Ok(output)
881 }
882
883 pub fn convert_from_f32<T: Pod>(
889 &self,
890 input: &GpuBuffer<f32>,
891 dst_type: GpuDataType,
892 params: &ConversionParams,
893 ) -> GpuResult<GpuBuffer<T>> {
894 let kernel = F32ToTypeKernel::new(&self.context, dst_type)?;
895
896 let output_len = match dst_type {
898 GpuDataType::U8 | GpuDataType::I8 => (input.len() + 3) / 4,
899 GpuDataType::U16 | GpuDataType::I16 => (input.len() + 1) / 2,
900 _ => input.len(),
901 };
902
903 let mut output = GpuBuffer::new(
904 &self.context,
905 output_len,
906 BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
907 )?;
908
909 kernel.execute(input, &mut output, params)?;
910
911 Ok(output)
912 }
913}
914
915pub struct ComputePipeline<T: Pod> {
925 context: GpuContext,
926 current_buffer: GpuBuffer<T>,
927 width: u32,
928 height: u32,
929 _phantom: PhantomData<T>,
930}
931
932impl<T: Pod + Zeroable> ComputePipeline<T> {
933 pub fn new(
935 context: &GpuContext,
936 input: GpuBuffer<T>,
937 width: u32,
938 height: u32,
939 ) -> GpuResult<Self> {
940 let expected_size = (width as usize) * (height as usize);
941 if input.len() != expected_size {
942 return Err(GpuError::invalid_kernel_params(format!(
943 "Buffer size mismatch: expected {}, got {}",
944 expected_size,
945 input.len()
946 )));
947 }
948
949 Ok(Self {
950 context: context.clone(),
951 current_buffer: input,
952 width,
953 height,
954 _phantom: PhantomData,
955 })
956 }
957
958 pub fn from_data(context: &GpuContext, data: &[T], width: u32, height: u32) -> GpuResult<Self> {
960 let buffer = GpuBuffer::from_data(
961 context,
962 data,
963 BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
964 )?;
965
966 Self::new(context, buffer, width, height)
967 }
968
969 pub fn buffer(&self) -> &GpuBuffer<T> {
971 &self.current_buffer
972 }
973
974 pub fn dimensions(&self) -> (u32, u32) {
976 (self.width, self.height)
977 }
978
979 pub fn element_wise(mut self, op: ElementWiseOp, other: &GpuBuffer<T>) -> GpuResult<Self> {
981 debug!("Pipeline: applying {:?}", op);
982
983 let kernel = RasterKernel::new(&self.context, op)?;
984 let mut output = GpuBuffer::new(
985 &self.context,
986 self.current_buffer.len(),
987 BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
988 )?;
989
990 kernel.execute(&self.current_buffer, other, &mut output)?;
991 self.current_buffer = output;
992
993 Ok(self)
994 }
995
996 pub fn unary(mut self, op: UnaryOp) -> GpuResult<Self> {
998 debug!("Pipeline: applying unary {:?}", op);
999
1000 let kernel = UnaryKernel::new(&self.context, op)?;
1001 let mut output = GpuBuffer::new(
1002 &self.context,
1003 self.current_buffer.len(),
1004 BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
1005 )?;
1006
1007 kernel.execute(&self.current_buffer, &mut output)?;
1008 self.current_buffer = output;
1009
1010 Ok(self)
1011 }
1012
1013 pub fn scalar(mut self, op: ScalarOp) -> GpuResult<Self> {
1015 debug!("Pipeline: applying scalar {:?}", op);
1016
1017 let kernel = ScalarKernel::new(&self.context, op)?;
1018 let mut output = GpuBuffer::new(
1019 &self.context,
1020 self.current_buffer.len(),
1021 BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
1022 )?;
1023
1024 kernel.execute(&self.current_buffer, &mut output)?;
1025 self.current_buffer = output;
1026
1027 Ok(self)
1028 }
1029
1030 pub fn gaussian_blur(mut self, sigma: f32) -> GpuResult<Self> {
1032 debug!("Pipeline: applying Gaussian blur (sigma={})", sigma);
1033
1034 let output = gaussian_blur(
1035 &self.context,
1036 &self.current_buffer,
1037 self.width,
1038 self.height,
1039 sigma,
1040 )?;
1041 self.current_buffer = output;
1042
1043 Ok(self)
1044 }
1045
1046 pub fn resize(
1048 mut self,
1049 new_width: u32,
1050 new_height: u32,
1051 method: ResamplingMethod,
1052 ) -> GpuResult<Self> {
1053 debug!(
1054 "Pipeline: resizing {}x{} -> {}x{} ({:?})",
1055 self.width, self.height, new_width, new_height, method
1056 );
1057
1058 let output = resize(
1059 &self.context,
1060 &self.current_buffer,
1061 self.width,
1062 self.height,
1063 new_width,
1064 new_height,
1065 method,
1066 )?;
1067
1068 self.width = new_width;
1069 self.height = new_height;
1070 self.current_buffer = output;
1071
1072 Ok(self)
1073 }
1074
1075 pub fn add(self, value: f32) -> GpuResult<Self> {
1077 self.scalar(ScalarOp::Add(value))
1078 }
1079
1080 pub fn multiply(self, value: f32) -> GpuResult<Self> {
1082 self.scalar(ScalarOp::Multiply(value))
1083 }
1084
1085 pub fn clamp(self, min: f32, max: f32) -> GpuResult<Self> {
1087 self.scalar(ScalarOp::Clamp { min, max })
1088 }
1089
1090 pub fn threshold(self, threshold: f32, above: f32, below: f32) -> GpuResult<Self> {
1092 self.scalar(ScalarOp::Threshold {
1093 threshold,
1094 above,
1095 below,
1096 })
1097 }
1098
1099 pub fn abs(self) -> GpuResult<Self> {
1101 self.unary(UnaryOp::Abs)
1102 }
1103
1104 pub fn sqrt(self) -> GpuResult<Self> {
1106 self.unary(UnaryOp::Sqrt)
1107 }
1108
1109 pub fn log(self) -> GpuResult<Self> {
1111 self.unary(UnaryOp::Log)
1112 }
1113
1114 pub fn exp(self) -> GpuResult<Self> {
1116 self.unary(UnaryOp::Exp)
1117 }
1118
1119 pub async fn statistics(&self) -> GpuResult<Statistics> {
1124 let staging = GpuBuffer::staging(&self.context, self.current_buffer.len())?;
1127 let mut staging_mut = staging.clone();
1128 staging_mut.copy_from(&self.current_buffer)?;
1129
1130 let data = staging.read().await?;
1132 let f32_data: Vec<f32> = data
1133 .into_iter()
1134 .map(|v: T| {
1135 let bytes = bytemuck::bytes_of(&v);
1137 if bytes.len() == 4 {
1138 f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
1140 } else {
1141 0.0f32
1143 }
1144 })
1145 .collect();
1146
1147 let input_buffer = GpuBuffer::from_data(
1148 &self.context,
1149 &f32_data,
1150 BufferUsages::STORAGE | BufferUsages::COPY_SRC,
1151 )?;
1152
1153 compute_statistics(&self.context, &input_buffer).await
1155 }
1156
1157 pub async fn statistics_with_conversion(
1161 &self,
1162 src_type: GpuDataType,
1163 params: &ConversionParams,
1164 ) -> GpuResult<Statistics> {
1165 let converter = BatchTypeConverter::new(&self.context);
1166 let f32_buffer = converter.convert_to_f32(&self.current_buffer, src_type, params)?;
1167 compute_statistics(&self.context, &f32_buffer).await
1168 }
1169
1170 pub async fn histogram(
1172 &self,
1173 num_bins: u32,
1174 min_value: f32,
1175 max_value: f32,
1176 ) -> GpuResult<Vec<u32>> {
1177 let kernel = HistogramKernel::new(&self.context)?;
1178 let params = HistogramParams::new(num_bins, min_value, max_value);
1179 kernel.execute(&self.current_buffer, params).await
1180 }
1181
1182 pub async fn reduce(&self, op: ReductionOp) -> GpuResult<T>
1184 where
1185 T: Copy,
1186 {
1187 let kernel = ReductionKernel::new(&self.context, op)?;
1188 kernel.execute(&self.current_buffer, op).await
1189 }
1190
1191 pub fn finish(self) -> GpuBuffer<T> {
1193 self.current_buffer
1194 }
1195
1196 pub async fn read(self) -> GpuResult<Vec<T>> {
1198 let staging = GpuBuffer::staging(&self.context, self.current_buffer.len())?;
1199 let mut staging_mut = staging.clone();
1200 staging_mut.copy_from(&self.current_buffer)?;
1201 staging.read().await
1202 }
1203
1204 pub fn read_blocking(self) -> GpuResult<Vec<T>> {
1206 pollster::block_on(self.read())
1207 }
1208
1209 pub fn convert_to_f32(
1217 self,
1218 src_type: GpuDataType,
1219 params: &ConversionParams,
1220 ) -> GpuResult<ComputePipeline<f32>> {
1221 let converter = BatchTypeConverter::new(&self.context);
1222 let f32_buffer = converter.convert_to_f32(&self.current_buffer, src_type, params)?;
1223
1224 Ok(ComputePipeline {
1225 context: self.context,
1226 current_buffer: f32_buffer,
1227 width: self.width,
1228 height: self.height,
1229 _phantom: PhantomData,
1230 })
1231 }
1232
1233 pub fn linear_transform(self, scale: f32, offset: f32) -> GpuResult<Self> {
1237 self.scalar(ScalarOp::Multiply(scale))?
1238 .scalar(ScalarOp::Add(offset))
1239 }
1240
1241 pub fn normalize_range(
1245 self,
1246 current_min: f32,
1247 current_max: f32,
1248 new_min: f32,
1249 new_max: f32,
1250 ) -> GpuResult<Self> {
1251 let current_range = current_max - current_min;
1252 let new_range = new_max - new_min;
1253
1254 if current_range.abs() < 1e-10 {
1255 return Err(GpuError::invalid_kernel_params(
1256 "Current range is too small for normalization",
1257 ));
1258 }
1259
1260 let scale = new_range / current_range;
1261 let offset = new_min - current_min * scale;
1262
1263 self.linear_transform(scale, offset)
1264 }
1265}
1266
1267impl ComputePipeline<f32> {
1269 pub fn convert_to_type<U: Pod + Zeroable>(
1275 self,
1276 dst_type: GpuDataType,
1277 params: &ConversionParams,
1278 ) -> GpuResult<ComputePipeline<U>> {
1279 let converter = BatchTypeConverter::new(&self.context);
1280 let output_buffer: GpuBuffer<U> =
1281 converter.convert_from_f32(&self.current_buffer, dst_type, params)?;
1282
1283 let (new_width, new_height) = match dst_type {
1285 GpuDataType::U8 | GpuDataType::I8 => {
1286 let total_elements = (self.width * self.height) as usize;
1288 let packed_len = (total_elements + 3) / 4;
1289 (packed_len as u32, 1)
1290 }
1291 GpuDataType::U16 | GpuDataType::I16 => {
1292 let total_elements = (self.width * self.height) as usize;
1294 let packed_len = (total_elements + 1) / 2;
1295 (packed_len as u32, 1)
1296 }
1297 _ => (self.width, self.height),
1298 };
1299
1300 Ok(ComputePipeline {
1301 context: self.context,
1302 current_buffer: output_buffer,
1303 width: new_width,
1304 height: new_height,
1305 _phantom: PhantomData,
1306 })
1307 }
1308
1309 pub fn from_u8_normalized(
1315 context: &GpuContext,
1316 data: &[u8],
1317 width: u32,
1318 height: u32,
1319 ) -> GpuResult<Self> {
1320 let f32_data: Vec<f32> = data.iter().map(|&v| v as f32 / 255.0).collect();
1322 Self::from_data(context, &f32_data, width, height)
1323 }
1324
1325 pub fn from_u16_normalized(
1331 context: &GpuContext,
1332 data: &[u16],
1333 width: u32,
1334 height: u32,
1335 ) -> GpuResult<Self> {
1336 let f32_data: Vec<f32> = data.iter().map(|&v| v as f32 / 65535.0).collect();
1338 Self::from_data(context, &f32_data, width, height)
1339 }
1340
1341 pub fn scale_offset(self, scale: f32, offset: f32) -> GpuResult<Self> {
1345 if (scale - 1.0).abs() < 1e-10 && offset.abs() < 1e-10 {
1346 return Ok(self);
1348 }
1349
1350 self.linear_transform(scale, offset)
1351 }
1352}
1353
1354pub struct MultibandPipeline<T: Pod> {
1356 context: GpuContext,
1357 bands: Vec<ComputePipeline<T>>,
1358}
1359
1360impl<T: Pod + Zeroable> MultibandPipeline<T> {
1361 pub fn new(context: &GpuContext, raster: &GpuRasterBuffer<T>) -> GpuResult<Self> {
1363 let (width, height) = raster.dimensions();
1364 let bands = raster
1365 .bands()
1366 .iter()
1367 .map(|band| ComputePipeline::new(context, band.clone(), width, height))
1368 .collect::<GpuResult<Vec<_>>>()?;
1369
1370 Ok(Self {
1371 context: context.clone(),
1372 bands,
1373 })
1374 }
1375
1376 pub fn num_bands(&self) -> usize {
1378 self.bands.len()
1379 }
1380
1381 pub fn band(&self, index: usize) -> Option<&ComputePipeline<T>> {
1383 self.bands.get(index)
1384 }
1385
1386 pub fn map<F>(mut self, mut f: F) -> GpuResult<Self>
1388 where
1389 F: FnMut(ComputePipeline<T>) -> GpuResult<ComputePipeline<T>>,
1390 {
1391 self.bands = self
1392 .bands
1393 .into_iter()
1394 .map(|band| f(band))
1395 .collect::<GpuResult<Vec<_>>>()?;
1396
1397 Ok(self)
1398 }
1399
1400 pub fn ndvi(self) -> GpuResult<ComputePipeline<T>> {
1408 if self.bands.len() < 4 {
1409 return Err(GpuError::invalid_kernel_params(
1410 "NDVI requires at least 4 bands (R,G,B,NIR)",
1411 ));
1412 }
1413
1414 let nir = self
1416 .bands
1417 .get(3)
1418 .ok_or_else(|| GpuError::internal("Missing NIR band"))?;
1419 let red = self
1420 .bands
1421 .get(0)
1422 .ok_or_else(|| GpuError::internal("Missing Red band"))?;
1423
1424 let nir_buffer = nir.buffer().clone();
1427 let red_buffer = red.buffer().clone();
1428
1429 let width = nir.width;
1430 let height = nir.height;
1431
1432 let diff_kernel = RasterKernel::new(&self.context, ElementWiseOp::Subtract)?;
1434 let mut diff_buffer = GpuBuffer::new(
1435 &self.context,
1436 nir_buffer.len(),
1437 BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
1438 )?;
1439 diff_kernel.execute(&nir_buffer, &red_buffer, &mut diff_buffer)?;
1440
1441 let sum_kernel = RasterKernel::new(&self.context, ElementWiseOp::Add)?;
1443 let mut sum_buffer = GpuBuffer::new(
1444 &self.context,
1445 nir_buffer.len(),
1446 BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
1447 )?;
1448 sum_kernel.execute(&nir_buffer, &red_buffer, &mut sum_buffer)?;
1449
1450 let div_kernel = RasterKernel::new(&self.context, ElementWiseOp::Divide)?;
1452 let mut ndvi_buffer = GpuBuffer::new(
1453 &self.context,
1454 nir_buffer.len(),
1455 BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
1456 )?;
1457 div_kernel.execute(&diff_buffer, &sum_buffer, &mut ndvi_buffer)?;
1458
1459 ComputePipeline::new(&self.context, ndvi_buffer, width, height)
1460 }
1461
1462 pub fn finish(self) -> Vec<GpuBuffer<T>> {
1464 self.bands.into_iter().map(|b| b.finish()).collect()
1465 }
1466
1467 pub async fn read_all(self) -> GpuResult<Vec<Vec<T>>> {
1469 let mut results = Vec::with_capacity(self.bands.len());
1470
1471 for band in self.bands {
1472 results.push(band.read().await?);
1473 }
1474
1475 Ok(results)
1476 }
1477}
1478
1479#[cfg(test)]
1480mod tests {
1481 use super::*;
1482
1483 #[tokio::test]
1484 async fn test_compute_pipeline() {
1485 if let Ok(context) = GpuContext::new().await {
1486 let data: Vec<f32> = (0..100).map(|i| i as f32).collect();
1487
1488 if let Ok(pipeline) = ComputePipeline::from_data(&context, &data, 10, 10) {
1489 if let Ok(result) = pipeline.add(5.0).and_then(|p| p.multiply(2.0)) {
1490 let _ = result.finish();
1492 }
1493 }
1494 }
1495 }
1496
1497 #[tokio::test]
1498 #[ignore]
1499 async fn test_pipeline_chaining() {
1500 if let Ok(context) = GpuContext::new().await {
1501 let data: Vec<f32> = vec![1.0; 64 * 64];
1502
1503 if let Ok(pipeline) = ComputePipeline::from_data(&context, &data, 64, 64) {
1504 if let Ok(result) = pipeline
1505 .add(10.0)
1506 .and_then(|p| p.multiply(2.0))
1507 .and_then(|p| p.clamp(0.0, 100.0))
1508 {
1509 let stats = result.statistics().await;
1510 if let Ok(stats) = stats {
1511 println!("Mean: {}", stats.mean());
1512 }
1513 }
1514 }
1515 }
1516 }
1517
1518 #[test]
1523 fn test_gpu_data_type_properties() {
1524 assert_eq!(GpuDataType::U8.size_bytes(), 1);
1526 assert_eq!(GpuDataType::U16.size_bytes(), 2);
1527 assert_eq!(GpuDataType::U32.size_bytes(), 4);
1528 assert_eq!(GpuDataType::F32.size_bytes(), 4);
1529 assert_eq!(GpuDataType::F64Emulated.size_bytes(), 8);
1530
1531 assert_eq!(GpuDataType::U8.min_value(), 0.0);
1533 assert_eq!(GpuDataType::U8.max_value(), 255.0);
1534 assert_eq!(GpuDataType::I8.min_value(), -128.0);
1535 assert_eq!(GpuDataType::I8.max_value(), 127.0);
1536 assert_eq!(GpuDataType::U16.max_value(), 65535.0);
1537
1538 assert!(!GpuDataType::U8.is_signed());
1540 assert!(GpuDataType::I8.is_signed());
1541 assert!(GpuDataType::F32.is_signed());
1542
1543 assert!(!GpuDataType::U8.is_float());
1545 assert!(GpuDataType::F32.is_float());
1546 assert!(GpuDataType::F64Emulated.is_float());
1547 }
1548
1549 #[test]
1550 fn test_conversion_params_default() {
1551 let params = ConversionParams::default();
1552 assert_eq!(params.scale, 1.0);
1553 assert_eq!(params.offset, 0.0);
1554 assert_eq!(params.use_nodata, 0);
1555 }
1556
1557 #[test]
1558 fn test_conversion_params_u8_to_normalized() {
1559 let params = ConversionParams::u8_to_normalized();
1560 assert!((params.scale - (1.0 / 255.0)).abs() < 1e-6);
1561 assert_eq!(params.offset, 0.0);
1562 assert_eq!(params.out_min, 0.0);
1563 assert_eq!(params.out_max, 1.0);
1564 }
1565
1566 #[test]
1567 fn test_conversion_params_normalized_to_u8() {
1568 let params = ConversionParams::normalized_to_u8();
1569 assert_eq!(params.scale, 255.0);
1570 assert_eq!(params.offset, 0.0);
1571 assert_eq!(params.out_min, 0.0);
1572 assert_eq!(params.out_max, 255.0);
1573 }
1574
1575 #[test]
1576 fn test_conversion_params_with_clamp() {
1577 let params = ConversionParams::new(2.0, 10.0).with_clamp(0.0, 100.0);
1578 assert_eq!(params.scale, 2.0);
1579 assert_eq!(params.offset, 10.0);
1580 assert_eq!(params.out_min, 0.0);
1581 assert_eq!(params.out_max, 100.0);
1582 }
1583
1584 #[test]
1585 fn test_conversion_params_with_nodata() {
1586 let params = ConversionParams::default().with_nodata(-9999.0, f32::NAN);
1587 assert_eq!(params.nodata_in, -9999.0);
1588 assert_eq!(params.use_nodata, 1);
1589 }
1590
1591 #[test]
1592 fn test_conversion_params_for_type_conversion() {
1593 let params = ConversionParams::for_type_conversion(GpuDataType::U8, GpuDataType::U16);
1595 let expected_scale = 65535.0 / 255.0;
1596 assert!((params.scale - expected_scale as f32).abs() < 0.01);
1597 }
1598
1599 #[tokio::test]
1600 async fn test_data_type_conversion_kernel_creation() {
1601 if let Ok(context) = GpuContext::new().await {
1602 for dtype in &[
1604 GpuDataType::U8,
1605 GpuDataType::U16,
1606 GpuDataType::U32,
1607 GpuDataType::I8,
1608 GpuDataType::I16,
1609 GpuDataType::I32,
1610 GpuDataType::F32,
1611 ] {
1612 let result = DataTypeConversionKernel::new(&context, *dtype);
1613 assert!(result.is_ok(), "Failed to create kernel for {:?}", dtype);
1614 }
1615 }
1616 }
1617
1618 #[tokio::test]
1619 async fn test_f32_to_type_kernel_creation() {
1620 if let Ok(context) = GpuContext::new().await {
1621 for dtype in &[
1622 GpuDataType::U8,
1623 GpuDataType::U16,
1624 GpuDataType::U32,
1625 GpuDataType::F32,
1626 ] {
1627 let result = F32ToTypeKernel::new(&context, *dtype);
1628 assert!(
1629 result.is_ok(),
1630 "Failed to create F32ToType kernel for {:?}",
1631 dtype
1632 );
1633 }
1634 }
1635 }
1636
1637 #[tokio::test]
1638 async fn test_batch_type_converter() {
1639 if let Ok(context) = GpuContext::new().await {
1640 let converter = BatchTypeConverter::new(&context);
1641
1642 let f32_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
1644 if let Ok(buffer) = GpuBuffer::from_data(
1645 &context,
1646 &f32_data,
1647 BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
1648 ) {
1649 let params = ConversionParams::default();
1650 let result = converter.convert_to_f32(&buffer, GpuDataType::F32, ¶ms);
1651 assert!(result.is_ok());
1652 }
1653 }
1654 }
1655
1656 #[tokio::test]
1657 #[ignore]
1658 async fn test_pipeline_with_u8_normalized() {
1659 if let Ok(context) = GpuContext::new().await {
1660 let u8_data: Vec<u8> = (0..100).collect();
1661
1662 if let Ok(pipeline) =
1663 ComputePipeline::<f32>::from_u8_normalized(&context, &u8_data, 10, 10)
1664 {
1665 if let Ok(data) = pipeline.read_blocking() {
1667 assert!(data[0].abs() < 1e-6);
1669 let expected = 99.0 / 255.0;
1671 assert!((data[99] - expected).abs() < 1e-4);
1672 }
1673 }
1674 }
1675 }
1676
1677 #[tokio::test]
1678 #[ignore]
1679 async fn test_pipeline_linear_transform() {
1680 if let Ok(context) = GpuContext::new().await {
1681 let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
1682
1683 if let Ok(pipeline) = ComputePipeline::from_data(&context, &data, 2, 2) {
1684 if let Ok(result) = pipeline.linear_transform(2.0, 10.0) {
1686 if let Ok(output) = result.read_blocking() {
1687 assert!((output[0] - 12.0).abs() < 1e-4); assert!((output[1] - 14.0).abs() < 1e-4); assert!((output[2] - 16.0).abs() < 1e-4); assert!((output[3] - 18.0).abs() < 1e-4); }
1692 }
1693 }
1694 }
1695 }
1696
1697 #[tokio::test]
1698 #[ignore]
1699 async fn test_pipeline_normalize_range() {
1700 if let Ok(context) = GpuContext::new().await {
1701 let data: Vec<f32> = vec![0.0, 50.0, 100.0, 25.0];
1703
1704 if let Ok(pipeline) = ComputePipeline::from_data(&context, &data, 2, 2) {
1705 if let Ok(result) = pipeline.normalize_range(0.0, 100.0, 0.0, 1.0) {
1707 if let Ok(output) = result.read_blocking() {
1708 assert!(output[0].abs() < 1e-4); assert!((output[1] - 0.5).abs() < 1e-4); assert!((output[2] - 1.0).abs() < 1e-4); assert!((output[3] - 0.25).abs() < 1e-4); }
1713 }
1714 }
1715 }
1716 }
1717
1718 #[tokio::test]
1719 #[ignore]
1720 async fn test_pipeline_scale_offset_noop() {
1721 if let Ok(context) = GpuContext::new().await {
1722 let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
1723
1724 if let Ok(pipeline) = ComputePipeline::from_data(&context, &data, 2, 2) {
1725 if let Ok(result) = pipeline.scale_offset(1.0, 0.0) {
1727 if let Ok(output) = result.read_blocking() {
1728 for (i, &v) in output.iter().enumerate() {
1729 assert!((v - data[i]).abs() < 1e-6);
1730 }
1731 }
1732 }
1733 }
1734 }
1735 }
1736
1737 #[test]
1738 fn test_gpu_data_type_wgsl_storage_type() {
1739 assert_eq!(GpuDataType::U8.wgsl_storage_type(), "u32");
1741 assert_eq!(GpuDataType::F32.wgsl_storage_type(), "f32");
1742 assert_eq!(GpuDataType::F64Emulated.wgsl_storage_type(), "vec2<f32>");
1743 }
1744}