haloumi-llzk 0.5.15

Haloumi backend to LLZK.
use haloumi_synthesis::io::{AdviceIO, InstanceIO};
use llzk::{attributes::NamedAttribute, prelude::*};

use melior::{
    Context,
    ir::{Identifier, Location, Operation, Type},
};

/// Generates a pseudo-filename for use in location metadata in the generated IR.
pub fn filename(name: &str, section: Option<&str>) -> String {
    use std::fmt::Write;
    const STRUCT: &str = "struct ";
    const SEP: &str = " | ";
    let mut s = String::with_capacity(
        STRUCT.len() + name.len() + section.map(|s| s.len() + SEP.len()).unwrap_or_default(),
    );
    write!(s, "{STRUCT}{name}").expect("write to string");
    if let Some(section) = section {
        write!(s, "{SEP}{section}").expect("write to string");
    }
    s
}

fn struct_def_op_location<'c>(context: &'c Context, name: &str, index: usize) -> Location<'c> {
    Location::new(context, filename(name, None).as_str(), index, 0)
}

pub struct StructIO {
    private_inputs: usize,
    public_inputs: usize,
    private_outputs: usize,
    public_outputs: usize,
}

impl StructIO {
    fn fields<'c>(
        &self,
        context: &'c Context,
        header: &str,
    ) -> impl Iterator<Item = Result<MemberDefOp<'c>, LlzkError>> {
        let public_filename = filename(header, Some("public outputs"));
        let private_filename = filename(header, Some("private outputs"));
        std::iter::repeat_n(true, self.public_outputs)
            .enumerate()
            .chain(std::iter::repeat_n(false, self.private_outputs).enumerate())
            .map(move |(n, public)| {
                let filename = if public {
                    &public_filename
                } else {
                    &private_filename
                };
                (public, Location::new(context, filename, n, 0))
            })
            .enumerate()
            .map(|(n, (public, loc))| {
                let name = format!("out_{n}");
                dialect::r#struct::member(loc, &name, FeltType::new(context), false, public)
            })
    }

    pub fn args<'c>(&self, ctx: &'c Context, struct_name: &str) -> Vec<(Type<'c>, Location<'c>)> {
        let public_filename = filename(struct_name, Some("public inputs"));
        let private_filename = filename(struct_name, Some("private inputs"));
        let public_locs = std::iter::repeat(&public_filename)
            .enumerate()
            .take(self.public_inputs);
        let private_locs = std::iter::repeat(&private_filename)
            .enumerate()
            .take(self.private_inputs);
        let locs = public_locs
            .chain(private_locs)
            .map(|(n, filename)| Location::new(ctx, filename, n, 0));

        let types = std::iter::repeat_n(
            Type::from(FeltType::new(ctx)),
            self.public_inputs + self.private_inputs,
        );

        std::iter::zip(types, locs).collect()
    }

    /// Returns the list of argument attributes for the struct's functions.
    pub fn arg_attrs<'c>(&self, ctx: &'c Context) -> Vec<Vec<NamedAttribute<'c>>> {
        let pub_attr = (
            Identifier::new(ctx, "llzk.pub"),
            PublicAttribute::new(ctx).into(),
        );
        std::iter::repeat_n(vec![pub_attr], self.public_inputs)
            .chain(std::iter::repeat_n(vec![], self.private_inputs))
            .collect()
    }

    pub fn from_io(advice: &AdviceIO, instance: &InstanceIO) -> Self {
        Self {
            private_inputs: advice.inputs().len(),
            public_inputs: instance.inputs().len(),
            private_outputs: advice.outputs().len(),
            public_outputs: instance.outputs().len(),
        }
    }

    pub fn from_io_count(inputs: usize, outputs: usize) -> Self {
        Self {
            private_inputs: inputs,
            public_inputs: 0,
            private_outputs: outputs,
            public_outputs: 0,
        }
    }
}

pub fn create_struct<'c>(
    context: &'c Context,
    struct_name: &str,
    idx: usize,
    io: StructIO,
) -> Result<StructDefOp<'c>, LlzkError> {
    log::debug!("context = {context:?}");
    let loc = struct_def_op_location(context, struct_name, idx);
    log::debug!("Struct location: {loc:?}");
    let fields = io
        .fields(context, struct_name)
        .map(|r| r.map(Operation::from));

    let func_args = io.args(context, struct_name);
    let arg_attrs = io.arg_attrs(context);

    log::debug!("Creating function with arguments: {func_args:?}");

    let funcs = [
        dialect::r#struct::helpers::compute_fn(
            loc,
            StructType::from_str(context, struct_name),
            &func_args,
            Some(&arg_attrs),
        )
        .map(Operation::from),
        dialect::r#struct::helpers::constrain_fn(
            loc,
            StructType::from_str(context, struct_name),
            &func_args,
            Some(&arg_attrs),
        )
        .map(Operation::from),
    ];

    dialect::r#struct::def(loc, struct_name, fields.chain(funcs))
}