use heck::*;
use std::io::{Read, Write};
use std::path::Path;
use std::process::{Command, Stdio};
use witx::*;
pub fn generate<P: AsRef<Path>>(witx_paths: &[P]) -> String {
let doc = witx::load(witx_paths).unwrap();
let mut raw = String::new();
raw.push_str(
"\
// This file is automatically generated, DO NOT EDIT
//
// To regenerate this file run the `crates/witx-bindgen` command
use core::mem::MaybeUninit;
pub use crate::error::Error;
pub type Result<T, E = Error> = core::result::Result<T, E>;
",
);
for ty in doc.typenames() {
ty.render(&mut raw);
raw.push_str("\n");
}
for m in doc.modules() {
m.render(&mut raw);
raw.push_str("\n");
}
let mut rustfmt = Command::new("rustfmt")
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.spawn()
.unwrap();
rustfmt
.stdin
.take()
.unwrap()
.write_all(raw.as_bytes())
.unwrap();
let mut ret = String::new();
rustfmt
.stdout
.take()
.unwrap()
.read_to_string(&mut ret)
.unwrap();
let status = rustfmt.wait().unwrap();
assert!(status.success());
return ret;
}
trait Render {
fn render(&self, src: &mut String);
}
impl Render for NamedType {
fn render(&self, src: &mut String) {
let name = self.name.as_str();
match &self.tref {
TypeRef::Value(ty) => match &**ty {
Type::Enum(e) => render_enum(src, name, e),
Type::Flags(f) => render_flags(src, name, f),
Type::Int(c) => render_const(src, name, c),
Type::Struct(s) => render_struct(src, name, s),
Type::Union(u) => render_union(src, name, u),
Type::Handle(h) => render_handle(src, name, h),
Type::Array { .. }
| Type::Pointer { .. }
| Type::ConstPointer { .. }
| Type::Builtin { .. } => render_alias(src, name, &self.tref),
},
TypeRef::Name(_nt) => render_alias(src, name, &self.tref),
}
}
}
fn render_const(src: &mut String, name: &str, c: &IntDatatype) {
src.push_str(&format!("pub type {} = ", name.to_camel_case()));
c.repr.render(src);
src.push_str(";\n");
for r#const in c.consts.iter() {
rustdoc(&r#const.docs, src);
src.push_str(&format!(
"pub const {}_{}: {} = {};",
name.to_shouty_snake_case(),
r#const.name.as_str().to_shouty_snake_case(),
name.to_camel_case(),
r#const.value
));
}
}
fn render_union(src: &mut String, name: &str, u: &UnionDatatype) {
src.push_str("#[repr(C)]\n");
src.push_str("#[derive(Copy, Clone)]\n");
src.push_str(&format!("pub union {}U {{\n", name.to_camel_case()));
for variant in u.variants.iter() {
if let Some(ref tref) = variant.tref {
rustdoc(&variant.docs, src);
src.push_str("pub ");
variant.name.render(src);
src.push_str(": ");
tref.render(src);
src.push_str(",\n");
}
}
src.push_str("}\n");
src.push_str("#[repr(C)]\n");
src.push_str("#[derive(Copy, Clone)]\n");
src.push_str(&format!("pub struct {} {{\n", name.to_camel_case()));
src.push_str(&format!(
"pub tag: {},\n",
u.tag.name.as_str().to_camel_case()
));
src.push_str(&format!("pub u: {}U,\n", name.to_camel_case()));
src.push_str("}\n");
}
fn render_struct(src: &mut String, name: &str, s: &StructDatatype) {
src.push_str("#[repr(C)]\n");
if struct_contains_union(s) {
src.push_str("#[derive(Copy, Clone)]\n");
} else {
src.push_str("#[derive(Copy, Clone, Debug)]\n");
}
src.push_str(&format!("pub struct {} {{\n", name.to_camel_case()));
for member in s.members.iter() {
rustdoc(&member.docs, src);
src.push_str("pub ");
member.name.render(src);
src.push_str(": ");
member.tref.render(src);
src.push_str(",\n");
}
src.push_str("}");
}
fn render_flags(src: &mut String, name: &str, f: &FlagsDatatype) {
src.push_str(&format!("pub type {} = ", name.to_camel_case()));
f.repr.render(src);
src.push_str(";\n");
for (i, variant) in f.flags.iter().enumerate() {
rustdoc(&variant.docs, src);
src.push_str(&format!(
"pub const {}_{}: {} = 0x{:x};",
name.to_shouty_snake_case(),
variant.name.as_str().to_shouty_snake_case(),
name.to_camel_case(),
1 << i
));
}
}
fn render_enum(src: &mut String, name: &str, e: &EnumDatatype) {
src.push_str(&format!("pub type {} = ", name.to_camel_case()));
e.repr.render(src);
src.push_str(";\n");
for (i, variant) in e.variants.iter().enumerate() {
rustdoc(&variant.docs, src);
src.push_str(&format!(
"pub const {}_{}: {} = {};",
name.to_shouty_snake_case(),
variant.name.as_str().to_shouty_snake_case(),
name.to_camel_case(),
i
));
}
if name == "errno" {
src.push_str("pub(crate) fn strerror(code: u16) -> &'static str {");
src.push_str("match code {");
for variant in e.variants.iter() {
src.push_str(&name.to_shouty_snake_case());
src.push_str("_");
src.push_str(&variant.name.as_str().to_shouty_snake_case());
src.push_str(" => \"");
src.push_str(variant.docs.trim());
src.push_str("\",");
}
src.push_str("_ => \"Unknown error.\",");
src.push_str("}");
src.push_str("}");
}
}
impl Render for IntRepr {
fn render(&self, src: &mut String) {
match self {
IntRepr::U8 => src.push_str("u8"),
IntRepr::U16 => src.push_str("u16"),
IntRepr::U32 => src.push_str("u32"),
IntRepr::U64 => src.push_str("u64"),
}
}
}
fn render_alias(src: &mut String, name: &str, dest: &TypeRef) {
src.push_str(&format!("pub type {}", name.to_camel_case()));
if dest.type_().passed_by() == TypePassedBy::PointerLengthPair {
src.push_str("<'a>");
}
src.push_str(" = ");
if name == "size" {
src.push_str("usize");
} else {
dest.render(src);
}
src.push(';');
}
impl Render for TypeRef {
fn render(&self, src: &mut String) {
match self {
TypeRef::Name(t) => {
src.push_str(&t.name.as_str().to_camel_case());
if t.type_().passed_by() == TypePassedBy::PointerLengthPair {
src.push_str("<'_>");
}
}
TypeRef::Value(v) => match &**v {
Type::Builtin(t) => t.render(src),
Type::Array(t) => {
src.push_str("&'a [");
t.render(src);
src.push_str("]");
}
Type::Pointer(t) => {
src.push_str("*mut ");
t.render(src);
}
Type::ConstPointer(t) => {
src.push_str("*const ");
t.render(src);
}
t => panic!("reference to anonymous {} not possible!", t.kind()),
},
}
}
}
impl Render for BuiltinType {
fn render(&self, src: &mut String) {
match self {
BuiltinType::String => src.push_str("&str"),
BuiltinType::U8 => src.push_str("u8"),
BuiltinType::U16 => src.push_str("u16"),
BuiltinType::U32 => src.push_str("u32"),
BuiltinType::U64 => src.push_str("u64"),
BuiltinType::S8 => src.push_str("i8"),
BuiltinType::S16 => src.push_str("i16"),
BuiltinType::S32 => src.push_str("i32"),
BuiltinType::S64 => src.push_str("i64"),
BuiltinType::F32 => src.push_str("f32"),
BuiltinType::F64 => src.push_str("f64"),
BuiltinType::USize => src.push_str("usize"),
BuiltinType::Char8 => {
src.push_str("u8")
}
}
}
}
impl Render for Module {
fn render(&self, src: &mut String) {
let rust_name = self.name.as_str().to_snake_case();
for f in self.funcs() {
render_highlevel(&f, &rust_name, src);
src.push_str("\n\n");
}
src.push_str("pub mod ");
src.push_str(&rust_name);
src.push_str("{\nuse super::*;");
src.push_str("#[link(wasm_import_module =\"");
src.push_str(self.name.as_str());
src.push_str("\")]\n");
src.push_str("extern \"C\" {\n");
for f in self.funcs() {
f.render(src);
src.push_str("\n");
}
src.push_str("}");
src.push_str("}");
}
}
fn render_highlevel(func: &InterfaceFunc, module: &str, src: &mut String) {
let mut rust_name = String::new();
func.name.render(&mut rust_name);
let rust_name = rust_name.to_snake_case();
rustdoc(&func.docs, src);
rustdoc_params(&func.params, "Parameters", src);
rustdoc_params(&func.results, "Return", src);
src.push_str("pub unsafe fn ");
cfg_if::cfg_if! {
if #[cfg(feature = "multi-module")] {
src.push_str(&[module, &rust_name].join("_"));
} else {
src.push_str(&rust_name);
}
}
src.push_str("(");
for param in func.params.iter() {
param.name.render(src);
src.push_str(": ");
param.tref.render(src);
src.push_str(",");
}
src.push_str(")");
if let Some(first) = func.results.get(0) {
assert_eq!(first.name.as_str(), "error");
src.push_str(" -> Result<");
if func.results.len() != 2 {
src.push_str("(");
}
for result in func.results.iter().skip(1) {
result.tref.render(src);
src.push_str(",");
}
if func.results.len() != 2 {
src.push_str(")");
}
src.push_str(">");
}
src.push_str("{");
for result in func.results.iter().skip(1) {
src.push_str("let mut ");
result.name.render(src);
src.push_str(" = MaybeUninit::uninit();");
}
if func.results.len() > 0 {
src.push_str("let rc = ");
}
src.push_str(module);
src.push_str("::");
src.push_str(&rust_name);
src.push_str("(");
for param in func.params.iter() {
match param.tref.type_().passed_by() {
TypePassedBy::Value(_) => param.name.render(src),
TypePassedBy::Pointer => unreachable!(
"unable to translate parameter `{}` of type `{}` in function `{}`",
param.name.as_str(),
param.tref.type_name(),
func.name.as_str()
),
TypePassedBy::PointerLengthPair => {
param.name.render(src);
src.push_str(".as_ptr(), ");
param.name.render(src);
src.push_str(".len()");
}
}
src.push_str(",");
}
for result in func.results.iter().skip(1) {
result.name.render(src);
src.push_str(".as_mut_ptr(),");
}
src.push_str(");");
if func.results.len() > 0 {
src.push_str("if let Some(err) = Error::from_raw_error(rc) { ");
src.push_str("Err(err)");
src.push_str("} else {");
src.push_str("Ok(");
if func.results.len() != 2 {
src.push_str("(");
}
for result in func.results.iter().skip(1) {
result.name.render(src);
src.push_str(".assume_init(),");
}
if func.results.len() != 2 {
src.push_str(")");
}
src.push_str(") }");
}
src.push_str("}");
}
impl Render for InterfaceFunc {
fn render(&self, src: &mut String) {
rustdoc(&self.docs, src);
if self.name.as_str() != self.name.as_str().to_snake_case() {
src.push_str("#[link_name = \"");
src.push_str(self.name.as_str());
src.push_str("\"]\n");
}
src.push_str("pub fn ");
let mut name = String::new();
self.name.render(&mut name);
src.push_str(&name.to_snake_case());
src.push_str("(");
for param in self.params.iter() {
param.render(src);
src.push_str(",");
}
for result in self.results.iter().skip(1) {
result.name.render(src);
src.push_str(": *mut ");
result.tref.render(src);
src.push_str(",");
}
src.push_str(")");
if let Some(result) = self.results.get(0) {
src.push_str(" -> ");
result.render(src);
} else if self.name.as_str() == "proc_exit" {
src.push_str(" -> !");
}
src.push_str(";");
}
}
impl Render for InterfaceFuncParam {
fn render(&self, src: &mut String) {
let is_param = match self.position {
InterfaceFuncParamPosition::Param(_) => true,
_ => false,
};
match self.tref.type_().passed_by() {
TypePassedBy::Value(_) => {
if is_param {
self.name.render(src);
src.push_str(": ");
}
self.tref.render(src);
}
TypePassedBy::Pointer => {
if is_param {
self.name.render(src);
src.push_str(": ");
}
src.push_str("*mut ");
self.tref.render(src);
}
TypePassedBy::PointerLengthPair => {
assert!(is_param);
src.push_str(self.name.as_str());
src.push_str("_ptr");
src.push_str(": ");
src.push_str("*const ");
match &*self.tref.type_() {
Type::Array(x) => x.render(src),
Type::Builtin(BuiltinType::String) => src.push_str("u8"),
x => panic!("unexpected pointer length pair type {:?}", x),
}
src.push_str(", ");
src.push_str(self.name.as_str());
src.push_str("_len");
src.push_str(": ");
src.push_str("usize");
}
}
}
}
impl Render for Id {
fn render(&self, src: &mut String) {
match self.as_str() {
"in" => src.push_str("r#in"),
"type" => src.push_str("r#type"),
"yield" => src.push_str("r#yield"),
s => src.push_str(s),
}
}
}
fn render_handle(src: &mut String, name: &str, _h: &HandleDatatype) {
src.push_str(&format!("pub type {} = u32;", name.to_camel_case()));
}
fn rustdoc(docs: &str, dst: &mut String) {
if docs.trim().is_empty() {
return;
}
for line in docs.lines() {
dst.push_str("/// ");
dst.push_str(line);
dst.push_str("\n");
}
}
fn rustdoc_params(docs: &[InterfaceFuncParam], header: &str, dst: &mut String) {
let docs = docs
.iter()
.filter(|param| param.docs.trim().len() > 0)
.collect::<Vec<_>>();
if docs.len() == 0 {
return;
}
dst.push_str("///\n");
dst.push_str("/// ## ");
dst.push_str(header);
dst.push_str("\n");
dst.push_str("///\n");
for param in docs {
for (i, line) in param.docs.lines().enumerate() {
dst.push_str("/// ");
if i == 0 {
dst.push_str("* `");
param.name.render(dst);
dst.push_str("` - ");
} else {
dst.push_str(" ");
}
dst.push_str(line);
dst.push_str("\n");
}
}
}
fn struct_contains_union(s: &StructDatatype) -> bool {
s.members
.iter()
.any(|member| type_contains_union(&member.tref.type_()))
}
fn type_contains_union(ty: &Type) -> bool {
match ty {
Type::Union(_) => true,
Type::Array(tref) => type_contains_union(&tref.type_()),
Type::Struct(st) => struct_contains_union(st),
_ => false,
}
}