use num_bigint::BigInt;
use num_rational::BigRational;
use num_traits::One;
use num_traits::Zero;
use serde::Deserialize;
use serde::Serialize;
use crate::symbolic::calculus::differentiate;
use crate::symbolic::core::Expr;
use crate::symbolic::matrix::inverse_matrix;
use crate::symbolic::simplify_dag::simplify;
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Tensor {
pub components: Vec<Expr>,
pub shape: Vec<usize>,
}
impl Tensor {
pub fn new(
components: Vec<Expr>,
shape: Vec<usize>,
) -> Result<Self, String> {
let expected_len: usize = shape.iter().product();
if components.len() != expected_len {
return Err(format!(
"Number of components \
({}) does not match \
shape ({:?})",
components.len(),
shape
));
}
Ok(Self { components, shape })
}
#[must_use]
pub const fn rank(&self) -> usize {
self.shape.len()
}
pub fn get(
&self,
indices: &[usize],
) -> Result<&Expr, String> {
if indices.len() != self.rank() {
return Err("Incorrect number of \
indices for tensor \
rank"
.to_string());
}
let mut flat_index = 0;
let mut stride = 1;
for (i, &dim) in self.shape.iter().enumerate().rev() {
if indices[i] >= dim {
return Err(format!(
"Index {} out of \
bounds for \
dimension {}",
indices[i], i
));
}
flat_index += indices[i] * stride;
stride *= dim;
}
Ok(&self.components[flat_index])
}
pub(crate) fn get_mut(
&mut self,
indices: &[usize],
) -> Result<&mut Expr, String> {
if indices.len() != self.rank() {
return Err("Incorrect number of \
indices for tensor \
rank"
.to_string());
}
let mut flat_index = 0;
let mut stride = 1;
for (i, &dim) in self.shape.iter().enumerate().rev() {
if indices[i] >= dim {
return Err(format!(
"Index {} out of \
bounds for \
dimension {}",
indices[i], i
));
}
flat_index += indices[i] * stride;
stride *= dim;
}
Ok(&mut self.components[flat_index])
}
pub fn add(
&self,
other: &Self,
) -> Result<Self, String> {
if self.shape != other.shape {
return Err("Tensors must have \
the same shape for \
addition"
.to_string());
}
let new_components = self
.components
.iter()
.zip(other.components.iter())
.map(|(a, b)| simplify(&Expr::new_add(a.clone(), b.clone())))
.collect();
Self::new(new_components, self.shape.clone())
}
pub fn sub(
&self,
other: &Self,
) -> Result<Self, String> {
if self.shape != other.shape {
return Err("Tensors must have \
the same shape for \
subtraction"
.to_string());
}
let new_components = self
.components
.iter()
.zip(other.components.iter())
.map(|(a, b)| simplify(&Expr::new_sub(a.clone(), b.clone())))
.collect();
Self::new(new_components, self.shape.clone())
}
pub fn scalar_mul(
&self,
scalar: &Expr,
) -> Result<Self, String> {
let new_components = self
.components
.iter()
.map(|c| simplify(&Expr::new_mul(scalar.clone(), c.clone())))
.collect();
Self::new(new_components, self.shape.clone())
}
pub fn outer_product(
&self,
other: &Self,
) -> Result<Self, String> {
let new_shape: Vec<usize> = self
.shape
.iter()
.chain(other.shape.iter())
.copied()
.collect();
let mut new_components = Vec::with_capacity(self.components.len() * other.components.len());
for c1 in &self.components {
for c2 in &other.components {
new_components.push(simplify(&Expr::new_mul(c1.clone(), c2.clone())));
}
}
Self::new(new_components, new_shape)
}
pub fn contract(
&self,
axis1: usize,
axis2: usize,
) -> Result<Self, String> {
if axis1 >= self.rank() || axis2 >= self.rank() {
return Err("Axis out of \
bounds"
.to_string());
}
if self.shape[axis1] != self.shape[axis2] {
return Err("Dimensions \
of contracted \
axes must be \
equal"
.to_string());
}
if axis1 == axis2 {
return Err("Cannot contract an \
axis with itself"
.to_string());
}
let mut new_shape = self.shape.clone();
let dim = self.shape[axis1];
new_shape.remove(axis1.max(axis2));
new_shape.remove(axis1.min(axis2));
let new_len: usize = if new_shape.is_empty() {
1
} else {
new_shape.iter().product()
};
let new_components = vec![Expr::BigInt(BigInt::zero()); new_len];
let mut new_tensor = Self::new(new_components, new_shape.clone())?;
let mut current_indices = vec![0; self.rank()];
loop {
let mut sum_val = Expr::BigInt(BigInt::zero());
for i in 0..dim {
current_indices[axis1] = i;
current_indices[axis2] = i;
sum_val = simplify(&Expr::new_add(sum_val, self.get(¤t_indices)?.clone()));
}
let new_indices: Vec<usize> = current_indices
.iter()
.enumerate()
.filter(|(idx, _)| *idx != axis1 && *idx != axis2)
.map(|(_, &val)| val)
.collect();
if new_tensor.rank() > 0 {
*new_tensor.get_mut(&new_indices)? = sum_val;
} else {
new_tensor.components[0] = sum_val;
}
let mut done = true;
for idx in (0..self.rank()).rev() {
if idx == axis1 || idx == axis2 {
continue; }
current_indices[idx] += 1;
if current_indices[idx] < self.shape[idx] {
done = false;
break;
}
current_indices[idx] = 0;
}
if done {
break;
}
}
Ok(new_tensor)
}
pub fn to_matrix_expr(&self) -> Result<Expr, String> {
if self.rank() != 2 {
return Err("Can only convert a \
rank-2 tensor to a \
matrix expression."
.to_string());
}
Ok(Expr::Matrix(
self.components
.chunks(self.shape[1])
.map(<[Expr]>::to_vec)
.collect(),
))
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct MetricTensor {
pub g: Tensor,
pub g_inv: Tensor,
}
impl MetricTensor {
pub fn new(g: Tensor) -> Result<Self, String> {
if g.rank() != 2 || g.shape[0] != g.shape[1] {
return Err("Metric tensor must \
be a square rank-2 \
tensor."
.to_string());
}
let g_matrix = Expr::Matrix(
g.components
.chunks(g.shape[1])
.map(<[Expr]>::to_vec)
.collect(),
);
let g_inv_matrix = inverse_matrix(&g_matrix);
let g_inv = if let Expr::Matrix(rows) = g_inv_matrix {
Tensor::new(rows.into_iter().flatten().collect(), g.shape.clone())?
} else {
return Err("Failed to invert \
metric tensor"
.to_string());
};
Ok(Self { g, g_inv })
}
pub fn raise_index(
&self,
covector: &Tensor,
) -> Result<Tensor, String> {
if covector.rank() != 1 {
return Err("Can only raise index \
of a rank-1 tensor \
(covector)."
.to_string());
}
let product = self.g_inv.outer_product(covector)?;
product.contract(1, 2)
}
pub fn lower_index(
&self,
vector: &Tensor,
) -> Result<Tensor, String> {
if vector.rank() != 1 {
return Err("Can only lower index \
of a rank-1 tensor \
(vector)."
.to_string());
}
let product = self.g.outer_product(vector)?;
product.contract(1, 2)
}
}
pub fn christoffel_symbols_first_kind(
metric: &MetricTensor,
vars: &[&str],
) -> Result<Tensor, String> {
let dim = metric.g.shape[0];
if vars.len() != dim {
return Err("Number of variables must \
match metric dimension"
.to_string());
}
let mut components = Vec::new();
for i in 0..dim {
for j in 0..dim {
for k in 0..dim {
let g_ik = metric.g.get(&[i, k])?;
let g_jk = metric.g.get(&[j, k])?;
let g_ij = metric.g.get(&[i, j])?;
let d_g_ik_dj = differentiate(g_ik, vars[j]);
let d_g_jk_di = differentiate(g_jk, vars[i]);
let d_g_ij_dk = differentiate(g_ij, vars[k]);
let term1 = simplify(&Expr::new_add(d_g_ik_dj, d_g_jk_di));
let term2 = simplify(&Expr::new_sub(term1, d_g_ij_dk));
let christoffel = simplify(&Expr::new_mul(
Expr::Rational(BigRational::new(BigInt::one(), BigInt::from(2))),
term2,
));
components.push(christoffel);
}
}
}
Tensor::new(components, vec![dim, dim, dim])
}
pub fn christoffel_symbols_second_kind(
metric: &MetricTensor,
vars: &[&str],
) -> Result<Tensor, String> {
let christoffel_1st = christoffel_symbols_first_kind(metric, vars)?;
let product = metric.g_inv.outer_product(&christoffel_1st)?;
product.contract(1, 2)
}
pub fn riemann_curvature_tensor(
metric: &MetricTensor,
vars: &[&str],
) -> Result<Tensor, String> {
let dim = metric.g.shape[0];
let christoffel_2nd = christoffel_symbols_second_kind(metric, vars)?;
let mut components = Vec::new();
for i in 0..dim {
for j in 0..dim {
for k in 0..dim {
for l in 0..dim {
let term1 = differentiate(christoffel_2nd.get(&[i, j, l])?, vars[k]);
let term2 = differentiate(christoffel_2nd.get(&[i, j, k])?, vars[l]);
let mut term3 = Expr::BigInt(BigInt::zero());
for m in 0..dim {
let g_mjl = christoffel_2nd.get(&[m, j, l])?;
let g_imk = christoffel_2nd.get(&[i, m, k])?;
term3 = simplify(&Expr::new_add(
term3,
Expr::new_mul(g_mjl.clone(), g_imk.clone()),
));
}
let mut term4 = Expr::BigInt(BigInt::zero());
for m in 0..dim {
let g_mjk = christoffel_2nd.get(&[m, j, k])?;
let g_iml = christoffel_2nd.get(&[i, m, l])?;
term4 = simplify(&Expr::new_add(
term4,
Expr::new_mul(g_mjk.clone(), g_iml.clone()),
));
}
let r_ijkl = simplify(&Expr::new_sub(
simplify(&Expr::new_add(term1, term3)),
simplify(&Expr::new_add(term2, term4)),
));
components.push(r_ijkl);
}
}
}
}
Tensor::new(components, vec![dim, dim, dim, dim])
}
pub fn covariant_derivative_vector(
vector_field: &Tensor,
metric: &MetricTensor,
vars: &[&str],
) -> Result<Tensor, String> {
if vector_field.rank() != 1 {
return Err("Input must be a \
vector field \
(rank-1 tensor)"
.to_string());
}
let dim = vector_field.shape[0];
let christoffel_2nd = christoffel_symbols_second_kind(metric, vars)?;
let mut components = Vec::new();
for i in 0..dim {
for (k, _item) in vars.iter().enumerate().take(dim) {
let partial_deriv = differentiate(vector_field.get(&[i])?, vars[k]);
let mut christoffel_term = Expr::BigInt(BigInt::zero());
for j in 0..dim {
let g_ijk = christoffel_2nd.get(&[i, j, k])?;
let v_j = vector_field.get(&[j])?;
christoffel_term = simplify(&Expr::new_add(
christoffel_term,
Expr::new_mul(g_ijk.clone(), v_j.clone()),
));
}
let nabla_v = simplify(&Expr::new_add(partial_deriv, christoffel_term));
components.push(nabla_v);
}
}
Tensor::new(components, vec![dim, dim])
}