1use crate::{
4 diagnostics::{SimpleDiagnostic, Validation},
5 groups::callsite::CallSite,
6 printer::IRPrintable,
7 stmt::IRStmt,
8 traits::{Canonicalize, ConstantFolding, Validatable},
9};
10use eqv::{EqvRelation, equiv};
11use haloumi_core::eqv::SymbolicEqv;
12use haloumi_lowering::{
13 Lowering,
14 lowerable::{LowerableExpr, LowerableStmt},
15};
16use std::fmt::Write;
17use thiserror::Error;
18
19pub type GroupKey = u64;
21
22pub mod callsite;
23
24#[derive(Debug, Clone)]
26pub struct IRGroup<E> {
27 name: String,
28 id: usize,
30 input_count: usize,
31 output_count: usize,
32 key: Option<GroupKey>,
33 gates: IRStmt<E>,
34 eq_constraints: IRStmt<E>,
35 callsites: Vec<CallSite<E>>,
36 lookups: IRStmt<E>,
37 injected: Vec<IRStmt<E>>,
38 generate_debug_comments: bool,
39}
40
41impl<E> IRGroup<E> {
42 pub fn new(name: String, id: usize) -> Self {
44 Self {
45 name,
46 id,
47 input_count: Default::default(),
48 output_count: Default::default(),
49 key: Default::default(),
50 gates: Default::default(),
51 eq_constraints: Default::default(),
52 callsites: Default::default(),
53 lookups: Default::default(),
54 injected: Default::default(),
55 generate_debug_comments: Default::default(),
56 }
57 }
58
59 pub fn with_input_count(mut self, input_count: usize) -> Self {
61 self.input_count = input_count;
62 self
63 }
64
65 pub fn with_output_count(mut self, output_count: usize) -> Self {
67 self.output_count = output_count;
68 self
69 }
70
71 pub fn with_key(mut self, key: Option<GroupKey>) -> Self {
73 self.key = key;
74 self
75 }
76
77 pub fn no_key(mut self) -> Self {
79 self.key = None;
80 self
81 }
82
83 pub fn with_gates(mut self, gates: impl IntoIterator<Item = IRStmt<E>>) -> Self {
85 self.gates = IRStmt::from_iter(gates);
86 self
87 }
88
89 pub fn with_copy_constraints(
91 mut self,
92 constraints: impl IntoIterator<Item = IRStmt<E>>,
93 ) -> Self {
94 self.eq_constraints = IRStmt::from_iter(constraints);
95 self
96 }
97
98 pub fn with_callsites(mut self, callsites: impl IntoIterator<Item = CallSite<E>>) -> Self {
100 self.callsites = Vec::from_iter(callsites);
101 self
102 }
103
104 pub fn with_lookups(mut self, lookups: impl IntoIterator<Item = IRStmt<E>>) -> Self {
106 self.lookups = IRStmt::from_iter(lookups);
107 self
108 }
109
110 pub fn inject(&mut self, ir: IRStmt<E>) {
112 self.injected.push(ir);
113 }
114
115 pub fn injected_count(&self) -> usize {
117 self.injected.len()
118 }
119
120 pub fn do_debug_comments(mut self, do_it: bool) -> Self {
122 self.generate_debug_comments = do_it;
123 self
124 }
125
126 pub fn is_main(&self) -> bool {
128 self.key.is_none()
129 }
130
131 pub fn name(&self) -> &str {
133 &self.name
134 }
135
136 pub fn name_mut(&mut self) -> &mut String {
138 &mut self.name
139 }
140
141 pub fn id(&self) -> usize {
143 self.id
144 }
145
146 pub fn set_id(&mut self, id: usize) {
148 self.id = id;
149 }
150
151 pub fn input_count(&self) -> usize {
153 self.input_count
154 }
155
156 pub fn output_count(&self) -> usize {
158 self.output_count
159 }
160
161 pub fn key(&self) -> Option<GroupKey> {
163 self.key
164 }
165
166 pub fn callsites(&self) -> &[CallSite<E>] {
168 &self.callsites
169 }
170
171 pub fn callsites_mut(&mut self) -> &mut Vec<CallSite<E>> {
173 &mut self.callsites
174 }
175
176 pub fn statements(&self) -> impl Iterator<Item = &IRStmt<E>> {
178 self.gates
179 .iter()
180 .chain(self.eq_constraints.iter())
181 .chain(self.lookups.iter())
182 .chain(self.injected.iter().flatten())
183 }
184
185 pub fn try_map<O, Err>(
187 self,
188 f: &mut impl FnMut(E) -> Result<O, Err>,
189 ) -> Result<IRGroup<O>, Err> {
190 Ok(IRGroup {
191 name: self.name,
192 id: self.id,
193 input_count: self.input_count,
194 output_count: self.output_count,
195 key: self.key,
196 gates: self.gates.try_map(f)?,
197 eq_constraints: self.eq_constraints.try_map(f)?,
198 callsites: self
199 .callsites
200 .into_iter()
201 .map(|cs| cs.try_map(f))
202 .collect::<Result<Vec<_>, _>>()?,
203 lookups: self.lookups.try_map(f)?,
204 injected: self
205 .injected
206 .into_iter()
207 .map(|i| i.try_map(f))
208 .collect::<Result<Vec<_>, _>>()?,
209 generate_debug_comments: self.generate_debug_comments,
210 })
211 }
212
213 fn validate_callsite(
214 &self,
215 callsite: &CallSite<E>,
216 groups: &[Self],
217 ) -> Result<(), ValidationErrors> {
218 let callee_id = callsite.callee_id();
219 let callee = groups
220 .get(callee_id)
221 .ok_or(ValidationErrors::CalleeNotFound { callee_id })?;
222 if callee.id() != callsite.callee_id() {
223 return Err(ValidationErrors::WrongCallee {
224 callsite_name: callsite.name().to_string(),
225 callsite_id: callee_id,
226 callee_name: callee.name().to_string(),
227 callee_id: callee.id(),
228 });
229 }
230 if callee.input_count != callsite.inputs().len() {
231 return Err(ValidationErrors::UnexpectedInputs {
232 callee_name: callee.name().to_string(),
233 callee_id: callee.id(),
234 callee_count: callee.input_count,
235 callsite_count: callsite.inputs().len(),
236 });
237 }
238 if callee.output_count != callsite.outputs().len() {
239 return Err(ValidationErrors::UnexpectedOutputs {
240 callee_name: callee.name().to_string(),
241 callee_id: callee.id(),
242 callee_count: callee.output_count,
243 callsite_count: callsite.outputs().len(),
244 });
245 }
246 if callsite.outputs().len() != callsite.output_vars().len() {
247 return Err(ValidationErrors::UnexpectedOutputsVars {
248 callsite_name: callsite.name().to_string(),
249 callsite_id: callee_id,
250 callsite_count: callsite.outputs().len(),
251 callsite_vars_count: callsite.output_vars().len(),
252 });
253 }
254
255 Ok(())
256 }
257
258 pub fn eq_constraints_mut(&mut self) -> &mut IRStmt<E> {
260 &mut self.eq_constraints
261 }
262}
263
264impl<E> Validatable for IRGroup<E>
265where
266 IRStmt<E>: Validatable<Diagnostic = SimpleDiagnostic, Context = ()>,
267{
268 type Diagnostic = SimpleDiagnostic;
269
270 type Context = [Self];
271
272 fn validate_with_context(
273 &self,
274 groups: &Self::Context,
275 ) -> Result<Vec<Self::Diagnostic>, Vec<Self::Diagnostic>> {
276 let mut validation = Validation::new();
277
278 validation.with_errors(self.callsites().iter().enumerate().filter_map(
280 |(call_no, callsite)| match self.validate_callsite(callsite, groups) {
281 Ok(_) => None,
282 Err(err) => Some(SimpleDiagnostic::error(format!(
283 "on callsite {call_no}: {err}"
284 ))),
285 },
286 ));
287
288 validation.append_from_result(self.gates.validate(), "on gates");
290 validation.append_from_result(self.eq_constraints.validate(), "on copy constraints");
291 validation.append_from_result(self.lookups.validate(), "on lookups");
292 for ir in &self.injected {
293 validation.append_from_result(ir.validate(), "on injected ir");
294 }
295
296 validation.into()
297 }
298}
299
300#[derive(Error, Debug)]
302#[error("Validation of group {name} failed with {error_count} errors")]
303pub struct ValidationFailed {
304 pub(crate) name: String,
305 pub(crate) error_count: usize,
306}
307
308impl ValidationFailed {
309 pub fn error_count(&self) -> usize {
311 self.error_count
312 }
313}
314
315#[derive(Error, Debug)]
316enum ValidationErrors {
317 #[error("Callee with id {callee_id} was not found")]
318 CalleeNotFound { callee_id: usize },
319 #[error(
320 "Callsite points to \"{callsite_name}\" ({callsite_id}) but callee was \"{callee_name}\" ({callee_id})"
321 )]
322 WrongCallee {
323 callsite_name: String,
324 callsite_id: usize,
325 callee_name: String,
326 callee_id: usize,
327 },
328 #[error(
329 "Callee \"{callee_name}\" ({callee_id}) was expecting {callee_count} inputs but callsite has {callsite_count}"
330 )]
331 UnexpectedInputs {
332 callee_name: String,
333 callee_id: usize,
334 callee_count: usize,
335 callsite_count: usize,
336 },
337 #[error(
338 "Callee \"{callee_name}\" ({callee_id}) was expecting {callee_count} outputs but callsite has {callsite_count}"
339 )]
340 UnexpectedOutputs {
341 callee_name: String,
342 callee_id: usize,
343 callee_count: usize,
344 callsite_count: usize,
345 },
346 #[error(
347 "Call to \"{callsite_name}\" ({callsite_id}) has {callsite_count} outputs but declared {callsite_vars_count} output variables"
348 )]
349 UnexpectedOutputsVars {
350 callsite_name: String,
351 callsite_id: usize,
352 callsite_count: usize,
353 callsite_vars_count: usize,
354 },
355}
356
357impl<E: ConstantFolding> ConstantFolding for IRGroup<E>
358where
359 IRStmt<E>: ConstantFolding,
360{
361 type Error = ConstantFoldingError<E>;
362
363 type T = ();
364
365 fn constant_fold(&mut self) -> Result<(), Self::Error> {
366 self.gates
367 .constant_fold()
368 .map_err(ConstantFoldingError::Stmt)?;
369 self.eq_constraints
370 .constant_fold()
371 .map_err(ConstantFoldingError::Stmt)?;
372 self.lookups
373 .constant_fold()
374 .map_err(ConstantFoldingError::Stmt)?;
375 self.callsites
376 .constant_fold()
377 .map_err(ConstantFoldingError::CallsiteArg)?;
378 self.injected
379 .constant_fold()
380 .map_err(ConstantFoldingError::Stmt)
381 }
382}
383
384#[derive(Error)]
386pub enum ConstantFoldingError<E>
387where
388 IRStmt<E>: ConstantFolding,
389 E: ConstantFolding,
390{
391 #[error(transparent)]
393 Stmt(<IRStmt<E> as ConstantFolding>::Error),
394 #[error(transparent)]
396 CallsiteArg(<E as ConstantFolding>::Error),
397}
398
399impl<E, Err1, Err2> std::fmt::Debug for ConstantFoldingError<E>
400where
401 IRStmt<E>: ConstantFolding<Error = Err1>,
402 E: ConstantFolding<Error = Err2>,
403 Err1: std::fmt::Debug,
404 Err2: std::fmt::Debug,
405{
406 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
407 match self {
408 Self::Stmt(e) => std::fmt::Debug::fmt(e, f),
409 Self::CallsiteArg(e) => std::fmt::Debug::fmt(e, f),
410 }
411 }
412}
413
414impl<E> Canonicalize for IRGroup<E>
415where
416 IRStmt<E>: Canonicalize,
417 CallSite<E>: Canonicalize,
418{
419 fn canonicalize(&mut self) {
420 self.gates.canonicalize();
421 self.eq_constraints.canonicalize();
422 self.lookups.canonicalize();
423 self.callsites.canonicalize();
424 self.injected.canonicalize();
425 }
426}
427
428impl<E> EqvRelation<IRGroup<E>> for SymbolicEqv
429where
430 SymbolicEqv: EqvRelation<E>,
431{
432 fn equivalent(lhs: &IRGroup<E>, rhs: &IRGroup<E>) -> bool {
436 if lhs.is_main() || rhs.is_main() {
438 return false;
439 }
440
441 let lhs_key = lhs.key.unwrap();
442 let rhs_key = rhs.key.unwrap();
443
444 let k = lhs_key == rhs_key;
445 log::debug!("[equivalent({} ~ {})] key: {k}", lhs.id(), rhs.id());
446 let io = lhs.input_count == rhs.input_count && lhs.output_count == rhs.output_count;
447 log::debug!("[equivalent({} ~ {})] io: {io}", lhs.id(), rhs.id());
448 let gates = equiv! { Self | &lhs.gates, &rhs.gates };
449 log::debug!("[equivalent({} ~ {})] gates: {gates}", lhs.id(), rhs.id());
450 let eqc = equiv! { Self | &lhs.eq_constraints, &rhs.eq_constraints };
451 log::debug!("[equivalent({} ~ {})] eqc: {eqc}", lhs.id(), rhs.id());
452 let lookups = equiv! { Self | &lhs.lookups, &rhs.lookups };
453 log::debug!(
454 "[equivalent({} ~ {})] lookups: {lookups}",
455 lhs.id(),
456 rhs.id()
457 );
458 let callsites = equiv! { Self | &lhs.callsites, &rhs.callsites };
459 log::debug!(
460 "[equivalent({} ~ {})] callsites: {callsites}",
461 lhs.id(),
462 rhs.id()
463 );
464
465 k && io && gates && eqc && lookups && callsites
466 }
467}
468
469impl<E> LowerableStmt for IRGroup<E>
470where
471 E: LowerableExpr + std::fmt::Debug,
472 CallSite<E>: LowerableStmt,
473 IRStmt<E>: LowerableStmt,
474{
475 fn lower<L>(self, l: &L) -> haloumi_lowering::Result<()>
476 where
477 L: Lowering + ?Sized,
478 {
479 log::debug!("Lowering {self:?}");
480 if self.generate_debug_comments {
481 l.generate_comment("Calls to subgroups".to_owned())?;
482 }
483 log::debug!(" Lowering callsites");
484 for callsite in self.callsites {
485 callsite.lower(l)?;
486 }
487 if self.generate_debug_comments {
488 l.generate_comment("Gate constraints".to_owned())?;
489 }
490 log::debug!(" Lowering gates");
491 self.gates.lower(l)?;
492 if self.generate_debug_comments {
493 l.generate_comment("Equality constraints".to_owned())?;
494 }
495 log::debug!(" Lowering equality constraints");
496 self.eq_constraints.lower(l)?;
497 if self.generate_debug_comments {
498 l.generate_comment("Lookups".to_owned())?;
499 }
500 log::debug!(" Lowering lookups");
501 self.lookups.lower(l)?;
502 if self.generate_debug_comments {
503 l.generate_comment("Injected".to_owned())?;
504 }
505 log::debug!(" Lowering injected IR");
506 for stmt in self.injected {
507 stmt.lower(l)?;
508 }
509
510 Ok(())
511 }
512}
513
514impl<E: IRPrintable> IRPrintable for IRGroup<E> {
515 fn fmt(&self, ctx: &mut crate::printer::IRPrinterCtx<'_, '_>) -> crate::printer::Result {
516 ctx.block("group", |ctx| {
517 writeln!(
518 ctx,
519 "{} \"{}\" (inputs {}) (outputs {})",
520 self.id(),
521 self.name(),
522 self.input_count(),
523 self.output_count()
524 )?;
525
526 for callsite in self.callsites() {
527 ctx.fmt_call(
528 callsite.name(),
529 callsite.inputs(),
530 callsite.output_vars(),
531 Some(callsite.callee_id()),
532 )?;
533 ctx.nl()?;
534 }
535
536 for stmt in self.statements() {
537 stmt.fmt(ctx)?;
538 ctx.nl()?;
539 }
540
541 Ok(())
542 })
543 }
544}