hpt-display 0.1.2

An internal library for displaying tensors
Documentation
use hpt_traits::tensor::CommonBounds;
use hpt_types::into_scalar::Cast;

pub(crate) fn format_float<T: CommonBounds + Cast<f64>>(val: T, precision: usize) -> String {
    match T::STR {
        "bf16" => {
            let f64_val: f64 = val.cast();
            if f64_val - (f64_val as i64 as f64) != 0.0 {
                format!("{:.prec$}", f64_val, prec = precision)
            } else {
                format!("{:.0}.", f64_val)
            }
        }
        "f16" => {
            let f64_val: f64 = val.cast();
            if f64_val - (f64_val as i64 as f64) != 0.0 {
                format!("{:.prec$}", val, prec = precision)
            } else {
                format!("{:.0}.", val)
            }
        }
        "f32" => {
            let tmp_val: f64 = val.cast();
            if tmp_val - (tmp_val as i64 as f64) != 0.0 {
                format!("{:.prec$}", val, prec = precision)
            } else {
                format!("{:.0}.", val)
            }
        }
        "f64" => {
            let tmp_val: f64 = val.cast();
            if tmp_val - (tmp_val as i64 as f64) != 0.0 {
                format!("{:.prec$}", val, prec = precision)
            } else {
                format!("{:.0}.", val)
            }
        }
        _ => panic!("{} is not a floating point type", T::STR),
    }
}

pub(crate) fn format_complex<T: CommonBounds>(val: T, precision: usize) -> String {
    let mut val = val.to_string();
    match T::STR {
        "c32" => {
            let tmp_val: num_complex::Complex32 = val
                .parse::<num_complex::Complex32>()
                .expect("Failed to parse c32");
            let re = tmp_val.re;
            let im = tmp_val.im;
            let re_str = format_float(re, precision);
            let im_str = format_float(im, precision);
            val = format!("{} + {}i", re_str, im_str);
        }
        "c64" => {
            let tmp_val: num_complex::Complex64 = val
                .parse::<num_complex::Complex64>()
                .expect("Failed to parse c64");
            let re = tmp_val.re;
            let im = tmp_val.im;
            let re_str = format_float(re, precision);
            let im_str = format_float(im, precision);
            val = format!("{} + {}i", re_str, im_str);
        }
        _ => panic!("{} is not a complex type", T::STR),
    }
    val
}

pub(crate) fn format_val<T: CommonBounds + Cast<f64>>(val: T, precision: usize) -> String {
    match T::STR {
        "bf16" | "f16" | "f32" | "f64" => format_float(val, precision),

        "c32" | "c64" => format_complex(val, precision),
        _ => val.to_string(),
    }
}

#[test]
fn test_complex() {
    let val = num_complex::Complex32::new(1.0, 2.0);
    let val = format_complex(val, 2);
    println!("{}", val);
}