reductionml_core/
workspace.rs

1use std::sync::Arc;
2
3use schemars::JsonSchema;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6
7use crate::{
8    dense_weights::{DenseWeights, DenseWeightsWithNDArray},
9    error::{Error, Result},
10    global_config::GlobalConfig,
11    object_pool::Pool,
12    reduction::{DepthInfo, ReductionWrapper},
13    reduction_factory::JsonReductionConfig,
14    sparse_namespaced_features::SparseFeatures,
15    types::{Features, Label, Prediction},
16};
17
18#[derive(Serialize, Deserialize)]
19pub struct Workspace {
20    global_config: GlobalConfig,
21    entry_reduction: ReductionWrapper,
22
23    #[serde(skip)]
24    features_pool: Arc<Pool<SparseFeatures>>,
25}
26
27#[derive(Serialize, Deserialize, JsonSchema)]
28#[serde(rename_all = "camelCase")]
29#[serde(deny_unknown_fields)]
30pub struct Configuration {
31    // $schema,
32    #[serde(rename = "$schema", default)]
33    _schema: Option<String>,
34    global_config: GlobalConfig,
35    #[schemars(schema_with = "crate::config_schema::gen_json_reduction_config_schema")]
36    entry_reduction: JsonReductionConfig,
37}
38
39impl Configuration {
40    pub fn from_json_str(json: &str) -> Result<Self> {
41        serde_json::from_str::<serde_json::Value>(json)?.try_into()
42    }
43
44    pub fn from_yaml_str(yaml: &str) -> Result<Self> {
45        serde_yaml::from_str::<serde_yaml::Value>(yaml)?.try_into()
46    }
47}
48
49impl TryInto<Configuration> for serde_json::Value {
50    type Error = Error;
51
52    fn try_into(self) -> std::result::Result<Configuration, Self::Error> {
53        serde_json::from_value(self)
54            .map_err(|e| Error::InvalidConfiguration(format!("Failed to parse configuration: {e}")))
55    }
56}
57
58impl TryInto<Configuration> for serde_yaml::Value {
59    type Error = Error;
60
61    fn try_into(self) -> std::result::Result<Configuration, Self::Error> {
62        serde_yaml::from_value::<serde_json::Value>(self)
63            .map_err(|e| Error::InvalidConfiguration(format!("Failed to parse yaml: {e}")))?
64            .try_into()
65    }
66}
67
68// We need to search until we find an object with the keys weights, feature_index_size, model_index_size, feature_state_size, model_index_size_shift, feature_state_size_shift
69fn rewrite_json_ndarray_to_sparse(value: &mut serde_json::Value) {
70    match value {
71        serde_json::Value::Object(map) => {
72            if map.contains_key("weights")
73                && map.contains_key("feature_index_size")
74                && map.contains_key("model_index_size")
75                && map.contains_key("feature_state_size")
76                && map.contains_key("model_index_size_shift")
77                && map.contains_key("feature_state_size_shift")
78            {
79                let wts: DenseWeightsWithNDArray = serde_json::from_value(value.clone()).unwrap();
80                *value = serde_json::to_value(wts.to_dense_weights()).unwrap();
81                return;
82            }
83            for (_, v) in map {
84                rewrite_json_ndarray_to_sparse(v);
85            }
86        }
87        serde_json::Value::Array(vec) => {
88            for v in vec {
89                rewrite_json_ndarray_to_sparse(v);
90            }
91        }
92        _ => (),
93    }
94}
95
96fn rewrite_json_sparse_to_ndarray(value: &mut serde_json::Value) {
97    match value {
98        serde_json::Value::Object(map) => {
99            if map.contains_key("weights")
100                && map.contains_key("feature_index_size")
101                && map.contains_key("model_index_size")
102                && map.contains_key("feature_state_size")
103                && map.contains_key("model_index_size_shift")
104                && map.contains_key("feature_state_size_shift")
105            {
106                let wts: DenseWeights = serde_json::from_value(value.clone()).unwrap();
107                *value =
108                    serde_json::to_value(DenseWeightsWithNDArray::from_dense_weights(wts)).unwrap();
109                return;
110            }
111            for (_, v) in map {
112                rewrite_json_sparse_to_ndarray(v);
113            }
114        }
115        serde_json::Value::Array(vec) => {
116            for v in vec {
117                rewrite_json_sparse_to_ndarray(v);
118            }
119        }
120        _ => (),
121    }
122}
123
124impl Workspace {
125    pub fn new(config: Configuration) -> Result<Workspace> {
126        let reduction_config = crate::reduction_factory::parse_config(&config.entry_reduction)?;
127        let entry_reduction = crate::reduction_factory::create_reduction(
128            reduction_config.as_ref(),
129            &config.global_config,
130            1.into(), // Top of the stack must be passed as 1
131        )?;
132
133        Ok(Workspace {
134            global_config: config.global_config,
135            entry_reduction,
136            features_pool: Arc::new(Pool::new()),
137        })
138    }
139
140    // TODO move to bincode or msgpack
141    pub fn create_from_model(json: &[u8]) -> Result<Workspace> {
142        let r = flexbuffers::Reader::get_root(json).unwrap();
143        Ok(Workspace::deserialize(r).unwrap())
144    }
145
146    // TODO move to bincode or msgpack
147    pub fn serialize_model(&self) -> Result<Vec<u8>> {
148        let mut s = flexbuffers::FlexbufferSerializer::new();
149        self.serialize(&mut s).unwrap();
150        Ok(s.take_buffer())
151    }
152
153    // experimental
154    pub fn serialize_to_json(&self) -> Result<Value> {
155        let mut value = serde_json::to_value(self).unwrap();
156        rewrite_json_sparse_to_ndarray(&mut value);
157        Ok(value)
158    }
159
160    // experimental
161    pub fn deserialize_from_json(json: &Value) -> Result<Workspace> {
162        let mut value: serde_json::Value = serde_json::from_value(json.clone()).map_err(|e| {
163            Error::InvalidConfiguration(format!("Failed to parse configuration: {e}"))
164        })?;
165        rewrite_json_ndarray_to_sparse(&mut value);
166        serde_json::from_value(value)
167            .map_err(|e| Error::InvalidConfiguration(format!("Failed to parse model: {e}")))
168    }
169
170    pub fn predict(&self, features: &mut Features) -> Prediction {
171        let mut depth_info = DepthInfo::new();
172        self.entry_reduction
173            .predict(features, &mut depth_info, 0.into())
174    }
175
176    pub fn predict_then_learn(&mut self, features: &mut Features, label: &Label) -> Prediction {
177        let mut depth_info = DepthInfo::new();
178        self.entry_reduction
179            .predict_then_learn(features, label, &mut depth_info, 0.into())
180    }
181
182    pub fn learn(&mut self, features: &mut Features, label: &Label) {
183        let mut depth_info = DepthInfo::new();
184        self.entry_reduction
185            .learn(features, label, &mut depth_info, 0.into());
186    }
187
188    pub fn get_entry_reduction(&self) -> &ReductionWrapper {
189        &self.entry_reduction
190    }
191
192    pub fn global_config(&self) -> &GlobalConfig {
193        &self.global_config
194    }
195
196    pub fn features_pool(&self) -> &Arc<Pool<SparseFeatures>> {
197        &self.features_pool
198    }
199}
200
201#[cfg(test)]
202mod tests {
203
204    use approx::assert_relative_eq;
205    use serde_json::json;
206
207    use crate::{sparse_namespaced_features::SparseFeatures, utils::AsInner, ScalarPrediction};
208
209    use super::*;
210
211    #[test]
212    fn test_create_workspace() {
213        let config = json!(
214            {
215                "globalConfig": {
216                    "numBits": 4
217                },
218                "entryReduction": {
219                    "typename": "Coin",
220                    "config": {
221                        "alpha": 10
222                    }
223                }
224            }
225        );
226
227        let mut workspace = Workspace::new(config.try_into().unwrap()).unwrap();
228
229        let mut features = SparseFeatures::new();
230        let ns =
231            features.get_or_create_namespace(crate::sparse_namespaced_features::Namespace::Default);
232        ns.add_feature(0.into(), 1.0);
233        ns.add_feature(2.into(), 1.0);
234        ns.add_feature(3.into(), 1.0);
235
236        let mut features = Features::SparseSimple(features);
237
238        let pred = workspace.predict(&mut features);
239        let scalar_pred: &ScalarPrediction = pred.as_inner().unwrap();
240        assert_relative_eq!(scalar_pred.prediction, 0.0);
241
242        let label = Label::Simple(0.5.into());
243
244        // For some reason two calls to learn are required to get a non-zero prediction for coin?
245        workspace.learn(&mut features, &label);
246        workspace.learn(&mut features, &label);
247
248        let pred = workspace.predict(&mut features);
249        let scalar_pred: &ScalarPrediction = pred.as_inner().unwrap();
250        assert_relative_eq!(scalar_pred.prediction, 0.5);
251    }
252}