reductionml_core/
reduction_factory.rs1use std::{
2 any::Any,
3 fmt::{Display, Formatter},
4};
5
6use schemars::schema::RootSchema;
7use serde::{Deserialize, Serialize};
8
9use crate::{
10 error::{Error, Result},
11 global_config::GlobalConfig,
12 reduction::ReductionWrapper,
13 reduction_registry::REDUCTION_REGISTRY,
14 ModelIndex,
15};
16
17#[derive(Clone, Serialize, Deserialize)]
20#[serde(rename_all = "camelCase")]
21pub struct JsonReductionConfig {
22 typename: PascalCaseString,
23 config: serde_json::Value,
24}
25
26impl JsonReductionConfig {
27 pub fn new(typename: PascalCaseString, config: serde_json::Value) -> Self {
28 JsonReductionConfig { typename, config }
29 }
30}
31
32impl JsonReductionConfig {
57 pub fn typename(&self) -> String {
58 self.typename.as_ref().to_owned()
59 }
60 pub fn json_value(&self) -> &serde_json::Value {
61 &self.config
62 }
63}
64
65#[derive(Clone, Serialize, Deserialize, Debug)]
66pub struct PascalCaseString(String);
67
68impl TryFrom<String> for PascalCaseString {
69 type Error = Error;
70
71 fn try_from(value: String) -> std::result::Result<Self, Self::Error> {
72 if value.is_empty() {
73 return Err(Error::InvalidArgument(
74 "typename must not be empty".to_owned(),
75 ));
76 }
77 if !value.chars().next().unwrap().is_ascii_uppercase() {
78 return Err(Error::InvalidArgument(
79 "typename must start with an uppercase letter".to_owned(),
80 ));
81 }
82 if value.chars().any(|c| !c.is_ascii_alphanumeric()) {
83 return Err(Error::InvalidArgument(
84 "typename must only contain alphanumeric characters".to_owned(),
85 ));
86 }
87 Ok(PascalCaseString(value))
88 }
89}
90
91impl TryFrom<&str> for PascalCaseString {
92 type Error = Error;
93
94 fn try_from(value: &str) -> std::result::Result<Self, Self::Error> {
95 if value.is_empty() {
96 return Err(Error::InvalidArgument(
97 "typename must not be empty".to_owned(),
98 ));
99 }
100 if !value.chars().next().unwrap().is_ascii_uppercase() {
101 return Err(Error::InvalidArgument(
102 "typename must start with an uppercase letter".to_owned(),
103 ));
104 }
105 if value.chars().any(|c| !c.is_ascii_alphanumeric()) {
106 return Err(Error::InvalidArgument(
107 "typename must only contain alphanumeric characters".to_owned(),
108 ));
109 }
110 Ok(PascalCaseString(value.to_owned()))
111 }
112}
113
114impl From<PascalCaseString> for String {
115 fn from(value: PascalCaseString) -> Self {
116 value.0
117 }
118}
119
120impl AsRef<str> for PascalCaseString {
121 fn as_ref(&self) -> &str {
122 &self.0
123 }
124}
125
126impl Display for PascalCaseString {
127 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
128 write!(f, "{}", self.0.clone())
129 }
130}
131
132pub trait ReductionConfig: Any {
133 fn as_any(&self) -> &dyn Any;
134 fn typename(&self) -> PascalCaseString;
135}
136
137pub trait ReductionFactory {
138 fn parse_config(&self, value: &serde_json::Value) -> Result<Box<dyn ReductionConfig>>;
139 fn create(
140 &self,
141 config: &dyn ReductionConfig,
142 global_config: &GlobalConfig,
143 num_models_above: ModelIndex,
144 ) -> Result<ReductionWrapper>;
145 fn typename(&self) -> PascalCaseString;
146 fn get_config_schema(&self) -> RootSchema;
147 fn get_config_default(&self) -> serde_json::Value;
148 fn get_suggested_metrics(&self) -> Vec<String> {
149 vec![]
150 }
151}
152
153#[macro_export]
154macro_rules! impl_default_factory_functions {
155 ($typename: expr, $config_type: ident) => {
156 fn typename(&self) -> PascalCaseString {
157 $typename.try_into().unwrap()
158 }
159
160 fn parse_config(&self, value: &serde_json::Value) -> Result<Box<dyn ReductionConfig>> {
161 let res: $config_type = serde_json::from_value(value.clone())?;
162 Ok(Box::new(res))
163 }
164
165 fn get_config_schema(&self) -> RootSchema {
166 schema_for!($config_type)
167 }
168
169 fn get_config_default(&self) -> serde_json::Value {
170 serde_json::to_value($config_type::default()).unwrap()
171 }
172 };
173}
174
175pub fn parse_config(config: &JsonReductionConfig) -> Result<Box<dyn ReductionConfig>> {
176 match REDUCTION_REGISTRY
177 .read()
178 .unwrap()
179 .get(config.typename.as_ref())
180 {
181 Some(factory) => factory.parse_config(config.json_value()),
182 None => Err(crate::error::Error::InvalidArgument(format!(
183 "Unknown reduction type: {}",
184 &config.typename
185 ))),
186 }
187}
188
189pub fn create_reduction(
190 config: &dyn ReductionConfig,
191 global_config: &GlobalConfig,
192 num_models_above: ModelIndex,
193) -> Result<ReductionWrapper> {
194 match REDUCTION_REGISTRY
195 .read()
196 .unwrap()
197 .get(config.typename().as_ref())
198 {
199 Some(factory) => factory.create(config, global_config, num_models_above),
200 None => Err(crate::error::Error::InvalidArgument(format!(
201 "Unknown reduction type: {}",
202 config.typename()
203 ))),
204 }
205}