use crate::data_types::{Type, BIT};
use crate::errors::Result;
use crate::graphs::Node;
pub(super) fn validate_arguments_in_broadcast_bit_ops(
arguments_types: Vec<Type>,
custom_op_name: &str,
) -> Result<()> {
if arguments_types.len() != 2 {
return Err(runtime_error!("Invalid number of arguments"));
}
let mut are_valid_inputs: bool = true;
let mut error_message: String = format!("{}: ", custom_op_name);
match (&arguments_types[0], &arguments_types[1]) {
(Type::Array(shape0, scalar_type0), Type::Array(shape1, scalar_type1)) => {
if shape0[shape0.len() - 1] != shape1[shape1.len() - 1] {
are_valid_inputs = false;
error_message.push_str("Input arrays' last dimensions are not the same");
} else if *scalar_type0 != BIT {
are_valid_inputs = false;
error_message.push_str("Input array [0]'s ScalarType is not BIT");
} else if *scalar_type1 != BIT {
are_valid_inputs = false;
error_message.push_str("Input array [1]'s ScalarType is not BIT");
}
}
_ => {
are_valid_inputs = false;
error_message.push_str("Invalid input argument type, expected Array type");
}
}
if !are_valid_inputs {
Err(runtime_error!("{}", error_message))
} else {
Ok(())
}
}
pub fn pull_out_bits(x: Node) -> Result<Node> {
let shape = x.get_type()?.get_dimensions();
if shape.len() == 1 {
Ok(x)
} else {
let mut axes_permutation = vec![shape.len() as u64 - 1];
axes_permutation.extend(0..shape.len() as u64 - 1);
Ok(x.permute_axes(axes_permutation)?)
}
}
pub fn put_in_bits(x: Node) -> Result<Node> {
let shape = x.get_type()?.get_dimensions();
if shape.len() == 1 {
Ok(x)
} else {
let mut axes_permutation: Vec<u64> = (1..shape.len()).map(|x| x as u64).collect();
axes_permutation.push(0);
Ok(x.permute_axes(axes_permutation)?)
}
}