use std::collections::HashMap;
use crate::autograd::Variable;
use crate::nn::{Module, NamedInputModule};
use crate::tensor::Result;
pub struct StateAdd;
impl Module for StateAdd {
fn name(&self) -> &str { "state_add" }
fn forward(&self, input: &Variable) -> Result<Variable> {
Ok(input.clone())
}
fn as_named_input(&self) -> Option<&dyn crate::nn::NamedInputModule> {
Some(self)
}
}
impl NamedInputModule for StateAdd {
fn forward_named(
&self,
input: &Variable,
refs: &HashMap<String, Variable>,
) -> Result<Variable> {
let mut result = input.clone();
for v in refs.values() {
result = result.add(v)?;
}
Ok(result)
}
}