use std::collections::HashMap;
use crate::text::simple_extensions::{
NullabilityHandling as RawNullabilityHandling, Options as RawOptions,
ScalarFunction as RawScalarFunction, ScalarFunctionImplsItem as RawImpl, Type as RawType,
VariadicBehavior as RawVariadicBehavior, VariadicBehaviorParameterConsistency,
};
use super::argument::{ArgumentsItem, ArgumentsItemError};
use super::extensions::TypeContext;
use super::type_ast::{TypeExpr, TypeParseError};
use super::types::{ConcreteType, ExtensionTypeError};
use crate::parse::Parse;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ScalarFunctionError {
#[error("Scalar function '{name}' must have at least one implementation")]
NoImplementations {
name: String,
},
#[error("Variadic behavior {field} must be a non-negative integer, got {value}")]
InvalidVariadicBehavior {
field: String,
value: f64,
},
#[error("Variadic min ({min}) must be less than or equal to max ({max})")]
VariadicMinGreaterThanMax {
min: u32,
max: u32,
},
#[error("Argument error: {0}")]
ArgumentError(#[from] ArgumentsItemError),
#[error("Type error: {0}")]
TypeError(#[from] ExtensionTypeError),
#[error("Type parse error: {0}")]
TypeParseError(#[from] TypeParseError),
#[error("Not yet implemented: {0}")]
NotYetImplemented(String),
}
#[derive(Clone, Debug, PartialEq)]
pub struct ScalarFunction {
pub name: String,
pub description: Option<String>,
pub impls: Vec<Impl>,
}
impl ScalarFunction {
pub(super) fn from_raw(
raw: RawScalarFunction,
ctx: &mut TypeContext,
) -> Result<Self, ScalarFunctionError> {
if raw.impls.is_empty() {
return Err(ScalarFunctionError::NoImplementations { name: raw.name });
}
let impls = raw
.impls
.into_iter()
.map(|impl_| Impl::from_raw(impl_, ctx))
.collect::<Result<Vec<_>, _>>()?;
Ok(ScalarFunction {
name: raw.name,
description: raw.description,
impls,
})
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct Impl {
pub args: Vec<ArgumentsItem>,
pub options: Options,
pub variadic: Option<VariadicBehavior>,
pub session_dependent: bool,
pub deterministic: bool,
pub nullability: NullabilityHandling,
pub return_type: ConcreteType,
pub implementation: HashMap<String, String>,
}
impl Impl {
pub(super) fn from_raw(
raw: RawImpl,
ctx: &mut TypeContext,
) -> Result<Self, ScalarFunctionError> {
let return_type = match raw.return_.0 {
RawType::String(s) => {
if s.contains('\n') {
return Err(ScalarFunctionError::NotYetImplemented(
"Type derivation expressions - issue #449".to_string(),
));
}
let type_expr = TypeExpr::parse(&s)?;
type_expr.visit_references(&mut |name| ctx.linked(name));
match ConcreteType::try_from(type_expr) {
Ok(concrete) => concrete,
Err(ExtensionTypeError::InvalidAnyTypeVariable { .. })
| Err(ExtensionTypeError::InvalidParameter(_))
| Err(ExtensionTypeError::InvalidParameterKind { .. }) => {
return Err(ScalarFunctionError::NotYetImplemented(
"Type variables in function signatures - issue #452".to_string(),
));
}
Err(ExtensionTypeError::UnknownTypeName { name }) => {
return Err(ScalarFunctionError::TypeError(
ExtensionTypeError::UnknownTypeName { name },
));
}
Err(e) => return Err(ScalarFunctionError::TypeError(e)),
}
}
RawType::Object(_) => {
return Err(ScalarFunctionError::NotYetImplemented(
"Struct return types - issue #450".to_string(),
));
}
};
let variadic = raw.variadic.map(|v| v.try_into()).transpose()?;
let args = match raw.args {
Some(a) => {
a.0.into_iter()
.map(|raw_arg| raw_arg.parse(ctx))
.collect::<Result<Vec<_>, _>>()?
}
None => Vec::new(),
};
Ok(Impl {
args,
options: raw.options.as_ref().map(Options::from).unwrap_or_default(),
variadic,
session_dependent: raw.session_dependent.map(|b| b.0).unwrap_or(false),
deterministic: raw.deterministic.map(|b| b.0).unwrap_or(true),
nullability: raw
.nullability
.map(Into::into)
.unwrap_or(NullabilityHandling::Mirror),
return_type,
implementation: raw
.implementation
.map(|i| i.0.into_iter().collect())
.unwrap_or_default(),
})
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct VariadicBehavior {
pub min: u32,
pub max: Option<u32>,
pub parameter_consistency: Option<ParameterConsistency>,
}
#[derive(Clone, Debug, PartialEq)]
pub enum ParameterConsistency {
Consistent,
Inconsistent,
}
impl From<VariadicBehaviorParameterConsistency> for ParameterConsistency {
fn from(raw: VariadicBehaviorParameterConsistency) -> Self {
match raw {
VariadicBehaviorParameterConsistency::Consistent => ParameterConsistency::Consistent,
VariadicBehaviorParameterConsistency::Inconsistent => {
ParameterConsistency::Inconsistent
}
}
}
}
impl TryFrom<RawVariadicBehavior> for VariadicBehavior {
type Error = ScalarFunctionError;
fn try_from(raw: RawVariadicBehavior) -> Result<Self, Self::Error> {
fn parse_bound(value: f64, field: &str) -> Result<u32, ScalarFunctionError> {
if value < 0.0 || value.fract() != 0.0 {
return Err(ScalarFunctionError::InvalidVariadicBehavior {
field: field.to_string(),
value,
});
}
Ok(value as u32)
}
let min = raw
.min
.map(|v| parse_bound(v, "min"))
.transpose()?
.unwrap_or(0);
let max = raw.max.map(|v| parse_bound(v, "max")).transpose()?;
if let Some(max_val) = max {
if min > max_val {
return Err(ScalarFunctionError::VariadicMinGreaterThanMax { min, max: max_val });
}
}
Ok(VariadicBehavior {
min,
max,
parameter_consistency: raw.parameter_consistency.map(Into::into),
})
}
}
#[derive(Clone, Debug, PartialEq)]
pub enum NullabilityHandling {
Mirror,
DeclaredOutput,
Discrete,
}
impl From<RawNullabilityHandling> for NullabilityHandling {
fn from(raw: RawNullabilityHandling) -> Self {
match raw {
RawNullabilityHandling::Mirror => NullabilityHandling::Mirror,
RawNullabilityHandling::DeclaredOutput => NullabilityHandling::DeclaredOutput,
RawNullabilityHandling::Discrete => NullabilityHandling::Discrete,
}
}
}
#[derive(Clone, Debug, Default, PartialEq)]
pub struct Options(pub HashMap<String, Vec<String>>);
impl From<&RawOptions> for Options {
fn from(raw: &RawOptions) -> Self {
Options(
raw.0
.iter()
.map(|(k, v)| (k.clone(), v.values.clone()))
.collect(),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_variadic_invalid_values() {
let invalid_cases = vec![
(Some(-1.0), None, "negative min"),
(None, Some(-2.5), "negative max"),
(Some(7.2), None, "non-integer min"),
(None, Some(3.5), "non-integer max"),
(Some(5.0), Some(3.0), "min greater than max"),
];
for (min, max, description) in invalid_cases {
let raw = RawVariadicBehavior {
min,
max,
parameter_consistency: None,
};
assert!(
VariadicBehavior::try_from(raw).is_err(),
"expected error for {}",
description
);
}
}
#[test]
fn test_variadic_valid() {
let raw = RawVariadicBehavior {
min: Some(1.0),
max: Some(5.0),
parameter_consistency: None,
};
let result = VariadicBehavior::try_from(raw).unwrap();
assert_eq!(result.min, 1);
assert_eq!(result.max, Some(5));
}
#[test]
fn test_variadic_none_values() {
let raw = RawVariadicBehavior {
min: None,
max: None,
parameter_consistency: None,
};
let result = VariadicBehavior::try_from(raw).unwrap();
assert_eq!(result.min, 0);
assert_eq!(result.max, None);
}
#[test]
fn test_no_implementations_error() {
use crate::text::simple_extensions::ScalarFunction as RawScalarFunction;
let raw = RawScalarFunction {
name: "empty_function".to_string(),
description: None,
metadata: Default::default(),
impls: vec![],
};
let mut ctx = super::super::extensions::TypeContext::default();
let result = ScalarFunction::from_raw(raw, &mut ctx);
assert!(matches!(
result,
Err(ScalarFunctionError::NoImplementations { name })
if name == "empty_function"
));
}
#[test]
fn test_scalar_function_with_single_impl() {
use crate::text::simple_extensions::{
ReturnValue, ScalarFunction as RawScalarFunction, ScalarFunctionImplsItem, Type,
};
let raw = RawScalarFunction {
name: "add".to_string(),
description: Some("Addition function".to_string()),
metadata: Default::default(),
impls: vec![ScalarFunctionImplsItem {
args: None,
options: None,
variadic: None,
session_dependent: None,
deterministic: None,
nullability: None,
return_: ReturnValue(Type::String("i32".to_string())),
implementation: None,
}],
};
let mut ctx = super::super::extensions::TypeContext::default();
let result = ScalarFunction::from_raw(raw, &mut ctx).unwrap();
assert_eq!(result.name, "add");
assert_eq!(result.description, Some("Addition function".to_string()));
assert_eq!(result.impls.len(), 1);
use super::super::types::{BasicBuiltinType, ConcreteTypeKind};
let return_type = &result.impls[0].return_type;
assert!(!return_type.nullable, "i32 should not be nullable");
assert!(matches!(
&return_type.kind,
ConcreteTypeKind::Builtin(BasicBuiltinType::I32)
));
}
#[test]
fn test_options_conversion() {
use crate::text::simple_extensions::{Options as RawOptions, OptionsValue};
use indexmap::IndexMap;
let mut raw_map = IndexMap::new();
raw_map.insert(
"overflow".to_string(),
OptionsValue {
values: vec!["SILENT".to_string(), "ERROR".to_string()],
description: None,
},
);
let raw = RawOptions(raw_map);
let options = Options::from(&raw);
assert_eq!(options.0.len(), 1);
assert_eq!(
options.0.get("overflow").unwrap(),
&vec!["SILENT".to_string(), "ERROR".to_string()]
);
}
}