use std::any::Any;
use std::cell::RefCell;
use std::error::Error;
use std::fmt::Debug;
use std::fs;
use std::io::BufReader;
use std::path::Path;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, LazyLock, Mutex};
use serde::de::DeserializeOwned;
use crate::context::Context;
use crate::error::IxaError;
use crate::{trace, ContextBase, HashMap, HashMapExt};
type PropertySetterFn =
dyn Fn(&mut Context, &str, serde_json::Value) -> Result<(), IxaError> + Send + Sync;
#[allow(clippy::type_complexity)]
#[doc(hidden)]
pub static GLOBAL_PROPERTIES: LazyLock<Mutex<RefCell<HashMap<String, Arc<PropertySetterFn>>>>> =
LazyLock::new(|| Mutex::new(RefCell::new(HashMap::new())));
static NEXT_GLOBAL_PROPERTY_ID: Mutex<usize> = Mutex::new(0);
pub fn get_global_property_count() -> usize {
*NEXT_GLOBAL_PROPERTY_ID.lock().unwrap()
}
pub fn initialize_global_property_id(global_property_id: &AtomicUsize) -> usize {
let mut guard = NEXT_GLOBAL_PROPERTY_ID.lock().unwrap();
let candidate = *guard;
match global_property_id.compare_exchange(
usize::MAX,
candidate,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
*guard += 1;
candidate
}
Err(existing) => existing,
}
}
#[allow(clippy::missing_panics_doc)]
pub fn add_global_property<T: GlobalProperty>(name: &str)
where
for<'de> <T as GlobalProperty>::Value: serde::Deserialize<'de>,
{
trace!("Adding global property {name}");
let properties = GLOBAL_PROPERTIES.lock().unwrap();
properties
.borrow_mut()
.insert(
name.to_string(),
Arc::new(
|context: &mut Context, name, value| -> Result<(), IxaError> {
let val: T::Value = serde_json::from_value(value)?;
if context.get_global_property_value(T::new()).is_some() {
return Err(IxaError::DuplicateProperty {
name: name.to_string(),
});
}
context.set_global_property_value(T::new(), val)?;
Ok(())
},
),
)
.inspect(|_| panic!("Duplicate global property {}", name));
}
fn get_global_property_setter(name: &str) -> Option<Arc<PropertySetterFn>> {
let properties = GLOBAL_PROPERTIES.lock().unwrap();
let tmp = properties.borrow();
tmp.get(name).map(Arc::clone)
}
fn get_global_property_setter_for_config_key(name: &str) -> Option<Arc<PropertySetterFn>> {
get_global_property_setter(name).or_else(|| {
if name.contains('-') {
get_global_property_setter(&name.replace('-', "_"))
} else {
None
}
})
}
pub trait GlobalProperty: Any {
type Value: Any;
fn id() -> usize;
fn new() -> Self;
fn name() -> &'static str {
let full = std::any::type_name::<Self>();
full.rsplit("::").next().unwrap()
}
fn validate(value: &Self::Value) -> Result<(), Box<dyn Error + Send + Sync + 'static>>;
}
pub trait ContextGlobalPropertiesExt: ContextBase {
fn set_global_property_value<T: GlobalProperty + 'static>(
&mut self,
property: T,
value: T::Value,
) -> Result<(), IxaError>;
fn get_global_property_value<T: GlobalProperty + 'static>(
&self,
_property: T,
) -> Option<&T::Value>;
fn load_parameters_from_json<T: 'static + Debug + DeserializeOwned>(
&mut self,
file_name: &Path,
) -> Result<T, IxaError> {
trace!("Loading parameters from JSON: {file_name:?}");
let config_file = fs::File::open(file_name)?;
let reader = BufReader::new(config_file);
let config = serde_json::from_reader(reader)?;
Ok(config)
}
fn load_global_properties(&mut self, file_name: &Path) -> Result<(), IxaError>;
}
impl ContextGlobalPropertiesExt for Context {
fn set_global_property_value<T: GlobalProperty + 'static>(
&mut self,
_property: T,
value: T::Value,
) -> Result<(), IxaError> {
T::validate(&value).map_err(|source| IxaError::IllegalGlobalPropertyValue {
name: T::name().to_string(),
source,
})?;
let index = T::id();
let cell = self.global_properties.get_mut(index).unwrap_or_else(|| {
panic!(
"No global property found with index = {index:?}. You must use the \
`define_global_property!` macro to create a global property."
)
});
if cell.get().is_some() {
return Err(IxaError::EntryAlreadyExists);
}
let _ = cell.set(Box::new(value));
Ok(())
}
fn get_global_property_value<T: GlobalProperty + 'static>(
&self,
_property: T,
) -> Option<&T::Value> {
let index = T::id();
self.global_properties
.get(index)
.unwrap_or_else(|| {
panic!(
"No global property found with index = {index:?}. You must use the \
`define_global_property!` macro to create a global property."
)
})
.get()
.map(|property| {
property.downcast_ref::<T::Value>().expect(
"TypeID does not match global property type. You must use the \
`define_global_property!` macro to create a global property.",
)
})
}
fn load_global_properties(&mut self, file_name: &Path) -> Result<(), IxaError> {
trace!("Loading global properties from {file_name:?}");
let config_file = fs::File::open(file_name)?;
let reader = BufReader::new(config_file);
let val: serde_json::Map<String, serde_json::Value> = serde_json::from_reader(reader)?;
for (k, v) in val {
if let Some(setter) = get_global_property_setter_for_config_key(&k) {
setter(self, &k, v)?;
} else {
return Err(IxaError::NoGlobalProperty { name: k });
}
}
Ok(())
}
}
#[cfg(test)]
mod test {
use std::error::Error;
use std::fmt;
use std::path::PathBuf;
use serde::{Deserialize, Serialize};
use tempfile::tempdir;
use super::*;
use crate::context::Context;
use crate::define_global_property;
use crate::error::IxaError;
#[derive(Debug)]
struct InvalidProperty3Value {
field_int: u32,
}
impl fmt::Display for InvalidProperty3Value {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "field_int must be zero, got {}", self.field_int)
}
}
impl Error for InvalidProperty3Value {}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ParamType {
pub days: usize,
pub diseases: usize,
}
define_global_property!(DiseaseParams, ParamType);
#[test]
fn set_get_global_property() {
let params: ParamType = ParamType {
days: 10,
diseases: 2,
};
let params2: ParamType = ParamType {
days: 11,
diseases: 3,
};
let mut context = Context::new();
context
.set_global_property_value(DiseaseParams, params.clone())
.unwrap();
let global_params = context
.get_global_property_value(DiseaseParams)
.unwrap()
.clone();
assert_eq!(global_params.days, params.days);
assert_eq!(global_params.diseases, params.diseases);
assert!(context
.set_global_property_value(DiseaseParams, params2.clone())
.is_err());
let global_params = context
.get_global_property_value(DiseaseParams)
.unwrap()
.clone();
assert_eq!(global_params.days, params.days);
assert_eq!(global_params.diseases, params.diseases);
}
#[test]
fn get_global_propert_missing() {
let context = Context::new();
let global_params = context.get_global_property_value(DiseaseParams);
assert!(global_params.is_none());
}
#[test]
fn set_parameters() {
let mut context = Context::new();
let temp_dir = tempdir().unwrap();
let config_path = PathBuf::from(&temp_dir.path());
let file_name = "test.json";
let file_path = config_path.join(file_name);
let config = fs::File::create(config_path.join(file_name)).unwrap();
let params: ParamType = ParamType {
days: 10,
diseases: 2,
};
define_global_property!(Parameters, ParamType);
let _ = serde_json::to_writer(config, ¶ms);
let params_json = context
.load_parameters_from_json::<ParamType>(&file_path)
.unwrap();
context
.set_global_property_value(Parameters, params_json)
.unwrap();
let params_read = context
.get_global_property_value(Parameters)
.unwrap()
.clone();
assert_eq!(params_read.days, params.days);
assert_eq!(params_read.diseases, params.diseases);
}
#[derive(Serialize, Deserialize)]
pub struct Property1Type {
field_int: u32,
field_str: String,
}
define_global_property!(Property1, Property1Type);
#[derive(Serialize, Deserialize)]
pub struct Property2Type {
field_int: u32,
}
define_global_property!(Property2, Property2Type);
#[test]
fn read_global_properties() {
let mut context = Context::new();
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/data/global_properties_test1.json");
context.load_global_properties(&path).unwrap();
let p1 = context.get_global_property_value(Property1).unwrap();
assert_eq!(p1.field_int, 1);
assert_eq!(p1.field_str, "test");
let p2 = context.get_global_property_value(Property2).unwrap();
assert_eq!(p2.field_int, 2);
}
#[test]
fn read_unknown_property() {
let mut context = Context::new();
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/data/global_properties_missing.json");
match context.load_global_properties(&path) {
Err(IxaError::NoGlobalProperty { name }) => assert_eq!(name, "ixa.PropertyUnknown"),
_ => panic!("Unexpected error type"),
}
}
#[test]
fn read_malformed_property() {
let mut context = Context::new();
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/data/global_properties_malformed.json");
let error = context.load_global_properties(&path);
println!("Error {error:?}");
match error {
Err(IxaError::JsonError(_)) => {}
_ => panic!("Unexpected error type"),
}
}
#[test]
fn read_duplicate_property() {
let mut context = Context::new();
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/data/global_properties_test1.json");
context.load_global_properties(&path).unwrap();
let error = context.load_global_properties(&path);
match error {
Err(IxaError::DuplicateProperty { .. }) => {}
_ => panic!("Unexpected error type"),
}
}
#[derive(Serialize, Deserialize)]
pub struct Property3Type {
field_int: u32,
}
define_global_property!(Property3, Property3Type, |v: &Property3Type| {
match v.field_int {
0 => Ok(()),
_ => Err(Box::new(InvalidProperty3Value {
field_int: v.field_int,
}) as Box<dyn Error + Send + Sync + 'static>),
}
});
#[test]
fn validate_property_set_success() {
let mut context = Context::new();
context
.set_global_property_value(Property3, Property3Type { field_int: 0 })
.unwrap();
}
#[test]
fn validate_property_set_failure() {
let mut context = Context::new();
let error = context
.set_global_property_value(Property3, Property3Type { field_int: 1 })
.unwrap_err();
assert_eq!(
error.to_string(),
"illegal value for global property `Property3`: field_int must be zero, got 1"
);
match error {
IxaError::IllegalGlobalPropertyValue { name, source } => {
assert_eq!(name, "Property3");
assert_eq!(source.to_string(), "field_int must be zero, got 1");
}
_ => panic!("Unexpected error type"),
}
}
#[test]
fn validate_property_load_success() {
let mut context = Context::new();
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/data/global_properties_valid.json");
context.load_global_properties(&path).unwrap();
}
#[test]
fn validate_property_load_failure() {
let mut context = Context::new();
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/data/global_properties_invalid.json");
let error = context.load_global_properties(&path).unwrap_err();
assert_eq!(
error.to_string(),
"illegal value for global property `Property3`: field_int must be zero, got 42"
);
match error {
IxaError::IllegalGlobalPropertyValue { name, source } => {
assert_eq!(name, "Property3");
assert_eq!(source.to_string(), "field_int must be zero, got 42");
}
_ => panic!("Unexpected error type"),
}
}
}