Skip to main content

basalt_bedrock/language/
language_set.rs

1use std::borrow::Cow;
2use std::collections::BTreeSet;
3use std::fmt;
4use std::ops::{Deref, DerefMut};
5
6use serde::de::{Deserializer, MapAccess, Visitor};
7use serde::ser::{SerializeMap, Serializer};
8use serde::{Deserialize, Serialize};
9
10use crate::language::Version;
11
12use super::{BuiltInLanguage, Language, Syntax};
13
14#[derive(Default, Debug, Clone, PartialEq, Eq, Hash)]
15pub struct LanguageSet {
16    inner: BTreeSet<Language>,
17}
18
19impl LanguageSet {
20    pub fn new() -> Self {
21        Self {
22            inner: Default::default(),
23        }
24    }
25
26    pub fn get_by_str(&self, raw_name: &str) -> Option<&Language> {
27        self.inner.iter().find(|l| l.name() == raw_name)
28    }
29}
30
31impl Deref for LanguageSet {
32    type Target = BTreeSet<Language>;
33
34    fn deref(&self) -> &Self::Target {
35        &self.inner
36    }
37}
38
39impl DerefMut for LanguageSet {
40    fn deref_mut(&mut self) -> &mut Self::Target {
41        &mut self.inner
42    }
43}
44
45struct LanguageMapVisitor;
46
47impl<'de> Visitor<'de> for LanguageMapVisitor {
48    type Value = LanguageSet;
49
50    fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
51        write!(f, "a map of languages")
52    }
53
54    fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
55    where
56        M: MapAccess<'de>,
57    {
58        let mut map = LanguageSet::new();
59
60        // TODO: Spans or something for better error messages
61        while let Some((key, value)) = access.next_entry::<Cow<'_, str>, TomlLanguage>()? {
62            let val = match value {
63                TomlLanguage::Latest => Language::BuiltIn {
64                    language: key.parse().map_err(|()| {
65                        serde::de::Error::custom(format!(
66                            "Unknown built-in language: '{}'. Known languages: {}",
67                            key,
68                            BuiltInLanguage::joined_variants()
69                        ))
70                    })?,
71                    version: Version::Latest,
72                },
73                TomlLanguage::Version(v) => {
74                    let language: BuiltInLanguage = key.parse().map_err(|()| {
75                        serde::de::Error::custom(format!(
76                            "Unknown built-in language: '{}'.  Known languages: {}",
77                            key,
78                            BuiltInLanguage::joined_variants()
79                        ))
80                    })?;
81                    let version = Version::Specific(v.clone().into());
82
83                    if let Err(versions) = language.has_version(&version) {
84                        return Err(serde::de::Error::custom(format!(
85                            "Unknown {} version: '{}'.  Known versions: {}",
86                            key,
87                            v,
88                            versions
89                                .into_iter()
90                                .map(|s| format!("'{}'", s))
91                                .collect::<Vec<_>>()
92                                .join(", ")
93                        )));
94                    }
95
96                    Language::BuiltIn { language, version }
97                }
98                TomlLanguage::Custom {
99                    display_name,
100                    build,
101                    run,
102                    source_file,
103                    syntax,
104                } => Language::Custom {
105                    name: key.clone().into_owned(),
106                    display_name: display_name.unwrap_or_else(|| key.clone()).into_owned(),
107                    build: build.map(Cow::into_owned),
108                    run: run.into_owned(),
109                    syntax: syntax
110                        .or_else(|| Syntax::from_string::<M::Error>(key).ok())
111                        .unwrap_or_default(),
112                    source_file: source_file.into_owned(),
113                },
114            };
115
116            map.insert(val);
117        }
118
119        Ok(map)
120    }
121}
122
123impl<'de> Deserialize<'de> for LanguageSet {
124    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
125    where
126        D: Deserializer<'de>,
127    {
128        deserializer.deserialize_map(LanguageMapVisitor)
129    }
130}
131
132impl Serialize for LanguageSet {
133    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
134    where
135        S: Serializer,
136    {
137        let mut map = serializer.serialize_map(Some(self.inner.len()))?;
138        for lang in &self.inner {
139            match lang {
140                Language::BuiltIn {
141                    language: name,
142                    version: value,
143                } => {
144                    map.serialize_entry(name.name(), &TomlLanguage::from(value))?;
145                }
146                Language::Custom {
147                    name,
148                    display_name,
149                    build,
150                    run,
151                    source_file,
152                    syntax,
153                } => {
154                    map.serialize_entry(
155                        name,
156                        &TomlLanguage::Custom {
157                            display_name: Some(display_name.into()),
158                            build: build.as_ref().map(Into::into),
159                            run: run.into(),
160                            source_file: source_file.into(),
161                            syntax: Some(*syntax),
162                        },
163                    )?;
164                }
165            }
166        }
167        map.end()
168    }
169}
170
171/// Language as represented in the toml file
172#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Ord, PartialOrd, Hash)]
173#[serde(rename_all = "kebab-case", deny_unknown_fields)]
174enum TomlLanguage<'a> {
175    #[serde(alias = "*")]
176    Latest,
177    #[serde(untagged)]
178    Version(Cow<'a, str>),
179    #[serde(untagged)]
180    Custom {
181        #[serde(alias = "name")]
182        display_name: Option<Cow<'a, str>>,
183        build: Option<Cow<'a, str>>,
184        run: Cow<'a, str>,
185        source_file: Cow<'a, str>,
186        syntax: Option<Syntax>,
187    },
188}
189
190impl<'a> From<&'a Version> for TomlLanguage<'a> {
191    fn from(value: &'a Version) -> Self {
192        match value {
193            Version::Latest => TomlLanguage::Latest,
194            Version::Specific(v) => TomlLanguage::Version(v.into()),
195        }
196    }
197}