1use js_sys::{Array, Float64Array, Int32Array, Object, Reflect, Uint32Array, Uint8Array};
2use wasm_bindgen::prelude::*;
3
4pub enum ColumnData {
5 F64(Vec<f64>),
6 I32(Vec<i32>),
7 U32(Vec<u32>),
8 Bool(Vec<bool>),
9 Enum {
10 indices: Vec<u32>,
11 labels: Vec<String>,
12 },
13}
14
15pub struct ModelOutput {
16 length: usize,
17 columns: Vec<(String, ColumnData)>,
18}
19
20impl ModelOutput {
21 pub fn new(length: usize) -> Self {
22 Self {
23 length,
24 columns: Vec::new(),
25 }
26 }
27
28 pub fn add_f64(mut self, name: &str, data: Vec<f64>) -> Self {
29 self.columns
30 .push((name.to_string(), ColumnData::F64(data)));
31 self
32 }
33
34 pub fn add_i32(mut self, name: &str, data: Vec<i32>) -> Self {
35 self.columns
36 .push((name.to_string(), ColumnData::I32(data)));
37 self
38 }
39
40 pub fn add_u32(mut self, name: &str, data: Vec<u32>) -> Self {
41 self.columns
42 .push((name.to_string(), ColumnData::U32(data)));
43 self
44 }
45
46 pub fn add_bool(mut self, name: &str, data: Vec<bool>) -> Self {
47 self.columns
48 .push((name.to_string(), ColumnData::Bool(data)));
49 self
50 }
51
52 pub fn add_enum(mut self, name: &str, indices: Vec<u32>, labels: Vec<&str>) -> Self {
53 self.columns.push((
54 name.to_string(),
55 ColumnData::Enum {
56 indices,
57 labels: labels.into_iter().map(|s| s.to_string()).collect(),
58 },
59 ));
60 self
61 }
62
63 pub fn into_js(self) -> JsValue {
64 let obj = Object::new();
65 Reflect::set(&obj, &"__modelOutput".into(), &true.into()).unwrap();
66 Reflect::set(&obj, &"length".into(), &(self.length as f64).into()).unwrap();
67
68 let cols = Array::new();
69 let bufs = Array::new();
70
71 for (name, data) in self.columns {
72 let desc = Object::new();
73 Reflect::set(&desc, &"name".into(), &name.into()).unwrap();
74
75 match data {
76 ColumnData::F64(v) => {
77 Reflect::set(&desc, &"type".into(), &"f64".into()).unwrap();
78 let arr = Float64Array::from(v.as_slice());
79 bufs.push(&arr.buffer());
80 }
81 ColumnData::I32(v) => {
82 Reflect::set(&desc, &"type".into(), &"i32".into()).unwrap();
83 let arr = Int32Array::from(v.as_slice());
84 bufs.push(&arr.buffer());
85 }
86 ColumnData::U32(v) => {
87 Reflect::set(&desc, &"type".into(), &"u32".into()).unwrap();
88 let arr = Uint32Array::from(v.as_slice());
89 bufs.push(&arr.buffer());
90 }
91 ColumnData::Bool(v) => {
92 Reflect::set(&desc, &"type".into(), &"bool".into()).unwrap();
93 let bytes: Vec<u8> = v.iter().map(|b| if *b { 1 } else { 0 }).collect();
94 let arr = Uint8Array::from(bytes.as_slice());
95 bufs.push(&arr.buffer());
96 }
97 ColumnData::Enum { indices, labels } => {
98 Reflect::set(&desc, &"type".into(), &"enum".into()).unwrap();
99 let labels_arr = Array::new();
100 for l in &labels {
101 labels_arr.push(&JsValue::from_str(l));
102 }
103 Reflect::set(&desc, &"enumLabels".into(), &labels_arr).unwrap();
104 let arr = Uint32Array::from(indices.as_slice());
105 bufs.push(&arr.buffer());
106 }
107 }
108 cols.push(&desc);
109 }
110
111 Reflect::set(&obj, &"columns".into(), &cols).unwrap();
112 Reflect::set(&obj, &"buffers".into(), &bufs).unwrap();
113 obj.into()
114 }
115}
116
117pub fn model_outputs<const N: usize>(pairs: [(&str, ModelOutput); N]) -> JsValue {
119 let obj = Object::new();
120 Reflect::set(&obj, &"__modelOutputs".into(), &true.into()).unwrap();
121
122 let outputs = Object::new();
123 for (name, output) in pairs {
124 Reflect::set(&outputs, &name.into(), &output.into_js()).unwrap();
125 }
126 Reflect::set(&obj, &"outputs".into(), &outputs).unwrap();
127
128 obj.into()
129}