use std::collections::{HashMap, hash_map::Entry};
use super::{ExtensionFile, SimpleExtensions, SimpleExtensionsError, types::CustomType};
use crate::urn::Urn;
#[derive(Debug)]
pub struct Registry {
extensions: HashMap<Urn, SimpleExtensions>,
}
impl Registry {
pub fn new<I: IntoIterator<Item = ExtensionFile>>(
extensions: I,
) -> Result<Self, SimpleExtensionsError> {
let mut map = HashMap::new();
for ExtensionFile { urn, extension } in extensions {
match map.entry(urn.clone()) {
Entry::Occupied(_) => return Err(SimpleExtensionsError::DuplicateUrn(urn)),
Entry::Vacant(entry) => {
entry.insert(extension);
}
}
}
Ok(Self { extensions: map })
}
pub fn extensions(&self) -> impl Iterator<Item = (&Urn, &SimpleExtensions)> {
self.extensions.iter()
}
#[cfg(feature = "extensions")]
pub fn from_core_extensions() -> Self {
use crate::extensions::EXTENSIONS;
let extensions: HashMap<Urn, SimpleExtensions> = EXTENSIONS
.iter()
.filter_map(|(orig_urn, simple_extensions)| {
let urn_str = orig_urn.to_string();
if urn_str == "extension:io.substrait:extension_types" ||
urn_str == "extension:io.substrait:unknown" {
return None;
}
let ExtensionFile { urn, extension } = ExtensionFile::create(simple_extensions.clone())
.unwrap_or_else(|err| panic!("Core extensions should be valid, but failed to create extension file for {orig_urn}: {err}"));
debug_assert_eq!(orig_urn, &urn);
Some((urn, extension))
})
.collect();
Self { extensions }
}
fn get_extension(&self, urn: &Urn) -> Option<&SimpleExtensions> {
self.extensions.get(urn)
}
pub fn get_type(&self, urn: &Urn, name: &str) -> Option<&CustomType> {
self.get_extension(urn)?.get_type(name)
}
pub fn get_scalar_function(&self, urn: &Urn, name: &str) -> Option<&super::ScalarFunction> {
self.get_extension(urn)?.get_scalar_function(name)
}
}
#[cfg(test)]
mod tests {
use super::{ExtensionFile, Registry};
use crate::parse::text::simple_extensions::{
SimpleExtensionsError, scalar_functions::ScalarFunctionError, types::ExtensionTypeError,
};
use crate::text::simple_extensions::{SimpleExtensions, SimpleExtensionsTypesItem};
use crate::urn::Urn;
use std::str::FromStr;
fn extension_file(urn: &str, type_names: &[&str]) -> ExtensionFile {
let types = type_names
.iter()
.map(|name| SimpleExtensionsTypesItem {
name: (*name).to_string(),
description: None,
metadata: Default::default(),
parameters: None,
structure: None,
variadic: None,
})
.collect();
let raw = SimpleExtensions {
scalar_functions: vec![],
aggregate_functions: vec![],
window_functions: vec![],
dependencies: Default::default(),
metadata: Default::default(),
type_variations: vec![],
types,
urn: urn.to_string(),
};
ExtensionFile::create(raw).expect("valid extension file")
}
#[test]
fn test_registry_iteration() {
let urns = vec![
"extension:example.com:first",
"extension:example.com:second",
];
let registry =
Registry::new(urns.iter().map(|&urn| extension_file(urn, &["type"]))).unwrap();
let collected: Vec<&Urn> = registry.extensions().map(|(urn, _)| urn).collect();
assert_eq!(collected.len(), 2);
for urn in urns {
assert!(
collected
.iter()
.any(|candidate| candidate.to_string() == urn)
);
}
}
#[test]
fn test_type_lookup() {
let urn = Urn::from_str("extension:example.com:test").unwrap();
let registry =
Registry::new(vec![extension_file(&urn.to_string(), &["test_type"])]).unwrap();
let other_urn = Urn::from_str("extension:example.com:other").unwrap();
let cases = vec![
(&urn, "test_type", true),
(&urn, "missing", false),
(&other_urn, "test_type", false),
];
for (query_urn, type_name, expected) in cases {
assert_eq!(
registry.get_type(query_urn, type_name).is_some(),
expected,
"unexpected lookup result for {query_urn}:{type_name}"
);
}
}
#[cfg(feature = "extensions")]
#[test]
fn test_from_core_extensions() {
let registry = Registry::from_core_extensions();
assert!(registry.extensions().count() > 0);
let urn = Urn::from_str("extension:io.substrait:functions_geometry").unwrap();
let core_extension = registry
.get_extension(&urn)
.expect("Should find functions_geometry extension");
let geometry_type = core_extension.get_type("geometry");
assert!(
geometry_type.is_some(),
"Should find 'geometry' type in functions_geometry extension"
);
let type_via_registry = registry.get_type(&urn, "geometry");
assert!(type_via_registry.is_some());
let extension_types_urn = Urn::from_str("extension:io.substrait:extension_types").unwrap();
assert!(
registry.get_extension(&extension_types_urn).is_none(),
"extension_types should be skipped due to missing u! prefix bug"
);
}
#[test]
fn test_unknown_type_without_prefix_fails() {
use crate::text::simple_extensions;
let invalid_extension = SimpleExtensions {
scalar_functions: vec![simple_extensions::ScalarFunction {
name: "bad_function".to_string(),
description: None,
metadata: Default::default(),
impls: vec![simple_extensions::ScalarFunctionImplsItem {
args: None,
options: None,
variadic: None,
session_dependent: None,
deterministic: None,
nullability: None,
return_: simple_extensions::ReturnValue(simple_extensions::Type::String(
"point".to_string(), )),
implementation: None,
}],
}],
aggregate_functions: vec![],
window_functions: vec![],
dependencies: Default::default(),
metadata: Default::default(),
type_variations: vec![],
types: vec![],
urn: "extension:example.com:invalid".to_string(),
};
let result = ExtensionFile::create(invalid_extension);
assert!(
result.is_err(),
"Should fail when type is missing u! prefix"
);
match result {
Err(SimpleExtensionsError::ScalarFunctionError(ScalarFunctionError::TypeError(
ExtensionTypeError::UnknownTypeName { name },
))) => {
assert_eq!(name, "point");
}
other => panic!("Expected UnknownTypeName error, got {:?}", other),
}
}
fn extension_with_custom_type_reference(
urn: &str,
function_name: &str,
return_type: &str,
defined_types: Vec<&str>,
) -> SimpleExtensions {
use crate::text::simple_extensions;
SimpleExtensions {
scalar_functions: vec![simple_extensions::ScalarFunction {
name: function_name.to_string(),
description: None,
metadata: Default::default(),
impls: vec![simple_extensions::ScalarFunctionImplsItem {
args: None,
options: None,
variadic: None,
session_dependent: None,
deterministic: None,
nullability: None,
return_: simple_extensions::ReturnValue(simple_extensions::Type::String(
return_type.to_string(),
)),
implementation: None,
}],
}],
aggregate_functions: vec![],
window_functions: vec![],
dependencies: Default::default(),
metadata: Default::default(),
type_variations: vec![],
types: defined_types
.into_iter()
.map(|name| SimpleExtensionsTypesItem {
name: name.to_string(),
description: None,
metadata: Default::default(),
parameters: None,
structure: None,
variadic: None,
})
.collect(),
urn: urn.to_string(),
}
}
#[test]
fn test_custom_type_reference_valid() {
let extension = extension_with_custom_type_reference(
"extension:example.com:valid",
"get_point",
"u!point",
vec!["point"],
);
let result = ExtensionFile::create(extension);
assert!(
result.is_ok(),
"Should succeed when referenced type exists with u! prefix"
);
}
#[test]
fn test_custom_type_reference_missing() {
let extension = extension_with_custom_type_reference(
"extension:example.com:invalid",
"get_rectangle",
"u!rectangle",
vec![], );
let result = ExtensionFile::create(extension);
assert!(
result.is_err(),
"Should fail when referenced type doesn't exist"
);
match result {
Err(SimpleExtensionsError::UnresolvedTypeReference { type_name }) => {
assert_eq!(type_name, "rectangle");
}
other => panic!("Expected UnresolvedTypeReference error, got {:?}", other),
}
}
#[cfg(feature = "extensions")]
#[test]
fn test_scalar_function_parses_completely() {
use super::super::{
argument::ArgumentsItem,
scalar_functions::{Impl, NullabilityHandling, Options},
types::*,
};
use crate::parse::Parse;
use crate::text::simple_extensions;
use std::collections::HashMap;
let registry = Registry::from_core_extensions();
let functions_arithmetic_urn =
Urn::from_str("extension:io.substrait:functions_arithmetic").unwrap();
let add = registry
.get_scalar_function(&functions_arithmetic_urn, "add")
.expect("add function should exist");
assert_eq!(add.name, "add");
assert_eq!(add.description, Some("Add two values.".to_string()));
assert!(
!add.impls.is_empty(),
"add should have at least one implementation"
);
let mut ctx = super::super::extensions::TypeContext::default();
let expected_impl = Impl {
args: vec![
ArgumentsItem::ValueArgument(
simple_extensions::ValueArg {
name: Some("x".to_string()),
description: None,
value: simple_extensions::Type::String("i8".to_string()),
constant: None,
}
.parse(&mut ctx)
.unwrap(),
),
ArgumentsItem::ValueArgument(
simple_extensions::ValueArg {
name: Some("y".to_string()),
description: None,
value: simple_extensions::Type::String("i8".to_string()),
constant: None,
}
.parse(&mut ctx)
.unwrap(),
),
],
options: Options({
let mut map = HashMap::new();
map.insert(
"overflow".to_string(),
vec![
"SILENT".to_string(),
"SATURATE".to_string(),
"ERROR".to_string(),
],
);
map
}),
variadic: None,
session_dependent: false,
deterministic: true,
nullability: NullabilityHandling::Mirror,
return_type: ConcreteType {
kind: ConcreteTypeKind::Builtin(BasicBuiltinType::I8),
nullable: false,
},
implementation: HashMap::new(),
};
assert_eq!(&add.impls[0], &expected_impl);
}
}