1#![allow(clippy::redundant_closure)]
2use std::collections::{BTreeSet, BTreeMap};
3use std::fmt;
4
5use crate::collection::generational_index::{GenIndex, GenKey};
6use crate::collection::directed_graph::Graph;
7use tensor_rs::tensor::Tensor;
8use crate::op::Op;
9use crate::err::AutoDiffError;
10
11#[cfg(feature = "use-serde")]
12use serde::{Serialize, Deserialize};
13
14#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
17#[derive(Clone)]
18pub struct Net {
19 data: GenIndex<Tensor>,
20 ops: GenIndex<Op>,
21 set_mark: BTreeSet<GenKey>,
22 graph: Graph<GenKey, GenKey>,
23 data_grad: BTreeMap<GenKey, Tensor>,
24 label2id: BTreeMap<String, GenKey>, }
26
27impl Net {
28 pub fn new() -> Net {
29 Net {
30 data: GenIndex::new(),
31 ops: GenIndex::new(),
32 set_mark: BTreeSet::new(),
33 graph: Graph::new(),
34 data_grad: BTreeMap::new(),
35 label2id: BTreeMap::new(),
36 }
37 }
38
39 pub fn get_data(&self) -> &GenIndex<Tensor> {
40 &self.data
41 }
42
43 pub fn get_data_mut(&mut self) -> &mut GenIndex<Tensor> {
44 &mut self.data
45 }
46 pub fn get_ops(&self) -> &GenIndex<Op> {
47 &self.ops
48 }
49 pub fn get_ops_mut(&mut self) -> &mut GenIndex<Op> {
50 &mut self.ops
51 }
52
53 pub fn add_tensor(&mut self, t: Tensor) -> GenKey {
54 let id = self.data.insert(t);
55 self.graph.add_data(&id).expect("");
56 id
57 }
58
59 pub fn get_tensor(&self, id: GenKey) -> Result<Tensor, AutoDiffError> {
60 match self.data.get(&id) {
61 Ok(v) => {Ok(v.ref_copy())}, Err(v) => {Err(v)}
63 }
64 }
65 pub fn set_tensor(&mut self, id: GenKey, val: Tensor) -> Result<(), AutoDiffError> {
66 self.data.replace(&id, val)
67 }
68
69 pub fn add_op(&mut self, op: Op) -> GenKey {
71 let id = self.ops.insert(op);
72 self.graph.add_op(&id).expect("");
73 id
74 }
75 pub fn get_op(&self, id: GenKey) -> Result<Op, AutoDiffError> {
76 Ok(self.ops.get(&id)?.ref_copy())
77 }
78
79 pub fn get_grad(&self, id: GenKey) -> Result<Tensor, AutoDiffError> {
80 match self.data_grad.get(&id) {
81 Some(v) => {Ok(v.ref_copy())},
82 None => {Err(AutoDiffError::new(&format!("Data {:?} doesn't ahave gradient yet.", id)))}
83 }
84 }
85
86 pub fn get_input_edge_data(&self) -> BTreeSet<GenKey> {
87 self.graph.get_input_edge_data()
88 }
89
90 pub fn get_output_edge_data(&self) -> BTreeSet<GenKey> {
91 self.graph.get_output_edge_data()
92 }
93
94
95pub fn connect(&mut self, input: &[GenKey], op: GenKey, output: &[GenKey]) {
150
151 self.graph.connect(input, output, &op).expect("");
152 }
153
154
155 pub fn set_mark(&mut self, did: &GenKey) {
157 self.set_mark.insert(*did);
158 }
159 pub fn unset_mark(&mut self, did: &GenKey) {
160 self.set_mark.remove(did);
161 }
162
163 pub fn eval(&mut self, starting_node: &[GenKey]) -> Result<(), BTreeSet<GenKey>> {
165
166 self.graph
167 .walk(
168 starting_node,
169 true,
170 |input, output, op| {
171 let mut inputs: Vec<Tensor> = Vec::new();
174 for input_id in input {
175 let a = self.data.get(input_id).expect("").ref_copy();
176 inputs.push(a);
177 }
178
179 let mut outputs: Vec<Tensor> = Vec::new();
180 for output_id in output {
181 let a = self.data.get(output_id).expect("").ref_copy();
182 outputs.push(a);
183 }
184
185 self.ops
186 .get(op)
187 .expect("")
188 .apply(&inputs, &outputs);
189
190 }
193 )?;
194
195 Ok(())
196 }
197
198pub fn bptt(&mut self, output_grad: &BTreeMap<GenKey, Tensor>) {
229 let mut output = Vec::new();
230 self.data_grad.clear();
231 for (k, v) in output_grad {
232 output.push(*k);
233 self.data_grad.insert(*k, v.clone());
234 }
235
236 for i in self.graph.iter_data() {
237 self.data_grad.entry(*i).or_insert_with(Tensor::new);
238 }
239
240 self.graph
241 .walk(
242 &output[..],
243 false,
244 |output_grads, input_grads, op| {
245 let mut inputs: Vec<Tensor> = Vec::new();
249 for input_id in input_grads {
250 let a = self.data.get(input_id).expect("").ref_copy();
252 inputs.push(a);
253 }
254 let mut output_grad: Vec<Tensor> = Vec::new();
258 for output_id in output_grads {
259 let a = self.data_grad.get(output_id).expect("").ref_copy();
261 output_grad.push(a);
262 }
263 let mut input_grad: Vec<Tensor> = Vec::new();
267 for input_id in input_grads {
268 let a = self.data_grad.get(input_id).expect("").ref_copy();
270 input_grad.push(a);
271 }
272 self.ops
275 .get(op)
276 .expect("")
277 .grad(&inputs, &output_grad, &input_grad);
278
279 }
282 ).expect("");
283 }
284
285 pub fn visit_op<F>(&mut self, closure: F,
287 allow: Option<Vec<GenKey>>,
288 skip: Option<Vec<GenKey>>)
289 where F: Fn(&Op) {
290 let allow_list = if let Some(val) = allow { val } else {Vec::new()};
291 let skip_list = if let Some(val) = skip {val} else {Vec::new()};
292
293 for i in self.graph.iter_op() {
294 if (allow_list.is_empty() && skip_list.is_empty()) ||
295 (!allow_list.is_empty() && allow_list.contains(i)) ||
296 (!skip_list.is_empty() && !skip_list.contains(i) ) {
297 closure(self.ops.get(i).expect(""));
298 }
299 }
300 }
301
302 pub fn visit_data<F>(&mut self, closure: F)
303 where F: Fn(GenKey, &Tensor) {
304 for i in self.graph.iter_data() {
305 closure(*i, self.data.get(i).expect(""));
306 }
307 }
308
309 pub fn append(&mut self, other: &Self,
312 original_keys: &[GenKey]) -> Result<Vec<GenKey>, AutoDiffError> {
313
314 let mut data_key_map = BTreeMap::new();
315 let mut ret_keys = Vec::new();
316 for key in other.get_data().iter_key() {
317 let new_key = self.add_tensor(other.get_tensor(key)?);
318 if original_keys.contains(&key) {
319 ret_keys.push(new_key);
320 }
321 data_key_map.insert(key, new_key);
322 }
323
324 let mut op_key_map = BTreeMap::new();
325 for key in other.get_ops().iter_key() {
326 let new_key = self.add_op(other.get_op(key)?);
327 op_key_map.insert(key, new_key);
328 }
329
330 self.graph.append(&other.graph, data_key_map, op_key_map)?;
331
332 Ok(ret_keys)
333 }
334
335 pub fn set_label(&mut self, label: &str, id: &GenKey) -> Result<(), AutoDiffError>{
337 if !self.data.contains(id) {
338 Err(AutoDiffError::new("unknown id."))
339 } else {
340 self.label2id.insert(label.to_string(), *id);
341 Ok(())
342 }
343 }
344
345 pub fn get_id_by_label(&self, label: &str) -> Result<GenKey, AutoDiffError> {
346 match self.label2id.get(label) {
347 Some(v) => {Ok(*v)},
348 None => {Err(AutoDiffError::new("unknown label."))}
349 }
350 }
351
352 pub fn drop_label(&mut self, label: &str) -> Result<GenKey, AutoDiffError> {
353 if !self.label2id.contains_key(label) {
354 Err(AutoDiffError::new("unknown label."))
355 } else {
356 Ok(*self.label2id.get(label).expect("unknown label."))
357 }
358 }
359}
360
361impl fmt::Debug for Net {
362 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
363 writeln!(f, "Dumping Net:")?;
364 for i in self.data.iter_key() {
365 writeln!(f, "id: {:?} data: {:?}", i, self.data.get(&i)?)?;
366 }
367 writeln!(f, "data_grad")?;
368 for (k, v) in self.data_grad.iter() {
369 writeln!(f, "id: {:?} data: {:?}", k, v)?;
370 }
371 writeln!(f, "op names")?;
372 for i in self.ops.iter_key() {
373 writeln!(f, "id: {:?} \n data: {:?}", i, self.ops.get(&i)?.get_name())?;
374 }
375 writeln!(f, "graph: {:?}", self.graph)
376 }
377}
378
379impl Default for Net {
380 fn default() -> Self {
381 Self::new()
382 }
383}
384
385