use std::sync::Arc;
use semver::Version;
use super::{Extension, ExtensionCollectionError, ExtensionResolutionError};
use crate::Node;
use crate::extension::{ExtensionId, ExtensionRegistry, OpDef};
use crate::ops::custom::OpaqueOpError;
use crate::ops::{DataflowOpTrait, ExtensionOp, NamedOp, OpNameRef, OpType};
pub(crate) fn collect_op_extension(
node: Option<Node>,
op: &OpType,
) -> Result<Option<Arc<Extension>>, ExtensionCollectionError> {
let OpType::ExtensionOp(ext_op) = op else {
return Ok(None);
};
let ext = ext_op.def().extension();
match ext.upgrade() {
Some(e) => Ok(Some(e)),
None => Err(ExtensionCollectionError::dropped_op_extension(
node,
op,
[ext_op.def().extension_id().clone()],
)),
}
}
pub(crate) fn resolve_op_extensions<'e>(
node: Node,
op: &mut OpType,
extensions: &'e ExtensionRegistry,
) -> Result<Option<&'e Arc<Extension>>, ExtensionResolutionError> {
fn op_and_ext<'a, 'e>(
extensions: &'e ExtensionRegistry,
node: Node,
ext_id: &'a ExtensionId,
ext_version: Option<&'a Version>,
qualified_id: impl AsRef<OpNameRef>,
unqualified_id: impl AsRef<OpNameRef>,
) -> Result<(&'e Arc<OpDef>, &'e Arc<Extension>), ExtensionResolutionError> {
let Some(extension) = extensions.get_req(ext_id, ext_version) else {
return Err(ExtensionResolutionError::MissingOpExtension {
node: Some(node),
op: qualified_id.as_ref().into(),
missing_extension: ext_id.clone(),
available_extensions: extensions.ids().cloned().collect(),
});
};
let Some(op_def) = extension.get_op(unqualified_id.as_ref()) else {
return Err(OpaqueOpError::OpNotFoundInExtension {
node,
op: qualified_id.as_ref().into(),
extension: ext_id.clone(),
available_ops: extension
.operations()
.map(|(name, _)| name.clone())
.collect(),
}
.into());
};
Ok((op_def, extension))
}
match op {
OpType::ExtensionOp(ext_op) => {
let ext_id = ext_op.extension_id();
let ext_version = ext_op.extension_version();
let (op_def, extension) = op_and_ext(
extensions,
node,
ext_id,
Some(&ext_version),
ext_op.qualified_id(),
ext_op.unqualified_id(),
)?;
ext_op
.relink_def(op_def.clone())
.map_err(|e| OpaqueOpError::SignatureError {
node,
name: ext_op.qualified_id(),
cause: e,
})?;
Ok(Some(extension))
}
OpType::OpaqueOp(opaque) => {
let ext_id = opaque.extension();
let ext_version = opaque.extension_version();
let (op_def, extension) = op_and_ext(
extensions,
node,
ext_id,
ext_version,
opaque.qualified_id(),
opaque.unqualified_id(),
)?;
let ext_op =
ExtensionOp::new_with_cached(op_def.clone(), opaque.args().to_vec(), opaque)
.map_err(|e| OpaqueOpError::SignatureError {
node,
name: opaque.name().clone(),
cause: e,
})?;
if opaque.signature().io() != ext_op.signature().io() {
return Err(OpaqueOpError::SignatureMismatch {
node,
extension: opaque.extension().clone(),
op: op_def.name().clone(),
computed: Box::new(ext_op.signature().into_owned()),
stored: Box::new(opaque.signature().into_owned()),
}
.into());
}
*op = ext_op.into();
Ok(Some(extension))
}
_ => Ok(None),
}
}