arr_o_gpu/arr_o_gpu/module/method/function/broadcasting/
broadcast.rs

1use std::sync::Arc;
2
3use wgpu::{
4    ComputePassDescriptor,
5    ComputePipelineDescriptor,
6    PipelineCompilationOptions,
7    PipelineLayoutDescriptor,
8    ShaderModuleDescriptor,
9    ShaderSource,
10    wgt::{ CommandEncoderDescriptor, PollType },
11};
12
13use crate::{ ArrOgpuErr, ArrOgpuModule, GpuArray, broadcast_bind_group, get_stride_from_shape };
14
15impl ArrOgpuModule {
16    pub fn broadcasting(&self, arr: &GpuArray, broadcast: &[u32]) -> Result<GpuArray, ArrOgpuErr> {
17        let arr_shape = arr.shape();
18        if arr_shape.len() > broadcast.len() {
19            let err = format!(
20                "Array Broadcasting Error, Array {:?} can't Broadcast to {:?}",
21                arr_shape,
22                broadcast
23            );
24
25            return Err(ArrOgpuErr::Broadcast(err));
26        }
27
28        // expend shape
29        let diff_range = broadcast.len() - arr_shape.len();
30        let extend_arr_shape = if diff_range != 0 {
31            let mut extend = vec![1;diff_range];
32            extend.extend_from_slice(arr_shape);
33            extend
34        } else {
35            arr_shape.clone()
36        };
37
38        // validation every dim
39        let mut broadcast_index = None;
40        for i in (0..broadcast.len()).rev() {
41            let arr_dim = extend_arr_shape[i];
42            let broadcast_dim = broadcast[i];
43
44            if arr_dim != broadcast_dim {
45                if arr_dim == 1 {
46                    if let None = broadcast_index {
47                        broadcast_index = Some(i);
48                    } else {
49                        let err = format!(
50                            "Array Broadcasting Error, Array {:?} can't Broadcast to {:?} Cause Broadcast can't over then one target",
51                            arr_shape,
52                            broadcast
53                        );
54
55                        return Err(ArrOgpuErr::Broadcast(err));
56                    }
57                } else {
58                    let err = format!(
59                        "Array Broadcasting Error, Array {:?} can't Broadcast to {:?}",
60                        arr_shape,
61                        broadcast
62                    );
63
64                    return Err(ArrOgpuErr::Broadcast(err));
65                }
66            }
67        }
68
69        let wgpu_init = self.wgpu_init.read().unwrap();
70        let allocator = self.allocator_write();
71        // binding
72        let heap_binding = &self.binding_compounds.read().unwrap()[0];
73        let (bind_group_layout, bind_group, thread_limit, stride_out, output_allocate) =
74            broadcast_bind_group(
75                &wgpu_init.device,
76                arr,
77                allocator,
78                broadcast,
79                broadcast_index.unwrap()
80            );
81
82        // pipeline
83        let pipeline_layout = wgpu_init.device.create_pipeline_layout(
84            &(PipelineLayoutDescriptor {
85                label: Some("Create Pipeline Layout For Broadcast"),
86                bind_group_layouts: &[&heap_binding.binding_group_layouts, &bind_group_layout],
87                push_constant_ranges: &[],
88            })
89        );
90
91        let shader = wgpu_init.device.create_shader_module(ShaderModuleDescriptor {
92            label: Some("Create Shaders For Broadcast"),
93            source: ShaderSource::Wgsl(include_str!("./broadcast.wgsl").into()),
94        });
95
96        let pipeline = wgpu_init.device.create_compute_pipeline(
97            &(ComputePipelineDescriptor {
98                label: Some("()"),
99                cache: None,
100                compilation_options: PipelineCompilationOptions::default(),
101                entry_point: Some("main"),
102                layout: Some(&pipeline_layout),
103                module: &shader,
104            })
105        );
106
107        let mut encoder = wgpu_init.device.create_command_encoder(
108            &(CommandEncoderDescriptor {
109                label: Some("Create Encoder For Broadcast"),
110            })
111        );
112
113        {
114            let mut bcp = encoder.begin_compute_pass(
115                &(ComputePassDescriptor {
116                    label: Some("Create Compute Pass For Broadcast"),
117                    timestamp_writes: None,
118                })
119            );
120
121            // pipeline
122            bcp.set_pipeline(&pipeline);
123
124            // group
125            // // 0
126            bcp.set_bind_group(0, Some(&heap_binding.binding_groups), &[]);
127            // // 1
128            bcp.set_bind_group(1, Some(&bind_group), &[]);
129
130            // dispatch workgroup
131            let x = ((thread_limit as f32) / 16.0).ceil() as u32;
132            let y = ((stride_out as f32) / 16.0).ceil() as u32;
133            bcp.dispatch_workgroups(x, y, 1);
134        }
135        wgpu_init.queue.submit(Some(encoder.finish()));
136        wgpu_init.device.poll(PollType::Wait).unwrap();
137
138        let out_shape = broadcast.to_vec();
139        let stride = get_stride_from_shape(&out_shape);
140        let len = out_shape.iter().product::<u32>() as usize;
141        let arr = GpuArray {
142            length: len,
143            module: Arc::new(self.clone()),
144            shape: out_shape,
145            stride,
146            pointer: (output_allocate.1 as usize, output_allocate.2 as usize),
147            space_type: output_allocate.0,
148        };
149        Ok(arr)
150    }
151}