1use std::{vec, collections::{HashMap, HashSet}, ops};
97use wgpu::{util::DeviceExt, BufferUsages};
98use std::vec::Vec;
99use pollster::FutureExt;
100
101struct Buffer {
105 buffer: Option<wgpu::Buffer>,
106 name: String,
107 size: u64,
108 type_: BufferType,
109 type_stride: i32,
110 data: Option<Vec<u8>>,
111 usage: BufferUsage,
112 wgpu_usage: wgpu::BufferUsages,
113 is_output: bool,
114 output_buf: Option<wgpu::Buffer>,
115}
116
117static MAX_BYTE_BUFFER_SIZE: usize = 134_217_728;
119static IDENT_SEPARATOR_CHARS: &str = " \t\n\r([,+-/*=%&|^~!<>{}";
120
121#[derive(Clone, Hash, PartialEq, Eq)]
122pub enum BufferType {
123 I32,
124 U32,
125 F32,
126 Bool
127}
128impl BufferType {
129 pub fn to_string(&self) -> String {
130 match &self {
131 BufferType::I32 => String::from("i32"),
132 BufferType::U32 => String::from("u32"),
133 BufferType::F32 => String::from("f32"),
134 BufferType::Bool => String::from("bool")
135 }
136 }
137 pub fn stride(&self) -> i32 {
138 match &self {
139 BufferType::Bool => 1,
140 _ => 4
141 }
142 }
143}
144
145
146#[derive(Clone)]
148pub enum BufferUsage {
149 ReadOnly,
150 WriteOnly,
151 ReadWrite
152}
153impl ops::BitOr<BufferUsages> for BufferUsage {
154 type Output = wgpu::BufferUsages;
155 fn bitor(self, rhs: wgpu::BufferUsages) -> Self::Output {
156 buffer_usage(self, false) | rhs
157 }
158}
159
160
161#[inline]
163fn buffer_usage(usage: BufferUsage, output: bool) -> wgpu::BufferUsages {
164 let out;
165 match usage {
166 BufferUsage::ReadOnly => {
167 out = wgpu::BufferUsages::MAP_WRITE
168 | wgpu::BufferUsages::STORAGE
169 },
170 BufferUsage::WriteOnly => {
171 out = wgpu::BufferUsages::STORAGE
172 },
173 BufferUsage::ReadWrite => {
174 out = wgpu::BufferUsages::STORAGE
175 }
176 }
177 if output {
178 out | wgpu::BufferUsages::COPY_SRC
179 }else {
180 out
181 }
182}
183
184#[derive(Clone)]
188pub enum Dispatch {
189 Linear(usize),
190 Custom(u32, u32, u32)
191}
192
193pub trait ToVecU8<T> {
196 fn convert(v: &Vec<T>) -> (Vec<u8>, BufferType, i32, usize);
198 fn get_output(v: OutputVec) -> Vec<T>;
199}
200impl ToVecU8<i32> for i32 {
201 fn convert(v: &Vec<i32>) -> (Vec<u8>, BufferType, i32, usize) {
202 (bytemuck::cast_slice::<i32, u8>(&v).to_vec(), BufferType::I32, 4, v.len())
203 }
204 fn get_output(v: OutputVec) -> Vec<i32> {
205 v.unwrap_i32()
206 }
207}
208impl ToVecU8<u32> for u32 {
209 fn convert(v: &Vec<u32>) -> (Vec<u8>, BufferType, i32, usize) {
210 (bytemuck::cast_slice::<u32, u8>(&v).to_vec(), BufferType::U32, 4, v.len())
211 }
212 fn get_output(v: OutputVec) -> Vec<u32> {
213 v.unwrap_u32()
214 }
215}
216impl ToVecU8<f32> for f32 {
217 fn convert(v: &Vec<f32>) -> (Vec<u8>, BufferType, i32, usize) {
218 (bytemuck::cast_slice::<f32, u8>(&v).to_vec(), BufferType::F32, 4, v.len())
219 }
220 fn get_output(v: OutputVec) -> Vec<f32> {
221 v.unwrap_f32()
222 }
223}
224impl ToVecU8<bool> for bool {
225 fn convert(v: &Vec<bool>) -> (Vec<u8>, BufferType, i32, usize) {
226 (v.iter().map(|&e| e as u8).collect::<Vec<_>>(), BufferType::Bool, 1, v.len())
227 }
228 fn get_output(v: OutputVec) -> Vec<bool> {
229 v.unwrap_bool()
230 }
231}
232
233#[derive(Debug)]
236pub enum OutputVec {
237 VecI32(Vec<i32>),
238 VecU32(Vec<u32>),
239 VecF32(Vec<f32>),
240 VecBool(Vec<bool>)
241}
242impl OutputVec {
243 pub fn unwrap_i32(self) -> Vec<i32> {
244 match self {
245 OutputVec::VecI32(val) => {
246 val
247 }
248 _ => {
249 panic!("value is not a u32!");
250 }
251 }
252 }
253 pub fn unwrap_u32(self) -> Vec<u32> {
254 match self {
255 OutputVec::VecU32(val) => {
256 val
257 }
258 _ => {
259 panic!("value is not a u32!, it's a {self:?}");
260 }
261 }
262 }
263 pub fn unwrap_f32(self) -> Vec<f32> {
264 match self {
265 OutputVec::VecF32(val) => {
266 val
267 }
268 _ => {
269 panic!("value is not a f32!");
270 }
271 }
272 }
273 pub fn unwrap_bool(self) -> Vec<bool> {
274 match self {
275 OutputVec::VecBool(val) => {
276 val
277 }
278 _ => {
279 panic!("value is not a bool!");
280 }
281 }
282 }
283}
284
285pub struct ShaderModule {
286 shader: wgpu::ShaderModule,
287 used_buffers_id: Vec<usize>,
288 dispatch: Dispatch
290}
291
292pub fn remove_comments(code: String) -> String{
293 let mut out = String::new();
294 let mut chars = code.chars();
295 while let Some(c) = chars.next() {
296 if c == '/' {
297 out.push(c);
298 let c = chars.next();
299 if Some('/') == c { out.pop();
301 while let Some(c) = chars.next() {
302 if c == '\n' {
303 out.push(c);
304 break;
305 }
306 }
307 }else if Some('*') == c { out.pop();
309 while let Some(c) = chars.next() {
310 if c == '*' {
311 if let Some('/') = chars.next() {
312 break;
313 }
314 }
315 }
316 }else {
317 if let Some(c) = c {
318 out.push(c);
319 }
320 }
321 }else {
322 out.push(c);
323 }
324 }
325 out
326}
327
328pub enum Command<'a> {
329 Shader(ShaderModule),
330 Copy(&'a str, &'a str),
331 Retrieve(&'a str)
332}
333
334pub struct Device {
336 adapter: wgpu::Adapter,
337 device: wgpu::Device,
338 queue: wgpu::Queue,
339 buffers: Vec<Buffer>,
340 buf_name_to_id: HashMap<String, usize>,
341 output_buffers_id: Vec<usize>,
342}
343impl Device {
344 pub fn new() -> Device {
346 async {
347 let instance = wgpu::Instance::new(wgpu::Backends::all());
348 let adapter = instance
349 .request_adapter(&wgpu::RequestAdapterOptions {
350 power_preference: wgpu::PowerPreference::HighPerformance,
351 compatible_surface: None,
352 force_fallback_adapter: false,
353 })
354 .await
355 .unwrap();
356 let (device, queue) = adapter
357 .request_device(&Default::default(), None)
358 .await
359 .unwrap();
360 Device {
361 adapter,
362 device,
363 queue,
364 buffers: vec![],
365 buf_name_to_id: HashMap::new(),
366 output_buffers_id: vec![],
367 }
368 }.block_on()
369 }
370
371 #[inline]
373 pub fn get_info(self) -> String {
374 format!("{:?}", self.adapter.get_info())
375 }
376
377 pub fn create_buffer(&mut self, name: &str, data_type: BufferType, size: usize, usage: BufferUsage, is_output: bool) {
387 let data_type_stride = data_type.stride();
388 let byte_size = (data_type_stride as usize) * size;
389 let id = self.buffers.len();
390 if is_output {
393 self.output_buffers_id.push(self.buffers.len())
394 }
395 self.buf_name_to_id.insert(name.to_string(), id);
397 self.buffers.push(Buffer {
399 buffer: None,
400 name: name.to_owned(),
401 size: byte_size as u64,
402 type_: data_type,
403 type_stride: data_type_stride as i32,
404 data: None,
405 usage: usage.clone(),
406 wgpu_usage: buffer_usage(usage, is_output),
407 is_output,
409 output_buf: None
410 });
411 }
412
413 pub fn create_buffer_from<T: ToVecU8<T>>(&mut self, name: &str, content: &Vec<T>, usage: BufferUsage, is_output: bool) {
422 let (raw_content, data_type, data_type_stride, size) = <T as ToVecU8<T>>::convert(content);
423 let byte_size = data_type_stride as u64 * size as u64;
424 let id = self.buffers.len();
425 if is_output {
428 self.output_buffers_id.push(id);
429 }
430 self.buf_name_to_id.insert(name.to_string(), id);
432 self.buffers.push(Buffer {
434 buffer: None,
435 size: byte_size,
436 name: name.to_owned(),
437 type_: data_type,
438 type_stride: data_type_stride,
439 data: Some(raw_content),
440 usage: usage.clone(),
441 wgpu_usage: buffer_usage(usage, is_output),
442 is_output,
444 output_buf: None
445 });
446 }
447
448 pub fn apply_buffer_usages(&mut self, buf_name: &str, usage: wgpu::BufferUsages, is_output: bool) {
452 let buf = &mut self.buffers[*self.buf_name_to_id.get(buf_name).expect("The buffer has not been created on this device (probably wrong name).")];
453 buf.wgpu_usage |= usage;
454 buf.is_output = is_output;
455 }
456
457 pub fn apply_on_vector<T: ToVecU8<T> + std::clone::Clone>(&mut self, vec: Vec<T>, code: &str) -> Vec<T>{
486 if vec.len() == 0 {
487 return vec;
488 }
489 let code = remove_comments(code.to_string());
491 let mut code = code.lines().collect::<Vec<_>>();
492 while code.last().unwrap_or(&"// empty").chars().all(|e| " \t\r\n".contains(e)) {
494 code.pop();
495 }
496 let mut code = code.join("\n");
497
498 let mut used_buffers_id = vec![];
500 for (i, b) in self.buffers.iter().enumerate() {
501 let mut found = false;
502 for c in IDENT_SEPARATOR_CHARS.chars() { let patern = &format!("{c}{}[", b.name);
504 if code.find(patern).is_some() {
505 found = true;
506 }
507 }
508 if found {
510 used_buffers_id.push(i);
511 }
512 }
513
514 self.build_buffers(&used_buffers_id, true);
515
516 let (_, type_, stride, size) = <T as ToVecU8<T>>::convert(&vec);
517 let byte_size = stride as usize *size;
520 let max_size = MAX_BYTE_BUFFER_SIZE / stride as usize;
521 if byte_size <= MAX_BYTE_BUFFER_SIZE {
522 code = code.replace("element", "reservedbuf[index]");
523 let mut other_code = code.lines().collect::<Vec<_>>();
524 other_code.pop();
525 let other_code = other_code.join("\n");
526 let other_code = other_code.as_str();
527 let apply_expr = code.lines().last().unwrap();
528
529 used_buffers_id.push(self.buffers.len());
531 self.create_buffer_from("reservedbuf", &vec, BufferUsage::ReadWrite, true);
533 self.buffers[*self.buf_name_to_id.get("reservedbuf").unwrap()].wgpu_usage |= wgpu::BufferUsages::MAP_READ;
534 self.build_buffers(&used_buffers_id, false);
536
537 let shader_module = self.create_shader_module(Dispatch::Linear(vec.len()), &format!("
539 {other_code}
540 fn main() {{
541 reservedbuf[index] = {apply_expr};
542 }}
543 ", ));
544 self.execute_commands(vec![
545 Command::Shader(shader_module),
546 Command::Retrieve("reservedbuf")
547 ]);
548 let result = self.get_buffer_data(vec!["reservedbuf"]).into_iter().next().unwrap();
549 return <T as ToVecU8<T>>::get_output(result);
550
551 }else {
552 for c in IDENT_SEPARATOR_CHARS.chars() {
554 if code.contains(&(c.to_string()+"index")) {
555 panic!("Cannot use the `index` variable with a vector of more than {} bytes.", max_size * 4)
556 }
557 }
558
559 let t = type_.to_string();
560 let mut other_code = code.lines().collect::<Vec<_>>();
561 other_code.pop();
562 let other_code = other_code.join("\n");
563 let other_code = other_code.as_str();
564
565 let apply_expr = code.lines().last().unwrap();
566 let function_element = format!("fn reservedfn(element: {t}) -> {t} {{return({apply_expr});}}\n");
567
568 let nb_buf = (byte_size as f64 / MAX_BYTE_BUFFER_SIZE as f64).ceil() as u32;
569 let size_last_buf = size % max_size;
570 let mut main_body = vec![];
571
572 for i in 0..(nb_buf-1) {
574 let vec_i = i as usize*max_size;
575 self.create_buffer_from(&format!("reservedbuf{i}"), &vec[vec_i..(vec_i + max_size)].to_vec(), BufferUsage::ReadWrite, true);
576 main_body.push(format!("reservedbuf{i}[index] = reservedfn(reservedbuf{i}[index]);\n"));
577 }
578 if size_last_buf != 0 {
579 self.create_buffer_from(&format!("reservedbuf{}", nb_buf-1), &vec[(vec.len() - size_last_buf)..].to_vec(), BufferUsage::ReadWrite, true);
580 main_body.push(format!("if (index < {size_last_buf}u) {{ reservedbuf{0}[index] = reservedfn(reservedbuf{0}[index]);}}\n", nb_buf-1))
581 }
582
583 let shader_module = self.create_shader_module(Dispatch::Linear(max_size), &format!("
585 {other_code}
586 {function_element}
587 fn main() {{
588 {}
589 }}
590 ", main_body.join("")));
591 let mut commands = vec![Command::Shader(shader_module)];
592 let mut buffer_names = vec![];
593 for i in 0..nb_buf {
594 buffer_names.push(format!("reservedbuf{i}"));
595 }
596 for name in buffer_names.iter() {
597 commands.push(Command::Retrieve(&name));
598 }
599 self.execute_commands(commands);
600 let mut bufs = vec![];
601 for i in 0..nb_buf {
602 bufs.push(format!("reservedbuf{i}"));
603 }
604 let results = self.get_buffer_data(bufs.iter().map(|e| e.as_str()).collect::<Vec<_>>());
605 let mut out = Vec::<T>::new();
606 for result in results {
607 out.append(&mut <T as ToVecU8<T>>::get_output(result));
608 }
609 return out;
610 }
611 }
612
613 #[inline]
616 pub fn create_shader_module(&self, dispatch: Dispatch, code: &str) -> ShaderModule {
617 let dispatch_linear_len;
618 match dispatch {
619 Dispatch::Linear(l) => {
620 dispatch_linear_len = Some(l);
621 }
622 Dispatch::Custom(_, _, _) => {
623 dispatch_linear_len = None;
624 }
625 }
626 let dispatch_linear = dispatch_linear_len.is_some();
627
628 let mut tmp_code = remove_comments(code.to_owned());
629 let mut main_headers = String::from("
631[[stage(compute), workgroup_size(1)]]
632fn main([[builtin(global_invocation_id)]] global_id: vec3<u32>) {\n");
633 if dispatch_linear {
635 if dispatch_linear_len < Some(65536) {
636 main_headers += "\tlet index: u32 = global_id.x;\n";
637 }else {
638 main_headers += &format!("\tlet index: u32 = global_id.x + global_id.y * 65535u;\nif (index >= {}u) {{return;}}\n", dispatch_linear_len.unwrap());
639
640 }
641 }
642
643 let mut used_buffers_id = vec![];
645 for (i, b) in self.buffers.iter().enumerate() {
647 let mut found = false;
648 for c in IDENT_SEPARATOR_CHARS.chars() { let patern = &format!("{c}{}[", b.name);
650 if tmp_code.find(patern).is_some() {
651 found = true;
652 }
653 tmp_code = tmp_code.replace(patern, &format!("{c}{}.d[", b.name));
654 }
655
656 if found {
658 used_buffers_id.push(i);
659 }
660 }
661
662 let mut structs = vec![];
663 let mut struct_types = HashMap::new();
664 let mut bindings = vec![];
665 let mut used_buffers_id_i = 0;
666 let mut binding_i = 0;
667 for (i, b) in self.buffers.iter().enumerate() {
668 while used_buffers_id_i < used_buffers_id.len()-1 && used_buffers_id[used_buffers_id_i] < i {
669 used_buffers_id_i += 1;
670 }
671 if used_buffers_id[used_buffers_id_i] != i {
672 continue;
673 }
674 if b.type_stride != -1 && !struct_types.contains_key(&b.type_){
677 structs.push(format!("struct reserved{i} {{\n\td: [[stride({})]] array<{}>;\n}};\n",b.type_stride, b.type_.to_string()));
678 struct_types.insert(b.type_.clone(), i);
679 }
680
681 bindings.push(format!(
683 "[[group(0), binding({binding_i})]] \n var<storage, {}> {}: reserved{};\n",
684 match b.usage {
685 BufferUsage::ReadOnly => {"read".to_string()},
686 BufferUsage::WriteOnly => {"write".to_string()},
687 BufferUsage::ReadWrite => {"read_write".to_string()}
688 },
689 b.name,
690 struct_types.get(&b.type_).unwrap()
691 ));
692
693 binding_i += 1;
694 }
695
696 tmp_code = format!("{}{}{}",structs.join(""), bindings.join(""), tmp_code.replace("fn main() {\n", &main_headers));
698
699 ShaderModule {
700 shader: self.device.create_shader_module(&wgpu::ShaderModuleDescriptor {
701 label: None,
702 source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Borrowed(&tmp_code))
703 }),
704 used_buffers_id,
706 dispatch
707 }
708 }
709
710 fn build_buffers(&mut self, indices: &Vec<usize>, build_output: bool) { let mut bufs_to_build = vec![];
713 let mut indices_i = 0;
714 if indices.len() == 0 {
715 for b in self.buffers.iter_mut() {
716 if b.is_output { bufs_to_build.push(b);
718 }
719 }
720 }else if build_output {
721 for (i, b) in self.buffers.iter_mut().enumerate() {
722 while indices_i < indices.len()-1 && indices[indices_i] < i {
723 indices_i += 1;
724 }
725 if b.is_output || i == indices[indices_i] { bufs_to_build.push(b);
727 }
728 }
729 }else {
730 for (i, b) in self.buffers.iter_mut().enumerate() {
731 while indices_i < indices.len()-1 && indices[indices_i] < i {
732 indices_i += 1;
733 }
734 if i == indices[indices_i] { bufs_to_build.push(b);
736 }
737 }
738 }
739
740
741 for buf in bufs_to_build.iter_mut() {
742 if buf.buffer.is_some(){
744 continue;
745 }
746 match &buf.data {
747 Some(data) => {
748 buf.buffer = Some(self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
749 label: Some(&buf.name),
750 contents: &data,
751 usage: buf.wgpu_usage,
752 }));
753 }
754 None => {
755 buf.buffer = Some(self.device.create_buffer(&wgpu::BufferDescriptor {
756 label: Some(&buf.name),
757 size: buf.size,
758 usage: buf.wgpu_usage,
759 mapped_at_creation: false,
760 }))
761 }
762 }
763 if buf.is_output {
764 buf.output_buf = Some(self.device.create_buffer(&wgpu::BufferDescriptor {
765 label: Some(&buf.name),
766 size: buf.size,
767 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
768 mapped_at_creation: false,
769 }))
770 }
771 }
772 }
773
774 #[inline]
792 pub fn execute_shader_module(&mut self, shader_module: &ShaderModule) -> Vec<OutputVec> {
793 if self.buffers.len() == 0 {
794 panic!("In function `Device::execute_shader_module` : Cannot execute a shader if no buffer has been created.
795 You can create a buffer using the `Device::create_buffer` and `Device::create_buffer_from` functions");
796 }
797
798 self.build_buffers(&shader_module.used_buffers_id, false);
799
800 let mut bind_group_entries = vec![];
801 let mut i = 0;
802 for id in shader_module.used_buffers_id.iter() {
803 bind_group_entries.push(wgpu::BindGroupEntry{
804 binding: i as u32,
805 resource: self.buffers[*id].buffer.as_ref().unwrap().as_entire_binding()
806 });
807 i += 1;
808 }
809
810 let compute_pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
811 label: None,
812 layout: None,
813 module: &shader_module.shader,
814 entry_point: "main",
815 });
816
817 let bind_group_layout = compute_pipeline.get_bind_group_layout(0);
818 let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
819 label: None,
820 layout: &bind_group_layout,
821 entries: &bind_group_entries
822 });
823
824 let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
825 {
826 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: None });
827 cpass.set_pipeline(&compute_pipeline);
828 cpass.set_bind_group(0, &bind_group, &[]);
829 match shader_module.dispatch {
830 Dispatch::Linear(val) => {
831 if val < 65536 {
832 cpass.dispatch(val as u32, 1, 1);
833 }else {
834 cpass.dispatch(65535, (val as f64 / 65535f64).ceil() as u32, 1);
835 }
836 }
837 Dispatch::Custom(x, y, z) => {
838 cpass.dispatch(x, y, z);
839 }
840 }
841 }
842 for i in shader_module.used_buffers_id.iter() {
844 let b = &self.buffers[*i];
845 if b.is_output {
846 encoder.copy_buffer_to_buffer(&b.buffer.as_ref().unwrap(), 0, &b.output_buf.as_ref().unwrap(), 0, b.size);
847 }
848 }
849 self.queue.submit(Some(encoder.finish()));
850
851 self.retrieve_buffer_data(&shader_module.used_buffers_id)
852 }
853
854 pub fn execute_shader_code(&mut self, dispatch: Dispatch, code: &str) -> Vec<OutputVec>{
872 let shader_module = self.create_shader_module(dispatch.clone(), code);
873 self.execute_shader_module(&shader_module)
874 }
875
876 pub fn execute_commands(&mut self, commands: Vec<Command>) {
878 let mut shader_count = 0;
879 let mut shader_index = None;
880 for (i, c) in commands.iter().enumerate() {
881 if let Command::Shader(_) = c {
882 shader_count += 1;
883 shader_index = Some(i);
884 }
885 }
886 if shader_count > 1 {
887 panic!("In function `Device::execute_commands` : There should only be 1 shader in the commands, got {shader_count}");
888 }
889
890 let mut copy_buffers = HashSet::new();
892 for c in commands.iter() {
893 if let Command::Copy(from, to) = c {
894 let buf_id1 = self.buf_name_to_id.get(&from.to_string()).expect("The source buffer has not been created on this device (probably wrong name).");
895 self.buffers[*buf_id1].wgpu_usage |= wgpu::BufferUsages::COPY_SRC | wgpu::BufferUsages::MAP_READ;
896 let buf_id2 = self.buf_name_to_id.get(&to.to_string()).expect("The destination buffer has not been created on this device (probably wrong name).");
897 self.buffers[*buf_id2].wgpu_usage |= wgpu::BufferUsages::COPY_DST;
898 copy_buffers.insert(*buf_id1);
899 copy_buffers.insert(*buf_id2);
900 }
901 }
902
903 let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
904 if shader_count == 0 { for c in commands.iter() {
906 if let Command::Copy(from, to) = c {
907 let buf1 = &self.buffers[*self.buf_name_to_id.get(&from.to_string()).unwrap()]; let buf2 = &self.buffers[*self.buf_name_to_id.get(&to.to_string()).unwrap()].buffer; encoder.copy_buffer_to_buffer(&buf1.buffer.as_ref().unwrap(), 0, &buf2.as_ref().unwrap(), 0, buf1.size);
910 }
911 }
912 }else { let shader_index = shader_index.unwrap();
914 for c in commands.iter().take(shader_index) {
915 if let Command::Copy(from, to) = c {
916 let buf1 = &self.buffers[*self.buf_name_to_id.get(&from.to_string()).unwrap()];
917 let buf2 = &self.buffers[*self.buf_name_to_id.get(&to.to_string()).unwrap()].buffer;
918 encoder.copy_buffer_to_buffer(&buf1.buffer.as_ref().unwrap(), 0, &buf2.as_ref().unwrap(), 0, buf1.size);
919 }
920 }
921
922 let shader_module =
923 (if let Command::Shader(sm) = &commands[shader_index] {
924 Some(sm)
925 }else {
926 None
927 }).unwrap();
928
929 let used_buffers_shader_and_copy = &shader_module.used_buffers_id.iter().map(|e| *e).chain(copy_buffers.into_iter()).collect::<Vec<_>>();
930 self.build_buffers(used_buffers_shader_and_copy, true);
931
932 let mut bind_group_entries = vec![];
933 let mut i = 0;
934 for id in shader_module.used_buffers_id.iter() {
935 bind_group_entries.push(wgpu::BindGroupEntry{
936 binding: i as u32,
937 resource: self.buffers[*id].buffer.as_ref().unwrap().as_entire_binding()
938 });
939 i += 1;
940 }
941
942 let compute_pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
943 label: None,
944 layout: None,
945 module: &shader_module.shader,
946 entry_point: "main",
947 });
948
949 let bind_group_layout = compute_pipeline.get_bind_group_layout(0);
950 let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
951 label: None,
952 layout: &bind_group_layout,
953 entries: &bind_group_entries
954 });
955
956 {
957 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("execute commands")});
958 cpass.set_pipeline(&compute_pipeline);
959 cpass.set_bind_group(0, &bind_group, &[]);
960 match shader_module.dispatch {
961 Dispatch::Linear(val) => {
962 if val < 65536 {
963 cpass.dispatch(val as u32, 1, 1);
964 }else {
965 cpass.dispatch(65535, (val as f64 / 65535f64).ceil() as u32, 1);
966 }
967 }
968 Dispatch::Custom(x, y, z) => {
969 cpass.dispatch(x, y, z);
970 }
971 }
972 }
973
974 for c in commands.iter().skip(shader_index+1) {
976 if let Command::Copy(from, to) = c {
977 let buf1 = &self.buffers[*self.buf_name_to_id.get(&from.to_string()).unwrap()];
978 let buf2 = &self.buffers[*self.buf_name_to_id.get(&to.to_string()).unwrap()].buffer;
979 encoder.copy_buffer_to_buffer(&buf1.buffer.as_ref().unwrap(), 0, &buf2.as_ref().unwrap(), 0, buf1.size);
980 }
981 if let Command::Retrieve(buf) = c {
982 let buf1 = &self.buffers[*self.buf_name_to_id.get(&buf.to_string()).unwrap()];
983 encoder.copy_buffer_to_buffer(&buf1.buffer.as_ref().unwrap(), 0, &buf1.output_buf.as_ref().unwrap(), 0, buf1.size);
984 }
985 }
986
987 self.queue.submit(Some(encoder.finish()));
988
989
990 }
991 }
992
993 #[inline]
994 fn retrieve_buffer_data(&self, buffers_id: &Vec<usize>) -> Vec<OutputVec> {
995 let mut output_buffers_id = vec![];
996 let mut buffer_slices = vec![];
997 let mut result_futures = vec![];
998 for i in buffers_id.iter() {
999 let b = &self.buffers[*i];
1000 if !b.is_output {
1001 continue;
1002 }
1003 let out_b = b.output_buf.as_ref().unwrap();
1004 let buffer_slice = out_b.slice(..);
1006 let buffer_future = buffer_slice.map_async(wgpu::MapMode::Read);
1007 buffer_slices.push(buffer_slice);
1008 result_futures.push(buffer_future);
1009 output_buffers_id.push(*i);
1010 }
1011
1012 self.device.poll(wgpu::Maintain::Wait);
1014
1015
1016 let mut results = Vec::with_capacity(self.output_buffers_id.len());
1017 let mut buffer_index = 0;
1019
1020 async {
1021 for (i, result_future) in result_futures.into_iter().enumerate() {
1022 let buf = &self.buffers[output_buffers_id[i]];
1023 buffer_index += 1;
1024 if let Ok(()) = result_future.await {
1025 let data = buffer_slices[i].get_mapped_range();
1027 let result: OutputVec;
1029 match buf.type_ {
1030 BufferType::I32 => {
1031 result = OutputVec::VecI32(bytemuck::cast_slice::<u8, i32>(&data).to_vec());
1032 },
1033 BufferType::U32 => {
1034 result = OutputVec::VecU32(bytemuck::cast_slice::<u8, u32>(&data).to_vec());
1035 },
1036 BufferType::F32 => {
1037 result = OutputVec::VecF32(bytemuck::cast_slice::<u8, f32>(&data).to_vec());
1038 },
1039 BufferType::Bool => {
1040 result = OutputVec::VecBool(data.iter().map(|&e| e != 0).collect());
1041 }
1042 }
1043
1044 drop(data);
1046 buf.output_buf.as_ref().unwrap().unmap(); results.push(result);
1050 }else {
1051 panic!("computations failed");
1052 }
1053 }
1054 results
1055 }.block_on()
1056 }
1057
1058 pub fn get_buffer_data(&self, buffer_names: Vec<&str>) -> Vec<OutputVec> {
1060 let buffers_id = buffer_names.iter().map(|&e| *self.buf_name_to_id.get(e)
1061 .expect("The buffer has not been created on this device (probably wrong name)"))
1062 .collect::<Vec<_>>();
1063 self.retrieve_buffer_data(&buffers_id)
1064 }
1065}
1066
1067pub mod examples {
1068 use crate::{Device, BufferUsage, Dispatch, BufferType, Command};
1069
1070 pub fn simplest_apply() {
1072 let mut device = Device::new();
1073 let v1 = vec![1.0f32, 2.0, 3.0];
1074 let v1 = device.apply_on_vector(v1, "element * 2.0");
1076 println!("{v1:?}");
1077 }
1078
1079 pub fn apply_with_buf() {
1081 let mut device = Device::new();
1082 let v1 = vec![2.0f32, 3.0, 5.0, 7.0, 11.0];
1083 let exponent = vec![3.0];
1084 device.create_buffer_from("exponent", &exponent, BufferUsage::ReadOnly, false);
1085 let cubes = device.apply_on_vector(v1, "pow(element, exponent[0u])");
1086 println!("{cubes:?}")
1087 }
1088
1089 pub fn with_execute_shader() {
1091 let mut device = Device::new();
1092 let v1 = vec![1i32, 2, 3, 4, 5, 6];
1093 device.create_buffer_from("v1", &v1, BufferUsage::ReadOnly, false);
1094 device.create_buffer("output", BufferType::I32, v1.len(), BufferUsage::WriteOnly, true);
1095 let result = device.execute_shader_code(Dispatch::Linear(v1.len()), r"
1096 fn main() {
1097 output[index] = v1[index] * 2;
1098 }
1099 ").into_iter().next().unwrap().unwrap_i32();
1100 assert_eq!(result, vec![2, 4, 6, 8, 10, 12]);
1101 }
1102
1103 pub fn multiple_output_buffers() {
1105 let mut device = Device::new();
1106 let v = vec![1u32, 2, 3];
1107 let v2 = vec![3u32, 4, 5];
1108 let v3 = vec![7u32, 8, 9];
1109 device.create_buffer_from(
1111 "buf",
1112 &v,
1113 BufferUsage::ReadWrite,
1114 true
1115 );
1116 device.create_buffer_from(
1117 "buf2",
1118 &v2,
1119 BufferUsage::ReadOnly,
1120 false
1121 );
1122 device.create_buffer_from(
1124 "buf3",
1125 &v3,
1126 BufferUsage::ReadWrite,
1127 true
1128 );
1129 let mut result = device.execute_shader_code(Dispatch::Linear(v.len()), r"
1130 fn main() {
1131 buf[index] = buf[index] + buf2[index] + buf3[index];
1132 buf3[index] = buf[index] * buf2[index] * buf3[index];
1133 }
1134 ").into_iter();
1135
1136 let sum = result.next().unwrap().unwrap_u32();
1137 let product = result.next().unwrap().unwrap_u32();
1138 println!("{:?}", sum);
1139 println!("{:?}", product);
1140 }
1141
1142 pub fn global_id() {
1144 let mut device = Device::new();
1145 let vec = vec![2u32, 3, 5, 7, 11, 13, 17];
1146 device.create_buffer_from("vec1", &vec, BufferUsage::ReadWrite, true);
1148 let result = device.execute_shader_code(Dispatch::Custom(1, vec.len() as u32, 1), r"
1149 fn main() {
1150 vec1[global_id.y] = vec1[global_id.y] + global_id.x + global_id.z;
1151 }
1152 ").into_iter().next().unwrap().unwrap_u32();
1153 assert_eq!(result, vec![2u32, 3, 5, 7, 11, 13, 17]);
1156 }
1157
1158 pub fn shader_two_steps() {
1160 let mut device = Device::new();
1161 let v = vec![1u32, 2, 3, 4];
1162 device.create_buffer_from("buf1", &v, BufferUsage::ReadWrite, true);
1163 let shader_module = device.create_shader_module(Dispatch::Linear(v.len()), "
1164 fn main() {
1165 buf1[index] = buf1[index] * 17u;
1166 }
1167 ");
1168 let result = device.execute_shader_module(&shader_module).into_iter().next().unwrap().unwrap_u32();
1169 assert_eq!(result, vec![17u32, 34, 51, 68]);
1170 }
1171
1172 pub fn complete_pipeline() {
1174 let mut device = Device::new();
1175 let v1 = vec![1u32, 2, 3, 4, 5];
1176 device.create_buffer_from("v1", &v1, BufferUsage::ReadWrite, true);
1177 let shader_module = device.create_shader_module(Dispatch::Linear(v1.len()), "
1178 fn main() {
1179 v1[index] = v1[index] * 2u;
1180 }
1181 ");
1182 let mut commands = vec![];
1183 commands.push(Command::Shader(shader_module));
1184 commands.push(Command::Retrieve("v1"));
1185 device.execute_commands(commands);
1186 let result = device.get_buffer_data(vec!["v1"]).into_iter().next().unwrap().unwrap_u32();
1187 assert_eq!(result, vec![2u32, 4, 6, 8, 10]);
1188 }
1189
1190 pub fn reusing_device() {
1192 let mut device = Device::new();
1193 let v1 = vec![1i32, 2, 3, 4, 5, 6];
1194 device.create_buffer_from("v1", &v1, BufferUsage::ReadOnly, false);
1195 device.create_buffer("output", BufferType::I32, v1.len(), BufferUsage::WriteOnly, true);
1196 let result = device.execute_shader_code(Dispatch::Linear(v1.len()), r"
1197 fn main() {
1198 output[index] = v1[index] * 2;
1199 }
1200 ").into_iter().next().unwrap().unwrap_i32();
1201 assert_eq!(result, vec![2i32, 4, 6, 8, 10, 12]);
1202
1203 let result2 = device.execute_shader_code(Dispatch::Linear(v1.len()), r"
1204 fn main() {
1205 output[index] = v1[index] * 10;
1206 }
1207 ").into_iter().next().unwrap().unwrap_i32();
1208 assert_eq!(result2, vec![10, 20, 30, 40, 50, 60]);
1209 }
1210
1211 #[test]
1213 pub fn big_computations() {
1214 let mut device = Device::new();
1215 let size = 33_554_432;
1216 device.create_buffer(
1217 "buf",
1218 BufferType::U32,
1219 size,
1220 BufferUsage::WriteOnly,
1221 true);
1222 let result = device.execute_shader_code(Dispatch::Linear(size), r"
1223 fn number_of_seven_in_digit_product(number: u32) -> u32 {
1224 var p: u32 = 1u;
1225 var n: u32 = number;
1226 loop {
1227 if (n == 0u) {break;}
1228 p = p * (n % 10u);
1229 n = n / 10u;
1230 }
1231 var nb_seven: u32 = 0u;
1232 loop {
1233 if (p == 0u) {break;}
1234 if (p % 10u == 7u) {
1235 nb_seven = nb_seven + 1u;
1236 }
1237 p = p / 10u;
1238 }
1239 return nb_seven;
1240 }
1241 fn main() {
1242 buf[index] = number_of_seven_in_digit_product(index);
1243 }
1244 ").into_iter().next().unwrap().unwrap_u32();
1245 let mut index = 0;
1246 let mut max = result[0];
1247 for (i, e) in result.iter().enumerate() {
1248 if e > &max {
1249 max = *e;
1250 index = i;
1251 }
1252 }
1253 println!("The number who's digit product got the most seven below {size} is {index} with {max} sevens.");
1254 }
1255}