#![allow(dead_code, unused_variables)]
use std::mem;
use std::sync::Arc;
use semver::Version;
use crate::extension::resolution::types::collect_func_type_exts;
use crate::extension::{
Extension, ExtensionId, ExtensionRegistry, ExtensionSet, OpDef, SignatureFunc, TypeDef,
};
use super::types_mut::resolve_func_type_exts;
use super::{ExtensionCollectionError, ExtensionResolutionError, WeakExtensionRegistry};
impl ExtensionRegistry {
pub fn new_with_extension_resolution(
extensions: impl IntoIterator<Item = Extension>,
other_extensions: &WeakExtensionRegistry,
) -> Result<ExtensionRegistry, ExtensionResolutionError> {
Self::new_cyclic(extensions, |mut exts, weak_registry| {
let mut weak_registry = weak_registry.clone();
for (other_id, other_version, other) in other_extensions.iter_all() {
weak_registry.register(other_id.clone(), other_version.clone(), other.clone());
}
for ext in &mut exts {
ext.resolve_references(&weak_registry)?;
}
Ok(exts)
})
}
pub fn extend_with_dependencies(&mut self) -> Result<(), ExtensionCollectionError> {
let mut queue: Vec<Arc<Extension>> = self.iter_all().cloned().collect();
let mut seen: std::collections::BTreeSet<(ExtensionId, crate::extension::Version)> = self
.iter_all()
.map(|ext| (ext.name().clone(), ext.version().clone()))
.collect();
while let Some(ext) = queue.pop() {
let deps = collect_extension_deps(&ext)?;
for dep in deps.iter_all().cloned() {
let dep_key = (dep.name().clone(), dep.version().clone());
if seen.insert(dep_key) {
self.register(dep.clone());
queue.push(dep);
}
}
}
Ok(())
}
}
fn collect_extension_deps(
extension: &Extension,
) -> Result<ExtensionRegistry, ExtensionCollectionError> {
let mut used = WeakExtensionRegistry::default();
let mut missing = ExtensionSet::new();
for (_, op_def) in extension.operations() {
if let Some(signature) = op_def.signature_func().poly_func_type() {
let mut local_missing = ExtensionSet::new();
collect_func_type_exts(signature.body(), &mut used, &mut local_missing);
for ext in local_missing {
missing.insert(ext);
}
}
}
if missing.is_empty() {
Ok(used.try_into().expect("All extensions are valid"))
} else {
Err(ExtensionCollectionError::DroppedTransitiveExtensions {
extension: extension.name().to_string(),
missing_extensions: missing.into_iter().collect(),
})
}
}
impl Extension {
fn resolve_references(
&mut self,
extensions: &WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
let mut used_extensions = WeakExtensionRegistry::default();
for type_def in self.types.values_mut() {
resolve_typedef_exts(
&self.name,
&self.version,
type_def,
extensions,
&mut used_extensions,
)?;
}
let ops = mem::take(&mut self.operations);
for (op_id, mut op_def) in ops {
let op_ref = Arc::<OpDef>::get_mut(&mut op_def).expect("OpDef is not unique");
resolve_opdef_exts(
&self.name,
&self.version,
op_ref,
extensions,
&mut used_extensions,
)?;
self.operations.insert(op_id, op_def);
}
Ok(())
}
}
pub(super) fn resolve_typedef_exts(
extension: &ExtensionId,
version: &Version,
def: &mut TypeDef,
extensions: &WeakExtensionRegistry,
used_extensions: &mut WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
match extensions.get_req(def.extension_id(), Some(version)) {
Some((_, ext)) => {
def.fill_extension_version(version);
*def.extension_mut() = ext.clone();
}
None => {
return Err(ExtensionResolutionError::WrongTypeDefExtension {
extension: extension.clone(),
def: def.name().clone(),
wrong_extension: def.extension_id().clone(),
});
}
}
Ok(())
}
pub(super) fn resolve_opdef_exts(
extension: &ExtensionId,
version: &Version,
def: &mut OpDef,
extensions: &WeakExtensionRegistry,
used_extensions: &mut WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
match extensions.get_req(def.extension_id(), Some(version)) {
Some((_, ext)) => {
def.fill_extension_version(version);
*def.extension_mut() = ext.clone();
}
None => {
return Err(ExtensionResolutionError::WrongOpDefExtension {
extension: extension.clone(),
def: def.name().clone(),
wrong_extension: def.extension_id().clone(),
});
}
}
resolve_signature_func_exts(
extension,
def.signature_func_mut(),
extensions,
used_extensions,
)?;
Ok(())
}
pub(super) fn resolve_signature_func_exts(
extension: &ExtensionId,
signature: &mut SignatureFunc,
extensions: &WeakExtensionRegistry,
used_extensions: &mut WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
let signature_body = match signature {
SignatureFunc::PolyFuncType(p) => p.body_mut(),
SignatureFunc::CustomValidator(v) => v.poly_func_mut().body_mut(),
SignatureFunc::MissingValidateFunc(p) => p.body_mut(),
SignatureFunc::CustomFunc(_) | SignatureFunc::MissingComputeFunc => {
return Ok(());
}
};
resolve_func_type_exts(None, signature_body, extensions, used_extensions)
}