1use std::{vec, collections::{HashMap, HashSet}, ops};
97use wgpu::{util::DeviceExt, BufferUsages};
98use std::vec::Vec;
99use pollster::FutureExt;
100
101struct Buffer {
105    buffer: Option<wgpu::Buffer>,
106    name: String,
107    size: u64,
108    type_: BufferType,
109    type_stride: i32,
110    data: Option<Vec<u8>>,
111    usage: BufferUsage,
112    wgpu_usage: wgpu::BufferUsages,
113    is_output: bool,
114    output_buf: Option<wgpu::Buffer>,
115}
116
117static MAX_BYTE_BUFFER_SIZE: usize = 134_217_728;
119static IDENT_SEPARATOR_CHARS: &str = " \t\n\r([,+-/*=%&|^~!<>{}";
120
121#[derive(Clone, Hash, PartialEq, Eq)]
122pub enum BufferType {
123    I32,
124    U32,
125    F32,
126    Bool
127}
128impl BufferType {
129    pub fn to_string(&self) -> String {
130        match &self {
131            BufferType::I32 => String::from("i32"),
132            BufferType::U32 => String::from("u32"),
133            BufferType::F32 => String::from("f32"),
134            BufferType::Bool => String::from("bool")
135        }
136    }
137    pub fn stride(&self) -> i32 {
138        match &self {
139            BufferType::Bool => 1,
140            _ => 4
141        }
142    }
143}
144
145
146#[derive(Clone)]
148pub enum BufferUsage {
149    ReadOnly,
150    WriteOnly,
151    ReadWrite
152}
153impl ops::BitOr<BufferUsages> for BufferUsage {
154    type Output = wgpu::BufferUsages;
155    fn bitor(self, rhs: wgpu::BufferUsages) -> Self::Output {
156        buffer_usage(self, false) | rhs
157    }
158}
159
160
161#[inline]
163fn buffer_usage(usage: BufferUsage, output: bool) -> wgpu::BufferUsages {
164    let out;
165    match usage {
166        BufferUsage::ReadOnly => {
167            out = wgpu::BufferUsages::MAP_WRITE
168            | wgpu::BufferUsages::STORAGE
169        },
170        BufferUsage::WriteOnly => {
171            out = wgpu::BufferUsages::STORAGE
172        },
173        BufferUsage::ReadWrite => {
174            out = wgpu::BufferUsages::STORAGE
175        }
176    }
177    if output {
178        out | wgpu::BufferUsages::COPY_SRC
179    }else {
180        out
181    }
182}
183
184#[derive(Clone)]
188pub enum Dispatch {
189    Linear(usize),
190    Custom(u32, u32, u32)
191}
192
193pub trait ToVecU8<T> {
196    fn convert(v: &Vec<T>) -> (Vec<u8>, BufferType, i32, usize);
198    fn get_output(v: OutputVec) -> Vec<T>;
199}
200impl ToVecU8<i32> for i32 {
201    fn convert(v: &Vec<i32>) -> (Vec<u8>, BufferType, i32, usize) {
202        (bytemuck::cast_slice::<i32, u8>(&v).to_vec(), BufferType::I32, 4, v.len())
203    }
204    fn get_output(v: OutputVec) -> Vec<i32> {
205        v.unwrap_i32()
206    }
207}
208impl ToVecU8<u32> for u32 {
209    fn convert(v: &Vec<u32>) -> (Vec<u8>, BufferType, i32, usize) {
210        (bytemuck::cast_slice::<u32, u8>(&v).to_vec(), BufferType::U32, 4, v.len())
211    }
212    fn get_output(v: OutputVec) -> Vec<u32> {
213        v.unwrap_u32()
214    }
215}
216impl ToVecU8<f32> for f32 {
217    fn convert(v: &Vec<f32>) -> (Vec<u8>, BufferType, i32, usize) {
218        (bytemuck::cast_slice::<f32, u8>(&v).to_vec(), BufferType::F32, 4, v.len())
219    }
220    fn get_output(v: OutputVec) -> Vec<f32> {
221        v.unwrap_f32()
222    }
223}
224impl ToVecU8<bool> for bool {
225    fn convert(v: &Vec<bool>) -> (Vec<u8>, BufferType, i32, usize) {
226        (v.iter().map(|&e| e as u8).collect::<Vec<_>>(), BufferType::Bool, 1, v.len())
227    }
228    fn get_output(v: OutputVec) -> Vec<bool> {
229        v.unwrap_bool()
230    }
231}
232
233#[derive(Debug)]
236pub enum OutputVec {
237    VecI32(Vec<i32>),
238    VecU32(Vec<u32>),
239    VecF32(Vec<f32>),
240    VecBool(Vec<bool>)
241}
242impl OutputVec {
243    pub fn unwrap_i32(self) -> Vec<i32> {
244        match self {
245            OutputVec::VecI32(val) => {
246                val
247            }
248            _ => {
249                panic!("value is not a u32!");
250            }
251        }
252    }
253    pub fn unwrap_u32(self) -> Vec<u32> {
254        match self {
255            OutputVec::VecU32(val) => {
256                val
257            }
258            _ => {
259                panic!("value is not a u32!, it's a {self:?}");
260            }
261        }
262    }
263    pub fn unwrap_f32(self) -> Vec<f32> {
264        match self {
265            OutputVec::VecF32(val) => {
266                val
267            }
268            _ => {
269                panic!("value is not a f32!");
270            }
271        }
272    }
273    pub fn unwrap_bool(self) -> Vec<bool> {
274        match self {
275            OutputVec::VecBool(val) => {
276                val
277            }
278            _ => {
279                panic!("value is not a bool!");
280            }
281        }
282    }
283}
284
285pub struct ShaderModule {
286    shader: wgpu::ShaderModule,
287    used_buffers_id: Vec<usize>,
288    dispatch: Dispatch
290}
291
292pub fn remove_comments(code: String) -> String{
293    let mut out = String::new();
294    let mut chars = code.chars();
295    while let Some(c) = chars.next() {
296        if c == '/' {
297            out.push(c);
298            let c = chars.next();
299            if Some('/') == c { out.pop();
301                while let Some(c) = chars.next() {
302                    if c == '\n' {
303                        out.push(c);
304                        break;
305                    }
306                }
307            }else if Some('*') == c { out.pop();
309                while let Some(c) = chars.next() {
310                    if c == '*' {
311                        if let Some('/') = chars.next() {
312                            break;
313                        }
314                    }
315                }
316            }else {
317                if let Some(c) = c {
318                    out.push(c);
319                }
320            }
321        }else {
322            out.push(c);
323        }
324    }
325    out
326}
327
328pub enum Command<'a> {
329    Shader(ShaderModule),
330    Copy(&'a str, &'a str),
331    Retrieve(&'a str)
332}
333
334pub struct Device {
336    adapter: wgpu::Adapter,
337    device: wgpu::Device,
338    queue: wgpu::Queue,
339    buffers: Vec<Buffer>,
340    buf_name_to_id: HashMap<String, usize>,
341    output_buffers_id: Vec<usize>,
342}
343impl  Device {
344    pub fn new() -> Device {
346        async {
347            let instance = wgpu::Instance::new(wgpu::Backends::all());
348            let adapter = instance
349                .request_adapter(&wgpu::RequestAdapterOptions {
350                    power_preference: wgpu::PowerPreference::HighPerformance,
351                    compatible_surface: None,
352                    force_fallback_adapter: false,
353                })
354                .await
355                .unwrap();
356            let (device, queue) = adapter
357                .request_device(&Default::default(), None)
358                .await
359                .unwrap();
360            Device {
361                adapter,
362                device,
363                queue,
364                buffers: vec![],
365                buf_name_to_id: HashMap::new(),
366                output_buffers_id: vec![],
367            }
368        }.block_on()
369    }
370
371    #[inline]
373    pub fn get_info(self) -> String {
374        format!("{:?}", self.adapter.get_info())
375    }
376
377    pub fn create_buffer(&mut self, name: &str, data_type: BufferType, size: usize, usage: BufferUsage, is_output: bool) {
387        let data_type_stride = data_type.stride();
388        let byte_size = (data_type_stride as usize) * size;
389        let id = self.buffers.len();
390        if is_output {
393            self.output_buffers_id.push(self.buffers.len())
394        }
395        self.buf_name_to_id.insert(name.to_string(), id);
397        self.buffers.push(Buffer {
399            buffer: None,
400            name: name.to_owned(),
401            size: byte_size as u64,
402            type_: data_type,
403            type_stride: data_type_stride as i32,
404            data: None,
405            usage: usage.clone(),
406            wgpu_usage: buffer_usage(usage, is_output),
407            is_output,
409            output_buf: None
410        });
411    }
412
413    pub fn create_buffer_from<T: ToVecU8<T>>(&mut self, name: &str, content: &Vec<T>, usage: BufferUsage, is_output: bool) {
422        let (raw_content, data_type, data_type_stride, size) = <T as ToVecU8<T>>::convert(content);
423        let byte_size = data_type_stride as u64 * size as u64;
424        let id = self.buffers.len();
425         if is_output {
428            self.output_buffers_id.push(id);
429        }
430        self.buf_name_to_id.insert(name.to_string(), id);
432        self.buffers.push(Buffer {
434            buffer: None,
435            size: byte_size,
436            name: name.to_owned(),
437            type_: data_type,
438            type_stride: data_type_stride,
439            data: Some(raw_content),
440            usage: usage.clone(),
441            wgpu_usage: buffer_usage(usage, is_output),
442            is_output,
444            output_buf: None
445        });
446    }
447
448    pub fn apply_buffer_usages(&mut self, buf_name: &str, usage: wgpu::BufferUsages, is_output: bool) {
452        let buf = &mut self.buffers[*self.buf_name_to_id.get(buf_name).expect("The buffer has not been created on this device (probably wrong name).")];
453        buf.wgpu_usage |= usage;
454        buf.is_output = is_output;
455    }
456
457    pub fn apply_on_vector<T: ToVecU8<T> + std::clone::Clone>(&mut self, vec: Vec<T>, code: &str) -> Vec<T>{
486        if vec.len() == 0 {
487            return vec;
488        }
489        let code = remove_comments(code.to_string());
491        let mut code = code.lines().collect::<Vec<_>>();
492        while code.last().unwrap_or(&"// empty").chars().all(|e| " \t\r\n".contains(e)) {
494            code.pop();
495        }
496        let mut code = code.join("\n");
497
498        let mut used_buffers_id = vec![];
500        for (i, b) in self.buffers.iter().enumerate() {
501            let mut found = false;
502            for c in IDENT_SEPARATOR_CHARS.chars() { let patern = &format!("{c}{}[", b.name);
504                if code.find(patern).is_some() {
505                    found = true;
506                }
507            }
508            if found {
510                used_buffers_id.push(i);
511            }
512        }
513
514        self.build_buffers(&used_buffers_id, true);
515
516        let (_, type_, stride, size) = <T as ToVecU8<T>>::convert(&vec);
517        let byte_size =  stride as usize *size;
520        let max_size = MAX_BYTE_BUFFER_SIZE / stride as usize;
521        if byte_size <= MAX_BYTE_BUFFER_SIZE {
522            code = code.replace("element", "reservedbuf[index]");
523            let mut other_code = code.lines().collect::<Vec<_>>();
524            other_code.pop();
525            let other_code = other_code.join("\n");
526            let other_code = other_code.as_str();
527            let apply_expr = code.lines().last().unwrap();
528
529            used_buffers_id.push(self.buffers.len());
531            self.create_buffer_from("reservedbuf", &vec, BufferUsage::ReadWrite, true);
533            self.buffers[*self.buf_name_to_id.get("reservedbuf").unwrap()].wgpu_usage |= wgpu::BufferUsages::MAP_READ;
534            self.build_buffers(&used_buffers_id, false);
536
537            let shader_module = self.create_shader_module(Dispatch::Linear(vec.len()), &format!("
539            {other_code}
540            fn main() {{
541                reservedbuf[index] = {apply_expr};
542            }}
543            ", ));
544            self.execute_commands(vec![
545                Command::Shader(shader_module),
546                Command::Retrieve("reservedbuf")
547            ]);
548            let result = self.get_buffer_data(vec!["reservedbuf"]).into_iter().next().unwrap();
549            return <T as ToVecU8<T>>::get_output(result);
550
551        }else {
552            for c in IDENT_SEPARATOR_CHARS.chars() {
554                if code.contains(&(c.to_string()+"index")) {
555                    panic!("Cannot use the `index` variable with a vector of more than {} bytes.", max_size * 4)
556                }
557            }
558
559            let t = type_.to_string();
560            let mut other_code = code.lines().collect::<Vec<_>>();
561            other_code.pop();
562            let other_code = other_code.join("\n");
563            let other_code = other_code.as_str();
564
565            let apply_expr = code.lines().last().unwrap();
566            let function_element = format!("fn reservedfn(element: {t}) -> {t} {{return({apply_expr});}}\n");
567
568            let nb_buf = (byte_size as f64 / MAX_BYTE_BUFFER_SIZE as f64).ceil() as u32;
569            let size_last_buf = size % max_size;
570            let mut main_body = vec![];
571
572            for i in 0..(nb_buf-1) {
574                let vec_i = i as usize*max_size;
575                self.create_buffer_from(&format!("reservedbuf{i}"), &vec[vec_i..(vec_i + max_size)].to_vec(), BufferUsage::ReadWrite, true);
576                main_body.push(format!("reservedbuf{i}[index] = reservedfn(reservedbuf{i}[index]);\n"));
577            }
578            if size_last_buf != 0 {
579                self.create_buffer_from(&format!("reservedbuf{}", nb_buf-1), &vec[(vec.len() - size_last_buf)..].to_vec(), BufferUsage::ReadWrite, true);
580                main_body.push(format!("if (index < {size_last_buf}u) {{ reservedbuf{0}[index] = reservedfn(reservedbuf{0}[index]);}}\n", nb_buf-1))
581            }
582
583            let shader_module = self.create_shader_module(Dispatch::Linear(max_size), &format!("
585            {other_code}
586            {function_element}
587            fn main() {{
588                {}
589            }}
590            ", main_body.join("")));
591            let mut commands = vec![Command::Shader(shader_module)];
592            let mut buffer_names = vec![];
593            for i in 0..nb_buf {
594                buffer_names.push(format!("reservedbuf{i}"));
595            }
596            for name in buffer_names.iter() {
597                commands.push(Command::Retrieve(&name));
598            }
599            self.execute_commands(commands);
600            let mut bufs = vec![];
601            for i in 0..nb_buf {
602                bufs.push(format!("reservedbuf{i}"));
603            }
604            let results = self.get_buffer_data(bufs.iter().map(|e| e.as_str()).collect::<Vec<_>>());
605            let mut out = Vec::<T>::new();
606            for result in results {
607                out.append(&mut <T as ToVecU8<T>>::get_output(result));
608            }
609            return out;
610        }
611    }
612
613    #[inline]
616    pub fn create_shader_module(&self, dispatch: Dispatch, code: &str) -> ShaderModule {
617        let dispatch_linear_len;
618        match dispatch {
619            Dispatch::Linear(l) => {
620                dispatch_linear_len = Some(l);
621            }
622            Dispatch::Custom(_, _, _) => {
623                dispatch_linear_len = None;
624            } 
625        }
626        let dispatch_linear = dispatch_linear_len.is_some();
627
628        let mut tmp_code = remove_comments(code.to_owned());
629        let mut main_headers = String::from("
631[[stage(compute), workgroup_size(1)]]
632fn main([[builtin(global_invocation_id)]] global_id: vec3<u32>) {\n");
633        if dispatch_linear {
635            if dispatch_linear_len < Some(65536) {
636                main_headers += "\tlet index: u32 = global_id.x;\n";
637            }else {
638                main_headers += &format!("\tlet index: u32 = global_id.x + global_id.y * 65535u;\nif (index >= {}u) {{return;}}\n", dispatch_linear_len.unwrap());
639                
640            }
641        }
642
643        let mut used_buffers_id = vec![];
645        for (i, b) in self.buffers.iter().enumerate() {
647            let mut found = false;
648            for c in IDENT_SEPARATOR_CHARS.chars() { let patern = &format!("{c}{}[", b.name);
650                if tmp_code.find(patern).is_some() {
651                    found = true;
652                }
653                tmp_code = tmp_code.replace(patern, &format!("{c}{}.d[", b.name));
654            }
655
656            if found {
658                used_buffers_id.push(i);
659            }
660        }
661
662        let mut structs = vec![];
663        let mut struct_types = HashMap::new();
664        let mut bindings = vec![];
665        let mut used_buffers_id_i = 0;
666        let mut binding_i = 0;
667        for (i, b) in self.buffers.iter().enumerate() {
668            while used_buffers_id_i < used_buffers_id.len()-1 && used_buffers_id[used_buffers_id_i] < i {
669                used_buffers_id_i += 1;
670            }
671            if used_buffers_id[used_buffers_id_i] != i {
672                continue;
673            }
674            if b.type_stride != -1 && !struct_types.contains_key(&b.type_){
677                structs.push(format!("struct reserved{i} {{\n\td: [[stride({})]] array<{}>;\n}};\n",b.type_stride, b.type_.to_string()));
678                struct_types.insert(b.type_.clone(), i);
679            }
680            
681            bindings.push(format!(
683                "[[group(0), binding({binding_i})]] \n var<storage, {}> {}: reserved{};\n",
684                match b.usage {
685                    BufferUsage::ReadOnly => {"read".to_string()},
686                    BufferUsage::WriteOnly => {"write".to_string()},
687                    BufferUsage::ReadWrite => {"read_write".to_string()}
688                },
689                b.name,
690                struct_types.get(&b.type_).unwrap()
691            ));
692
693            binding_i += 1;
694        }
695        
696        tmp_code = format!("{}{}{}",structs.join(""), bindings.join(""),  tmp_code.replace("fn main() {\n", &main_headers));
698
699        ShaderModule {
700            shader: self.device.create_shader_module(&wgpu::ShaderModuleDescriptor {
701                label: None,
702                source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Borrowed(&tmp_code))
703            }),
704            used_buffers_id,
706            dispatch
707        }
708    }
709
710    fn build_buffers(&mut self, indices: &Vec<usize>, build_output: bool) { let mut bufs_to_build = vec![];
713        let mut indices_i = 0;
714        if indices.len() == 0 {
715            for b in self.buffers.iter_mut() {
716                if b.is_output { bufs_to_build.push(b);
718                }
719            }
720        }else if build_output {
721            for (i, b) in self.buffers.iter_mut().enumerate() {
722                while indices_i < indices.len()-1 && indices[indices_i] < i {
723                    indices_i += 1;
724                }
725                if b.is_output || i == indices[indices_i] { bufs_to_build.push(b);
727                }
728            }
729        }else {
730            for (i, b) in self.buffers.iter_mut().enumerate() {
731                while indices_i < indices.len()-1 && indices[indices_i] < i {
732                    indices_i += 1;
733                }
734                if i == indices[indices_i] { bufs_to_build.push(b);
736                }
737            }
738        }
739        
740
741        for buf in bufs_to_build.iter_mut() {
742            if buf.buffer.is_some(){
744                continue;
745            }
746            match &buf.data {
747                Some(data) => {
748                    buf.buffer = Some(self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
749                        label: Some(&buf.name),
750                        contents: &data,
751                        usage: buf.wgpu_usage,
752                    }));
753                }
754                None => {
755                    buf.buffer = Some(self.device.create_buffer(&wgpu::BufferDescriptor {
756                            label: Some(&buf.name),
757                            size: buf.size,
758                            usage: buf.wgpu_usage,
759                            mapped_at_creation: false,
760                        }))
761                }
762            }
763            if buf.is_output {
764                buf.output_buf = Some(self.device.create_buffer(&wgpu::BufferDescriptor {
765                    label: Some(&buf.name),
766                    size: buf.size,
767                    usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
768                    mapped_at_creation: false,
769                }))
770            }
771        }
772    }
773
774    #[inline]
792    pub fn execute_shader_module(&mut self, shader_module: &ShaderModule) -> Vec<OutputVec> {
793        if self.buffers.len() == 0 {
794            panic!("In function `Device::execute_shader_module` : Cannot execute a shader if no buffer has been created. 
795            You can create a buffer using the `Device::create_buffer` and `Device::create_buffer_from` functions");
796        }
797
798        self.build_buffers(&shader_module.used_buffers_id, false);
799
800        let mut bind_group_entries = vec![];
801        let mut i = 0;
802        for id in shader_module.used_buffers_id.iter() {
803            bind_group_entries.push(wgpu::BindGroupEntry{
804                binding: i as u32,
805                resource: self.buffers[*id].buffer.as_ref().unwrap().as_entire_binding()
806            });
807            i += 1;
808        }
809
810        let compute_pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
811            label: None,
812            layout: None,
813            module: &shader_module.shader,
814            entry_point: "main",
815        });
816
817        let bind_group_layout = compute_pipeline.get_bind_group_layout(0);
818        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
819            label: None,
820            layout: &bind_group_layout,
821            entries: &bind_group_entries
822        });
823
824        let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
825        {
826            let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: None });
827            cpass.set_pipeline(&compute_pipeline);
828            cpass.set_bind_group(0, &bind_group, &[]);
829            match shader_module.dispatch {
830                Dispatch::Linear(val) => {
831                    if val < 65536 {
832                        cpass.dispatch(val as u32, 1, 1);
833                    }else {
834                        cpass.dispatch(65535, (val as f64 / 65535f64).ceil() as u32, 1);
835                    }
836                }
837                Dispatch::Custom(x, y, z) => {
838                    cpass.dispatch(x, y, z);
839                }
840            }
841        }
842        for i in shader_module.used_buffers_id.iter() {
844                let b = &self.buffers[*i];
845                if b.is_output {
846                    encoder.copy_buffer_to_buffer(&b.buffer.as_ref().unwrap(), 0, &b.output_buf.as_ref().unwrap(), 0, b.size);
847                }
848            }
849        self.queue.submit(Some(encoder.finish()));
850        
851        self.retrieve_buffer_data(&shader_module.used_buffers_id)
852    }
853
854    pub fn execute_shader_code(&mut self, dispatch: Dispatch, code: &str) -> Vec<OutputVec>{
872        let shader_module = self.create_shader_module(dispatch.clone(), code);
873        self.execute_shader_module(&shader_module)
874    }
875
876    pub fn execute_commands(&mut self, commands: Vec<Command>) {
878        let mut shader_count = 0;
879        let mut shader_index = None;
880        for (i, c) in commands.iter().enumerate() {
881            if let Command::Shader(_) = c {
882                shader_count += 1;
883                shader_index = Some(i);
884            }
885        }
886        if shader_count > 1 {
887            panic!("In function `Device::execute_commands` : There should only be 1 shader in the commands, got {shader_count}");
888        }
889
890        let mut copy_buffers = HashSet::new();
892        for c in commands.iter() {
893            if let Command::Copy(from, to) = c {
894                let buf_id1 = self.buf_name_to_id.get(&from.to_string()).expect("The source buffer has not been created on this device (probably wrong name).");
895                self.buffers[*buf_id1].wgpu_usage |= wgpu::BufferUsages::COPY_SRC | wgpu::BufferUsages::MAP_READ;
896                let buf_id2 = self.buf_name_to_id.get(&to.to_string()).expect("The destination buffer has not been created on this device (probably wrong name).");
897                self.buffers[*buf_id2].wgpu_usage |= wgpu::BufferUsages::COPY_DST;
898                copy_buffers.insert(*buf_id1);
899                copy_buffers.insert(*buf_id2);
900            }
901        }
902
903        let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
904        if shader_count == 0 { for c in commands.iter() {
906                if let Command::Copy(from, to) = c {
907                    let buf1 = &self.buffers[*self.buf_name_to_id.get(&from.to_string()).unwrap()]; let buf2 = &self.buffers[*self.buf_name_to_id.get(&to.to_string()).unwrap()].buffer; encoder.copy_buffer_to_buffer(&buf1.buffer.as_ref().unwrap(), 0, &buf2.as_ref().unwrap(), 0, buf1.size);
910                }
911            }
912        }else { let shader_index = shader_index.unwrap();
914            for c in commands.iter().take(shader_index) {
915                if let Command::Copy(from, to) = c {
916                    let buf1 = &self.buffers[*self.buf_name_to_id.get(&from.to_string()).unwrap()];
917                    let buf2 = &self.buffers[*self.buf_name_to_id.get(&to.to_string()).unwrap()].buffer;
918                    encoder.copy_buffer_to_buffer(&buf1.buffer.as_ref().unwrap(), 0, &buf2.as_ref().unwrap(), 0, buf1.size);
919                }
920            }
921
922            let shader_module = 
923            (if let Command::Shader(sm) = &commands[shader_index] {
924                Some(sm)
925            }else {
926                None
927            }).unwrap();
928
929            let used_buffers_shader_and_copy = &shader_module.used_buffers_id.iter().map(|e| *e).chain(copy_buffers.into_iter()).collect::<Vec<_>>();
930            self.build_buffers(used_buffers_shader_and_copy, true);
931
932            let mut bind_group_entries = vec![];
933            let mut i = 0;
934            for id in shader_module.used_buffers_id.iter() {
935                bind_group_entries.push(wgpu::BindGroupEntry{
936                    binding: i as u32,
937                    resource: self.buffers[*id].buffer.as_ref().unwrap().as_entire_binding()
938                });
939                i += 1;
940            }
941
942            let compute_pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
943                label: None,
944                layout: None,
945                module: &shader_module.shader,
946                entry_point: "main",
947            });
948
949            let bind_group_layout = compute_pipeline.get_bind_group_layout(0);
950            let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
951                label: None,
952                layout: &bind_group_layout,
953                entries: &bind_group_entries
954            });
955
956            {
957                let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("execute commands")});
958                cpass.set_pipeline(&compute_pipeline);
959                cpass.set_bind_group(0, &bind_group, &[]);
960                match shader_module.dispatch {
961                    Dispatch::Linear(val) => {
962                        if val < 65536 {
963                            cpass.dispatch(val as u32, 1, 1);
964                        }else {
965                            cpass.dispatch(65535, (val as f64 / 65535f64).ceil() as u32, 1);
966                        }
967                    }
968                    Dispatch::Custom(x, y, z) => {
969                        cpass.dispatch(x, y, z);
970                    }
971                }
972            }
973
974            for c in commands.iter().skip(shader_index+1) {
976                if let Command::Copy(from, to) = c {
977                    let buf1 = &self.buffers[*self.buf_name_to_id.get(&from.to_string()).unwrap()];
978                    let buf2 = &self.buffers[*self.buf_name_to_id.get(&to.to_string()).unwrap()].buffer;
979                    encoder.copy_buffer_to_buffer(&buf1.buffer.as_ref().unwrap(), 0, &buf2.as_ref().unwrap(), 0, buf1.size);
980                }
981                if let Command::Retrieve(buf) = c {
982                    let buf1 = &self.buffers[*self.buf_name_to_id.get(&buf.to_string()).unwrap()];
983                    encoder.copy_buffer_to_buffer(&buf1.buffer.as_ref().unwrap(), 0, &buf1.output_buf.as_ref().unwrap(), 0, buf1.size);
984                }
985            }
986
987            self.queue.submit(Some(encoder.finish()));
988            
989    
990        }
991    }
992
993    #[inline]
994    fn retrieve_buffer_data(&self, buffers_id: &Vec<usize>) -> Vec<OutputVec> {
995        let mut output_buffers_id = vec![];
996        let mut buffer_slices = vec![];
997        let mut result_futures = vec![];
998        for i in buffers_id.iter() {
999            let b = &self.buffers[*i];
1000            if !b.is_output {
1001                continue;
1002            }
1003            let out_b = b.output_buf.as_ref().unwrap();
1004            let buffer_slice = out_b.slice(..);
1006            let buffer_future = buffer_slice.map_async(wgpu::MapMode::Read);
1007            buffer_slices.push(buffer_slice);
1008            result_futures.push(buffer_future);
1009            output_buffers_id.push(*i);
1010        }
1011
1012        self.device.poll(wgpu::Maintain::Wait);
1014
1015
1016        let mut results = Vec::with_capacity(self.output_buffers_id.len());
1017        let mut buffer_index = 0;
1019
1020        async {
1021            for (i, result_future) in result_futures.into_iter().enumerate() {
1022                let buf = &self.buffers[output_buffers_id[i]];
1023                buffer_index += 1;
1024                if let Ok(()) = result_future.await {
1025                    let data = buffer_slices[i].get_mapped_range();
1027                    let result: OutputVec;
1029                    match buf.type_ {
1030                        BufferType::I32 => {
1031                            result = OutputVec::VecI32(bytemuck::cast_slice::<u8, i32>(&data).to_vec());
1032                        },
1033                        BufferType::U32 => {
1034                            result = OutputVec::VecU32(bytemuck::cast_slice::<u8, u32>(&data).to_vec());
1035                        },
1036                        BufferType::F32 => {
1037                            result = OutputVec::VecF32(bytemuck::cast_slice::<u8, f32>(&data).to_vec());
1038                        },
1039                        BufferType::Bool => {
1040                            result = OutputVec::VecBool(data.iter().map(|&e| e != 0).collect());
1041                        }
1042                    }
1043
1044                    drop(data);
1046                    buf.output_buf.as_ref().unwrap().unmap(); results.push(result);
1050                }else {
1051                    panic!("computations failed");
1052                }
1053            }
1054        results
1055        }.block_on()
1056    }
1057
1058    pub fn get_buffer_data(&self, buffer_names: Vec<&str>) -> Vec<OutputVec> {
1060        let buffers_id = buffer_names.iter().map(|&e| *self.buf_name_to_id.get(e)
1061            .expect("The buffer has not been created on this device (probably wrong name)"))
1062            .collect::<Vec<_>>();
1063        self.retrieve_buffer_data(&buffers_id)
1064    }
1065}
1066
1067pub mod examples {
1068    use crate::{Device, BufferUsage, Dispatch, BufferType, Command};
1069
1070    pub fn simplest_apply() {
1072        let mut device = Device::new();
1073        let v1 = vec![1.0f32, 2.0, 3.0];
1074        let v1 = device.apply_on_vector(v1, "element * 2.0");
1076        println!("{v1:?}");
1077    }
1078
1079    pub fn apply_with_buf() {
1081        let mut device = Device::new();
1082        let v1 = vec![2.0f32, 3.0, 5.0, 7.0, 11.0];
1083        let exponent = vec![3.0];
1084        device.create_buffer_from("exponent", &exponent, BufferUsage::ReadOnly, false);
1085        let cubes = device.apply_on_vector(v1, "pow(element, exponent[0u])");
1086        println!("{cubes:?}")
1087    }
1088
1089    pub fn with_execute_shader() {
1091        let mut device = Device::new();
1092        let v1 = vec![1i32, 2, 3, 4, 5, 6];
1093        device.create_buffer_from("v1", &v1, BufferUsage::ReadOnly, false);
1094        device.create_buffer("output", BufferType::I32, v1.len(), BufferUsage::WriteOnly, true);
1095        let result = device.execute_shader_code(Dispatch::Linear(v1.len()), r"
1096        fn main() {
1097            output[index] = v1[index] * 2;
1098        }
1099        ").into_iter().next().unwrap().unwrap_i32();
1100        assert_eq!(result, vec![2, 4, 6, 8, 10, 12]);
1101    }
1102
1103    pub fn multiple_output_buffers() {
1105        let mut device = Device::new();
1106        let v = vec![1u32, 2, 3];
1107        let v2 = vec![3u32, 4, 5];
1108        let v3 = vec![7u32, 8, 9];
1109        device.create_buffer_from(
1111            "buf",
1112            &v,
1113            BufferUsage::ReadWrite,
1114            true
1115        );
1116        device.create_buffer_from(
1117            "buf2",
1118            &v2,
1119            BufferUsage::ReadOnly,
1120            false
1121        );
1122        device.create_buffer_from(
1124            "buf3",
1125            &v3,
1126            BufferUsage::ReadWrite,
1127            true
1128        );
1129        let mut result = device.execute_shader_code(Dispatch::Linear(v.len()), r"
1130            fn main() {
1131                buf[index] = buf[index] + buf2[index] + buf3[index];
1132                buf3[index] = buf[index] * buf2[index] * buf3[index];
1133            }
1134        ").into_iter();
1135
1136        let sum = result.next().unwrap().unwrap_u32();
1137        let product = result.next().unwrap().unwrap_u32();
1138        println!("{:?}", sum);
1139        println!("{:?}", product);
1140    }
1141
1142    pub fn global_id() {
1144        let mut device = Device::new();
1145        let vec = vec![2u32, 3, 5, 7, 11, 13, 17];
1146        device.create_buffer_from("vec1", &vec, BufferUsage::ReadWrite, true);
1148        let result = device.execute_shader_code(Dispatch::Custom(1, vec.len() as u32, 1), r"
1149        fn main() {
1150            vec1[global_id.y] = vec1[global_id.y] + global_id.x + global_id.z;
1151        }
1152        ").into_iter().next().unwrap().unwrap_u32();
1153        assert_eq!(result, vec![2u32, 3, 5, 7, 11, 13, 17]);
1156    }
1157
1158    pub fn shader_two_steps() {
1160        let mut device = Device::new();
1161        let v = vec![1u32, 2, 3, 4];
1162        device.create_buffer_from("buf1", &v, BufferUsage::ReadWrite, true);
1163        let shader_module = device.create_shader_module(Dispatch::Linear(v.len()), "
1164        fn main() {
1165            buf1[index] = buf1[index] * 17u;
1166        }
1167        ");
1168        let result = device.execute_shader_module(&shader_module).into_iter().next().unwrap().unwrap_u32();
1169        assert_eq!(result, vec![17u32, 34, 51, 68]);
1170    }
1171
1172    pub fn complete_pipeline() {
1174        let mut device = Device::new();
1175        let v1 = vec![1u32, 2, 3, 4, 5];
1176        device.create_buffer_from("v1", &v1, BufferUsage::ReadWrite, true);
1177        let shader_module = device.create_shader_module(Dispatch::Linear(v1.len()), "
1178        fn main() {
1179            v1[index] = v1[index] * 2u;
1180        }
1181        ");
1182        let mut commands = vec![];
1183        commands.push(Command::Shader(shader_module));
1184        commands.push(Command::Retrieve("v1"));
1185        device.execute_commands(commands);
1186        let result = device.get_buffer_data(vec!["v1"]).into_iter().next().unwrap().unwrap_u32();
1187        assert_eq!(result, vec![2u32, 4, 6, 8, 10]);
1188    }
1189
1190    pub fn reusing_device() {
1192        let mut device = Device::new();
1193        let v1 = vec![1i32, 2, 3, 4, 5, 6];
1194        device.create_buffer_from("v1", &v1, BufferUsage::ReadOnly, false);
1195        device.create_buffer("output", BufferType::I32, v1.len(), BufferUsage::WriteOnly, true);
1196        let result = device.execute_shader_code(Dispatch::Linear(v1.len()), r"
1197        fn main() {
1198            output[index] = v1[index] * 2;
1199        }
1200        ").into_iter().next().unwrap().unwrap_i32();
1201        assert_eq!(result, vec![2i32, 4, 6, 8, 10, 12]);
1202
1203        let result2 = device.execute_shader_code(Dispatch::Linear(v1.len()), r"
1204        fn main() {
1205            output[index] = v1[index] * 10;
1206        }
1207        ").into_iter().next().unwrap().unwrap_i32();
1208        assert_eq!(result2, vec![10, 20, 30, 40, 50, 60]);
1209    }
1210
1211    #[test]
1213    pub fn big_computations() {
1214        let mut device = Device::new();
1215        let size = 33_554_432;
1216        device.create_buffer(
1217            "buf",
1218            BufferType::U32,
1219            size,
1220            BufferUsage::WriteOnly,
1221            true);
1222        let result = device.execute_shader_code(Dispatch::Linear(size), r"
1223        fn number_of_seven_in_digit_product(number: u32) -> u32 {
1224            var p: u32 = 1u;
1225            var n: u32 = number;
1226            loop {
1227                if (n == 0u) {break;}
1228                p = p * (n % 10u);
1229                n = n / 10u;
1230            }
1231            var nb_seven: u32 = 0u;
1232            loop {
1233                if (p == 0u) {break;}
1234                if (p % 10u == 7u) {
1235                    nb_seven = nb_seven + 1u;
1236                }
1237                p = p / 10u;
1238            }
1239            return nb_seven;
1240        }
1241        fn main() {
1242            buf[index] = number_of_seven_in_digit_product(index);
1243        }
1244        ").into_iter().next().unwrap().unwrap_u32();
1245        let mut index = 0;
1246        let mut max = result[0];
1247        for (i, e) in result.iter().enumerate() {
1248            if e > &max {
1249                max = *e;
1250                index = i;
1251            }
1252        }
1253        println!("The number who's digit product got the most seven below {size} is {index} with {max} sevens.");
1254    }
1255}