dendritic_ndarray/ops/
binary.rs

1use crate::ndarray::NDArray;
2use std::fs::File;
3use std::io::{BufWriter, Read, Write}; 
4
5pub trait BinaryOps {
6    fn mult(&self, other: NDArray<f64>) -> Result <NDArray<f64 >, String>; 
7    fn add(&self, other: NDArray<f64>) -> Result <NDArray<f64 >, String>;
8    fn subtract(&self, other: NDArray<f64>) -> Result<NDArray<f64>, String>;
9    fn dot(&self, other: NDArray<f64>) -> Result<NDArray<f64>, String>;
10    fn scale_add(&self, other: NDArray<f64>) -> Result<NDArray<f64>, String>;
11    fn scale_mult(&self, other: NDArray<f64>) -> Result<NDArray<f64>, String>;
12    fn save(&self, filepath: &str) -> std::io::Result<()>; 
13    fn load(filepath: &str) -> std::io::Result<NDArray<f64>>;
14}
15
16
17impl BinaryOps for NDArray<f64> {
18
19
20    /// Multiply an ndarray by another
21    fn mult(&self, other: NDArray<f64>) -> Result <NDArray<f64 >, String> {
22
23        /* rank mismatch */
24        if self.rank() != other.rank() {
25            return Err("Mult: Rank Mismatch".to_string());
26        }
27
28        let mut result = NDArray::new(self.shape().values()).unwrap();
29        if self.size() != other.values().len() {
30            println!("{:?} {:?}", self.size(), other.values().len()); 
31            return Err("Mult: Size mismatch for arrays".to_string());
32        }
33
34        let mut counter = 0; 
35        let values = other.values(); 
36        for item in self.values() {
37            let mult_result = item * values[counter];
38            let _ = result.set_idx(counter, mult_result);
39            counter += 1;
40        }
41
42        Ok(result)
43    }
44
45
46    /// Add two NDArray's and get resulting NDArray instance
47    fn add(&self, value: NDArray<f64>) -> Result<NDArray<f64>, String> {
48
49        /* rank mismatch */
50        if self.rank() != value.rank() {
51            return Err("Add: Rank Mismatch".to_string());
52        }
53
54        let mut result = NDArray::new(self.shape().values()).unwrap();
55        if self.size() != value.values().len() {
56            return Err("Add: Size mismatch for arrays".to_string());
57        }
58
59        let mut counter = 0; 
60        let values = value.values(); 
61        for item in self.values() {
62            let add_result = item + values[counter];
63            let _ = result.set_idx(counter, add_result);
64            counter += 1;
65        }
66
67        Ok(result)
68    }
69
70
71    /// Subtract values in NDArray instances
72    fn subtract(&self, value: NDArray<f64>) -> Result<NDArray<f64>, String> {
73
74        /* rank mismatch */
75        if self.rank() != value.rank() {
76            return Err("Subtract: Rank Mismatch".to_string());
77        }
78
79        let mut result = NDArray::new(self.shape().values()).unwrap();
80        if self.size() != value.values().len() {
81            return Err("Subtract: Size mismatch for arrays".to_string());
82        }
83
84        let mut counter = 0; 
85        let values = value.values(); 
86        for item in self.values() {
87            let add_result = item - values[counter];
88            let _ = result.set_idx(counter, add_result);
89            counter += 1;
90        }
91
92        Ok(result)
93    }
94
95
96    /// Perform dot product of current NDArray on another NDArray instance
97    fn dot(&self, input: NDArray<f64>) -> Result<NDArray<f64>, String> {
98
99        /* rank mismatch */
100        if self.rank() != input.rank() {
101            return Err("Dot: Rank Mismatch".to_string());
102        }
103
104        if self.rank() != 2 {
105            return Err("Dot: Requires rank 2 values".to_string());
106        }
107
108        if self.shape().dim(self.rank()-1) != input.shape().dim(0) {
109            return Err("Dot: Rows must equal columns".to_string());
110        }
111
112        let new_shape: Vec<usize> = vec![self.shape().dim(0), input.shape().dim(self.rank()-1)];
113        let mut result = NDArray::new(new_shape).unwrap();
114
115        /* stride values to stay in constant time */ 
116        // let mut counter = 0; 
117        let mut row_counter = 0; 
118        let mut col_counter = 0; 
119        let mut stride = 0;  
120        for counter in 0..result.size() {
121
122            if stride == input.shape().dim(self.rank()-1)  {
123                row_counter += 1;
124                stride = 0; 
125            }
126
127            let col_dim = input.shape().dim(input.rank()-1);
128            if col_counter == col_dim {
129                col_counter = 0; 
130            }
131
132            let curr: NDArray<f64> = self.axis(0, row_counter).unwrap();
133            let val: NDArray<f64> = input.axis(1, col_counter).unwrap();
134
135            /* multiply */ 
136            let mut value = 0.0; 
137            for item in 0..curr.size() {
138                value += curr.idx(item) * val.idx(item);
139            }
140            result.set_idx(counter, value).unwrap(); 
141
142            
143            // counter += 1; 
144            col_counter += 1;
145            stride += 1;  
146                    
147        }
148
149        Ok(result)
150    }
151
152
153    /// Add values by scalar for current NDArray instance
154    fn scale_add(&self, value: NDArray<f64>) -> Result<NDArray<f64>, String> {
155
156        if value.shape().dim(0) != 1 {
157            return Err("Scale add must have a vector dimension (1, N)".to_string());
158        }
159
160        let mut total_counter = 0; 
161        let mut counter = 0;
162        let vector_values = value.values();
163        let mut result = NDArray::new(self.shape().values()).unwrap();
164        for item in self.values() {
165            if counter == value.size() {
166                counter = 0;
167            }
168             let add_result = item + vector_values[counter];
169             let _ = result.set_idx(total_counter, add_result);
170             total_counter += 1; 
171        }
172
173        Ok(result)
174
175    }
176
177    /// Elementwise multiplication of ndarray
178    fn scale_mult(&self, value: NDArray<f64>) -> Result<NDArray<f64>, String> {
179    
180        let value_shape = value.shape();
181        if value_shape.dim(0) != 1 {
182            return Err("Scale add must have a vector dimension (1, N)".to_string());
183        }
184
185        let mut total_counter = 0; 
186        let mut counter = 0;
187        let vector_values = value.values();
188        let mut result = NDArray::new(self.shape().values()).unwrap();
189        for item in self.values() {
190            if counter == value.size() {
191                counter = 0;
192            }
193             let add_result = item * vector_values[counter];
194             let _ = result.set_idx(total_counter, add_result);
195             total_counter += 1; 
196        }
197
198        Ok(result)
199    }
200
201
202    /// Save instance of NDArray to json file with serialized values
203    fn save(&self, filepath: &str) -> std::io::Result<()> {
204        let filename_format = format!("{filepath}.json");
205        let file = match File::create(filename_format) {
206            Ok(file) => file,
207            Err(err) => {
208                return Err(err);
209            }
210        };
211        let mut writer = BufWriter::new(file);
212        let json_string = serde_json::to_string_pretty(&self)?;
213        writer.write_all(json_string.as_bytes())?;
214        Ok(())
215    }
216
217
218    /// Load Instance of saved NDarray, serialize to NDArray structure
219    fn load(filepath: &str) -> std::io::Result<NDArray<f64>> {
220        let filename_format = format!("{filepath}.json");
221        let mut file = match File::open(filename_format) {
222            Ok(file) => file,
223            Err(err) => {
224                return Err(err);
225            }
226        };
227        let mut contents = String::new();
228        file.read_to_string(&mut contents)?;
229        let instance: NDArray<f64> = serde_json::from_str(&contents)?;
230        Ok(instance)
231    }
232}