use std::fmt::Debug;
use tch::Tensor;
pub trait FallibleModule: Debug + Send {
type Error;
fn forward(&self, input: &Tensor) -> Result<Tensor, Self::Error>;
}
pub trait FallibleModuleT: Debug + Send {
type Error;
fn forward_t(&self, input: &Tensor, train: bool) -> Result<Tensor, Self::Error>;
}
impl<M> FallibleModuleT for M
where
M: FallibleModule,
{
type Error = M::Error;
fn forward_t(&self, input: &Tensor, _train: bool) -> Result<Tensor, Self::Error> {
self.forward(input)
}
}