trustfall_stubgen 0.1.0

Generate a Trustfall adapter stub for a given schema.
Documentation
use std::{
    collections::{BTreeMap, BTreeSet},
    io::Write,
    path::Path,
    sync::{Arc, OnceLock},
};

use maplit::btreemap;
use quote::quote;
use regex::Regex;
use trustfall::{Schema, SchemaAdapter, TryIntoStruct};

use super::{
    adapter_creator::make_adapter_file, edges_creator::make_edges_file,
    entrypoints_creator::make_entrypoints_file, properties_creator::make_properties_file,
};

/// Given a schema, make a Rust adapter stub for it in the given directory.
///
/// Generated code structure:
/// - adapter/mod.rs         connects everything together
/// - adapter/schema.graphql contains the schema for the adapter
/// - adapter/adapter.rs     contains the adapter implementation
/// - adapter/vertex.rs      contains the vertex type definition
/// - adapter/entrypoints.rs contains the entry points where all queries must start
/// - adapter/properties.rs  contains the property implementations
/// - adapter/edges.rs       contains the edge implementations
///
/// # Example
/// ```no_run
/// # fn main() {
/// let schema_text = include_str!("./schema.graphql");
/// generate_rust_stub(schema_text, Path::new("crate/with/generated/stubs/src"))
///     .expect("stub generation failed");
/// # }
/// ```
pub fn generate_rust_stub(schema: &str, target: &Path) -> anyhow::Result<()> {
    let target_schema = Schema::parse(schema)?;

    let querying_schema =
        Schema::parse(SchemaAdapter::schema_text()).expect("schema querying schema was not valid");
    let schema_adapter = Arc::new(SchemaAdapter::new(&target_schema));

    let mut stub = AdapterStub::with_standard_mod(schema);

    let mut entrypoint_match_arms = proc_macro2::TokenStream::new();

    make_vertex_file(&querying_schema, schema_adapter.clone(), &mut stub.vertex);
    make_entrypoints_file(
        &querying_schema,
        schema_adapter.clone(),
        &mut stub.entrypoints,
        &mut entrypoint_match_arms,
    );
    make_properties_file(
        &querying_schema,
        schema_adapter.clone(),
        &mut stub.properties,
    );
    make_edges_file(&querying_schema, schema_adapter.clone(), &mut stub.edges);

    make_adapter_file(
        &querying_schema,
        schema_adapter.clone(),
        &mut stub.adapter,
        entrypoint_match_arms,
    );

    stub.write_to_directory(target)
}

#[derive(Debug, Default)]
pub(crate) struct RustFile {
    pub(crate) builtin_imports: BTreeSet<Vec<String>>,
    pub(crate) internal_imports: BTreeSet<Vec<String>>,
    pub(crate) external_imports: BTreeSet<Vec<String>>,
    pub(crate) top_level_items: Vec<proc_macro2::TokenStream>,
}

impl RustFile {
    fn write_to_file(self, target: &Path) -> anyhow::Result<()> {
        let mut buffer: Vec<u8> = Vec::with_capacity(8192);

        write_import_tree(&mut buffer, &self.builtin_imports)?;
        if !self.builtin_imports.is_empty() {
            buffer.write_all("\n".as_bytes())?;
        }

        write_import_tree(&mut buffer, &self.external_imports)?;
        if !self.external_imports.is_empty() {
            buffer.write_all("\n".as_bytes())?;
        }

        write_import_tree(&mut buffer, &self.internal_imports)?;
        if !self.internal_imports.is_empty() {
            buffer.write_all("\n".as_bytes())?;
        }

        let mut item_iter = self.top_level_items.into_iter();
        let first_item = item_iter.next().expect("no items found");
        Self::pretty_print_item(&mut buffer, first_item)?;

        for item in item_iter {
            buffer.write_all("\n".as_bytes())?;
            Self::pretty_print_item(&mut buffer, item)?;
        }

        std::fs::write(target, buffer)?;

        Ok(())
    }

    /// Pretty-print an item into the buffer.
    ///
    /// First use `prettyplease`, then postprocess with a regex to further improve quality.
    /// `prettyplease` does not add blank lines between sibling items, so we add them via regex.
    fn pretty_print_item(
        buffer: &mut impl std::io::Write,
        item: proc_macro2::TokenStream,
    ) -> anyhow::Result<()> {
        static PATTERN: OnceLock<Regex> = OnceLock::new();
        let pattern =
            PATTERN.get_or_init(|| Regex::new("([^{])\n    (pub|fn|use)").expect("invalid regex"));

        let pretty_item =
            prettyplease::unparse(&syn::parse_str(&item.to_string()).expect("not valid Rust"));
        let postprocessed = pattern.replace_all(&pretty_item, "$1\n\n    $2");

        buffer.write_all(postprocessed.as_bytes())?;

        Ok(())
    }
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
enum NodeOrLeaf<'a> {
    Leaf,
    Node(BTreeMap<&'a str, NodeOrLeaf<'a>>),
}

impl<'a> NodeOrLeaf<'a> {
    fn insert(&mut self, path: &'a [String]) {
        if let Some(first) = path.first() {
            let rest = &path[1..];
            match self {
                Self::Leaf => {
                    *self = Self::Node(btreemap! {
                        "self" => Self::Leaf,
                        first.as_str() => Self::from_path(rest),
                    })
                }
                Self::Node(ref mut map) => match map.entry(first) {
                    std::collections::btree_map::Entry::Vacant(e) => {
                        e.insert(Self::from_path(rest));
                    }
                    std::collections::btree_map::Entry::Occupied(mut e) => {
                        e.get_mut().insert(rest);
                    }
                },
            }
        } else {
            match self {
                Self::Leaf => {} // self is already here
                Self::Node(ref mut map) => {
                    map.insert("self", Self::Leaf);
                }
            }
        }
    }

