litehouse_config/
lib.rs

1mod hash_read;
2
3use std::{collections::HashMap, fmt::Display, num::NonZeroU8, path::Path, str::FromStr};
4
5use hash_read::HashRead;
6use miette::{Diagnostic, NamedSource, SourceOffset};
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9use thiserror::Error;
10use tokio::io::{AsyncRead, AsyncWrite};
11
12const REGISTRY_SEPARATOR: &str = "::";
13const VERSION_SEPARATOR: &str = "@";
14const SHA_SEPERATOR: &str = "~";
15
16#[derive(JsonSchema, Serialize, Deserialize, Debug, Default)]
17pub struct LitehouseConfig {
18    /// The list of plugins to use in this litehouse
19    pub plugins: HashMap<String, PluginInstance>,
20    /// Additional registries to look for plugins in
21    #[serde(default, skip_serializing_if = "Vec::is_empty")]
22    pub registries: Vec<Registry>,
23    /// Additional plugins to import from registries. Without a registry prefix, uses the default.
24    #[serde(default, skip_serializing_if = "Vec::is_empty")]
25    #[schemars(with = "Vec<String>")]
26    pub imports: Vec<Import>,
27    /// The capabilities of this litehouse. Plugins that attempt to use capabilities not present in
28    /// this list will fail. By default, plugins are not given any capabilities and are completely
29    /// sandboxed.
30    ///
31    /// Can be one of the following:
32    /// - `http-server:<port>`: Start an HTTP server on the given port
33    /// - `http-client:<url>`: Make HTTP requests to the given URL
34    #[serde(default, skip_serializing_if = "Vec::is_empty")]
35    #[schemars(with = "Vec<String>")]
36    pub capabilities: Vec<Capability>,
37    /// Advanced engine configuration
38    #[serde(default, skip_serializing_if = "is_default")]
39    pub engine: Engine,
40}
41
42#[derive(JsonSchema, Serialize, Deserialize, Debug, Default, PartialEq)]
43pub struct Engine {
44    /// The strategy to use for sandboxing plugins. By default, each plugin instance is run in its
45    /// own storage sandbox, for maximum parallelism and isolation. If you are in a constrained
46    /// environment, you may want to put all plugins in the same storage instead.
47    #[serde(default, skip_serializing_if = "is_default")]
48    pub sandbox_strategy: SandboxStrategy,
49    #[serde(default, skip_serializing_if = "is_default")]
50    pub max_parallel_builds: MaxBuildCount,
51    #[serde(default, skip_serializing_if = "is_default")]
52    pub max_parallel_instantiations: MaxBuildCount,
53}
54
55fn is_default<T: Default + PartialEq>(t: &T) -> bool {
56    *t == Default::default()
57}
58
59#[derive(JsonSchema, Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)]
60#[serde(rename_all = "snake_case")]
61pub enum SandboxStrategy {
62    /// All plugins are run in the same storage sandbox
63    Global,
64    /// Each plugin type is run in its own storage sandbox
65    Plugin,
66    /// Each plugin instance is run in its own storage sandbox
67    #[default]
68    Instance,
69}
70
71#[derive(JsonSchema, Serialize, Deserialize, Debug, PartialEq)]
72pub struct MaxBuildCount(NonZeroU8);
73
74impl Default for MaxBuildCount {
75    fn default() -> Self {
76        MaxBuildCount(NonZeroU8::new(10).unwrap())
77    }
78}
79
80impl From<MaxBuildCount> for u8 {
81    fn from(count: MaxBuildCount) -> Self {
82        count.0.get()
83    }
84}
85
86impl LitehouseConfig {
87    pub fn load() -> Result<Self, ConfigError> {
88        let data = std::fs::read_to_string("settings.json")?;
89        let config: LitehouseConfig = serde_json::from_str(&data).map_err(|e| {
90            ConfigError::Parse(ParseError {
91                err_span: SourceOffset::from_location(&data, e.line() - 1, e.column()).into(),
92                src: NamedSource::new("settings.json", data),
93                error: e.to_string(),
94            })
95        })?;
96        Ok(config)
97    }
98
99    pub fn save(&self) -> Result<(), ConfigError> {
100        let file = std::fs::File::create("settings.json")?;
101        serde_json::to_writer_pretty(&file, self).map_err(ConfigError::Write)?;
102        Ok(())
103    }
104}
105
106pub fn directories() -> Option<directories_next::ProjectDirs> {
107    directories_next::ProjectDirs::from("com", "litehouse", "litehouse")
108}
109
110#[derive(thiserror::Error, Debug, miette::Diagnostic)]
111pub enum ConfigError {
112    #[error("io error")]
113    Io(#[from] std::io::Error),
114    #[error(transparent)]
115    #[diagnostic(transparent)]
116    Parse(#[from] ParseError),
117    #[error("write error")]
118    Write(serde_json::Error),
119}
120
121#[derive(thiserror::Error, Debug, miette::Diagnostic)]
122#[error("parse error")]
123#[diagnostic(
124    code(config::invalid),
125    url(docsrs),
126    help("check the configuration file for errors")
127)]
128/// Raised when there is an error parsing the configuration file.
129/// This is likely a formatting issue.
130pub struct ParseError {
131    #[source_code]
132    pub src: NamedSource<String>,
133    pub error: String,
134
135    #[label = "{error}"]
136    pub err_span: miette::SourceSpan,
137}
138
139#[derive(Debug, Clone)]
140pub enum Capability {
141    HttpServer(usize),
142    HttpClient(String),
143}
144
145impl Display for Capability {
146    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147        match self {
148            Capability::HttpServer(port) => write!(f, "http-server:{}", port),
149            Capability::HttpClient(url) => write!(f, "http-client:{}", url),
150        }
151    }
152}
153
154impl FromStr for Capability {
155    type Err = CapabilityParseError;
156
157    fn from_str(s: &str) -> Result<Self, Self::Err> {
158        let (name, value) = s
159            .split_once(':')
160            .map(|(name, value)| (name, value.to_string()))
161            .ok_or_else(|| CapabilityParseError::MissingDelimiter)?;
162        match name {
163            "http-server" => Ok(value
164                .parse()
165                .map(Capability::HttpServer)
166                .map_err(|_| CapabilityParseError::InvalidPort(value)))?,
167            "http-client" => Ok(Capability::HttpClient(value)),
168            variant => Err(CapabilityParseError::UnknownVariant(variant.to_string())),
169        }
170    }
171}
172
173#[derive(Error, Diagnostic, Debug)]
174#[error("invalid capability")]
175#[diagnostic(
176    code(config::invalid_capability),
177    url(docsrs),
178    help("check the capability name and value")
179)]
180pub enum CapabilityParseError {
181    #[error("unknown variant: {0}")]
182    UnknownVariant(String),
183    #[error("missing delimiter")]
184    MissingDelimiter,
185    #[error("invalid port: {0}")]
186    InvalidPort(String),
187}
188
189impl Serialize for Capability {
190    fn serialize<S>(&self, serializer: S) -> std::prelude::v1::Result<S::Ok, S::Error>
191    where
192        S: serde::Serializer,
193    {
194        let string = self.to_string();
195        serializer.serialize_str(&string)
196    }
197}
198
199impl<'de> Deserialize<'de> for Capability {
200    fn deserialize<D>(deserializer: D) -> std::prelude::v1::Result<Self, D::Error>
201    where
202        D: serde::Deserializer<'de>,
203    {
204        let s = String::deserialize(deserializer)?;
205        s.parse().map_err(serde::de::Error::custom)
206    }
207}
208
209#[derive(JsonSchema, Serialize, Deserialize, Debug)]
210pub struct Registry {
211    /// The local name of the registry
212    pub name: String,
213    /// The url to use for interacting with the registry
214    pub url: String,
215}
216
217/// A plugin import. Serializes to a string with the format `registry::plugin`
218#[derive(Debug)]
219pub struct Import {
220    pub registry: Option<String>,
221    pub plugin: String,
222    pub version: Option<semver::Version>,
223    pub sha: Option<Blake3>,
224}
225
226impl Import {
227    pub fn file_name(&self) -> String {
228        let version = self
229            .version
230            .as_ref()
231            .map(|v| format!("{}{}", VERSION_SEPARATOR, v))
232            .unwrap_or_default();
233        format!("{}{}.wasm", self.plugin, version)
234    }
235
236    pub async fn read_sha(&mut self, base_dir: &Path) {
237        use futures::StreamExt;
238
239        // if there is no version, we need to resolve it
240        if self.version.is_none() {
241            let files = tokio::fs::read_dir(base_dir).await.unwrap();
242            let stream = tokio_stream::wrappers::ReadDirStream::new(files);
243            let max_version = stream
244                .filter_map(|entry| {
245                    let import = Import::from_str(
246                        entry
247                            .unwrap()
248                            .file_name()
249                            .to_string_lossy()
250                            .strip_suffix(".wasm")
251                            .unwrap(),
252                    )
253                    .unwrap();
254                    let plugin = &self.plugin;
255                    async move {
256                        if import.plugin.eq(plugin) {
257                            Some(import)
258                        } else {
259                            None
260                        }
261                    }
262                })
263                .collect::<Vec<_>>()
264                .await
265                .into_iter()
266                .max();
267
268            if let Some(import) = max_version {
269                self.version = import.version;
270            } else {
271                return;
272            }
273        }
274
275        let plugin_path = base_dir.join(self.file_name());
276        let hasher = blake3::Hasher::new();
277        let file = tokio::fs::File::open(plugin_path).await.unwrap();
278        let mut hasher = HashRead::new(file, hasher);
279        tokio::io::copy(&mut hasher, &mut tokio::io::empty())
280            .await
281            .unwrap();
282        let output = hasher.finalize();
283        let b: [u8; 32] = output.as_slice().try_into().unwrap();
284        self.sha = Some(Blake3(b));
285    }
286
287    /// Verify that the plugin at this path matches
288    /// this import. This validates the version
289    /// via the file name as well as the sha if
290    /// one is specified.
291    pub async fn verify(&self, path: &Path) -> Option<()> {
292        self.sha.as_ref()?;
293
294        let mut file = tokio::fs::File::open(path).await.unwrap();
295        self.copy(&mut file, &mut tokio::io::empty())
296            .await
297            .map(|_| ())
298    }
299
300    /// Copy the plugin from src to dest, validating the sha in the
301    /// process.
302    pub async fn copy<R: AsyncRead + Unpin, W: AsyncWrite + Unpin>(
303        &self,
304        src: R,
305        dest: &mut W,
306    ) -> Option<u64> {
307        let hasher = blake3::Hasher::new();
308        let mut hasher = HashRead::new(src, hasher);
309        let bytes = tokio::io::copy(&mut hasher, dest).await.unwrap();
310        let output = hasher.finalize();
311
312        if let Some(Blake3(sha)) = self.sha {
313            // maybe consider constant time comparison fn
314            if *output != sha {
315                eprintln!("sha mismatch\n  got {:02X?}\n  exp {:02X?}", &*output, sha);
316                return None;
317            }
318        }
319
320        Some(bytes)
321    }
322}
323
324impl PartialEq for Import {
325    fn eq(&self, other: &Self) -> bool {
326        self.plugin == other.plugin && self.version == other.version && self.sha == other.sha
327    }
328}
329
330impl Eq for Import {}
331
332impl PartialOrd for Import {
333    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
334        Some(self.cmp(other))
335    }
336}
337
338impl Ord for Import {
339    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
340        match self.plugin.cmp(&other.plugin) {
341            std::cmp::Ordering::Equal => self.version.cmp(&other.version),
342            other => other,
343        }
344    }
345}
346
347impl Serialize for Import {
348    fn serialize<S>(&self, serializer: S) -> std::prelude::v1::Result<S::Ok, S::Error>
349    where
350        S: serde::Serializer,
351    {
352        let string = self.to_string();
353        serializer.serialize_str(&string)
354    }
355}
356
357impl<'de> Deserialize<'de> for Import {
358    fn deserialize<D>(deserializer: D) -> std::prelude::v1::Result<Self, D::Error>
359    where
360        D: serde::Deserializer<'de>,
361    {
362        let s = String::deserialize(deserializer)?;
363        s.parse().map_err(serde::de::Error::custom)
364    }
365}
366
367#[derive(Error, Debug, Diagnostic)]
368#[error("failed to parse import")]
369pub enum ImportParseError {
370    SemverParseError(#[from] SemverParseError),
371    Blake3ParseError(#[from] Blake3ParseError),
372}
373
374#[derive(Error, Debug, Diagnostic)]
375#[error("failed to parse import")]
376#[diagnostic(
377    code(import::invalid_format),
378    url(docsrs),
379    help("check the documentation for the correct format")
380)]
381pub struct SemverParseError {
382    #[source_code]
383    src: String,
384
385    err: semver::Error,
386
387    #[label("{err}")]
388    err_span: miette::SourceSpan,
389}
390
391#[derive(Error, Debug, Diagnostic)]
392#[error("failed to parse import")]
393#[diagnostic(
394    code(import::invalid_format),
395    url(docsrs),
396    help("check the documentation for the correct format")
397)]
398pub struct Blake3ParseError {
399    #[source_code]
400    src: String,
401
402    err: blake3::HexError,
403
404    #[label("{err}")]
405    err_span: miette::SourceSpan,
406}
407
408impl FromStr for Import {
409    type Err = ImportParseError;
410
411    fn from_str(s: &str) -> Result<Self, Self::Err> {
412        let rest = s.strip_suffix(".wasm").unwrap_or(s); // remove file extension
413        let (registry, rest) = rest
414            .split_once(REGISTRY_SEPARATOR)
415            .map(|(registry, rest)| (Some(registry), rest))
416            .unwrap_or((None, rest));
417        let (sha, rest) = rest
418            .rsplit_once(SHA_SEPERATOR)
419            .map(|(rest, sha)| (Some(sha), rest))
420            .unwrap_or((None, rest));
421        let (package, version) = rest
422            .split_once(VERSION_SEPARATOR)
423            .map(|(package, version)| {
424                version
425                    .parse()
426                    .map(|v| (package, Some(v)))
427                    .map_err(|e| (e, version))
428            })
429            .unwrap_or(Ok((rest, None)))
430            .map_err(|(e, version)| SemverParseError {
431                err: e,
432                src: s.to_string(),
433                err_span: s
434                    .find(version)
435                    .map(|i| i..i + version.len())
436                    .unwrap()
437                    .into(),
438            })?;
439
440        Ok(Import {
441            registry: registry.map(str::to_string),
442            plugin: package.to_string(),
443            version,
444            sha: sha
445                .map(|sha| {
446                    Blake3::from_str(sha).map_err(|e| Blake3ParseError {
447                        err: e,
448                        err_span: s.find(sha).map(|i| i..i + s.len()).unwrap().into(),
449                        src: s.to_string(),
450                    })
451                })
452                .transpose()?,
453        })
454    }
455}
456
457impl Display for Import {
458    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
459        let registry = self
460            .registry
461            .as_deref()
462            .map(|s| format!("{}{}", s, REGISTRY_SEPARATOR))
463            .unwrap_or_default();
464        let version = self
465            .version
466            .as_ref()
467            .map(|v| format!("{}{}", VERSION_SEPARATOR, v))
468            .unwrap_or_default();
469        let sha = self
470            .sha
471            .as_ref()
472            .map(|v| format!("{}{}", SHA_SEPERATOR, v.to_string()))
473            .unwrap_or_default();
474
475        write!(f, "{}{}{}{}", registry, self.plugin, version, sha)
476    }
477}
478
479#[derive(Debug, PartialEq, Eq)]
480pub struct Blake3([u8; blake3::OUT_LEN]);
481
482impl FromStr for Blake3 {
483    type Err = blake3::HexError;
484    fn from_str(s: &str) -> Result<Self, Self::Err> {
485        let hash = s.strip_prefix("blake3:").unwrap();
486        Ok(Self(blake3::Hash::from_hex(hash)?.as_bytes().to_owned()))
487    }
488}
489
490impl Display for Blake3 {
491    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
492        let hash = blake3::Hash::from_bytes(self.0);
493        write!(f, "blake3:{}", hash.to_hex())
494    }
495}
496
497#[derive(JsonSchema, Serialize, Deserialize, Debug)]
498pub struct PluginInstance {
499    #[schemars(with = "String")]
500    pub plugin: Import,
501    pub config: Option<serde_json::Value>,
502}
503
504#[cfg(test)]
505mod test {
506    use super::*;
507    use test_case::test_case;
508
509    #[test_case("package" ; "just package")]
510    #[test_case("registry::package" ; "registry")]
511    #[test_case("registry::package@1.0.0" ; "version")]
512    #[test_case("registry::package@1.0.0~blake3:deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef" ; "everything")]
513    #[test_case("registry::package~blake3:deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef" ; "no version")]
514    #[test_case("package~blake3:deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef" ; "just sha")]
515    fn roundtrip(import_exp: &str) {
516        let package = Import::from_str(import_exp).unwrap();
517        let import_actual = package.to_string();
518        assert_eq!(import_exp, import_actual);
519        assert_eq!(package.plugin, "package");
520    }
521}