use std::any::Any;
use std::collections::BTreeMap;
use std::fmt;
use std::rc::Rc;
use std::cell::RefCell;
use crate::aligned_pool::AlignedByteSlice;
use crate::complex;
use crate::det_map::DetMap;
use crate::gc::GcRef;
use crate::paged_kv::PagedKvCache;
use crate::scratchpad::Scratchpad;
use crate::sparse::SparseCsr;
use crate::tensor::Tensor;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Bf16(pub u16);
impl Bf16 {
pub fn from_f32(v: f32) -> Self {
Bf16((v.to_bits() >> 16) as u16)
}
pub fn to_f32(self) -> f32 {
f32::from_bits((self.0 as u32) << 16)
}
pub fn add(self, rhs: Self) -> Self {
Self::from_f32(self.to_f32() + rhs.to_f32())
}
pub fn sub(self, rhs: Self) -> Self {
Self::from_f32(self.to_f32() - rhs.to_f32())
}
pub fn mul(self, rhs: Self) -> Self {
Self::from_f32(self.to_f32() * rhs.to_f32())
}
pub fn div(self, rhs: Self) -> Self {
Self::from_f32(self.to_f32() / rhs.to_f32())
}
pub fn neg(self) -> Self {
Self::from_f32(-self.to_f32())
}
}
impl fmt::Display for Bf16 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_f32())
}
}
#[derive(Debug, Clone)]
pub struct FnValue {
pub name: String,
pub arity: usize,
pub body_id: usize,
}
#[derive(Debug, Clone)]
pub enum Value {
Int(i64),
Float(f64),
Bool(bool),
String(Rc<String>),
Bytes(Rc<RefCell<Vec<u8>>>),
ByteSlice(Rc<Vec<u8>>),
StrView(Rc<Vec<u8>>),
U8(u8),
Tensor(Tensor),
SparseTensor(SparseCsr),
Map(Rc<RefCell<DetMap>>),
Array(Rc<Vec<Value>>),
Struct {
name: String,
fields: BTreeMap<String, Value>,
},
Tuple(Rc<Vec<Value>>),
ClassRef(GcRef),
Fn(FnValue),
Closure {
fn_name: String,
env: Vec<Value>,
arity: usize,
},
Enum {
enum_name: String,
variant: String,
fields: Vec<Value>,
},
Regex { pattern: String, flags: String },
Bf16(Bf16),
F16(crate::f16::F16),
Complex(complex::ComplexF64),
Scratchpad(Rc<RefCell<Scratchpad>>),
PagedKvCache(Rc<RefCell<PagedKvCache>>),
AlignedBytes(AlignedByteSlice),
GradGraph(Rc<RefCell<dyn Any>>),
OptimizerState(Rc<RefCell<dyn Any>>),
TidyView(Rc<dyn Any>),
GroupedTidyView(Rc<dyn Any>),
VizorPlot(Rc<dyn Any>),
QuantumState(Rc<RefCell<dyn Any>>),
Na,
Void,
}
impl Value {
pub fn type_name(&self) -> &str {
match self {
Value::Int(_) => "Int",
Value::Float(_) => "Float",
Value::Bool(_) => "Bool",
Value::String(_) => "String",
Value::Bytes(_) => "Bytes",
Value::ByteSlice(_) => "ByteSlice",
Value::StrView(_) => "StrView",
Value::U8(_) => "u8",
Value::Tensor(_) => "Tensor",
Value::SparseTensor(_) => "SparseTensor",
Value::Map(_) => "Map",
Value::Array(_) => "Array",
Value::Tuple(_) => "Tuple",
Value::Struct { .. } => "Struct",
Value::Enum { .. } => "Enum",
Value::ClassRef(_) => "ClassRef",
Value::Fn(_) => "Fn",
Value::Closure { .. } => "Closure",
Value::Regex { .. } => "Regex",
Value::Bf16(_) => "Bf16",
Value::F16(_) => "F16",
Value::Complex(_) => "Complex",
Value::Scratchpad(_) => "Scratchpad",
Value::PagedKvCache(_) => "PagedKvCache",
Value::AlignedBytes(_) => "AlignedBytes",
Value::GradGraph(_) => "GradGraph",
Value::OptimizerState(_) => "OptimizerState",
Value::TidyView(_) => "TidyView",
Value::GroupedTidyView(_) => "GroupedTidyView",
Value::VizorPlot(_) => "VizorPlot",
Value::QuantumState(_) => "QuantumState",
Value::Na => "Na",
Value::Void => "Void",
}
}
}
impl fmt::Display for Value {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Value::Int(v) => write!(f, "{v}"),
Value::Float(v) => write!(f, "{v}"),
Value::Bool(v) => write!(f, "{v}"),
Value::String(v) => write!(f, "{v}"),
Value::Bytes(b) => {
let b = b.borrow();
write!(f, "Bytes([")?;
for (i, byte) in b.iter().enumerate() {
if i > 0 { write!(f, ", ")?; }
write!(f, "{byte}")?;
}
write!(f, "])")
}
Value::ByteSlice(b) => {
match std::str::from_utf8(b) {
Ok(s) => write!(f, "b\"{s}\""),
Err(_) => {
write!(f, "b\"")?;
for &byte in b.iter() {
if byte.is_ascii_graphic() || byte == b' ' {
write!(f, "{}", byte as char)?;
} else {
write!(f, "\\x{byte:02x}")?;
}
}
write!(f, "\"")
}
}
}
Value::StrView(b) => {
let s = std::str::from_utf8(b).unwrap_or("<invalid utf8>");
write!(f, "{s}")
}
Value::U8(v) => write!(f, "{v}"),
Value::Tensor(t) => write!(f, "{t}"),
Value::SparseTensor(s) => write!(f, "SparseTensor({}x{}, nnz={})", s.nrows, s.ncols, s.nnz()),
Value::Map(m) => {
let m = m.borrow();
write!(f, "Map({{")?;
for (i, (k, v)) in m.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{k}: {v}")?;
}
write!(f, "}})")
}
Value::Array(arr) => {
write!(f, "[")?;
for (i, v) in arr.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{v}")?;
}
write!(f, "]")
}
Value::Tuple(elems) => {
write!(f, "(")?;
for (i, v) in elems.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{v}")?;
}
write!(f, ")")
}
Value::Struct { name, fields } => {
write!(f, "{name} {{ ")?;
for (i, (k, v)) in fields.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{k}: {v}")?;
}
write!(f, " }}")
}
Value::Enum {
enum_name: _,
variant,
fields,
} => {
write!(f, "{variant}")?;
if !fields.is_empty() {
write!(f, "(")?;
for (i, v) in fields.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{v}")?;
}
write!(f, ")")?;
}
Ok(())
}
Value::Regex { pattern, flags } => {
write!(f, "/{pattern}/")?;
if !flags.is_empty() {
write!(f, "{flags}")?;
}
Ok(())
}
Value::Bf16(v) => write!(f, "{}", v.to_f32()),
Value::F16(v) => write!(f, "{}", v.to_f64()),
Value::Complex(z) => write!(f, "{z}"),
Value::ClassRef(r) => write!(f, "<object@{}>", r.index),
Value::Fn(fv) => write!(f, "<fn {}({})>", fv.name, fv.arity),
Value::Closure {
fn_name, arity, ..
} => write!(f, "<closure {}({})>", fn_name, arity),
Value::Scratchpad(s) => write!(f, "{}", s.borrow()),
Value::PagedKvCache(c) => write!(f, "{}", c.borrow()),
Value::AlignedBytes(a) => write!(f, "{}", a),
Value::GradGraph(_) => write!(f, "<GradGraph>"),
Value::OptimizerState(_) => write!(f, "<OptimizerState>"),
Value::TidyView(_) => write!(f, "<TidyView>"),
Value::GroupedTidyView(_) => write!(f, "<GroupedTidyView>"),
Value::VizorPlot(_) => write!(f, "<VizorPlot>"),
Value::QuantumState(_) => write!(f, "<QuantumState>"),
Value::Na => write!(f, "NA"),
Value::Void => write!(f, "void"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::rc::Rc;
#[test]
fn int_display() {
assert_eq!(format!("{}", Value::Int(42)), "42");
assert_eq!(format!("{}", Value::Int(-1)), "-1");
}
#[test]
fn float_display() {
let s = format!("{}", Value::Float(3.14));
assert!(s.starts_with("3.14"), "got: {s}");
}
#[test]
fn bool_display() {
assert_eq!(format!("{}", Value::Bool(true)), "true");
assert_eq!(format!("{}", Value::Bool(false)), "false");
}
#[test]
fn string_display() {
let v = Value::String(Rc::new("hello".to_string()));
assert_eq!(format!("{v}"), "hello");
}
#[test]
fn void_display() {
assert_eq!(format!("{}", Value::Void), "void");
}
#[test]
fn type_name_coverage() {
assert_eq!(Value::Int(0).type_name(), "Int");
assert_eq!(Value::Float(0.0).type_name(), "Float");
assert_eq!(Value::Bool(true).type_name(), "Bool");
assert_eq!(Value::String(Rc::new(String::new())).type_name(), "String");
assert_eq!(Value::Void.type_name(), "Void");
}
#[test]
fn tuple_display() {
let t = Value::Tuple(Rc::new(vec![
Value::Int(1),
Value::Bool(true),
]));
let s = format!("{t}");
assert!(s.contains("1"), "tuple should contain 1, got: {s}");
assert!(s.contains("true"), "tuple should contain true, got: {s}");
}
#[test]
fn array_display() {
let a = Value::Array(Rc::new(vec![
Value::Int(10),
Value::Int(20),
]));
let s = format!("{a}");
assert!(s.contains("10"), "array should contain 10, got: {s}");
assert!(s.contains("20"), "array should contain 20, got: {s}");
}
#[test]
fn struct_value_display() {
let mut fields = std::collections::BTreeMap::new();
fields.insert("x".to_string(), Value::Int(1));
fields.insert("y".to_string(), Value::Int(2));
let sv = Value::Struct {
name: "Point".to_string(),
fields,
};
let s = format!("{sv}");
assert!(s.contains("Point"), "struct display should contain name, got: {s}");
}
#[test]
fn enum_value_display() {
let ev = Value::Enum {
enum_name: "Option".to_string(),
variant: "Some".to_string(),
fields: vec![Value::Int(42)],
};
let s = format!("{ev}");
assert!(s.contains("Some"), "enum display should contain variant, got: {s}");
}
#[test]
fn map_display() {
let m = Value::Map(Rc::new(std::cell::RefCell::new(crate::det_map::DetMap::new())));
let s = format!("{m}");
assert!(s.contains("{") || s.contains("Map"), "map display should be readable, got: {s}");
}
}