use super::Error;
use crate::de::ErrorKind;
use serde::de;
use std::{cell::Cell, fmt::Display, ops::Range, thread::LocalKey};
use taml::{
diagnostics::{
Diagnostic, DiagnosticLabel, DiagnosticLabelPriority, DiagnosticType,
Reporter as diagReporter,
},
parsing::TamlValue,
Position,
};
use tap::Pipe;
thread_local! {
pub(super) static OVERRIDE: Cell<Option<ForcedTamlValueType>> = Cell::default();
}
pub(super) trait Override {
fn set(&'static self, force: ForcedTamlValueType);
fn take(&'static self) -> Option<ForcedTamlValueType>;
fn insert_if_none(&'static self, new_default: ForcedTamlValueType);
}
impl Override for LocalKey<Cell<Option<ForcedTamlValueType>>> {
fn set(&'static self, force: ForcedTamlValueType) {
self.with(|override_| override_.set(Some(force)));
}
fn take(&'static self) -> Option<ForcedTamlValueType> {
self.with(Cell::take)
}
fn insert_if_none(&'static self, new_default: ForcedTamlValueType) {
self.with(|this| {
if this.get().is_none() {
this.set(new_default.into());
}
});
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub(super) enum ForcedTamlValueType {
DataLiteral,
Decimal,
EnumVariant,
Integer,
List,
String,
Struct,
}
impl ForcedTamlValueType {
pub fn pick<'a, 'de, P: Position, Reporter: diagReporter<P>>(
self,
value: &'a TamlValue<'de, P>,
span: &Range<P>,
reporter: &mut Reporter,
) -> Result<&'a TamlValue<'de, P>, Error> {
#[allow(
clippy::match_same_arms,
clippy::non_ascii_literal,
clippy::single_match_else
)]
match self {
ForcedTamlValueType::String => match value {
v @ TamlValue::String(_) => Ok(v),
_ => {
reporter.report_with(|| Diagnostic {
type_: DiagnosticType::InvalidType,
labels: vec![DiagnosticLabel::new(
r#"Expected string (`"…"`)."#,
span.clone(),
DiagnosticLabelPriority::Primary,
)],
});
Err(ErrorKind::Reported.into())
}
},
ForcedTamlValueType::DataLiteral => match value {
v @ TamlValue::DataLiteral(_) => Ok(v),
_ => {
reporter.report_with(|| Diagnostic {
type_: DiagnosticType::InvalidType,
labels: vec![DiagnosticLabel::new(
"Expected data literal (`<…;…>`).",
span.clone(),
DiagnosticLabelPriority::Primary,
)],
});
Err(ErrorKind::Reported.into())
}
},
ForcedTamlValueType::Integer => match value {
v @ TamlValue::Integer(_) => Ok(v),
_ => {
reporter.report_with(|| Diagnostic {
type_: DiagnosticType::InvalidType,
labels: vec![DiagnosticLabel::new(
"Expected integer.",
span.clone(),
DiagnosticLabelPriority::Primary,
)],
});
Err(ErrorKind::Reported.into())
}
},
ForcedTamlValueType::Decimal => match value {
TamlValue::Integer(i) => {
let span = span.clone().pipe(Some);
reporter.report_with(|| Diagnostic {
type_: DiagnosticType::InvalidType,
labels: vec![
DiagnosticLabel::new(
"Expected decimal.",
span.clone(),
DiagnosticLabelPriority::Primary,
),
DiagnosticLabel::new(
format!("Hint: Try `{}.0`.", i),
span,
DiagnosticLabelPriority::Auxiliary,
),
],
});
Err(ErrorKind::Reported.into())
}
v @ TamlValue::Decimal(_) => Ok(v),
_ => {
reporter.report_with(|| Diagnostic {
type_: DiagnosticType::InvalidType,
labels: vec![DiagnosticLabel::new(
"Expected decimal.",
span.clone(),
DiagnosticLabelPriority::Primary,
)],
});
Err(ErrorKind::Reported.into())
}
},
ForcedTamlValueType::EnumVariant => match value {
v @ TamlValue::EnumVariant { .. } => Ok(v),
_ => {
reporter.report_with(|| Diagnostic {
type_: DiagnosticType::InvalidType,
labels: vec![DiagnosticLabel::new(
r#"Expected enum variant (`key` or `key(…)`)."#,
span.clone(),
DiagnosticLabelPriority::Primary,
)],
});
Err(ErrorKind::Reported.into())
}
},
ForcedTamlValueType::List => match value {
v @ TamlValue::List(_) => Ok(v),
_ => {
reporter.report_with(|| Diagnostic {
type_: DiagnosticType::InvalidType,
labels: vec![DiagnosticLabel::new(
"Expected list (`(…)`).",
span.clone(),
DiagnosticLabelPriority::Primary,
)],
});
Err(ErrorKind::Reported.into())
}
},
ForcedTamlValueType::Struct => match value {
v @ TamlValue::Map(_) => Ok(v),
_ => {
reporter.report_with(|| Diagnostic {
type_: DiagnosticType::InvalidType,
labels: vec![DiagnosticLabel::new(
"Expected struct.",
span.clone(),
DiagnosticLabelPriority::Primary,
)],
});
Err(ErrorKind::Reported.into())
}
},
}
}
}
impl Display for ForcedTamlValueType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(match self {
ForcedTamlValueType::DataLiteral => "data literal",
ForcedTamlValueType::Decimal => "decimal",
ForcedTamlValueType::EnumVariant => "enum variant",
ForcedTamlValueType::Integer => "integer",
ForcedTamlValueType::List => "list",
ForcedTamlValueType::String => "string",
ForcedTamlValueType::Struct => "struct",
})
}
}
pub(crate) trait AssertAcceptableAndUnwrapOrDefault<T> {
fn assert_acceptable_and_unwrap(self, default: T, other_acceptable: &[T]) -> T;
}
impl AssertAcceptableAndUnwrapOrDefault<ForcedTamlValueType> for Option<ForcedTamlValueType> {
fn assert_acceptable_and_unwrap(
self,
default: ForcedTamlValueType,
other_acceptable: &[ForcedTamlValueType],
) -> ForcedTamlValueType {
match self {
None => default,
Some(forced) if forced == default || other_acceptable.contains(&forced) => forced,
Some(forced) => panic!(
"Unsupported TAML type override: Can't expect {} when parsing {}.",
forced, default
),
}
}
}
pub fn from_data_literal<'de, D, T>(deserializer: D) -> Result<T, D::Error>
where
D: de::Deserializer<'de>,
T: de::Deserialize<'de>,
{
OVERRIDE.set(ForcedTamlValueType::DataLiteral);
T::deserialize(deserializer)
}
pub fn from_decimal<'de, D, T>(deserializer: D) -> Result<T, D::Error>
where
D: de::Deserializer<'de>,
T: de::Deserialize<'de>,
{
OVERRIDE.set(ForcedTamlValueType::Decimal);
T::deserialize(deserializer)
}
pub fn from_enum_variant<'de, D, T>(deserializer: D) -> Result<T, D::Error>
where
D: de::Deserializer<'de>,
T: de::Deserialize<'de>,
{
OVERRIDE.set(ForcedTamlValueType::EnumVariant);
T::deserialize(deserializer)
}
pub fn from_integer<'de, D, T>(deserializer: D) -> Result<T, D::Error>
where
D: de::Deserializer<'de>,
T: de::Deserialize<'de>,
{
OVERRIDE.set(ForcedTamlValueType::Integer);
T::deserialize(deserializer)
}
pub fn from_list<'de, D, T>(deserializer: D) -> Result<T, D::Error>
where
D: de::Deserializer<'de>,
T: de::Deserialize<'de>,
{
OVERRIDE.set(ForcedTamlValueType::List);
T::deserialize(deserializer)
}
pub fn from_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
where
D: de::Deserializer<'de>,
T: de::Deserialize<'de>,
{
OVERRIDE.set(ForcedTamlValueType::String);
T::deserialize(deserializer)
}
pub fn from_struct<'de, D, T>(deserializer: D) -> Result<T, D::Error>
where
D: de::Deserializer<'de>,
T: de::Deserialize<'de>,
{
OVERRIDE.set(ForcedTamlValueType::Struct);
T::deserialize(deserializer)
}