use anyhow::{bail, Context, Result};
use heck::{AsSnakeCase, ToSnakeCase, ToUpperCamelCase};
use indexmap::IndexSet;
use std::{
borrow::Cow,
collections::{BTreeMap, BTreeSet},
fmt::{self, Write},
fs,
io::Read,
path::Path,
process::{Command, Stdio},
};
use warg_protocol::registry::PackageName;
use wit_bindgen_rust::to_rust_ident;
use wit_component::DecodedWasm;
use wit_parser::{
Function, Handle, Interface, Resolve, Type, TypeDef, TypeDefKind, TypeId, TypeOwner, WorldId,
WorldItem, WorldKey,
};
#[derive(Default)]
struct UseTrieNode {
children: BTreeMap<String, UseTrieNode>,
tys: BTreeSet<String>,
}
impl fmt::Display for UseTrieNode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.children.len() + self.tys.len() > 1 {
write!(f, "{{")?;
}
for (i, (segment, child)) in self.children.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{segment}::{child}")?;
}
if !self.children.is_empty() && !self.tys.is_empty() {
write!(f, ", ")?;
}
for (i, ty) in self.tys.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{ty}")?;
}
if self.children.len() + self.tys.len() > 1 {
write!(f, "}}")?;
}
Ok(())
}
}
#[derive(Default)]
struct UseTrie {
root: UseTrieNode,
segments: IndexSet<String>,
types: IndexSet<String>,
}
impl UseTrie {
fn get<'a>(&self, path: impl Iterator<Item = &'a str>) -> Option<impl Iterator<Item = &str>> {
let mut node = &self.root;
for segment in path {
node = node.children.get(segment)?;
}
Some(node.tys.iter().map(|ty| ty.as_str()))
}
fn insert<'a, I>(&mut self, path: I, ty: &str) -> Cow<str>
where
I: IntoIterator<Item = &'a str>,
I::IntoIter: Clone,
{
let (type_index, inserted) = self.types.insert_full(ty.to_upper_camel_case());
let ty: &String = &self.types[type_index];
if !inserted {
let path = path.into_iter();
if let Some(tys) = self.get(path.clone()) {
for existing in tys {
if ty == existing {
return ty.into();
}
}
}
return format!(
"{path}::{ty}",
path = path.enumerate().fold(String::new(), |mut s, (i, p)| {
if i > 0 {
s.push_str("::");
}
write!(s, "{p}", p = AsSnakeCase(p)).unwrap();
s
}),
ty = self.types[type_index],
)
.into();
}
let mut node = &mut self.root;
for segment in path {
assert!(!segment.is_empty());
let (segment_index, _) = self.segments.insert_full(segment.to_snake_case());
let segment = &self.segments[segment_index];
node = node.children.entry(segment.clone()).or_default();
}
let inserted = node.tys.insert(ty.clone());
assert!(inserted);
Cow::Borrowed(&self.types[type_index])
}
fn insert_interface_type(
&mut self,
resolve: &Resolve,
interface: &Interface,
ty: &str,
) -> Cow<str> {
let pkg = &resolve.packages[interface.package.expect("interface should have a package")];
let name = interface.name.as_deref().expect("unnamed interface");
self.insert(
[
"bindings",
"exports",
pkg.name.namespace.as_str(),
pkg.name.name.as_str(),
name,
],
ty,
)
}
fn insert_export_trait(&mut self, resolve: &Resolve, key: &WorldKey) -> Cow<str> {
match key {
WorldKey::Name(name) => self.insert(["bindings", "exports", name.as_str()], "Guest"),
WorldKey::Interface(id) => {
let iface = &resolve.interfaces[*id];
self.insert_interface_type(resolve, iface, "Guest")
}
}
}
fn is_empty(&self) -> bool {
self.root.children.is_empty() && self.root.tys.is_empty()
}
}
impl fmt::Display for UseTrie {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
assert!(self.root.tys.is_empty());
for (segment, child) in &self.root.children {
writeln!(f, "use {segment}::{child};")?;
}
Ok(())
}
}
pub struct SourceGenerator<'a> {
name: &'a PackageName,
path: &'a Path,
format: bool,
}
impl<'a> SourceGenerator<'a> {
pub fn new(name: &'a PackageName, path: &'a Path, format: bool) -> Self {
Self { name, path, format }
}
pub fn generate(&self, world: Option<&str>) -> Result<String> {
let (resolve, world) = self.decode(world)?;
let mut trie = UseTrie::default();
let mut impls = Vec::new();
let world = &resolve.worlds[world];
let mut function_exports = Vec::new();
for (key, item) in &world.exports {
match item {
WorldItem::Function(f) => {
function_exports.push(f);
}
WorldItem::Interface(i) => {
let interface = &resolve.interfaces[*i];
let mut imp: String = String::new();
writeln!(
&mut imp,
"\nimpl {name} for Component {{",
name = trie.insert_export_trait(&resolve, key),
)
.unwrap();
for (i, (_, func)) in interface.functions.iter().enumerate() {
if i > 0 {
imp.push('\n');
}
Self::print_unimplemented_func(&resolve, func, &mut imp, &mut trie)?;
}
imp.push_str("}\n");
impls.push(imp);
}
WorldItem::Type(_) => continue,
}
}
if !function_exports.is_empty() {
let mut imp = String::new();
writeln!(
&mut imp,
"\nimpl {name} for Component {{",
name = trie.insert(["bindings"], "Guest")
)
.unwrap();
for (i, func) in function_exports.iter().enumerate() {
if i > 0 {
imp.push('\n');
}
Self::print_unimplemented_func(&resolve, func, &mut imp, &mut trie)?;
}
imp.push_str("}\n");
impls.push(imp);
}
let mut source = String::new();
writeln!(&mut source, "mod bindings;")?;
writeln!(&mut source)?;
write!(
&mut source,
"{trie}{nl}",
nl = if trie.is_empty() { "" } else { "\n" }
)?;
source.push_str("struct Component;\n");
for (i, imp) in impls.iter().enumerate() {
if i > 0 {
source.push_str("\n\n");
}
source.push_str(imp);
}
if self.format {
let mut child = Command::new("rustfmt")
.arg("--edition=2018")
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.spawn()
.context("failed to spawn `rustfmt`")?;
std::io::Write::write_all(&mut child.stdin.take().unwrap(), source.as_bytes())
.context("failed to write to `rustfmt`")?;
source.truncate(0);
child
.stdout
.take()
.unwrap()
.read_to_string(&mut source)
.context("failed to write to `rustfmt`")?;
let status = child.wait().context("failed to wait for `rustfmt`")?;
if !status.success() {
bail!("execution of `rustfmt` returned a non-zero exit code {status}");
}
}
Ok(source)
}
fn decode(&self, world: Option<&str>) -> Result<(Resolve, WorldId)> {
let bytes = fs::read(self.path).with_context(|| {
format!(
"failed to read the content of target package `{name}` path `{path}`",
name = self.name,
path = self.path.display()
)
})?;
let decoded = wit_component::decode(&bytes).with_context(|| {
format!(
"failed to decode the content of target package `{name}` path `{path}`",
name = self.name,
path = self.path.display()
)
})?;
match decoded {
DecodedWasm::WitPackage(resolve, package) => {
let world = resolve.select_world(package, world).with_context(|| {
format!(
"failed to select world from target package `{name}`",
name = self.name
)
})?;
Ok((resolve, world))
}
DecodedWasm::Component(..) => bail!("target is not a WIT package"),
}
}
fn print_unimplemented_func(
resolve: &Resolve,
func: &Function,
source: &mut String,
trie: &mut UseTrie,
) -> Result<()> {
write!(source, " fn {name}(", name = to_rust_ident(&func.name)).unwrap();
for (i, (name, param)) in func.params.iter().enumerate() {
if i > 0 {
source.push_str(", ");
}
source.push_str(&to_rust_ident(name));
source.push_str(": ");
Self::print_type(resolve, param, source, trie)?;
}
source.push(')');
match func.results.len() {
0 => {}
1 => {
source.push_str(" -> ");
Self::print_type(
resolve,
func.results.iter_types().next().unwrap(),
source,
trie,
)?;
}
_ => {
source.push_str(" -> (");
for (i, ty) in func.results.iter_types().enumerate() {
if i > 0 {
source.push_str(", ");
}
Self::print_type(resolve, ty, source, trie)?;
}
source.push(')');
}
}
source.push_str(" {\n unimplemented!()\n }\n");
Ok(())
}
fn print_type(
resolve: &Resolve,
ty: &Type,
source: &mut String,
trie: &mut UseTrie,
) -> Result<()> {
match ty {
Type::Bool => source.push_str("bool"),
Type::U8 => source.push_str("u8"),
Type::U16 => source.push_str("u16"),
Type::U32 => source.push_str("u32"),
Type::U64 => source.push_str("u64"),
Type::S8 => source.push_str("i8"),
Type::S16 => source.push_str("i16"),
Type::S32 => source.push_str("i32"),
Type::S64 => source.push_str("i64"),
Type::Float32 => source.push_str("f32"),
Type::Float64 => source.push_str("f64"),
Type::Char => source.push_str("char"),
Type::String => source.push_str("String"),
Type::Id(id) => Self::print_type_id(resolve, *id, source, trie)?,
}
Ok(())
}
fn print_type_id(
resolve: &Resolve,
id: TypeId,
source: &mut String,
trie: &mut UseTrie,
) -> Result<()> {
let ty = &resolve.types[id];
if ty.name.is_some() {
Self::print_type_path(resolve, ty, source, trie);
return Ok(());
}
match &ty.kind {
TypeDefKind::List(ty) => {
source.push_str("Vec<");
Self::print_type(resolve, ty, source, trie)?;
source.push('>');
}
TypeDefKind::Option(ty) => {
source.push_str("Option<");
Self::print_type(resolve, ty, source, trie)?;
source.push('>');
}
TypeDefKind::Result(r) => {
source.push_str("Result<");
Self::print_optional_type(resolve, r.ok.as_ref(), source, trie)?;
source.push_str(", ");
Self::print_optional_type(resolve, r.err.as_ref(), source, trie)?;
source.push('>');
}
TypeDefKind::Variant(_) => {
bail!("unsupported anonymous variant type found in WIT package")
}
TypeDefKind::Tuple(t) => {
source.push('(');
for (i, ty) in t.types.iter().enumerate() {
if i > 0 {
source.push_str(", ");
}
Self::print_type(resolve, ty, source, trie)?;
}
source.push(')');
}
TypeDefKind::Record(_) => {
bail!("unsupported anonymous record type found in WIT package")
}
TypeDefKind::Flags(_) => {
bail!("unsupported anonymous flags type found in WIT package")
}
TypeDefKind::Enum(_) => {
bail!("unsupported anonymous enum type found in WIT package")
}
TypeDefKind::Future(ty) => {
source.push_str("Future<");
Self::print_optional_type(resolve, ty.as_ref(), source, trie)?;
source.push('>');
}
TypeDefKind::Stream(stream) => {
source.push_str("Stream<");
Self::print_optional_type(resolve, stream.element.as_ref(), source, trie)?;
source.push_str(", ");
Self::print_optional_type(resolve, stream.end.as_ref(), source, trie)?;
source.push('>');
}
TypeDefKind::Type(ty) => Self::print_type(resolve, ty, source, trie)?,
TypeDefKind::Handle(Handle::Own(id)) => {
Self::print_type_id(resolve, *id, source, trie)?
}
TypeDefKind::Handle(Handle::Borrow(id)) => {
source.push('&');
Self::print_type_id(resolve, *id, source, trie)?
}
TypeDefKind::Resource => {
bail!("unsupported anonymous resource type found in WIT package")
}
TypeDefKind::Unknown => unreachable!(),
}
Ok(())
}
fn print_type_path(resolve: &Resolve, ty: &TypeDef, source: &mut String, trie: &mut UseTrie) {
if let TypeOwner::Interface(id) = ty.owner {
let interface = &resolve.interfaces[id];
if interface.package.is_some() {
write!(
source,
"{name}",
name =
trie.insert_interface_type(resolve, interface, ty.name.as_deref().unwrap())
)
.unwrap();
return;
}
}
write!(
source,
"{name}",
name = trie.insert(["bindings"], ty.name.as_deref().unwrap())
)
.unwrap();
}
fn print_optional_type(
resolve: &Resolve,
ty: Option<&Type>,
source: &mut String,
trie: &mut UseTrie,
) -> Result<()> {
match ty {
Some(ty) => Self::print_type(resolve, ty, source, trie)?,
None => source.push_str("()"),
}
Ok(())
}
}