use crate::op::{ComputeContext, GradientContext, Op, OpError};
use crate::tensor::Tensor;
use crate::tensor_ops::convert_to_tensor;
use crate::Float;
use scirs2_core::ndarray::{Array2, ArrayD, IxDyn};
pub struct TensorSolveOp {
axes: Option<Vec<i32>>,
}
impl<F: Float + scirs2_core::ndarray::ScalarOperand> Op<F> for TensorSolveOp {
fn name(&self) -> &'static str {
"TensorSolve"
}
fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
let a = ctx.input(0);
let b = ctx.input(1);
let ashape = a.shape();
let bshape = b.shape();
let (prod_x, prod_b) = validate_tensor_solveshapes(ashape, bshape, &self.axes)?;
let a_reshaped = reshape_for_solve(&a.view(), prod_b, prod_x)?;
let b_reshaped = reshape_vector(&b.view(), prod_b)?;
let x_flat = solve_linear_system(&a_reshaped, &b_reshaped)?;
let xshape = compute_solutionshape(ashape, bshape, &self.axes)?;
let x = reshape_solution(&x_flat, &xshape)?;
ctx.append_output(x);
Ok(())
}
fn grad(&self, ctx: &mut GradientContext<F>) {
let grad_output = ctx.output_grad();
let a = ctx.input(0);
let _b = ctx.input(1);
let x = ctx.output();
let g = ctx.graph();
let a_array = match a.eval(g) {
Ok(arr) => arr,
Err(_) => {
ctx.append_input_grad(0, None);
ctx.append_input_grad(1, None);
return;
}
};
let x_array = match x.eval(g) {
Ok(arr) => arr,
Err(_) => {
ctx.append_input_grad(0, None);
ctx.append_input_grad(1, None);
return;
}
};
let grad_output_array = match grad_output.eval(g) {
Ok(arr) => arr,
Err(_) => {
ctx.append_input_grad(0, None);
ctx.append_input_grad(1, None);
return;
}
};
let grad_b = compute_grad_b(&a_array, &grad_output_array, &self.axes);
let grad_a = compute_grad_a(&grad_output_array, &x_array, a_array.shape(), &self.axes);
let grad_a_tensor = convert_to_tensor(grad_a, g);
let grad_b_tensor = convert_to_tensor(grad_b, g);
ctx.append_input_grad(0, Some(grad_a_tensor));
ctx.append_input_grad(1, Some(grad_b_tensor));
}
}
pub struct EinsumOp {
pattern: String,
}
impl<F: Float + scirs2_core::ndarray::ScalarOperand> Op<F> for EinsumOp {
fn name(&self) -> &'static str {
"Einsum"
}
fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
let (input_specs, output_spec) = parse_einsum_pattern(&self.pattern)?;
if input_specs.len() == 2 {
let a = ctx.input(0);
let b = ctx.input(1);
if self.pattern == "ij,jk->ik" {
let result = compute_matmul(&a.view(), &b.view())?;
ctx.append_output(result);
} else if self.pattern == "i,i->" {
let result = compute_dot_product(&a.view(), &b.view())?;
ctx.append_output(result);
} else if self.pattern == "ij,ij->ij" {
let result = compute_elementwise_mul(&a.view(), &b.view())?;
ctx.append_output(result);
} else {
let result =
compute_general_einsum(&a.view(), &b.view(), &input_specs, &output_spec)?;
ctx.append_output(result);
}
} else {
return Err(OpError::Other(
"Only binary einsum operations supported".into(),
));
}
Ok(())
}
fn grad(&self, ctx: &mut GradientContext<F>) {
let gy = ctx.output_grad();
ctx.append_input_grad(0, Some(*gy));
ctx.append_input_grad(1, Some(*gy));
}
}
#[allow(dead_code)]
fn validate_tensor_solveshapes(
ashape: &[usize],
bshape: &[usize],
axes: &Option<Vec<i32>>,
) -> Result<(usize, usize), OpError> {
let ndim_a = ashape.len();
let ndim_b = bshape.len();
let axes_normalized = if let Some(ax) = axes {
ax.clone()
} else {
let start = (ndim_a - ndim_b) as i32;
(start..ndim_a as i32).collect()
};
let mut prod_x = 1;
let mut prod_b = 1;
for (i, &dim) in ashape.iter().enumerate() {
if axes_normalized.contains(&(i as i32)) {
prod_x *= dim;
} else {
prod_b *= dim;
}
}
let b_prod: usize = bshape.iter().product();
if b_prod != prod_b {
return Err(OpError::IncompatibleShape(
"Incompatible shapes for tensor solve".into(),
));
}
Ok((prod_x, prod_b))
}
#[allow(dead_code)]
fn reshape_for_solve<F: Float>(
tensor: &scirs2_core::ndarray::ArrayViewD<F>,
rows: usize,
cols: usize,
) -> Result<Array2<F>, OpError> {
let tensor_view = tensor.view();
let flat = tensor_view
.to_shape(rows * cols)
.map_err(|_| OpError::IncompatibleShape("Failed to flatten tensor for solve".into()))?;
let mut matrix = Array2::<F>::zeros((rows, cols));
for i in 0..rows {
for j in 0..cols {
matrix[[i, j]] = flat[i * cols + j];
}
}
Ok(matrix)
}
#[allow(dead_code)]
fn reshape_vector<F: Float>(
tensor: &scirs2_core::ndarray::ArrayViewD<F>,
size: usize,
) -> Result<scirs2_core::ndarray::Array1<F>, OpError> {
tensor
.view()
.to_shape(size)
.map_err(|_| OpError::IncompatibleShape("Failed to reshape vector".into()))
.map(|v| v.to_owned())
}
#[allow(dead_code)]
fn solve_linear_system<F: Float>(
a: &Array2<F>,
b: &scirs2_core::ndarray::Array1<F>,
) -> Result<scirs2_core::ndarray::Array1<F>, OpError> {
let n = a.shape()[0];
let m = a.shape()[1];
if n != b.len() {
return Err(OpError::IncompatibleShape(
"Matrix-vector dimension mismatch".into(),
));
}
if n != m {
let ata = a.t().dot(a);
let atb = a.t().dot(b);
return solve_square_system(&ata, &atb);
}
solve_square_system(a, b)
}
#[allow(dead_code)]
fn solve_square_system<F: Float>(
a: &Array2<F>,
b: &scirs2_core::ndarray::Array1<F>,
) -> Result<scirs2_core::ndarray::Array1<F>, OpError> {
let n = a.shape()[0];
let mut aug = Array2::<F>::zeros((n, n + 1));
for i in 0..n {
for j in 0..n {
aug[[i, j]] = a[[i, j]];
}
aug[[i, n]] = b[i];
}
for i in 0..n {
let mut max_row = i;
for k in (i + 1)..n {
if aug[[k, i]].abs() > aug[[max_row, i]].abs() {
max_row = k;
}
}
if aug[[max_row, i]].abs() < F::epsilon() {
return Err(OpError::IncompatibleShape("Matrix is singular".into()));
}
if max_row != i {
for j in 0..=n {
aug.swap((i, j), (max_row, j));
}
}
for k in (i + 1)..n {
let factor = aug[[k, i]] / aug[[i, i]];
for j in i..=n {
aug[[k, j]] = aug[[k, j]] - factor * aug[[i, j]];
}
}
}
let mut x = scirs2_core::ndarray::Array1::<F>::zeros(n);
for i in (0..n).rev() {
x[i] = aug[[i, n]];
for j in (i + 1)..n {
let x_j = x[j];
x[i] -= aug[[i, j]] * x_j;
}
x[i] /= aug[[i, i]];
}
Ok(x)
}
#[allow(dead_code)]
fn compute_solutionshape(
ashape: &[usize],
bshape: &[usize],
axes: &Option<Vec<i32>>,
) -> Result<Vec<usize>, OpError> {
let ndim_a = ashape.len();
let axes_normalized = if let Some(ax) = axes {
ax.clone()
} else {
vec![]
};
let mut xshape = Vec::new();
for (i, &dim) in ashape.iter().enumerate() {
if axes_normalized.contains(&(i as i32)) {
xshape.push(dim);
}
}
if xshape.is_empty() {
let ndim_b = bshape.len();
for &dim in ashape.iter().skip(ndim_a - ndim_b) {
xshape.push(dim);
}
}
Ok(xshape)
}
#[allow(dead_code)]
fn reshape_solution<F: Float>(
flat: &scirs2_core::ndarray::Array1<F>,
shape: &[usize],
) -> Result<ArrayD<F>, OpError> {
let total: usize = shape.iter().product();
if flat.len() != total {
return Err(OpError::IncompatibleShape(
"Solution reshape size mismatch".into(),
));
}
let dynshape = IxDyn(shape);
flat.view()
.to_shape(dynshape)
.map_err(|_| OpError::IncompatibleShape("Failed to reshape solution".into()))
.map(|v| v.to_owned())
}
#[allow(dead_code)]
fn compute_grad_b<F: Float>(
_a: &ArrayD<F>,
grad_x: &ArrayD<F>,
_axes: &Option<Vec<i32>>,
) -> ArrayD<F> {
grad_x.clone()
}
#[allow(dead_code)]
fn compute_grad_a<F: Float>(
_grad_x: &ArrayD<F>,
_x: &ArrayD<F>,
ashape: &[usize],
_axes: &Option<Vec<i32>>,
) -> ArrayD<F> {
ArrayD::<F>::zeros(IxDyn(ashape))
}
#[allow(dead_code)]
fn parse_einsum_pattern(pattern: &str) -> Result<(Vec<String>, String), OpError> {
let parts: Vec<&str> = pattern.split("->").collect();
if parts.len() != 2 {
return Err(OpError::Other("Invalid einsum _pattern".into()));
}
let input_part = parts[0];
let output_part = parts[1];
let input_specs: Vec<String> = input_part.split(',').map(|s| s.to_string()).collect();
Ok((input_specs, output_part.to_string()))
}
#[allow(dead_code)]
fn compute_matmul<F: Float>(
a: &scirs2_core::ndarray::ArrayViewD<F>,
b: &scirs2_core::ndarray::ArrayViewD<F>,
) -> Result<ArrayD<F>, OpError> {
if a.ndim() != 2 || b.ndim() != 2 {
return Err(OpError::IncompatibleShape(
"Matrix multiplication requires 2D arrays".into(),
));
}
let a_2d = a
.view()
.into_dimensionality::<scirs2_core::ndarray::Ix2>()
.expect("Operation failed");
let b_2d = b
.view()
.into_dimensionality::<scirs2_core::ndarray::Ix2>()
.expect("Operation failed");
Ok(a_2d.dot(&b_2d).into_dyn())
}
#[allow(dead_code)]
fn compute_dot_product<F: Float>(
a: &scirs2_core::ndarray::ArrayViewD<F>,
b: &scirs2_core::ndarray::ArrayViewD<F>,
) -> Result<ArrayD<F>, OpError> {
if a.shape() != b.shape() {
return Err(OpError::IncompatibleShape(
"Dot product requires same shape".into(),
));
}
let mut sum = F::zero();
for (&a_val, &b_val) in a.iter().zip(b.iter()) {
sum += a_val * b_val;
}
Ok(scirs2_core::ndarray::arr0(sum).into_dyn())
}
#[allow(dead_code)]
fn compute_elementwise_mul<F: Float>(
a: &scirs2_core::ndarray::ArrayViewD<F>,
b: &scirs2_core::ndarray::ArrayViewD<F>,
) -> Result<ArrayD<F>, OpError> {
if a.shape() != b.shape() {
return Err(OpError::IncompatibleShape(
"Element-wise multiplication requires same shape".into(),
));
}
Ok((a * b).into_owned())
}
#[allow(dead_code)]
fn compute_general_einsum<F: Float>(
a: &scirs2_core::ndarray::ArrayViewD<F>,
_b: &scirs2_core::ndarray::ArrayViewD<F>,
_input_specs: &[String],
_output_spec: &str,
) -> Result<ArrayD<F>, OpError> {
Ok(a.to_owned())
}
#[allow(dead_code)]
pub fn tensor_solve<'g, F: Float + scirs2_core::ndarray::ScalarOperand>(
a: &Tensor<'g, F>,
b: &Tensor<'g, F>,
axes: Option<Vec<i32>>,
) -> Tensor<'g, F> {
let g = a.graph();
Tensor::builder(g)
.append_input(a, false)
.append_input(b, false)
.build(TensorSolveOp { axes })
}
#[allow(dead_code)]
pub fn einsum<'g, F: Float + scirs2_core::ndarray::ScalarOperand>(
pattern: &str,
operands: &[&Tensor<'g, F>],
) -> Tensor<'g, F> {
if operands.len() != 2 {
panic!("Only binary einsum operations are currently supported");
}
let g = operands[0].graph();
Tensor::builder(g)
.append_input(operands[0], false)
.append_input(operands[1], false)
.build(EinsumOp {
pattern: pattern.to_string(),
})
}
#[allow(dead_code)]
pub fn kron<'g, F: Float>(a: &Tensor<'g, F>, b: &Tensor<'g, F>) -> Tensor<'g, F> {
let g = a.graph();
Tensor::builder(g)
.append_input(a, false)
.append_input(b, false)
.build(KroneckerOp)
}
struct KroneckerOp;
impl<F: Float> Op<F> for KroneckerOp {
fn name(&self) -> &'static str {
"Kronecker"
}
fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
let a = ctx.input(0);
let b = ctx.input(1);
let ashape = a.shape();
let bshape = b.shape();
if ashape.len() != 2 || bshape.len() != 2 {
return Err(OpError::IncompatibleShape(
"Kronecker product requires 2D matrices".into(),
));
}
let (m, n) = (ashape[0], ashape[1]);
let (p, q) = (bshape[0], bshape[1]);
let mut result = ArrayD::<F>::zeros(IxDyn(&[m * p, n * q]));
for i in 0..m {
for j in 0..n {
let a_ij = a[[i, j]];
for k in 0..p {
for l in 0..q {
result[[i * p + k, j * q + l]] = a_ij * b[[k, l]];
}
}
}
}
ctx.append_output(result);
Ok(())
}
fn grad(&self, ctx: &mut GradientContext<F>) {
let gy = ctx.output_grad();
ctx.append_input_grad(0, Some(*gy));
ctx.append_input_grad(1, Some(*gy));
}
}