use anyhow::{Context, Result};
use indexmap::{IndexMap, IndexSet};
use std::borrow::Borrow;
use std::collections::HashSet;
use std::fmt::Display;
use std::hash::Hash;
pub trait IDLike: Eq + Hash + Borrow<str> + Clone + Display + From<String> {}
impl<T> IDLike for T where T: Eq + Hash + Borrow<str> + Clone + Display + From<String> {}
macro_rules! define_id_type {
($name:ident) => {
#[derive(
Clone,
derive_more::Display,
std::hash::Hash,
PartialOrd,
Ord,
PartialEq,
Eq,
Debug,
serde::Serialize,
)]
pub struct $name(pub std::rc::Rc<str>);
impl std::borrow::Borrow<str> for $name {
fn borrow(&self) -> &str {
&self.0
}
}
impl From<&str> for $name {
fn from(s: &str) -> Self {
$name(std::rc::Rc::from(s))
}
}
impl From<String> for $name {
fn from(s: String) -> Self {
$name(std::rc::Rc::from(s))
}
}
impl<'de> serde::Deserialize<'de> for $name {
fn deserialize<D>(deserialiser: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
const FORBIDDEN_IDS: [&str; 2] = ["all", "annual"];
let id: String = serde::Deserialize::deserialize(deserialiser)?;
let id = id.trim();
if id.is_empty() {
return Err(D::Error::custom("IDs cannot be empty"));
}
for forbidden in FORBIDDEN_IDS.iter() {
if id.eq_ignore_ascii_case(forbidden) {
return Err(D::Error::custom(format!(
"'{id}' is an invalid value for an ID"
)));
}
}
Ok(id.into())
}
}
impl $name {
pub fn new(id: &str) -> Self {
$name(std::rc::Rc::from(id))
}
}
};
}
pub(crate) use define_id_type;
#[cfg(test)]
define_id_type!(GenericID);
pub trait HasID<ID: IDLike> {
fn get_id(&self) -> &ID;
}
macro_rules! define_id_getter {
($t:ty, $id_ty:ty) => {
impl crate::id::HasID<$id_ty> for $t {
fn get_id(&self) -> &$id_ty {
&self.id
}
}
};
}
pub(crate) use define_id_getter;
pub trait IDCollection<ID: IDLike> {
fn get_id<T: Borrow<str> + Display + ?Sized>(&self, id: &T) -> Result<&ID>;
}
macro_rules! define_id_methods {
() => {
fn get_id<T: Borrow<str> + Display + ?Sized>(&self, id: &T) -> Result<&ID> {
let found = self
.get(id.borrow())
.with_context(|| format!("Unknown ID {id} found"))?;
Ok(found)
}
};
}
impl<ID: IDLike> IDCollection<ID> for HashSet<ID> {
define_id_methods!();
}
impl<ID: IDLike> IDCollection<ID> for IndexSet<ID> {
define_id_methods!();
}
impl<ID: IDLike, V> IDCollection<ID> for IndexMap<ID, V> {
fn get_id<T: Borrow<str> + Display + ?Sized>(&self, id: &T) -> Result<&ID> {
let (found, _) = self
.get_key_value(id.borrow())
.with_context(|| format!("Unknown ID {id} found"))?;
Ok(found)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
use serde::Deserialize;
#[derive(Debug, Deserialize)]
struct Record {
id: GenericID,
}
fn deserialise_id(id: &str) -> Result<Record> {
Ok(toml::from_str(&format!("id = \"{id}\""))?)
}
#[rstest]
#[case("commodity1")]
#[case("some commodity")]
#[case("PROCESS")]
#[case("café")] fn deserialise_id_valid(#[case] id: &str) {
assert_eq!(deserialise_id(id).unwrap().id.to_string(), id);
}
#[rstest]
#[case("")]
#[case("all")]
#[case("annual")]
#[case("ALL")]
#[case(" ALL ")]
fn deserialise_id_invalid(#[case] id: &str) {
deserialise_id(id).unwrap_err();
}
}