1use std::fmt::Debug;
3use std::marker::PhantomData;
4use std::str::FromStr;
5use crate::{Stack};
6use crate::arr::{Arr, SerializedVec};
7use crate::cuda::ToHost;
8use crate::device::{Device};
9use crate::device::output::DeviceLinearOutput;
10use crate::error::{ConfigReadError, EvaluateError, PersistenceError, SizeMismatchError, TrainingError, TypeConvertError};
11use crate::layer::{AskDiffInput, BackwardAll, BatchBackward, BatchDataType, BatchForward, BatchForwardBase, BatchLoss, BatchPreTrain, BatchPreTrainBase, BatchSize, BatchTrain, ForwardAll, Loss, PreTrain, Train, UpdateWeight};
12use crate::lossfunction::{BatchLossFunctionLinear, LossFunction, LossFunctionLinear};
13use crate::ope::UnitValue;
14use crate::persistence::{Linear, LinearPersistence, Persistence, Specialized, TextFilePersistence};
15
16pub struct LinearOutputLayer<U,P,D,I,PI,const N:usize>
18 where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + PreTrain<U,PreOutput=PI> + Loss<U>,
19 U: Default + Clone + Copy + UnitValue<U>,
20 D: Device<U>,
21 PI: Debug + 'static,
22 I: Debug + Send + Sync {
23 u:PhantomData<U>,
24 i:PhantomData<I>,
25 io:PhantomData<PI>,
26 n:PhantomData<[();N]>,
27 parent:P,
28 device:D,
29}
30impl<U,P,D,I,PI,const N:usize> LinearOutputLayer<U,P,D,I,PI,N>
31 where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + PreTrain<U,PreOutput=PI> + Loss<U>,
32 U: Default + Clone + Copy + UnitValue<U>,
33 D: Device<U>,
34 PI: Debug + 'static,
35 I: Debug + Send + Sync {
36 pub fn new(parent:P,device:&D) -> LinearOutputLayer<U,P,D,I,PI,N> {
41 LinearOutputLayer {
42 u:PhantomData::<U>,
43 i:PhantomData::<I>,
44 io:PhantomData::<PI>,
45 n:PhantomData::<[();N]>,
46 parent:parent,
47 device:device.clone(),
48 }
49 }
50}
51impl<U,P,D,I,PI,const N:usize> Persistence<U,TextFilePersistence<U>,Specialized> for LinearOutputLayer<U,P,D,I,PI,N>
52 where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> +
53 PreTrain<U,PreOutput=PI> + Loss<U> + Persistence<U,TextFilePersistence<U>,Specialized>,
54 U: Default + Clone + Copy + UnitValue<U> + FromStr + Sized,
55 D: Device<U>,
56 PI: Debug + 'static,
57 I: Debug + Send + Sync {
58 fn load(&mut self, persistence: &mut TextFilePersistence<U>) -> Result<(),ConfigReadError> {
59 self.parent.load(persistence)?;
60 persistence.verify_eof()
61 }
62
63 fn save(&mut self, persistence: &mut TextFilePersistence<U>) -> Result<(), PersistenceError> {
64 self.parent.save(persistence)
65 }
66}
67impl<T,U,P,D,I,PI,const N:usize> Persistence<U,T,Linear> for LinearOutputLayer<U,P,D,I,PI,N>
68 where T: LinearPersistence<U>,
69 P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> +
70 PreTrain<U,PreOutput=PI> + Loss<U> + Persistence<U,T,Linear>,
71 U: Default + Clone + Copy + UnitValue<U>,
72 D: Device<U>,
73 PI: Debug + 'static,
74 I: Debug + Send + Sync {
75 fn load(&mut self, persistence: &mut T) -> Result<(),ConfigReadError> {
76 self.parent.load(persistence)?;
77 persistence.verify_eof()
78 }
79
80 fn save(&mut self, persistence: &mut T) -> Result<(), PersistenceError> {
81 self.parent.save(persistence)
82 }
83}
84impl<U,P,D,I,PI,const N:usize> ForwardAll for LinearOutputLayer<U,P,D,I,PI,N>
85 where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + PreTrain<U,PreOutput=PI> + Loss<U>,
86 U: Default + Clone + Copy + UnitValue<U>,
87 PI: Debug + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
88 I: Debug + Send + Sync,
89 <PI as ToHost<U>>::Output: Debug + 'static,
90 for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI> {
91 type Input = I;
92 type Output = <PI as ToHost<U>>::Output;
93 fn forward_all(&self, input: Self::Input) -> Result<Self::Output, EvaluateError> {
94 Ok(self.parent.forward_all(input)?.to_host()?)
95 }
96}
97impl<U,P,D,I,PI,const N:usize> PreTrain<U> for LinearOutputLayer<U,P,D,I,PI,N>
98 where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + PreTrain<U,PreOutput=PI> + Loss<U>,
99 U: Default + Clone + Copy + UnitValue<U>,
100 PI: Debug + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
101 I: Debug + Send + Sync,
102 <PI as ToHost<U>>::Output: Debug + 'static,
103 for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI> {
104 type PreOutput = PI;
105 type OutStack = P::OutStack;
106
107 fn pre_train(&self, input: Self::Input) -> Result<Self::OutStack, EvaluateError> {
108 self.parent.pre_train(input)
109 }
110}
111impl<U,P,D,I,PI,const N:usize> BackwardAll<U> for LinearOutputLayer<U,P,D,I,PI,N>
112 where P: BackwardAll<U,LossInput=PI> +
113 ForwardAll<Input=I,Output=PI> + PreTrain<U,PreOutput=PI> + Loss<U>,
114 U: Default + Clone + Copy + UnitValue<U>,
115 PI: Debug + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
116 I: Debug + Send + Sync,
117 <PI as ToHost<U>>::Output: Debug + 'static,
118 for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI> {
119 type LossInput = PI;
120 type LossOutput = <P as BackwardAll<U>>::LossOutput;
121
122 fn backward_all<L: LossFunction<U>>(&mut self, input: Self::LossInput, stack:Self::OutStack, lossf:&L)
123 -> Result<(<Self as BackwardAll<U>>::LossOutput,<Self as UpdateWeight<U>>::GradientStack), TrainingError> {
124 self.parent.backward_all(input, stack, lossf)
125 }
126}
127impl<U,P,D,I,PI,const N:usize> UpdateWeight<U> for LinearOutputLayer<U,P,D,I,PI,N>
128 where P: BackwardAll<U,LossInput=PI> +
129 ForwardAll<Input=I,Output=PI> +
130 PreTrain<U,PreOutput=PI> + Loss<U> + UpdateWeight<U>,
131 U: Default + Clone + Copy + UnitValue<U>,
132 PI: Debug + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
133 I: Debug + Send + Sync,
134 <PI as ToHost<U>>::Output: Debug + 'static,
135 for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI> {
136 type GradientStack = <P as UpdateWeight<U>>::GradientStack;
137
138 fn update_weight(&mut self, stack: Self::GradientStack) -> Result<(), TrainingError> {
139 Ok(self.parent.update_weight(stack)?)
140 }
141}
142impl<U,P,D,I,PI,const N:usize> AskDiffInput<U> for LinearOutputLayer<U,P,D,I,PI,N>
143 where P: BackwardAll<U,LossInput=PI> +
144 ForwardAll<Input=I,Output=PI> + PreTrain<U,PreOutput=PI> + Loss<U> + AskDiffInput<U>,
145 U: Default + Clone + Copy + UnitValue<U>,
146 PI: Debug + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
147 I: Debug + Send + Sync,
148 <PI as ToHost<U>>::Output: Debug + 'static,
149 for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI> {
150 type DiffInput = P::DiffInput;
151
152 fn ask_diff_input(&self, stack: &Self::OutStack) -> Result<Self::DiffInput,TypeConvertError> {
153 self.parent.ask_diff_input(stack)
154 }
155}
156impl<U,P,D,I,PI,L,const N:usize> Train<U,L> for LinearOutputLayer<U,P,D,I,PI,N>
157 where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + PreTrain<U,PreOutput=PI> + Loss<U>,
158 U: Default + Clone + Copy + UnitValue<U>,
159 PI: Debug + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
160 I: Debug + Send + Sync,
161 <PI as ToHost<U>>::Output: Debug + 'static,
162 for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI>,
163 for<'a> L: LossFunction<U> + LossFunctionLinear<'a,U,PI,D,N,Output=PI> {
164 fn train(&mut self, expected: Self::Output, input: Self::Input, lossf: &L) -> Result<U, TrainingError> {
165 let stack = self.pre_train(input)?;
166
167 let total_loss = stack.map(|l| self.device.loss_linear_total(&expected,l,lossf))?;
168
169 let (stack,loss) = if self.parent.is_canonical_link(lossf) {
170 let loss = stack.map(|actual| {
171 self.device.loss_linear_by_canonical_link(&expected, &actual)
172 })?;
173
174 (stack,loss)
175 } else {
176 let loss = stack.map(|actual| {
177 self.device.loss_linear(&expected,&actual,lossf)
178 })?;
179
180 self.parent.loss(loss,lossf,stack)?
181 };
182
183 let (_,s) = self.backward_all(loss,stack,lossf)?;
184
185 self.parent.update_weight(s)?;
186
187 Ok(total_loss)
188 }
189}
190impl<U,P,D,I,PI,const N:usize> BatchForwardBase for LinearOutputLayer<U,P,D,I,PI,N>
191 where P: PreTrain<U,PreOutput=PI> + ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + Loss<U> +
192 BatchForwardBase<BatchInput=<I as BatchDataType>::Type,BatchOutput=<PI as BatchDataType>::Type>,
193 U: Default + Clone + Copy + Send + UnitValue<U>,
194 PI: Debug + BatchDataType + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
195 I: Debug + Send + Sync + BatchDataType,
196 <PI as BatchDataType>::Type: Debug + ToHost<U,Output=SerializedVec<U,Arr<U,N>>>,
197 <PI as ToHost<U>>::Output: Debug + 'static,
198 <I as BatchDataType>::Type: Debug,
199 <<PI as BatchDataType>::Type as ToHost<U>>::Output: Debug + 'static,
200 for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI> {
201 type BatchInput = <I as BatchDataType>::Type;
202 type BatchOutput = <<PI as BatchDataType>::Type as ToHost<U>>::Output;
203}
204impl<U,P,D,I,PI,const N:usize> BatchForward for LinearOutputLayer<U,P,D,I,PI,N>
205 where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + PreTrain<U,PreOutput=PI> + Loss<U> +
206 BatchForwardBase<BatchInput=<I as BatchDataType>::Type,BatchOutput=<PI as BatchDataType>::Type> + BatchForward,
207 U: Default + Clone + Copy + Send + UnitValue<U>,
208 PI: Debug + BatchDataType + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
209 I: Debug + Send + Sync + BatchDataType,
210 <PI as BatchDataType>::Type: Debug + ToHost<U,Output=SerializedVec<U,Arr<U,N>>>,
211 <PI as ToHost<U>>::Output: Debug + 'static,
212 <I as BatchDataType>::Type: Debug,
213 <<PI as BatchDataType>::Type as ToHost<U>>::Output: Debug + 'static,
214 for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI> {
215 fn batch_forward(&self, input: Self::BatchInput) -> Result<Self::BatchOutput, TrainingError> {
216 Ok(self.parent.batch_forward(input)?.to_host()?)
217 }
218}
219impl<U,P,D,I,PI,const N:usize> BatchPreTrainBase<U> for LinearOutputLayer<U,P,D,I,PI,N>
220 where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + PreTrain<U,PreOutput=PI> + Loss<U> +
221 BatchForwardBase<BatchInput=<I as BatchDataType>::Type,BatchOutput=<PI as BatchDataType>::Type> + BatchForward +
222 BatchPreTrainBase<U,BatchPreOutput=<PI as BatchDataType>::Type>,
223 U: Default + Clone + Copy + Send + UnitValue<U>,
224 PI: Debug + BatchDataType + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
225 I: Debug + Send + Sync + BatchDataType,
226 <PI as BatchDataType>::Type: Debug + ToHost<U,Output=SerializedVec<U,Arr<U,N>>>,
227 <PI as ToHost<U>>::Output: Debug + 'static,
228 <I as BatchDataType>::Type: Debug,
229 <<PI as BatchDataType>::Type as ToHost<U>>::Output: Debug + 'static,
230 for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI> {
231 type BatchPreOutput = <PI as BatchDataType>::Type;
232 type BatchOutStack = P::BatchOutStack;
233}
234impl<U,P,D,I,PI,const N:usize> BatchPreTrain<U> for LinearOutputLayer<U,P,D,I,PI,N>
235 where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + PreTrain<U,PreOutput=PI> + Loss<U> +
236 BatchForwardBase<BatchInput=<I as BatchDataType>::Type,BatchOutput=<PI as BatchDataType>::Type> + BatchForward +
237 BatchPreTrainBase<U,BatchPreOutput=<PI as BatchDataType>::Type> + BatchPreTrain<U>,
238 U: Default + Clone + Copy + Send + UnitValue<U>,
239 PI: Debug + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
240 I: Debug + Send + Sync + BatchDataType,
241 <PI as BatchDataType>::Type: Debug + ToHost<U,Output=SerializedVec<U,Arr<U,N>>>,
242 <PI as ToHost<U>>::Output: Debug + 'static,
243 <I as BatchDataType>::Type: Debug,
244 <<PI as BatchDataType>::Type as ToHost<U>>::Output: Debug + 'static,
245 for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI> {
246 fn batch_pre_train(&self, input: Self::BatchInput) -> Result<Self::BatchOutStack, TrainingError> {
247 self.parent.batch_pre_train(input)
248 }
249}
250impl<U,P,D,I,PI,const N:usize> BatchBackward<U> for LinearOutputLayer<U,P,D,I,PI,N>
251 where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + PreTrain<U,PreOutput=PI> + Loss<U> +
252 BatchForwardBase<BatchInput=<I as BatchDataType>::Type,BatchOutput=<PI as BatchDataType>::Type> + BatchForward +
253 BatchPreTrainBase<U,BatchPreOutput=<PI as BatchDataType>::Type> + BatchPreTrain<U> +
254 BatchBackward<U> + UpdateWeight<U> + BatchLoss<U,BatchLossInput=<PI as BatchDataType>::Type>,
255 U: Default + Clone + Copy + Send + UnitValue<U>,
256 PI: Debug + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
257 I: Debug + Send + Sync + BatchDataType,
258 <PI as BatchDataType>::Type: Debug + ToHost<U,Output=SerializedVec<U,Arr<U,N>>>,
259 <PI as ToHost<U>>::Output: Debug + 'static,
260 <I as BatchDataType>::Type: Debug + BatchSize,
261 <<PI as BatchDataType>::Type as ToHost<U>>::Output: Debug + 'static,
262 for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI> {
263 type BatchLossInput = <PI as BatchDataType>::Type;
264 type BatchLossOutput = <P as BatchBackward<U>>::BatchLossOutput;
265
266 fn batch_backward<L: LossFunction<U>>(&mut self, input: Self::BatchLossInput, stack: Self::BatchOutStack, lossf: &L)
267 -> Result<(<Self as BatchBackward<U>>::BatchLossOutput,<Self as UpdateWeight<U>>::GradientStack), TrainingError> {
268 self.parent.batch_backward(input,stack,lossf)
269 }
270}
271impl<U,P,D,I,PI,L,const N:usize> BatchTrain<U,D,L> for LinearOutputLayer<U,P,D,I,PI,N>
272 where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + PreTrain<U,PreOutput=PI> + Loss<U> +
273 BatchForwardBase<BatchInput=<I as BatchDataType>::Type,BatchOutput=<PI as BatchDataType>::Type> + BatchForward +
274 BatchPreTrainBase<U,BatchPreOutput=<PI as BatchDataType>::Type> + BatchPreTrain<U> +
275 BatchBackward<U> + UpdateWeight<U> + BatchLoss<U,BatchLossInput=<PI as BatchDataType>::Type>,
276 U: Default + Clone + Copy + Send + UnitValue<U>,
277 PI: Debug + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
278 I: Debug + Send + Sync + BatchDataType,
279 <PI as BatchDataType>::Type: Debug + ToHost<U,Output=SerializedVec<U,Arr<U,N>>>,
280 <PI as ToHost<U>>::Output: Debug + 'static,
281 <I as BatchDataType>::Type: Debug + BatchSize,
282 <<PI as BatchDataType>::Type as ToHost<U>>::Output: Debug + 'static,
283 for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI,BatchIO=<PI as BatchDataType>::Type>,
284 f64: From<U>,
285 Self: UpdateWeight<U,GradientStack = <P as UpdateWeight<U>>::GradientStack>,
286 for<'a> L: LossFunction<U> + BatchLossFunctionLinear<'a,U,<PI as BatchDataType>::Type,D,N,Output=<PI as BatchDataType>::Type> {
287 fn batch_train(&mut self, expected:Self::BatchOutput, input:Self::BatchInput, lossf:&L) -> Result<U, TrainingError> {
288 if expected.len() != input.size() {
289 return Err(TrainingError::from(SizeMismatchError(expected.len(),input.size())));
290 }
291
292 let stack = self.batch_pre_train(input)?;
293
294 let total_loss = stack.map(|l| self.device.batch_loss_linear_total(&expected,l,lossf))?;
295
296 let (stack,loss) = if self.parent.is_canonical_link(lossf) {
297 let loss = stack.map(|actual| {
298 self.device.loss_linear_batch_by_canonical_link(&expected, &actual)
299 })?;
300
301 (stack,loss)
302 } else {
303 let loss = stack.map(|actual| {
304 self.device.batch_loss_linear(&expected,actual,lossf)
305 })?;
306
307 self.parent.batch_loss(loss,lossf,stack)?
308 };
309
310 let (_,s) = self.parent.batch_backward(loss,stack,lossf)?;
311
312 self.parent.update_weight(s)?;
313
314 Ok(total_loss)
315 }
316}