use std::fmt::Debug;
use std::hash::Hash;
use std::sync::LazyLock;
use rustc_hash::FxHashMap;
use vyre_spec::extension::{
ExtensionAtomicOp, ExtensionAtomicOpId, ExtensionBinOp, ExtensionBinOpId, ExtensionDataType,
ExtensionDataTypeId, ExtensionRuleConditionId, ExtensionUnOp, ExtensionUnOpId,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ExtensionId(pub u32);
impl ExtensionId {
#[must_use]
pub fn from_name(name: &str) -> Self {
let digest = blake3::hash(name.as_bytes());
let bytes = digest.as_bytes();
let id = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
Self(id | 0x8000_0000)
}
}
pub trait ExprExtensionNode: Debug + Send + Sync + 'static {
fn extension_id(&self) -> ExtensionId;
fn encode(&self) -> Vec<u8>;
fn display(&self) -> String;
}
pub trait NodeNode: Debug + Send + Sync + 'static {
fn extension_id(&self) -> ExtensionId;
fn encode(&self) -> Vec<u8>;
fn display(&self) -> String;
}
pub trait RuleConditionExt: Debug + Send + Sync + 'static {
fn extension_id(&self) -> ExtensionRuleConditionId;
fn evaluate_opaque(&self, ctx: &dyn std::any::Any) -> bool;
fn stable_fingerprint(&self) -> [u8; 32];
fn required_buffers(&self) -> Vec<crate::ir::BufferDecl> {
Vec::new()
}
}
pub struct ExtensionDataTypeRegistration {
pub id: ExtensionDataTypeId,
pub vtable: &'static dyn ExtensionDataType,
}
pub struct ExtensionBinOpRegistration {
pub id: ExtensionBinOpId,
pub vtable: &'static dyn ExtensionBinOp,
}
pub struct ExtensionUnOpRegistration {
pub id: ExtensionUnOpId,
pub vtable: &'static dyn ExtensionUnOp,
}
pub struct ExtensionAtomicOpRegistration {
pub id: ExtensionAtomicOpId,
pub vtable: &'static dyn ExtensionAtomicOp,
}
pub struct ExtensionRegistration {
pub id: ExtensionId,
pub name: &'static str,
pub kind: ExtensionKind,
pub decode: fn(&[u8]) -> Result<(), String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum ExtensionKind {
Expr,
Node,
DataType,
RuleCondition,
}
inventory::collect!(ExtensionRegistration);
inventory::collect!(ExtensionDataTypeRegistration);
inventory::collect!(ExtensionBinOpRegistration);
inventory::collect!(ExtensionUnOpRegistration);
inventory::collect!(ExtensionAtomicOpRegistration);
pub type ExprExtensionDeserializer =
fn(&[u8]) -> Result<std::sync::Arc<dyn crate::ir::ExprNode>, String>;
pub type NodeExtensionDeserializer =
fn(&[u8]) -> Result<std::sync::Arc<dyn crate::ir::NodeExtension>, String>;
pub struct OpaqueExprResolver {
pub kind: &'static str,
pub deserialize: ExprExtensionDeserializer,
}
pub struct OpaqueNodeResolver {
pub kind: &'static str,
pub deserialize: NodeExtensionDeserializer,
}
inventory::collect!(OpaqueExprResolver);
inventory::collect!(OpaqueNodeResolver);
fn collect_unique_by<K, V, I>(
registrations: I,
registry_name: &str,
) -> Result<FxHashMap<K, V>, String>
where
K: Eq + Hash + Copy + std::fmt::Debug,
I: IntoIterator<Item = (K, V, &'static str)>,
{
let mut map = FxHashMap::default();
let mut owners: FxHashMap<K, &'static str> = FxHashMap::default();
for (key, value, owner) in registrations {
if let Some(previous_owner) = owners.insert(key, owner) {
return Err(format!(
"{registry_name} duplicate registration for {key:?}: first registrant `{previous_owner}`, second registrant `{owner}`. Fix: pick one stable tag/kind owner."
));
}
map.insert(key, value);
}
Ok(map)
}
fn frozen_opaque_expr_registry(
) -> Result<&'static FxHashMap<&'static str, ExprExtensionDeserializer>, String> {
static FROZEN: LazyLock<Result<FxHashMap<&'static str, ExprExtensionDeserializer>, String>> =
LazyLock::new(|| {
collect_unique_by(
inventory::iter::<OpaqueExprResolver>
.into_iter()
.map(|reg| (reg.kind, reg.deserialize, reg.kind)),
"OpaqueExprResolver",
)
});
FROZEN.as_ref().map_err(Clone::clone)
}
fn frozen_opaque_node_registry(
) -> Result<&'static FxHashMap<&'static str, NodeExtensionDeserializer>, String> {
static FROZEN: LazyLock<Result<FxHashMap<&'static str, NodeExtensionDeserializer>, String>> =
LazyLock::new(|| {
collect_unique_by(
inventory::iter::<OpaqueNodeResolver>
.into_iter()
.map(|reg| (reg.kind, reg.deserialize, reg.kind)),
"OpaqueNodeResolver",
)
});
FROZEN.as_ref().map_err(Clone::clone)
}
pub fn decode_opaque_expr(kind: &str, payload: &[u8]) -> Result<crate::ir::Expr, String> {
let registry = frozen_opaque_expr_registry()?;
if let Some(deserialize) = registry.get(kind) {
let node = deserialize(payload)?;
Ok(crate::ir::Expr::Opaque(node))
} else {
Err(format!(
"Fix: no OpaqueExprResolver registered for extension kind `{kind}`. Link the crate that owns this extension and ensure it submits `inventory::submit! {{ OpaqueExprResolver {{ kind, deserialize }} }}`."
))
}
}
pub fn decode_opaque_node(kind: &str, payload: &[u8]) -> Result<crate::ir::Node, String> {
let registry = frozen_opaque_node_registry()?;
if let Some(deserialize) = registry.get(kind) {
let extension = deserialize(payload)?;
Ok(crate::ir::Node::Opaque(extension))
} else {
Err(format!(
"Fix: no OpaqueNodeResolver registered for extension kind `{kind}`. Link the crate that owns this extension and ensure it submits `inventory::submit! {{ OpaqueNodeResolver {{ kind, deserialize }} }}`."
))
}
}
fn frozen_generic_registry(
) -> Result<&'static FxHashMap<ExtensionId, &'static ExtensionRegistration>, String> {
static FROZEN: LazyLock<
Result<FxHashMap<ExtensionId, &'static ExtensionRegistration>, String>,
> = LazyLock::new(|| {
collect_unique_by(
inventory::iter::<ExtensionRegistration>
.into_iter()
.map(|reg| (reg.id, reg, reg.name)),
"ExtensionRegistration",
)
});
FROZEN.as_ref().map_err(Clone::clone)
}
fn frozen_data_type_registry(
) -> Result<&'static FxHashMap<ExtensionDataTypeId, &'static dyn ExtensionDataType>, String> {
static FROZEN: LazyLock<
Result<FxHashMap<ExtensionDataTypeId, &'static dyn ExtensionDataType>, String>,
> = LazyLock::new(|| {
collect_unique_by(
inventory::iter::<ExtensionDataTypeRegistration>
.into_iter()
.map(|reg| (reg.id, reg.vtable, reg.vtable.display_name())),
"ExtensionDataTypeRegistration",
)
});
FROZEN.as_ref().map_err(Clone::clone)
}
fn frozen_bin_op_registry(
) -> Result<&'static FxHashMap<ExtensionBinOpId, &'static dyn ExtensionBinOp>, String> {
static FROZEN: LazyLock<
Result<FxHashMap<ExtensionBinOpId, &'static dyn ExtensionBinOp>, String>,
> = LazyLock::new(|| {
collect_unique_by(
inventory::iter::<ExtensionBinOpRegistration>
.into_iter()
.map(|reg| (reg.id, reg.vtable, reg.vtable.display_name())),
"ExtensionBinOpRegistration",
)
});
FROZEN.as_ref().map_err(Clone::clone)
}
fn frozen_un_op_registry(
) -> Result<&'static FxHashMap<ExtensionUnOpId, &'static dyn ExtensionUnOp>, String> {
static FROZEN: LazyLock<
Result<FxHashMap<ExtensionUnOpId, &'static dyn ExtensionUnOp>, String>,
> = LazyLock::new(|| {
collect_unique_by(
inventory::iter::<ExtensionUnOpRegistration>
.into_iter()
.map(|reg| (reg.id, reg.vtable, reg.vtable.display_name())),
"ExtensionUnOpRegistration",
)
});
FROZEN.as_ref().map_err(Clone::clone)
}
fn frozen_atomic_op_registry(
) -> Result<&'static FxHashMap<ExtensionAtomicOpId, &'static dyn ExtensionAtomicOp>, String> {
static FROZEN: LazyLock<
Result<FxHashMap<ExtensionAtomicOpId, &'static dyn ExtensionAtomicOp>, String>,
> = LazyLock::new(|| {
collect_unique_by(
inventory::iter::<ExtensionAtomicOpRegistration>
.into_iter()
.map(|reg| (reg.id, reg.vtable, reg.vtable.display_name())),
"ExtensionAtomicOpRegistration",
)
});
FROZEN.as_ref().map_err(Clone::clone)
}
#[must_use]
pub fn resolve_data_type(id: ExtensionDataTypeId) -> Option<&'static dyn ExtensionDataType> {
try_resolve_data_type(id).ok().flatten()
}
pub fn try_resolve_data_type(
id: ExtensionDataTypeId,
) -> Result<Option<&'static dyn ExtensionDataType>, String> {
Ok(frozen_data_type_registry()?.get(&id).copied())
}
#[must_use]
pub fn resolve_bin_op(id: ExtensionBinOpId) -> Option<&'static dyn ExtensionBinOp> {
try_resolve_bin_op(id).ok().flatten()
}
pub fn try_resolve_bin_op(
id: ExtensionBinOpId,
) -> Result<Option<&'static dyn ExtensionBinOp>, String> {
Ok(frozen_bin_op_registry()?.get(&id).copied())
}
#[must_use]
pub fn resolve_un_op(id: ExtensionUnOpId) -> Option<&'static dyn ExtensionUnOp> {
try_resolve_un_op(id).ok().flatten()
}
pub fn try_resolve_un_op(
id: ExtensionUnOpId,
) -> Result<Option<&'static dyn ExtensionUnOp>, String> {
Ok(frozen_un_op_registry()?.get(&id).copied())
}
#[must_use]
pub fn resolve_atomic_op(id: ExtensionAtomicOpId) -> Option<&'static dyn ExtensionAtomicOp> {
try_resolve_atomic_op(id).ok().flatten()
}
pub fn try_resolve_atomic_op(
id: ExtensionAtomicOpId,
) -> Result<Option<&'static dyn ExtensionAtomicOp>, String> {
Ok(frozen_atomic_op_registry()?.get(&id).copied())
}
#[must_use]
pub fn find_extension(id: ExtensionId) -> Option<&'static ExtensionRegistration> {
try_find_extension(id).ok().flatten()
}
pub fn try_find_extension(
id: ExtensionId,
) -> Result<Option<&'static ExtensionRegistration>, String> {
Ok(frozen_generic_registry()?.get(&id).copied())
}
#[must_use]
pub fn registered_extensions() -> Vec<&'static ExtensionRegistration> {
try_registered_extensions().unwrap_or_default()
}
pub fn try_registered_extensions() -> Result<Vec<&'static ExtensionRegistration>, String> {
Ok(frozen_generic_registry()?.values().copied().collect())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extension_id_has_high_bit_set() {
let id = ExtensionId::from_name("example.crate");
assert_ne!(id.0 & 0x8000_0000, 0);
}
#[test]
fn extension_id_is_deterministic() {
let a = ExtensionId::from_name("vyre-example-ext");
let b = ExtensionId::from_name("vyre-example-ext");
assert_eq!(a, b);
}
#[test]
fn per_kind_resolvers_are_empty_by_default() {
let data_type_id = ExtensionDataTypeId::from_name("tensor.gather");
assert!(resolve_data_type(data_type_id).is_none());
let bin_op_id = ExtensionBinOpId::from_name("bit.parity");
assert!(resolve_bin_op(bin_op_id).is_none());
let un_op_id = ExtensionUnOpId::from_name("bit.reverse_nibbles");
assert!(resolve_un_op(un_op_id).is_none());
let atomic_id = ExtensionAtomicOpId::from_name("atomic.clamp");
assert!(resolve_atomic_op(atomic_id).is_none());
}
#[test]
fn generic_registry_is_empty_by_default() {
assert_eq!(registered_extensions().len(), 0);
}
#[test]
fn duplicate_extension_ids_name_both_registrants() {
let err = collect_unique_by(
[
(ExtensionId(1), 10usize, "dialect.alpha"),
(ExtensionId(1), 20usize, "dialect.beta"),
],
"ExtensionRegistration",
)
.expect_err("Fix: duplicate registrations must return an error");
assert!(err.contains("dialect.alpha"));
assert!(err.contains("dialect.beta"));
}
}