1use std::any::Any;
2use std::collections::BTreeMap;
3use std::fmt;
4use std::rc::Rc;
5use std::cell::RefCell;
6
7use crate::aligned_pool::AlignedByteSlice;
8use crate::complex;
9use crate::det_map::DetMap;
10use crate::gc::GcRef;
11use crate::paged_kv::PagedKvCache;
12use crate::scratchpad::Scratchpad;
13use crate::sparse::SparseCsr;
14use crate::tensor::Tensor;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
23pub struct Bf16(pub u16);
24
25impl Bf16 {
26 pub fn from_f32(v: f32) -> Self {
28 Bf16((v.to_bits() >> 16) as u16)
29 }
30
31 pub fn to_f32(self) -> f32 {
33 f32::from_bits((self.0 as u32) << 16)
34 }
35
36 pub fn add(self, rhs: Self) -> Self {
37 Self::from_f32(self.to_f32() + rhs.to_f32())
38 }
39
40 pub fn sub(self, rhs: Self) -> Self {
41 Self::from_f32(self.to_f32() - rhs.to_f32())
42 }
43
44 pub fn mul(self, rhs: Self) -> Self {
45 Self::from_f32(self.to_f32() * rhs.to_f32())
46 }
47
48 pub fn div(self, rhs: Self) -> Self {
49 Self::from_f32(self.to_f32() / rhs.to_f32())
50 }
51
52 pub fn neg(self) -> Self {
53 Self::from_f32(-self.to_f32())
54 }
55}
56
57impl fmt::Display for Bf16 {
58 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59 write!(f, "{}", self.to_f32())
60 }
61}
62
63#[derive(Debug, Clone)]
65pub struct FnValue {
66 pub name: String,
68 pub arity: usize,
70 pub body_id: usize,
72}
73
74#[derive(Debug, Clone)]
76pub enum Value {
77 Int(i64),
78 Float(f64),
79 Bool(bool),
80 String(Rc<String>),
81 Bytes(Rc<RefCell<Vec<u8>>>),
83 ByteSlice(Rc<Vec<u8>>),
87 StrView(Rc<Vec<u8>>),
90 U8(u8),
92 Tensor(Tensor),
93 SparseTensor(SparseCsr),
94 Map(Rc<RefCell<DetMap>>),
95 Array(Rc<Vec<Value>>),
98 Struct {
99 name: String,
100 fields: BTreeMap<String, Value>,
101 },
102 Tuple(Rc<Vec<Value>>),
104 ClassRef(GcRef),
105 Fn(FnValue),
106 Closure {
108 fn_name: String,
109 env: Vec<Value>,
110 arity: usize,
112 },
113 Enum {
115 enum_name: String,
116 variant: String,
117 fields: Vec<Value>,
118 },
119 Regex { pattern: String, flags: String },
121 Bf16(Bf16),
123 F16(crate::f16::F16),
125 Complex(complex::ComplexF64),
127 Scratchpad(Rc<RefCell<Scratchpad>>),
132 PagedKvCache(Rc<RefCell<PagedKvCache>>),
137 AlignedBytes(AlignedByteSlice),
143 GradGraph(Rc<RefCell<dyn Any>>),
147 OptimizerState(Rc<RefCell<dyn Any>>),
150 TidyView(Rc<dyn Any>),
156 GroupedTidyView(Rc<dyn Any>),
159 VizorPlot(Rc<dyn Any>),
164 Void,
165}
166
167impl Value {
168 pub fn type_name(&self) -> &str {
170 match self {
171 Value::Int(_) => "Int",
172 Value::Float(_) => "Float",
173 Value::Bool(_) => "Bool",
174 Value::String(_) => "String",
175 Value::Bytes(_) => "Bytes",
176 Value::ByteSlice(_) => "ByteSlice",
177 Value::StrView(_) => "StrView",
178 Value::U8(_) => "u8",
179 Value::Tensor(_) => "Tensor",
180 Value::SparseTensor(_) => "SparseTensor",
181 Value::Map(_) => "Map",
182 Value::Array(_) => "Array",
183 Value::Tuple(_) => "Tuple",
184 Value::Struct { .. } => "Struct",
185 Value::Enum { .. } => "Enum",
186 Value::ClassRef(_) => "ClassRef",
187 Value::Fn(_) => "Fn",
188 Value::Closure { .. } => "Closure",
189 Value::Regex { .. } => "Regex",
190 Value::Bf16(_) => "Bf16",
191 Value::F16(_) => "F16",
192 Value::Complex(_) => "Complex",
193 Value::Scratchpad(_) => "Scratchpad",
194 Value::PagedKvCache(_) => "PagedKvCache",
195 Value::AlignedBytes(_) => "AlignedBytes",
196 Value::GradGraph(_) => "GradGraph",
197 Value::OptimizerState(_) => "OptimizerState",
198 Value::TidyView(_) => "TidyView",
199 Value::GroupedTidyView(_) => "GroupedTidyView",
200 Value::VizorPlot(_) => "VizorPlot",
201 Value::Void => "Void",
202 }
203 }
204}
205
206impl fmt::Display for Value {
207 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208 match self {
209 Value::Int(v) => write!(f, "{v}"),
210 Value::Float(v) => write!(f, "{v}"),
211 Value::Bool(v) => write!(f, "{v}"),
212 Value::String(v) => write!(f, "{v}"),
213 Value::Bytes(b) => {
214 let b = b.borrow();
215 write!(f, "Bytes([")?;
216 for (i, byte) in b.iter().enumerate() {
217 if i > 0 { write!(f, ", ")?; }
218 write!(f, "{byte}")?;
219 }
220 write!(f, "])")
221 }
222 Value::ByteSlice(b) => {
223 match std::str::from_utf8(b) {
225 Ok(s) => write!(f, "b\"{s}\""),
226 Err(_) => {
227 write!(f, "b\"")?;
228 for &byte in b.iter() {
229 if byte.is_ascii_graphic() || byte == b' ' {
230 write!(f, "{}", byte as char)?;
231 } else {
232 write!(f, "\\x{byte:02x}")?;
233 }
234 }
235 write!(f, "\"")
236 }
237 }
238 }
239 Value::StrView(b) => {
240 let s = std::str::from_utf8(b).unwrap_or("<invalid utf8>");
242 write!(f, "{s}")
243 }
244 Value::U8(v) => write!(f, "{v}"),
245 Value::Tensor(t) => write!(f, "{t}"),
246 Value::SparseTensor(s) => write!(f, "SparseTensor({}x{}, nnz={})", s.nrows, s.ncols, s.nnz()),
247 Value::Map(m) => {
248 let m = m.borrow();
249 write!(f, "Map({{")?;
250 for (i, (k, v)) in m.iter().enumerate() {
251 if i > 0 {
252 write!(f, ", ")?;
253 }
254 write!(f, "{k}: {v}")?;
255 }
256 write!(f, "}})")
257 }
258 Value::Array(arr) => {
259 write!(f, "[")?;
260 for (i, v) in arr.iter().enumerate() {
261 if i > 0 {
262 write!(f, ", ")?;
263 }
264 write!(f, "{v}")?;
265 }
266 write!(f, "]")
267 }
268 Value::Tuple(elems) => {
269 write!(f, "(")?;
270 for (i, v) in elems.iter().enumerate() {
271 if i > 0 {
272 write!(f, ", ")?;
273 }
274 write!(f, "{v}")?;
275 }
276 write!(f, ")")
277 }
278 Value::Struct { name, fields } => {
279 write!(f, "{name} {{ ")?;
280 for (i, (k, v)) in fields.iter().enumerate() {
281 if i > 0 {
282 write!(f, ", ")?;
283 }
284 write!(f, "{k}: {v}")?;
285 }
286 write!(f, " }}")
287 }
288 Value::Enum {
289 enum_name: _,
290 variant,
291 fields,
292 } => {
293 write!(f, "{variant}")?;
294 if !fields.is_empty() {
295 write!(f, "(")?;
296 for (i, v) in fields.iter().enumerate() {
297 if i > 0 {
298 write!(f, ", ")?;
299 }
300 write!(f, "{v}")?;
301 }
302 write!(f, ")")?;
303 }
304 Ok(())
305 }
306 Value::Regex { pattern, flags } => {
307 write!(f, "/{pattern}/")?;
308 if !flags.is_empty() {
309 write!(f, "{flags}")?;
310 }
311 Ok(())
312 }
313 Value::Bf16(v) => write!(f, "{}", v.to_f32()),
314 Value::F16(v) => write!(f, "{}", v.to_f64()),
315 Value::Complex(z) => write!(f, "{z}"),
316 Value::ClassRef(r) => write!(f, "<object@{}>", r.index),
317 Value::Fn(fv) => write!(f, "<fn {}({})>", fv.name, fv.arity),
318 Value::Closure {
319 fn_name, arity, ..
320 } => write!(f, "<closure {}({})>", fn_name, arity),
321 Value::Scratchpad(s) => write!(f, "{}", s.borrow()),
322 Value::PagedKvCache(c) => write!(f, "{}", c.borrow()),
323 Value::AlignedBytes(a) => write!(f, "{}", a),
324 Value::GradGraph(_) => write!(f, "<GradGraph>"),
325 Value::OptimizerState(_) => write!(f, "<OptimizerState>"),
326 Value::TidyView(_) => write!(f, "<TidyView>"),
327 Value::GroupedTidyView(_) => write!(f, "<GroupedTidyView>"),
328 Value::VizorPlot(_) => write!(f, "<VizorPlot>"),
329 Value::Void => write!(f, "void"),
330 }
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337 use std::rc::Rc;
338
339 #[test]
340 fn int_display() {
341 assert_eq!(format!("{}", Value::Int(42)), "42");
342 assert_eq!(format!("{}", Value::Int(-1)), "-1");
343 }
344
345 #[test]
346 fn float_display() {
347 let s = format!("{}", Value::Float(3.14));
348 assert!(s.starts_with("3.14"), "got: {s}");
349 }
350
351 #[test]
352 fn bool_display() {
353 assert_eq!(format!("{}", Value::Bool(true)), "true");
354 assert_eq!(format!("{}", Value::Bool(false)), "false");
355 }
356
357 #[test]
358 fn string_display() {
359 let v = Value::String(Rc::new("hello".to_string()));
360 assert_eq!(format!("{v}"), "hello");
361 }
362
363 #[test]
364 fn void_display() {
365 assert_eq!(format!("{}", Value::Void), "void");
366 }
367
368 #[test]
369 fn type_name_coverage() {
370 assert_eq!(Value::Int(0).type_name(), "Int");
371 assert_eq!(Value::Float(0.0).type_name(), "Float");
372 assert_eq!(Value::Bool(true).type_name(), "Bool");
373 assert_eq!(Value::String(Rc::new(String::new())).type_name(), "String");
374 assert_eq!(Value::Void.type_name(), "Void");
375 }
376
377 #[test]
378 fn tuple_display() {
379 let t = Value::Tuple(Rc::new(vec![
380 Value::Int(1),
381 Value::Bool(true),
382 ]));
383 let s = format!("{t}");
384 assert!(s.contains("1"), "tuple should contain 1, got: {s}");
385 assert!(s.contains("true"), "tuple should contain true, got: {s}");
386 }
387
388 #[test]
389 fn array_display() {
390 let a = Value::Array(Rc::new(vec![
391 Value::Int(10),
392 Value::Int(20),
393 ]));
394 let s = format!("{a}");
395 assert!(s.contains("10"), "array should contain 10, got: {s}");
396 assert!(s.contains("20"), "array should contain 20, got: {s}");
397 }
398
399 #[test]
400 fn struct_value_display() {
401 let mut fields = std::collections::BTreeMap::new();
402 fields.insert("x".to_string(), Value::Int(1));
403 fields.insert("y".to_string(), Value::Int(2));
404 let sv = Value::Struct {
405 name: "Point".to_string(),
406 fields,
407 };
408 let s = format!("{sv}");
409 assert!(s.contains("Point"), "struct display should contain name, got: {s}");
410 }
411
412 #[test]
413 fn enum_value_display() {
414 let ev = Value::Enum {
415 enum_name: "Option".to_string(),
416 variant: "Some".to_string(),
417 fields: vec![Value::Int(42)],
418 };
419 let s = format!("{ev}");
420 assert!(s.contains("Some"), "enum display should contain variant, got: {s}");
421 }
422
423 #[test]
424 fn map_display() {
425 let m = Value::Map(Rc::new(std::cell::RefCell::new(crate::det_map::DetMap::new())));
426 let s = format!("{m}");
427 assert!(s.contains("{") || s.contains("Map"), "map display should be readable, got: {s}");
428 }
429}
430