#![allow(clippy::too_many_arguments)]
use crate::char_parse::char_parse;
use crate::containers::*;
use crate::float_trait::*;
use core::ptr::null;
use derive_builder::Builder;
use std::collections::{BTreeMap, BTreeSet};
use tblis_ffi::tblis::{tblis_comm, tblis_config};
#[non_exhaustive]
#[derive(Builder, Debug, Clone)]
pub struct TblisTriCfg<T>
where
T: TblisFloatAPI,
{
#[builder(default = "null()")]
pub comm: *const tblis_comm,
#[builder(default = "null()")]
pub cntx: *const tblis_config,
#[builder(default = "T::one()")]
pub alpha: T,
#[builder(default = "T::zero()")]
pub beta: T,
#[builder(default = "false")]
pub conja: bool,
#[builder(default = "false")]
pub conjb: bool,
}
impl<T> Default for TblisTriCfg<T>
where
T: TblisFloatAPI,
{
fn default() -> Self {
TblisTriCfgBuilder::default().build().unwrap()
}
}
#[non_exhaustive]
#[derive(Builder, Debug, Clone)]
pub struct TblisBiCfg<T>
where
T: TblisFloatAPI,
{
#[builder(default = "null()")]
pub comm: *const tblis_comm,
#[builder(default = "null()")]
pub cntx: *const tblis_config,
#[builder(default = "T::one()")]
pub alpha: T,
#[builder(default = "T::one()")]
pub beta: T,
#[builder(default = "false")]
pub conja: bool,
#[builder(default = "false")]
pub conjb: bool,
}
impl<T> Default for TblisBiCfg<T>
where
T: TblisFloatAPI,
{
fn default() -> Self {
TblisBiCfgBuilder::default().build().unwrap()
}
}
#[non_exhaustive]
#[derive(Builder, Debug, Clone)]
pub struct TblisUniCfg<T>
where
T: TblisFloatAPI,
{
#[builder(default = "null()")]
pub comm: *const tblis_comm,
#[builder(default = "null()")]
pub cntx: *const tblis_config,
#[builder(default = "T::one()")]
pub alpha: T,
#[builder(default = "false")]
pub conj: bool,
}
impl<T> Default for TblisUniCfg<T>
where
T: TblisFloatAPI,
{
fn default() -> Self {
TblisUniCfgBuilder::default().build().unwrap()
}
}
#[non_exhaustive]
#[derive(Builder, Debug, Clone)]
pub struct TblisZeroCfg {
#[builder(default = "null()")]
pub comm: *const tblis_comm,
#[builder(default = "null()")]
pub cntx: *const tblis_config,
}
impl Default for TblisZeroCfg {
fn default() -> Self {
TblisZeroCfgBuilder::default().build().unwrap()
}
}
fn check_size_dict(subscripts: &[&str], shapes: &[&[isize]]) -> Result<BTreeMap<char, isize>, String> {
let mut size_dict = BTreeMap::new();
if subscripts.len() != shapes.len() {
return Err(format!("Number of subscripts and shapes do not match: {} vs {}", subscripts.len(), shapes.len()));
}
for (subscript, shape) in subscripts.iter().zip(shapes.iter()) {
let subscript = subscript.chars().collect::<Vec<char>>();
if subscript.len() != shape.len() {
return Err(format!("Subscript length and shape length do not match: {subscript:?} vs {shape:?}"));
}
for (c, &s) in subscript.iter().zip(shape.iter()) {
if s < 0 {
return Err(format!("Invalid dimension size {s} for index {c} in subscript {subscript:?}"));
}
if let Some(&existing) = size_dict.get(c) {
if existing != s {
return Err(format!(
"Inconsistent dimension size for index {c}: {existing} vs {s} in subscript {subscript:?}"
));
}
} else {
size_dict.insert(*c, s);
}
}
}
Ok(size_dict)
}
pub use TblisBiCfg as TblisAddCfg;
pub use TblisBiCfgBuilder as TblisAddCfgBuilder;
pub unsafe fn tblis_tensor_add<T>(
a: &TblisTensor<T>,
idx_a: &str,
b: &mut TblisTensor<T>,
idx_b: &str,
cfg: Option<TblisAddCfg<T>>,
) where
T: TblisFloatAPI,
{
unsafe { tblis_tensor_add_f(a, idx_a, b, idx_b, cfg).unwrap() }
}
pub unsafe fn tblis_tensor_add_f<T>(
a: &TblisTensor<T>,
idx_a: &str,
b: &mut TblisTensor<T>,
idx_b: &str,
cfg: Option<TblisAddCfg<T>>,
) -> Result<(), String>
where
T: TblisFloatAPI,
{
check_size_dict(&[idx_a, idx_b], &[&a.shape, &b.shape])?;
let chk_a = idx_a.chars().collect::<BTreeSet<char>>();
let chk_b = idx_b.chars().collect::<BTreeSet<char>>();
let chk_ab = &chk_a & &chk_b;
let chk_a_only = &chk_a - &chk_ab;
let chk_b_only = &chk_b - &chk_ab;
if !(chk_a_only.is_empty() || chk_b_only.is_empty()) {
return Err(format!(
"tblis_tensor_add: Only one of two tensors can have unique indices. Unique to a ({idx_a}): {chk_a_only:?}, unique to b ({idx_b}): {chk_b_only:?}"
));
}
let indices = char_parse(&[idx_a, idx_b])?;
let (a_idx, b_idx) = (indices[0].as_ptr(), indices[1].as_ptr());
let TblisAddCfg { comm, cntx, alpha, beta, conja, conjb } = cfg.unwrap_or_default();
let mut a = a.clone();
a.scalar = alpha;
b.scalar = beta;
b.conj = conjb;
a.conj = conja;
unsafe { tblis_ffi::tblis::tblis_tensor_add(comm, cntx, &a.to_ffi_tensor(), a_idx, &mut b.to_ffi_tensor(), b_idx) };
Ok(())
}
pub use TblisBiCfg as TblisDotCfg;
pub use TblisBiCfgBuilder as TblisDotCfgBuilder;
pub unsafe fn tblis_tensor_dot<T>(
a: &TblisTensor<T>,
idx_a: &str,
b: &TblisTensor<T>,
idx_b: &str,
cfg: Option<TblisDotCfg<T>>,
) -> T
where
T: TblisFloatAPI,
{
unsafe { tblis_tensor_dot_f(a, idx_a, b, idx_b, cfg).unwrap() }
}
pub unsafe fn tblis_tensor_dot_f<T>(
a: &TblisTensor<T>,
idx_a: &str,
b: &TblisTensor<T>,
idx_b: &str,
cfg: Option<TblisDotCfg<T>>,
) -> Result<T, String>
where
T: TblisFloatAPI,
{
check_size_dict(&[idx_a, idx_b], &[&a.shape, &b.shape])?;
let chk_a = idx_a.chars().collect::<BTreeSet<char>>();
let chk_b = idx_b.chars().collect::<BTreeSet<char>>();
let chk_ab = &chk_a & &chk_b;
let chk_a_only = &chk_a - &chk_ab;
let chk_b_only = &chk_b - &chk_ab;
if !chk_a_only.is_empty() {
return Err(format!(
"tblis_tensor_dot: Unique indices is not allowed. Unique to a ({idx_a}) of b ({idx_b}): {chk_a_only:?}"
));
}
if !chk_b_only.is_empty() {
return Err(format!(
"tblis_tensor_dot: Unique indices is not allowed. Unique to b ({idx_b}) of a ({idx_a}): {chk_b_only:?}"
));
}
let indices = char_parse(&[idx_a, idx_b])?;
let (a_idx, b_idx) = (indices[0].as_ptr(), indices[1].as_ptr());
let TblisDotCfg { comm, cntx, alpha, beta, conja, conjb } = cfg.unwrap_or_default();
let mut a = a.clone();
let mut b = b.clone();
a.scalar = alpha;
b.scalar = beta;
a.conj = conja;
b.conj = conjb;
let result = T::zero();
unsafe {
tblis_ffi::tblis::tblis_tensor_dot(
comm,
cntx,
&a.to_ffi_tensor(),
a_idx,
&b.to_ffi_tensor(),
b_idx,
&mut result.to_ffi_scalar(),
);
}
Ok(result)
}
pub use TblisTriCfg as TblisMultCfg;
pub use TblisTriCfgBuilder as TblisMultCfgBuilder;
pub unsafe fn tblis_tensor_mult<T>(
a: &TblisTensor<T>,
idx_a: &str,
b: &TblisTensor<T>,
idx_b: &str,
c: &mut TblisTensor<T>,
idx_c: &str,
cfg: Option<TblisMultCfg<T>>,
) where
T: TblisFloatAPI,
{
unsafe { tblis_tensor_mult_f(a, idx_a, b, idx_b, c, idx_c, cfg).unwrap() }
}
pub unsafe fn tblis_tensor_mult_f<T>(
a: &TblisTensor<T>,
idx_a: &str,
b: &TblisTensor<T>,
idx_b: &str,
c: &mut TblisTensor<T>,
idx_c: &str,
cfg: Option<TblisMultCfg<T>>,
) -> Result<(), String>
where
T: TblisFloatAPI,
{
check_size_dict(&[idx_a, idx_b, idx_c], &[&a.shape, &b.shape, &c.shape])?;
let chk_a = idx_a.chars().collect::<BTreeSet<char>>();
let chk_b = idx_b.chars().collect::<BTreeSet<char>>();
let chk_c = idx_c.chars().collect::<BTreeSet<char>>();
let chk_ab = &chk_a & &chk_b;
let chk_ac = &chk_a & &chk_c;
let chk_bc = &chk_b & &chk_c;
let chk_a_only = &chk_a - &(&chk_ab | &chk_ac);
let chk_b_only = &chk_b - &(&chk_ab | &chk_bc);
let chk_c_only = &chk_c - &(&chk_ac | &chk_bc);
if !chk_a_only.is_empty() || !chk_b_only.is_empty() || !chk_c_only.is_empty() {
return Err(format!(
"tblis_tensor_mult: Unique indices is not allowed. Input and unique indices: a ({idx_a}): {chk_a_only:?}, b ({idx_b}): {chk_b_only:?}, c ({idx_c}): {chk_c_only:?}"
));
}
let indices = char_parse(&[idx_a, idx_b, idx_c])?;
let (a_idx, b_idx, c_idx) = (indices[0].as_ptr(), indices[1].as_ptr(), indices[2].as_ptr());
let TblisMultCfg { comm, cntx, alpha, beta, conja, conjb } = cfg.unwrap_or_default();
let mut a = a.clone();
let mut b = b.clone();
a.scalar = alpha;
b.scalar = T::one();
c.scalar = beta;
b.conj = conjb;
a.conj = conja;
c.conj = false;
unsafe {
tblis_ffi::tblis::tblis_tensor_mult(
comm,
cntx,
&a.to_ffi_tensor(),
a_idx,
&b.to_ffi_tensor(),
b_idx,
&mut c.to_ffi_tensor(),
c_idx,
);
};
Ok(())
}
#[repr(u32)]
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TblisReduceOp {
Sum = tblis_ffi::tblis::reduce_t::REDUCE_SUM as _,
SumAbs = tblis_ffi::tblis::reduce_t::REDUCE_SUM_ABS as _,
Max = tblis_ffi::tblis::reduce_t::REDUCE_MAX as _,
MaxAbs = tblis_ffi::tblis::reduce_t::REDUCE_MAX_ABS as _,
Min = tblis_ffi::tblis::reduce_t::REDUCE_MIN as _,
MinAbs = tblis_ffi::tblis::reduce_t::REDUCE_MIN_ABS as _,
Norm2 = tblis_ffi::tblis::reduce_t::REDUCE_NORM_2 as _,
}
#[allow(non_upper_case_globals)]
impl TblisReduceOp {
pub const Norm1: Self = Self::SumAbs;
pub const NormInf: Self = Self::MaxAbs;
}
impl From<&str> for TblisReduceOp {
fn from(s: &str) -> Self {
let st = s.to_lowercase().replace(['-', ' ', '_'], "");
match st.as_str() {
"sum" => Self::Sum,
"sumabs" => Self::SumAbs,
"max" => Self::Max,
"maxabs" => Self::MaxAbs,
"min" => Self::Min,
"minabs" => Self::MinAbs,
"norm2" => Self::Norm2,
"norm1" => Self::Norm1,
"norminf" => Self::NormInf,
_ => panic!("Invalid reduction operation string: {s}"),
}
}
}
impl From<TblisReduceOp> for tblis_ffi::tblis::reduce_t {
fn from(op: TblisReduceOp) -> Self {
match op {
TblisReduceOp::Sum => tblis_ffi::tblis::reduce_t::REDUCE_SUM,
TblisReduceOp::SumAbs => tblis_ffi::tblis::reduce_t::REDUCE_SUM_ABS,
TblisReduceOp::Max => tblis_ffi::tblis::reduce_t::REDUCE_MAX,
TblisReduceOp::MaxAbs => tblis_ffi::tblis::reduce_t::REDUCE_MAX_ABS,
TblisReduceOp::Min => tblis_ffi::tblis::reduce_t::REDUCE_MIN,
TblisReduceOp::MinAbs => tblis_ffi::tblis::reduce_t::REDUCE_MIN_ABS,
TblisReduceOp::Norm2 => tblis_ffi::tblis::reduce_t::REDUCE_NORM_2,
}
}
}
pub use TblisUniCfg as TblisReduceCfg;
pub use TblisUniCfgBuilder as TblisReduceCfgBuilder;
pub unsafe fn tblis_tensor_reduce<T>(
a: &TblisTensor<T>,
idx_a: &str,
op: TblisReduceOp,
cfg: Option<TblisReduceCfg<T>>,
) -> T
where
T: TblisFloatAPI,
{
unsafe { tblis_tensor_reduce_f(a, idx_a, op, cfg).unwrap() }
}
pub unsafe fn tblis_tensor_reduce_f<T>(
a: &TblisTensor<T>,
idx_a: &str,
op: TblisReduceOp,
cfg: Option<TblisReduceCfg<T>>,
) -> Result<T, String>
where
T: TblisFloatAPI,
{
check_size_dict(&[idx_a], &[&a.shape]).unwrap();
let indices = char_parse(&[idx_a])?;
let a_idx = indices[0].as_ptr();
let TblisReduceCfg { comm, cntx, alpha, conj } = cfg.unwrap_or_default();
let mut a = a.clone();
a.scalar = alpha;
a.conj = conj;
let op = op.into();
let result = T::zero();
let mut idx = 0_isize;
unsafe {
tblis_ffi::tblis::tblis_tensor_reduce(
comm,
cntx,
op,
&a.to_ffi_tensor(),
a_idx,
&mut result.to_ffi_scalar(),
&mut idx,
);
}
Ok(result)
}
pub use TblisUniCfg as TblisScaleCfg;
pub use TblisUniCfgBuilder as TblisScaleCfgBuilder;
pub unsafe fn tblis_tensor_scale<T>(a: &mut TblisTensor<T>, idx_a: &str, cfg: Option<TblisScaleCfg<T>>)
where
T: TblisFloatAPI,
{
unsafe { tblis_tensor_scale_f(a, idx_a, cfg).unwrap() }
}
pub unsafe fn tblis_tensor_scale_f<T>(
a: &mut TblisTensor<T>,
idx_a: &str,
cfg: Option<TblisScaleCfg<T>>,
) -> Result<(), String>
where
T: TblisFloatAPI,
{
check_size_dict(&[idx_a], &[&a.shape]).unwrap();
let indices = char_parse(&[idx_a])?;
let a_idx = indices[0].as_ptr();
let TblisScaleCfg { comm, cntx, alpha, conj } = cfg.unwrap_or_default();
a.scalar = alpha;
a.conj = conj;
unsafe {
tblis_ffi::tblis::tblis_tensor_scale(comm, cntx, &mut a.to_ffi_tensor(), a_idx);
};
Ok(())
}
pub use TblisZeroCfg as TblisSetCfg;
pub use TblisZeroCfgBuilder as TblisSetCfgBuilder;
pub unsafe fn tblis_tensor_set<T>(a: &mut TblisTensor<T>, idx_a: &str, alpha: T, cfg: Option<TblisSetCfg>)
where
T: TblisFloatAPI,
{
unsafe { tblis_tensor_set_f(a, idx_a, alpha, cfg).unwrap() }
}
pub unsafe fn tblis_tensor_set_f<T>(
a: &mut TblisTensor<T>,
idx_a: &str,
alpha: T,
cfg: Option<TblisSetCfg>,
) -> Result<(), String>
where
T: TblisFloatAPI,
{
check_size_dict(&[idx_a], &[&a.shape]).unwrap();
let indices = char_parse(&[idx_a])?;
let a_idx = indices[0].as_ptr();
let TblisSetCfg { comm, cntx } = cfg.unwrap_or_default();
unsafe {
tblis_ffi::tblis::tblis_tensor_set(comm, cntx, &alpha.to_ffi_scalar(), &mut a.to_ffi_tensor(), a_idx);
};
Ok(())
}
pub use TblisUniCfg as TblisShiftCfg;
pub use TblisUniCfgBuilder as TblisShiftCfgBuilder;
pub unsafe fn tblis_tensor_shift<T>(a: &mut TblisTensor<T>, idx_a: &str, alpha: T, cfg: Option<TblisShiftCfg<T>>)
where
T: TblisFloatAPI,
{
unsafe { tblis_tensor_shift_f(a, idx_a, alpha, cfg).unwrap() }
}
pub unsafe fn tblis_tensor_shift_f<T>(
a: &mut TblisTensor<T>,
idx_a: &str,
alpha: T,
cfg: Option<TblisShiftCfg<T>>,
) -> Result<(), String>
where
T: TblisFloatAPI,
{
check_size_dict(&[idx_a], &[&a.shape]).unwrap();
let indices = char_parse(&[idx_a])?;
let a_idx = indices[0].as_ptr();
let TblisShiftCfg { comm, cntx, alpha: alpha_a, conj } = cfg.unwrap_or_default();
a.scalar = alpha_a;
a.conj = conj;
unsafe {
tblis_ffi::tblis::tblis_tensor_shift(comm, cntx, &alpha.to_ffi_scalar(), &mut a.to_ffi_tensor(), a_idx);
};
Ok(())
}