1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
use super::{BackwardRecordedOps, ForwardRecordedOps, RecordedOpsParentRef};
use crate::graph::{converter::Forward2BackwardGraphConverter, node::BackwardNodeState};
use burn_tensor::ops::Zeros;
use std::ops::Add;
#[derive(new, Debug, Clone)]
pub struct InitRecordedOps {}
impl<Out> BackwardRecordedOps<Out> for InitRecordedOps
where
Out: Clone + Zeros + Add<Output = Out> + std::fmt::Debug + 'static,
{
fn backward_step(&self, _: &BackwardNodeState<Out>) {}
fn backward_parents(&self) -> Vec<RecordedOpsParentRef> {
vec![]
}
}
impl<Out> ForwardRecordedOps<Out> for InitRecordedOps
where
Out: Clone + Zeros + Add<Output = Out> + std::fmt::Debug + 'static,
{
fn to_backward(
&self,
_graph: &mut Forward2BackwardGraphConverter,
) -> super::BackwardRecordedOpsBoxed<Out> {
Box::new(self.clone())
}
}