1use eure_document::identifier::Identifier;
7use std::collections::HashMap;
8use std::fmt;
9
10#[derive(Debug, Clone, PartialEq)]
16pub enum SynthType {
17 Null,
20
21 Boolean,
23
24 Integer,
26
27 Float,
29
30 Text(Option<String>),
36
37 Array(Box<SynthType>),
40
41 Tuple(Vec<SynthType>),
43
44 Record(SynthRecord),
46
47 Union(SynthUnion),
55
56 Any,
62
63 Never,
67
68 Hole(Option<Identifier>),
73}
74
75#[derive(Debug, Clone, PartialEq)]
77pub struct SynthRecord {
78 pub fields: HashMap<String, SynthField>,
80}
81
82#[derive(Debug, Clone, PartialEq)]
84pub struct SynthField {
85 pub ty: SynthType,
87
88 pub optional: bool,
93}
94
95#[derive(Debug, Clone, PartialEq)]
100pub struct SynthUnion {
101 pub variants: Vec<SynthType>,
108}
109
110impl SynthRecord {
113 pub fn empty() -> Self {
115 Self {
116 fields: HashMap::new(),
117 }
118 }
119
120 pub fn new(fields: impl IntoIterator<Item = (String, SynthField)>) -> Self {
122 Self {
123 fields: fields.into_iter().collect(),
124 }
125 }
126}
127
128impl SynthField {
129 pub fn required(ty: SynthType) -> Self {
131 Self {
132 ty,
133 optional: false,
134 }
135 }
136
137 pub fn optional(ty: SynthType) -> Self {
139 Self { ty, optional: true }
140 }
141}
142
143impl SynthUnion {
144 pub fn from_variants(variants: impl IntoIterator<Item = SynthType>) -> SynthType {
152 let mut flat: Vec<SynthType> = Vec::new();
153
154 for variant in variants {
155 match variant {
156 SynthType::Union(inner) => {
158 for v in inner.variants {
159 if !flat.contains(&v) {
160 flat.push(v);
161 }
162 }
163 }
164 SynthType::Never => {}
166 SynthType::Hole(_) => {}
168 other => {
170 if !flat.contains(&other) {
171 flat.push(other);
172 }
173 }
174 }
175 }
176
177 match flat.len() {
178 0 => SynthType::Never,
179 1 => flat.pop().unwrap(),
180 _ => SynthType::Union(SynthUnion { variants: flat }),
181 }
182 }
183}
184
185impl fmt::Display for SynthType {
188 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
189 match self {
190 SynthType::Null => write!(f, "null"),
191 SynthType::Boolean => write!(f, "boolean"),
192 SynthType::Integer => write!(f, "integer"),
193 SynthType::Float => write!(f, "float"),
194 SynthType::Text(None) => write!(f, "text"),
195 SynthType::Text(Some(lang)) => write!(f, "text.{}", lang),
196 SynthType::Array(inner) => write!(f, "[{}]", inner),
197 SynthType::Tuple(elems) => {
198 write!(f, "(")?;
199 for (i, elem) in elems.iter().enumerate() {
200 if i > 0 {
201 write!(f, ", ")?;
202 }
203 write!(f, "{}", elem)?;
204 }
205 write!(f, ")")
206 }
207 SynthType::Record(rec) => write!(f, "{}", rec),
208 SynthType::Union(union) => write!(f, "{}", union),
209 SynthType::Any => write!(f, "any"),
210 SynthType::Never => write!(f, "never"),
211 SynthType::Hole(None) => write!(f, "!"),
212 SynthType::Hole(Some(id)) => write!(f, "!{}", id),
213 }
214 }
215}
216
217impl fmt::Display for SynthRecord {
218 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
219 write!(f, "{{")?;
220 let mut first = true;
221 for (name, field) in &self.fields {
222 if !first {
223 write!(f, ", ")?;
224 }
225 first = false;
226 write!(f, "{}", name)?;
227 if field.optional {
228 write!(f, "?")?;
229 }
230 write!(f, ": {}", field.ty)?;
231 }
232 write!(f, "}}")
233 }
234}
235
236impl fmt::Display for SynthUnion {
237 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
238 for (i, variant) in self.variants.iter().enumerate() {
239 if i > 0 {
240 write!(f, " | ")?;
241 }
242 write!(f, "{}", variant)?;
243 }
244 Ok(())
245 }
246}
247
248impl SynthType {
251 pub fn is_primitive(&self) -> bool {
253 matches!(
254 self,
255 SynthType::Null
256 | SynthType::Boolean
257 | SynthType::Integer
258 | SynthType::Float
259 | SynthType::Text(_)
260 )
261 }
262
263 pub fn is_compound(&self) -> bool {
265 matches!(
266 self,
267 SynthType::Array(_) | SynthType::Tuple(_) | SynthType::Record(_)
268 )
269 }
270
271 pub fn has_holes(&self) -> bool {
273 match self {
274 SynthType::Hole(_) => true,
275 SynthType::Array(inner) => inner.has_holes(),
276 SynthType::Tuple(elems) => elems.iter().any(|e| e.has_holes()),
277 SynthType::Record(rec) => rec.fields.values().any(|f| f.ty.has_holes()),
278 SynthType::Union(union) => union.variants.iter().any(|v| v.has_holes()),
279 _ => false,
280 }
281 }
282
283 pub fn is_complete(&self) -> bool {
285 match self {
286 SynthType::Hole(_) | SynthType::Any | SynthType::Never => false,
287 SynthType::Array(inner) => inner.is_complete(),
288 SynthType::Tuple(elems) => elems.iter().all(|e| e.is_complete()),
289 SynthType::Record(rec) => rec.fields.values().all(|f| f.ty.is_complete()),
290 SynthType::Union(union) => union.variants.iter().all(|v| v.is_complete()),
291 _ => true,
292 }
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299
300 #[test]
301 fn test_union_flattening() {
302 let inner = SynthUnion::from_variants([SynthType::Integer, SynthType::Boolean]);
304 let outer = SynthUnion::from_variants([inner, SynthType::Text(None)]);
305
306 assert_eq!(
307 outer,
308 SynthType::Union(SynthUnion {
309 variants: vec![
310 SynthType::Integer,
311 SynthType::Boolean,
312 SynthType::Text(None)
313 ]
314 })
315 );
316 }
317
318 #[test]
319 fn test_union_dedup() {
320 let union =
321 SynthUnion::from_variants([SynthType::Integer, SynthType::Integer, SynthType::Boolean]);
322
323 assert_eq!(
324 union,
325 SynthType::Union(SynthUnion {
326 variants: vec![SynthType::Integer, SynthType::Boolean]
327 })
328 );
329 }
330
331 #[test]
332 fn test_union_single_collapses() {
333 let union = SynthUnion::from_variants([SynthType::Integer]);
334 assert_eq!(union, SynthType::Integer);
335 }
336
337 #[test]
338 fn test_union_absorbs_holes() {
339 let union = SynthUnion::from_variants([SynthType::Integer, SynthType::Hole(None)]);
340 assert_eq!(union, SynthType::Integer);
341 }
342
343 #[test]
344 fn test_union_absorbs_never() {
345 let union = SynthUnion::from_variants([SynthType::Integer, SynthType::Never]);
346 assert_eq!(union, SynthType::Integer);
347 }
348
349 #[test]
350 fn test_display() {
351 assert_eq!(SynthType::Integer.to_string(), "integer");
352 assert_eq!(
353 SynthType::Text(Some("rust".to_string())).to_string(),
354 "text.rust"
355 );
356 assert_eq!(
357 SynthType::Array(Box::new(SynthType::Integer)).to_string(),
358 "[integer]"
359 );
360 }
361
362 #[test]
363 fn test_has_holes() {
364 assert!(!SynthType::Integer.has_holes());
365 assert!(SynthType::Hole(None).has_holes());
366 assert!(SynthType::Array(Box::new(SynthType::Hole(None))).has_holes());
367 }
368}