reductionml_core/
workspace.rs1use std::sync::Arc;
2
3use schemars::JsonSchema;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6
7use crate::{
8 dense_weights::{DenseWeights, DenseWeightsWithNDArray},
9 error::{Error, Result},
10 global_config::GlobalConfig,
11 object_pool::Pool,
12 reduction::{DepthInfo, ReductionWrapper},
13 reduction_factory::JsonReductionConfig,
14 sparse_namespaced_features::SparseFeatures,
15 types::{Features, Label, Prediction},
16};
17
18#[derive(Serialize, Deserialize)]
19pub struct Workspace {
20 global_config: GlobalConfig,
21 entry_reduction: ReductionWrapper,
22
23 #[serde(skip)]
24 features_pool: Arc<Pool<SparseFeatures>>,
25}
26
27#[derive(Serialize, Deserialize, JsonSchema)]
28#[serde(rename_all = "camelCase")]
29#[serde(deny_unknown_fields)]
30pub struct Configuration {
31 #[serde(rename = "$schema", default)]
33 _schema: Option<String>,
34 global_config: GlobalConfig,
35 #[schemars(schema_with = "crate::config_schema::gen_json_reduction_config_schema")]
36 entry_reduction: JsonReductionConfig,
37}
38
39impl Configuration {
40 pub fn from_json_str(json: &str) -> Result<Self> {
41 serde_json::from_str::<serde_json::Value>(json)?.try_into()
42 }
43
44 pub fn from_yaml_str(yaml: &str) -> Result<Self> {
45 serde_yaml::from_str::<serde_yaml::Value>(yaml)?.try_into()
46 }
47}
48
49impl TryInto<Configuration> for serde_json::Value {
50 type Error = Error;
51
52 fn try_into(self) -> std::result::Result<Configuration, Self::Error> {
53 serde_json::from_value(self)
54 .map_err(|e| Error::InvalidConfiguration(format!("Failed to parse configuration: {e}")))
55 }
56}
57
58impl TryInto<Configuration> for serde_yaml::Value {
59 type Error = Error;
60
61 fn try_into(self) -> std::result::Result<Configuration, Self::Error> {
62 serde_yaml::from_value::<serde_json::Value>(self)
63 .map_err(|e| Error::InvalidConfiguration(format!("Failed to parse yaml: {e}")))?
64 .try_into()
65 }
66}
67
68fn rewrite_json_ndarray_to_sparse(value: &mut serde_json::Value) {
70 match value {
71 serde_json::Value::Object(map) => {
72 if map.contains_key("weights")
73 && map.contains_key("feature_index_size")
74 && map.contains_key("model_index_size")
75 && map.contains_key("feature_state_size")
76 && map.contains_key("model_index_size_shift")
77 && map.contains_key("feature_state_size_shift")
78 {
79 let wts: DenseWeightsWithNDArray = serde_json::from_value(value.clone()).unwrap();
80 *value = serde_json::to_value(wts.to_dense_weights()).unwrap();
81 return;
82 }
83 for (_, v) in map {
84 rewrite_json_ndarray_to_sparse(v);
85 }
86 }
87 serde_json::Value::Array(vec) => {
88 for v in vec {
89 rewrite_json_ndarray_to_sparse(v);
90 }
91 }
92 _ => (),
93 }
94}
95
96fn rewrite_json_sparse_to_ndarray(value: &mut serde_json::Value) {
97 match value {
98 serde_json::Value::Object(map) => {
99 if map.contains_key("weights")
100 && map.contains_key("feature_index_size")
101 && map.contains_key("model_index_size")
102 && map.contains_key("feature_state_size")
103 && map.contains_key("model_index_size_shift")
104 && map.contains_key("feature_state_size_shift")
105 {
106 let wts: DenseWeights = serde_json::from_value(value.clone()).unwrap();
107 *value =
108 serde_json::to_value(DenseWeightsWithNDArray::from_dense_weights(wts)).unwrap();
109 return;
110 }
111 for (_, v) in map {
112 rewrite_json_sparse_to_ndarray(v);
113 }
114 }
115 serde_json::Value::Array(vec) => {
116 for v in vec {
117 rewrite_json_sparse_to_ndarray(v);
118 }
119 }
120 _ => (),
121 }
122}
123
124impl Workspace {
125 pub fn new(config: Configuration) -> Result<Workspace> {
126 let reduction_config = crate::reduction_factory::parse_config(&config.entry_reduction)?;
127 let entry_reduction = crate::reduction_factory::create_reduction(
128 reduction_config.as_ref(),
129 &config.global_config,
130 1.into(), )?;
132
133 Ok(Workspace {
134 global_config: config.global_config,
135 entry_reduction,
136 features_pool: Arc::new(Pool::new()),
137 })
138 }
139
140 pub fn create_from_model(json: &[u8]) -> Result<Workspace> {
142 let r = flexbuffers::Reader::get_root(json).unwrap();
143 Ok(Workspace::deserialize(r).unwrap())
144 }
145
146 pub fn serialize_model(&self) -> Result<Vec<u8>> {
148 let mut s = flexbuffers::FlexbufferSerializer::new();
149 self.serialize(&mut s).unwrap();
150 Ok(s.take_buffer())
151 }
152
153 pub fn serialize_to_json(&self) -> Result<Value> {
155 let mut value = serde_json::to_value(self).unwrap();
156 rewrite_json_sparse_to_ndarray(&mut value);
157 Ok(value)
158 }
159
160 pub fn deserialize_from_json(json: &Value) -> Result<Workspace> {
162 let mut value: serde_json::Value = serde_json::from_value(json.clone()).map_err(|e| {
163 Error::InvalidConfiguration(format!("Failed to parse configuration: {e}"))
164 })?;
165 rewrite_json_ndarray_to_sparse(&mut value);
166 serde_json::from_value(value)
167 .map_err(|e| Error::InvalidConfiguration(format!("Failed to parse model: {e}")))
168 }
169
170 pub fn predict(&self, features: &mut Features) -> Prediction {
171 let mut depth_info = DepthInfo::new();
172 self.entry_reduction
173 .predict(features, &mut depth_info, 0.into())
174 }
175
176 pub fn predict_then_learn(&mut self, features: &mut Features, label: &Label) -> Prediction {
177 let mut depth_info = DepthInfo::new();
178 self.entry_reduction
179 .predict_then_learn(features, label, &mut depth_info, 0.into())
180 }
181
182 pub fn learn(&mut self, features: &mut Features, label: &Label) {
183 let mut depth_info = DepthInfo::new();
184 self.entry_reduction
185 .learn(features, label, &mut depth_info, 0.into());
186 }
187
188 pub fn get_entry_reduction(&self) -> &ReductionWrapper {
189 &self.entry_reduction
190 }
191
192 pub fn global_config(&self) -> &GlobalConfig {
193 &self.global_config
194 }
195
196 pub fn features_pool(&self) -> &Arc<Pool<SparseFeatures>> {
197 &self.features_pool
198 }
199}
200
201#[cfg(test)]
202mod tests {
203
204 use approx::assert_relative_eq;
205 use serde_json::json;
206
207 use crate::{sparse_namespaced_features::SparseFeatures, utils::AsInner, ScalarPrediction};
208
209 use super::*;
210
211 #[test]
212 fn test_create_workspace() {
213 let config = json!(
214 {
215 "globalConfig": {
216 "numBits": 4
217 },
218 "entryReduction": {
219 "typename": "Coin",
220 "config": {
221 "alpha": 10
222 }
223 }
224 }
225 );
226
227 let mut workspace = Workspace::new(config.try_into().unwrap()).unwrap();
228
229 let mut features = SparseFeatures::new();
230 let ns =
231 features.get_or_create_namespace(crate::sparse_namespaced_features::Namespace::Default);
232 ns.add_feature(0.into(), 1.0);
233 ns.add_feature(2.into(), 1.0);
234 ns.add_feature(3.into(), 1.0);
235
236 let mut features = Features::SparseSimple(features);
237
238 let pred = workspace.predict(&mut features);
239 let scalar_pred: &ScalarPrediction = pred.as_inner().unwrap();
240 assert_relative_eq!(scalar_pred.prediction, 0.0);
241
242 let label = Label::Simple(0.5.into());
243
244 workspace.learn(&mut features, &label);
246 workspace.learn(&mut features, &label);
247
248 let pred = workspace.predict(&mut features);
249 let scalar_pred: &ScalarPrediction = pred.as_inner().unwrap();
250 assert_relative_eq!(scalar_pred.prediction, 0.5);
251 }
252}