use std::collections::{BTreeMap, HashSet};
use apollo_infra_utils::dumping::serialize_to_file;
use itertools::chain;
use serde::Serialize;
use serde_json::{json, Value};
use crate::{
ConfigError,
ParamPath,
ParamPrivacy,
ParamPrivacyInput,
SerializationType,
SerializedContent,
SerializedParam,
FIELD_SEPARATOR,
IS_NONE_MARK,
};
type PointerTarget = (ParamPath, SerializedParam);
pub type Pointers = HashSet<ParamPath>;
pub type ConfigPointers = Vec<(PointerTarget, Pointers)>;
pub fn generate_struct_pointer<T: SerializeConfig>(
target_prefix: ParamPath,
default_instance: &T,
pointer_prefixes: HashSet<ParamPath>,
) -> ConfigPointers {
let mut res = ConfigPointers::new();
for (param_path, serialized_param) in default_instance.dump() {
let pointer_target = serialized_param_to_pointer_target(
target_prefix.clone(),
¶m_path,
&serialized_param,
);
let pointers = pointer_prefixes
.iter()
.map(|pointer| chain_param_paths(&[pointer, ¶m_path]))
.collect();
res.push((pointer_target, pointers));
}
res
}
pub fn generate_optional_struct_pointer<T: SerializeConfig + Default>(
target_prefix: ParamPath,
default_instance: Option<&T>,
pointer_prefixes: HashSet<ParamPath>,
) -> ConfigPointers {
let default_instance_value = match default_instance {
Some(instance) => instance,
None => &T::default(),
};
let mut res = generate_struct_pointer(
target_prefix.clone(),
default_instance_value,
pointer_prefixes.clone(),
);
let pointer_target = ser_is_param_none(target_prefix.as_str(), default_instance.is_none());
let pointing_params: Pointers = pointer_prefixes
.into_iter()
.map(|prefix| format!("{prefix}{FIELD_SEPARATOR}{IS_NONE_MARK}"))
.collect();
res.push((pointer_target, pointing_params));
res
}
fn serialized_param_to_pointer_target(
target_prefix: ParamPath,
param_path: &ParamPath,
serialized_param: &SerializedParam,
) -> PointerTarget {
let full_param_path = chain_param_paths(&[&target_prefix, param_path]);
if serialized_param.is_required() {
let description = serialized_param
.description
.strip_prefix(REQUIRED_PARAM_DESCRIPTION_PREFIX)
.unwrap_or(&serialized_param.description)
.trim_start();
ser_pointer_target_required_param(
&full_param_path,
serialized_param.content.get_serialization_type().unwrap(),
description,
)
} else {
let default_value = match &serialized_param.content {
SerializedContent::DefaultValue(value) => value,
SerializedContent::PointerTarget(_) => panic!("Pointers to pointer is not supported."),
SerializedContent::ParamType(_) => {
panic!("Generated pointer targets are not supported.")
}
};
ser_pointer_target_param(&full_param_path, default_value, &serialized_param.description)
}
}
fn chain_param_paths(param_paths: &[&str]) -> ParamPath {
param_paths.join(FIELD_SEPARATOR)
}
pub trait SerializeConfig {
fn dump(&self) -> BTreeMap<ParamPath, SerializedParam>;
fn dump_to_file(
&self,
config_pointers: &ConfigPointers,
non_pointer_params: &Pointers,
file_path: &str,
) -> Result<(), ConfigError> {
let combined_map =
combine_config_map_and_pointers(self.dump(), config_pointers, non_pointer_params)?;
serialize_to_file(&combined_map, file_path);
Ok(())
}
}
pub fn prepend_sub_config_name(
sub_config_dump: BTreeMap<ParamPath, SerializedParam>,
sub_config_name: &str,
) -> BTreeMap<ParamPath, SerializedParam> {
BTreeMap::from_iter(
sub_config_dump.into_iter().map(|(field_name, val)| {
(format!("{sub_config_name}{FIELD_SEPARATOR}{field_name}"), val)
}),
)
}
fn common_ser_param(
name: &str,
content: SerializedContent,
description: &str,
privacy: ParamPrivacy,
) -> (String, SerializedParam) {
(name.to_owned(), SerializedParam { description: description.to_owned(), content, privacy })
}
pub fn ser_param<T: Serialize>(
name: &str,
value: &T,
description: &str,
privacy: ParamPrivacyInput,
) -> (String, SerializedParam) {
common_ser_param(
name,
SerializedContent::DefaultValue(json!(value)),
description,
privacy.into(),
)
}
pub fn ser_required_param(
name: &str,
serialization_type: SerializationType,
description: &str,
privacy: ParamPrivacyInput,
) -> (String, SerializedParam) {
common_ser_param(
name,
SerializedContent::ParamType(serialization_type),
required_param_description(description).as_str(),
privacy.into(),
)
}
pub fn ser_generated_param(
name: &str,
serialization_type: SerializationType,
description: &str,
privacy: ParamPrivacyInput,
) -> (String, SerializedParam) {
common_ser_param(
name,
SerializedContent::ParamType(serialization_type),
format!("{description} If no value is provided, the system will generate one.").as_str(),
privacy.into(),
)
}
pub fn ser_optional_sub_config<T: SerializeConfig + Default>(
optional_config: &Option<T>,
name: &str,
) -> BTreeMap<ParamPath, SerializedParam> {
chain!(
BTreeMap::from_iter([ser_is_param_none(name, optional_config.is_none())]),
prepend_sub_config_name(
match optional_config {
None => T::default().dump(),
Some(config) => config.dump(),
},
name,
),
)
.collect()
}
pub fn ser_optional_param<T: Serialize>(
optional_param: &Option<T>,
default_value: T,
name: &str,
description: &str,
privacy: ParamPrivacyInput,
) -> BTreeMap<ParamPath, SerializedParam> {
BTreeMap::from([
ser_is_param_none(name, optional_param.is_none()),
ser_param(
name,
match optional_param {
Some(param) => param,
None => &default_value,
},
description,
privacy,
),
])
}
pub fn ser_is_param_none(name: &str, is_none: bool) -> (String, SerializedParam) {
common_ser_param(
format!("{name}{FIELD_SEPARATOR}{IS_NONE_MARK}").as_str(),
SerializedContent::DefaultValue(json!(is_none)),
"Flag for an optional field.",
ParamPrivacy::TemporaryValue,
)
}
pub fn ser_pointer_target_param<T: Serialize>(
name: &str,
value: &T,
description: &str,
) -> (String, SerializedParam) {
common_ser_param(
name,
SerializedContent::DefaultValue(json!(value)),
description,
ParamPrivacy::TemporaryValue,
)
}
pub fn ser_pointer_target_required_param(
name: &str,
serialization_type: SerializationType,
description: &str,
) -> (String, SerializedParam) {
common_ser_param(
name,
SerializedContent::ParamType(serialization_type),
required_param_description(description).as_str(),
ParamPrivacy::TemporaryValue,
)
}
pub fn combine_config_map_and_pointers(
mut config_map: BTreeMap<ParamPath, SerializedParam>,
pointers: &ConfigPointers,
non_pointer_params: &Pointers,
) -> Result<Value, ConfigError> {
for ((target_param, serialized_pointer), pointing_params_vec) in pointers {
config_map.insert(target_param.clone(), serialized_pointer.clone());
for pointing_param in pointing_params_vec {
let pointing_serialized_param =
config_map.get(pointing_param).ok_or(ConfigError::PointerSourceNotFound {
pointing_param: pointing_param.to_owned(),
})?;
config_map.insert(
pointing_param.to_owned(),
SerializedParam {
description: pointing_serialized_param.description.clone(),
content: SerializedContent::PointerTarget(target_param.to_owned()),
privacy: pointing_serialized_param.privacy.clone(),
},
);
}
}
verify_pointing_params_by_name(&config_map, pointers, non_pointer_params);
Ok(json!(config_map))
}
pub fn set_pointing_param_paths(param_path_list: &[&str]) -> Pointers {
let mut param_paths = HashSet::new();
for ¶m_path in param_path_list {
assert!(
param_paths.insert(param_path.to_string()),
"Duplicate parameter path found: {param_path}"
);
}
param_paths
}
pub(crate) const REQUIRED_PARAM_DESCRIPTION_PREFIX: &str = "A required param!";
pub(crate) fn required_param_description(description: &str) -> String {
format!("{REQUIRED_PARAM_DESCRIPTION_PREFIX} {description}")
}
fn verify_pointing_params_by_name(
config_map: &BTreeMap<ParamPath, SerializedParam>,
pointers: &ConfigPointers,
non_pointer_params: &Pointers,
) {
config_map.iter().for_each(|(param_path, serialized_param)| {
for ((target_param, _), _) in pointers {
if param_path.ends_with(format!("{FIELD_SEPARATOR}{target_param}").as_str())
&& !non_pointer_params.contains(param_path)
{
assert!(
serialized_param.content
== SerializedContent::PointerTarget(target_param.to_owned()),
"The target param {param_path} should point to {target_param}, or to be \
whitelisted. You can use set_pointing_param_paths to point it to a value."
);
};
}
});
}