Skip to main content

cfasim_model/
lib.rs

1use js_sys::{Array, Float64Array, Int32Array, Object, Reflect, Uint32Array, Uint8Array};
2use wasm_bindgen::prelude::*;
3
4pub enum ColumnData {
5    F64(Vec<f64>),
6    I32(Vec<i32>),
7    U32(Vec<u32>),
8    Bool(Vec<bool>),
9    Enum {
10        indices: Vec<u32>,
11        labels: Vec<String>,
12    },
13}
14
15pub struct ModelOutput {
16    length: usize,
17    columns: Vec<(String, ColumnData)>,
18}
19
20impl ModelOutput {
21    pub fn new(length: usize) -> Self {
22        Self {
23            length,
24            columns: Vec::new(),
25        }
26    }
27
28    pub fn add_f64(mut self, name: &str, data: Vec<f64>) -> Self {
29        self.columns
30            .push((name.to_string(), ColumnData::F64(data)));
31        self
32    }
33
34    pub fn add_i32(mut self, name: &str, data: Vec<i32>) -> Self {
35        self.columns
36            .push((name.to_string(), ColumnData::I32(data)));
37        self
38    }
39
40    pub fn add_u32(mut self, name: &str, data: Vec<u32>) -> Self {
41        self.columns
42            .push((name.to_string(), ColumnData::U32(data)));
43        self
44    }
45
46    pub fn add_bool(mut self, name: &str, data: Vec<bool>) -> Self {
47        self.columns
48            .push((name.to_string(), ColumnData::Bool(data)));
49        self
50    }
51
52    pub fn add_enum(mut self, name: &str, indices: Vec<u32>, labels: Vec<&str>) -> Self {
53        self.columns.push((
54            name.to_string(),
55            ColumnData::Enum {
56                indices,
57                labels: labels.into_iter().map(|s| s.to_string()).collect(),
58            },
59        ));
60        self
61    }
62
63    pub fn into_js(self) -> JsValue {
64        let obj = Object::new();
65        Reflect::set(&obj, &"__modelOutput".into(), &true.into()).unwrap();
66        Reflect::set(&obj, &"length".into(), &(self.length as f64).into()).unwrap();
67
68        let cols = Array::new();
69        let bufs = Array::new();
70
71        for (name, data) in self.columns {
72            let desc = Object::new();
73            Reflect::set(&desc, &"name".into(), &name.into()).unwrap();
74
75            match data {
76                ColumnData::F64(v) => {
77                    Reflect::set(&desc, &"type".into(), &"f64".into()).unwrap();
78                    let arr = Float64Array::from(v.as_slice());
79                    bufs.push(&arr.buffer());
80                }
81                ColumnData::I32(v) => {
82                    Reflect::set(&desc, &"type".into(), &"i32".into()).unwrap();
83                    let arr = Int32Array::from(v.as_slice());
84                    bufs.push(&arr.buffer());
85                }
86                ColumnData::U32(v) => {
87                    Reflect::set(&desc, &"type".into(), &"u32".into()).unwrap();
88                    let arr = Uint32Array::from(v.as_slice());
89                    bufs.push(&arr.buffer());
90                }
91                ColumnData::Bool(v) => {
92                    Reflect::set(&desc, &"type".into(), &"bool".into()).unwrap();
93                    let bytes: Vec<u8> = v.iter().map(|b| if *b { 1 } else { 0 }).collect();
94                    let arr = Uint8Array::from(bytes.as_slice());
95                    bufs.push(&arr.buffer());
96                }
97                ColumnData::Enum { indices, labels } => {
98                    Reflect::set(&desc, &"type".into(), &"enum".into()).unwrap();
99                    let labels_arr = Array::new();
100                    for l in &labels {
101                        labels_arr.push(&JsValue::from_str(l));
102                    }
103                    Reflect::set(&desc, &"enumLabels".into(), &labels_arr).unwrap();
104                    let arr = Uint32Array::from(indices.as_slice());
105                    bufs.push(&arr.buffer());
106                }
107            }
108            cols.push(&desc);
109        }
110
111        Reflect::set(&obj, &"columns".into(), &cols).unwrap();
112        Reflect::set(&obj, &"buffers".into(), &bufs).unwrap();
113        obj.into()
114    }
115}
116
117/// Create a multi-output result from named ModelOutput pairs.
118pub fn model_outputs<const N: usize>(pairs: [(&str, ModelOutput); N]) -> JsValue {
119    let obj = Object::new();
120    Reflect::set(&obj, &"__modelOutputs".into(), &true.into()).unwrap();
121
122    let outputs = Object::new();
123    for (name, output) in pairs {
124        Reflect::set(&outputs, &name.into(), &output.into_js()).unwrap();
125    }
126    Reflect::set(&obj, &"outputs".into(), &outputs).unwrap();
127
128    obj.into()
129}