use crate::model::agent::{AgentError, AgentType};
use anyhow::anyhow;
use rib::ParsedFunctionName;
use std::path::Path;
use std::sync::{Arc, Mutex};
use tracing::{debug, error};
use wasmtime::component::types::{ComponentInstance, ComponentItem};
use wasmtime::component::{
Component, Func, Instance, Linker, LinkerInstance, ResourceTable, ResourceType, Type,
};
use wasmtime::{AsContextMut, Engine, Store};
use wasmtime_wasi::p2::{pipe, StdoutStream, WasiCtx, WasiView};
use wasmtime_wasi::{IoCtx, IoView};
use wit_parser::{PackageId, Resolve, WorldItem};
const INTERFACE_NAME: &str = "golem:agent/guest";
const FUNCTION_NAME: &str = "discover-agent-types";
pub async fn extract_agent_types_with_streams(
wasm_path: &Path,
stdout: Option<impl StdoutStream + 'static>,
stderr: Option<impl StdoutStream + 'static>,
fail_on_missing_discover_method: bool,
) -> anyhow::Result<Vec<AgentType>> {
let mut config = wasmtime::Config::default();
config.async_support(true);
config.wasm_component_model(true);
let engine = Engine::new(&config)?;
let mut linker: Linker<Host> = Linker::new(&engine);
linker.allow_shadowing(true);
wasmtime_wasi::p2::add_to_linker_with_options_async(
&mut linker,
&wasmtime_wasi::p2::bindings::LinkOptions::default(),
)?;
let mut builder = WasiCtx::builder();
if let Some(stdout) = stdout {
builder.stdout(stdout);
} else {
builder.inherit_stdout();
}
if let Some(stderr) = stderr {
builder.stderr(stderr);
} else {
builder.inherit_stderr();
}
let (wasi, io) = builder.env("RUST_BACKTRACE", "1").build();
let host = Host {
table: Arc::new(Mutex::new(ResourceTable::new())),
wasi: Arc::new(Mutex::new(wasi)),
io: Arc::new(Mutex::new(io)),
};
let component = Component::from_file(&engine, wasm_path)?;
let mut store = Store::new(&engine, host);
let mut linker_instance = linker.root();
let component_type = component.component_type();
for (name, item) in component_type.imports(&engine) {
let name = name.to_string();
match item {
ComponentItem::ComponentFunc(_) => {}
ComponentItem::CoreFunc(_) => {}
ComponentItem::Module(_) => {}
ComponentItem::Component(_) => {}
ComponentItem::ComponentInstance(ref inst) => {
dynamic_import(&name, &engine, &mut linker_instance, inst)?;
}
ComponentItem::Type(_) => {}
ComponentItem::Resource(_) => {}
}
}
debug!("Instantiating component");
let instance = linker.instantiate_async(&mut store, &component).await?;
let func = if let Some(func) = find_discover_function(&mut store, &instance) {
func
} else if fail_on_missing_discover_method {
return Err(anyhow!(
"Function {FUNCTION_NAME} not found in interface {INTERFACE_NAME}"
));
} else {
return Ok(Vec::new());
};
let typed_func = func.typed::<(), (
Result<
Vec<crate::model::agent::bindings::golem::agent::common::AgentType>,
crate::model::agent::bindings::golem::agent::common::AgentError,
>,
)>(&mut store)?;
let results = typed_func.call_async(&mut store, ()).await?;
typed_func.post_return_async(&mut store).await?;
match results.0 {
Ok(results) => {
let agent_types = results.into_iter().map(AgentType::from).collect();
debug!("Discovered agent types: {:#?}", agent_types);
Ok(agent_types)
}
Err(agent_error) => {
let agent_error: AgentError = agent_error.into();
error!("Error while discovering agent types: {agent_error}");
Err(anyhow!(agent_error.to_string()))
}
}
}
pub async fn extract_agent_types(
wasm_path: &Path,
fail_on_missing_discover_method: bool,
) -> anyhow::Result<Vec<AgentType>> {
extract_agent_types_with_streams(
wasm_path,
None::<pipe::MemoryOutputPipe>,
None::<pipe::MemoryOutputPipe>,
fail_on_missing_discover_method,
)
.await
}
pub fn is_agent(
resolve: &Resolve,
root_package_id: &PackageId,
world: Option<&str>,
) -> anyhow::Result<bool> {
let golem_agent_package = wit_parser::PackageName {
namespace: "golem".to_string(),
name: "agent".to_string(),
version: None,
};
const GOLEM_AGENT_INTERFACE_NAME: &str = "guest";
let world_id = resolve.select_world(*root_package_id, world)?;
let world = resolve
.worlds
.get(world_id)
.ok_or_else(|| anyhow!("Could not get {world_id:?}"))?;
let world_name = &world.name;
for (key, item) in &world.exports {
if let WorldItem::Interface { id, .. } = &item {
let interface = resolve.interfaces.get(*id).ok_or_else(|| {
anyhow!("Could not get exported interface {key:?} exported from world {world_name}")
})?;
if let Some(interface_name) = interface.name.as_ref() {
if interface_name == GOLEM_AGENT_INTERFACE_NAME {
if let Some(package_id) = &interface.package {
let package = resolve.packages.get(*package_id).ok_or_else(|| {
anyhow!(
"Could not get owner package of exported interface {interface_name}"
)
})?;
if package.name == golem_agent_package {
return Ok(true);
}
}
}
}
}
}
Ok(false)
}
fn find_discover_function(mut store: impl AsContextMut, instance: &Instance) -> Option<Func> {
let (_, exported_instance_id) = instance.get_export(&mut store, None, INTERFACE_NAME)?;
let (_, func_id) =
instance.get_export(&mut store, Some(&exported_instance_id), FUNCTION_NAME)?;
let func = instance.get_func(&mut store, func_id)?;
Some(func)
}
#[derive(Clone)]
struct Host {
pub table: Arc<Mutex<ResourceTable>>,
pub wasi: Arc<Mutex<WasiCtx>>,
pub io: Arc<Mutex<IoCtx>>,
}
impl IoView for Host {
fn table(&mut self) -> &mut ResourceTable {
Arc::get_mut(&mut self.table)
.expect("ResourceTable is shared and cannot be borrowed mutably")
.get_mut()
.expect("ResourceTable mutex must never fail")
}
fn io_ctx(&mut self) -> &mut IoCtx {
Arc::get_mut(&mut self.io)
.expect("IoCtx is shared and cannot be borrowed mutably")
.get_mut()
.expect("IoCtx mutex must never fail")
}
}
impl WasiView for Host {
fn ctx(&mut self) -> &mut WasiCtx {
Arc::get_mut(&mut self.wasi)
.expect("WasiCtx is shared and cannot be borrowed mutably")
.get_mut()
.expect("WasiCtx mutex must never fail")
}
}
fn dynamic_import(
name: &str,
engine: &Engine,
root: &mut LinkerInstance<Host>,
inst: &ComponentInstance,
) -> anyhow::Result<()> {
if name.starts_with("wasi:cli")
|| name.starts_with("wasi:clocks")
|| name.starts_with("wasi:filesystem")
|| name.starts_with("wasi:io")
|| name.starts_with("wasi:random")
|| name.starts_with("wasi:sockets")
{
Ok(())
} else {
let mut instance = root.instance(name)?;
let mut functions = Vec::new();
for (inner_name, inner_item) in inst.exports(engine) {
let name = name.to_owned();
let inner_name = inner_name.to_owned();
match inner_item {
ComponentItem::ComponentFunc(fun) => {
let param_types: Vec<Type> = fun.params().map(|(_, t)| t).collect();
let result_types: Vec<Type> = fun.results().collect();
let function_name = ParsedFunctionName::parse(format!(
"{name}.{{{inner_name}}}"
))
.map_err(|err| anyhow!(format!("Unexpected linking error: {name}.{{{inner_name}}} is not a valid function name: {err}")))?;
functions.push(FunctionInfo {
name: function_name,
params: param_types,
results: result_types,
});
}
ComponentItem::CoreFunc(_) => {}
ComponentItem::Module(_) => {}
ComponentItem::Component(_) => {}
ComponentItem::ComponentInstance(_) => {}
ComponentItem::Type(_) => {}
ComponentItem::Resource(_resource) => {
if &inner_name != "pollable"
&& inner_name != "wasi-io-pollable"
&& &inner_name != "input-stream"
&& &inner_name != "output-stream"
&& &inner_name != "incoming-value-async-body"
&& &inner_name != "outgoing-value-body-async"
{
instance.resource(
&inner_name,
ResourceType::host::<ResourceEntry>(),
|_store, _rep| Ok(()),
)?;
}
}
}
}
for function in functions {
instance.func_new_async(
&function.name.function.function_name(),
move |_store, _params, _results| {
let function_name = function.name.clone();
Box::new(async move {
error!(
"External function called in get-agent-definitions: {function_name}",
);
Err(anyhow!(
"External function called in get-agent-definitions: {function_name}"
))
})
},
)?;
}
Ok(())
}
}
#[allow(unused)]
struct MethodInfo {
method_name: String,
params: Vec<Type>,
results: Vec<Type>,
}
#[allow(unused)]
struct FunctionInfo {
name: ParsedFunctionName,
params: Vec<Type>,
results: Vec<Type>,
}
struct ResourceEntry;
#[cfg(test)]
mod tests {
use super::*;
use assert2::assert;
use std::path::PathBuf;
use std::str::FromStr;
use test_r::test;
#[test]
async fn can_extract_agent_types_from_component_with_dynamic_rpc() -> anyhow::Result<()> {
let result =
extract_agent_types(&PathBuf::from_str("../test-components/caller.wasm")?, false).await;
assert!(let Ok(_) = result);
Ok(())
}
}