#![allow(dead_code, unused_variables)]
use std::mem;
use std::sync::Arc;
use crate::extension::{
Extension, ExtensionId, ExtensionRegistry, ExtensionSet, OpDef, SignatureFunc, TypeDef,
};
use super::types::collect_signature_exts;
use super::types_mut::resolve_signature_exts;
use super::{ExtensionCollectionError, ExtensionResolutionError, WeakExtensionRegistry};
impl ExtensionRegistry {
pub fn new_with_extension_resolution(
extensions: impl IntoIterator<Item = Extension>,
other_extensions: &WeakExtensionRegistry,
) -> Result<ExtensionRegistry, ExtensionResolutionError> {
Self::new_cyclic(extensions, |mut exts, weak_registry| {
let mut weak_registry = weak_registry.clone();
for (other_id, other) in other_extensions.iter() {
weak_registry.register(other_id.clone(), other.clone());
}
for ext in &mut exts {
ext.resolve_references(&weak_registry)?;
}
Ok(exts)
})
}
pub fn extend_with_dependencies(&mut self) -> Result<(), ExtensionCollectionError> {
let mut queue: Vec<Arc<Extension>> = self.exts.values().cloned().collect();
let mut seen: std::collections::BTreeSet<ExtensionId> = self.exts.keys().cloned().collect();
while let Some(ext) = queue.pop() {
let deps = collect_extension_deps(&ext)?;
for dep in deps {
let dep_id = dep.name().clone();
if seen.insert(dep_id.clone()) {
self.register_updated(dep.clone());
queue.push(dep);
}
}
}
Ok(())
}
}
fn collect_extension_deps(
extension: &Extension,
) -> Result<ExtensionRegistry, ExtensionCollectionError> {
let mut used = WeakExtensionRegistry::default();
let mut missing = ExtensionSet::new();
for (_, op_def) in extension.operations() {
if let Some(signature) = op_def.signature_func().poly_func_type() {
let mut local_missing = ExtensionSet::new();
collect_signature_exts(signature.body(), &mut used, &mut local_missing);
for ext in local_missing {
missing.insert(ext);
}
}
}
if missing.is_empty() {
Ok(used.try_into().expect("All extensions are valid"))
} else {
Err(ExtensionCollectionError::DroppedTransitiveExtensions {
extension: extension.name().to_string(),
missing_extensions: missing.into_iter().collect(),
})
}
}
impl Extension {
fn resolve_references(
&mut self,
extensions: &WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
let mut used_extensions = WeakExtensionRegistry::default();
for type_def in self.types.values_mut() {
resolve_typedef_exts(&self.name, type_def, extensions, &mut used_extensions)?;
}
let ops = mem::take(&mut self.operations);
for (op_id, mut op_def) in ops {
let op_ref = Arc::<OpDef>::get_mut(&mut op_def).expect("OpDef is not unique");
resolve_opdef_exts(&self.name, op_ref, extensions, &mut used_extensions)?;
self.operations.insert(op_id, op_def);
}
Ok(())
}
}
pub(super) fn resolve_typedef_exts(
extension: &ExtensionId,
def: &mut TypeDef,
extensions: &WeakExtensionRegistry,
used_extensions: &mut WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
match extensions.get(def.extension_id()) {
Some(ext) => {
*def.extension_mut() = ext.clone();
}
None => {
return Err(ExtensionResolutionError::WrongTypeDefExtension {
extension: extension.clone(),
def: def.name().clone(),
wrong_extension: def.extension_id().clone(),
});
}
}
Ok(())
}
pub(super) fn resolve_opdef_exts(
extension: &ExtensionId,
def: &mut OpDef,
extensions: &WeakExtensionRegistry,
used_extensions: &mut WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
match extensions.get(def.extension_id()) {
Some(ext) => {
*def.extension_mut() = ext.clone();
}
None => {
return Err(ExtensionResolutionError::WrongOpDefExtension {
extension: extension.clone(),
def: def.name().clone(),
wrong_extension: def.extension_id().clone(),
});
}
}
resolve_signature_func_exts(
extension,
def.signature_func_mut(),
extensions,
used_extensions,
)?;
Ok(())
}
pub(super) fn resolve_signature_func_exts(
extension: &ExtensionId,
signature: &mut SignatureFunc,
extensions: &WeakExtensionRegistry,
used_extensions: &mut WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
let signature_body = match signature {
SignatureFunc::PolyFuncType(p) => p.body_mut(),
SignatureFunc::CustomValidator(v) => v.poly_func_mut().body_mut(),
SignatureFunc::MissingValidateFunc(p) => p.body_mut(),
SignatureFunc::CustomFunc(_) | SignatureFunc::MissingComputeFunc => {
return Ok(());
}
};
resolve_signature_exts(None, signature_body, extensions, used_extensions)
}