1use std::any::Any;
2use std::collections::BTreeMap;
3use std::fmt;
4use std::rc::Rc;
5use std::cell::RefCell;
6
7use crate::aligned_pool::AlignedByteSlice;
8use crate::complex;
9use crate::det_map::DetMap;
10use crate::gc::GcRef;
11use crate::paged_kv::PagedKvCache;
12use crate::scratchpad::Scratchpad;
13use crate::sparse::SparseCsr;
14use crate::tensor::Tensor;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
23pub struct Bf16(pub u16);
24
25impl Bf16 {
26 pub fn from_f32(v: f32) -> Self {
28 Bf16((v.to_bits() >> 16) as u16)
29 }
30
31 pub fn to_f32(self) -> f32 {
33 f32::from_bits((self.0 as u32) << 16)
34 }
35
36 pub fn add(self, rhs: Self) -> Self {
37 Self::from_f32(self.to_f32() + rhs.to_f32())
38 }
39
40 pub fn sub(self, rhs: Self) -> Self {
41 Self::from_f32(self.to_f32() - rhs.to_f32())
42 }
43
44 pub fn mul(self, rhs: Self) -> Self {
45 Self::from_f32(self.to_f32() * rhs.to_f32())
46 }
47
48 pub fn div(self, rhs: Self) -> Self {
49 Self::from_f32(self.to_f32() / rhs.to_f32())
50 }
51
52 pub fn neg(self) -> Self {
53 Self::from_f32(-self.to_f32())
54 }
55}
56
57impl fmt::Display for Bf16 {
58 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59 write!(f, "{}", self.to_f32())
60 }
61}
62
63#[derive(Debug, Clone)]
65pub struct FnValue {
66 pub name: String,
68 pub arity: usize,
70 pub body_id: usize,
72}
73
74#[derive(Debug, Clone)]
76pub enum Value {
77 Int(i64),
78 Float(f64),
79 Bool(bool),
80 String(Rc<String>),
81 Bytes(Rc<RefCell<Vec<u8>>>),
83 ByteSlice(Rc<Vec<u8>>),
87 StrView(Rc<Vec<u8>>),
90 U8(u8),
92 Tensor(Tensor),
93 SparseTensor(SparseCsr),
94 Map(Rc<RefCell<DetMap>>),
95 Array(Rc<Vec<Value>>),
98 Struct {
99 name: String,
100 fields: BTreeMap<String, Value>,
101 },
102 Tuple(Rc<Vec<Value>>),
104 ClassRef(GcRef),
105 Fn(FnValue),
106 Closure {
108 fn_name: String,
109 env: Vec<Value>,
110 arity: usize,
112 },
113 Enum {
115 enum_name: String,
116 variant: String,
117 fields: Vec<Value>,
118 },
119 Regex { pattern: String, flags: String },
121 Bf16(Bf16),
123 F16(crate::f16::F16),
125 Complex(complex::ComplexF64),
127 Scratchpad(Rc<RefCell<Scratchpad>>),
132 PagedKvCache(Rc<RefCell<PagedKvCache>>),
137 AlignedBytes(AlignedByteSlice),
143 GradGraph(Rc<RefCell<dyn Any>>),
147 OptimizerState(Rc<RefCell<dyn Any>>),
150 TidyView(Rc<dyn Any>),
156 GroupedTidyView(Rc<dyn Any>),
159 VizorPlot(Rc<dyn Any>),
164 QuantumState(Rc<RefCell<dyn Any>>),
169 Void,
170}
171
172impl Value {
173 pub fn type_name(&self) -> &str {
175 match self {
176 Value::Int(_) => "Int",
177 Value::Float(_) => "Float",
178 Value::Bool(_) => "Bool",
179 Value::String(_) => "String",
180 Value::Bytes(_) => "Bytes",
181 Value::ByteSlice(_) => "ByteSlice",
182 Value::StrView(_) => "StrView",
183 Value::U8(_) => "u8",
184 Value::Tensor(_) => "Tensor",
185 Value::SparseTensor(_) => "SparseTensor",
186 Value::Map(_) => "Map",
187 Value::Array(_) => "Array",
188 Value::Tuple(_) => "Tuple",
189 Value::Struct { .. } => "Struct",
190 Value::Enum { .. } => "Enum",
191 Value::ClassRef(_) => "ClassRef",
192 Value::Fn(_) => "Fn",
193 Value::Closure { .. } => "Closure",
194 Value::Regex { .. } => "Regex",
195 Value::Bf16(_) => "Bf16",
196 Value::F16(_) => "F16",
197 Value::Complex(_) => "Complex",
198 Value::Scratchpad(_) => "Scratchpad",
199 Value::PagedKvCache(_) => "PagedKvCache",
200 Value::AlignedBytes(_) => "AlignedBytes",
201 Value::GradGraph(_) => "GradGraph",
202 Value::OptimizerState(_) => "OptimizerState",
203 Value::TidyView(_) => "TidyView",
204 Value::GroupedTidyView(_) => "GroupedTidyView",
205 Value::VizorPlot(_) => "VizorPlot",
206 Value::QuantumState(_) => "QuantumState",
207 Value::Void => "Void",
208 }
209 }
210}
211
212impl fmt::Display for Value {
213 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
214 match self {
215 Value::Int(v) => write!(f, "{v}"),
216 Value::Float(v) => write!(f, "{v}"),
217 Value::Bool(v) => write!(f, "{v}"),
218 Value::String(v) => write!(f, "{v}"),
219 Value::Bytes(b) => {
220 let b = b.borrow();
221 write!(f, "Bytes([")?;
222 for (i, byte) in b.iter().enumerate() {
223 if i > 0 { write!(f, ", ")?; }
224 write!(f, "{byte}")?;
225 }
226 write!(f, "])")
227 }
228 Value::ByteSlice(b) => {
229 match std::str::from_utf8(b) {
231 Ok(s) => write!(f, "b\"{s}\""),
232 Err(_) => {
233 write!(f, "b\"")?;
234 for &byte in b.iter() {
235 if byte.is_ascii_graphic() || byte == b' ' {
236 write!(f, "{}", byte as char)?;
237 } else {
238 write!(f, "\\x{byte:02x}")?;
239 }
240 }
241 write!(f, "\"")
242 }
243 }
244 }
245 Value::StrView(b) => {
246 let s = std::str::from_utf8(b).unwrap_or("<invalid utf8>");
248 write!(f, "{s}")
249 }
250 Value::U8(v) => write!(f, "{v}"),
251 Value::Tensor(t) => write!(f, "{t}"),
252 Value::SparseTensor(s) => write!(f, "SparseTensor({}x{}, nnz={})", s.nrows, s.ncols, s.nnz()),
253 Value::Map(m) => {
254 let m = m.borrow();
255 write!(f, "Map({{")?;
256 for (i, (k, v)) in m.iter().enumerate() {
257 if i > 0 {
258 write!(f, ", ")?;
259 }
260 write!(f, "{k}: {v}")?;
261 }
262 write!(f, "}})")
263 }
264 Value::Array(arr) => {
265 write!(f, "[")?;
266 for (i, v) in arr.iter().enumerate() {
267 if i > 0 {
268 write!(f, ", ")?;
269 }
270 write!(f, "{v}")?;
271 }
272 write!(f, "]")
273 }
274 Value::Tuple(elems) => {
275 write!(f, "(")?;
276 for (i, v) in elems.iter().enumerate() {
277 if i > 0 {
278 write!(f, ", ")?;
279 }
280 write!(f, "{v}")?;
281 }
282 write!(f, ")")
283 }
284 Value::Struct { name, fields } => {
285 write!(f, "{name} {{ ")?;
286 for (i, (k, v)) in fields.iter().enumerate() {
287 if i > 0 {
288 write!(f, ", ")?;
289 }
290 write!(f, "{k}: {v}")?;
291 }
292 write!(f, " }}")
293 }
294 Value::Enum {
295 enum_name: _,
296 variant,
297 fields,
298 } => {
299 write!(f, "{variant}")?;
300 if !fields.is_empty() {
301 write!(f, "(")?;
302 for (i, v) in fields.iter().enumerate() {
303 if i > 0 {
304 write!(f, ", ")?;
305 }
306 write!(f, "{v}")?;
307 }
308 write!(f, ")")?;
309 }
310 Ok(())
311 }
312 Value::Regex { pattern, flags } => {
313 write!(f, "/{pattern}/")?;
314 if !flags.is_empty() {
315 write!(f, "{flags}")?;
316 }
317 Ok(())
318 }
319 Value::Bf16(v) => write!(f, "{}", v.to_f32()),
320 Value::F16(v) => write!(f, "{}", v.to_f64()),
321 Value::Complex(z) => write!(f, "{z}"),
322 Value::ClassRef(r) => write!(f, "<object@{}>", r.index),
323 Value::Fn(fv) => write!(f, "<fn {}({})>", fv.name, fv.arity),
324 Value::Closure {
325 fn_name, arity, ..
326 } => write!(f, "<closure {}({})>", fn_name, arity),
327 Value::Scratchpad(s) => write!(f, "{}", s.borrow()),
328 Value::PagedKvCache(c) => write!(f, "{}", c.borrow()),
329 Value::AlignedBytes(a) => write!(f, "{}", a),
330 Value::GradGraph(_) => write!(f, "<GradGraph>"),
331 Value::OptimizerState(_) => write!(f, "<OptimizerState>"),
332 Value::TidyView(_) => write!(f, "<TidyView>"),
333 Value::GroupedTidyView(_) => write!(f, "<GroupedTidyView>"),
334 Value::VizorPlot(_) => write!(f, "<VizorPlot>"),
335 Value::QuantumState(_) => write!(f, "<QuantumState>"),
336 Value::Void => write!(f, "void"),
337 }
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344 use std::rc::Rc;
345
346 #[test]
347 fn int_display() {
348 assert_eq!(format!("{}", Value::Int(42)), "42");
349 assert_eq!(format!("{}", Value::Int(-1)), "-1");
350 }
351
352 #[test]
353 fn float_display() {
354 let s = format!("{}", Value::Float(3.14));
355 assert!(s.starts_with("3.14"), "got: {s}");
356 }
357
358 #[test]
359 fn bool_display() {
360 assert_eq!(format!("{}", Value::Bool(true)), "true");
361 assert_eq!(format!("{}", Value::Bool(false)), "false");
362 }
363
364 #[test]
365 fn string_display() {
366 let v = Value::String(Rc::new("hello".to_string()));
367 assert_eq!(format!("{v}"), "hello");
368 }
369
370 #[test]
371 fn void_display() {
372 assert_eq!(format!("{}", Value::Void), "void");
373 }
374
375 #[test]
376 fn type_name_coverage() {
377 assert_eq!(Value::Int(0).type_name(), "Int");
378 assert_eq!(Value::Float(0.0).type_name(), "Float");
379 assert_eq!(Value::Bool(true).type_name(), "Bool");
380 assert_eq!(Value::String(Rc::new(String::new())).type_name(), "String");
381 assert_eq!(Value::Void.type_name(), "Void");
382 }
383
384 #[test]
385 fn tuple_display() {
386 let t = Value::Tuple(Rc::new(vec![
387 Value::Int(1),
388 Value::Bool(true),
389 ]));
390 let s = format!("{t}");
391 assert!(s.contains("1"), "tuple should contain 1, got: {s}");
392 assert!(s.contains("true"), "tuple should contain true, got: {s}");
393 }
394
395 #[test]
396 fn array_display() {
397 let a = Value::Array(Rc::new(vec![
398 Value::Int(10),
399 Value::Int(20),
400 ]));
401 let s = format!("{a}");
402 assert!(s.contains("10"), "array should contain 10, got: {s}");
403 assert!(s.contains("20"), "array should contain 20, got: {s}");
404 }
405
406 #[test]
407 fn struct_value_display() {
408 let mut fields = std::collections::BTreeMap::new();
409 fields.insert("x".to_string(), Value::Int(1));
410 fields.insert("y".to_string(), Value::Int(2));
411 let sv = Value::Struct {
412 name: "Point".to_string(),
413 fields,
414 };
415 let s = format!("{sv}");
416 assert!(s.contains("Point"), "struct display should contain name, got: {s}");
417 }
418
419 #[test]
420 fn enum_value_display() {
421 let ev = Value::Enum {
422 enum_name: "Option".to_string(),
423 variant: "Some".to_string(),
424 fields: vec![Value::Int(42)],
425 };
426 let s = format!("{ev}");
427 assert!(s.contains("Some"), "enum display should contain variant, got: {s}");
428 }
429
430 #[test]
431 fn map_display() {
432 let m = Value::Map(Rc::new(std::cell::RefCell::new(crate::det_map::DetMap::new())));
433 let s = format!("{m}");
434 assert!(s.contains("{") || s.contains("Map"), "map display should be readable, got: {s}");
435 }
436}
437