use indexmap::IndexMap;
use std::collections::{HashMap, HashSet};
use std::str::FromStr;
use super::{SimpleExtensionsError, scalar_functions::ScalarFunction, types::CustomType};
use crate::{
parse::{Context, Parse},
text::simple_extensions::SimpleExtensions as RawExtensions,
urn::Urn,
};
#[derive(Clone, Debug, Default)]
pub struct SimpleExtensions {
types: HashMap<String, CustomType>,
scalar_functions: HashMap<String, ScalarFunction>,
}
impl SimpleExtensions {
pub fn add_type(&mut self, custom_type: &CustomType) -> Result<(), SimpleExtensionsError> {
if self.types.contains_key(&custom_type.name) {
return Err(SimpleExtensionsError::DuplicateTypeName {
name: custom_type.name.clone(),
});
}
self.types
.insert(custom_type.name.clone(), custom_type.clone());
Ok(())
}
pub fn get_type(&self, name: &str) -> Option<&CustomType> {
self.types.get(name)
}
pub fn types(&self) -> impl Iterator<Item = &CustomType> {
self.types.values()
}
pub(crate) fn into_types(self) -> HashMap<String, CustomType> {
self.types
}
pub(super) fn add_scalar_function(&mut self, scalar_function: ScalarFunction) {
use std::collections::hash_map::Entry;
match self.scalar_functions.entry(scalar_function.name.clone()) {
Entry::Vacant(e) => {
e.insert(scalar_function);
}
Entry::Occupied(mut e) => {
Self::merge_scalar_function(e.get_mut(), scalar_function);
}
}
}
fn merge_scalar_function(existing: &mut ScalarFunction, new: ScalarFunction) {
existing.impls.extend(new.impls);
existing.description = existing.description.take().or(new.description);
}
pub fn get_scalar_function(&self, name: &str) -> Option<&ScalarFunction> {
self.scalar_functions.get(name)
}
pub fn scalar_functions(&self) -> impl Iterator<Item = &ScalarFunction> {
self.scalar_functions.values()
}
}
#[derive(Debug, Default)]
pub(crate) struct TypeContext {
known: HashSet<String>,
linked: HashSet<String>,
}
impl TypeContext {
pub fn found(&mut self, name: &str) {
self.linked.remove(name);
self.known.insert(name.to_string());
}
pub fn linked(&mut self, name: &str) {
if !self.known.contains(name) {
self.linked.insert(name.to_string());
}
}
}
impl Context for TypeContext {}
impl Parse<TypeContext> for RawExtensions {
type Parsed = (Urn, SimpleExtensions);
type Error = super::SimpleExtensionsError;
fn parse(self, ctx: &mut TypeContext) -> Result<Self::Parsed, Self::Error> {
let RawExtensions {
urn,
types,
scalar_functions,
..
} = self;
let urn = Urn::from_str(&urn)?;
let mut extension = SimpleExtensions::default();
for type_item in types {
let custom_type = Parse::parse(type_item, ctx)?;
extension.add_type(&custom_type)?;
}
for scalar_fn in scalar_functions {
match ScalarFunction::from_raw(scalar_fn, ctx) {
Ok(parsed_fn) => {
extension.add_scalar_function(parsed_fn);
}
Err(super::scalar_functions::ScalarFunctionError::NotYetImplemented(_)) => {
continue;
}
Err(e) => return Err(e.into()),
}
}
if let Some(missing) = ctx.linked.iter().next() {
return Err(super::SimpleExtensionsError::UnresolvedTypeReference {
type_name: missing.clone(),
});
}
Ok((urn, extension))
}
}
impl From<(Urn, SimpleExtensions)> for RawExtensions {
fn from((urn, extension): (Urn, SimpleExtensions)) -> Self {
let types = extension
.into_types()
.into_values()
.map(Into::into)
.collect();
RawExtensions {
urn: urn.to_string(),
aggregate_functions: vec![],
dependencies: IndexMap::new(),
metadata: Default::default(),
scalar_functions: vec![],
type_variations: vec![],
types,
window_functions: vec![],
}
}
}