dendritic_ndarray/ops/
binary.rs1use crate::ndarray::NDArray;
2use std::fs::File;
3use std::io::{BufWriter, Read, Write};
4
5pub trait BinaryOps {
6 fn mult(&self, other: NDArray<f64>) -> Result <NDArray<f64 >, String>;
7 fn add(&self, other: NDArray<f64>) -> Result <NDArray<f64 >, String>;
8 fn subtract(&self, other: NDArray<f64>) -> Result<NDArray<f64>, String>;
9 fn dot(&self, other: NDArray<f64>) -> Result<NDArray<f64>, String>;
10 fn scale_add(&self, other: NDArray<f64>) -> Result<NDArray<f64>, String>;
11 fn scale_mult(&self, other: NDArray<f64>) -> Result<NDArray<f64>, String>;
12 fn save(&self, filepath: &str) -> std::io::Result<()>;
13 fn load(filepath: &str) -> std::io::Result<NDArray<f64>>;
14}
15
16
17impl BinaryOps for NDArray<f64> {
18
19
20 fn mult(&self, other: NDArray<f64>) -> Result <NDArray<f64 >, String> {
22
23 if self.rank() != other.rank() {
25 return Err("Mult: Rank Mismatch".to_string());
26 }
27
28 let mut result = NDArray::new(self.shape().values()).unwrap();
29 if self.size() != other.values().len() {
30 println!("{:?} {:?}", self.size(), other.values().len());
31 return Err("Mult: Size mismatch for arrays".to_string());
32 }
33
34 let mut counter = 0;
35 let values = other.values();
36 for item in self.values() {
37 let mult_result = item * values[counter];
38 let _ = result.set_idx(counter, mult_result);
39 counter += 1;
40 }
41
42 Ok(result)
43 }
44
45
46 fn add(&self, value: NDArray<f64>) -> Result<NDArray<f64>, String> {
48
49 if self.rank() != value.rank() {
51 return Err("Add: Rank Mismatch".to_string());
52 }
53
54 let mut result = NDArray::new(self.shape().values()).unwrap();
55 if self.size() != value.values().len() {
56 return Err("Add: Size mismatch for arrays".to_string());
57 }
58
59 let mut counter = 0;
60 let values = value.values();
61 for item in self.values() {
62 let add_result = item + values[counter];
63 let _ = result.set_idx(counter, add_result);
64 counter += 1;
65 }
66
67 Ok(result)
68 }
69
70
71 fn subtract(&self, value: NDArray<f64>) -> Result<NDArray<f64>, String> {
73
74 if self.rank() != value.rank() {
76 return Err("Subtract: Rank Mismatch".to_string());
77 }
78
79 let mut result = NDArray::new(self.shape().values()).unwrap();
80 if self.size() != value.values().len() {
81 return Err("Subtract: Size mismatch for arrays".to_string());
82 }
83
84 let mut counter = 0;
85 let values = value.values();
86 for item in self.values() {
87 let add_result = item - values[counter];
88 let _ = result.set_idx(counter, add_result);
89 counter += 1;
90 }
91
92 Ok(result)
93 }
94
95
96 fn dot(&self, input: NDArray<f64>) -> Result<NDArray<f64>, String> {
98
99 if self.rank() != input.rank() {
101 return Err("Dot: Rank Mismatch".to_string());
102 }
103
104 if self.rank() != 2 {
105 return Err("Dot: Requires rank 2 values".to_string());
106 }
107
108 if self.shape().dim(self.rank()-1) != input.shape().dim(0) {
109 return Err("Dot: Rows must equal columns".to_string());
110 }
111
112 let new_shape: Vec<usize> = vec![self.shape().dim(0), input.shape().dim(self.rank()-1)];
113 let mut result = NDArray::new(new_shape).unwrap();
114
115 let mut row_counter = 0;
118 let mut col_counter = 0;
119 let mut stride = 0;
120 for counter in 0..result.size() {
121
122 if stride == input.shape().dim(self.rank()-1) {
123 row_counter += 1;
124 stride = 0;
125 }
126
127 let col_dim = input.shape().dim(input.rank()-1);
128 if col_counter == col_dim {
129 col_counter = 0;
130 }
131
132 let curr: NDArray<f64> = self.axis(0, row_counter).unwrap();
133 let val: NDArray<f64> = input.axis(1, col_counter).unwrap();
134
135 let mut value = 0.0;
137 for item in 0..curr.size() {
138 value += curr.idx(item) * val.idx(item);
139 }
140 result.set_idx(counter, value).unwrap();
141
142
143 col_counter += 1;
145 stride += 1;
146
147 }
148
149 Ok(result)
150 }
151
152
153 fn scale_add(&self, value: NDArray<f64>) -> Result<NDArray<f64>, String> {
155
156 if value.shape().dim(0) != 1 {
157 return Err("Scale add must have a vector dimension (1, N)".to_string());
158 }
159
160 let mut total_counter = 0;
161 let mut counter = 0;
162 let vector_values = value.values();
163 let mut result = NDArray::new(self.shape().values()).unwrap();
164 for item in self.values() {
165 if counter == value.size() {
166 counter = 0;
167 }
168 let add_result = item + vector_values[counter];
169 let _ = result.set_idx(total_counter, add_result);
170 total_counter += 1;
171 }
172
173 Ok(result)
174
175 }
176
177 fn scale_mult(&self, value: NDArray<f64>) -> Result<NDArray<f64>, String> {
179
180 let value_shape = value.shape();
181 if value_shape.dim(0) != 1 {
182 return Err("Scale add must have a vector dimension (1, N)".to_string());
183 }
184
185 let mut total_counter = 0;
186 let mut counter = 0;
187 let vector_values = value.values();
188 let mut result = NDArray::new(self.shape().values()).unwrap();
189 for item in self.values() {
190 if counter == value.size() {
191 counter = 0;
192 }
193 let add_result = item * vector_values[counter];
194 let _ = result.set_idx(total_counter, add_result);
195 total_counter += 1;
196 }
197
198 Ok(result)
199 }
200
201
202 fn save(&self, filepath: &str) -> std::io::Result<()> {
204 let filename_format = format!("{filepath}.json");
205 let file = match File::create(filename_format) {
206 Ok(file) => file,
207 Err(err) => {
208 return Err(err);
209 }
210 };
211 let mut writer = BufWriter::new(file);
212 let json_string = serde_json::to_string_pretty(&self)?;
213 writer.write_all(json_string.as_bytes())?;
214 Ok(())
215 }
216
217
218 fn load(filepath: &str) -> std::io::Result<NDArray<f64>> {
220 let filename_format = format!("{filepath}.json");
221 let mut file = match File::open(filename_format) {
222 Ok(file) => file,
223 Err(err) => {
224 return Err(err);
225 }
226 };
227 let mut contents = String::new();
228 file.read_to_string(&mut contents)?;
229 let instance: NDArray<f64> = serde_json::from_str(&contents)?;
230 Ok(instance)
231 }
232}