Skip to main content

cfasim_model/
lib.rs

1use js_sys::{Array, Float64Array, Int32Array, Object, Reflect, Uint32Array, Uint8Array};
2use wasm_bindgen::prelude::*;
3
4pub enum ColumnData {
5    F64(Vec<f64>),
6    I32(Vec<i32>),
7    U32(Vec<u32>),
8    Bool(Vec<bool>),
9    Enum {
10        indices: Vec<u32>,
11        labels: Vec<String>,
12    },
13}
14
15pub struct ModelOutput {
16    length: usize,
17    columns: Vec<(String, ColumnData)>,
18}
19
20impl ModelOutput {
21    pub fn new(length: usize) -> Self {
22        Self {
23            length,
24            columns: Vec::new(),
25        }
26    }
27
28    pub fn add_f64(mut self, name: &str, data: Vec<f64>) -> Self {
29        self.columns.push((name.to_string(), ColumnData::F64(data)));
30        self
31    }
32
33    pub fn add_i32(mut self, name: &str, data: Vec<i32>) -> Self {
34        self.columns.push((name.to_string(), ColumnData::I32(data)));
35        self
36    }
37
38    pub fn add_u32(mut self, name: &str, data: Vec<u32>) -> Self {
39        self.columns.push((name.to_string(), ColumnData::U32(data)));
40        self
41    }
42
43    pub fn add_bool(mut self, name: &str, data: Vec<bool>) -> Self {
44        self.columns
45            .push((name.to_string(), ColumnData::Bool(data)));
46        self
47    }
48
49    pub fn add_enum(mut self, name: &str, indices: Vec<u32>, labels: Vec<&str>) -> Self {
50        self.columns.push((
51            name.to_string(),
52            ColumnData::Enum {
53                indices,
54                labels: labels.into_iter().map(|s| s.to_string()).collect(),
55            },
56        ));
57        self
58    }
59
60    pub fn into_js(self) -> JsValue {
61        let obj = Object::new();
62        Reflect::set(&obj, &"__modelOutput".into(), &true.into()).unwrap();
63        Reflect::set(&obj, &"length".into(), &(self.length as f64).into()).unwrap();
64
65        let cols = Array::new();
66        let bufs = Array::new();
67
68        for (name, data) in self.columns {
69            let desc = Object::new();
70            Reflect::set(&desc, &"name".into(), &name.into()).unwrap();
71
72            match data {
73                ColumnData::F64(v) => {
74                    Reflect::set(&desc, &"type".into(), &"f64".into()).unwrap();
75                    let arr = Float64Array::from(v.as_slice());
76                    bufs.push(&arr.buffer());
77                }
78                ColumnData::I32(v) => {
79                    Reflect::set(&desc, &"type".into(), &"i32".into()).unwrap();
80                    let arr = Int32Array::from(v.as_slice());
81                    bufs.push(&arr.buffer());
82                }
83                ColumnData::U32(v) => {
84                    Reflect::set(&desc, &"type".into(), &"u32".into()).unwrap();
85                    let arr = Uint32Array::from(v.as_slice());
86                    bufs.push(&arr.buffer());
87                }
88                ColumnData::Bool(v) => {
89                    Reflect::set(&desc, &"type".into(), &"bool".into()).unwrap();
90                    let bytes: Vec<u8> = v.iter().map(|b| if *b { 1 } else { 0 }).collect();
91                    let arr = Uint8Array::from(bytes.as_slice());
92                    bufs.push(&arr.buffer());
93                }
94                ColumnData::Enum { indices, labels } => {
95                    Reflect::set(&desc, &"type".into(), &"enum".into()).unwrap();
96                    let labels_arr = Array::new();
97                    for l in &labels {
98                        labels_arr.push(&JsValue::from_str(l));
99                    }
100                    Reflect::set(&desc, &"enumLabels".into(), &labels_arr).unwrap();
101                    let arr = Uint32Array::from(indices.as_slice());
102                    bufs.push(&arr.buffer());
103                }
104            }
105            cols.push(&desc);
106        }
107
108        Reflect::set(&obj, &"columns".into(), &cols).unwrap();
109        Reflect::set(&obj, &"buffers".into(), &bufs).unwrap();
110        obj.into()
111    }
112}
113
114/// Create a multi-output result from named ModelOutput pairs.
115pub fn model_outputs<const N: usize>(pairs: [(&str, ModelOutput); N]) -> JsValue {
116    let obj = Object::new();
117    Reflect::set(&obj, &"__modelOutputs".into(), &true.into()).unwrap();
118
119    let outputs = Object::new();
120    for (name, output) in pairs {
121        Reflect::set(&outputs, &name.into(), &output.into_js()).unwrap();
122    }
123    Reflect::set(&obj, &"outputs".into(), &outputs).unwrap();
124
125    obj.into()
126}