1use std::collections::{HashMap, HashSet};
21use std::net::IpAddr;
22use std::path::Path;
23use std::sync::Arc;
24use std::time::Duration;
25
26use tokio::sync::Mutex;
27use wasmtime::component::{Component, HasSelf, Linker, ResourceTable};
28use wasmtime::{Engine, Store};
29use wasmtime_wasi::{WasiCtx, WasiCtxBuilder, WasiCtxView, WasiView};
30
31mod tool;
32
33pub use tool::WasmTool;
34
35pub type LogSink = Arc<dyn Fn(&str, &str, &str) + Send + Sync>;
44
45wasmtime::component::bindgen!({
50 path: "wit/grain-plugin.wit",
51 world: "grain-plugin",
52});
53
54mod v2_bindings {
55 wasmtime::component::bindgen!({
56 path: "wit/grain-plugin.wit",
57 world: "grain-plugin-v2",
58 });
59}
60
61pub use exports::grain::plugin::plugin as wit_plugin;
66pub use grain::plugin::host as wit_host;
67
68use v2_bindings::exports::grain::plugin::orchestration as wit_orchestration;
69
70const HOST_HTTP_TIMEOUT: Duration = Duration::from_secs(30);
71
72#[derive(Debug, Clone, Default)]
78pub struct Capabilities {
79 pub log: bool,
80 pub env: bool,
81 pub http: bool,
82 pub role_orchestration: bool,
83}
84
85impl Capabilities {
86 pub fn from_list(caps: &[String]) -> Self {
88 let set: HashSet<&str> = caps.iter().map(|s| s.as_str()).collect();
89 Capabilities {
90 log: set.contains("log"),
91 env: set.contains("env"),
92 http: set.contains("http"),
93 role_orchestration: set.contains("role-orchestration"),
94 }
95 }
96}
97
98pub struct PluginState {
100 wasi_ctx: WasiCtx,
101 table: ResourceTable,
102 capabilities: Capabilities,
103 plugin_name: String,
104 rt_handle: tokio::runtime::Handle,
106 log_sink: Option<LogSink>,
109 env_map: HashMap<String, String>,
111}
112
113impl WasiView for PluginState {
114 fn ctx(&mut self) -> WasiCtxView<'_> {
115 WasiCtxView {
116 ctx: &mut self.wasi_ctx,
117 table: &mut self.table,
118 }
119 }
120}
121
122impl wit_host::Host for PluginState {
127 fn log(&mut self, level: wit_host::LogLevel, msg: String) {
128 if !self.capabilities.log {
129 return;
130 }
131 let tag = match level {
132 wit_host::LogLevel::Debug => "debug",
133 wit_host::LogLevel::Info => "info",
134 wit_host::LogLevel::Warn => "warn",
135 wit_host::LogLevel::Error => "error",
136 };
137 if let Some(sink) = &self.log_sink {
138 sink(tag, &self.plugin_name, &msg);
139 } else {
140 eprintln!("[{tag}] wasm plugin '{}': {msg}", self.plugin_name);
144 }
145 }
146
147 fn env_get(&mut self, key: String) -> Option<String> {
148 if !self.capabilities.env {
149 return None;
150 }
151 if let Some(val) = self.env_map.get(&key) {
153 return Some(val.clone());
154 }
155 std::env::var(&key).ok()
156 }
157
158 fn http_get(
159 &mut self,
160 url: String,
161 headers: Vec<(String, String)>,
162 ) -> Result<wit_host::HttpResponse, String> {
163 if !self.capabilities.http {
164 return Err("http capability not granted".into());
165 }
166 self.rt_handle
167 .block_on(async { do_http_request("GET", &url, &headers, None).await })
168 }
169
170 fn http_post(
171 &mut self,
172 url: String,
173 headers: Vec<(String, String)>,
174 body: String,
175 ) -> Result<wit_host::HttpResponse, String> {
176 if !self.capabilities.http {
177 return Err("http capability not granted".into());
178 }
179 self.rt_handle
180 .block_on(async { do_http_request("POST", &url, &headers, Some(&body)).await })
181 }
182}
183
184impl v2_bindings::grain::plugin::host::Host for PluginState {
185 fn log(&mut self, level: v2_bindings::grain::plugin::host::LogLevel, msg: String) {
186 if !self.capabilities.log {
187 return;
188 }
189 let tag = match level {
190 v2_bindings::grain::plugin::host::LogLevel::Debug => "debug",
191 v2_bindings::grain::plugin::host::LogLevel::Info => "info",
192 v2_bindings::grain::plugin::host::LogLevel::Warn => "warn",
193 v2_bindings::grain::plugin::host::LogLevel::Error => "error",
194 };
195 if let Some(sink) = &self.log_sink {
196 sink(tag, &self.plugin_name, &msg);
197 } else {
198 eprintln!("[{tag}] wasm plugin '{}': {msg}", self.plugin_name);
199 }
200 }
201
202 fn env_get(&mut self, key: String) -> Option<String> {
203 if !self.capabilities.env {
204 return None;
205 }
206 if let Some(val) = self.env_map.get(&key) {
207 return Some(val.clone());
208 }
209 std::env::var(&key).ok()
210 }
211
212 fn http_get(
213 &mut self,
214 url: String,
215 headers: Vec<(String, String)>,
216 ) -> Result<v2_bindings::grain::plugin::host::HttpResponse, String> {
217 if !self.capabilities.http {
218 return Err("http capability not granted".into());
219 }
220 let response = self
221 .rt_handle
222 .block_on(async { do_http_request("GET", &url, &headers, None).await })?;
223 Ok(v2_bindings::grain::plugin::host::HttpResponse {
224 status: response.status,
225 headers: response.headers,
226 body: response.body,
227 })
228 }
229
230 fn http_post(
231 &mut self,
232 url: String,
233 headers: Vec<(String, String)>,
234 body: String,
235 ) -> Result<v2_bindings::grain::plugin::host::HttpResponse, String> {
236 if !self.capabilities.http {
237 return Err("http capability not granted".into());
238 }
239 let response = self
240 .rt_handle
241 .block_on(async { do_http_request("POST", &url, &headers, Some(&body)).await })?;
242 Ok(v2_bindings::grain::plugin::host::HttpResponse {
243 status: response.status,
244 headers: response.headers,
245 body: response.body,
246 })
247 }
248}
249
250async fn do_http_request(
251 method: &str,
252 url: &str,
253 headers: &[(String, String)],
254 body: Option<&str>,
255) -> Result<wit_host::HttpResponse, String> {
256 let mut client = reqwest::Client::builder().timeout(HOST_HTTP_TIMEOUT);
257 if should_bypass_proxy_for_url(url) {
258 client = client.no_proxy();
259 }
260 let client = client.build().map_err(|e| e.to_string())?;
261 let mut builder = match method {
262 "POST" => client.post(url),
263 _ => client.get(url),
264 };
265 for (k, v) in headers {
266 builder = builder.header(k.as_str(), v.as_str());
267 }
268 if let Some(b) = body {
269 builder = builder.body(b.to_string());
270 }
271 let resp = builder.send().await.map_err(|e| e.to_string())?;
272 let status = resp.status().as_u16();
273 let resp_headers: Vec<(String, String)> = resp
274 .headers()
275 .iter()
276 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
277 .collect();
278 let resp_body = resp.text().await.map_err(|e| e.to_string())?;
279 Ok(wit_host::HttpResponse {
280 status,
281 headers: resp_headers,
282 body: resp_body,
283 })
284}
285
286fn should_bypass_proxy_for_url(url: &str) -> bool {
287 let Ok(parsed) = reqwest::Url::parse(url) else {
288 return false;
289 };
290 let Some(host) = parsed.host_str() else {
291 return false;
292 };
293 host.eq_ignore_ascii_case("localhost")
294 || host
295 .parse::<IpAddr>()
296 .map(|ip| ip.is_loopback())
297 .unwrap_or(false)
298}
299
300#[derive(Debug, thiserror::Error)]
305pub enum WasmPluginError {
306 #[error("wasmtime: {0}")]
307 Wasmtime(#[from] wasmtime::Error),
308 #[error("plugin init failed: {0}")]
309 InitFailed(String),
310 #[error("tool call failed: {0}")]
311 ToolCallFailed(String),
312 #[error("io: {0}")]
313 Io(#[from] std::io::Error),
314}
315
316#[derive(Debug, Clone)]
322pub struct LoadedPlugin {
323 pub info: PluginInfo,
324 pub tool_defs: Vec<ToolDef>,
325 pub orchestration: Option<OrchestrationDef>,
326}
327
328#[derive(Debug, Clone)]
331pub struct PluginInfo {
332 pub name: String,
333 pub version: String,
334}
335
336#[derive(Debug, Clone)]
338pub struct ToolDef {
339 pub name: String,
340 pub label: String,
341 pub description: String,
342 pub parameters_json: String,
343}
344
345#[derive(Debug, Clone, Default, PartialEq, Eq)]
347pub struct OrchestrationDef {
348 pub roles: Vec<RoleDef>,
349 pub hooks: Vec<HookDef>,
350}
351
352#[derive(Debug, Clone, PartialEq, Eq)]
354pub struct RoleDef {
355 pub name: String,
356 pub model: String,
357 pub prompt: String,
358 pub tools: Vec<String>,
359 pub thinking_level: Option<String>,
360}
361
362#[derive(Debug, Clone, PartialEq, Eq)]
364pub struct HookDef {
365 pub point: HookPoint,
366 pub name: String,
367}
368
369#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
371pub enum HookPoint {
372 BeforeAgentStart,
373 AfterToolCall,
374 PrepareNextTurn,
375 ShouldStopAfterTurn,
376}
377
378#[derive(Debug, Clone, PartialEq, Eq)]
380pub enum HostAction {
381 SwitchRole(String),
382 SwitchModel(String),
383 SetSystemPrompt(String),
384 SetActiveTools(Vec<String>),
385 InjectUserMessage(String),
386 StopAfterTurn(bool),
387 EmitCustom(String),
388 SetUiHeader(UiHeader),
389 SetUiStatus(String),
390}
391
392#[derive(Debug, Clone, Default, PartialEq, Eq)]
394pub struct UiHeader {
395 pub provider: Option<String>,
396 pub model: Option<String>,
397}
398
399pub struct WasmPluginRuntime {
405 engine: Engine,
406 linker: Linker<PluginState>,
407 linker_v2: Linker<PluginState>,
408 components: Mutex<Vec<PluginEntry>>,
410 log_sink: Option<LogSink>,
413}
414
415struct PluginEntry {
416 id: String,
417 component: Component,
418 capabilities: Capabilities,
419 plugin_name: String,
420 env_map: HashMap<String, String>,
422}
423
424impl std::fmt::Debug for WasmPluginRuntime {
425 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
426 f.debug_struct("WasmPluginRuntime").finish_non_exhaustive()
427 }
428}
429
430impl WasmPluginRuntime {
431 pub fn new() -> Result<Self, WasmPluginError> {
433 let mut config = wasmtime::Config::new();
434 config.wasm_component_model(true);
435 let engine = Engine::new(&config)?;
436 let mut linker = Linker::<PluginState>::new(&engine);
437 wasmtime_wasi::p2::add_to_linker_sync(&mut linker)?;
439 GrainPlugin::add_to_linker::<_, HasSelf<_>>(&mut linker, |s| s)?;
440 let mut linker_v2 = Linker::<PluginState>::new(&engine);
441 wasmtime_wasi::p2::add_to_linker_sync(&mut linker_v2)?;
442 v2_bindings::GrainPluginV2::add_to_linker::<_, HasSelf<_>>(&mut linker_v2, |s| s)?;
443 Ok(WasmPluginRuntime {
444 engine,
445 linker,
446 linker_v2,
447 components: Mutex::new(Vec::new()),
448 log_sink: None,
449 })
450 }
451
452 pub fn with_log_sink(mut self, sink: LogSink) -> Self {
456 self.log_sink = Some(sink);
457 self
458 }
459
460 pub async fn load(
463 &self,
464 path: &Path,
465 plugin_id: &str,
466 capabilities: Capabilities,
467 plugin_name: &str,
468 env_map: HashMap<String, String>,
469 ) -> Result<LoadedPlugin, WasmPluginError> {
470 let wasm_bytes = tokio::fs::read(path).await?;
471 let component = Component::new(&self.engine, &wasm_bytes)?;
472
473 let wasi = WasiCtxBuilder::new().build();
475 let state = PluginState {
476 wasi_ctx: wasi,
477 table: ResourceTable::new(),
478 capabilities: capabilities.clone(),
479 plugin_name: plugin_name.to_string(),
480 rt_handle: tokio::runtime::Handle::current(),
481 log_sink: self.log_sink.clone(),
482 env_map: env_map.clone(),
483 };
484 let mut store = Store::new(&self.engine, state);
485 let bindings = GrainPlugin::instantiate(&mut store, &component, &self.linker)?;
486
487 let guest = bindings.grain_plugin_plugin();
489 let info_raw = guest
490 .call_init(&mut store)?
491 .map_err(WasmPluginError::InitFailed)?;
492 let info = PluginInfo {
493 name: info_raw.name,
494 version: info_raw.version,
495 };
496
497 let tools_raw = guest.call_list_tools(&mut store)?;
499 let tool_defs: Vec<ToolDef> = tools_raw
500 .into_iter()
501 .map(|t| ToolDef {
502 name: t.name,
503 label: t.label,
504 description: t.description,
505 parameters_json: t.parameters_json,
506 })
507 .collect();
508
509 let orchestration = self
510 .load_orchestration_metadata(&component, &capabilities, plugin_name, &env_map)
511 .await?;
512
513 self.components.lock().await.push(PluginEntry {
515 id: plugin_id.to_string(),
516 component,
517 capabilities,
518 plugin_name: plugin_name.to_string(),
519 env_map,
520 });
521
522 Ok(LoadedPlugin {
523 info,
524 tool_defs,
525 orchestration,
526 })
527 }
528
529 async fn load_orchestration_metadata(
530 &self,
531 component: &Component,
532 capabilities: &Capabilities,
533 plugin_name: &str,
534 env_map: &HashMap<String, String>,
535 ) -> Result<Option<OrchestrationDef>, WasmPluginError> {
536 if !capabilities.role_orchestration {
537 return Ok(None);
538 }
539 let wasi = WasiCtxBuilder::new().build();
540 let state = PluginState {
541 wasi_ctx: wasi,
542 table: ResourceTable::new(),
543 capabilities: capabilities.clone(),
544 plugin_name: plugin_name.to_string(),
545 rt_handle: tokio::runtime::Handle::current(),
546 log_sink: self.log_sink.clone(),
547 env_map: env_map.clone(),
548 };
549 let mut store = Store::new(&self.engine, state);
550 let bindings =
551 v2_bindings::GrainPluginV2::instantiate(&mut store, component, &self.linker_v2)?;
552 let guest = bindings.grain_plugin_orchestration();
553 let roles = guest
554 .call_list_roles(&mut store)?
555 .into_iter()
556 .map(role_from_wit)
557 .collect();
558 let hooks = guest
559 .call_list_hooks(&mut store)?
560 .into_iter()
561 .map(hook_from_wit)
562 .collect();
563 Ok(Some(OrchestrationDef { roles, hooks }))
564 }
565
566 pub async fn call_tool(
578 &self,
579 plugin_id: &str,
580 tool_name: &str,
581 args_json: &str,
582 host_rt_handle: tokio::runtime::Handle,
583 ) -> Result<CallToolResult, WasmPluginError> {
584 std::thread::scope(|scope| {
585 scope
586 .spawn(move || {
587 self.call_tool_blocking(plugin_id, tool_name, args_json, host_rt_handle)
588 })
589 .join()
590 .map_err(|_| WasmPluginError::ToolCallFailed("tool call panicked".into()))?
591 })
592 }
593
594 pub fn call_tool_blocking(
602 &self,
603 plugin_id: &str,
604 tool_name: &str,
605 args_json: &str,
606 host_rt_handle: tokio::runtime::Handle,
607 ) -> Result<CallToolResult, WasmPluginError> {
608 let entries = self.components.blocking_lock();
609 let entry = entries.iter().find(|e| e.id == plugin_id).ok_or_else(|| {
610 WasmPluginError::ToolCallFailed(format!("plugin '{plugin_id}' not loaded"))
611 })?;
612
613 let wasi = WasiCtxBuilder::new().build();
614 let state = PluginState {
615 wasi_ctx: wasi,
616 table: ResourceTable::new(),
617 capabilities: entry.capabilities.clone(),
618 plugin_name: entry.plugin_name.clone(),
619 rt_handle: host_rt_handle,
620 log_sink: self.log_sink.clone(),
621 env_map: entry.env_map.clone(),
622 };
623 let mut store = Store::new(&self.engine, state);
624 let bindings = GrainPlugin::instantiate(&mut store, &entry.component, &self.linker)?;
625
626 let guest = bindings.grain_plugin_plugin();
627
628 let _ = guest.call_init(&mut store)?;
630
631 let result = guest.call_call_tool(&mut store, tool_name, args_json)?;
632 Ok(CallToolResult {
633 content_json: result.content_json,
634 is_error: result.is_error,
635 })
636 }
637
638 pub async fn call_hook(
644 &self,
645 plugin_id: &str,
646 point: HookPoint,
647 context_json: &str,
648 host_rt_handle: tokio::runtime::Handle,
649 ) -> Result<Vec<HostAction>, WasmPluginError> {
650 let point_raw = point_to_wit(point);
651 std::thread::scope(|scope| {
652 scope
653 .spawn(move || {
654 self.call_hook_blocking(plugin_id, point_raw, context_json, host_rt_handle)
655 })
656 .join()
657 .map_err(|_| WasmPluginError::ToolCallFailed("hook call panicked".into()))?
658 })
659 }
660
661 fn call_hook_blocking(
662 &self,
663 plugin_id: &str,
664 point: wit_orchestration::HookPoint,
665 context_json: &str,
666 host_rt_handle: tokio::runtime::Handle,
667 ) -> Result<Vec<HostAction>, WasmPluginError> {
668 let entries = self.components.blocking_lock();
669 let entry = entries.iter().find(|e| e.id == plugin_id).ok_or_else(|| {
670 WasmPluginError::ToolCallFailed(format!("plugin '{plugin_id}' not loaded"))
671 })?;
672 if !entry.capabilities.role_orchestration {
673 return Err(WasmPluginError::ToolCallFailed(format!(
674 "plugin '{plugin_id}' does not have role-orchestration capability"
675 )));
676 }
677
678 let wasi = WasiCtxBuilder::new().build();
679 let state = PluginState {
680 wasi_ctx: wasi,
681 table: ResourceTable::new(),
682 capabilities: entry.capabilities.clone(),
683 plugin_name: entry.plugin_name.clone(),
684 rt_handle: host_rt_handle,
685 log_sink: self.log_sink.clone(),
686 env_map: entry.env_map.clone(),
687 };
688 let mut store = Store::new(&self.engine, state);
689 let bindings =
690 v2_bindings::GrainPluginV2::instantiate(&mut store, &entry.component, &self.linker_v2)?;
691 let guest = bindings.grain_plugin_orchestration();
692 let actions = guest
693 .call_call_hook(&mut store, point, context_json)?
694 .map_err(WasmPluginError::ToolCallFailed)?
695 .into_iter()
696 .map(action_from_wit)
697 .collect();
698 Ok(actions)
699 }
700}
701
702#[derive(Debug, Clone)]
704pub struct CallToolResult {
705 pub content_json: String,
706 pub is_error: bool,
707}
708
709fn role_from_wit(role: wit_orchestration::RoleDef) -> RoleDef {
710 RoleDef {
711 name: role.name,
712 model: role.model,
713 prompt: role.prompt,
714 tools: role.tools,
715 thinking_level: role.thinking_level,
716 }
717}
718
719fn hook_from_wit(hook: wit_orchestration::HookDef) -> HookDef {
720 HookDef {
721 point: hook_point_from_wit(hook.point),
722 name: hook.name,
723 }
724}
725
726fn hook_point_from_wit(point: wit_orchestration::HookPoint) -> HookPoint {
727 match point {
728 wit_orchestration::HookPoint::BeforeAgentStart => HookPoint::BeforeAgentStart,
729 wit_orchestration::HookPoint::AfterToolCall => HookPoint::AfterToolCall,
730 wit_orchestration::HookPoint::PrepareNextTurn => HookPoint::PrepareNextTurn,
731 wit_orchestration::HookPoint::ShouldStopAfterTurn => HookPoint::ShouldStopAfterTurn,
732 }
733}
734
735fn point_to_wit(point: HookPoint) -> wit_orchestration::HookPoint {
736 match point {
737 HookPoint::BeforeAgentStart => wit_orchestration::HookPoint::BeforeAgentStart,
738 HookPoint::AfterToolCall => wit_orchestration::HookPoint::AfterToolCall,
739 HookPoint::PrepareNextTurn => wit_orchestration::HookPoint::PrepareNextTurn,
740 HookPoint::ShouldStopAfterTurn => wit_orchestration::HookPoint::ShouldStopAfterTurn,
741 }
742}
743
744fn action_from_wit(action: wit_orchestration::HostAction) -> HostAction {
745 match action {
746 wit_orchestration::HostAction::SwitchRole(role) => HostAction::SwitchRole(role),
747 wit_orchestration::HostAction::SwitchModel(model) => HostAction::SwitchModel(model),
748 wit_orchestration::HostAction::SetSystemPrompt(prompt) => {
749 HostAction::SetSystemPrompt(prompt)
750 }
751 wit_orchestration::HostAction::SetActiveTools(tools) => HostAction::SetActiveTools(tools),
752 wit_orchestration::HostAction::InjectUserMessage(message) => {
753 HostAction::InjectUserMessage(message)
754 }
755 wit_orchestration::HostAction::StopAfterTurn(stop) => HostAction::StopAfterTurn(stop),
756 wit_orchestration::HostAction::EmitCustom(value) => HostAction::EmitCustom(value),
757 wit_orchestration::HostAction::SetUiHeader(header) => HostAction::SetUiHeader(UiHeader {
758 provider: header.provider,
759 model: header.model,
760 }),
761 wit_orchestration::HostAction::SetUiStatus(status) => HostAction::SetUiStatus(status),
762 }
763}