arr_o_gpu/arr_o_gpu/array/array_trait/
array_compute.rs1use wgpu::{BindGroup, BindGroupLayout};
2
3use crate::{ArrayType, CheckArrayType, GpuArray, GpuArrayView, MetadataCompound};
4
5pub trait ArrayCompute {
6 fn module(&self) -> &std::sync::Arc<crate::ArrOgpuModule>;
7
8 fn pointer_to_arr(&self) -> [u32; 2];
9
10 fn dim(&self) -> usize;
11
12 fn stride(&self) -> &Vec<u32>;
13
14 fn shape(&self) -> &Vec<u32>;
15
16 fn pointer(&self) -> (u32, u32);
17
18 fn offset(&self) -> u32;
19
20 fn len(&self) -> u32;
21
22 fn is_contiguous(&self) -> bool;
23
24 fn binding(&self) -> Option<&(BindGroupLayout, BindGroup)>;
25
26 fn metadata_compound(&self) -> Option<&MetadataCompound>;
27
28 fn check_contiguous_or_view<'a>(&'a self) -> ArrayType<'a>;
29}
30
31impl ArrayCompute for GpuArray {
32 fn module(&self) -> &std::sync::Arc<crate::ArrOgpuModule> {
33 self.module()
34 }
35
36 fn pointer_to_arr(&self) -> [u32; 2] {
37 self.pointer_to_arr()
38 }
39
40 fn dim(&self) -> usize {
41 self.dim()
42 }
43
44 fn stride(&self) -> &Vec<u32> {
45 self.stride()
46 }
47
48 fn shape(&self) -> &Vec<u32> {
49 self.shape()
50 }
51
52 fn pointer(&self) -> (u32, u32) {
53 self.pointer()
54 }
55
56 fn offset(&self) -> u32 {
57 0
58 }
59
60 fn len(&self) -> u32 {
61 self.len() as u32
62 }
63
64 fn is_contiguous(&self) -> bool {
65 true
66 }
67
68 fn binding(&self) -> Option<&(BindGroupLayout, BindGroup)> {
69 self.binding()
70 }
71
72 fn metadata_compound(&self) -> Option<&MetadataCompound> {
73 self.metadata_compound()
74 }
75
76 fn check_contiguous_or_view<'a>(&'a self) -> ArrayType<'a> {
77 self.check()
78 }
79}
80
81impl<'a, A> ArrayCompute for GpuArrayView<'a, A>
82where
83 A: ArrayCompute,
84{
85 fn module(&self) -> &std::sync::Arc<crate::ArrOgpuModule> {
86 self.array.module()
87 }
88
89 fn pointer_to_arr(&self) -> [u32; 2] {
90 self.array.pointer_to_arr()
91 }
92
93 fn dim(&self) -> usize {
94 self.shape.len()
95 }
96
97 fn stride(&self) -> &Vec<u32> {
98 &self.stride
99 }
100
101 fn shape(&self) -> &Vec<u32> {
102 &self.shape
103 }
104
105 fn pointer(&self) -> (u32, u32) {
106 self.array.pointer()
107 }
108
109 fn offset(&self) -> u32 {
110 self.offset
111 }
112
113 fn len(&self) -> u32 {
114 self.shape.iter().product::<u32>()
115 }
116
117 fn is_contiguous(&self) -> bool {
118 self.pointer.1 - self.pointer.0 == self.shape.iter().product::<u32>()
119 }
120
121 fn binding(&self) -> Option<&(BindGroupLayout, BindGroup)> {
122 self.binding.as_ref()
123 }
124
125 fn metadata_compound(&self) -> Option<&MetadataCompound> {
126 self.metadata_compound.as_ref()
127 }
128
129 fn check_contiguous_or_view(&self) -> ArrayType<'_> {
130 self.check()
131 }
132}