1use std::sync::Arc;
4use std::time::Duration;
5
6use thiserror::Error;
7use wasmtime::component::{Component, Linker, ResourceTable};
8use wasmtime::{AsContextMut, Engine as WtEngine, Store, StoreLimits, StoreLimitsBuilder};
9use wasmtime_wasi::{WasiCtx, WasiCtxBuilder, WasiView};
10
11use forge_ir::{Diagnostic, Ir, PluginInfo};
12use forge_ir_bindgen::bindings;
13use forge_ir_bindgen::convert::{self, ResourceKindRepr, StageErrorRepr};
14
15#[derive(Debug, Error)]
20pub enum StageError {
21 #[error("plugin rejected input: {reason}")]
22 Rejected {
23 reason: String,
24 diagnostics: Vec<Diagnostic>,
25 },
26 #[error("plugin trap or malformed return: {0}")]
27 PluginBug(String),
28 #[error("plugin config invalid: {0}")]
29 ConfigInvalid(String),
30 #[error("plugin exceeded {0:?}")]
31 ResourceExceeded(ResourceKind),
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum ResourceKind {
36 Fuel,
37 Memory,
38 Time,
39 OutputSize,
40}
41
42#[derive(Debug, Clone)]
43pub struct TransformOutput {
44 pub spec: Ir,
45 pub diagnostics: Vec<Diagnostic>,
46}
47
48#[derive(Debug, Clone)]
49pub struct GenerationOutput {
50 pub files: Vec<OutputFile>,
51 pub diagnostics: Vec<Diagnostic>,
52}
53
54#[derive(Debug, Clone)]
55pub struct OutputFile {
56 pub path: String,
57 pub content: Vec<u8>,
58 pub mode: FileMode,
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62pub enum FileMode {
63 Text,
64 Binary,
65 Executable,
66}
67
68#[derive(Debug, Clone, Copy)]
73pub struct Limits {
74 pub fuel: u64,
75 pub memory_bytes: usize,
76 pub wall_clock_ms: u64,
77 pub output_files_max: u32,
78 pub output_total_bytes_max: u64,
79 pub output_per_file_bytes_max: u64,
80}
81
82impl Limits {
83 pub const fn transformer() -> Self {
84 Self {
85 fuel: 5_000_000_000,
86 memory_bytes: 128 * 1024 * 1024,
87 wall_clock_ms: 5_000,
88 output_files_max: 0,
89 output_total_bytes_max: 0,
90 output_per_file_bytes_max: 0,
91 }
92 }
93
94 pub const fn generator() -> Self {
95 Self {
96 fuel: 50_000_000_000,
97 memory_bytes: 512 * 1024 * 1024,
98 wall_clock_ms: 30_000,
99 output_files_max: 10_000,
100 output_total_bytes_max: 256 * 1024 * 1024,
101 output_per_file_bytes_max: 16 * 1024 * 1024,
102 }
103 }
104}
105
106#[derive(Clone)]
115pub struct Engine {
116 inner: Arc<EngineInner>,
117}
118
119struct EngineInner {
120 wt: WtEngine,
121 _ticker: EpochTicker,
123}
124
125impl std::fmt::Debug for Engine {
126 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127 f.debug_struct("Engine").finish_non_exhaustive()
128 }
129}
130
131impl Engine {
132 pub fn new() -> Result<Self, EngineError> {
133 let mut cfg = wasmtime::Config::new();
134 cfg.wasm_component_model(true)
135 .consume_fuel(true)
136 .epoch_interruption(true);
137 cfg.relaxed_simd_deterministic(true);
139
140 let wt = WtEngine::new(&cfg).map_err(|e| EngineError::Init(e.to_string()))?;
141
142 let ticker = EpochTicker::spawn(wt.clone(), Duration::from_millis(10));
145
146 Ok(Engine {
147 inner: Arc::new(EngineInner {
148 wt,
149 _ticker: ticker,
150 }),
151 })
152 }
153
154 pub fn raw(&self) -> &WtEngine {
155 &self.inner.wt
156 }
157}
158
159#[derive(Debug, Error)]
160pub enum EngineError {
161 #[error("wasmtime engine init failed: {0}")]
162 Init(String),
163}
164
165struct EpochTicker {
168 stop: Arc<std::sync::atomic::AtomicBool>,
169 handle: Option<std::thread::JoinHandle<()>>,
170}
171
172impl EpochTicker {
173 fn spawn(engine: WtEngine, cadence: Duration) -> Self {
174 let stop = Arc::new(std::sync::atomic::AtomicBool::new(false));
175 let stop_t = stop.clone();
176 let handle = std::thread::spawn(move || {
177 while !stop_t.load(std::sync::atomic::Ordering::Relaxed) {
178 std::thread::sleep(cadence);
179 engine.increment_epoch();
180 }
181 });
182 EpochTicker {
183 stop,
184 handle: Some(handle),
185 }
186 }
187}
188
189impl Drop for EpochTicker {
190 fn drop(&mut self) {
191 self.stop.store(true, std::sync::atomic::Ordering::Relaxed);
192 if let Some(h) = self.handle.take() {
193 let _ = h.join();
194 }
195 }
196}
197
198pub struct HostState {
211 pub limits: Limits,
212 pub log_lines: Vec<(forge_ir::LogLevel, String)>,
213 pub store_limits: StoreLimits,
214 pub resource_table: ResourceTable,
215 pub wasi: WasiCtx,
216}
217
218impl std::fmt::Debug for HostState {
219 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
220 f.debug_struct("HostState")
221 .field("limits", &self.limits)
222 .field("log_lines", &self.log_lines.len())
223 .finish_non_exhaustive()
224 }
225}
226
227impl HostState {
228 fn new(limits: Limits) -> Self {
229 let store_limits = StoreLimitsBuilder::new()
230 .memory_size(limits.memory_bytes)
231 .build();
232 let wasi = WasiCtxBuilder::new().build();
235 HostState {
236 limits,
237 log_lines: Vec::new(),
238 store_limits,
239 resource_table: ResourceTable::new(),
240 wasi,
241 }
242 }
243}
244
245impl WasiView for HostState {
246 fn ctx(&mut self) -> &mut WasiCtx {
247 &mut self.wasi
248 }
249 fn table(&mut self) -> &mut ResourceTable {
250 &mut self.resource_table
251 }
252}
253
254macro_rules! impl_host_api {
261 ($world:ident) => {
262 impl bindings::$world::forge::plugin::types::Host for HostState {}
263 impl bindings::$world::forge::plugin::stage::Host for HostState {}
264
265 impl bindings::$world::forge::plugin::host_api::Host for HostState {
266 fn log(
267 &mut self,
268 level: bindings::$world::forge::plugin::host_api::LogLevel,
269 message: String,
270 ) -> wasmtime::Result<()> {
271 use bindings::$world::forge::plugin::host_api::LogLevel as L;
272 let lv = match level {
273 L::Trace => forge_ir::LogLevel::Trace,
274 L::Debug => forge_ir::LogLevel::Debug,
275 L::Info => forge_ir::LogLevel::Info,
276 L::Warn => forge_ir::LogLevel::Warn,
277 L::Error => forge_ir::LogLevel::Error,
278 };
279 match lv {
280 forge_ir::LogLevel::Trace => {
281 tracing::trace!(target: "plugin", "{message}")
282 }
283 forge_ir::LogLevel::Debug => {
284 tracing::debug!(target: "plugin", "{message}")
285 }
286 forge_ir::LogLevel::Info => {
287 tracing::info!(target: "plugin", "{message}")
288 }
289 forge_ir::LogLevel::Warn => {
290 tracing::warn!(target: "plugin", "{message}")
291 }
292 forge_ir::LogLevel::Error => {
293 tracing::error!(target: "plugin", "{message}")
294 }
295 }
296 self.log_lines.push((lv, message));
297 Ok(())
298 }
299
300 fn case_convert(
301 &mut self,
302 input: String,
303 style: bindings::$world::forge::plugin::host_api::CaseStyle,
304 ) -> wasmtime::Result<String> {
305 use bindings::$world::forge::plugin::host_api::CaseStyle as S;
306 let local = match style {
307 S::Snake => case::Style::Snake,
308 S::Kebab => case::Style::Kebab,
309 S::Camel => case::Style::Camel,
310 S::Pascal => case::Style::Pascal,
311 S::ScreamingSnake => case::Style::ScreamingSnake,
312 };
313 Ok(case::convert(&input, local))
314 }
315 }
316 };
317}
318
319impl_host_api!(transformer);
320impl_host_api!(generator);
321
322mod case {
323 #[derive(Debug, Clone, Copy)]
325 pub enum Style {
326 Snake,
327 Kebab,
328 Camel,
329 Pascal,
330 ScreamingSnake,
331 }
332
333 fn split(input: &str) -> Vec<String> {
337 let mut words: Vec<String> = Vec::new();
338 let mut cur = String::new();
339 let mut prev_lower = false;
340 for ch in input.chars() {
341 if ch == '_' || ch == '-' || ch.is_whitespace() {
342 if !cur.is_empty() {
343 words.push(std::mem::take(&mut cur));
344 }
345 prev_lower = false;
346 } else if ch.is_ascii_uppercase() {
347 if prev_lower && !cur.is_empty() {
348 words.push(std::mem::take(&mut cur));
349 }
350 cur.push(ch.to_ascii_lowercase());
351 prev_lower = false;
352 } else {
353 cur.push(ch);
354 prev_lower = ch.is_ascii_lowercase();
355 }
356 }
357 if !cur.is_empty() {
358 words.push(cur);
359 }
360 words
361 }
362
363 pub fn convert(input: &str, style: Style) -> String {
364 let words = split(input);
365 match style {
366 Style::Snake => words.join("_"),
367 Style::Kebab => words.join("-"),
368 Style::ScreamingSnake => words
369 .iter()
370 .map(|w| w.to_ascii_uppercase())
371 .collect::<Vec<_>>()
372 .join("_"),
373 Style::Camel => words
374 .iter()
375 .enumerate()
376 .map(|(i, w)| if i == 0 { w.clone() } else { capitalize(w) })
377 .collect::<String>(),
378 Style::Pascal => words.iter().map(|w| capitalize(w)).collect::<String>(),
379 }
380 }
381
382 fn capitalize(w: &str) -> String {
383 let mut chars = w.chars();
384 match chars.next() {
385 None => String::new(),
386 Some(c) => c.to_ascii_uppercase().to_string() + chars.as_str(),
387 }
388 }
389
390 #[cfg(test)]
391 mod tests {
392 use super::*;
393 #[test]
394 fn snake() {
395 assert_eq!(convert("HelloWorld", Style::Snake), "hello_world");
396 assert_eq!(convert("hello-world", Style::Snake), "hello_world");
397 assert_eq!(convert("hello world", Style::Snake), "hello_world");
398 }
399 #[test]
400 fn pascal() {
401 assert_eq!(convert("hello_world", Style::Pascal), "HelloWorld");
402 }
403 #[test]
404 fn camel() {
405 assert_eq!(convert("hello_world", Style::Camel), "helloWorld");
406 }
407 #[test]
408 fn kebab() {
409 assert_eq!(convert("HelloWorld", Style::Kebab), "hello-world");
410 }
411 #[test]
412 fn screaming() {
413 assert_eq!(convert("helloWorld", Style::ScreamingSnake), "HELLO_WORLD");
414 }
415 }
416}
417
418#[derive(Debug, Error)]
423pub enum LoadError {
424 #[error("failed to compile plugin component: {0}")]
425 Compile(String),
426 #[error("failed to link plugin: {0}")]
427 Link(String),
428 #[error("failed to instantiate plugin: {0}")]
429 Instantiate(String),
430 #[error("failed to fetch plugin info: {0}")]
431 Info(String),
432 #[error("plugin info failed conversion: {0}")]
433 Convert(String),
434}
435
436pub struct Plugin {
439 engine: Engine,
440 component: Component,
441 info: PluginInfo,
442 config_schema: String,
443 kind: PluginKind,
444}
445
446impl std::fmt::Debug for Plugin {
447 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
448 f.debug_struct("Plugin")
449 .field("info", &self.info)
450 .field("kind", &self.kind)
451 .finish_non_exhaustive()
452 }
453}
454
455#[derive(Debug, Clone, Copy, PartialEq, Eq)]
456pub enum PluginKind {
457 Transformer,
458 Generator,
459}
460
461impl Plugin {
462 pub fn info(&self) -> &PluginInfo {
463 &self.info
464 }
465
466 pub fn config_schema(&self) -> &str {
467 &self.config_schema
468 }
469
470 pub fn kind(&self) -> PluginKind {
471 self.kind
472 }
473
474 pub fn load_transformer(engine: &Engine, bytes: &[u8]) -> Result<Self, LoadError> {
476 let component =
477 Component::new(engine.raw(), bytes).map_err(|e| LoadError::Compile(e.to_string()))?;
478 let linker = build_transformer_linker(engine, &component).map_err(LoadError::Link)?;
479 let mut store = make_store(engine, Limits::transformer());
480 let inst =
481 bindings::transformer::IrTransformer::instantiate(&mut store, &component, &linker)
482 .map_err(|e| LoadError::Instantiate(e.to_string()))?;
483 let info_wit = inst
484 .forge_plugin_transformer_api()
485 .call_info(&mut store)
486 .map_err(|e| LoadError::Info(e.to_string()))?;
487 let schema = inst
488 .forge_plugin_transformer_api()
489 .call_config_schema(&mut store)
490 .map_err(|e| LoadError::Info(e.to_string()))?;
491 let info = convert::transformer::plugin_info_from_wit(info_wit);
492 Ok(Plugin {
493 engine: engine.clone(),
494 component,
495 info,
496 config_schema: schema,
497 kind: PluginKind::Transformer,
498 })
499 }
500
501 pub fn load_generator(engine: &Engine, bytes: &[u8]) -> Result<Self, LoadError> {
503 let component =
504 Component::new(engine.raw(), bytes).map_err(|e| LoadError::Compile(e.to_string()))?;
505 let linker = build_generator_linker(engine, &component).map_err(LoadError::Link)?;
506 let mut store = make_store(engine, Limits::generator());
507 let inst = bindings::generator::CodeGenerator::instantiate(&mut store, &component, &linker)
508 .map_err(|e| LoadError::Instantiate(e.to_string()))?;
509 let info_wit = inst
510 .forge_plugin_generator_api()
511 .call_info(&mut store)
512 .map_err(|e| LoadError::Info(e.to_string()))?;
513 let schema = inst
514 .forge_plugin_generator_api()
515 .call_config_schema(&mut store)
516 .map_err(|e| LoadError::Info(e.to_string()))?;
517 let info = convert::generator::plugin_info_from_wit(info_wit);
518 Ok(Plugin {
519 engine: engine.clone(),
520 component,
521 info,
522 config_schema: schema,
523 kind: PluginKind::Generator,
524 })
525 }
526
527 pub fn transform(
529 &self,
530 spec: Ir,
531 config: &str,
532 limits: Limits,
533 ) -> Result<TransformOutput, StageError> {
534 if self.kind != PluginKind::Transformer {
535 return Err(StageError::PluginBug(
536 "plugin loaded as transformer but called as generator".into(),
537 ));
538 }
539 let linker = build_transformer_linker(&self.engine, &self.component)
540 .map_err(|e| StageError::PluginBug(format!("link: {e}")))?;
541 let mut store = make_store(&self.engine, limits);
542 let inst =
543 bindings::transformer::IrTransformer::instantiate(&mut store, &self.component, &linker)
544 .map_err(|e| StageError::PluginBug(format!("instantiate: {e}")))?;
545 let wit_ir = convert::transformer::ir_to_wit(spec);
546 let result = inst.forge_plugin_transformer_api().call_transform(
547 store.as_context_mut(),
548 &wit_ir,
549 config,
550 );
551 let result = map_call_error(result, &store)?;
552 match result {
553 Ok(out) => {
554 let spec = convert::transformer::ir_from_wit(out.spec)
555 .map_err(|e| StageError::PluginBug(format!("ir convert: {e}")))?;
556 let diagnostics = out
557 .diagnostics
558 .into_iter()
559 .map(convert::transformer::diagnostic_from_wit)
560 .collect::<Result<Vec<_>, _>>()
561 .map_err(|e| StageError::PluginBug(format!("diag convert: {e}")))?;
562 Ok(TransformOutput { spec, diagnostics })
563 }
564 Err(stage_err) => Err(stage_error_from_repr(
565 convert::transformer::stage_error_from_wit(stage_err),
566 )),
567 }
568 }
569
570 pub fn generate(
572 &self,
573 spec: Ir,
574 config: &str,
575 limits: Limits,
576 ) -> Result<GenerationOutput, StageError> {
577 if self.kind != PluginKind::Generator {
578 return Err(StageError::PluginBug(
579 "plugin loaded as generator but called as transformer".into(),
580 ));
581 }
582 let linker = build_generator_linker(&self.engine, &self.component)
583 .map_err(|e| StageError::PluginBug(format!("link: {e}")))?;
584 let mut store = make_store(&self.engine, limits);
585 let inst =
586 bindings::generator::CodeGenerator::instantiate(&mut store, &self.component, &linker)
587 .map_err(|e| StageError::PluginBug(format!("instantiate: {e}")))?;
588 let wit_ir = convert::generator::ir_to_wit(spec);
589 let result = inst.forge_plugin_generator_api().call_generate(
590 store.as_context_mut(),
591 &wit_ir,
592 config,
593 );
594 let result = map_call_error(result, &store)?;
595 match result {
596 Ok(out) => {
597 let mut total_bytes: u64 = 0;
598 let files: Vec<OutputFile> = out
599 .files
600 .into_iter()
601 .map(|f| {
602 total_bytes = total_bytes.saturating_add(f.content.len() as u64);
603 OutputFile {
604 path: f.path,
605 content: f.content,
606 mode: match f.mode {
607 bindings::generator::exports::forge::plugin::generator_api::FileMode::Text => FileMode::Text,
608 bindings::generator::exports::forge::plugin::generator_api::FileMode::Binary => FileMode::Binary,
609 bindings::generator::exports::forge::plugin::generator_api::FileMode::Executable => FileMode::Executable,
610 },
611 }
612 })
613 .collect();
614 if files.len() as u64 > limits.output_files_max as u64 {
615 return Err(StageError::ResourceExceeded(ResourceKind::OutputSize));
616 }
617 if total_bytes > limits.output_total_bytes_max {
618 return Err(StageError::ResourceExceeded(ResourceKind::OutputSize));
619 }
620 let diagnostics = out
621 .diagnostics
622 .into_iter()
623 .map(convert::generator::diagnostic_from_wit)
624 .collect::<Result<Vec<_>, _>>()
625 .map_err(|e| StageError::PluginBug(format!("diag convert: {e}")))?;
626 Ok(GenerationOutput { files, diagnostics })
627 }
628 Err(stage_err) => Err(stage_error_from_repr(
629 convert::generator::stage_error_from_wit(stage_err),
630 )),
631 }
632 }
633}
634
635fn stage_error_from_repr(r: StageErrorRepr) -> StageError {
639 match r {
640 StageErrorRepr::Rejected {
641 reason,
642 diagnostics,
643 } => StageError::Rejected {
644 reason,
645 diagnostics,
646 },
647 StageErrorRepr::PluginBug(s) => StageError::PluginBug(s),
648 StageErrorRepr::ConfigInvalid(s) => StageError::ConfigInvalid(s),
649 StageErrorRepr::ResourceExceeded(k) => StageError::ResourceExceeded(match k {
650 ResourceKindRepr::Fuel => ResourceKind::Fuel,
651 ResourceKindRepr::Memory => ResourceKind::Memory,
652 ResourceKindRepr::Time => ResourceKind::Time,
653 ResourceKindRepr::OutputSize => ResourceKind::OutputSize,
654 }),
655 }
656}
657
658fn build_transformer_linker(
667 engine: &Engine,
668 _component: &Component,
669) -> Result<Linker<HostState>, String> {
670 let mut linker = Linker::<HostState>::new(engine.raw());
671 bindings::transformer::IrTransformer::add_to_linker(&mut linker, |s: &mut HostState| s)
672 .map_err(|e| e.to_string())?;
673 wasmtime_wasi::add_to_linker_sync(&mut linker).map_err(|e| e.to_string())?;
674 Ok(linker)
675}
676
677fn build_generator_linker(
678 engine: &Engine,
679 _component: &Component,
680) -> Result<Linker<HostState>, String> {
681 let mut linker = Linker::<HostState>::new(engine.raw());
682 bindings::generator::CodeGenerator::add_to_linker(&mut linker, |s: &mut HostState| s)
683 .map_err(|e| e.to_string())?;
684 wasmtime_wasi::add_to_linker_sync(&mut linker).map_err(|e| e.to_string())?;
685 Ok(linker)
686}
687
688fn make_store(engine: &Engine, limits: Limits) -> Store<HostState> {
689 let mut store = Store::new(engine.raw(), HostState::new(limits));
690 let _ = store.set_fuel(limits.fuel);
691 let deadline = limits.wall_clock_ms.div_ceil(10).max(1);
693 store.set_epoch_deadline(deadline);
694 store.epoch_deadline_trap();
695 store.limiter(|s| &mut s.store_limits);
696 store
697}
698
699fn map_call_error<T>(res: wasmtime::Result<T>, store: &Store<HostState>) -> Result<T, StageError> {
704 match res {
705 Ok(v) => Ok(v),
706 Err(e) => {
707 if let Some(t) = e.downcast_ref::<wasmtime::Trap>() {
710 match t {
711 wasmtime::Trap::OutOfFuel => {
712 return Err(StageError::ResourceExceeded(ResourceKind::Fuel))
713 }
714 wasmtime::Trap::Interrupt => {
715 return Err(StageError::ResourceExceeded(ResourceKind::Time))
716 }
717 _ => {}
718 }
719 }
720 let msg = format!("{e:#}");
723 if msg.contains("memory") && store.data().limits.memory_bytes > 0 {
724 if msg.contains("grow") || msg.contains("limit") {
727 return Err(StageError::ResourceExceeded(ResourceKind::Memory));
728 }
729 }
730 Err(StageError::PluginBug(msg))
731 }
732 }
733}