arr_o_gpu/arr_o_gpu/array/array_trait/
array_compute.rs

1use wgpu::{BindGroup, BindGroupLayout};
2
3use crate::{ArrayType, CheckArrayType, GpuArray, GpuArrayView, MetadataCompound};
4
5pub trait ArrayCompute {
6    fn module(&self) -> &std::sync::Arc<crate::ArrOgpuModule>;
7
8    fn pointer_to_arr(&self) -> [u32; 2];
9
10    fn dim(&self) -> usize;
11
12    fn stride(&self) -> &Vec<u32>;
13
14    fn shape(&self) -> &Vec<u32>;
15
16    fn pointer(&self) -> (u32, u32);
17
18    fn offset(&self) -> u32;
19
20    fn len(&self) -> u32;
21
22    fn is_contiguous(&self) -> bool;
23
24    fn binding(&self) -> Option<&(BindGroupLayout, BindGroup)>;
25
26    fn metadata_compound(&self) -> Option<&MetadataCompound>;
27
28    fn check_contiguous_or_view<'a>(&'a self) -> ArrayType<'a>;
29}
30
31impl ArrayCompute for GpuArray {
32    fn module(&self) -> &std::sync::Arc<crate::ArrOgpuModule> {
33        self.module()
34    }
35
36    fn pointer_to_arr(&self) -> [u32; 2] {
37        self.pointer_to_arr()
38    }
39
40    fn dim(&self) -> usize {
41        self.dim()
42    }
43
44    fn stride(&self) -> &Vec<u32> {
45        self.stride()
46    }
47
48    fn shape(&self) -> &Vec<u32> {
49        self.shape()
50    }
51
52    fn pointer(&self) -> (u32, u32) {
53        self.pointer()
54    }
55
56    fn offset(&self) -> u32 {
57        0
58    }
59
60    fn len(&self) -> u32 {
61        self.len() as u32
62    }
63
64    fn is_contiguous(&self) -> bool {
65        true
66    }
67
68    fn binding(&self) -> Option<&(BindGroupLayout, BindGroup)> {
69        self.binding()
70    }
71
72    fn metadata_compound(&self) -> Option<&MetadataCompound> {
73        self.metadata_compound()
74    }
75
76    fn check_contiguous_or_view<'a>(&'a self) -> ArrayType<'a> {
77        self.check()
78    }
79}
80
81impl<'a, A> ArrayCompute for GpuArrayView<'a, A>
82where
83    A: ArrayCompute,
84{
85    fn module(&self) -> &std::sync::Arc<crate::ArrOgpuModule> {
86        self.array.module()
87    }
88
89    fn pointer_to_arr(&self) -> [u32; 2] {
90        self.array.pointer_to_arr()
91    }
92
93    fn dim(&self) -> usize {
94        self.shape.len()
95    }
96
97    fn stride(&self) -> &Vec<u32> {
98        &self.stride
99    }
100
101    fn shape(&self) -> &Vec<u32> {
102        &self.shape
103    }
104
105    fn pointer(&self) -> (u32, u32) {
106        self.array.pointer()
107    }
108
109    fn offset(&self) -> u32 {
110        self.offset
111    }
112
113    fn len(&self) -> u32 {
114        self.shape.iter().product::<u32>()
115    }
116
117    fn is_contiguous(&self) -> bool {
118        self.pointer.1 - self.pointer.0 == self.shape.iter().product::<u32>()
119    }
120
121    fn binding(&self) -> Option<&(BindGroupLayout, BindGroup)> {
122        self.binding.as_ref()
123    }
124
125    fn metadata_compound(&self) -> Option<&MetadataCompound> {
126        self.metadata_compound.as_ref()
127    }
128
129    fn check_contiguous_or_view(&self) -> ArrayType<'_> {
130        self.check()
131    }
132}