1use itertools::Itertools;
2
3use std::{cmp::Ordering, collections::HashMap};
4use strum::IntoEnumIterator;
5use strum_macros::EnumIter;
6use thiserror::Error;
7
8use serde::{
9 de::{self, DeserializeOwned, IntoDeserializer, Visitor},
10 Deserialize, Deserializer,
11};
12use serde_enum_str::{Deserialize_enum_str, Serialize_enum_str};
13
14use super::{normalize_lua_value, DisplayAsLuaKV, DisplayLuaKV, DisplayLuaValue, LuaValueSeed};
15
16#[derive(Deserialize_enum_str, Serialize_enum_str, PartialEq, Eq, Hash, Debug, Clone, EnumIter)]
19#[serde(rename_all = "lowercase")]
20#[strum(serialize_all = "lowercase")]
21pub enum PlatformIdentifier {
22 Unix,
24 Windows,
25 Win32,
26 Cygwin,
27 MacOSX,
28 Linux,
29 FreeBSD,
30 #[serde(other)]
31 Unknown(String),
32}
33
34impl Default for PlatformIdentifier {
35 fn default() -> Self {
36 target_identifier()
37 }
38}
39
40impl PartialOrd for PlatformIdentifier {
42 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
43 match (self, other) {
44 (PlatformIdentifier::Unix, PlatformIdentifier::Cygwin) => Some(Ordering::Less),
45 (PlatformIdentifier::Unix, PlatformIdentifier::MacOSX) => Some(Ordering::Less),
46 (PlatformIdentifier::Unix, PlatformIdentifier::Linux) => Some(Ordering::Less),
47 (PlatformIdentifier::Unix, PlatformIdentifier::FreeBSD) => Some(Ordering::Less),
48 (PlatformIdentifier::Windows, PlatformIdentifier::Win32) => Some(Ordering::Greater),
49 (PlatformIdentifier::Win32, PlatformIdentifier::Windows) => Some(Ordering::Less),
50 (PlatformIdentifier::Cygwin, PlatformIdentifier::Unix) => Some(Ordering::Greater),
51 (PlatformIdentifier::MacOSX, PlatformIdentifier::Unix) => Some(Ordering::Greater),
52 (PlatformIdentifier::Linux, PlatformIdentifier::Unix) => Some(Ordering::Greater),
53 (PlatformIdentifier::FreeBSD, PlatformIdentifier::Unix) => Some(Ordering::Greater),
54 _ if self == other => Some(Ordering::Equal),
55 _ => None,
56 }
57 }
58}
59
60fn target_identifier() -> PlatformIdentifier {
67 if cfg!(target_env = "msvc") {
68 PlatformIdentifier::Windows
69 } else if cfg!(target_os = "linux") {
70 PlatformIdentifier::Linux
71 } else if cfg!(target_os = "macos") || cfg!(target_vendor = "apple") {
72 PlatformIdentifier::MacOSX
73 } else if cfg!(target_os = "freebsd") {
74 PlatformIdentifier::FreeBSD
75 } else if which::which("cygpath").is_ok() {
76 PlatformIdentifier::Cygwin
77 } else {
78 PlatformIdentifier::Unix
79 }
80}
81
82impl PlatformIdentifier {
83 pub fn get_subsets(&self) -> Vec<Self> {
86 PlatformIdentifier::iter()
87 .filter(|identifier| identifier.is_subset_of(self))
88 .collect()
89 }
90
91 pub fn get_extended_platforms(&self) -> Vec<Self> {
94 PlatformIdentifier::iter()
95 .filter(|identifier| identifier.is_extension_of(self))
96 .collect()
97 }
98
99 fn is_subset_of(&self, other: &PlatformIdentifier) -> bool {
101 self.partial_cmp(other) == Some(Ordering::Less)
102 }
103
104 fn is_extension_of(&self, other: &PlatformIdentifier) -> bool {
106 self.partial_cmp(other) == Some(Ordering::Greater)
107 }
108}
109
110#[derive(Clone, Debug, PartialEq)]
111pub struct PlatformSupport {
112 platform_map: HashMap<PlatformIdentifier, bool>,
114}
115
116impl Default for PlatformSupport {
117 fn default() -> Self {
118 Self {
119 platform_map: PlatformIdentifier::iter()
120 .filter(|identifier| !matches!(identifier, PlatformIdentifier::Unknown(_)))
121 .map(|identifier| (identifier, true))
122 .collect(),
123 }
124 }
125}
126
127impl<'de> Deserialize<'de> for PlatformSupport {
128 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
129 where
130 D: Deserializer<'de>,
131 {
132 let platforms: Vec<String> = Vec::deserialize(deserializer)?;
133 Self::parse(&platforms).map_err(de::Error::custom)
134 }
135}
136
137impl DisplayAsLuaKV for PlatformSupport {
138 fn display_lua(&self) -> DisplayLuaKV {
139 DisplayLuaKV {
140 key: "supported_platforms".to_string(),
141 value: DisplayLuaValue::List(
142 self.platforms()
143 .iter()
144 .map(|(platform, supported)| {
145 DisplayLuaValue::String(format!(
146 "{}{}",
147 if *supported { "" } else { "!" },
148 platform,
149 ))
150 })
151 .collect(),
152 ),
153 }
154 }
155}
156
157#[derive(Error, Debug)]
158pub enum PlatformValidationError {
159 #[error("error when parsing platform identifier: {0}")]
160 ParseError(String),
161
162 #[error("conflicting supported platform entries")]
163 ConflictingEntries,
164}
165
166impl PlatformSupport {
167 fn validate_platforms(
168 platforms: &[String],
169 ) -> Result<HashMap<PlatformIdentifier, bool>, PlatformValidationError> {
170 platforms
171 .iter()
172 .try_fold(HashMap::new(), |mut platforms, platform| {
173 let (is_positive_assertion, platform) = platform
177 .strip_prefix('!')
178 .map(|str| (false, str))
179 .unwrap_or((true, platform));
180
181 let platform_identifier = platform
182 .parse::<PlatformIdentifier>()
183 .map_err(|err| PlatformValidationError::ParseError(err.to_string()))?;
184
185 if platforms
189 .get(&platform_identifier)
190 .unwrap_or(&is_positive_assertion)
191 != &is_positive_assertion
192 {
193 return Err(PlatformValidationError::ConflictingEntries);
194 }
195
196 platforms.insert(platform_identifier.clone(), is_positive_assertion);
197
198 let subset_or_extended_platforms = if is_positive_assertion {
199 platform_identifier.get_extended_platforms()
200 } else {
201 platform_identifier.get_subsets()
202 };
203
204 for sub_platform in subset_or_extended_platforms {
205 if platforms
206 .get(&sub_platform)
207 .unwrap_or(&is_positive_assertion)
208 != &is_positive_assertion
209 {
210 return Err(PlatformValidationError::ConflictingEntries);
212 }
213
214 platforms.insert(sub_platform, is_positive_assertion);
215 }
216
217 Ok(platforms)
218 })
219 }
220
221 pub fn parse(platforms: &[String]) -> Result<Self, PlatformValidationError> {
222 match platforms {
226 [] => Ok(Self::default()),
227 platforms if platforms.iter().any(|platform| platform.starts_with('!')) => {
228 let mut platform_map = Self::validate_platforms(platforms)?;
229
230 for identifier in PlatformIdentifier::iter() {
233 if !matches!(identifier, PlatformIdentifier::Unknown(_)) {
234 platform_map.entry(identifier).or_insert(true);
235 }
236 }
237
238 Ok(Self { platform_map })
239 }
240 platforms => Ok(Self {
242 platform_map: Self::validate_platforms(platforms)?,
243 }),
244 }
245 }
246
247 pub fn is_supported(&self, platform: &PlatformIdentifier) -> bool {
248 self.platform_map.get(platform).cloned().unwrap_or(false)
249 }
250
251 pub(crate) fn platforms(&self) -> &HashMap<PlatformIdentifier, bool> {
252 &self.platform_map
253 }
254}
255
256pub trait PartialOverride: Sized {
257 type Err: std::error::Error;
258
259 fn apply_overrides(&self, override_val: &Self) -> Result<Self, Self::Err>;
260}
261
262pub trait PlatformOverridable: PartialOverride {
263 type Err: std::error::Error;
264
265 fn on_nil<T>() -> Result<PerPlatform<T>, <Self as PlatformOverridable>::Err>
266 where
267 T: PlatformOverridable,
268 T: Default;
269}
270
271#[derive(Clone, Debug, PartialEq)]
273pub struct PerPlatform<T> {
274 pub(crate) default: T,
276 pub(crate) per_platform: HashMap<PlatformIdentifier, T>,
278}
279
280impl<T> PerPlatform<T> {
281 pub(crate) fn new(default: T) -> Self {
282 Self {
283 default,
284 per_platform: HashMap::default(),
285 }
286 }
287
288 pub fn current_platform(&self) -> &T {
291 self.for_platform_identifier(&target_identifier())
292 }
293
294 fn for_platform_identifier(&self, identifier: &PlatformIdentifier) -> &T {
295 self.get(identifier)
296 }
297
298 pub fn get(&self, platform: &PlatformIdentifier) -> &T {
299 self.per_platform.get(platform).unwrap_or(
300 platform
301 .get_subsets()
302 .into_iter()
303 .sorted_by(|a, b| b.partial_cmp(a).unwrap_or(Ordering::Equal))
307 .find(|identifier| self.per_platform.contains_key(identifier))
308 .and_then(|identifier| self.per_platform.get(&identifier))
309 .unwrap_or(&self.default),
310 )
311 }
312
313 pub(crate) fn map<U, F>(&self, cb: F) -> PerPlatform<U>
314 where
315 F: Fn(&T) -> U,
316 {
317 PerPlatform {
318 default: cb(&self.default),
319 per_platform: self
320 .per_platform
321 .iter()
322 .map(|(identifier, value)| (identifier.clone(), cb(value)))
323 .collect(),
324 }
325 }
326}
327
328impl<U, E> PerPlatform<Result<U, E>>
329where
330 E: std::error::Error,
331{
332 pub fn transpose(self) -> Result<PerPlatform<U>, E> {
333 Ok(PerPlatform {
334 default: self.default?,
335 per_platform: self
336 .per_platform
337 .into_iter()
338 .map(|(identifier, value)| Ok((identifier, value?)))
339 .try_collect()?,
340 })
341 }
342}
343
344impl<T: Default> Default for PerPlatform<T> {
345 fn default() -> Self {
346 Self {
347 default: T::default(),
348 per_platform: HashMap::default(),
349 }
350 }
351}
352
353struct PerPlatformVisitor<T>(std::marker::PhantomData<T>);
354
355impl<'de, T> Visitor<'de> for PerPlatformVisitor<T>
356where
357 T: DeserializeOwned,
358 T: PlatformOverridable,
359 T: Default,
360 T: Clone,
361{
362 type Value = PerPlatform<T>;
363
364 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
365 formatter.write_str("a table or nil")
366 }
367
368 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
369 where
370 A: de::MapAccess<'de>,
371 {
372 use serde_value::Value;
373
374 let mut platforms_val: Option<Value> = None;
375 let mut other_entries: Vec<(Value, Value)> = Vec::new();
376
377 while let Some(key) = map.next_key_seed(LuaValueSeed)? {
378 if key == Value::String("platforms".to_string()) {
379 platforms_val = Some(map.next_value_seed(LuaValueSeed)?);
380 } else {
381 other_entries.push((key, map.next_value_seed(LuaValueSeed)?));
382 }
383 }
384
385 let mut per_platform = match platforms_val {
386 Some(val) => match val {
387 Value::Map(_) => val
388 .deserialize_into::<HashMap<PlatformIdentifier, T>>()
389 .map_err(de::Error::custom)?,
390 Value::Unit => HashMap::default(),
391 val => {
392 return Err(de::Error::custom(format!(
393 "Expected platforms to be a table or nil, but got {val:?}",
394 )))
395 }
396 },
397 None => HashMap::default(),
398 };
399
400 let obj = normalize_lua_value(Value::Map(other_entries.into_iter().collect()));
404 let default = T::deserialize(obj.into_deserializer()).map_err(de::Error::custom)?;
405 apply_per_platform_overrides(&mut per_platform, &default)
406 .map_err(|err: <T as PartialOverride>::Err| de::Error::custom(err.to_string()))?;
407 Ok(PerPlatform {
408 default,
409 per_platform,
410 })
411 }
412
413 fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
414 where
415 A: de::SeqAccess<'de>,
416 {
417 let default = T::deserialize(de::value::SeqAccessDeserializer::new(seq))?;
421 Ok(PerPlatform::new(default))
422 }
423
424 fn visit_unit<E>(self) -> Result<Self::Value, E>
425 where
426 E: de::Error,
427 {
428 T::on_nil().map_err(de::Error::custom)
429 }
430
431 fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
432 where
433 E: de::Error,
434 {
435 let s = std::str::from_utf8(v).map_err(de::Error::custom)?;
436 self.visit_str(s)
437 }
438
439 fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
440 where
441 E: de::Error,
442 {
443 self.visit_bytes(&v)
444 }
445
446 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
447 where
448 E: de::Error,
449 {
450 let default = T::deserialize(v.into_deserializer())?;
451 Ok(PerPlatform::new(default))
452 }
453}
454
455impl<'de, T> Deserialize<'de> for PerPlatform<T>
456where
457 T: DeserializeOwned,
458 T: PlatformOverridable,
459 T: Default,
460 T: Clone,
461{
462 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
463 deserializer.deserialize_any(PerPlatformVisitor(std::marker::PhantomData))
464 }
465}
466
467pub(crate) fn per_platform_from_intermediate<'de, D, I, T>(
471 deserializer: D,
472) -> Result<PerPlatform<T>, D::Error>
473where
474 D: Deserializer<'de>,
475 I: PlatformOverridable<Err: ToString>,
476 I: DeserializeOwned,
477 I: Default,
478 I: Clone,
479 T: TryFrom<I, Error: ToString>,
480{
481 PerPlatform::<I>::deserialize(deserializer)?
482 .map(|internal| {
483 T::try_from(internal.clone()).map_err(|err| serde::de::Error::custom(err.to_string()))
484 })
485 .transpose()
486}
487
488fn apply_per_platform_overrides<T>(
489 per_platform: &mut HashMap<PlatformIdentifier, T>,
490 base: &T,
491) -> Result<(), T::Err>
492where
493 T: PartialOverride,
494 T: Clone,
495{
496 let per_platform_raw = per_platform.clone();
497 for (platform, overrides) in per_platform.clone() {
498 let overridden = base.apply_overrides(&overrides)?;
500 per_platform.insert(platform, overridden);
501 }
502 for (platform, overrides) in per_platform_raw {
503 for extended_platform in &platform.get_extended_platforms() {
505 if let Some(extended_overrides) = per_platform.get(extended_platform) {
506 per_platform.insert(
507 extended_platform.to_owned(),
508 extended_overrides.apply_overrides(&overrides)?,
509 );
510 }
511 }
512 }
513 Ok(())
514}
515
516#[cfg(test)]
517mod tests {
518
519 use super::*;
520 use proptest::prelude::*;
521
522 fn platform_identifier_strategy() -> impl Strategy<Value = PlatformIdentifier> {
523 prop_oneof![
524 Just(PlatformIdentifier::Unix),
525 Just(PlatformIdentifier::Windows),
526 Just(PlatformIdentifier::Win32),
527 Just(PlatformIdentifier::Cygwin),
528 Just(PlatformIdentifier::MacOSX),
529 Just(PlatformIdentifier::Linux),
530 Just(PlatformIdentifier::FreeBSD),
531 ]
532 }
533
534 #[tokio::test]
535 async fn sort_platform_identifier_more_specific_last() {
536 let mut platforms = vec![
537 PlatformIdentifier::Cygwin,
538 PlatformIdentifier::Linux,
539 PlatformIdentifier::Unix,
540 ];
541 platforms.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
542 assert_eq!(
543 platforms,
544 vec![
545 PlatformIdentifier::Unix,
546 PlatformIdentifier::Cygwin,
547 PlatformIdentifier::Linux
548 ]
549 );
550 let mut platforms = vec![PlatformIdentifier::Windows, PlatformIdentifier::Win32];
551 platforms.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
552 assert_eq!(
553 platforms,
554 vec![PlatformIdentifier::Win32, PlatformIdentifier::Windows]
555 )
556 }
557
558 #[tokio::test]
559 async fn test_is_subset_of() {
560 assert!(PlatformIdentifier::Unix.is_subset_of(&PlatformIdentifier::Linux));
561 assert!(PlatformIdentifier::Unix.is_subset_of(&PlatformIdentifier::MacOSX));
562 assert!(!PlatformIdentifier::Linux.is_subset_of(&PlatformIdentifier::Unix));
563 }
564
565 #[tokio::test]
566 async fn test_is_extension_of() {
567 assert!(PlatformIdentifier::Linux.is_extension_of(&PlatformIdentifier::Unix));
568 assert!(PlatformIdentifier::MacOSX.is_extension_of(&PlatformIdentifier::Unix));
569 assert!(!PlatformIdentifier::Unix.is_extension_of(&PlatformIdentifier::Linux));
570 }
571
572 #[tokio::test]
573 async fn per_platform() {
574 let foo = PerPlatform {
575 default: "default",
576 per_platform: vec![
577 (PlatformIdentifier::Unix, "unix"),
578 (PlatformIdentifier::FreeBSD, "freebsd"),
579 (PlatformIdentifier::Cygwin, "cygwin"),
580 (PlatformIdentifier::Linux, "linux"),
581 ]
582 .into_iter()
583 .collect(),
584 };
585 assert_eq!(*foo.get(&PlatformIdentifier::MacOSX), "unix");
586 assert_eq!(*foo.get(&PlatformIdentifier::Linux), "linux");
587 assert_eq!(*foo.get(&PlatformIdentifier::FreeBSD), "freebsd");
588 assert_eq!(*foo.get(&PlatformIdentifier::Cygwin), "cygwin");
589 assert_eq!(*foo.get(&PlatformIdentifier::Windows), "default");
590 }
591
592 #[cfg(target_os = "linux")]
593 #[tokio::test]
594 async fn test_target_identifier() {
595 run_test_target_identifier(PlatformIdentifier::Linux)
596 }
597
598 #[cfg(target_os = "macos")]
599 #[tokio::test]
600 async fn test_target_identifier() {
601 run_test_target_identifier(PlatformIdentifier::MacOSX)
602 }
603
604 #[cfg(target_env = "msvc")]
605 #[tokio::test]
606 async fn test_target_identifier() {
607 run_test_target_identifier(PlatformIdentifier::Windows)
608 }
609
610 #[cfg(target_os = "android")]
611 #[tokio::test]
612 async fn test_target_identifier() {
613 run_test_target_identifier(PlatformIdentifier::Unix)
614 }
615
616 fn run_test_target_identifier(expected: PlatformIdentifier) {
617 assert_eq!(expected, target_identifier());
618 }
619
620 proptest! {
621 #[test]
622 fn supported_platforms(identifier in platform_identifier_strategy()) {
623 let identifier_str = identifier.to_string();
624 let platforms = vec![identifier_str];
625 let platform_support = PlatformSupport::parse(&platforms).unwrap();
626 prop_assert!(platform_support.is_supported(&identifier))
627 }
628
629 #[test]
630 fn unsupported_platforms_only(unsupported in platform_identifier_strategy(), supported in platform_identifier_strategy()) {
631 if supported == unsupported
632 || unsupported.is_extension_of(&supported) {
633 return Ok(());
634 }
635 let identifier_str = format!("!{unsupported}");
636 let platforms = vec![identifier_str];
637 let platform_support = PlatformSupport::parse(&platforms).unwrap();
638 prop_assert!(!platform_support.is_supported(&unsupported));
639 prop_assert!(platform_support.is_supported(&supported))
640 }
641
642 #[test]
643 fn supported_and_unsupported_platforms(unsupported in platform_identifier_strategy(), unspecified in platform_identifier_strategy()) {
644 if unspecified == unsupported
645 || unsupported.is_extension_of(&unspecified) {
646 return Ok(());
647 }
648 let supported_str = unspecified.to_string();
649 let unsupported_str = format!("!{unsupported}");
650 let platforms = vec![supported_str, unsupported_str];
651 let platform_support = PlatformSupport::parse(&platforms).unwrap();
652 prop_assert!(platform_support.is_supported(&unspecified));
653 prop_assert!(!platform_support.is_supported(&unsupported));
654 }
655
656 #[test]
657 fn all_platforms_supported_if_none_are_specified(identifier in platform_identifier_strategy()) {
658 let platforms = vec![];
659 let platform_support = PlatformSupport::parse(&platforms).unwrap();
660 prop_assert!(platform_support.is_supported(&identifier))
661 }
662
663 #[test]
664 fn conflicting_platforms(identifier in platform_identifier_strategy()) {
665 let identifier_str = identifier.to_string();
666 let identifier_str_negated = format!("!{identifier}");
667 let platforms = vec![identifier_str, identifier_str_negated];
668 let _ = PlatformSupport::parse(&platforms).unwrap_err();
669 }
670
671 #[test]
672 fn extended_platforms_supported_if_supported(identifier in platform_identifier_strategy()) {
673 let identifier_str = identifier.to_string();
674 let platforms = vec![identifier_str];
675 let platform_support = PlatformSupport::parse(&platforms).unwrap();
676 for identifier in identifier.get_extended_platforms() {
677 prop_assert!(platform_support.is_supported(&identifier))
678 }
679 }
680
681 #[test]
682 fn sub_platforms_unsupported_if_unsupported(identifier in platform_identifier_strategy()) {
683 let identifier_str = format!("!{identifier}");
684 let platforms = vec![identifier_str];
685 let platform_support = PlatformSupport::parse(&platforms).unwrap();
686 for identifier in identifier.get_subsets() {
687 prop_assert!(!platform_support.is_supported(&identifier))
688 }
689 }
690
691 #[test]
692 fn conflicting_extended_platform_definitions(identifier in platform_identifier_strategy()) {
693 let extended_platforms = identifier.get_extended_platforms();
694 if extended_platforms.is_empty() {
695 return Ok(());
696 }
697 let supported_str = identifier.to_string();
698 let mut platforms: Vec<String> = extended_platforms.into_iter().map(|ident| format!("!{ident}")).collect();
699 platforms.push(supported_str);
700 let _ = PlatformSupport::parse(&platforms).unwrap_err();
701 }
702 }
703}