nncombinator/layer/
output.rs

1//! Implementation of output layers
2use std::fmt::Debug;
3use std::marker::PhantomData;
4use std::str::FromStr;
5use crate::{Stack};
6use crate::arr::{Arr, SerializedVec};
7use crate::cuda::ToHost;
8use crate::device::{Device};
9use crate::device::output::DeviceLinearOutput;
10use crate::error::{ConfigReadError, EvaluateError, PersistenceError, SizeMismatchError, TrainingError, TypeConvertError};
11use crate::layer::{AskDiffInput, BackwardAll, BatchBackward, BatchDataType, BatchForward, BatchForwardBase, BatchLoss, BatchPreTrain, BatchPreTrainBase, BatchSize, BatchTrain, ForwardAll, Loss, PreTrain, Train, UpdateWeight};
12use crate::lossfunction::{BatchLossFunctionLinear, LossFunction, LossFunctionLinear};
13use crate::ope::UnitValue;
14use crate::persistence::{Linear, LinearPersistence, Persistence, Specialized, TextFilePersistence};
15
16/// Layer implementation of the output layer (linear layer)
17pub struct LinearOutputLayer<U,P,D,I,PI,const N:usize>
18    where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + PreTrain<U,PreOutput=PI> + Loss<U>,
19          U: Default + Clone + Copy + UnitValue<U>,
20          D: Device<U>,
21          PI: Debug + 'static,
22          I: Debug + Send + Sync {
23    u:PhantomData<U>,
24    i:PhantomData<I>,
25    io:PhantomData<PI>,
26    n:PhantomData<[();N]>,
27    parent:P,
28    device:D,
29}
30impl<U,P,D,I,PI,const N:usize> LinearOutputLayer<U,P,D,I,PI,N>
31    where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + PreTrain<U,PreOutput=PI> + Loss<U>,
32          U: Default + Clone + Copy + UnitValue<U>,
33          D: Device<U>,
34          PI: Debug + 'static,
35          I: Debug + Send + Sync {
36    /// Create and return an instance of LinearOutputLayer
37    /// # Arguments
38    /// * `parent` - upper layer
39    /// * `device` - Device object used for neural network computation
40    pub fn new(parent:P,device:&D) -> LinearOutputLayer<U,P,D,I,PI,N> {
41        LinearOutputLayer {
42            u:PhantomData::<U>,
43            i:PhantomData::<I>,
44            io:PhantomData::<PI>,
45            n:PhantomData::<[();N]>,
46            parent:parent,
47            device:device.clone(),
48        }
49    }
50}
51impl<U,P,D,I,PI,const N:usize> Persistence<U,TextFilePersistence<U>,Specialized> for LinearOutputLayer<U,P,D,I,PI,N>
52    where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> +
53             PreTrain<U,PreOutput=PI> + Loss<U> + Persistence<U,TextFilePersistence<U>,Specialized>,
54          U: Default + Clone + Copy + UnitValue<U> + FromStr + Sized,
55          D: Device<U>,
56          PI: Debug + 'static,
57          I: Debug + Send + Sync {
58    fn load(&mut self, persistence: &mut TextFilePersistence<U>) -> Result<(),ConfigReadError> {
59        self.parent.load(persistence)?;
60        persistence.verify_eof()
61    }
62
63    fn save(&mut self, persistence: &mut TextFilePersistence<U>) -> Result<(), PersistenceError> {
64        self.parent.save(persistence)
65    }
66}
67impl<T,U,P,D,I,PI,const N:usize> Persistence<U,T,Linear> for LinearOutputLayer<U,P,D,I,PI,N>
68    where T: LinearPersistence<U>,
69          P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> +
70             PreTrain<U,PreOutput=PI> + Loss<U> + Persistence<U,T,Linear>,
71          U: Default + Clone + Copy + UnitValue<U>,
72          D: Device<U>,
73          PI: Debug + 'static,
74          I: Debug + Send + Sync {
75    fn load(&mut self, persistence: &mut T) -> Result<(),ConfigReadError> {
76        self.parent.load(persistence)?;
77        persistence.verify_eof()
78    }
79
80    fn save(&mut self, persistence: &mut T) -> Result<(), PersistenceError> {
81        self.parent.save(persistence)
82    }
83}
84impl<U,P,D,I,PI,const N:usize> ForwardAll for LinearOutputLayer<U,P,D,I,PI,N>
85    where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + PreTrain<U,PreOutput=PI> + Loss<U>,
86          U: Default + Clone + Copy + UnitValue<U>,
87          PI: Debug + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
88          I: Debug + Send + Sync,
89          <PI as ToHost<U>>::Output: Debug + 'static,
90          for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI> {
91    type Input = I;
92    type Output = <PI as ToHost<U>>::Output;
93    fn forward_all(&self, input: Self::Input) -> Result<Self::Output, EvaluateError> {
94        Ok(self.parent.forward_all(input)?.to_host()?)
95    }
96}
97impl<U,P,D,I,PI,const N:usize> PreTrain<U> for LinearOutputLayer<U,P,D,I,PI,N>
98    where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + PreTrain<U,PreOutput=PI> + Loss<U>,
99          U: Default + Clone + Copy + UnitValue<U>,
100          PI: Debug + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
101          I: Debug + Send + Sync,
102          <PI as ToHost<U>>::Output: Debug + 'static,
103          for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI> {
104    type PreOutput = PI;
105    type OutStack = P::OutStack;
106
107    fn pre_train(&self, input: Self::Input) -> Result<Self::OutStack, EvaluateError> {
108        self.parent.pre_train(input)
109    }
110}
111impl<U,P,D,I,PI,const N:usize> BackwardAll<U> for LinearOutputLayer<U,P,D,I,PI,N>
112    where P: BackwardAll<U,LossInput=PI> +
113             ForwardAll<Input=I,Output=PI> + PreTrain<U,PreOutput=PI> + Loss<U>,
114          U: Default + Clone + Copy + UnitValue<U>,
115          PI: Debug + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
116          I: Debug + Send + Sync,
117          <PI as ToHost<U>>::Output: Debug + 'static,
118          for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI> {
119    type LossInput = PI;
120    type LossOutput = <P as BackwardAll<U>>::LossOutput;
121
122    fn backward_all<L: LossFunction<U>>(&mut self, input: Self::LossInput, stack:Self::OutStack, lossf:&L)
123        -> Result<(<Self as BackwardAll<U>>::LossOutput,<Self as UpdateWeight<U>>::GradientStack), TrainingError> {
124        self.parent.backward_all(input, stack, lossf)
125    }
126}
127impl<U,P,D,I,PI,const N:usize> UpdateWeight<U> for LinearOutputLayer<U,P,D,I,PI,N>
128    where P: BackwardAll<U,LossInput=PI> +
129             ForwardAll<Input=I,Output=PI> +
130             PreTrain<U,PreOutput=PI> + Loss<U> + UpdateWeight<U>,
131          U: Default + Clone + Copy + UnitValue<U>,
132          PI: Debug + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
133          I: Debug + Send + Sync,
134          <PI as ToHost<U>>::Output: Debug + 'static,
135          for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI> {
136    type GradientStack = <P as UpdateWeight<U>>::GradientStack;
137
138    fn update_weight(&mut self, stack: Self::GradientStack) -> Result<(), TrainingError> {
139        Ok(self.parent.update_weight(stack)?)
140    }
141}
142impl<U,P,D,I,PI,const N:usize> AskDiffInput<U> for LinearOutputLayer<U,P,D,I,PI,N>
143    where P: BackwardAll<U,LossInput=PI> +
144             ForwardAll<Input=I,Output=PI> + PreTrain<U,PreOutput=PI> + Loss<U> + AskDiffInput<U>,
145          U: Default + Clone + Copy + UnitValue<U>,
146          PI: Debug + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
147          I: Debug + Send + Sync,
148          <PI as ToHost<U>>::Output: Debug + 'static,
149          for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI> {
150    type DiffInput = P::DiffInput;
151
152    fn ask_diff_input(&self, stack: &Self::OutStack) -> Result<Self::DiffInput,TypeConvertError> {
153        self.parent.ask_diff_input(stack)
154    }
155}
156impl<U,P,D,I,PI,L,const N:usize> Train<U,L> for LinearOutputLayer<U,P,D,I,PI,N>
157    where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + PreTrain<U,PreOutput=PI> + Loss<U>,
158          U: Default + Clone + Copy + UnitValue<U>,
159          PI: Debug + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
160          I: Debug + Send + Sync,
161          <PI as ToHost<U>>::Output: Debug + 'static,
162          for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI>,
163          for<'a> L: LossFunction<U> + LossFunctionLinear<'a,U,PI,D,N,Output=PI> {
164    fn train(&mut self, expected: Self::Output, input: Self::Input, lossf: &L) -> Result<U, TrainingError> {
165        let stack = self.pre_train(input)?;
166
167        let total_loss = stack.map(|l| self.device.loss_linear_total(&expected,l,lossf))?;
168
169        let (stack,loss) = if self.parent.is_canonical_link(lossf) {
170            let loss = stack.map(|actual| {
171                self.device.loss_linear_by_canonical_link(&expected, &actual)
172            })?;
173
174            (stack,loss)
175        } else {
176            let loss = stack.map(|actual| {
177                self.device.loss_linear(&expected,&actual,lossf)
178            })?;
179
180            self.parent.loss(loss,lossf,stack)?
181        };
182
183        let (_,s) = self.backward_all(loss,stack,lossf)?;
184
185        self.parent.update_weight(s)?;
186
187        Ok(total_loss)
188    }
189}
190impl<U,P,D,I,PI,const N:usize> BatchForwardBase for LinearOutputLayer<U,P,D,I,PI,N>
191    where P: PreTrain<U,PreOutput=PI> + ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + Loss<U> +
192             BatchForwardBase<BatchInput=<I as BatchDataType>::Type,BatchOutput=<PI as BatchDataType>::Type>,
193          U: Default + Clone + Copy + Send + UnitValue<U>,
194          PI: Debug + BatchDataType + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
195          I: Debug + Send + Sync + BatchDataType,
196          <PI as BatchDataType>::Type: Debug + ToHost<U,Output=SerializedVec<U,Arr<U,N>>>,
197          <PI as ToHost<U>>::Output: Debug + 'static,
198          <I as BatchDataType>::Type: Debug,
199          <<PI as BatchDataType>::Type as ToHost<U>>::Output: Debug + 'static,
200          for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI> {
201    type BatchInput = <I as BatchDataType>::Type;
202    type BatchOutput = <<PI as BatchDataType>::Type as ToHost<U>>::Output;
203}
204impl<U,P,D,I,PI,const N:usize> BatchForward for LinearOutputLayer<U,P,D,I,PI,N>
205    where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + PreTrain<U,PreOutput=PI> + Loss<U> +
206             BatchForwardBase<BatchInput=<I as BatchDataType>::Type,BatchOutput=<PI as BatchDataType>::Type> + BatchForward,
207          U: Default + Clone + Copy + Send + UnitValue<U>,
208          PI: Debug + BatchDataType + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
209          I: Debug + Send + Sync + BatchDataType,
210          <PI as BatchDataType>::Type: Debug + ToHost<U,Output=SerializedVec<U,Arr<U,N>>>,
211          <PI as ToHost<U>>::Output: Debug + 'static,
212          <I as BatchDataType>::Type: Debug,
213          <<PI as BatchDataType>::Type as ToHost<U>>::Output: Debug + 'static,
214          for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI> {
215    fn batch_forward(&self, input: Self::BatchInput) -> Result<Self::BatchOutput, TrainingError> {
216        Ok(self.parent.batch_forward(input)?.to_host()?)
217    }
218}
219impl<U,P,D,I,PI,const N:usize> BatchPreTrainBase<U> for LinearOutputLayer<U,P,D,I,PI,N>
220    where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + PreTrain<U,PreOutput=PI> + Loss<U> +
221             BatchForwardBase<BatchInput=<I as BatchDataType>::Type,BatchOutput=<PI as BatchDataType>::Type> + BatchForward +
222             BatchPreTrainBase<U,BatchPreOutput=<PI as BatchDataType>::Type>,
223          U: Default + Clone + Copy + Send + UnitValue<U>,
224          PI: Debug + BatchDataType + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
225          I: Debug + Send + Sync + BatchDataType,
226          <PI as BatchDataType>::Type: Debug + ToHost<U,Output=SerializedVec<U,Arr<U,N>>>,
227          <PI as ToHost<U>>::Output: Debug + 'static,
228          <I as BatchDataType>::Type: Debug,
229          <<PI as BatchDataType>::Type as ToHost<U>>::Output: Debug + 'static,
230          for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI> {
231    type BatchPreOutput = <PI as BatchDataType>::Type;
232    type BatchOutStack = P::BatchOutStack;
233}
234impl<U,P,D,I,PI,const N:usize> BatchPreTrain<U> for LinearOutputLayer<U,P,D,I,PI,N>
235    where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + PreTrain<U,PreOutput=PI> + Loss<U> +
236             BatchForwardBase<BatchInput=<I as BatchDataType>::Type,BatchOutput=<PI as BatchDataType>::Type> + BatchForward +
237             BatchPreTrainBase<U,BatchPreOutput=<PI as BatchDataType>::Type> + BatchPreTrain<U>,
238          U: Default + Clone + Copy + Send + UnitValue<U>,
239          PI: Debug + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
240          I: Debug + Send + Sync + BatchDataType,
241          <PI as BatchDataType>::Type: Debug + ToHost<U,Output=SerializedVec<U,Arr<U,N>>>,
242          <PI as ToHost<U>>::Output: Debug + 'static,
243          <I as BatchDataType>::Type: Debug,
244          <<PI as BatchDataType>::Type as ToHost<U>>::Output: Debug + 'static,
245          for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI> {
246    fn batch_pre_train(&self, input: Self::BatchInput) -> Result<Self::BatchOutStack, TrainingError> {
247        self.parent.batch_pre_train(input)
248    }
249}
250impl<U,P,D,I,PI,const N:usize> BatchBackward<U> for LinearOutputLayer<U,P,D,I,PI,N>
251    where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + PreTrain<U,PreOutput=PI> + Loss<U> +
252             BatchForwardBase<BatchInput=<I as BatchDataType>::Type,BatchOutput=<PI as BatchDataType>::Type> + BatchForward +
253             BatchPreTrainBase<U,BatchPreOutput=<PI as BatchDataType>::Type> + BatchPreTrain<U> +
254             BatchBackward<U> + UpdateWeight<U> + BatchLoss<U,BatchLossInput=<PI as BatchDataType>::Type>,
255          U: Default + Clone + Copy + Send + UnitValue<U>,
256          PI: Debug + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
257          I: Debug + Send + Sync + BatchDataType,
258          <PI as BatchDataType>::Type: Debug + ToHost<U,Output=SerializedVec<U,Arr<U,N>>>,
259          <PI as ToHost<U>>::Output: Debug + 'static,
260          <I as BatchDataType>::Type: Debug + BatchSize,
261          <<PI as BatchDataType>::Type as ToHost<U>>::Output: Debug + 'static,
262          for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI> {
263    type BatchLossInput = <PI as BatchDataType>::Type;
264    type BatchLossOutput = <P as BatchBackward<U>>::BatchLossOutput;
265
266    fn batch_backward<L: LossFunction<U>>(&mut self, input: Self::BatchLossInput, stack: Self::BatchOutStack, lossf: &L)
267        -> Result<(<Self as BatchBackward<U>>::BatchLossOutput,<Self as UpdateWeight<U>>::GradientStack), TrainingError> {
268        self.parent.batch_backward(input,stack,lossf)
269    }
270}
271impl<U,P,D,I,PI,L,const N:usize> BatchTrain<U,D,L> for LinearOutputLayer<U,P,D,I,PI,N>
272    where P: ForwardAll<Input=I,Output=PI> + BackwardAll<U,LossInput=PI> + PreTrain<U,PreOutput=PI> + Loss<U> +
273             BatchForwardBase<BatchInput=<I as BatchDataType>::Type,BatchOutput=<PI as BatchDataType>::Type> + BatchForward +
274             BatchPreTrainBase<U,BatchPreOutput=<PI as BatchDataType>::Type> + BatchPreTrain<U> +
275             BatchBackward<U> + UpdateWeight<U> + BatchLoss<U,BatchLossInput=<PI as BatchDataType>::Type>,
276          U: Default + Clone + Copy + Send + UnitValue<U>,
277          PI: Debug + BatchDataType + ToHost<U,Output=Arr<U,N>> + 'static,
278          I: Debug + Send + Sync + BatchDataType,
279          <PI as BatchDataType>::Type: Debug + ToHost<U,Output=SerializedVec<U,Arr<U,N>>>,
280          <PI as ToHost<U>>::Output: Debug + 'static,
281          <I as BatchDataType>::Type: Debug + BatchSize,
282          <<PI as BatchDataType>::Type as ToHost<U>>::Output: Debug + 'static,
283          for<'a> D: Device<U> + DeviceLinearOutput<'a,U,N,IO=PI,BatchIO=<PI as BatchDataType>::Type>,
284          f64: From<U>,
285          Self: UpdateWeight<U,GradientStack = <P as UpdateWeight<U>>::GradientStack>,
286          for<'a> L: LossFunction<U> + BatchLossFunctionLinear<'a,U,<PI as BatchDataType>::Type,D,N,Output=<PI as BatchDataType>::Type> {
287    fn batch_train(&mut self, expected:Self::BatchOutput, input:Self::BatchInput, lossf:&L) -> Result<U, TrainingError> {
288        if expected.len() != input.size() {
289            return Err(TrainingError::from(SizeMismatchError(expected.len(),input.size())));
290        }
291
292        let stack = self.batch_pre_train(input)?;
293
294        let total_loss = stack.map(|l| self.device.batch_loss_linear_total(&expected,l,lossf))?;
295
296        let (stack,loss) = if self.parent.is_canonical_link(lossf) {
297            let loss = stack.map(|actual| {
298                self.device.loss_linear_batch_by_canonical_link(&expected, &actual)
299            })?;
300
301            (stack,loss)
302        } else {
303            let loss = stack.map(|actual| {
304                self.device.batch_loss_linear(&expected,actual,lossf)
305            })?;
306
307            self.parent.batch_loss(loss,lossf,stack)?
308        };
309
310        let (_,s) = self.parent.batch_backward(loss,stack,lossf)?;
311
312        self.parent.update_weight(s)?;
313
314        Ok(total_loss)
315    }
316}