auto_diff/op/
vision.rs

1use tensor_rs::tensor::Tensor;
2use super::{OpTrait, OpHandle, OpCall, Op};
3
4use std::cell::{RefCell};
5use std::rc::Rc;
6
7use crate::var::{Var};
8use crate::err::AutoDiffError;
9
10#[cfg(feature = "use-serde")]
11use serde::{Serialize, Deserialize};
12#[cfg(feature = "use-serde")]
13use std::any::Any;
14
15#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
16pub struct GetPatch {
17    #[cfg_attr(feature = "use-serde", serde(skip))]
18    handle: OpHandle,
19    range: Vec<(usize, usize)>,
20    step: Option<Vec<usize>>,
21}
22impl GetPatch {
23    pub fn new(range: &[(usize, usize)], step: Option<&[usize]>)
24               -> GetPatch{
25        let new_range = range.to_vec();
26        let new_step = step.map(|v| v.to_vec());
27        GetPatch {
28            handle: OpHandle::new(),
29            range: new_range,
30            step: new_step
31        }
32    }
33    fn get_handle(&self) -> &OpHandle {
34        &self.handle
35    }
36    fn get_handle_mut(&mut self) -> &mut OpHandle {
37        &mut self.handle
38    }
39}
40impl OpCall for GetPatch {
41    fn call(&mut self, inputs: &[&Var])
42            -> Result<Vec<Var>, AutoDiffError> {
43        let new_one = GetPatch {
44            handle: OpHandle::new(),
45            range: self.range.clone(),
46            step: self.step.clone(),
47        };
48
49        let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
50
51        inputs[0].called_with(op, &inputs[1..inputs.len()])
52    }
53}
54impl OpTrait for GetPatch {
55
56    fn get_name(&self) -> &'static str {
57        "GetPatch"
58    }
59    fn get_input_size(&self) -> usize {
60        1
61    }
62    fn get_output_size(&self) -> usize {
63        1
64    }
65    fn apply(&self, input: &[Tensor], output: &[Tensor]) {
66        let step = self.step.as_ref().map(|v| &v[..]);
67        output[0].swap(&input[0].get_patch(&self.range, step));
68    }
69    fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
70        unimplemented!();
71    }
72    fn get_values(&self) -> Vec<Tensor> {
73        Vec::new()
74    }
75    fn get_grads(&self) -> Vec<Tensor> {
76        Vec::new()
77    }
78    fn set_values(&self, _v: &[Tensor]) {
79    }
80    #[cfg(feature = "use-serde")]
81    fn as_any(&self) -> &dyn Any {
82	self
83    }
84}
85
86#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
87pub struct SetPatch {
88    #[cfg_attr(feature = "use-serde", serde(skip))]
89    handle: OpHandle,
90    range: Vec<(usize, usize)>,
91    step: Option<Vec<usize>>,
92}
93impl SetPatch {
94    pub fn new(range: &[(usize, usize)],
95               step: Option<&[usize]>)
96               -> SetPatch {
97        let new_range = range.to_vec();
98        let new_step = step.map(|v| v.to_vec());
99        SetPatch {
100            handle: OpHandle::new(),
101            range: new_range,
102            step: new_step
103        }
104    }
105    fn get_handle(&self) -> &OpHandle {
106        &self.handle
107    }
108    fn get_handle_mut(&mut self) -> &mut OpHandle {
109        &mut self.handle
110    }
111}
112impl OpCall for SetPatch {
113    fn call(&mut self, inputs: &[&Var])
114            -> Result<Vec<Var>, AutoDiffError> {
115        let new_one = SetPatch {
116            handle: OpHandle::new(),
117            range: self.range.clone(),
118            step: self.step.clone(),
119        };
120
121        let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
122
123        inputs[0].called_with(op, &inputs[1..inputs.len()])
124    }
125}
126impl OpTrait for SetPatch {
127
128    fn get_name(&self) -> &'static str {
129        "SetPatch"
130    }
131    fn get_input_size(&self) -> usize {
132        2
133    }
134    fn get_output_size(&self) -> usize {
135        1
136    }
137    fn apply(&self, input: &[Tensor], output: &[Tensor]) {
138        let step = self.step.as_ref().map(|v| &v[..]);
139        output[0].swap(&input[0].set_patch(&input[1], &self.range, step));
140    }
141    fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
142        unimplemented!();
143    }
144    fn get_values(&self) -> Vec<Tensor> {
145        Vec::new()
146    }
147    fn get_grads(&self) -> Vec<Tensor> {
148        Vec::new()
149    }
150    fn set_values(&self, _v: &[Tensor]) {
151    }
152    #[cfg(feature = "use-serde")]
153    fn as_any(&self) -> &dyn Any {
154	self
155    }
156}
157
158