use haloumi_synthesis::io::{AdviceIO, InstanceIO};
use llzk::{attributes::NamedAttribute, prelude::*};
use melior::{
Context,
ir::{Identifier, Location, Operation, Type},
};
pub fn filename(name: &str, section: Option<&str>) -> String {
use std::fmt::Write;
const STRUCT: &str = "struct ";
const SEP: &str = " | ";
let mut s = String::with_capacity(
STRUCT.len() + name.len() + section.map(|s| s.len() + SEP.len()).unwrap_or_default(),
);
write!(s, "{STRUCT}{name}").expect("write to string");
if let Some(section) = section {
write!(s, "{SEP}{section}").expect("write to string");
}
s
}
fn struct_def_op_location<'c>(context: &'c Context, name: &str, index: usize) -> Location<'c> {
Location::new(context, filename(name, None).as_str(), index, 0)
}
pub struct StructIO {
private_inputs: usize,
public_inputs: usize,
private_outputs: usize,
public_outputs: usize,
}
impl StructIO {
fn fields<'c>(
&self,
context: &'c Context,
header: &str,
) -> impl Iterator<Item = Result<MemberDefOp<'c>, LlzkError>> {
let public_filename = filename(header, Some("public outputs"));
let private_filename = filename(header, Some("private outputs"));
std::iter::repeat_n(true, self.public_outputs)
.enumerate()
.chain(std::iter::repeat_n(false, self.private_outputs).enumerate())
.map(move |(n, public)| {
let filename = if public {
&public_filename
} else {
&private_filename
};
(public, Location::new(context, filename, n, 0))
})
.enumerate()
.map(|(n, (public, loc))| {
let name = format!("out_{n}");
dialect::r#struct::member(loc, &name, FeltType::new(context), false, public)
})
}
pub fn args<'c>(&self, ctx: &'c Context, struct_name: &str) -> Vec<(Type<'c>, Location<'c>)> {
let public_filename = filename(struct_name, Some("public inputs"));
let private_filename = filename(struct_name, Some("private inputs"));
let public_locs = std::iter::repeat(&public_filename)
.enumerate()
.take(self.public_inputs);
let private_locs = std::iter::repeat(&private_filename)
.enumerate()
.take(self.private_inputs);
let locs = public_locs
.chain(private_locs)
.map(|(n, filename)| Location::new(ctx, filename, n, 0));
let types = std::iter::repeat_n(
Type::from(FeltType::new(ctx)),
self.public_inputs + self.private_inputs,
);
std::iter::zip(types, locs).collect()
}
pub fn arg_attrs<'c>(&self, ctx: &'c Context) -> Vec<Vec<NamedAttribute<'c>>> {
let pub_attr = (
Identifier::new(ctx, "llzk.pub"),
PublicAttribute::new(ctx).into(),
);
std::iter::repeat_n(vec![pub_attr], self.public_inputs)
.chain(std::iter::repeat_n(vec![], self.private_inputs))
.collect()
}
pub fn from_io(advice: &AdviceIO, instance: &InstanceIO) -> Self {
Self {
private_inputs: advice.inputs().len(),
public_inputs: instance.inputs().len(),
private_outputs: advice.outputs().len(),
public_outputs: instance.outputs().len(),
}
}
pub fn from_io_count(inputs: usize, outputs: usize) -> Self {
Self {
private_inputs: inputs,
public_inputs: 0,
private_outputs: outputs,
public_outputs: 0,
}
}
}
pub fn create_struct<'c>(
context: &'c Context,
struct_name: &str,
idx: usize,
io: StructIO,
) -> Result<StructDefOp<'c>, LlzkError> {
log::debug!("context = {context:?}");
let loc = struct_def_op_location(context, struct_name, idx);
log::debug!("Struct location: {loc:?}");
let fields = io
.fields(context, struct_name)
.map(|r| r.map(Operation::from));
let func_args = io.args(context, struct_name);
let arg_attrs = io.arg_attrs(context);
log::debug!("Creating function with arguments: {func_args:?}");
let funcs = [
dialect::r#struct::helpers::compute_fn(
loc,
StructType::from_str(context, struct_name),
&func_args,
Some(&arg_attrs),
)
.map(Operation::from),
dialect::r#struct::helpers::constrain_fn(
loc,
StructType::from_str(context, struct_name),
&func_args,
Some(&arg_attrs),
)
.map(Operation::from),
];
dialect::r#struct::def(loc, struct_name, fields.chain(funcs))
}