reductionml_core/
reduction_factory.rs

1use std::{
2    any::Any,
3    fmt::{Display, Formatter},
4};
5
6use schemars::schema::RootSchema;
7use serde::{Deserialize, Serialize};
8
9use crate::{
10    error::{Error, Result},
11    global_config::GlobalConfig,
12    reduction::ReductionWrapper,
13    reduction_registry::REDUCTION_REGISTRY,
14    ModelIndex,
15};
16
17// This intentionally does not derive JsonSchema
18// Use gen_json_reduction_config_schema instead with schema_with
19#[derive(Clone, Serialize, Deserialize)]
20#[serde(rename_all = "camelCase")]
21pub struct JsonReductionConfig {
22    typename: PascalCaseString,
23    config: serde_json::Value,
24}
25
26impl JsonReductionConfig {
27    pub fn new(typename: PascalCaseString, config: serde_json::Value) -> Self {
28        JsonReductionConfig { typename, config }
29    }
30}
31
32// impl TryFrom<serde_json::Value> for JsonReductionConfig {
33//     type Error = Error;
34
35//     fn try_from(value: serde_json::Value) -> std::result::Result<Self, Self::Error> {
36//         let typename = value["typename"]
37//             .as_str()
38//             .ok_or(Error::InvalidArgument(
39//                 "typename must be a string".to_owned(),
40//             ))?
41//             .to_string();
42//         let config = value["config"].clone();
43//         Ok(JsonReductionConfig { typename, config })
44//     }
45// }
46
47// impl Into<serde_json::Value> for JsonReductionConfig {
48//     fn into(self) -> serde_json::Value {
49//         json!({
50//             "typename": self.typename,
51//             "config": self.config
52//         })
53//     }
54// }
55
56impl JsonReductionConfig {
57    pub fn typename(&self) -> String {
58        self.typename.as_ref().to_owned()
59    }
60    pub fn json_value(&self) -> &serde_json::Value {
61        &self.config
62    }
63}
64
65#[derive(Clone, Serialize, Deserialize, Debug)]
66pub struct PascalCaseString(String);
67
68impl TryFrom<String> for PascalCaseString {
69    type Error = Error;
70
71    fn try_from(value: String) -> std::result::Result<Self, Self::Error> {
72        if value.is_empty() {
73            return Err(Error::InvalidArgument(
74                "typename must not be empty".to_owned(),
75            ));
76        }
77        if !value.chars().next().unwrap().is_ascii_uppercase() {
78            return Err(Error::InvalidArgument(
79                "typename must start with an uppercase letter".to_owned(),
80            ));
81        }
82        if value.chars().any(|c| !c.is_ascii_alphanumeric()) {
83            return Err(Error::InvalidArgument(
84                "typename must only contain alphanumeric characters".to_owned(),
85            ));
86        }
87        Ok(PascalCaseString(value))
88    }
89}
90
91impl TryFrom<&str> for PascalCaseString {
92    type Error = Error;
93
94    fn try_from(value: &str) -> std::result::Result<Self, Self::Error> {
95        if value.is_empty() {
96            return Err(Error::InvalidArgument(
97                "typename must not be empty".to_owned(),
98            ));
99        }
100        if !value.chars().next().unwrap().is_ascii_uppercase() {
101            return Err(Error::InvalidArgument(
102                "typename must start with an uppercase letter".to_owned(),
103            ));
104        }
105        if value.chars().any(|c| !c.is_ascii_alphanumeric()) {
106            return Err(Error::InvalidArgument(
107                "typename must only contain alphanumeric characters".to_owned(),
108            ));
109        }
110        Ok(PascalCaseString(value.to_owned()))
111    }
112}
113
114impl From<PascalCaseString> for String {
115    fn from(value: PascalCaseString) -> Self {
116        value.0
117    }
118}
119
120impl AsRef<str> for PascalCaseString {
121    fn as_ref(&self) -> &str {
122        &self.0
123    }
124}
125
126impl Display for PascalCaseString {
127    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
128        write!(f, "{}", self.0.clone())
129    }
130}
131
132pub trait ReductionConfig: Any {
133    fn as_any(&self) -> &dyn Any;
134    fn typename(&self) -> PascalCaseString;
135}
136
137pub trait ReductionFactory {
138    fn parse_config(&self, value: &serde_json::Value) -> Result<Box<dyn ReductionConfig>>;
139    fn create(
140        &self,
141        config: &dyn ReductionConfig,
142        global_config: &GlobalConfig,
143        num_models_above: ModelIndex,
144    ) -> Result<ReductionWrapper>;
145    fn typename(&self) -> PascalCaseString;
146    fn get_config_schema(&self) -> RootSchema;
147    fn get_config_default(&self) -> serde_json::Value;
148    fn get_suggested_metrics(&self) -> Vec<String> {
149        vec![]
150    }
151}
152
153#[macro_export]
154macro_rules! impl_default_factory_functions {
155    ($typename: expr, $config_type: ident) => {
156        fn typename(&self) -> PascalCaseString {
157            $typename.try_into().unwrap()
158        }
159
160        fn parse_config(&self, value: &serde_json::Value) -> Result<Box<dyn ReductionConfig>> {
161            let res: $config_type = serde_json::from_value(value.clone())?;
162            Ok(Box::new(res))
163        }
164
165        fn get_config_schema(&self) -> RootSchema {
166            schema_for!($config_type)
167        }
168
169        fn get_config_default(&self) -> serde_json::Value {
170            serde_json::to_value($config_type::default()).unwrap()
171        }
172    };
173}
174
175pub fn parse_config(config: &JsonReductionConfig) -> Result<Box<dyn ReductionConfig>> {
176    match REDUCTION_REGISTRY
177        .read()
178        .unwrap()
179        .get(config.typename.as_ref())
180    {
181        Some(factory) => factory.parse_config(config.json_value()),
182        None => Err(crate::error::Error::InvalidArgument(format!(
183            "Unknown reduction type: {}",
184            &config.typename
185        ))),
186    }
187}
188
189pub fn create_reduction(
190    config: &dyn ReductionConfig,
191    global_config: &GlobalConfig,
192    num_models_above: ModelIndex,
193) -> Result<ReductionWrapper> {
194    match REDUCTION_REGISTRY
195        .read()
196        .unwrap()
197        .get(config.typename().as_ref())
198    {
199        Some(factory) => factory.create(config, global_config, num_models_above),
200        None => Err(crate::error::Error::InvalidArgument(format!(
201            "Unknown reduction type: {}",
202            config.typename()
203        ))),
204    }
205}