use super::lowering::LlzkStructLowering;
use super::state::LlzkCodegenState;
use super::{LlzkOutput, counter::Counter};
use haloumi_backend::codegen::Codegen;
use llzk::prelude::*;
use melior::{
Context,
ir::{Location, Module},
};
use crate::error::Error;
use crate::factory::StructIO;
use haloumi_synthesis::io::{AdviceIO, InstanceIO};
use super::factory;
#[derive(Debug)]
pub struct LlzkCodegen<'c, 's> {
state: &'s LlzkCodegenState<'c>,
module: Module<'c>,
struct_count: Counter,
}
impl<'c, 's> LlzkCodegen<'c, 's> {
fn add_struct(&self, s: StructDefOp<'c>) -> Result<StructDefOpRefMut<'c, 's>, Error> {
let s: StructDefOpRef = self.module.body().append_operation(s.into()).try_into()?;
Ok(unsafe { StructDefOpRefMut::from_raw(s.to_raw()) })
}
fn create_lowering_scope(
&self,
name: &str,
io: StructIO,
) -> Result<LlzkStructLowering<'c, 's>, Error> {
let s = factory::create_struct(self.context(), name, self.struct_count.next(), io)?;
LlzkStructLowering::new(self.context(), self.add_struct(s)?)
}
fn context(&self) -> &'c Context {
self.state.context()
}
}
impl<'c: 's, 's> Codegen<'c, 's> for LlzkCodegen<'c, 's> {
type FuncOutput = LlzkStructLowering<'c, 's>;
type Output = LlzkOutput<'c>;
type State = LlzkCodegenState<'c>;
type Error = Error;
fn initialize(state: &'s Self::State) -> Self {
let module = llzk_module(Location::unknown(state.context()), Some("haloumi"));
Self {
state,
module,
struct_count: Default::default(),
}
}
fn define_main_function(
&self,
advice_io: &AdviceIO,
instance_io: &InstanceIO,
) -> Result<Self::FuncOutput, Self::Error> {
let name = self.state.params().top_level().unwrap_or("Main");
log::debug!("Creating Main struct with name '{name}'");
self.create_lowering_scope(name, StructIO::from_io(advice_io, instance_io))
}
fn define_function(
&self,
name: &str,
inputs: usize,
outputs: usize,
) -> Result<Self::FuncOutput, Self::Error> {
self.create_lowering_scope(name, StructIO::from_io_count(inputs, outputs))
}
fn on_scope_end(&self, _: Self::FuncOutput) -> Result<(), Self::Error> {
Ok(())
}
fn generate_output(mut self) -> Result<Self::Output, Self::Error> {
verify_operation_with_diags(&self.module.as_operation()).map_err(|err| {
Error::VerificationFailed {
err,
note: if self.state.optimize() {
" (before optimization)"
} else {
""
},
}
})?;
if self.state.optimize() {
let pipeline = create_pipeline(self.context());
pipeline.run(&mut self.module)?;
}
Ok(self.module.into())
}
}
fn create_pipeline(context: &Context) -> PassManager<'_> {
let pm = PassManager::new(context);
pm.nested_under("builtin.module")
.nested_under("struct.def")
.add_pass(llzk_passes::create_member_write_validator_pass());
pm.add_pass(melior_passes::create_canonicalizer());
pm.add_pass(melior_passes::create_cse());
pm.add_pass(llzk_passes::create_redundant_read_and_write_elimination_pass());
pm.nested_under("builtin.module")
.nested_under("struct.def")
.add_pass(llzk_passes::create_member_write_validator_pass());
let opm = pm.as_operation_pass_manager();
log::debug!("Optimization pipeline: {opm}");
pm
}
#[cfg(test)]
mod tests {
use crate::params::LlzkParams;
use super::*;
use haloumi_core::{
query::{Advice, Instance},
table::Column,
};
use log::LevelFilter;
use rstest::{fixture, rstest};
use simplelog::{Config, TestLogger};
#[fixture]
fn common() {
let _ = TestLogger::init(LevelFilter::Debug, Config::default());
}
#[fixture]
#[allow(unused_variables)]
fn ctx(common: ()) -> LlzkContext {
LlzkContext::new()
}
macro_rules! main_function_test {
($test_name:ident, $expected:literal, $io:expr $(,)?) => {
#[rstest]
fn $test_name(ctx: LlzkContext) {
let state: LlzkCodegenState = LlzkParams::new(&ctx).no_optimize().into();
let codegen = LlzkCodegen::initialize(&state);
let (advice_io, instance_io) = $io;
let main = codegen
.define_main_function(&advice_io, &instance_io)
.unwrap();
codegen.on_scope_end(main).unwrap();
let op = codegen.generate_output().unwrap();
verify_operation_with_diags(&op.module().as_operation()).unwrap();
mlir_testutils::assert_module_eq_to_file!(op.module(), $expected);
}
};
}
main_function_test! {
define_main_function_empty_io,
"test_files/empty_io.mlir",
(AdviceIO::empty(), InstanceIO::empty()),
}
main_function_test! {
define_main_function_public_inputs,
"test_files/public_inputs.mlir",
(
AdviceIO::empty(),
InstanceIO::new(&[(Column::new(0, Instance), &[0, 1, 2])], &[]).unwrap()
)
}
main_function_test! {
define_main_function_private_inputs,
"test_files/private_inputs.mlir",
(
AdviceIO::new(&[(Column::new(0, Advice), &[0, 1, 2])], &[]).unwrap(),
InstanceIO::empty()
)
}
main_function_test! {
define_main_function_public_outputs,
"test_files/public_outputs.mlir",
(
AdviceIO::empty(),
InstanceIO::new(&[], &[(Column::new(0, Instance), &[0, 1, 2])]).unwrap()
)
}
main_function_test! {
define_main_function_private_outputs,
"test_files/private_outputs.mlir",
(
AdviceIO::new(&[], &[(Column::new(0, Advice), &[0, 1, 2])]).unwrap(),
InstanceIO::empty()
)
}
main_function_test! {
define_main_function_mixed_io,
"test_files/mixed_io.mlir",
{
let advice_col = Column::new(0, Advice);
let instance_col = Column::new(0, Instance);
(
AdviceIO::new(&[(advice_col, &[0, 1, 2])], &[(advice_col, &[3, 4])]).unwrap(),
InstanceIO::new(&[(instance_col, &[0, 1])], &[(instance_col, &[2, 3])]).unwrap()
)
}
}
}