use std::collections::HashMap;
use std::sync::{Arc, OnceLock, RwLock};
use crate::{Graph, Node, NodeId, Shape};
pub struct VmapContext<'a> {
pub lifted_inputs: &'a [NodeId],
pub is_batched: &'a [bool],
pub batch_size: usize,
pub out: &'a mut Graph,
}
pub struct JvpContext<'a> {
pub tangents: &'a [Option<NodeId>],
pub fwd_map: &'a HashMap<NodeId, NodeId>,
pub bwd: &'a mut Graph,
}
pub struct VjpContext<'a> {
pub upstream: NodeId,
pub fwd_map: &'a HashMap<NodeId, NodeId>,
pub bwd: &'a mut Graph,
}
pub trait OpExtension: Send + Sync {
fn name(&self) -> &str;
fn num_inputs(&self) -> usize;
fn infer_shape(&self, inputs: &[&Shape], attrs: &[u8]) -> Shape;
fn vjp(&self, _node: &Node, _ctx: &mut VjpContext) -> Vec<(usize, NodeId)> {
Vec::new()
}
fn jvp(&self, _node: &Node, _ctx: &mut JvpContext) -> Option<NodeId> {
None
}
fn vmap(&self, _node: &Node, _ctx: &mut VmapContext) -> Option<NodeId> {
None
}
}
pub struct OpRegistry {
ops: RwLock<HashMap<String, Arc<dyn OpExtension>>>,
}
impl OpRegistry {
pub fn new() -> Self {
Self {
ops: RwLock::new(HashMap::new()),
}
}
pub fn register(&self, op: Arc<dyn OpExtension>) {
let name = op.name().to_string();
let mut g = self.ops.write().unwrap();
if g.contains_key(&name) {
eprintln!(
"rlx-ir: OpExtension '{name}' was already registered — \
replacing the previous entry"
);
}
g.insert(name, op);
}
pub fn lookup(&self, name: &str) -> Option<Arc<dyn OpExtension>> {
self.ops.read().unwrap().get(name).cloned()
}
pub fn list(&self) -> Vec<String> {
self.ops.read().unwrap().keys().cloned().collect()
}
}
impl Default for OpRegistry {
fn default() -> Self {
Self::new()
}
}
pub fn global_registry() -> &'static OpRegistry {
static REGISTRY: OnceLock<OpRegistry> = OnceLock::new();
REGISTRY.get_or_init(OpRegistry::new)
}
pub fn register_op(op: Arc<dyn OpExtension>) {
global_registry().register(op);
}
pub fn lookup_op(name: &str) -> Option<Arc<dyn OpExtension>> {
global_registry().lookup(name)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{DType, Shape};
struct DummyOp;
impl OpExtension for DummyOp {
fn name(&self) -> &str {
"dummy"
}
fn num_inputs(&self) -> usize {
1
}
fn infer_shape(&self, inputs: &[&Shape], _: &[u8]) -> Shape {
inputs[0].clone()
}
}
#[test]
fn register_and_lookup() {
let reg = OpRegistry::new();
reg.register(Arc::new(DummyOp));
let op = reg.lookup("dummy").expect("should find");
assert_eq!(op.name(), "dummy");
assert_eq!(op.num_inputs(), 1);
let s = Shape::new(&[2, 3], DType::F32);
let out = op.infer_shape(&[&s], &[]);
assert_eq!(out, s);
}
#[test]
fn vjp_default_is_empty() {
let d = DummyOp;
let mut bwd = Graph::new("b");
let map = HashMap::new();
let upstream = bwd.input("u", Shape::new(&[1], DType::F32));
let node = bwd.nodes()[upstream.0 as usize].clone();
let mut ctx = VjpContext {
upstream,
fwd_map: &map,
bwd: &mut bwd,
};
let grads = d.vjp(&node, &mut ctx);
assert!(grads.is_empty());
}
}