use std::sync::Arc;
use super::{Extension, ExtensionCollectionError, ExtensionResolutionError};
use crate::Node;
use crate::extension::ExtensionRegistry;
use crate::ops::custom::OpaqueOpError;
use crate::ops::{DataflowOpTrait, ExtensionOp, NamedOp, OpType};
pub(crate) fn collect_op_extension(
node: Option<Node>,
op: &OpType,
) -> Result<Option<Arc<Extension>>, ExtensionCollectionError> {
let OpType::ExtensionOp(ext_op) = op else {
return Ok(None);
};
let ext = ext_op.def().extension();
match ext.upgrade() {
Some(e) => Ok(Some(e)),
None => Err(ExtensionCollectionError::dropped_op_extension(
node,
op,
[ext_op.def().extension_id().clone()],
)),
}
}
pub(crate) fn resolve_op_extensions<'e>(
node: Node,
op: &mut OpType,
extensions: &'e ExtensionRegistry,
) -> Result<Option<&'e Arc<Extension>>, ExtensionResolutionError> {
let extension = operation_extension(node, op, extensions)?;
let OpType::OpaqueOp(opaque) = op else {
return Ok(extension);
};
let extension = extension.expect("OpaqueOp should have an extension");
let Some(def) = extension.get_op(opaque.unqualified_id()) else {
return Err(OpaqueOpError::OpNotFoundInExtension {
node,
op: opaque.name().clone(),
extension: extension.name().clone(),
available_ops: extension
.operations()
.map(|(name, _)| name.clone())
.collect(),
}
.into());
};
let ext_op = ExtensionOp::new_with_cached(def.clone(), opaque.args().to_vec(), opaque)
.map_err(|e| OpaqueOpError::SignatureError {
node,
name: opaque.name().clone(),
cause: e,
})?;
if opaque.signature().io() != ext_op.signature().io() {
return Err(OpaqueOpError::SignatureMismatch {
node,
extension: opaque.extension().clone(),
op: def.name().clone(),
computed: Box::new(ext_op.signature().into_owned()),
stored: Box::new(opaque.signature().into_owned()),
}
.into());
}
*op = ext_op.into();
Ok(Some(extension))
}
fn operation_extension<'e>(
node: Node,
op: &OpType,
extensions: &'e ExtensionRegistry,
) -> Result<Option<&'e Arc<Extension>>, ExtensionResolutionError> {
let Some(ext) = op.extension_id() else {
return Ok(None);
};
match extensions.get(ext) {
Some(e) => Ok(Some(e)),
None => Err(ExtensionResolutionError::missing_op_extension(
Some(node),
op,
ext,
extensions,
)),
}
}