use eure_document::identifier::Identifier;
use std::collections::HashMap;
use std::fmt;
#[derive(Debug, Clone, PartialEq)]
pub enum SynthType {
Null,
Boolean,
Integer,
Float,
Text(Option<String>),
Array(Box<SynthType>),
Tuple(Vec<SynthType>),
Record(SynthRecord),
Union(SynthUnion),
Any,
Never,
Hole(Option<Identifier>),
}
#[derive(Debug, Clone, PartialEq)]
pub struct SynthRecord {
pub fields: HashMap<String, SynthField>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct SynthField {
pub ty: SynthType,
pub optional: bool,
}
#[derive(Debug, Clone, PartialEq)]
pub struct SynthUnion {
pub variants: Vec<SynthType>,
}
impl SynthRecord {
pub fn empty() -> Self {
Self {
fields: HashMap::new(),
}
}
pub fn new(fields: impl IntoIterator<Item = (String, SynthField)>) -> Self {
Self {
fields: fields.into_iter().collect(),
}
}
}
impl SynthField {
pub fn required(ty: SynthType) -> Self {
Self {
ty,
optional: false,
}
}
pub fn optional(ty: SynthType) -> Self {
Self { ty, optional: true }
}
}
impl SynthUnion {
pub fn from_variants(variants: impl IntoIterator<Item = SynthType>) -> SynthType {
let mut flat: Vec<SynthType> = Vec::new();
for variant in variants {
match variant {
SynthType::Union(inner) => {
for v in inner.variants {
if !flat.contains(&v) {
flat.push(v);
}
}
}
SynthType::Never => {}
SynthType::Hole(_) => {}
other => {
if !flat.contains(&other) {
flat.push(other);
}
}
}
}
match flat.len() {
0 => SynthType::Never,
1 => flat.pop().unwrap(),
_ => SynthType::Union(SynthUnion { variants: flat }),
}
}
}
impl fmt::Display for SynthType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SynthType::Null => write!(f, "null"),
SynthType::Boolean => write!(f, "boolean"),
SynthType::Integer => write!(f, "integer"),
SynthType::Float => write!(f, "float"),
SynthType::Text(None) => write!(f, "text"),
SynthType::Text(Some(lang)) => write!(f, "text.{}", lang),
SynthType::Array(inner) => write!(f, "[{}]", inner),
SynthType::Tuple(elems) => {
write!(f, "(")?;
for (i, elem) in elems.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", elem)?;
}
write!(f, ")")
}
SynthType::Record(rec) => write!(f, "{}", rec),
SynthType::Union(union) => write!(f, "{}", union),
SynthType::Any => write!(f, "any"),
SynthType::Never => write!(f, "never"),
SynthType::Hole(None) => write!(f, "!"),
SynthType::Hole(Some(id)) => write!(f, "!{}", id),
}
}
}
impl fmt::Display for SynthRecord {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{{")?;
let mut first = true;
for (name, field) in &self.fields {
if !first {
write!(f, ", ")?;
}
first = false;
write!(f, "{}", name)?;
if field.optional {
write!(f, "?")?;
}
write!(f, ": {}", field.ty)?;
}
write!(f, "}}")
}
}
impl fmt::Display for SynthUnion {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for (i, variant) in self.variants.iter().enumerate() {
if i > 0 {
write!(f, " | ")?;
}
write!(f, "{}", variant)?;
}
Ok(())
}
}
impl SynthType {
pub fn is_primitive(&self) -> bool {
matches!(
self,
SynthType::Null
| SynthType::Boolean
| SynthType::Integer
| SynthType::Float
| SynthType::Text(_)
)
}
pub fn is_compound(&self) -> bool {
matches!(
self,
SynthType::Array(_) | SynthType::Tuple(_) | SynthType::Record(_)
)
}
pub fn has_holes(&self) -> bool {
match self {
SynthType::Hole(_) => true,
SynthType::Array(inner) => inner.has_holes(),
SynthType::Tuple(elems) => elems.iter().any(|e| e.has_holes()),
SynthType::Record(rec) => rec.fields.values().any(|f| f.ty.has_holes()),
SynthType::Union(union) => union.variants.iter().any(|v| v.has_holes()),
_ => false,
}
}
pub fn is_complete(&self) -> bool {
match self {
SynthType::Hole(_) | SynthType::Any | SynthType::Never => false,
SynthType::Array(inner) => inner.is_complete(),
SynthType::Tuple(elems) => elems.iter().all(|e| e.is_complete()),
SynthType::Record(rec) => rec.fields.values().all(|f| f.ty.is_complete()),
SynthType::Union(union) => union.variants.iter().all(|v| v.is_complete()),
_ => true,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_union_flattening() {
let inner = SynthUnion::from_variants([SynthType::Integer, SynthType::Boolean]);
let outer = SynthUnion::from_variants([inner, SynthType::Text(None)]);
assert_eq!(
outer,
SynthType::Union(SynthUnion {
variants: vec![
SynthType::Integer,
SynthType::Boolean,
SynthType::Text(None)
]
})
);
}
#[test]
fn test_union_dedup() {
let union =
SynthUnion::from_variants([SynthType::Integer, SynthType::Integer, SynthType::Boolean]);
assert_eq!(
union,
SynthType::Union(SynthUnion {
variants: vec![SynthType::Integer, SynthType::Boolean]
})
);
}
#[test]
fn test_union_single_collapses() {
let union = SynthUnion::from_variants([SynthType::Integer]);
assert_eq!(union, SynthType::Integer);
}
#[test]
fn test_union_absorbs_holes() {
let union = SynthUnion::from_variants([SynthType::Integer, SynthType::Hole(None)]);
assert_eq!(union, SynthType::Integer);
}
#[test]
fn test_union_absorbs_never() {
let union = SynthUnion::from_variants([SynthType::Integer, SynthType::Never]);
assert_eq!(union, SynthType::Integer);
}
#[test]
fn test_display() {
assert_eq!(SynthType::Integer.to_string(), "integer");
assert_eq!(
SynthType::Text(Some("rust".to_string())).to_string(),
"text.rust"
);
assert_eq!(
SynthType::Array(Box::new(SynthType::Integer)).to_string(),
"[integer]"
);
}
#[test]
fn test_has_holes() {
assert!(!SynthType::Integer.has_holes());
assert!(SynthType::Hole(None).has_holes());
assert!(SynthType::Array(Box::new(SynthType::Hole(None))).has_holes());
}
}