use std::collections::HashSet;
use anyhow::{anyhow, bail, Context, Result};
use heck::{ToShoutySnakeCase, ToSnakeCase, ToUpperCamelCase};
use proc_macro2::{Span, TokenStream};
use quote::quote;
use wit_parser::{
Function, Handle, Interface, InterfaceId, Resolve, Type, TypeDefKind, TypeId, TypeOwner,
WorldId, WorldItem,
};
use super::bindings_index::{bindings_path_tokens, BindingsItem, BindingsPath, WrapperBindings};
struct IfaceEntry {
id: InterfaceId,
path: BindingsPath,
is_export: bool,
}
pub fn build_ir(
resolve: &Resolve,
world_id: WorldId,
bindings: &WrapperBindings,
) -> Result<WrapperIR> {
let world = &resolve.worlds[world_id];
let mut ifaces: Vec<IfaceEntry> = Vec::new();
let mut seen_iface_ids: HashSet<InterfaceId> = HashSet::new();
let exports = world.exports.iter().map(|(_, item)| (item, true));
let imports = world.imports.iter().map(|(_, item)| (item, false));
for (item, is_export) in exports.chain(imports) {
let id = require_iface(item)?;
if !seen_iface_ids.insert(id) {
continue;
}
let path = module_path_for_interface(resolve, id, is_export).with_context(|| {
let side = if is_export { "exported" } else { "imported" };
format!("could not derive module path for {side} interface")
})?;
ifaces.push(IfaceEntry {
id,
path,
is_export,
});
}
let mut types: Vec<NamedType> = Vec::new();
let mut seen: HashSet<(BindingsPath, String)> = HashSet::new();
for entry in &ifaces {
let iface = &resolve.interfaces[entry.id];
for (wit_name, type_id) in &iface.types {
let nt = build_named_type(resolve, &ifaces, *type_id, wit_name, &entry.path, bindings)?;
if let Some(nt) = nt {
let key = (
match &nt.location {
TypeLocation::InBindings { path } => path.clone(),
TypeLocation::TopLevel => Vec::new(),
},
nt.rust_ident.to_string(),
);
if seen.insert(key) {
types.push(nt);
}
}
}
}
let mut args_records: Vec<NamedType> = Vec::new();
for entry in ifaces.iter().filter(|e| e.is_export) {
let iface = &resolve.interfaces[entry.id];
let iface_pascal = iface
.name
.as_ref()
.ok_or_else(|| anyhow!("exported interface has no name"))?
.to_upper_camel_case();
for (fn_name, func) in &iface.functions {
let args = synth_args_record(resolve, &iface_pascal, fn_name, func, &ifaces)?;
args_records.push(args);
}
}
Ok(WrapperIR {
types,
args_records,
})
}
pub struct WrapperIR {
pub types: Vec<NamedType>,
pub args_records: Vec<NamedType>,
}
pub struct NamedType {
pub location: TypeLocation,
pub rust_ident: syn::Ident,
pub kind: NamedKind,
}
pub enum TypeLocation {
InBindings { path: BindingsPath },
TopLevel,
}
impl NamedType {
pub fn rust_path_tokens(&self) -> TokenStream {
let ident = &self.rust_ident;
match &self.location {
TypeLocation::InBindings { path } => bindings_path_tokens(path, Some(ident)),
TypeLocation::TopLevel => quote!(#ident),
}
}
}
pub enum NamedKind {
Record { fields: Vec<RecordField> },
Variant { cases: Vec<VariantCase> },
Enum { cases: Vec<EnumCase> },
Flags { members: Vec<FlagMember> },
}
pub struct RecordField {
pub wit_name: String,
pub rust_ident: syn::Ident,
pub ty: WitTypeRef,
}
pub struct VariantCase {
pub wit_name: String,
pub rust_ident: syn::Ident,
pub payload: Option<WitTypeRef>,
}
pub struct EnumCase {
pub wit_name: String,
pub rust_ident: syn::Ident,
}
pub struct FlagMember {
pub wit_name: String,
pub rust_ident: syn::Ident,
}
pub struct NamedRef {
pub path: BindingsPath,
pub rust_ident: syn::Ident,
}
pub enum WitTypeRef {
Primitive(Prim),
List(Box<WitTypeRef>),
Option(Box<WitTypeRef>),
Result {
ok: Option<Box<WitTypeRef>>,
err: Option<Box<WitTypeRef>>,
},
Tuple(Vec<WitTypeRef>),
Named(NamedRef),
}
impl WitTypeRef {
pub fn to_tokens(&self) -> TokenStream {
match self {
WitTypeRef::Primitive(p) => p.to_tokens(),
WitTypeRef::List(inner) => {
let t = inner.to_tokens();
quote!(::std::vec::Vec<#t>)
}
WitTypeRef::Option(inner) => {
let t = inner.to_tokens();
quote!(::core::option::Option<#t>)
}
WitTypeRef::Result { ok, err } => {
let ok_ty = match ok {
Some(t) => t.to_tokens(),
None => quote!(()),
};
let err_ty = match err {
Some(t) => t.to_tokens(),
None => quote!(()),
};
quote!(::core::result::Result<#ok_ty, #err_ty>)
}
WitTypeRef::Tuple(elems) => {
let ts: Vec<_> = elems.iter().map(|t| t.to_tokens()).collect();
if ts.len() == 1 {
let t = &ts[0];
quote!((#t,))
} else {
quote!((#(#ts),*))
}
}
WitTypeRef::Named(NamedRef { path, rust_ident }) => {
bindings_path_tokens(path, Some(rust_ident))
}
}
}
}
#[derive(Copy, Clone)]
pub enum Prim {
Bool,
U8,
U16,
U32,
U64,
S8,
S16,
S32,
S64,
F32,
F64,
Char,
String,
}
impl Prim {
fn to_tokens(self) -> TokenStream {
match self {
Prim::Bool => quote!(bool),
Prim::U8 => quote!(u8),
Prim::U16 => quote!(u16),
Prim::U32 => quote!(u32),
Prim::U64 => quote!(u64),
Prim::S8 => quote!(i8),
Prim::S16 => quote!(i16),
Prim::S32 => quote!(i32),
Prim::S64 => quote!(i64),
Prim::F32 => quote!(f32),
Prim::F64 => quote!(f64),
Prim::Char => quote!(char),
Prim::String => quote!(::std::string::String),
}
}
}
fn build_named_type(
resolve: &Resolve,
ifaces: &[IfaceEntry],
type_id: TypeId,
wit_name: &str,
bindings_path: &BindingsPath,
bindings: &WrapperBindings,
) -> Result<Option<NamedType>> {
let td = &resolve.types[type_id];
let rust_ident_str = wit_name.to_upper_camel_case();
let rust_ident = syn::Ident::new(&rust_ident_str, Span::call_site());
let kind = match &td.kind {
TypeDefKind::Record(r) => {
let item = bindings.index.get(bindings_path, &rust_ident_str);
match item {
Some(BindingsItem::Struct) => {}
_ => bail!(
"wit-bindgen did not emit a struct for WIT record {wit_name:?} at \
bindings::{} (got {})",
bindings_path.join("::"),
describe_item(item),
),
};
let fields = record_fields_from(
resolve,
ifaces,
r.fields.iter().map(|f| (f.name.as_str(), f.ty)),
)?;
NamedKind::Record { fields }
}
TypeDefKind::Variant(v) => {
let item = bindings.index.get(bindings_path, &rust_ident_str);
match item {
Some(BindingsItem::Enum) => {}
_ => bail!(
"wit-bindgen did not emit an enum for WIT variant {wit_name:?} at \
bindings::{} (got {})",
bindings_path.join("::"),
describe_item(item),
),
};
let cases = v
.cases
.iter()
.map(|c| {
let pascal = c.name.to_upper_camel_case();
Ok(VariantCase {
wit_name: c.name.clone(),
rust_ident: syn::Ident::new(&pascal, Span::call_site()),
payload: c
.ty
.as_ref()
.map(|t| type_to_ref(resolve, ifaces, t))
.transpose()?,
})
})
.collect::<Result<_>>()?;
NamedKind::Variant { cases }
}
TypeDefKind::Enum(e) => {
let item = bindings.index.get(bindings_path, &rust_ident_str);
match item {
Some(BindingsItem::Enum) => {}
_ => bail!(
"wit-bindgen did not emit an enum for WIT enum {wit_name:?} at \
bindings::{} (got {})",
bindings_path.join("::"),
describe_item(item),
),
};
let cases = e
.cases
.iter()
.map(|c| EnumCase {
wit_name: c.name.clone(),
rust_ident: syn::Ident::new(&c.name.to_upper_camel_case(), Span::call_site()),
})
.collect();
NamedKind::Enum { cases }
}
TypeDefKind::Flags(f) => {
let item = bindings.index.get(bindings_path, &rust_ident_str);
match item {
Some(BindingsItem::BitflagsMacro) => {}
_ => bail!(
"wit-bindgen did not emit a bitflags! macro for WIT flags {wit_name:?} \
at bindings::{} (got {})",
bindings_path.join("::"),
describe_item(item),
),
};
let members = f
.flags
.iter()
.map(|flag| FlagMember {
wit_name: flag.name.clone(),
rust_ident: syn::Ident::new(
&flag.name.to_shouty_snake_case(),
Span::call_site(),
),
})
.collect();
NamedKind::Flags { members }
}
TypeDefKind::Type(_) => return Ok(None),
TypeDefKind::Resource | TypeDefKind::Handle(_) => {
bail!("resource and handle types are not supported (encountered {wit_name:?})")
}
TypeDefKind::Future(_) | TypeDefKind::Stream(_) => {
bail!("future and stream types are not supported (encountered {wit_name:?})")
}
TypeDefKind::Map(..) | TypeDefKind::FixedLengthList(..) => bail!(
"{} types are not supported (encountered {wit_name:?})",
td.kind.as_str()
),
TypeDefKind::List(_)
| TypeDefKind::Option(_)
| TypeDefKind::Result(_)
| TypeDefKind::Tuple(_) => bail!(
"top-level named {} types are not supported (encountered {wit_name:?})",
td.kind.as_str()
),
TypeDefKind::Unknown => bail!("unresolved WIT type {wit_name:?}"),
};
Ok(Some(NamedType {
location: TypeLocation::InBindings {
path: bindings_path.clone(),
},
rust_ident,
kind,
}))
}
fn describe_item(item: Option<&BindingsItem>) -> &'static str {
match item {
None => "nothing",
Some(BindingsItem::Struct) => "Struct",
Some(BindingsItem::Enum) => "Enum",
Some(BindingsItem::BitflagsMacro) => "BitflagsMacro",
}
}
fn type_to_ref(resolve: &Resolve, ifaces: &[IfaceEntry], ty: &Type) -> Result<WitTypeRef> {
Ok(match ty {
Type::Bool => WitTypeRef::Primitive(Prim::Bool),
Type::U8 => WitTypeRef::Primitive(Prim::U8),
Type::U16 => WitTypeRef::Primitive(Prim::U16),
Type::U32 => WitTypeRef::Primitive(Prim::U32),
Type::U64 => WitTypeRef::Primitive(Prim::U64),
Type::S8 => WitTypeRef::Primitive(Prim::S8),
Type::S16 => WitTypeRef::Primitive(Prim::S16),
Type::S32 => WitTypeRef::Primitive(Prim::S32),
Type::S64 => WitTypeRef::Primitive(Prim::S64),
Type::F32 => WitTypeRef::Primitive(Prim::F32),
Type::F64 => WitTypeRef::Primitive(Prim::F64),
Type::Char => WitTypeRef::Primitive(Prim::Char),
Type::String => WitTypeRef::Primitive(Prim::String),
Type::ErrorContext => bail!("error-context type not supported"),
Type::Id(id) => {
let td = &resolve.types[*id];
match &td.kind {
TypeDefKind::List(t) => {
WitTypeRef::List(Box::new(type_to_ref(resolve, ifaces, t)?))
}
TypeDefKind::Option(t) => {
WitTypeRef::Option(Box::new(type_to_ref(resolve, ifaces, t)?))
}
TypeDefKind::Result(r) => WitTypeRef::Result {
ok: r
.ok
.as_ref()
.map(|t| type_to_ref(resolve, ifaces, t).map(Box::new))
.transpose()?,
err: r
.err
.as_ref()
.map(|t| type_to_ref(resolve, ifaces, t).map(Box::new))
.transpose()?,
},
TypeDefKind::Tuple(t) => WitTypeRef::Tuple(
t.types
.iter()
.map(|ty| type_to_ref(resolve, ifaces, ty))
.collect::<Result<_>>()?,
),
TypeDefKind::Type(inner) => return type_to_ref(resolve, ifaces, inner),
TypeDefKind::Record(_)
| TypeDefKind::Variant(_)
| TypeDefKind::Enum(_)
| TypeDefKind::Flags(_) => {
let (path, rust_ident) = named_ref_for(resolve, ifaces, *id)?;
WitTypeRef::Named(NamedRef { path, rust_ident })
}
TypeDefKind::Resource
| TypeDefKind::Handle(Handle::Own(_))
| TypeDefKind::Handle(Handle::Borrow(_)) => {
bail!("resource/handle in field position not supported")
}
TypeDefKind::Future(_) | TypeDefKind::Stream(_) => {
bail!("future/stream in field position not supported")
}
TypeDefKind::Map(..) | TypeDefKind::FixedLengthList(..) => {
bail!("{} in field position not supported", td.kind.as_str())
}
TypeDefKind::Unknown => bail!("unresolved type in field position"),
}
}
})
}
fn named_ref_for(
resolve: &Resolve,
ifaces: &[IfaceEntry],
type_id: TypeId,
) -> Result<(BindingsPath, syn::Ident)> {
let td = &resolve.types[type_id];
let name = td
.name
.as_ref()
.ok_or_else(|| anyhow!("named-WIT type at field position has no name"))?;
let path = match td.owner {
TypeOwner::Interface(iface_id) => ifaces
.iter()
.find(|e| e.id == iface_id)
.ok_or_else(|| {
anyhow!(
"type {name:?} is owned by an interface not reachable from the world's \
imports or exports"
)
})?
.path
.clone(),
TypeOwner::World(_) => bail!("world-owned types not supported"),
TypeOwner::None => Vec::new(),
};
Ok((
path,
syn::Ident::new(&name.to_upper_camel_case(), Span::call_site()),
))
}
fn require_iface(item: &WorldItem) -> Result<InterfaceId> {
match item {
WorldItem::Interface { id, .. } => Ok(*id),
WorldItem::Function(_) => {
bail!("world-level functions (outside an interface) are not supported")
}
WorldItem::Type { .. } => bail!("world-level type aliases are not supported"),
}
}
fn module_path_for_interface(
resolve: &Resolve,
interface_id: InterfaceId,
is_export: bool,
) -> Result<BindingsPath> {
let iface: &Interface = &resolve.interfaces[interface_id];
let iface_name = iface
.name
.as_ref()
.ok_or_else(|| anyhow!("interface has no name"))?;
let pkg_id = iface
.package
.ok_or_else(|| anyhow!("interface {iface_name:?} has no package"))?;
let pkg = &resolve.packages[pkg_id];
let mut path = Vec::new();
if is_export {
path.push("exports".to_string());
}
path.push(pkg.name.namespace.to_snake_case());
path.push(pkg.name.name.to_snake_case());
path.push(iface_name.to_snake_case());
Ok(path)
}
fn synth_args_record(
resolve: &Resolve,
iface_pascal: &str,
fn_name: &str,
func: &Function,
ifaces: &[IfaceEntry],
) -> Result<NamedType> {
let rust_ident = args_struct_ident(iface_pascal, fn_name);
let fields = record_fields_from(
resolve,
ifaces,
func.params.iter().map(|p| (p.name.as_str(), p.ty)),
)?;
Ok(NamedType {
location: TypeLocation::TopLevel,
rust_ident,
kind: NamedKind::Record { fields },
})
}
fn record_fields_from<'a, I>(
resolve: &Resolve,
ifaces: &[IfaceEntry],
pairs: I,
) -> Result<Vec<RecordField>>
where
I: IntoIterator<Item = (&'a str, Type)>,
{
pairs
.into_iter()
.map(|(name, ty)| {
Ok(RecordField {
wit_name: name.to_string(),
rust_ident: mirror_field_ident(name),
ty: type_to_ref(resolve, ifaces, &ty)?,
})
})
.collect()
}
pub(super) fn args_struct_ident(interface_pascal: &str, method_name: &str) -> syn::Ident {
let method_pascal = method_name.to_upper_camel_case();
syn::Ident::new(
&format!("{interface_pascal}{method_pascal}Args"),
Span::call_site(),
)
}
fn mirror_field_ident(wit_name: &str) -> syn::Ident {
let snake = wit_name.to_snake_case();
let final_name = if is_rust_keyword(&snake) {
format!("{snake}_")
} else {
snake
};
syn::Ident::new(&final_name, Span::call_site())
}
pub(super) fn is_rust_keyword(s: &str) -> bool {
syn::parse_str::<syn::Ident>(s).is_err()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::adapter::typed::bindgen::run_wit_bindgen_rust;
use crate::adapter::typed::bindings_index::build_bindings_index;
fn build(wit: &str, world: &str) -> WrapperIR {
let (resolve, world_id, src) = run_wit_bindgen_rust(wit, Some(world)).unwrap();
let bindings = build_bindings_index(&src).unwrap();
build_ir(&resolve, world_id, &bindings).unwrap()
}
#[test]
fn record_field_types_capture_primitives_and_compositions() {
let wit = r#"
package test:pkg@0.1.0;
interface ops {
record point { x: u32, y: u32, label: string }
record blob { bytes: list<u8>, tag: option<string> }
use-p: func(p: point) -> blob;
}
world w { export ops; }
"#;
let ir = build(wit, "w");
let point = ir
.types
.iter()
.find(|t| t.rust_ident == "Point")
.expect("Point in IR");
match &point.kind {
NamedKind::Record { fields } => {
assert_eq!(fields.len(), 3);
assert_eq!(fields[0].wit_name, "x");
assert!(matches!(fields[0].ty, WitTypeRef::Primitive(Prim::U32)));
assert_eq!(fields[2].wit_name, "label");
assert!(matches!(fields[2].ty, WitTypeRef::Primitive(Prim::String)));
}
_ => panic!("expected Record"),
}
let blob = ir
.types
.iter()
.find(|t| t.rust_ident == "Blob")
.expect("Blob in IR");
match &blob.kind {
NamedKind::Record { fields } => {
assert_eq!(fields[0].wit_name, "bytes");
assert!(
matches!(&fields[0].ty, WitTypeRef::List(inner)
if matches!(**inner, WitTypeRef::Primitive(Prim::U8))),
"expected list<u8>"
);
assert_eq!(fields[1].wit_name, "tag");
assert!(
matches!(&fields[1].ty, WitTypeRef::Option(inner)
if matches!(**inner, WitTypeRef::Primitive(Prim::String))),
"expected option<string>"
);
}
_ => panic!("expected Record"),
}
}
#[test]
fn flags_kind_recovers_kebab_and_shouting_snake_idents() {
let wit = r#"
package test:pkg@0.1.0;
interface ops {
flags perms { read, write, exec-x }
check: func(p: perms);
}
world w { export ops; }
"#;
let ir = build(wit, "w");
let perms = ir
.types
.iter()
.find(|t| t.rust_ident == "Perms")
.expect("Perms in IR");
match &perms.kind {
NamedKind::Flags { members } => {
assert_eq!(members.len(), 3);
assert_eq!(members[0].wit_name, "read");
assert_eq!(members[0].rust_ident, "READ");
assert_eq!(members[2].wit_name, "exec-x");
assert_eq!(members[2].rust_ident, "EXEC_X");
}
_ => panic!("expected Flags"),
}
}
#[test]
fn variant_kind_captures_unit_and_payload_cases() {
let wit = r#"
package test:pkg@0.1.0;
interface ops {
variant outcome { miss, hit(u32), report(string) }
go: func() -> outcome;
}
world w { export ops; }
"#;
let ir = build(wit, "w");
let outcome = ir
.types
.iter()
.find(|t| t.rust_ident == "Outcome")
.expect("Outcome in IR");
match &outcome.kind {
NamedKind::Variant { cases } => {
assert_eq!(cases.len(), 3);
assert!(cases[0].payload.is_none());
assert!(
matches!(cases[1].payload, Some(WitTypeRef::Primitive(Prim::U32))),
"hit should carry u32"
);
assert_eq!(cases[2].rust_ident, "Report");
}
_ => panic!("expected Variant"),
}
}
#[test]
fn enum_kind_collects_unit_cases() {
let wit = r#"
package test:pkg@0.1.0;
interface ops {
enum color { red, green, blue }
tag: func(c: color);
}
world w { export ops; }
"#;
let ir = build(wit, "w");
let color = ir
.types
.iter()
.find(|t| t.rust_ident == "Color")
.expect("Color in IR");
match &color.kind {
NamedKind::Enum { cases } => {
assert_eq!(cases.len(), 3);
assert_eq!(cases[0].wit_name, "red");
assert_eq!(cases[0].rust_ident, "Red");
}
_ => panic!("expected Enum"),
}
}
#[test]
fn args_records_synthesized_per_method() {
let wit = r#"
package test:pkg@0.1.0;
interface ops {
add: func(a: u32, b: u32) -> u32;
noop: func();
}
world w { export ops; }
"#;
let ir = build(wit, "w");
assert_eq!(ir.args_records.len(), 2);
let add_args = ir
.args_records
.iter()
.find(|t| t.rust_ident == "OpsAddArgs")
.expect("OpsAddArgs in args_records");
match &add_args.kind {
NamedKind::Record { fields } => assert_eq!(fields.len(), 2),
_ => panic!("args should be Records"),
}
let noop_args = ir
.args_records
.iter()
.find(|t| t.rust_ident == "OpsNoopArgs")
.expect("OpsNoopArgs in args_records");
match &noop_args.kind {
NamedKind::Record { fields } => assert_eq!(fields.len(), 0),
_ => panic!("args should be Records"),
}
}
#[test]
fn keyword_field_idents_get_trailing_underscore() {
let wit = r#"
package test:pkg@0.1.0;
interface ops {
record %type {
%loop: u32,
%match: u32,
}
use-t: func(t: %type);
}
world w { export ops; }
"#;
let ir = build(wit, "w");
let t = ir.types.iter().find(|t| t.rust_ident == "Type").unwrap();
match &t.kind {
NamedKind::Record { fields } => {
let names: Vec<String> = fields.iter().map(|f| f.rust_ident.to_string()).collect();
assert!(names.contains(&"loop_".to_string()));
assert!(names.contains(&"match_".to_string()));
}
_ => panic!("expected Record"),
}
}
#[test]
fn resources_are_rejected_loudly() {
let wit = r#"
package test:pkg@0.1.0;
interface ops {
resource thing { }
make: func() -> thing;
}
world w { export ops; }
"#;
let (resolve, world_id, src) = run_wit_bindgen_rust(wit, Some("w")).unwrap();
let bindings = build_bindings_index(&src).unwrap();
let err = match build_ir(&resolve, world_id, &bindings) {
Ok(_) => panic!("expected resource rejection"),
Err(e) => e,
};
let msg = format!("{err:#}");
assert!(
msg.contains("resource"),
"expected resource rejection; got: {msg}"
);
}
#[test]
fn type_alias_is_transparent() {
let wit = r#"
package test:pkg@0.1.0;
interface ops {
type %id = u32;
fetch: func(i: %id) -> %id;
}
world w { export ops; }
"#;
let ir = build(wit, "w");
assert!(!ir.types.iter().any(|t| t.rust_ident == "Id"));
let args = ir
.args_records
.iter()
.find(|t| t.rust_ident == "OpsFetchArgs")
.unwrap();
match &args.kind {
NamedKind::Record { fields } => {
assert_eq!(fields.len(), 1);
assert!(matches!(fields[0].ty, WitTypeRef::Primitive(Prim::U32)));
}
_ => panic!(),
}
}
}