nd_matrix/
lib.rs

1
2
3#[allow(dead_code)]
4
5#[macro_export] macro_rules! matrix {
6    ($dimensions:tt; $type:ty, $default:tt) => {
7        NdMatrix::<$type>::new(vec!$dimensions, $default, 1);
8    };
9
10    ($dimensions:tt; $type:ty, $default:tt, $threads:tt) => {
11        NdMatrix::<$type>::new(vec!$dimensions, $default, $threads);
12    }
13
14}
15
16
17
18
19
20        //vec_a is user input to compare against
21fn oob(vec_a:&Vec<usize>, vec_b:&Vec<usize>) -> bool {
22    for i in 0..vec_a.len() {
23        if vec_a[i] >= vec_b[i] {return true}
24    }
25
26    return false
27}
28
29#[derive(Debug, Clone, PartialEq)]
30pub struct NdMatrix<T> {
31    pub data:Vec<T>,//data of the matrix
32
33    dimensions:usize,//how many layers
34    size:Vec<usize>,//size of each layer
35
36    length:usize,//length of the vector
37
38    count:usize,//iterator count
39}
40
41use std::thread;
42impl<T: Clone + Send + 'static> NdMatrix<T> {
43    pub fn new(dim:Vec<usize>, default: T, threads:usize) -> Self {
44
45        let dimensions = dim.len();
46        let size = dim;
47
48        let mut length:usize = 1;
49
50        for i in 0..size.len() {
51            length *= size[i];
52        }
53        
54        //MULTITHREADING MAGIC
55        let split = length / threads;
56        let remod = length % threads;
57
58        let mut thread_table:Vec<thread::JoinHandle<Vec<T>>> = vec![];
59
60        for _ in 0..threads {//thread inits
61            let thr_split = split;
62            let thr_default = default.clone();
63
64            thread_table.push(thread::spawn(move || {
65                vec![thr_default; thr_split]
66            }));
67        }
68
69        let mut data:Vec<T> = vec![];
70        for thread in thread_table {
71            let mut split_table = thread.join().unwrap();
72
73            data.append(&mut split_table);
74        }
75        data.append(&mut vec![default; remod]);
76        
77        //let data = vec![default; length];
78        
79        NdMatrix {data, dimensions, size, length, count:0}
80    }
81
82
83    pub fn pos(&self, index:Vec<usize>) -> Result<T, Error> {
84        if index.len() != self.size.len() {return Err(Error::InvalidDimensions)}
85        if oob(&index, &self.size) {return Err(Error::OOBIndex)}
86
87
88        let total = self.pos_to_nth(index).unwrap();
89        
90        Ok(self.data[total].clone())
91    }
92
93    pub fn nth(&self, index:usize) -> Result<T, Error> {
94        if index >= self.length {return Err(Error::OOBIndex)}
95
96        Ok(self.data[index].clone())
97    }
98
99
100    pub fn set_pos(&mut self, index:Vec<usize>, value:T) -> Result<(), Error> {
101        if index.len() != self.size.len() {return Err(Error::InvalidDimensions)}
102        if oob(&index, &self.size) {return Err(Error::OOBIndex)}
103        
104        let total = self.pos_to_nth(index).unwrap();
105
106        self.data[total] = value;
107
108        Ok(()) //returning None is actually the good path as no errors were returned
109    }
110
111    pub fn set_nth(&mut self, index:usize, value:T) -> Result<(), Error> {
112        if index >= self.length {return Err(Error::OOBIndex)}
113
114        self.data[index] = value;
115
116        Ok(())
117    }
118
119
120    //thank you chatgpt for writing these 2 methods because i wanted to pull my hair out
121    //no idea wtfs goin on in the for loop tho
122    pub fn pos_to_nth(&self, index:Vec<usize>) -> Result<usize, Error> {
123        if index.len() != self.size.len() {return Err(Error::InvalidDimensions)}
124        if oob(&index, &self.size) {return Err(Error::OOBIndex)}       
125        
126        let mut result = 0;
127        let mut stride = 1;
128        for (p, s) in index.iter().rev().zip(self.size.iter().rev()) {
129            result += p * stride;
130            stride *= s;
131        }
132        Ok(result)
133    }
134
135    pub fn nth_to_pos(&self, index:usize) -> Result<Vec<usize>, Error> {
136        if index >= self.length {return Err(Error::OOBIndex)}
137
138        let mut result = Vec::with_capacity(self.size.len());
139        let mut rem = index;
140        for s in self.size.iter().rev() {
141            let p = rem % s;
142            rem /= s;
143            result.push(p);
144        }
145        result.reverse();
146        Ok(result)
147    }
148
149
150
151        
152    
153
154
155    //properties
156    pub fn len(&self) -> usize {
157        self.length
158    }
159
160    pub fn dim(&self) -> usize {
161        self.dimensions
162    }
163
164    pub fn size(&self) -> Vec<usize> {
165        self.size.clone()
166    }
167
168}
169
170
171impl<T: Clone + Send + 'static> Iterator for NdMatrix<T> {
172    type Item = (T, usize, Vec<usize>);
173
174    fn next(&mut self) -> Option<Self::Item> {
175        if self.count < self.len() {
176            let item = self.data[self.count].clone();
177            let index = self.count;
178            let position = self.nth_to_pos(self.count).unwrap();
179            self.count += 1;
180            Some((item, index, position))
181        }else{None}
182    }
183}
184
185
186//arithmetic
187impl<T:num_traits::Num + Clone + Copy + Send + 'static +
188num_traits::CheckedAdd + num_traits::CheckedSub + num_traits::CheckedMul + num_traits::CheckedDiv> 
189NdMatrix<T> {
190    //basic arithmetic
191    pub fn add(&mut self, operand:NdMatrix<T>) -> Result<(), Error> {
192        if self.dimensions != operand.dimensions {return Err(Error::InvalidDimensions)}
193        
194        let res = self.data[0].checked_add(&operand.data[0]);
195        if res == None {return Err(Error::CannotOperate)}
196
197        for i in 0..self.len() {
198            self.data[i] = self.data[i] + operand.data[i]
199        }
200
201        Ok(())
202    }
203    
204    pub fn sub(&mut self, operand:NdMatrix<T>) -> Result<(), Error> {
205        if self.dimensions != operand.dimensions {return Err(Error::InvalidDimensions)}
206
207        let res = self.data[0].checked_sub(&operand.data[0]);
208        if res == None {return Err(Error::CannotOperate)}
209
210        for i in 0..self.len() {
211            self.data[i] = self.data[i] - operand.data[i]
212        }
213
214        Ok(())
215    }
216
217    pub fn mul(&mut self, operand:NdMatrix<T>) -> Result<(), Error> {
218        if self.dimensions != operand.dimensions {return Err(Error::InvalidDimensions)}
219
220        let res = self.data[0].checked_mul(&operand.data[0]);
221        if res == None {return Err(Error::CannotOperate)}
222
223        for i in 0..self.len() {
224            self.data[i] = self.data[i] * operand.data[i]
225        }
226
227        Ok(())
228    }
229    
230    pub fn div(&mut self, operand:NdMatrix<T>) -> Result<(), Error> {
231        //error catches
232        if self.dimensions != operand.dimensions {return Err(Error::InvalidDimensions)}
233
234        let res = self.data[0].checked_div(&operand.data[0]);
235        if res == None {return Err(Error::CannotOperate)}
236
237        for i in 0..self.len() {
238            self.data[i] = self.data[i] / operand.data[i]
239        }
240
241        Ok(())
242    }
243
244
245    //const arithmetic
246    pub fn const_add(&mut self, operand:T) {
247        
248        for i in 0..self.len() {
249            self.data[i] = self.data[i] + operand
250        }
251
252    }
253
254    pub fn const_sub(&mut self, operand:T) {
255        
256        for i in 0..self.len() {
257            self.data[i] = self.data[i] - operand
258        }
259      
260    }
261
262    pub fn const_mul(&mut self, operand:T) {
263        
264        for i in 0..self.len() {
265            self.data[i] = self.data[i] * operand
266        }
267     
268    }
269
270    pub fn const_div(&mut self, operand:T) {
271        
272        for i in 0..self.len() {
273            self.data[i] = self.data[i] / operand
274        }
275     
276    }
277
278}
279
280
281
282
283#[derive(Clone, Debug)]
284pub enum Error {
285    InvalidDimensions,
286    OOBIndex,
287    CannotOperate,
288}