1use tensor_rs::tensor::Tensor;
2use super::{OpTrait, OpHandle, OpCall, Op};
3
4use std::cell::{RefCell};
5use std::rc::Rc;
6
7use crate::var::{Var};
8use crate::err::AutoDiffError;
9
10#[cfg(feature = "use-serde")]
11use serde::{Serialize, Deserialize};
12#[cfg(feature = "use-serde")]
13use std::any::Any;
14
15#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
16pub struct GetPatch {
17 #[cfg_attr(feature = "use-serde", serde(skip))]
18 handle: OpHandle,
19 range: Vec<(usize, usize)>,
20 step: Option<Vec<usize>>,
21}
22impl GetPatch {
23 pub fn new(range: &[(usize, usize)], step: Option<&[usize]>)
24 -> GetPatch{
25 let new_range = range.to_vec();
26 let new_step = step.map(|v| v.to_vec());
27 GetPatch {
28 handle: OpHandle::new(),
29 range: new_range,
30 step: new_step
31 }
32 }
33 fn get_handle(&self) -> &OpHandle {
34 &self.handle
35 }
36 fn get_handle_mut(&mut self) -> &mut OpHandle {
37 &mut self.handle
38 }
39}
40impl OpCall for GetPatch {
41 fn call(&mut self, inputs: &[&Var])
42 -> Result<Vec<Var>, AutoDiffError> {
43 let new_one = GetPatch {
44 handle: OpHandle::new(),
45 range: self.range.clone(),
46 step: self.step.clone(),
47 };
48
49 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
50
51 inputs[0].called_with(op, &inputs[1..inputs.len()])
52 }
53}
54impl OpTrait for GetPatch {
55
56 fn get_name(&self) -> &'static str {
57 "GetPatch"
58 }
59 fn get_input_size(&self) -> usize {
60 1
61 }
62 fn get_output_size(&self) -> usize {
63 1
64 }
65 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
66 let step = self.step.as_ref().map(|v| &v[..]);
67 output[0].swap(&input[0].get_patch(&self.range, step));
68 }
69 fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
70 unimplemented!();
71 }
72 fn get_values(&self) -> Vec<Tensor> {
73 Vec::new()
74 }
75 fn get_grads(&self) -> Vec<Tensor> {
76 Vec::new()
77 }
78 fn set_values(&self, _v: &[Tensor]) {
79 }
80 #[cfg(feature = "use-serde")]
81 fn as_any(&self) -> &dyn Any {
82 self
83 }
84}
85
86#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
87pub struct SetPatch {
88 #[cfg_attr(feature = "use-serde", serde(skip))]
89 handle: OpHandle,
90 range: Vec<(usize, usize)>,
91 step: Option<Vec<usize>>,
92}
93impl SetPatch {
94 pub fn new(range: &[(usize, usize)],
95 step: Option<&[usize]>)
96 -> SetPatch {
97 let new_range = range.to_vec();
98 let new_step = step.map(|v| v.to_vec());
99 SetPatch {
100 handle: OpHandle::new(),
101 range: new_range,
102 step: new_step
103 }
104 }
105 fn get_handle(&self) -> &OpHandle {
106 &self.handle
107 }
108 fn get_handle_mut(&mut self) -> &mut OpHandle {
109 &mut self.handle
110 }
111}
112impl OpCall for SetPatch {
113 fn call(&mut self, inputs: &[&Var])
114 -> Result<Vec<Var>, AutoDiffError> {
115 let new_one = SetPatch {
116 handle: OpHandle::new(),
117 range: self.range.clone(),
118 step: self.step.clone(),
119 };
120
121 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
122
123 inputs[0].called_with(op, &inputs[1..inputs.len()])
124 }
125}
126impl OpTrait for SetPatch {
127
128 fn get_name(&self) -> &'static str {
129 "SetPatch"
130 }
131 fn get_input_size(&self) -> usize {
132 2
133 }
134 fn get_output_size(&self) -> usize {
135 1
136 }
137 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
138 let step = self.step.as_ref().map(|v| &v[..]);
139 output[0].swap(&input[0].set_patch(&input[1], &self.range, step));
140 }
141 fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
142 unimplemented!();
143 }
144 fn get_values(&self) -> Vec<Tensor> {
145 Vec::new()
146 }
147 fn get_grads(&self) -> Vec<Tensor> {
148 Vec::new()
149 }
150 fn set_values(&self, _v: &[Tensor]) {
151 }
152 #[cfg(feature = "use-serde")]
153 fn as_any(&self) -> &dyn Any {
154 self
155 }
156}
157
158