use crate::ir_inner::model::program::Program;
use lasso::ThreadedRodeo;
use std::sync::{Arc, OnceLock};
use vyre_spec::{AlgebraicLaw, CpuFn};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct InternedOpId(pub u32);
fn get_interner() -> &'static ThreadedRodeo {
static INTERNER: OnceLock<ThreadedRodeo> = OnceLock::new();
INTERNER.get_or_init(ThreadedRodeo::new)
}
#[must_use]
pub fn intern_string(s: &str) -> InternedOpId {
let interner = get_interner();
let key = interner.get_or_intern(s);
InternedOpId(key.into_inner().get())
}
pub type ReferenceKind = CpuFn;
#[derive(Default, Debug, Clone)]
pub struct LoweringCtx<'a> {
pub unused: std::marker::PhantomData<&'a ()>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TextModule {
pub asm: String,
pub version: u32,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct NativeModule {
pub ast: Vec<u8>,
pub entry: String,
}
pub type PrimaryTextBuilder = fn(&LoweringCtx<'_>) -> Result<(), String>;
pub type PrimaryBinaryBuilder = fn(&LoweringCtx<'_>) -> Vec<u32>;
pub type SecondaryTextBuilder = fn(&LoweringCtx<'_>) -> TextModule;
pub type NativeModuleBuilder = fn(&LoweringCtx<'_>) -> NativeModule;
pub type ExtensionLoweringFn =
fn(&LoweringCtx<'_>) -> Result<std::vec::Vec<u8>, std::string::String>;
#[derive(Clone)]
pub struct LoweringTable {
pub cpu_ref: ReferenceKind,
pub primary_text: Option<PrimaryTextBuilder>,
pub primary_binary: Option<PrimaryBinaryBuilder>,
pub secondary_text: Option<SecondaryTextBuilder>,
pub native_module: Option<NativeModuleBuilder>,
pub extensions: rustc_hash::FxHashMap<&'static str, ExtensionLoweringFn>,
}
impl Default for LoweringTable {
fn default() -> Self {
Self::empty()
}
}
impl LoweringTable {
#[must_use]
pub fn new(cpu_ref: ReferenceKind) -> Self {
Self {
cpu_ref,
primary_text: None,
primary_binary: None,
secondary_text: None,
native_module: None,
extensions: rustc_hash::FxHashMap::default(),
}
}
#[must_use]
pub fn empty() -> Self {
Self {
cpu_ref: crate::cpu_op::structured_intrinsic_cpu,
primary_text: None,
primary_binary: None,
secondary_text: None,
native_module: None,
extensions: rustc_hash::FxHashMap::default(),
}
}
#[must_use]
pub fn with_extension(
mut self,
backend_id: &'static str,
builder: ExtensionLoweringFn,
) -> Self {
self.extensions.insert(backend_id, builder);
self
}
#[must_use]
pub fn extension(&self, backend_id: &str) -> Option<ExtensionLoweringFn> {
self.extensions.get(backend_id).copied()
}
}
impl std::fmt::Debug for LoweringTable {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LoweringTable")
.field("cpu_ref", &"<fn>")
.field("primary_text", &self.primary_text.map(|_| "<fn>"))
.field("primary_binary", &self.primary_binary.map(|_| "<fn>"))
.field("secondary_text", &self.secondary_text.map(|_| "<fn>"))
.field("native_module", &self.native_module.map(|_| "<fn>"))
.field(
"extensions",
&self
.extensions
.keys()
.copied()
.collect::<std::vec::Vec<_>>(),
)
.finish()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum AttrType {
U32,
I32,
F32,
Bool,
Bytes,
String,
Enum(&'static [&'static str]),
Unknown,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AttrSchema {
pub name: &'static str,
pub ty: AttrType,
pub default: Option<&'static str>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TypedParam {
pub name: &'static str,
pub ty: &'static str,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Signature {
pub inputs: &'static [TypedParam],
pub outputs: &'static [TypedParam],
pub attrs: &'static [AttrSchema],
pub bytes_extraction: bool,
}
impl Signature {
#[must_use]
pub const fn bytes_extractor(
inputs: &'static [TypedParam],
outputs: &'static [TypedParam],
attrs: &'static [AttrSchema],
) -> Self {
Self {
inputs,
outputs,
attrs,
bytes_extraction: true,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Category {
Composite,
Extension,
Intrinsic,
}
#[derive(Debug, Clone)]
pub struct OpDef {
pub id: &'static str,
pub dialect: &'static str,
pub category: Category,
pub signature: Signature,
pub lowerings: LoweringTable,
pub laws: &'static [AlgebraicLaw],
pub compose: Option<fn() -> Program>,
}
impl OpDef {
#[must_use]
pub const fn id(&self) -> &'static str {
self.id
}
#[must_use]
pub fn program(&self) -> Option<Program> {
self.compose
.map(|compose| compose().with_entry_op_id(self.id))
}
}
impl Default for OpDef {
fn default() -> Self {
Self {
id: "",
dialect: "",
category: Category::Intrinsic,
signature: Signature {
inputs: &[],
outputs: &[],
attrs: &[],
bytes_extraction: false,
},
lowerings: LoweringTable::empty(),
laws: &[],
compose: None,
}
}
}
#[doc(hidden)]
pub mod private {
pub trait Sealed {}
}
pub trait DialectLookup: private::Sealed + Send + Sync {
fn provider_id(&self) -> &'static str;
fn intern_op(&self, name: &str) -> InternedOpId;
fn lookup(&self, id: InternedOpId) -> Option<&'static OpDef>;
}
static DIALECT_LOOKUP: OnceLock<Arc<dyn DialectLookup>> = OnceLock::new();
pub fn install_dialect_lookup(lookup: Arc<dyn DialectLookup>) -> Result<(), String> {
match DIALECT_LOOKUP.get() {
Some(existing) => {
let existing_id = existing.provider_id();
let incoming_id = lookup.provider_id();
ensure_same_provider(existing_id, incoming_id)?;
}
None => {
if let Err(lookup) = DIALECT_LOOKUP.set(lookup) {
let Some(existing) = DIALECT_LOOKUP.get() else {
return Err(
"dialect lookup install lost the value after OnceLock::set failed. Fix: report this impossible OnceLock state."
.to_string(),
);
};
let existing_id = existing.provider_id();
let incoming_id = lookup.provider_id();
ensure_same_provider(existing_id, incoming_id)?;
}
}
}
Ok(())
}
fn ensure_same_provider(existing_id: &str, incoming_id: &str) -> Result<(), String> {
if existing_id == incoming_id {
Ok(())
} else {
Err(format!(
"dialect lookup already installed by provider `{existing_id}`; second installer `{incoming_id}` reports a different id. Fix: pick one provider for the process or reuse the first provider's id. Silent replacement is refused because two divergent lookups would mis-resolve op ids at runtime."
))
}
}
#[must_use]
pub fn dialect_lookup() -> Option<&'static dyn DialectLookup> {
DIALECT_LOOKUP.get().map(Arc::as_ref)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn intern_string_is_deterministic() {
let a = intern_string("test::op::add");
let b = intern_string("test::op::add");
assert_eq!(a, b);
}
#[test]
fn intern_string_distinct_for_different_ops() {
let a = intern_string("test::op::add");
let b = intern_string("test::op::mul");
assert_ne!(a, b);
}
#[test]
fn lowering_table_empty_has_no_native_builders() {
let table = LoweringTable::empty();
assert!(table.primary_text.is_none());
assert!(table.primary_binary.is_none());
assert!(table.secondary_text.is_none());
assert!(table.native_module.is_none());
assert!(table.extensions.is_empty());
}
#[test]
fn lowering_table_extension_lookup() {
fn dummy_builder(_: &LoweringCtx<'_>) -> Result<Vec<u8>, String> {
Ok(vec![1, 2, 3])
}
let table = LoweringTable::empty().with_extension("my-extension", dummy_builder);
assert!(table.extension("my-extension").is_some());
assert!(table.extension("nonexistent").is_none());
}
#[test]
fn opdef_default_has_empty_id() {
let def = OpDef::default();
assert_eq!(def.id(), "");
assert!(def.program().is_none());
}
#[test]
fn signature_bytes_extractor_sets_flag() {
let sig = Signature::bytes_extractor(&[], &[], &[]);
assert!(sig.bytes_extraction);
}
#[test]
fn secondary_text_module_equality() {
let a = TextModule {
asm: ".version 7.0".into(),
version: 70,
};
let b = TextModule {
asm: ".version 7.0".into(),
version: 70,
};
assert_eq!(a, b);
}
#[test]
fn native_module_module_equality() {
let a = NativeModule {
ast: vec![1, 2, 3],
entry: "main".into(),
};
let b = NativeModule {
ast: vec![1, 2, 3],
entry: "main".into(),
};
assert_eq!(a, b);
}
#[test]
fn category_debug() {
assert_eq!(format!("{:?}", Category::Composite), "Composite");
assert_eq!(format!("{:?}", Category::Extension), "Extension");
assert_eq!(format!("{:?}", Category::Intrinsic), "Intrinsic");
}
}