catgrad 0.2.1

a categorical deep learning compiler
Documentation
use crate::category::core::Dtype;
use std::fmt::{Display, Formatter, Result as FmtResult};

use super::interpreter::Value;
use super::value_types::*;

impl Display for Value {
    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
        match self {
            Value::Type(type_expr) => write!(f, "{type_expr}"),
            Value::Shape(shape_expr) => write!(f, "{shape_expr}"),
            Value::Nat(nat_expr) => write!(f, "{nat_expr}"),
            Value::Dtype(dtype_expr) => write!(f, "{dtype_expr}"),
            Value::Tensor(type_expr) => write!(f, "{type_expr}"),
        }
    }
}

impl Display for TypeExpr {
    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
        match self {
            TypeExpr::Var(n) => write!(f, "v{n}"),
            TypeExpr::NdArrayType(array_type) => write!(f, "{array_type}"),
        }
    }
}

impl Display for ShapeExpr {
    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
        match self {
            ShapeExpr::Var(i) => write!(f, "v{i}"),
            ShapeExpr::OfType(i) => write!(f, "shape_of(v{i})"),
            ShapeExpr::Shape(shape) => {
                write!(f, "[")?;
                for (i, dim) in shape.iter().enumerate() {
                    if i > 0 {
                        write!(f, ", ")?;
                    }
                    write!(f, "{dim}")?;
                }
                write!(f, "]")
            }
        }
    }
}

impl Display for NdArrayType {
    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
        match &self.shape {
            ShapeExpr::Var(i) => write!(f, "v{i} : {}", self.dtype),
            ShapeExpr::OfType(i) => write!(f, "shape_of(v{i}) : {}", self.dtype),
            ShapeExpr::Shape(shape) => {
                write!(f, "{}[", self.dtype)?;
                for (i, dim) in shape.iter().enumerate() {
                    if i > 0 {
                        write!(f, ", ")?;
                    }
                    write!(f, "{dim}")?;
                }
                write!(f, "]")
            }
        }
    }
}

impl Display for NatExpr {
    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
        match self {
            NatExpr::Var(n) => write!(f, "v{n}"),
            NatExpr::Constant(c) => write!(f, "{c}"),
            NatExpr::Mul(terms) => {
                if terms.is_empty() {
                    write!(f, "1")
                } else if terms.len() == 1 {
                    write!(f, "{}", terms[0])
                } else {
                    write!(f, "(")?;
                    for (i, term) in terms.iter().enumerate() {
                        if i > 0 {
                            write!(f, "*")?;
                        }
                        write!(f, "{term}")?;
                    }
                    write!(f, ")")
                }
            }
            NatExpr::Add(terms) => {
                if terms.is_empty() {
                    write!(f, "0")
                } else if terms.len() == 1 {
                    write!(f, "{}", terms[0])
                } else {
                    write!(f, "(")?;
                    for (i, term) in terms.iter().enumerate() {
                        if i > 0 {
                            write!(f, "+")?;
                        }
                        write!(f, "{term}")?;
                    }
                    write!(f, ")")
                }
            }
        }
    }
}

impl Display for DtypeExpr {
    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
        match self {
            DtypeExpr::Var(n) => write!(f, "v{n}"),
            DtypeExpr::OfType(n) => write!(f, "dtype(v{n})"),
            DtypeExpr::Constant(dtype) => write!(f, "{dtype}"),
        }
    }
}

impl Display for Dtype {
    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
        match self {
            Dtype::F32 => write!(f, "f32"),
            Dtype::U32 => write!(f, "u32"),
        }
    }
}