use alloc::{boxed::Box, collections::BTreeMap, format};
use core::any::TypeId;
use midenc_hir_symbol::sync::{LazyLock, RwLock};
use midenc_session::diagnostics::DiagnosticsHandler;
use super::*;
use crate::Report;
static PASS_REGISTRY: LazyLock<PassRegistry> = LazyLock::new(PassRegistry::new);
pub struct PassRegistry {
passes: RwLock<BTreeMap<&'static str, PassRegistryEntry>>,
pipelines: RwLock<BTreeMap<&'static str, PassRegistryEntry>>,
}
impl Default for PassRegistry {
fn default() -> Self {
Self::new()
}
}
impl PassRegistry {
pub fn new() -> Self {
let mut passes = BTreeMap::default();
let mut pipelines = BTreeMap::default();
for pass in inventory::iter::<PassInfo>() {
passes.insert(
pass.0.arg,
PassRegistryEntry {
arg: pass.0.arg,
description: pass.0.description,
type_id: pass.0.type_id,
builder: pass.0.builder,
},
);
}
for pipeline in inventory::iter::<PassPipelineInfo>() {
pipelines.insert(
pipeline.0.arg,
PassRegistryEntry {
arg: pipeline.0.arg,
description: pipeline.0.description,
type_id: pipeline.0.type_id,
builder: pipeline.0.builder,
},
);
}
Self {
passes: RwLock::new(passes),
pipelines: RwLock::new(pipelines),
}
}
pub fn get_pass(&self, name: &str) -> Option<PassInfo> {
self.passes.read().get(name).cloned().map(PassInfo)
}
pub fn get_pipeline(&self, name: &str) -> Option<PassPipelineInfo> {
self.pipelines.read().get(name).cloned().map(PassPipelineInfo)
}
pub fn register_pass(&self, info: PassInfo) {
use alloc::collections::btree_map::Entry;
let mut passes = self.passes.write();
match passes.entry(info.argument()) {
Entry::Vacant(entry) => {
entry.insert(info.0);
}
Entry::Occupied(entry) => {
assert_eq!(
entry.get().type_id,
info.0.type_id,
"cannot register pass '{}': name already registered by a different type",
info.argument()
);
}
}
}
pub fn register_pipeline(&self, info: PassPipelineInfo) {
use alloc::collections::btree_map::Entry;
let mut pipelines = self.pipelines.write();
match pipelines.entry(info.argument()) {
Entry::Vacant(entry) => {
entry.insert(info.0);
}
Entry::Occupied(entry) => {
assert_eq!(
entry.get().type_id,
info.0.type_id,
"cannot register pass pipeline '{}': name already registered by a different \
type",
info.argument()
);
assert!(core::ptr::addr_eq(
entry.get().builder as *const (),
info.0.builder as *const ()
));
}
}
}
}
inventory::collect!(PassInfo);
inventory::collect!(PassPipelineInfo);
pub type PassRegistryFunction =
fn(&mut OpPassManager, &str, &DiagnosticsHandler) -> Result<(), Report>;
pub type PassAllocatorFunction = fn() -> Box<dyn OperationPass>;
pub trait RegistryEntry {
fn argument(&self) -> &'static str;
fn description(&self) -> &'static str;
fn add_to_pipeline(
&self,
pm: &mut OpPassManager,
options: &str,
diagnostics: &DiagnosticsHandler,
) -> Result<(), Report>;
}
#[derive(Clone)]
struct PassRegistryEntry {
arg: &'static str,
description: &'static str,
type_id: Option<TypeId>,
builder: PassAllocatorFunction,
}
impl RegistryEntry for PassRegistryEntry {
#[inline]
fn add_to_pipeline(
&self,
pm: &mut OpPassManager,
options: &str,
diagnostics: &DiagnosticsHandler,
) -> Result<(), Report> {
default_registration_factory(self.builder)(pm, options, diagnostics)
}
#[inline(always)]
fn argument(&self) -> &'static str {
self.arg
}
#[inline(always)]
fn description(&self) -> &'static str {
self.description
}
}
pub struct PassPipelineInfo(PassRegistryEntry);
impl PassPipelineInfo {
pub fn new<B>(
arg: &'static str,
description: &'static str,
builder: PassAllocatorFunction,
) -> Self {
Self(PassRegistryEntry {
arg,
description,
type_id: None,
builder,
})
}
pub fn lookup(name: &str) -> Option<PassPipelineInfo> {
PASS_REGISTRY.get_pipeline(name)
}
}
impl RegistryEntry for PassPipelineInfo {
fn argument(&self) -> &'static str {
self.0.argument()
}
fn description(&self) -> &'static str {
self.0.description()
}
fn add_to_pipeline(
&self,
pm: &mut OpPassManager,
options: &str,
diagnostics: &DiagnosticsHandler,
) -> Result<(), Report> {
self.0.add_to_pipeline(pm, options, diagnostics)
}
}
pub struct PassInfo(PassRegistryEntry);
impl PassInfo {
pub const fn new<P: Pass + Default>(arg: &'static str, description: &'static str) -> Self {
let type_id = TypeId::of::<P>();
Self(PassRegistryEntry {
arg,
description,
type_id: Some(type_id),
builder: default_instance::<P>,
})
}
pub const fn new_with_builder<P: Pass>(
arg: &'static str,
description: &'static str,
builder: PassAllocatorFunction,
) -> Self {
let type_id = TypeId::of::<P>();
Self(PassRegistryEntry {
arg,
description,
type_id: Some(type_id),
builder,
})
}
pub fn lookup(name: &str) -> Option<PassInfo> {
PASS_REGISTRY.get_pass(name)
}
}
impl RegistryEntry for PassInfo {
fn argument(&self) -> &'static str {
self.0.argument()
}
fn description(&self) -> &'static str {
self.0.description()
}
fn add_to_pipeline(
&self,
pm: &mut OpPassManager,
options: &str,
diagnostics: &DiagnosticsHandler,
) -> Result<(), Report> {
self.0.add_to_pipeline(pm, options, diagnostics)
}
}
pub fn register_pass_pipeline(
arg: &'static str,
description: &'static str,
builder: PassAllocatorFunction,
) {
PASS_REGISTRY.register_pipeline(PassPipelineInfo(PassRegistryEntry {
arg,
description,
type_id: None,
builder,
}));
}
pub fn register_pass(builder: PassAllocatorFunction) {
let pass = builder();
let type_id = pass.as_any().type_id();
let arg = pass.argument();
assert!(
!arg.is_empty(),
"attempted to register pass '{}' without specifying an argument name",
pass.name()
);
let description = pass.description();
PASS_REGISTRY.register_pass(PassInfo(PassRegistryEntry {
arg,
description,
type_id: Some(type_id),
builder,
}));
}
pub fn default_registration<P: Pass + Default>(
pm: &mut OpPassManager,
options: &str,
diagnostics: &DiagnosticsHandler,
) -> Result<(), Report> {
use midenc_session::diagnostics::Severity;
let mut pass = Box::<P>::default() as Box<dyn OperationPass>;
let result = pass.initialize_options(options);
let pm_op_name = pm.name();
let pass_op_name = pass.target_name(&pm.context());
let pass_op_name = pass_op_name.as_ref();
if matches!(pm.nesting(), Nesting::Explicit)
&& (pass_op_name.is_some_and(|p| pm_op_name.is_none_or(|p2| p != p2))
|| (pm_op_name.is_some_and(|p| pass_op_name.is_some_and(|p2| p != p2))))
{
return Err(diagnostics
.diagnostic(Severity::Error)
.with_message(format!(
"registration error for pass '{}': can't add pass restricted to '{}' on a pass \
manager intended to run on '{}', did you intend to nest?",
pass.name(),
crate::formatter::DisplayOptional(pass_op_name.as_ref()),
crate::formatter::DisplayOptional(pm_op_name),
))
.into_report());
}
pm.add_pass(pass);
result
}
pub fn default_registration_factory(
builder: PassAllocatorFunction,
) -> impl Fn(&mut OpPassManager, &str, &DiagnosticsHandler) -> Result<(), Report> + Send + Sync + 'static
{
use midenc_session::diagnostics::Severity;
move |pm: &mut OpPassManager,
options: &str,
diagnostics: &DiagnosticsHandler|
-> Result<(), Report> {
let mut pass = builder();
let result = pass.initialize_options(options);
let pm_op_name = pm.name();
let pass_op_name = pass.target_name(&pm.context());
let pass_op_name = pass_op_name.as_ref();
if matches!(pm.nesting(), Nesting::Explicit)
&& (pass_op_name.is_some_and(|p| pm_op_name.is_none_or(|p2| p != p2))
|| (pm_op_name.is_some_and(|p| pass_op_name.is_some_and(|p2| p != p2))))
{
return Err(diagnostics
.diagnostic(Severity::Error)
.with_message(format!(
"registration error for pass '{}': can't add pass restricted to '{}' on a \
pass manager intended to run on '{}', did you intend to nest?",
pass.name(),
crate::formatter::DisplayOptional(pass_op_name.as_ref()),
crate::formatter::DisplayOptional(pm_op_name),
))
.into_report());
}
pm.add_pass(pass);
result
}
}
fn default_instance<P: Pass + Default>() -> Box<dyn OperationPass> {
Box::<P>::default() as Box<dyn OperationPass>
}