use anyhow::{bail, Result};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tensorlogic_ir::{EinsumGraph, TLExpr};
use crate::{config::CompilationConfig, CompilerContext};
pub type CustomOpHandler = Arc<
dyn Fn(&TLExpr, &mut CompilerContext, &mut EinsumGraph, &CustomOpData) -> Result<usize>
+ Send
+ Sync,
>;
#[derive(Debug, Clone)]
pub struct CustomOpMetadata {
pub name: String,
pub description: String,
pub expected_arity: Option<usize>,
pub is_differentiable: bool,
}
#[derive(Debug, Clone, Default)]
pub struct CustomOpData {
pub string_data: HashMap<String, String>,
pub numeric_data: HashMap<String, f64>,
}
impl CustomOpData {
pub fn new() -> Self {
Self::default()
}
pub fn with_string(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.string_data.insert(key.into(), value.into());
self
}
pub fn with_numeric(mut self, key: impl Into<String>, value: f64) -> Self {
self.numeric_data.insert(key.into(), value);
self
}
pub fn get_string(&self, key: &str) -> Option<&String> {
self.string_data.get(key)
}
pub fn get_numeric(&self, key: &str) -> Option<f64> {
self.numeric_data.get(key).copied()
}
}
pub struct CustomOpRegistry {
handlers: RwLock<HashMap<String, (CustomOpHandler, CustomOpMetadata)>>,
}
impl Default for CustomOpRegistry {
fn default() -> Self {
Self::new()
}
}
impl CustomOpRegistry {
pub fn new() -> Self {
Self {
handlers: RwLock::new(HashMap::new()),
}
}
pub fn register(
&mut self,
name: impl Into<String>,
metadata: CustomOpMetadata,
handler: CustomOpHandler,
) -> Result<()> {
let name = name.into();
let mut handlers = self.handlers.write().unwrap();
if handlers.contains_key(&name) {
bail!("Custom operation '{}' is already registered", name);
}
handlers.insert(name, (handler, metadata));
Ok(())
}
pub fn unregister(&mut self, name: &str) -> Result<()> {
let mut handlers = self.handlers.write().unwrap();
if handlers.remove(name).is_none() {
bail!("Custom operation '{}' not found", name);
}
Ok(())
}
pub fn has_operation(&self, name: &str) -> bool {
let handlers = self.handlers.read().unwrap();
handlers.contains_key(name)
}
pub fn get_metadata(&self, name: &str) -> Option<CustomOpMetadata> {
let handlers = self.handlers.read().unwrap();
handlers.get(name).map(|(_, meta)| meta.clone())
}
pub fn list_operations(&self) -> Vec<String> {
let handlers = self.handlers.read().unwrap();
handlers.keys().cloned().collect()
}
pub fn invoke(
&self,
name: &str,
expr: &TLExpr,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
data: &CustomOpData,
) -> Result<usize> {
let handlers = self.handlers.read().unwrap();
let (handler, metadata) = handlers
.get(name)
.ok_or_else(|| anyhow::anyhow!("Custom operation '{}' not found", name))?;
if let Some(expected) = metadata.expected_arity {
if let TLExpr::Pred { args, .. } = expr {
if args.len() != expected {
bail!(
"Custom operation '{}' expects {} arguments, got {}",
name,
expected,
args.len()
);
}
}
}
handler(expr, ctx, graph, data)
}
}
#[derive(Clone)]
pub struct ExtendedCompilerContext {
pub base_context: CompilerContext,
pub custom_ops: Arc<CustomOpRegistry>,
pub custom_data: CustomOpData,
}
impl ExtendedCompilerContext {
pub fn new() -> Self {
Self {
base_context: CompilerContext::new(),
custom_ops: Arc::new(CustomOpRegistry::new()),
custom_data: CustomOpData::new(),
}
}
pub fn from_context(ctx: CompilerContext) -> Self {
Self {
base_context: ctx,
custom_ops: Arc::new(CustomOpRegistry::new()),
custom_data: CustomOpData::new(),
}
}
pub fn with_config(mut self, config: CompilationConfig) -> Self {
self.base_context = CompilerContext::with_config(config);
self
}
pub fn with_custom_data(mut self, data: CustomOpData) -> Self {
self.custom_data = data;
self
}
pub fn custom_ops_mut(&mut self) -> &mut CustomOpRegistry {
Arc::get_mut(&mut self.custom_ops)
.expect("Cannot get mutable access to shared CustomOpRegistry")
}
}
impl Default for ExtendedCompilerContext {
fn default() -> Self {
Self::new()
}
}
pub mod presets {
use super::*;
pub fn create_soft_threshold_and(sharpness: f64) -> (CustomOpMetadata, CustomOpHandler) {
let metadata = CustomOpMetadata {
name: "soft_threshold_and".to_string(),
description: format!("Soft threshold AND with sharpness parameter {}", sharpness),
expected_arity: Some(2),
is_differentiable: true,
};
let handler = Arc::new(
move |_expr: &TLExpr,
_ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
data: &CustomOpData| {
let _k = data.get_numeric("sharpness").unwrap_or(sharpness);
let tensor_idx = graph.add_tensor("soft_threshold_and_result");
Ok(tensor_idx)
},
) as CustomOpHandler;
(metadata, handler)
}
pub fn create_weighted_or(w1: f64, w2: f64) -> (CustomOpMetadata, CustomOpHandler) {
let metadata = CustomOpMetadata {
name: "weighted_or".to_string(),
description: format!("Weighted OR with weights {} and {}", w1, w2),
expected_arity: Some(2),
is_differentiable: true,
};
let handler = Arc::new(
move |_expr: &TLExpr,
_ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
data: &CustomOpData| {
let weight1 = data.get_numeric("w1").unwrap_or(w1);
let weight2 = data.get_numeric("w2").unwrap_or(w2);
let tensor_idx =
graph.add_tensor(format!("weighted_or_result_{}_{}", weight1, weight2));
Ok(tensor_idx)
},
) as CustomOpHandler;
(metadata, handler)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_custom_op_data() {
let data = CustomOpData::new()
.with_string("mode", "test")
.with_numeric("threshold", 0.5);
assert_eq!(data.get_string("mode"), Some(&"test".to_string()));
assert_eq!(data.get_numeric("threshold"), Some(0.5));
assert_eq!(data.get_string("nonexistent"), None);
}
#[test]
fn test_extended_context() {
let ctx = ExtendedCompilerContext::new();
assert_eq!(ctx.base_context.domains.len(), 0);
}
#[test]
fn test_preset_soft_threshold_and() {
let (metadata, _handler) = presets::create_soft_threshold_and(2.0);
assert_eq!(metadata.name, "soft_threshold_and");
assert_eq!(metadata.expected_arity, Some(2));
assert!(metadata.is_differentiable);
}
#[test]
fn test_preset_weighted_or() {
let (metadata, _handler) = presets::create_weighted_or(0.6, 0.4);
assert_eq!(metadata.name, "weighted_or");
assert_eq!(metadata.expected_arity, Some(2));
assert!(metadata.is_differentiable);
}
}