1use js_sys::{Array, Float64Array, Int32Array, Object, Reflect, Uint32Array, Uint8Array};
2use wasm_bindgen::prelude::*;
3
4pub enum ColumnData {
5 F64(Vec<f64>),
6 I32(Vec<i32>),
7 U32(Vec<u32>),
8 Bool(Vec<bool>),
9 Enum {
10 indices: Vec<u32>,
11 labels: Vec<String>,
12 },
13}
14
15pub struct ModelOutput {
16 length: usize,
17 columns: Vec<(String, ColumnData)>,
18}
19
20impl ModelOutput {
21 pub fn new(length: usize) -> Self {
22 Self {
23 length,
24 columns: Vec::new(),
25 }
26 }
27
28 pub fn add_f64(mut self, name: &str, data: Vec<f64>) -> Self {
29 self.columns.push((name.to_string(), ColumnData::F64(data)));
30 self
31 }
32
33 pub fn add_i32(mut self, name: &str, data: Vec<i32>) -> Self {
34 self.columns.push((name.to_string(), ColumnData::I32(data)));
35 self
36 }
37
38 pub fn add_u32(mut self, name: &str, data: Vec<u32>) -> Self {
39 self.columns.push((name.to_string(), ColumnData::U32(data)));
40 self
41 }
42
43 pub fn add_bool(mut self, name: &str, data: Vec<bool>) -> Self {
44 self.columns
45 .push((name.to_string(), ColumnData::Bool(data)));
46 self
47 }
48
49 pub fn add_enum(mut self, name: &str, indices: Vec<u32>, labels: Vec<&str>) -> Self {
50 self.columns.push((
51 name.to_string(),
52 ColumnData::Enum {
53 indices,
54 labels: labels.into_iter().map(|s| s.to_string()).collect(),
55 },
56 ));
57 self
58 }
59
60 pub fn into_js(self) -> JsValue {
61 let obj = Object::new();
62 Reflect::set(&obj, &"__modelOutput".into(), &true.into()).unwrap();
63 Reflect::set(&obj, &"length".into(), &(self.length as f64).into()).unwrap();
64
65 let cols = Array::new();
66 let bufs = Array::new();
67
68 for (name, data) in self.columns {
69 let desc = Object::new();
70 Reflect::set(&desc, &"name".into(), &name.into()).unwrap();
71
72 match data {
73 ColumnData::F64(v) => {
74 Reflect::set(&desc, &"type".into(), &"f64".into()).unwrap();
75 let arr = Float64Array::from(v.as_slice());
76 bufs.push(&arr.buffer());
77 }
78 ColumnData::I32(v) => {
79 Reflect::set(&desc, &"type".into(), &"i32".into()).unwrap();
80 let arr = Int32Array::from(v.as_slice());
81 bufs.push(&arr.buffer());
82 }
83 ColumnData::U32(v) => {
84 Reflect::set(&desc, &"type".into(), &"u32".into()).unwrap();
85 let arr = Uint32Array::from(v.as_slice());
86 bufs.push(&arr.buffer());
87 }
88 ColumnData::Bool(v) => {
89 Reflect::set(&desc, &"type".into(), &"bool".into()).unwrap();
90 let bytes: Vec<u8> = v.iter().map(|b| if *b { 1 } else { 0 }).collect();
91 let arr = Uint8Array::from(bytes.as_slice());
92 bufs.push(&arr.buffer());
93 }
94 ColumnData::Enum { indices, labels } => {
95 Reflect::set(&desc, &"type".into(), &"enum".into()).unwrap();
96 let labels_arr = Array::new();
97 for l in &labels {
98 labels_arr.push(&JsValue::from_str(l));
99 }
100 Reflect::set(&desc, &"enumLabels".into(), &labels_arr).unwrap();
101 let arr = Uint32Array::from(indices.as_slice());
102 bufs.push(&arr.buffer());
103 }
104 }
105 cols.push(&desc);
106 }
107
108 Reflect::set(&obj, &"columns".into(), &cols).unwrap();
109 Reflect::set(&obj, &"buffers".into(), &bufs).unwrap();
110 obj.into()
111 }
112}
113
114pub fn model_outputs<const N: usize>(pairs: [(&str, ModelOutput); N]) -> JsValue {
116 let obj = Object::new();
117 Reflect::set(&obj, &"__modelOutputs".into(), &true.into()).unwrap();
118
119 let outputs = Object::new();
120 for (name, output) in pairs {
121 Reflect::set(&outputs, &name.into(), &output.into_js()).unwrap();
122 }
123 Reflect::set(&obj, &"outputs".into(), &outputs).unwrap();
124
125 obj.into()
126}