Skip to main content

litcheck_core/variables/
mod.rs

1use std::{
2    borrow::{Borrow, Cow},
3    collections::BTreeMap,
4    fmt,
5    str::FromStr,
6};
7
8use miette::Context;
9
10use crate::{
11    Symbol,
12    diagnostics::{
13        self, DiagResult, Diagnostic, Label, Report, SourceId, SourceSpan, Span, Spanned,
14    },
15    range::Range,
16};
17
18pub trait ValueParser {
19    type Value<'a>;
20
21    fn try_parse<'input>(s: Span<&'input str>) -> DiagResult<Self::Value<'input>>;
22}
23
24impl ValueParser for str {
25    type Value<'a> = &'a str;
26
27    #[inline(always)]
28    fn try_parse<'input>(s: Span<&'input str>) -> DiagResult<Self::Value<'input>> {
29        Ok(s.into_inner())
30    }
31}
32
33impl ValueParser for String {
34    type Value<'a> = String;
35
36    #[inline(always)]
37    fn try_parse<'input>(s: Span<&'input str>) -> DiagResult<Self::Value<'input>> {
38        Ok(s.into_inner().to_string())
39    }
40}
41impl ValueParser for Cow<'_, str> {
42    type Value<'a> = Cow<'a, str>;
43
44    #[inline(always)]
45    fn try_parse<'input>(s: Span<&'input str>) -> DiagResult<Self::Value<'input>> {
46        Ok(Cow::Borrowed(s.into_inner()))
47    }
48}
49impl ValueParser for i64 {
50    type Value<'a> = i64;
51
52    #[inline(always)]
53    fn try_parse<'input>(s: Span<&'input str>) -> DiagResult<Self::Value<'input>> {
54        let (span, s) = s.into_parts();
55        s.parse::<i64>().map_err(|err| {
56            Report::new(diagnostics::Diag::new(format!("{err}")).with_label(Label::at(span)))
57        })
58    }
59}
60
61pub trait TypedVariable: Clone + Sized {
62    type Key<'a>;
63    type Value<'a>;
64    type Variable<'a>: Clone + Sized;
65
66    fn try_parse<'input>(input: Span<&'input str>) -> Result<Self::Variable<'input>, Report>;
67}
68
69#[derive(Diagnostic, Debug)]
70pub enum VariableError {
71    #[diagnostic()]
72    Empty(#[label] SourceSpan),
73    #[diagnostic()]
74    EmptyName(#[label] SourceSpan),
75    #[diagnostic()]
76    MissingEquals(#[label] SourceSpan),
77}
78impl VariableError {
79    pub fn into_report(self) -> Report {
80        Report::from(self)
81    }
82}
83impl std::error::Error for VariableError {}
84impl fmt::Display for VariableError {
85    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
86        match self {
87            Self::Empty(_) => f.write_str("invalid variable definition: expected expression of the form `NAME=(VALUE)?`"),
88            Self::EmptyName(_) => f.write_str("invalid variable definition: name cannot be empty"),
89            Self::MissingEquals(_) => f.write_str(
90                "invalid variable definition: expected 'NAME=VALUE', but no '=' was found in the input",
91            ),
92        }
93    }
94}
95
96#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
97pub enum VariableName {
98    Pseudo(Span<Symbol>),
99    Global(Span<Symbol>),
100    User(Span<Symbol>),
101}
102
103impl ValueParser for VariableName {
104    type Value<'a> = VariableName;
105
106    fn try_parse(input: Span<&str>) -> DiagResult<Self::Value<'_>> {
107        let (span, s) = input.into_parts();
108        let (prefix, unprefixed) = if let Some(name) = s.strip_prefix('$') {
109            (Some('$'), name)
110        } else if let Some(name) = s.strip_prefix('@') {
111            (Some('@'), name)
112        } else {
113            (None, s)
114        };
115        if !is_valid_variable_name(unprefixed) {
116            let offset = prefix.is_some() as u32;
117            let span = SourceSpan::new(
118                span.source_id(),
119                Range::new(span.start() + offset, span.end()),
120            );
121            return Err(miette::miette!(
122                labels = vec![Label::at(span).into()],
123                help = "must be non-empty, and match the pattern `[A-Za-z_][A-Za-z0-9_]*`",
124                "invalid variable name"
125            )
126            .with_source_code(s.to_string()));
127        }
128
129        let name = Symbol::intern(unprefixed);
130        match prefix {
131            None => Ok(Self::User(Span::new(span, name))),
132            Some('$') => Ok(Self::Global(Span::new(span, name))),
133            Some(_) => Ok(Self::Pseudo(Span::new(span, name))),
134        }
135    }
136}
137
138impl Spanned for VariableName {
139    fn span(&self) -> SourceSpan {
140        match self {
141            Self::Pseudo(name) | Self::Global(name) | Self::User(name) => name.span(),
142        }
143    }
144}
145
146impl VariableName {
147    pub fn as_str(&self) -> &str {
148        match self {
149            Self::Pseudo(s) | Self::Global(s) | Self::User(s) => s.as_str(),
150        }
151    }
152
153    pub fn into_inner(self) -> Symbol {
154        match self {
155            Self::User(s) | Self::Global(s) | Self::Pseudo(s) => s.into_inner(),
156        }
157    }
158
159    pub fn to_global(self) -> Self {
160        match self {
161            global @ (Self::Global(_) | Self::Pseudo(_)) => global,
162            Self::User(name) => Self::Global(name),
163        }
164    }
165}
166
167impl fmt::Display for VariableName {
168    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
169        match self {
170            Self::Global(name) => write!(f, "${name}"),
171            Self::Pseudo(name) => write!(f, "@{name}"),
172            Self::User(name) => f.write_str(name.as_str()),
173        }
174    }
175}
176
177impl<T> Borrow<T> for VariableName
178where
179    Symbol: Borrow<T>,
180{
181    fn borrow(&self) -> &T {
182        match self {
183            Self::Pseudo(s) | Self::Global(s) | Self::User(s) => s.inner().borrow(),
184        }
185    }
186}
187
188impl<T: ?Sized> AsRef<T> for VariableName
189where
190    Symbol: AsRef<T>,
191{
192    fn as_ref(&self) -> &T {
193        match self {
194            Self::Pseudo(s) | Self::Global(s) | Self::User(s) => (**s).as_ref(),
195        }
196    }
197}
198
199#[derive(Debug, PartialEq, Eq)]
200pub struct Variable<V> {
201    pub name: VariableName,
202    pub value: V,
203}
204impl<V> Clone for Variable<V>
205where
206    V: Clone,
207{
208    fn clone(&self) -> Self {
209        Self {
210            name: self.name,
211            value: self.value.clone(),
212        }
213    }
214}
215
216unsafe impl<V: Send> Send for Variable<V> {}
217
218unsafe impl<V: Sync> Sync for Variable<V> {}
219
220impl<V> Variable<V> {
221    pub fn new<T>(name: VariableName, value: T) -> Self
222    where
223        V: From<T>,
224    {
225        Self {
226            name,
227            value: V::from(value),
228        }
229    }
230
231    pub fn name(&self) -> &VariableName {
232        &self.name
233    }
234
235    pub fn is_pseudo(&self) -> bool {
236        matches!(self.name, VariableName::Pseudo(_))
237    }
238
239    pub fn is_global(&self) -> bool {
240        matches!(self.name, VariableName::Global(_) | VariableName::Pseudo(_))
241    }
242}
243
244impl<V> TypedVariable for Variable<V>
245where
246    V: FromStr + Clone,
247    <V as FromStr>::Err: Diagnostic + Send + Sync + 'static,
248{
249    type Key<'a> = VariableName;
250    type Value<'a> = V;
251    type Variable<'a> = Variable<V>;
252
253    fn try_parse<'input>(input: Span<&'input str>) -> Result<Self::Variable<'input>, Report> {
254        let (span, s) = input.into_parts();
255        if s.is_empty() {
256            Err(VariableError::Empty(span)
257                .into_report()
258                .with_source_code(s.to_string()))
259        } else if let Some((k, v)) = s.split_once('=') {
260            if k.is_empty() {
261                return Err(VariableError::EmptyName(span)
262                    .into_report()
263                    .with_source_code(s.to_string()));
264            }
265            let key_len = k.len() as u32;
266            let key_span = SourceSpan::new(span.source_id(), Range::new(0, key_len));
267            if !is_valid_variable_name(k) {
268                return Err(miette::miette!(
269                    labels = vec![Label::at(key_span).into()],
270                    help = "variable names must match the pattern `[A-Za-z_][A-Za-z0-9_]*`",
271                    "name contains invalid characters",
272                )
273                .with_source_code(s.to_string()));
274            }
275            let k = <VariableName as ValueParser>::try_parse(Span::new(key_span, k))
276                .wrap_err("invalid variable name")?;
277            let v = v
278                .parse::<V>()
279                .map_err(|err| Report::from(err).with_source_code(v.to_string()))
280                .wrap_err("invalid variable value")?;
281            Ok(Self::new(k, v))
282        } else {
283            Err(VariableError::MissingEquals(span)
284                .into_report()
285                .with_source_code(s.to_string()))
286        }
287    }
288}
289
290impl<V> clap::builder::ValueParserFactory for Variable<V>
291where
292    V: FromStr + Send + Sync + Clone + 'static,
293    <V as FromStr>::Err: Diagnostic + Send + Sync + Clone + 'static,
294    for<'a> Variable<V>:
295        TypedVariable<Key<'a> = VariableName, Value<'a> = V> + Send + Sync + Clone + 'static,
296{
297    type Parser = VariableParser<Variable<V>>;
298
299    fn value_parser() -> Self::Parser {
300        Default::default()
301    }
302}
303
304#[derive(Copy, Debug)]
305pub struct VariableParser<T>(core::marker::PhantomData<T>);
306
307impl<T> Clone for VariableParser<T> {
308    fn clone(&self) -> Self {
309        Self(core::marker::PhantomData)
310    }
311}
312
313unsafe impl<T: Send> Send for VariableParser<T> {}
314
315unsafe impl<T: Sync> Sync for VariableParser<T> {}
316
317impl<T> Default for VariableParser<T> {
318    fn default() -> Self {
319        Self(core::marker::PhantomData)
320    }
321}
322impl<T, V> clap::builder::TypedValueParser for VariableParser<T>
323where
324    V: Send + Sync + Clone + 'static,
325    for<'a> T: TypedVariable<Key<'a> = VariableName, Value<'a> = V, Variable<'a> = T>
326        + Send
327        + Sync
328        + Clone
329        + 'static,
330{
331    type Value = T;
332
333    fn parse_ref(
334        &self,
335        _cmd: &clap::Command,
336        _arg: Option<&clap::Arg>,
337        value: &std::ffi::OsStr,
338    ) -> Result<Self::Value, clap::Error> {
339        use clap::error::{Error, ErrorKind};
340
341        let raw = value
342            .to_str()
343            .ok_or_else(|| Error::new(ErrorKind::InvalidUtf8))?;
344
345        let span = SourceSpan::new(SourceId::UNKNOWN, Range::new(0, raw.len() as u32));
346        <T as TypedVariable>::try_parse(Span::new(span, raw)).map_err(|err| {
347            let err = if err.source_code().is_none() {
348                err.with_source_code(raw.to_string())
349            } else {
350                err
351            };
352            let diag = crate::reporting::PrintDiagnostic::new(err);
353            Error::raw(ErrorKind::InvalidValue, format!("{diag}"))
354        })
355    }
356}
357
358pub struct Variables<V>(BTreeMap<VariableName, V>);
359
360impl<V> FromIterator<Variable<V>> for Variables<V>
361where
362    V: TypedVariable,
363{
364    fn from_iter<T>(iter: T) -> Self
365    where
366        T: IntoIterator<Item = Variable<V>>,
367    {
368        Self(iter.into_iter().map(|var| (var.name, var.value)).collect())
369    }
370}
371
372impl<V: TypedVariable> Variables<V> {
373    pub fn is_defined<Q>(&self, k: &Q) -> bool
374    where
375        Q: Ord + Eq,
376        VariableName: Borrow<Q>,
377    {
378        self.0.contains_key(k)
379    }
380
381    pub fn get<Q>(&self, k: &Q) -> Option<&V>
382    where
383        Q: Ord + Eq,
384        VariableName: Borrow<Q>,
385    {
386        self.0.get(k)
387    }
388
389    pub fn define(&mut self, k: impl Into<VariableName>, v: V) -> Option<V> {
390        self.0.insert(k.into(), v)
391    }
392
393    pub fn delete<Q>(&mut self, k: &Q) -> Option<Variable<V>>
394    where
395        Q: Ord + Eq,
396        VariableName: Borrow<Q>,
397    {
398        self.0.remove_entry(k).map(|(k, v)| Variable::new(k, v))
399    }
400}
401
402pub fn is_valid_variable_name(name: &str) -> bool {
403    let mut chars = name.chars();
404    match chars.next() {
405        Some(c) if c == '_' || c.is_ascii_alphabetic() => {
406            for c in chars {
407                if c != '_' && !c.is_ascii_alphanumeric() {
408                    return false;
409                }
410            }
411        }
412        Some(_) | None => return false,
413    }
414
415    true
416}