use crate::error::Result;
use crate::losses::Loss;
use crate::optimizers::Optimizer;
use scirs2_core::ndarray::{Array, ScalarOperand};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
pub trait Model<F: Float + Debug + ScalarOperand> {
fn forward(
&self,
input: &Array<F, scirs2_core::ndarray::IxDyn>,
) -> Result<Array<F, scirs2_core::ndarray::IxDyn>>;
fn backward(
&self,
input: &Array<F, scirs2_core::ndarray::IxDyn>,
grad_output: &Array<F, scirs2_core::ndarray::IxDyn>,
) -> Result<Array<F, scirs2_core::ndarray::IxDyn>>;
fn update(&mut self, learningrate: F) -> Result<()>;
fn train_batch(
&mut self,
inputs: &Array<F, scirs2_core::ndarray::IxDyn>,
targets: &Array<F, scirs2_core::ndarray::IxDyn>,
loss_fn: &dyn Loss<F>,
optimizer: &mut dyn Optimizer<F>,
) -> Result<F>;
fn predict(
&self,
inputs: &Array<F, scirs2_core::ndarray::IxDyn>,
) -> Result<Array<F, scirs2_core::ndarray::IxDyn>>;
fn evaluate(
&self,
inputs: &Array<F, scirs2_core::ndarray::IxDyn>,
targets: &Array<F, scirs2_core::ndarray::IxDyn>,
loss_fn: &dyn Loss<F>,
) -> Result<F>;
}
pub mod architectures;
pub mod sequential;
pub mod trainer;
pub use architectures::*;
pub use sequential::Sequential;
pub use trainer::{History, Trainer, TrainingConfig};