use std::{
collections::{HashMap, HashSet},
fmt,
sync::Arc,
};
use crate::{
arith::{CompleteConstraints, Constraint, ConstraintSet, Num},
types::ParamQuantifier,
LengthVar, PrimitiveType, Tuple, TupleLen, Type, TypeVar,
};
#[derive(Debug, Clone)]
pub(crate) struct ParamConstraints<Prim: PrimitiveType> {
pub type_params: HashMap<usize, CompleteConstraints<Prim>>,
pub static_lengths: HashSet<usize>,
}
impl<Prim: PrimitiveType> Default for ParamConstraints<Prim> {
fn default() -> Self {
Self {
type_params: HashMap::new(),
static_lengths: HashSet::new(),
}
}
}
impl<Prim: PrimitiveType> fmt::Display for ParamConstraints<Prim> {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
if !self.static_lengths.is_empty() {
formatter.write_str("len! ")?;
for (i, len) in self.static_lengths.iter().enumerate() {
write!(formatter, "{}", LengthVar::param_str(*len))?;
if i + 1 < self.static_lengths.len() {
formatter.write_str(", ")?;
}
}
if !self.type_params.is_empty() {
formatter.write_str("; ")?;
}
}
let type_param_count = self.type_params.len();
for (i, (idx, constraints)) in self.type_params().enumerate() {
write!(formatter, "'{}: {}", TypeVar::param_str(idx), constraints)?;
if i + 1 < type_param_count {
formatter.write_str(", ")?;
}
}
Ok(())
}
}
impl<Prim: PrimitiveType> ParamConstraints<Prim> {
fn is_empty(&self) -> bool {
self.type_params.is_empty() && self.static_lengths.is_empty()
}
fn type_params(&self) -> impl Iterator<Item = (usize, &CompleteConstraints<Prim>)> + '_ {
let mut type_params: Vec<_> = self.type_params.iter().map(|(&idx, c)| (idx, c)).collect();
type_params.sort_unstable_by_key(|(idx, _)| *idx);
type_params.into_iter()
}
}
#[derive(Debug)]
pub(crate) struct FnParams<Prim: PrimitiveType> {
pub type_params: Vec<(usize, CompleteConstraints<Prim>)>,
pub len_params: Vec<(usize, bool)>,
pub constraints: Option<ParamConstraints<Prim>>,
}
impl<Prim: PrimitiveType> Default for FnParams<Prim> {
fn default() -> Self {
Self {
type_params: vec![],
len_params: vec![],
constraints: None,
}
}
}
impl<Prim: PrimitiveType> PartialEq for FnParams<Prim> {
fn eq(&self, other: &Self) -> bool {
self.type_params == other.type_params && self.len_params == other.len_params
}
}
impl<Prim: PrimitiveType> FnParams<Prim> {
fn is_empty(&self) -> bool {
self.len_params.is_empty() && self.type_params.is_empty()
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct Function<Prim: PrimitiveType = Num> {
pub(crate) args: Tuple<Prim>,
pub(crate) return_type: Type<Prim>,
pub(crate) params: Option<Arc<FnParams<Prim>>>,
}
impl<Prim: PrimitiveType> fmt::Display for Function<Prim> {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
let constraints = self
.params
.as_ref()
.and_then(|params| params.constraints.as_ref());
if let Some(constraints) = constraints {
if !constraints.is_empty() {
write!(formatter, "for<{}> ", constraints)?;
}
}
self.args.format_as_tuple(formatter)?;
write!(formatter, " -> {}", self.return_type)?;
Ok(())
}
}
impl<Prim: PrimitiveType> Function<Prim> {
pub(crate) fn new(args: Tuple<Prim>, return_type: Type<Prim>) -> Self {
Self {
args,
return_type,
params: None,
}
}
pub fn builder() -> FunctionBuilder<Prim> {
FunctionBuilder::default()
}
pub fn args(&self) -> &Tuple<Prim> {
&self.args
}
pub fn return_type(&self) -> &Type<Prim> {
&self.return_type
}
pub(crate) fn set_params(&mut self, params: FnParams<Prim>) {
self.params = Some(Arc::new(params));
}
pub(crate) fn is_parametric(&self) -> bool {
self.params
.as_ref()
.map_or(false, |params| !params.is_empty())
}
pub fn is_concrete(&self) -> bool {
self.args.is_concrete() && self.return_type.is_concrete()
}
pub fn with_constraints<C: Constraint<Prim>>(
self,
indexes: &[usize],
constraint: C,
) -> FnWithConstraints<Prim> {
assert!(
self.params.is_none(),
"Cannot attach constraints to a function with computed params: `{}`",
self
);
let constraints = CompleteConstraints::from(ConstraintSet::just(constraint));
let type_params = indexes
.iter()
.map(|&idx| (idx, constraints.clone()))
.collect();
FnWithConstraints {
function: self,
constraints: ParamConstraints {
type_params,
static_lengths: HashSet::new(),
},
}
}
pub fn with_static_lengths(self, indexes: &[usize]) -> FnWithConstraints<Prim> {
assert!(
self.params.is_none(),
"Cannot attach constraints to a function with computed params: `{}`",
self
);
FnWithConstraints {
function: self,
constraints: ParamConstraints {
type_params: HashMap::new(),
static_lengths: indexes.iter().copied().collect(),
},
}
}
}
#[derive(Debug)]
pub struct FnWithConstraints<Prim: PrimitiveType> {
function: Function<Prim>,
constraints: ParamConstraints<Prim>,
}
impl<Prim: PrimitiveType> fmt::Display for FnWithConstraints<Prim> {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.constraints.is_empty() {
fmt::Display::fmt(&self.function, formatter)
} else {
write!(formatter, "for<{}> {}", self.constraints, self.function)
}
}
}
impl<Prim: PrimitiveType> FnWithConstraints<Prim> {
pub fn with_constraint<C>(mut self, indexes: &[usize], constraint: &C) -> Self
where
C: Constraint<Prim> + Clone,
{
for &i in indexes {
let constraints = self.constraints.type_params.entry(i).or_default();
constraints.simple.insert(constraint.clone());
}
self
}
pub fn with_static_lengths(mut self, indexes: &[usize]) -> FnWithConstraints<Prim> {
let indexes = indexes.iter().copied();
self.constraints.static_lengths.extend(indexes);
self
}
}
impl<Prim: PrimitiveType> From<FnWithConstraints<Prim>> for Function<Prim> {
fn from(value: FnWithConstraints<Prim>) -> Self {
let mut function = value.function;
ParamQuantifier::set_params(&mut function, value.constraints);
function
}
}
impl<Prim: PrimitiveType> From<FnWithConstraints<Prim>> for Type<Prim> {
fn from(value: FnWithConstraints<Prim>) -> Self {
Function::from(value).into()
}
}
#[derive(Debug, Clone)]
pub struct FunctionBuilder<Prim: PrimitiveType = Num> {
args: Tuple<Prim>,
}
impl<Prim: PrimitiveType> Default for FunctionBuilder<Prim> {
fn default() -> Self {
Self {
args: Tuple::empty(),
}
}
}
impl<Prim: PrimitiveType> FunctionBuilder<Prim> {
pub fn with_arg(mut self, arg: impl Into<Type<Prim>>) -> Self {
self.args.push(arg.into());
self
}
pub fn with_varargs(
mut self,
element: impl Into<Type<Prim>>,
len: impl Into<TupleLen>,
) -> Self {
self.args.set_middle(element.into(), len.into());
self
}
pub fn returning(self, return_type: impl Into<Type<Prim>>) -> Function<Prim> {
Function::new(self.args, return_type.into())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{arith::Linearity, UnknownLen};
#[test]
fn constraints_display() {
let type_constraints = ConstraintSet::<Num>::just(Linearity);
let type_constraints = CompleteConstraints::from(type_constraints);
let constraints = ParamConstraints {
type_params: vec![(0, type_constraints.clone())].into_iter().collect(),
static_lengths: HashSet::new(),
};
assert_eq!(constraints.to_string(), "'T: Lin");
let constraints: ParamConstraints<Num> = ParamConstraints {
type_params: vec![(0, type_constraints)].into_iter().collect(),
static_lengths: vec![0].into_iter().collect(),
};
assert_eq!(constraints.to_string(), "len! N; 'T: Lin");
}
#[test]
fn fn_with_constraints_display() {
let sum_fn = <Function>::builder()
.with_arg(Type::param(0).repeat(UnknownLen::param(0)))
.returning(Type::param(0))
.with_constraints(&[0], Linearity);
assert_eq!(sum_fn.to_string(), "for<'T: Lin> (['T; N]) -> 'T");
}
#[test]
fn fn_builder_with_quantified_arg() {
let sum_fn: Function = Function::builder()
.with_arg(Type::NUM.repeat(UnknownLen::param(0)))
.returning(Type::NUM)
.with_constraints(&[], Linearity)
.into();
assert_eq!(sum_fn.to_string(), "([Num; N]) -> Num");
let complex_fn: Function = Function::builder()
.with_arg(Type::NUM)
.with_arg(sum_fn.clone())
.returning(Type::NUM)
.with_constraints(&[], Linearity)
.into();
assert_eq!(complex_fn.to_string(), "(Num, ([Num; N]) -> Num) -> Num");
let other_complex_fn: Function = Function::builder()
.with_varargs(Type::NUM, UnknownLen::param(0))
.with_arg(sum_fn)
.returning(Type::NUM)
.with_constraints(&[], Linearity)
.into();
assert_eq!(
other_complex_fn.to_string(),
"(...[Num; N], ([Num; N]) -> Num) -> Num"
);
}
}