    fn from_path(path: &[String]) -> NodeOrLeaf<'_> {
        if let Some(first) = path.first() {
            let rest = &path[1..];
            NodeOrLeaf::Node(btreemap! {
                first.as_str() => Self::from_path(rest)
            })
        } else {
            NodeOrLeaf::Leaf
        }
    }
}

fn make_import_forest(imports: &BTreeSet<Vec<String>>) -> BTreeMap<&str, NodeOrLeaf<'_>> {
    let first_import = imports.first().expect("no imports").as_slice();
    let mut node = NodeOrLeaf::from_path(first_import);

    for import in imports.iter().skip(1) {
        node.insert(import.as_slice());
    }

    match node {
        NodeOrLeaf::Node(map) => map,
        NodeOrLeaf::Leaf => {
            unreachable!("unexpectedly got a leaf node for the top level of the import forest")
        }
    }
}

fn write_import_tree<W: std::io::Write>(
    writer: &mut W,
    imports: &BTreeSet<Vec<String>>,
) -> anyhow::Result<()> {
    if imports.is_empty() {
        return Ok(());
    }

    let forest = make_import_forest(imports);

    for (root, nodes) in forest {
        writer.write_all("use ".as_bytes())?;
        writer.write_all(root.as_bytes())?;

        write_import_subtree(writer, nodes)?;
        writer.write_all(";\n".as_bytes())?;
    }

    Ok(())
}

fn write_import_subtree<W: std::io::Write>(
    writer: &mut W,
    nodes: NodeOrLeaf<'_>,
) -> anyhow::Result<()> {
    match nodes {
        NodeOrLeaf::Leaf => {}
        NodeOrLeaf::Node(map) => {
            writer.write_all("::".as_bytes())?;

            if map.len() == 1 {
                for (root, inner) in map {
                    writer.write_all(root.as_bytes())?;
                    write_import_subtree(writer, inner)?;
                }
            } else {
                writer.write_all("{".as_bytes())?;

                let mut map_iter = map.into_iter();
                let (root, inner) = map_iter.next().expect("empty map found");
                writer.write_all(root.as_bytes())?;
                write_import_subtree(writer, inner)?;

                for (root, inner) in map_iter {
                    writer.write_all(", ".as_bytes())?;
                    writer.write_all(root.as_bytes())?;
                    write_import_subtree(writer, inner)?;
                }

                writer.write_all("}".as_bytes())?;
            }
        }
    }

    Ok(())
}

#[derive(Debug)]
struct AdapterStub<'a> {
    mod_: RustFile,
    schema: &'a str,
    adapter: RustFile,
    vertex: RustFile,
    entrypoints: RustFile,
    properties: RustFile,
    edges: RustFile,
}

impl<'a> AdapterStub<'a> {
    fn with_standard_mod(schema: &'a str) -> Self {
        let mut mod_ = RustFile::default();

        mod_.top_level_items.push(quote! {
            mod adapter;
            mod vertex;
            mod entrypoints;
            mod properties;
            mod edges;
        });
        mod_.top_level_items.push(quote! {
            pub use adapter::Adapter;
            pub use vertex::Vertex;
        });

        Self {
            mod_,
            schema,
            adapter: Default::default(),
            vertex: Default::default(),
            entrypoints: Default::default(),
            properties: Default::default(),
            edges: Default::default(),
        }
    }

    fn write_to_directory(self, target: &Path) -> anyhow::Result<()> {
        let mut path_buf = target.to_path_buf();
        path_buf.push("adapter");
        std::fs::create_dir_all(&path_buf)?;

        path_buf.push("schema.graphql");
        std::fs::write(path_buf.as_path(), self.schema)?;
        path_buf.pop();

        path_buf.push("mod.rs");
        self.mod_.write_to_file(path_buf.as_path())?;
        path_buf.pop();

        path_buf.push("adapter.rs");
        self.adapter.write_to_file(path_buf.as_path())?;
        path_buf.pop();

        path_buf.push("vertex.rs");
        self.vertex.write_to_file(path_buf.as_path())?;
        path_buf.pop();

        path_buf.push("entrypoints.rs");
        self.entrypoints.write_to_file(path_buf.as_path())?;
        path_buf.pop();

        path_buf.push("properties.rs");
        self.properties.write_to_file(path_buf.as_path())?;
        path_buf.pop();

        path_buf.push("edges.rs");
        self.edges.write_to_file(path_buf.as_path())?;
        path_buf.pop();

        Ok(())
    }
}

fn make_vertex_file(
    querying_schema: &Schema,
    adapter: Arc<SchemaAdapter<'_>>,
    vertex_file: &mut RustFile,
) {
    let query = r#"
{
    VertexType {
        name @output
    }
}"#;
    let variables: BTreeMap<String, String> = Default::default();

    #[derive(Debug, serde::Deserialize)]
    struct ResultRow {
        name: String,
    }

    let mut variants = proc_macro2::TokenStream::new();
    let rows = trustfall::execute_query(querying_schema, adapter, query, variables)
        .expect("invalid query")
        .map(|x| {
            x.try_into_struct::<ResultRow>()
                .expect("invalid conversion")
        });
    for row in rows {
        let name = &row.name;
        let ident = syn::Ident::new(name.as_str(), proc_macro2::Span::call_site());
        variants.extend(quote! {
            #ident(()),
        });
    }

    let vertex = quote! {
        #[derive(Debug, Clone, trustfall::provider::TrustfallEnumVertex)]
        pub enum Vertex {
            #variants
        }
    };

    vertex_file.top_level_items.push(vertex);
}