1mod hash_read;
2
3use std::{collections::HashMap, fmt::Display, num::NonZeroU8, path::Path, str::FromStr};
4
5use hash_read::HashRead;
6use miette::{Diagnostic, NamedSource, SourceOffset};
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9use thiserror::Error;
10use tokio::io::{AsyncRead, AsyncWrite};
11
12const REGISTRY_SEPARATOR: &str = "::";
13const VERSION_SEPARATOR: &str = "@";
14const SHA_SEPERATOR: &str = "~";
15
16#[derive(JsonSchema, Serialize, Deserialize, Debug, Default)]
17pub struct LitehouseConfig {
18 pub plugins: HashMap<String, PluginInstance>,
20 #[serde(default, skip_serializing_if = "Vec::is_empty")]
22 pub registries: Vec<Registry>,
23 #[serde(default, skip_serializing_if = "Vec::is_empty")]
25 #[schemars(with = "Vec<String>")]
26 pub imports: Vec<Import>,
27 #[serde(default, skip_serializing_if = "Vec::is_empty")]
35 #[schemars(with = "Vec<String>")]
36 pub capabilities: Vec<Capability>,
37 #[serde(default, skip_serializing_if = "is_default")]
39 pub engine: Engine,
40}
41
42#[derive(JsonSchema, Serialize, Deserialize, Debug, Default, PartialEq)]
43pub struct Engine {
44 #[serde(default, skip_serializing_if = "is_default")]
48 pub sandbox_strategy: SandboxStrategy,
49 #[serde(default, skip_serializing_if = "is_default")]
50 pub max_parallel_builds: MaxBuildCount,
51 #[serde(default, skip_serializing_if = "is_default")]
52 pub max_parallel_instantiations: MaxBuildCount,
53}
54
55fn is_default<T: Default + PartialEq>(t: &T) -> bool {
56 *t == Default::default()
57}
58
59#[derive(JsonSchema, Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)]
60#[serde(rename_all = "snake_case")]
61pub enum SandboxStrategy {
62 Global,
64 Plugin,
66 #[default]
68 Instance,
69}
70
71#[derive(JsonSchema, Serialize, Deserialize, Debug, PartialEq)]
72pub struct MaxBuildCount(NonZeroU8);
73
74impl Default for MaxBuildCount {
75 fn default() -> Self {
76 MaxBuildCount(NonZeroU8::new(10).unwrap())
77 }
78}
79
80impl From<MaxBuildCount> for u8 {
81 fn from(count: MaxBuildCount) -> Self {
82 count.0.get()
83 }
84}
85
86impl LitehouseConfig {
87 pub fn load() -> Result<Self, ConfigError> {
88 let data = std::fs::read_to_string("settings.json")?;
89 let config: LitehouseConfig = serde_json::from_str(&data).map_err(|e| {
90 ConfigError::Parse(ParseError {
91 err_span: SourceOffset::from_location(&data, e.line() - 1, e.column()).into(),
92 src: NamedSource::new("settings.json", data),
93 error: e.to_string(),
94 })
95 })?;
96 Ok(config)
97 }
98
99 pub fn save(&self) -> Result<(), ConfigError> {
100 let file = std::fs::File::create("settings.json")?;
101 serde_json::to_writer_pretty(&file, self).map_err(ConfigError::Write)?;
102 Ok(())
103 }
104}
105
106pub fn directories() -> Option<directories_next::ProjectDirs> {
107 directories_next::ProjectDirs::from("com", "litehouse", "litehouse")
108}
109
110#[derive(thiserror::Error, Debug, miette::Diagnostic)]
111pub enum ConfigError {
112 #[error("io error")]
113 Io(#[from] std::io::Error),
114 #[error(transparent)]
115 #[diagnostic(transparent)]
116 Parse(#[from] ParseError),
117 #[error("write error")]
118 Write(serde_json::Error),
119}
120
121#[derive(thiserror::Error, Debug, miette::Diagnostic)]
122#[error("parse error")]
123#[diagnostic(
124 code(config::invalid),
125 url(docsrs),
126 help("check the configuration file for errors")
127)]
128pub struct ParseError {
131 #[source_code]
132 pub src: NamedSource<String>,
133 pub error: String,
134
135 #[label = "{error}"]
136 pub err_span: miette::SourceSpan,
137}
138
139#[derive(Debug, Clone)]
140pub enum Capability {
141 HttpServer(usize),
142 HttpClient(String),
143}
144
145impl Display for Capability {
146 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147 match self {
148 Capability::HttpServer(port) => write!(f, "http-server:{}", port),
149 Capability::HttpClient(url) => write!(f, "http-client:{}", url),
150 }
151 }
152}
153
154impl FromStr for Capability {
155 type Err = CapabilityParseError;
156
157 fn from_str(s: &str) -> Result<Self, Self::Err> {
158 let (name, value) = s
159 .split_once(':')
160 .map(|(name, value)| (name, value.to_string()))
161 .ok_or_else(|| CapabilityParseError::MissingDelimiter)?;
162 match name {
163 "http-server" => Ok(value
164 .parse()
165 .map(Capability::HttpServer)
166 .map_err(|_| CapabilityParseError::InvalidPort(value)))?,
167 "http-client" => Ok(Capability::HttpClient(value)),
168 variant => Err(CapabilityParseError::UnknownVariant(variant.to_string())),
169 }
170 }
171}
172
173#[derive(Error, Diagnostic, Debug)]
174#[error("invalid capability")]
175#[diagnostic(
176 code(config::invalid_capability),
177 url(docsrs),
178 help("check the capability name and value")
179)]
180pub enum CapabilityParseError {
181 #[error("unknown variant: {0}")]
182 UnknownVariant(String),
183 #[error("missing delimiter")]
184 MissingDelimiter,
185 #[error("invalid port: {0}")]
186 InvalidPort(String),
187}
188
189impl Serialize for Capability {
190 fn serialize<S>(&self, serializer: S) -> std::prelude::v1::Result<S::Ok, S::Error>
191 where
192 S: serde::Serializer,
193 {
194 let string = self.to_string();
195 serializer.serialize_str(&string)
196 }
197}
198
199impl<'de> Deserialize<'de> for Capability {
200 fn deserialize<D>(deserializer: D) -> std::prelude::v1::Result<Self, D::Error>
201 where
202 D: serde::Deserializer<'de>,
203 {
204 let s = String::deserialize(deserializer)?;
205 s.parse().map_err(serde::de::Error::custom)
206 }
207}
208
209#[derive(JsonSchema, Serialize, Deserialize, Debug)]
210pub struct Registry {
211 pub name: String,
213 pub url: String,
215}
216
217#[derive(Debug)]
219pub struct Import {
220 pub registry: Option<String>,
221 pub plugin: String,
222 pub version: Option<semver::Version>,
223 pub sha: Option<Blake3>,
224}
225
226impl Import {
227 pub fn file_name(&self) -> String {
228 let version = self
229 .version
230 .as_ref()
231 .map(|v| format!("{}{}", VERSION_SEPARATOR, v))
232 .unwrap_or_default();
233 format!("{}{}.wasm", self.plugin, version)
234 }
235
236 pub async fn read_sha(&mut self, base_dir: &Path) {
237 use futures::StreamExt;
238
239 if self.version.is_none() {
241 let files = tokio::fs::read_dir(base_dir).await.unwrap();
242 let stream = tokio_stream::wrappers::ReadDirStream::new(files);
243 let max_version = stream
244 .filter_map(|entry| {
245 let import = Import::from_str(
246 entry
247 .unwrap()
248 .file_name()
249 .to_string_lossy()
250 .strip_suffix(".wasm")
251 .unwrap(),
252 )
253 .unwrap();
254 let plugin = &self.plugin;
255 async move {
256 if import.plugin.eq(plugin) {
257 Some(import)
258 } else {
259 None
260 }
261 }
262 })
263 .collect::<Vec<_>>()
264 .await
265 .into_iter()
266 .max();
267
268 if let Some(import) = max_version {
269 self.version = import.version;
270 } else {
271 return;
272 }
273 }
274
275 let plugin_path = base_dir.join(self.file_name());
276 let hasher = blake3::Hasher::new();
277 let file = tokio::fs::File::open(plugin_path).await.unwrap();
278 let mut hasher = HashRead::new(file, hasher);
279 tokio::io::copy(&mut hasher, &mut tokio::io::empty())
280 .await
281 .unwrap();
282 let output = hasher.finalize();
283 let b: [u8; 32] = output.as_slice().try_into().unwrap();
284 self.sha = Some(Blake3(b));
285 }
286
287 pub async fn verify(&self, path: &Path) -> Option<()> {
292 self.sha.as_ref()?;
293
294 let mut file = tokio::fs::File::open(path).await.unwrap();
295 self.copy(&mut file, &mut tokio::io::empty())
296 .await
297 .map(|_| ())
298 }
299
300 pub async fn copy<R: AsyncRead + Unpin, W: AsyncWrite + Unpin>(
303 &self,
304 src: R,
305 dest: &mut W,
306 ) -> Option<u64> {
307 let hasher = blake3::Hasher::new();
308 let mut hasher = HashRead::new(src, hasher);
309 let bytes = tokio::io::copy(&mut hasher, dest).await.unwrap();
310 let output = hasher.finalize();
311
312 if let Some(Blake3(sha)) = self.sha {
313 if *output != sha {
315 eprintln!("sha mismatch\n got {:02X?}\n exp {:02X?}", &*output, sha);
316 return None;
317 }
318 }
319
320 Some(bytes)
321 }
322}
323
324impl PartialEq for Import {
325 fn eq(&self, other: &Self) -> bool {
326 self.plugin == other.plugin && self.version == other.version && self.sha == other.sha
327 }
328}
329
330impl Eq for Import {}
331
332impl PartialOrd for Import {
333 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
334 Some(self.cmp(other))
335 }
336}
337
338impl Ord for Import {
339 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
340 match self.plugin.cmp(&other.plugin) {
341 std::cmp::Ordering::Equal => self.version.cmp(&other.version),
342 other => other,
343 }
344 }
345}
346
347impl Serialize for Import {
348 fn serialize<S>(&self, serializer: S) -> std::prelude::v1::Result<S::Ok, S::Error>
349 where
350 S: serde::Serializer,
351 {
352 let string = self.to_string();
353 serializer.serialize_str(&string)
354 }
355}
356
357impl<'de> Deserialize<'de> for Import {
358 fn deserialize<D>(deserializer: D) -> std::prelude::v1::Result<Self, D::Error>
359 where
360 D: serde::Deserializer<'de>,
361 {
362 let s = String::deserialize(deserializer)?;
363 s.parse().map_err(serde::de::Error::custom)
364 }
365}
366
367#[derive(Error, Debug, Diagnostic)]
368#[error("failed to parse import")]
369pub enum ImportParseError {
370 SemverParseError(#[from] SemverParseError),
371 Blake3ParseError(#[from] Blake3ParseError),
372}
373
374#[derive(Error, Debug, Diagnostic)]
375#[error("failed to parse import")]
376#[diagnostic(
377 code(import::invalid_format),
378 url(docsrs),
379 help("check the documentation for the correct format")
380)]
381pub struct SemverParseError {
382 #[source_code]
383 src: String,
384
385 err: semver::Error,
386
387 #[label("{err}")]
388 err_span: miette::SourceSpan,
389}
390
391#[derive(Error, Debug, Diagnostic)]
392#[error("failed to parse import")]
393#[diagnostic(
394 code(import::invalid_format),
395 url(docsrs),
396 help("check the documentation for the correct format")
397)]
398pub struct Blake3ParseError {
399 #[source_code]
400 src: String,
401
402 err: blake3::HexError,
403
404 #[label("{err}")]
405 err_span: miette::SourceSpan,
406}
407
408impl FromStr for Import {
409 type Err = ImportParseError;
410
411 fn from_str(s: &str) -> Result<Self, Self::Err> {
412 let rest = s.strip_suffix(".wasm").unwrap_or(s); let (registry, rest) = rest
414 .split_once(REGISTRY_SEPARATOR)
415 .map(|(registry, rest)| (Some(registry), rest))
416 .unwrap_or((None, rest));
417 let (sha, rest) = rest
418 .rsplit_once(SHA_SEPERATOR)
419 .map(|(rest, sha)| (Some(sha), rest))
420 .unwrap_or((None, rest));
421 let (package, version) = rest
422 .split_once(VERSION_SEPARATOR)
423 .map(|(package, version)| {
424 version
425 .parse()
426 .map(|v| (package, Some(v)))
427 .map_err(|e| (e, version))
428 })
429 .unwrap_or(Ok((rest, None)))
430 .map_err(|(e, version)| SemverParseError {
431 err: e,
432 src: s.to_string(),
433 err_span: s
434 .find(version)
435 .map(|i| i..i + version.len())
436 .unwrap()
437 .into(),
438 })?;
439
440 Ok(Import {
441 registry: registry.map(str::to_string),
442 plugin: package.to_string(),
443 version,
444 sha: sha
445 .map(|sha| {
446 Blake3::from_str(sha).map_err(|e| Blake3ParseError {
447 err: e,
448 err_span: s.find(sha).map(|i| i..i + s.len()).unwrap().into(),
449 src: s.to_string(),
450 })
451 })
452 .transpose()?,
453 })
454 }
455}
456
457impl Display for Import {
458 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
459 let registry = self
460 .registry
461 .as_deref()
462 .map(|s| format!("{}{}", s, REGISTRY_SEPARATOR))
463 .unwrap_or_default();
464 let version = self
465 .version
466 .as_ref()
467 .map(|v| format!("{}{}", VERSION_SEPARATOR, v))
468 .unwrap_or_default();
469 let sha = self
470 .sha
471 .as_ref()
472 .map(|v| format!("{}{}", SHA_SEPERATOR, v.to_string()))
473 .unwrap_or_default();
474
475 write!(f, "{}{}{}{}", registry, self.plugin, version, sha)
476 }
477}
478
479#[derive(Debug, PartialEq, Eq)]
480pub struct Blake3([u8; blake3::OUT_LEN]);
481
482impl FromStr for Blake3 {
483 type Err = blake3::HexError;
484 fn from_str(s: &str) -> Result<Self, Self::Err> {
485 let hash = s.strip_prefix("blake3:").unwrap();
486 Ok(Self(blake3::Hash::from_hex(hash)?.as_bytes().to_owned()))
487 }
488}
489
490impl Display for Blake3 {
491 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
492 let hash = blake3::Hash::from_bytes(self.0);
493 write!(f, "blake3:{}", hash.to_hex())
494 }
495}
496
497#[derive(JsonSchema, Serialize, Deserialize, Debug)]
498pub struct PluginInstance {
499 #[schemars(with = "String")]
500 pub plugin: Import,
501 pub config: Option<serde_json::Value>,
502}
503
504#[cfg(test)]
505mod test {
506 use super::*;
507 use test_case::test_case;
508
509 #[test_case("package" ; "just package")]
510 #[test_case("registry::package" ; "registry")]
511 #[test_case("registry::package@1.0.0" ; "version")]
512 #[test_case("registry::package@1.0.0~blake3:deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef" ; "everything")]
513 #[test_case("registry::package~blake3:deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef" ; "no version")]
514 #[test_case("package~blake3:deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef" ; "just sha")]
515 fn roundtrip(import_exp: &str) {
516 let package = Import::from_str(import_exp).unwrap();
517 let import_actual = package.to_string();
518 assert_eq!(import_exp, import_actual);
519 assert_eq!(package.plugin, "package");
520 }
521}