use crate::error::{NeuralError, Result};
use crate::layers::Layer;
use scirs2_core::ndarray::{concatenate, Array, Axis, IxDyn, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use std::fmt::Debug;
use std::sync::{Arc, RwLock};
pub struct Bidirectional<F: Float + Debug + Send + Sync + NumAssign> {
forward_layer: Box<dyn Layer<F> + Send + Sync>,
backward_layer: Option<Box<dyn Layer<F> + Send + Sync>>,
name: Option<String>,
input_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Bidirectional<F> {
pub fn new(
forward_layer: Box<dyn Layer<F> + Send + Sync>,
backward_layer: Option<Box<dyn Layer<F> + Send + Sync>>,
name: Option<&str>,
) -> Result<Self> {
Ok(Self {
forward_layer,
backward_layer,
name: name.map(String::from),
input_cache: Arc::new(RwLock::new(None)),
})
}
pub fn new_with_single_layer(
layer: Box<dyn Layer<F> + Send + Sync>,
name: Option<&str>,
) -> Result<Self> {
Self::new(layer, None, name)
}
pub fn name(&self) -> Option<&str> {
self.name.as_deref()
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F>
for Bidirectional<F>
{
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
*self.input_cache.write().expect("Operation failed") = Some(input.clone());
let inputshape = input.shape();
if inputshape.len() != 3 {
return Err(NeuralError::InferenceError(format!(
"Expected 3D input [batch_size, seq_len, input_size], got {inputshape:?}"
)));
}
let _batch_size = inputshape[0];
let seq_len = inputshape[1];
let forward_output = self.forward_layer.forward(input)?;
if self.backward_layer.is_none() {
let mut reversed_slices = Vec::new();
for t in (0..seq_len).rev() {
let slice = input.slice(scirs2_core::ndarray::s![.., t..t + 1, ..]);
reversed_slices.push(slice);
}
let views: Vec<_> = reversed_slices.iter().map(|s| s.view()).collect();
let reversed_input = concatenate(Axis(1), &views)?.into_dyn();
let backward_output = self.forward_layer.forward(&reversed_input)?;
let mut backward_reversed_slices = Vec::new();
for t in (0..seq_len).rev() {
let slice = backward_output.slice(scirs2_core::ndarray::s![.., t..t + 1, ..]);
backward_reversed_slices.push(slice);
}
let backward_views: Vec<_> =
backward_reversed_slices.iter().map(|s| s.view()).collect();
let backward_output_aligned = concatenate(Axis(1), &backward_views)?.into_dyn();
let forward_view = forward_output.view();
let backward_view = backward_output_aligned.view();
let output = concatenate(Axis(2), &[forward_view, backward_view])?.into_dyn();
return Ok(output);
}
let backward_layer = self.backward_layer.as_ref().expect("Operation failed");
let mut reversed_slices = Vec::new();
for t in (0..seq_len).rev() {
let slice = input.slice(scirs2_core::ndarray::s![.., t..t + 1, ..]);
reversed_slices.push(slice);
}
let views: Vec<_> = reversed_slices.iter().map(|s| s.view()).collect();
let reversed_input = concatenate(Axis(1), &views)?.into_dyn();
let backward_output = backward_layer.forward(&reversed_input)?;
let mut backward_reversed_slices = Vec::new();
for t in (0..seq_len).rev() {
let slice = backward_output.slice(scirs2_core::ndarray::s![.., t..t + 1, ..]);
backward_reversed_slices.push(slice);
}
let backward_views: Vec<_> = backward_reversed_slices.iter().map(|s| s.view()).collect();
let backward_output_aligned = concatenate(Axis(1), &backward_views)?.into_dyn();
let forward_view = forward_output.view();
let backward_view = backward_output_aligned.view();
let output = concatenate(Axis(2), &[forward_view, backward_view])?.into_dyn();
Ok(output)
}
fn backward(
&self,
_input: &Array<F, IxDyn>,
grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
let input_ref = self.input_cache.read().expect("Operation failed");
if input_ref.is_none() {
return Err(NeuralError::InferenceError(
"No cached _input for backward pass. Call forward() first.".to_string(),
));
}
let cached_input = input_ref.as_ref().expect("Operation failed");
let gradshape = grad_output.shape();
if gradshape.len() != 3 {
return Err(NeuralError::InferenceError(format!(
"Expected 3D gradient [batch_size, seq_len, hidden_size*2], got {gradshape:?}"
)));
}
let _batch_size = gradshape[0];
let seq_len = gradshape[1];
let total_hidden = gradshape[2];
if self.backward_layer.is_none() {
let hidden_size = total_hidden / 2;
let grad_forward = grad_output
.slice(scirs2_core::ndarray::s![.., .., ..hidden_size])
.to_owned()
.into_dyn();
let grad_backward = grad_output
.slice(scirs2_core::ndarray::s![.., .., hidden_size..])
.to_owned()
.into_dyn();
let grad_input_forward = self.forward_layer.backward(cached_input, &grad_forward)?;
let mut backward_grad_slices = Vec::new();
for t in (0..seq_len).rev() {
let slice = grad_backward.slice(scirs2_core::ndarray::s![.., t..t + 1, ..]);
backward_grad_slices.push(slice);
}
let backward_grad_views: Vec<_> =
backward_grad_slices.iter().map(|s| s.view()).collect();
let grad_backward_reversed = concatenate(Axis(1), &backward_grad_views)?.into_dyn();
let mut input_slices = Vec::new();
for t in (0..seq_len).rev() {
let slice = cached_input.slice(scirs2_core::ndarray::s![.., t..t + 1, ..]);
input_slices.push(slice);
}
let input_views: Vec<_> = input_slices.iter().map(|s| s.view()).collect();
let input_reversed = concatenate(Axis(1), &input_views)?.into_dyn();
let grad_input_backward_reversed = self
.forward_layer
.backward(&input_reversed, &grad_backward_reversed)?;
let mut final_backward_slices = Vec::new();
for t in (0..seq_len).rev() {
let slice =
grad_input_backward_reversed.slice(scirs2_core::ndarray::s![.., t..t + 1, ..]);
final_backward_slices.push(slice);
}
let final_backward_views: Vec<_> =
final_backward_slices.iter().map(|s| s.view()).collect();
let grad_input_backward = concatenate(Axis(1), &final_backward_views)?.into_dyn();
let grad_input = grad_input_forward + grad_input_backward;
return Ok(grad_input);
}
let backward_layer = self.backward_layer.as_ref().expect("Operation failed");
let hidden_size = total_hidden / 2;
let grad_forward = grad_output
.slice(scirs2_core::ndarray::s![.., .., ..hidden_size])
.to_owned()
.into_dyn();
let grad_backward = grad_output
.slice(scirs2_core::ndarray::s![.., .., hidden_size..])
.to_owned()
.into_dyn();
let grad_input_forward = self.forward_layer.backward(cached_input, &grad_forward)?;
let mut backward_grad_slices = Vec::new();
for t in (0..seq_len).rev() {
let slice = grad_backward.slice(scirs2_core::ndarray::s![.., t..t + 1, ..]);
backward_grad_slices.push(slice);
}
let backward_grad_views: Vec<_> = backward_grad_slices.iter().map(|s| s.view()).collect();
let grad_backward_reversed = concatenate(Axis(1), &backward_grad_views)?.into_dyn();
let mut input_slices = Vec::new();
for t in (0..seq_len).rev() {
let slice = cached_input.slice(scirs2_core::ndarray::s![.., t..t + 1, ..]);
input_slices.push(slice);
}
let input_views: Vec<_> = input_slices.iter().map(|s| s.view()).collect();
let input_reversed = concatenate(Axis(1), &input_views)?.into_dyn();
let grad_input_backward_reversed =
backward_layer.backward(&input_reversed, &grad_backward_reversed)?;
let mut final_backward_slices = Vec::new();
for t in (0..seq_len).rev() {
let slice =
grad_input_backward_reversed.slice(scirs2_core::ndarray::s![.., t..t + 1, ..]);
final_backward_slices.push(slice);
}
let final_backward_views: Vec<_> = final_backward_slices.iter().map(|s| s.view()).collect();
let grad_input_backward = concatenate(Axis(1), &final_backward_views)?.into_dyn();
let grad_input = grad_input_forward + grad_input_backward;
Ok(grad_input)
}
fn update(&mut self, learningrate: F) -> Result<()> {
self.forward_layer.update(learningrate)?;
if let Some(ref mut backward_layer) = self.backward_layer {
backward_layer.update(learningrate)?;
}
Ok(())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}