use std::sync::Arc;
use wasmtime::component::{Component, Linker};
use wasmtime::{Config, Engine, Store};
use wasmtime_wasi::{WasiCtx, WasiCtxBuilder, WasiCtxView, WasiView};
wasmtime::component::bindgen!({
world: "mutation-plugin",
path: "wit",
});
pub use ryo_plugin_api::{
Capture, MatchResult, MutationCategory, MutationManifest, NodeKind, TextEdit, TransformContext,
TransformDef, TransformError, TypeHint, CURRENT_API_VERSION,
};
#[derive(Debug, thiserror::Error)]
pub enum LoaderError {
#[error("Failed to create WASM engine: {0}")]
EngineCreation(#[source] wasmtime::Error),
#[error("Failed to add WASI to linker: {0}")]
WasiSetup(#[source] wasmtime::Error),
#[error("Failed to parse WASM component: {0}")]
ComponentParse(#[source] wasmtime::Error),
#[error("Failed to set fuel limit: {0}")]
FuelSetup(#[source] wasmtime::Error),
#[error("Failed to instantiate WASM component: {0}")]
Instantiation(#[source] wasmtime::Error),
#[error("API version mismatch: expected {expected}, got {actual}")]
ApiVersionMismatch { expected: u32, actual: u32 },
#[error("Failed to call WASM function '{function}': {source}")]
FunctionCall {
function: &'static str,
#[source]
source: wasmtime::Error,
},
#[error("Transform error: {0}")]
TransformError(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
}
pub struct PluginLoader {
engine: Engine,
linker: Arc<Linker<PluginState>>,
}
struct PluginState {
wasi_ctx: WasiCtx,
resource_table: wasmtime::component::ResourceTable,
}
impl WasiView for PluginState {
fn ctx(&mut self) -> WasiCtxView<'_> {
WasiCtxView {
ctx: &mut self.wasi_ctx,
table: &mut self.resource_table,
}
}
}
impl PluginLoader {
pub fn new() -> Result<Self, LoaderError> {
let mut config = Config::new();
config.wasm_component_model(true);
config.consume_fuel(true);
config.max_wasm_stack(1024 * 1024);
let engine = Engine::new(&config).map_err(LoaderError::EngineCreation)?;
let mut linker = Linker::new(&engine);
wasmtime_wasi::p2::add_to_linker_sync(&mut linker).map_err(LoaderError::WasiSetup)?;
Ok(Self {
engine,
linker: Arc::new(linker),
})
}
pub fn load(&self, wasm_bytes: &[u8]) -> Result<LoadedPlugin, LoaderError> {
let component =
Component::new(&self.engine, wasm_bytes).map_err(LoaderError::ComponentParse)?;
let wasi_ctx = WasiCtxBuilder::new()
.inherit_stdout() .inherit_stderr()
.build();
let mut store = Store::new(
&self.engine,
PluginState {
wasi_ctx,
resource_table: wasmtime::component::ResourceTable::new(),
},
);
store.set_fuel(10_000_000).map_err(LoaderError::FuelSetup)?;
let bindings = MutationPlugin::instantiate(&mut store, &component, &self.linker)
.map_err(LoaderError::Instantiation)?;
let iface = bindings.ryo_transform_mutation();
let wasm_manifest =
iface
.call_get_manifest(&mut store)
.map_err(|e| LoaderError::FunctionCall {
function: "get-manifest",
source: e,
})?;
let expected_version = CURRENT_API_VERSION;
if wasm_manifest.api_version != expected_version {
return Err(LoaderError::ApiVersionMismatch {
expected: expected_version,
actual: wasm_manifest.api_version,
});
}
let manifest = convert_manifest(&wasm_manifest);
let additional_patterns =
iface
.call_get_pattern_source(&mut store)
.map_err(|e| LoaderError::FunctionCall {
function: "get-pattern-source",
source: e,
})?;
tracing::info!("Loaded mutation plugin: {}", manifest.name);
Ok(LoadedPlugin {
manifest,
additional_patterns,
bindings,
store,
})
}
}
pub struct LoadedPlugin {
pub manifest: MutationManifest,
pub additional_patterns: String,
bindings: MutationPlugin,
store: Store<PluginState>,
}
impl LoadedPlugin {
pub fn execute_transform(
&mut self,
matches: Vec<MatchResult>,
context: TransformContext,
) -> Result<Vec<TextEdit>, LoaderError> {
self.store
.set_fuel(1_000_000)
.map_err(LoaderError::FuelSetup)?;
let wasm_matches = matches
.iter()
.map(convert_match_to_wasm)
.collect::<Vec<_>>();
let wasm_context = convert_context_to_wasm(&context);
let iface = self.bindings.ryo_transform_mutation();
let result = iface
.call_execute_transform(&mut self.store, &wasm_matches, &wasm_context)
.map_err(|e| LoaderError::FunctionCall {
function: "execute-transform",
source: e,
})?;
match result {
Ok(edits) => Ok(edits.into_iter().map(convert_text_edit).collect()),
Err(e) => Err(LoaderError::TransformError(format_transform_error(&e))),
}
}
}
fn convert_manifest(
wasm: &exports::ryo::transform::mutation::MutationManifest,
) -> MutationManifest {
MutationManifest {
api_version: wasm.api_version,
name: wasm.name.clone(),
description: wasm.description.clone(),
category: convert_category(&wasm.category),
tier: wasm.tier,
pattern: wasm.pattern.clone(),
transform: convert_transform_def(&wasm.transform),
}
}
fn convert_category(
wasm: &exports::ryo::transform::mutation::MutationCategory,
) -> MutationCategory {
match wasm {
exports::ryo::transform::mutation::MutationCategory::Idiom => MutationCategory::Idiom,
exports::ryo::transform::mutation::MutationCategory::Refactor => MutationCategory::Refactor,
exports::ryo::transform::mutation::MutationCategory::Generate => MutationCategory::Generate,
exports::ryo::transform::mutation::MutationCategory::Custom => MutationCategory::Custom,
}
}
fn convert_transform_def(wasm: &exports::ryo::transform::mutation::TransformDef) -> TransformDef {
match wasm {
exports::ryo::transform::mutation::TransformDef::Template(t) => {
TransformDef::Template(t.clone())
}
exports::ryo::transform::mutation::TransformDef::WasmExecute => TransformDef::WasmExecute,
}
}
fn convert_match_to_wasm(m: &MatchResult) -> exports::ryo::transform::mutation::MatchResult {
exports::ryo::transform::mutation::MatchResult {
kind: convert_node_kind_to_wasm(&m.kind),
start_byte: m.start_byte,
end_byte: m.end_byte,
captures: m.captures.iter().map(convert_capture_to_wasm).collect(),
}
}
fn convert_node_kind_to_wasm(k: &NodeKind) -> exports::ryo::transform::types::NodeKind {
match k {
NodeKind::FnCall => exports::ryo::transform::types::NodeKind::FnCall,
NodeKind::MethodCall => exports::ryo::transform::types::NodeKind::MethodCall,
NodeKind::MatchExpr => exports::ryo::transform::types::NodeKind::MatchExpr,
NodeKind::IfExpr => exports::ryo::transform::types::NodeKind::IfExpr,
NodeKind::IfLetExpr => exports::ryo::transform::types::NodeKind::IfLetExpr,
NodeKind::LoopExpr => exports::ryo::transform::types::NodeKind::LoopExpr,
NodeKind::ForExpr => exports::ryo::transform::types::NodeKind::ForExpr,
NodeKind::WhileExpr => exports::ryo::transform::types::NodeKind::WhileExpr,
NodeKind::Block => exports::ryo::transform::types::NodeKind::Block,
NodeKind::Ident => exports::ryo::transform::types::NodeKind::Ident,
NodeKind::Literal => exports::ryo::transform::types::NodeKind::Literal,
NodeKind::BinaryExpr => exports::ryo::transform::types::NodeKind::BinaryExpr,
NodeKind::UnaryExpr => exports::ryo::transform::types::NodeKind::UnaryExpr,
NodeKind::FieldAccess => exports::ryo::transform::types::NodeKind::FieldAccess,
NodeKind::IndexExpr => exports::ryo::transform::types::NodeKind::IndexExpr,
NodeKind::Closure => exports::ryo::transform::types::NodeKind::Closure,
NodeKind::StructExpr => exports::ryo::transform::types::NodeKind::StructExpr,
NodeKind::TupleExpr => exports::ryo::transform::types::NodeKind::TupleExpr,
NodeKind::ArrayExpr => exports::ryo::transform::types::NodeKind::ArrayExpr,
NodeKind::Path => exports::ryo::transform::types::NodeKind::Path,
NodeKind::TypePath => exports::ryo::transform::types::NodeKind::TypePath,
}
}
fn convert_capture_to_wasm(c: &Capture) -> exports::ryo::transform::types::Capture {
exports::ryo::transform::types::Capture {
name: c.name.clone(),
start_byte: c.start_byte,
end_byte: c.end_byte,
text: c.text.clone(),
}
}
fn convert_context_to_wasm(
ctx: &TransformContext,
) -> exports::ryo::transform::mutation::TransformContext {
exports::ryo::transform::mutation::TransformContext {
file_path: ctx.file_path.clone(),
source_text: ctx.source_text.clone(),
type_hints: ctx
.type_hints
.iter()
.map(convert_type_hint_to_wasm)
.collect(),
fn_return_type: ctx.fn_return_type.clone(),
}
}
fn convert_type_hint_to_wasm(h: &TypeHint) -> exports::ryo::transform::types::TypeHint {
exports::ryo::transform::types::TypeHint {
node_id: h.node_id,
type_name: h.type_name.clone(),
is_result: h.is_result,
is_option: h.is_option,
is_copy: h.is_copy,
is_iterator: h.is_iterator,
}
}
fn convert_text_edit(e: exports::ryo::transform::types::TextEdit) -> TextEdit {
TextEdit {
start_byte: e.start_byte,
end_byte: e.end_byte,
replacement: e.replacement,
}
}
fn format_transform_error(e: &exports::ryo::transform::mutation::TransformError) -> String {
match e {
exports::ryo::transform::mutation::TransformError::MissingCapture(name) => {
format!("Missing capture: {}", name)
}
exports::ryo::transform::mutation::TransformError::InvalidContext(msg) => {
format!("Invalid context: {}", msg)
}
exports::ryo::transform::mutation::TransformError::TypeMismatch(msg) => {
format!("Type mismatch: {}", msg)
}
exports::ryo::transform::mutation::TransformError::PatternNotApplicable(msg) => {
format!("Pattern not applicable: {}", msg)
}
exports::ryo::transform::mutation::TransformError::Internal(msg) => {
format!("Internal error: {}", msg)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_loader_creation() {
let loader = PluginLoader::new();
assert!(loader.is_ok());
}
}