arr_o_gpu/arr_o_gpu/module/method/function/broadcasting/
broadcast.rs1use std::sync::Arc;
2
3use wgpu::{
4 ComputePassDescriptor,
5 ComputePipelineDescriptor,
6 PipelineCompilationOptions,
7 PipelineLayoutDescriptor,
8 ShaderModuleDescriptor,
9 ShaderSource,
10 wgt::{ CommandEncoderDescriptor, PollType },
11};
12
13use crate::{ ArrOgpuErr, ArrOgpuModule, GpuArray, broadcast_bind_group, get_stride_from_shape };
14
15impl ArrOgpuModule {
16 pub fn broadcasting(&self, arr: &GpuArray, broadcast: &[u32]) -> Result<GpuArray, ArrOgpuErr> {
17 let arr_shape = arr.shape();
18 if arr_shape.len() > broadcast.len() {
19 let err = format!(
20 "Array Broadcasting Error, Array {:?} can't Broadcast to {:?}",
21 arr_shape,
22 broadcast
23 );
24
25 return Err(ArrOgpuErr::Broadcast(err));
26 }
27
28 let diff_range = broadcast.len() - arr_shape.len();
30 let extend_arr_shape = if diff_range != 0 {
31 let mut extend = vec![1;diff_range];
32 extend.extend_from_slice(arr_shape);
33 extend
34 } else {
35 arr_shape.clone()
36 };
37
38 let mut broadcast_index = None;
40 for i in (0..broadcast.len()).rev() {
41 let arr_dim = extend_arr_shape[i];
42 let broadcast_dim = broadcast[i];
43
44 if arr_dim != broadcast_dim {
45 if arr_dim == 1 {
46 if let None = broadcast_index {
47 broadcast_index = Some(i);
48 } else {
49 let err = format!(
50 "Array Broadcasting Error, Array {:?} can't Broadcast to {:?} Cause Broadcast can't over then one target",
51 arr_shape,
52 broadcast
53 );
54
55 return Err(ArrOgpuErr::Broadcast(err));
56 }
57 } else {
58 let err = format!(
59 "Array Broadcasting Error, Array {:?} can't Broadcast to {:?}",
60 arr_shape,
61 broadcast
62 );
63
64 return Err(ArrOgpuErr::Broadcast(err));
65 }
66 }
67 }
68
69 let wgpu_init = self.wgpu_init.read().unwrap();
70 let allocator = self.allocator_write();
71 let heap_binding = &self.binding_compounds.read().unwrap()[0];
73 let (bind_group_layout, bind_group, thread_limit, stride_out, output_allocate) =
74 broadcast_bind_group(
75 &wgpu_init.device,
76 arr,
77 allocator,
78 broadcast,
79 broadcast_index.unwrap()
80 );
81
82 let pipeline_layout = wgpu_init.device.create_pipeline_layout(
84 &(PipelineLayoutDescriptor {
85 label: Some("Create Pipeline Layout For Broadcast"),
86 bind_group_layouts: &[&heap_binding.binding_group_layouts, &bind_group_layout],
87 push_constant_ranges: &[],
88 })
89 );
90
91 let shader = wgpu_init.device.create_shader_module(ShaderModuleDescriptor {
92 label: Some("Create Shaders For Broadcast"),
93 source: ShaderSource::Wgsl(include_str!("./broadcast.wgsl").into()),
94 });
95
96 let pipeline = wgpu_init.device.create_compute_pipeline(
97 &(ComputePipelineDescriptor {
98 label: Some("()"),
99 cache: None,
100 compilation_options: PipelineCompilationOptions::default(),
101 entry_point: Some("main"),
102 layout: Some(&pipeline_layout),
103 module: &shader,
104 })
105 );
106
107 let mut encoder = wgpu_init.device.create_command_encoder(
108 &(CommandEncoderDescriptor {
109 label: Some("Create Encoder For Broadcast"),
110 })
111 );
112
113 {
114 let mut bcp = encoder.begin_compute_pass(
115 &(ComputePassDescriptor {
116 label: Some("Create Compute Pass For Broadcast"),
117 timestamp_writes: None,
118 })
119 );
120
121 bcp.set_pipeline(&pipeline);
123
124 bcp.set_bind_group(0, Some(&heap_binding.binding_groups), &[]);
127 bcp.set_bind_group(1, Some(&bind_group), &[]);
129
130 let x = ((thread_limit as f32) / 16.0).ceil() as u32;
132 let y = ((stride_out as f32) / 16.0).ceil() as u32;
133 bcp.dispatch_workgroups(x, y, 1);
134 }
135 wgpu_init.queue.submit(Some(encoder.finish()));
136 wgpu_init.device.poll(PollType::Wait).unwrap();
137
138 let out_shape = broadcast.to_vec();
139 let stride = get_stride_from_shape(&out_shape);
140 let len = out_shape.iter().product::<u32>() as usize;
141 let arr = GpuArray {
142 length: len,
143 module: Arc::new(self.clone()),
144 shape: out_shape,
145 stride,
146 pointer: (output_allocate.1 as usize, output_allocate.2 as usize),
147 space_type: output_allocate.0,
148 };
149 Ok(arr)
150 }
151}