use itertools::Itertools;
use std::{cmp::Ordering, collections::HashMap};
use strum::IntoEnumIterator;
use strum_macros::EnumIter;
use thiserror::Error;
use serde::{
de::{self, DeserializeOwned, IntoDeserializer, Visitor},
Deserialize, Deserializer,
};
use serde_enum_str::{Deserialize_enum_str, Serialize_enum_str};
use super::{normalize_lua_value, DisplayAsLuaKV, DisplayLuaKV, DisplayLuaValue, LuaValueSeed};
#[derive(Deserialize_enum_str, Serialize_enum_str, PartialEq, Eq, Hash, Debug, Clone, EnumIter)]
#[serde(rename_all = "lowercase")]
#[strum(serialize_all = "lowercase")]
pub enum PlatformIdentifier {
Unix,
Windows,
Win32,
Cygwin,
MacOSX,
Linux,
FreeBSD,
#[serde(other)]
Unknown(String),
}
impl Default for PlatformIdentifier {
fn default() -> Self {
target_identifier()
}
}
impl PartialOrd for PlatformIdentifier {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
match (self, other) {
(PlatformIdentifier::Unix, PlatformIdentifier::Cygwin) => Some(Ordering::Less),
(PlatformIdentifier::Unix, PlatformIdentifier::MacOSX) => Some(Ordering::Less),
(PlatformIdentifier::Unix, PlatformIdentifier::Linux) => Some(Ordering::Less),
(PlatformIdentifier::Unix, PlatformIdentifier::FreeBSD) => Some(Ordering::Less),
(PlatformIdentifier::Windows, PlatformIdentifier::Win32) => Some(Ordering::Greater),
(PlatformIdentifier::Win32, PlatformIdentifier::Windows) => Some(Ordering::Less),
(PlatformIdentifier::Cygwin, PlatformIdentifier::Unix) => Some(Ordering::Greater),
(PlatformIdentifier::MacOSX, PlatformIdentifier::Unix) => Some(Ordering::Greater),
(PlatformIdentifier::Linux, PlatformIdentifier::Unix) => Some(Ordering::Greater),
(PlatformIdentifier::FreeBSD, PlatformIdentifier::Unix) => Some(Ordering::Greater),
_ if self == other => Some(Ordering::Equal),
_ => None,
}
}
}
fn target_identifier() -> PlatformIdentifier {
if cfg!(target_env = "msvc") {
PlatformIdentifier::Windows
} else if cfg!(target_os = "linux") {
PlatformIdentifier::Linux
} else if cfg!(target_os = "macos") || cfg!(target_vendor = "apple") {
PlatformIdentifier::MacOSX
} else if cfg!(target_os = "freebsd") {
PlatformIdentifier::FreeBSD
} else if which::which("cygpath").is_ok() {
PlatformIdentifier::Cygwin
} else {
PlatformIdentifier::Unix
}
}
impl PlatformIdentifier {
pub fn get_subsets(&self) -> Vec<Self> {
PlatformIdentifier::iter()
.filter(|identifier| identifier.is_subset_of(self))
.collect()
}
pub fn get_extended_platforms(&self) -> Vec<Self> {
PlatformIdentifier::iter()
.filter(|identifier| identifier.is_extension_of(self))
.collect()
}
fn is_subset_of(&self, other: &PlatformIdentifier) -> bool {
self.partial_cmp(other) == Some(Ordering::Less)
}
fn is_extension_of(&self, other: &PlatformIdentifier) -> bool {
self.partial_cmp(other) == Some(Ordering::Greater)
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct PlatformSupport {
platform_map: HashMap<PlatformIdentifier, bool>,
}
impl Default for PlatformSupport {
fn default() -> Self {
Self {
platform_map: PlatformIdentifier::iter()
.filter(|identifier| !matches!(identifier, PlatformIdentifier::Unknown(_)))
.map(|identifier| (identifier, true))
.collect(),
}
}
}
impl<'de> Deserialize<'de> for PlatformSupport {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let platforms: Vec<String> = Vec::deserialize(deserializer)?;
Self::parse(&platforms).map_err(de::Error::custom)
}
}
impl DisplayAsLuaKV for PlatformSupport {
fn display_lua(&self) -> DisplayLuaKV {
DisplayLuaKV {
key: "supported_platforms".to_string(),
value: DisplayLuaValue::List(
self.platforms()
.iter()
.map(|(platform, supported)| {
DisplayLuaValue::String(format!(
"{}{}",
if *supported { "" } else { "!" },
platform,
))
})
.collect(),
),
}
}
}
#[derive(Error, Debug)]
pub enum PlatformValidationError {
#[error("error when parsing platform identifier: {0}")]
ParseError(String),
#[error("conflicting supported platform entries")]
ConflictingEntries,
}
impl PlatformSupport {
fn validate_platforms(
platforms: &[String],
) -> Result<HashMap<PlatformIdentifier, bool>, PlatformValidationError> {
platforms
.iter()
.try_fold(HashMap::new(), |mut platforms, platform| {
let (is_positive_assertion, platform) = platform
.strip_prefix('!')
.map(|str| (false, str))
.unwrap_or((true, platform));
let platform_identifier = platform
.parse::<PlatformIdentifier>()
.map_err(|err| PlatformValidationError::ParseError(err.to_string()))?;
if platforms
.get(&platform_identifier)
.unwrap_or(&is_positive_assertion)
!= &is_positive_assertion
{
return Err(PlatformValidationError::ConflictingEntries);
}
platforms.insert(platform_identifier.clone(), is_positive_assertion);
let subset_or_extended_platforms = if is_positive_assertion {
platform_identifier.get_extended_platforms()
} else {
platform_identifier.get_subsets()
};
for sub_platform in subset_or_extended_platforms {
if platforms
.get(&sub_platform)
.unwrap_or(&is_positive_assertion)
!= &is_positive_assertion
{
return Err(PlatformValidationError::ConflictingEntries);
}
platforms.insert(sub_platform, is_positive_assertion);
}
Ok(platforms)
})
}
pub fn parse(platforms: &[String]) -> Result<Self, PlatformValidationError> {
match platforms {
[] => Ok(Self::default()),
platforms if platforms.iter().any(|platform| platform.starts_with('!')) => {
let mut platform_map = Self::validate_platforms(platforms)?;
for identifier in PlatformIdentifier::iter() {
if !matches!(identifier, PlatformIdentifier::Unknown(_)) {
platform_map.entry(identifier).or_insert(true);
}
}
Ok(Self { platform_map })
}
platforms => Ok(Self {
platform_map: Self::validate_platforms(platforms)?,
}),
}
}
pub fn is_supported(&self, platform: &PlatformIdentifier) -> bool {
self.platform_map.get(platform).cloned().unwrap_or(false)
}
pub(crate) fn platforms(&self) -> &HashMap<PlatformIdentifier, bool> {
&self.platform_map
}
}
pub trait PartialOverride: Sized {
type Err: std::error::Error;
fn apply_overrides(&self, override_val: &Self) -> Result<Self, Self::Err>;
}
pub trait PlatformOverridable: PartialOverride {
type Err: std::error::Error;
fn on_nil<T>() -> Result<PerPlatform<T>, <Self as PlatformOverridable>::Err>
where
T: PlatformOverridable,
T: Default;
}
#[derive(Clone, Debug, PartialEq)]
pub struct PerPlatform<T> {
pub(crate) default: T,
pub(crate) per_platform: HashMap<PlatformIdentifier, T>,
}
impl<T> PerPlatform<T> {
pub(crate) fn new(default: T) -> Self {
Self {
default,
per_platform: HashMap::default(),
}
}
pub fn current_platform(&self) -> &T {
self.for_platform_identifier(&target_identifier())
}
fn for_platform_identifier(&self, identifier: &PlatformIdentifier) -> &T {
self.get(identifier)
}
pub fn get(&self, platform: &PlatformIdentifier) -> &T {
self.per_platform.get(platform).unwrap_or(
platform
.get_subsets()
.into_iter()
.sorted_by(|a, b| b.partial_cmp(a).unwrap_or(Ordering::Equal))
.find(|identifier| self.per_platform.contains_key(identifier))
.and_then(|identifier| self.per_platform.get(&identifier))
.unwrap_or(&self.default),
)
}
pub(crate) fn map<U, F>(&self, cb: F) -> PerPlatform<U>
where
F: Fn(&T) -> U,
{
PerPlatform {
default: cb(&self.default),
per_platform: self
.per_platform
.iter()
.map(|(identifier, value)| (identifier.clone(), cb(value)))
.collect(),
}
}
}
impl<U, E> PerPlatform<Result<U, E>>
where
E: std::error::Error,
{
pub fn transpose(self) -> Result<PerPlatform<U>, E> {
Ok(PerPlatform {
default: self.default?,
per_platform: self
.per_platform
.into_iter()
.map(|(identifier, value)| Ok((identifier, value?)))
.try_collect()?,
})
}
}
impl<T: Default> Default for PerPlatform<T> {
fn default() -> Self {
Self {
default: T::default(),
per_platform: HashMap::default(),
}
}
}
struct PerPlatformVisitor<T>(std::marker::PhantomData<T>);
impl<'de, T> Visitor<'de> for PerPlatformVisitor<T>
where
T: DeserializeOwned,
T: PlatformOverridable,
T: Default,
T: Clone,
{
type Value = PerPlatform<T>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a table or nil")
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: de::MapAccess<'de>,
{
use serde_value::Value;
let mut platforms_val: Option<Value> = None;
let mut other_entries: Vec<(Value, Value)> = Vec::new();
while let Some(key) = map.next_key_seed(LuaValueSeed)? {
if key == Value::String("platforms".to_string()) {
platforms_val = Some(map.next_value_seed(LuaValueSeed)?);
} else {
other_entries.push((key, map.next_value_seed(LuaValueSeed)?));
}
}
let mut per_platform = match platforms_val {
Some(val) => match val {
Value::Map(_) => val
.deserialize_into::<HashMap<PlatformIdentifier, T>>()
.map_err(de::Error::custom)?,
Value::Unit => HashMap::default(),
val => {
return Err(de::Error::custom(format!(
"Expected platforms to be a table or nil, but got {val:?}",
)))
}
},
None => HashMap::default(),
};
let obj = normalize_lua_value(Value::Map(other_entries.into_iter().collect()));
let default = T::deserialize(obj.into_deserializer()).map_err(de::Error::custom)?;
apply_per_platform_overrides(&mut per_platform, &default)
.map_err(|err: <T as PartialOverride>::Err| de::Error::custom(err.to_string()))?;
Ok(PerPlatform {
default,
per_platform,
})
}
fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
where
A: de::SeqAccess<'de>,
{
let default = T::deserialize(de::value::SeqAccessDeserializer::new(seq))?;
Ok(PerPlatform::new(default))
}
fn visit_unit<E>(self) -> Result<Self::Value, E>
where
E: de::Error,
{
T::on_nil().map_err(de::Error::custom)
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: de::Error,
{
let s = std::str::from_utf8(v).map_err(de::Error::custom)?;
self.visit_str(s)
}
fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
where
E: de::Error,
{
self.visit_bytes(&v)
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
let default = T::deserialize(v.into_deserializer())?;
Ok(PerPlatform::new(default))
}
}
impl<'de, T> Deserialize<'de> for PerPlatform<T>
where
T: DeserializeOwned,
T: PlatformOverridable,
T: Default,
T: Clone,
{
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
deserializer.deserialize_any(PerPlatformVisitor(std::marker::PhantomData))
}
}
pub(crate) fn per_platform_from_intermediate<'de, D, I, T>(
deserializer: D,
) -> Result<PerPlatform<T>, D::Error>
where
D: Deserializer<'de>,
I: PlatformOverridable<Err: ToString>,
I: DeserializeOwned,
I: Default,
I: Clone,
T: TryFrom<I, Error: ToString>,
{
PerPlatform::<I>::deserialize(deserializer)?
.map(|internal| {
T::try_from(internal.clone()).map_err(|err| serde::de::Error::custom(err.to_string()))
})
.transpose()
}
fn apply_per_platform_overrides<T>(
per_platform: &mut HashMap<PlatformIdentifier, T>,
base: &T,
) -> Result<(), T::Err>
where
T: PartialOverride,
T: Clone,
{
let per_platform_raw = per_platform.clone();
for (platform, overrides) in per_platform.clone() {
let overridden = base.apply_overrides(&overrides)?;
per_platform.insert(platform, overridden);
}
for (platform, overrides) in per_platform_raw {
for extended_platform in &platform.get_extended_platforms() {
if let Some(extended_overrides) = per_platform.get(extended_platform) {
per_platform.insert(
extended_platform.to_owned(),
extended_overrides.apply_overrides(&overrides)?,
);
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
fn platform_identifier_strategy() -> impl Strategy<Value = PlatformIdentifier> {
prop_oneof![
Just(PlatformIdentifier::Unix),
Just(PlatformIdentifier::Windows),
Just(PlatformIdentifier::Win32),
Just(PlatformIdentifier::Cygwin),
Just(PlatformIdentifier::MacOSX),
Just(PlatformIdentifier::Linux),
Just(PlatformIdentifier::FreeBSD),
]
}
#[tokio::test]
async fn sort_platform_identifier_more_specific_last() {
let mut platforms = vec![
PlatformIdentifier::Cygwin,
PlatformIdentifier::Linux,
PlatformIdentifier::Unix,
];
platforms.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
assert_eq!(
platforms,
vec![
PlatformIdentifier::Unix,
PlatformIdentifier::Cygwin,
PlatformIdentifier::Linux
]
);
let mut platforms = vec![PlatformIdentifier::Windows, PlatformIdentifier::Win32];
platforms.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
assert_eq!(
platforms,
vec![PlatformIdentifier::Win32, PlatformIdentifier::Windows]
)
}
#[tokio::test]
async fn test_is_subset_of() {
assert!(PlatformIdentifier::Unix.is_subset_of(&PlatformIdentifier::Linux));
assert!(PlatformIdentifier::Unix.is_subset_of(&PlatformIdentifier::MacOSX));
assert!(!PlatformIdentifier::Linux.is_subset_of(&PlatformIdentifier::Unix));
}
#[tokio::test]
async fn test_is_extension_of() {
assert!(PlatformIdentifier::Linux.is_extension_of(&PlatformIdentifier::Unix));
assert!(PlatformIdentifier::MacOSX.is_extension_of(&PlatformIdentifier::Unix));
assert!(!PlatformIdentifier::Unix.is_extension_of(&PlatformIdentifier::Linux));
}
#[tokio::test]
async fn per_platform() {
let foo = PerPlatform {
default: "default",
per_platform: vec![
(PlatformIdentifier::Unix, "unix"),
(PlatformIdentifier::FreeBSD, "freebsd"),
(PlatformIdentifier::Cygwin, "cygwin"),
(PlatformIdentifier::Linux, "linux"),
]
.into_iter()
.collect(),
};
assert_eq!(*foo.get(&PlatformIdentifier::MacOSX), "unix");
assert_eq!(*foo.get(&PlatformIdentifier::Linux), "linux");
assert_eq!(*foo.get(&PlatformIdentifier::FreeBSD), "freebsd");
assert_eq!(*foo.get(&PlatformIdentifier::Cygwin), "cygwin");
assert_eq!(*foo.get(&PlatformIdentifier::Windows), "default");
}
#[cfg(target_os = "linux")]
#[tokio::test]
async fn test_target_identifier() {
run_test_target_identifier(PlatformIdentifier::Linux)
}
#[cfg(target_os = "macos")]
#[tokio::test]
async fn test_target_identifier() {
run_test_target_identifier(PlatformIdentifier::MacOSX)
}
#[cfg(target_env = "msvc")]
#[tokio::test]
async fn test_target_identifier() {
run_test_target_identifier(PlatformIdentifier::Windows)
}
#[cfg(target_os = "android")]
#[tokio::test]
async fn test_target_identifier() {
run_test_target_identifier(PlatformIdentifier::Unix)
}
fn run_test_target_identifier(expected: PlatformIdentifier) {
assert_eq!(expected, target_identifier());
}
proptest! {
#[test]
fn supported_platforms(identifier in platform_identifier_strategy()) {
let identifier_str = identifier.to_string();
let platforms = vec![identifier_str];
let platform_support = PlatformSupport::parse(&platforms).unwrap();
prop_assert!(platform_support.is_supported(&identifier))
}
#[test]
fn unsupported_platforms_only(unsupported in platform_identifier_strategy(), supported in platform_identifier_strategy()) {
if supported == unsupported
|| unsupported.is_extension_of(&supported) {
return Ok(());
}
let identifier_str = format!("!{unsupported}");
let platforms = vec![identifier_str];
let platform_support = PlatformSupport::parse(&platforms).unwrap();
prop_assert!(!platform_support.is_supported(&unsupported));
prop_assert!(platform_support.is_supported(&supported))
}
#[test]
fn supported_and_unsupported_platforms(unsupported in platform_identifier_strategy(), unspecified in platform_identifier_strategy()) {
if unspecified == unsupported
|| unsupported.is_extension_of(&unspecified) {
return Ok(());
}
let supported_str = unspecified.to_string();
let unsupported_str = format!("!{unsupported}");
let platforms = vec![supported_str, unsupported_str];
let platform_support = PlatformSupport::parse(&platforms).unwrap();
prop_assert!(platform_support.is_supported(&unspecified));
prop_assert!(!platform_support.is_supported(&unsupported));
}
#[test]
fn all_platforms_supported_if_none_are_specified(identifier in platform_identifier_strategy()) {
let platforms = vec![];
let platform_support = PlatformSupport::parse(&platforms).unwrap();
prop_assert!(platform_support.is_supported(&identifier))
}
#[test]
fn conflicting_platforms(identifier in platform_identifier_strategy()) {
let identifier_str = identifier.to_string();
let identifier_str_negated = format!("!{identifier}");
let platforms = vec![identifier_str, identifier_str_negated];
let _ = PlatformSupport::parse(&platforms).unwrap_err();
}
#[test]
fn extended_platforms_supported_if_supported(identifier in platform_identifier_strategy()) {
let identifier_str = identifier.to_string();
let platforms = vec![identifier_str];
let platform_support = PlatformSupport::parse(&platforms).unwrap();
for identifier in identifier.get_extended_platforms() {
prop_assert!(platform_support.is_supported(&identifier))
}
}
#[test]
fn sub_platforms_unsupported_if_unsupported(identifier in platform_identifier_strategy()) {
let identifier_str = format!("!{identifier}");
let platforms = vec![identifier_str];
let platform_support = PlatformSupport::parse(&platforms).unwrap();
for identifier in identifier.get_subsets() {
prop_assert!(!platform_support.is_supported(&identifier))
}
}
#[test]
fn conflicting_extended_platform_definitions(identifier in platform_identifier_strategy()) {
let extended_platforms = identifier.get_extended_platforms();
if extended_platforms.is_empty() {
return Ok(());
}
let supported_str = identifier.to_string();
let mut platforms: Vec<String> = extended_platforms.into_iter().map(|ident| format!("!{ident}")).collect();
platforms.push(supported_str);
let _ = PlatformSupport::parse(&platforms).unwrap_err();
}
}
}