pub mod bidirectional;
pub mod gru;
pub mod lstm;
pub mod rnn;
pub use bidirectional::Bidirectional;
pub use gru::GRU;
pub use lstm::LSTM;
pub use rnn::RNN;
use scirs2_core::ndarray::{Array, IxDyn};
use std::sync::{Arc, RwLock};
pub type LstmStateCache<F> = Arc<RwLock<Option<(Array<F, IxDyn>, Array<F, IxDyn>)>>>;
pub type LstmGateCache<F> = Arc<
RwLock<
Option<(
Array<F, IxDyn>,
Array<F, IxDyn>,
Array<F, IxDyn>,
Array<F, IxDyn>,
)>,
>,
>;
pub type GruStateCache<F> = Arc<RwLock<Option<Array<F, IxDyn>>>>;
pub type GruGateCache<F> = Arc<RwLock<Option<(Array<F, IxDyn>, Array<F, IxDyn>, Array<F, IxDyn>)>>>;
pub type RnnStateCache<F> = Arc<RwLock<Option<Array<F, IxDyn>>>>;
pub type LstmStepOutput<F> = (
Array<F, IxDyn>,
Array<F, IxDyn>,
(
Array<F, IxDyn>,
Array<F, IxDyn>,
Array<F, IxDyn>,
Array<F, IxDyn>,
),
);
pub type GruForwardOutput<F> = (
Array<F, IxDyn>,
(Array<F, IxDyn>, Array<F, IxDyn>, Array<F, IxDyn>),
);