Skip to main content

lemma/planning/
types.rs

1//! Type registry for managing custom type definitions and resolution
2//!
3//! This module provides the `TypeResolver` (formerly TypeRegistry) which handles:
4//! - Registering user-defined types for each spec
5//! - Resolving type hierarchies and inheritance chains
6//! - Detecting and preventing circular dependencies
7//! - Applying constraints to create final type specifications
8
9use crate::error::Error;
10use crate::parsing::ast::{self as ast, CommandArg, LemmaSpec, Reference, TypeDef};
11use crate::planning::semantics::{self, LemmaType, TypeExtends, TypeSpecification};
12use crate::planning::validation::validate_type_specifications;
13
14use std::collections::{HashMap, HashSet};
15use std::sync::Arc;
16
17/// Fully resolved types for a single spec
18/// After resolution, all imports are inlined - specs are independent
19#[derive(Debug, Clone)]
20pub struct ResolvedSpecTypes {
21    /// Named types: type_name -> fully resolved type
22    pub named_types: HashMap<String, LemmaType>,
23
24    /// Inline type definitions: fact reference -> fully resolved type
25    pub inline_type_definitions: HashMap<Reference, LemmaType>,
26
27    /// Unit index: unit_name -> type that defines it
28    /// Built during resolution - if unit appears in multiple types, resolution fails
29    pub unit_index: HashMap<String, LemmaType>,
30}
31
32/// Registry for managing and resolving custom types
33///
34/// Types are organized per spec (keyed by Arc<LemmaSpec>) and support inheritance through parent references.
35/// The registry handles cycle detection and accumulates constraints through the inheritance chain.
36/// name_to_arc maps base spec name to the earliest Arc for that name (by effective_from) for cross-spec resolution.
37#[derive(Debug, Clone)]
38pub struct TypeResolver {
39    named_types: HashMap<Arc<LemmaSpec>, HashMap<String, TypeDef>>,
40    inline_type_definitions: HashMap<Arc<LemmaSpec>, HashMap<Reference, TypeDef>>,
41    /// Earliest spec Arc per base name, for cross-spec type resolution.
42    name_to_arc: HashMap<String, Arc<LemmaSpec>>,
43}
44
45impl TypeResolver {
46    pub fn new() -> Self {
47        TypeResolver {
48            named_types: HashMap::new(),
49            inline_type_definitions: HashMap::new(),
50            name_to_arc: HashMap::new(),
51        }
52    }
53
54    /// Register all named types from a spec (skips inline types).
55    pub fn register_all(&mut self, spec: &Arc<LemmaSpec>) -> Vec<Error> {
56        let mut errors = Vec::new();
57        for type_def in &spec.types {
58            let type_name = match type_def {
59                ast::TypeDef::Regular { name, .. } | ast::TypeDef::Import { name, .. } => {
60                    Some(name.as_str())
61                }
62                ast::TypeDef::Inline { .. } => None,
63            };
64            if let Some(name) = type_name {
65                if let Err(e) = crate::limits::check_max_length(
66                    name,
67                    crate::limits::MAX_TYPE_NAME_LENGTH,
68                    "type",
69                ) {
70                    errors.push(e);
71                    continue;
72                }
73            }
74            if let Err(e) = self.register_type(spec, type_def.clone()) {
75                errors.push(e);
76            }
77        }
78        errors
79    }
80
81    /// Resolve all named types for every spec and validate their specifications.
82    /// Produces an entry for every spec (even those without named types) because
83    /// every spec needs a unit_index containing at least the primitive ratio units.
84    pub fn resolve(
85        &self,
86        all_specs: impl IntoIterator<Item = Arc<LemmaSpec>>,
87    ) -> (HashMap<Arc<LemmaSpec>, ResolvedSpecTypes>, Vec<Error>) {
88        let mut result = HashMap::new();
89        let mut errors = Vec::new();
90
91        for spec_arc in all_specs {
92            let spec_arc = &spec_arc;
93            match self.resolve_named_types(spec_arc) {
94                Ok(resolved_types) => {
95                    for (type_name, lemma_type) in &resolved_types.named_types {
96                        let source = spec_arc
97                            .types
98                            .iter()
99                            .find(|td| match td {
100                                ast::TypeDef::Regular { name, .. }
101                                | ast::TypeDef::Import { name, .. } => name == type_name,
102                                ast::TypeDef::Inline { .. } => false,
103                            })
104                            .map(|td| td.source_location().clone())
105                            .unwrap_or_else(|| {
106                                unreachable!(
107                                    "BUG: resolved named type '{}' has no corresponding TypeDef in spec '{}'",
108                                    type_name, spec_arc.name
109                                )
110                            });
111                        let mut spec_errors = validate_type_specifications(
112                            &lemma_type.specifications,
113                            type_name,
114                            &source,
115                        );
116                        errors.append(&mut spec_errors);
117                    }
118                    result.insert(Arc::clone(spec_arc), resolved_types);
119                }
120                Err(es) => errors.extend(es),
121            }
122        }
123
124        (result, errors)
125    }
126
127    /// Register a user-defined type for a given spec (keyed by Arc<LemmaSpec>).
128    /// Updates name_to_arc to keep the earliest spec per base name for cross-spec resolution.
129    pub fn register_type(&mut self, spec: &Arc<LemmaSpec>, def: TypeDef) -> Result<(), Error> {
130        self.name_to_arc
131            .entry(spec.name.clone())
132            .and_modify(|existing| {
133                if spec.effective_from() < existing.effective_from() {
134                    *existing = Arc::clone(spec);
135                }
136            })
137            .or_insert_with(|| Arc::clone(spec));
138
139        let def_loc = def.source_location().clone();
140        let spec_name = &spec.name;
141        match &def {
142            TypeDef::Regular { name, .. } | TypeDef::Import { name, .. } => {
143                let spec_types = self.named_types.entry(Arc::clone(spec)).or_default();
144                if spec_types.contains_key(name) {
145                    return Err(Error::validation(
146                        format!("Type '{}' is already defined in spec '{}'", name, spec_name),
147                        Some(def_loc.clone()),
148                        None::<String>,
149                    ));
150                }
151                spec_types.insert(name.clone(), def);
152            }
153            TypeDef::Inline { fact_ref, .. } => {
154                let spec_inline_types = self
155                    .inline_type_definitions
156                    .entry(Arc::clone(spec))
157                    .or_default();
158                if spec_inline_types.contains_key(fact_ref) {
159                    return Err(Error::validation(
160                        format!(
161                            "Inline type definition for fact '{}' is already defined in spec '{}'",
162                            fact_ref.name, spec_name
163                        ),
164                        Some(def_loc.clone()),
165                        None::<String>,
166                    ));
167                }
168                spec_inline_types.insert(fact_ref.clone(), def);
169            }
170        }
171        Ok(())
172    }
173
174    /// Resolve all types for a certain spec (keyed by Arc<LemmaSpec>).
175    pub fn resolve_types(&self, spec: &Arc<LemmaSpec>) -> Result<ResolvedSpecTypes, Vec<Error>> {
176        self.resolve_types_internal(spec, true)
177    }
178
179    /// Resolve only named types (for validation before inline type definitions are registered).
180    pub fn resolve_named_types(
181        &self,
182        spec: &Arc<LemmaSpec>,
183    ) -> Result<ResolvedSpecTypes, Vec<Error>> {
184        self.resolve_types_internal(spec, false)
185    }
186
187    /// Resolve only inline type definitions and merge them into an existing
188    /// `ResolvedSpecTypes` that already contains the named types.
189    pub fn resolve_inline_types(
190        &self,
191        spec: &Arc<LemmaSpec>,
192        mut existing: ResolvedSpecTypes,
193    ) -> Result<ResolvedSpecTypes, Vec<Error>> {
194        let mut errors = Vec::new();
195
196        if let Some(spec_inline_types) = self.inline_type_definitions.get(spec) {
197            for (fact_ref, type_def) in spec_inline_types {
198                let mut visited = HashSet::new();
199                match self.resolve_inline_type_definition(spec, type_def, &mut visited) {
200                    Ok(Some(resolved_type)) => {
201                        existing
202                            .inline_type_definitions
203                            .insert(fact_ref.clone(), resolved_type);
204                    }
205                    Ok(None) => {
206                        unreachable!(
207                            "BUG: registered inline type definition for fact '{}' could not be resolved (spec='{}')",
208                            fact_ref, spec.name
209                        );
210                    }
211                    Err(es) => return Err(es),
212                }
213            }
214        }
215
216        for (fact_ref, resolved_type) in &existing.inline_type_definitions {
217            let inline_type_name = format!("{}::{}", spec.name, fact_ref);
218            let e: Result<(), Error> = if resolved_type.is_scale() {
219                self.add_scale_units_to_index(
220                    &mut existing.unit_index,
221                    resolved_type,
222                    spec,
223                    &inline_type_name,
224                )
225            } else if resolved_type.is_ratio() {
226                self.add_ratio_units_to_index(
227                    &mut existing.unit_index,
228                    resolved_type,
229                    spec,
230                    &inline_type_name,
231                )
232            } else {
233                Ok(())
234            };
235            if let Err(e) = e {
236                errors.push(e);
237            }
238        }
239
240        if !errors.is_empty() {
241            return Err(errors);
242        }
243
244        Ok(existing)
245    }
246
247    fn resolve_types_internal(
248        &self,
249        spec: &Arc<LemmaSpec>,
250        include_anonymous: bool,
251    ) -> Result<ResolvedSpecTypes, Vec<Error>> {
252        let mut named_types = HashMap::new();
253        let mut inline_type_definitions = HashMap::new();
254        let mut visited = HashSet::new();
255
256        if let Some(spec_types) = self.named_types.get(spec) {
257            for type_name in spec_types.keys() {
258                match self.resolve_type_internal(spec, type_name, &mut visited) {
259                    Ok(Some(resolved_type)) => {
260                        named_types.insert(type_name.clone(), resolved_type);
261                    }
262                    Ok(None) => {
263                        unreachable!(
264                            "BUG: registered named type '{}' could not be resolved (spec='{}')",
265                            type_name, spec.name
266                        );
267                    }
268                    Err(es) => return Err(es),
269                }
270                visited.clear();
271            }
272        }
273
274        if include_anonymous {
275            if let Some(spec_inline_types) = self.inline_type_definitions.get(spec) {
276                for (fact_ref, type_def) in spec_inline_types {
277                    let mut visited = HashSet::new();
278                    match self.resolve_inline_type_definition(spec, type_def, &mut visited) {
279                        Ok(Some(resolved_type)) => {
280                            inline_type_definitions.insert(fact_ref.clone(), resolved_type);
281                        }
282                        Ok(None) => {
283                            unreachable!(
284                                "BUG: registered inline type definition for fact '{}' could not be resolved (spec='{}')",
285                                fact_ref, spec.name
286                            );
287                        }
288                        Err(es) => return Err(es),
289                    }
290                }
291            }
292        }
293
294        // Build unit index from types that have units (primitive types first, then spec types)
295        let mut unit_index: HashMap<String, LemmaType> = HashMap::new();
296        let mut errors = Vec::new();
297
298        if let Err(error) = self.add_ratio_units_to_index(
299            &mut unit_index,
300            semantics::primitive_ratio(),
301            spec,
302            "ratio",
303        ) {
304            errors.push(error);
305        }
306
307        // Add units from named types (collect all errors)
308        for resolved_type in named_types.values() {
309            let type_name = resolved_type.name.as_deref().unwrap_or("inline");
310            let e: Result<(), Error> = if resolved_type.is_scale() {
311                self.add_scale_units_to_index(&mut unit_index, resolved_type, spec, type_name)
312            } else if resolved_type.is_ratio() {
313                self.add_ratio_units_to_index(&mut unit_index, resolved_type, spec, type_name)
314            } else {
315                Ok(())
316            };
317            if let Err(e) = e {
318                errors.push(e);
319            }
320        }
321
322        // Add units from inline type definitions (collect all errors)
323        for (fact_ref, resolved_type) in &inline_type_definitions {
324            let inline_type_name = format!("{}::{}", spec.name, fact_ref);
325            let e: Result<(), Error> = if resolved_type.is_scale() {
326                self.add_scale_units_to_index(
327                    &mut unit_index,
328                    resolved_type,
329                    spec,
330                    &inline_type_name,
331                )
332            } else if resolved_type.is_ratio() {
333                self.add_ratio_units_to_index(
334                    &mut unit_index,
335                    resolved_type,
336                    spec,
337                    &inline_type_name,
338                )
339            } else {
340                Ok(())
341            };
342            if let Err(e) = e {
343                errors.push(e);
344            }
345        }
346
347        if !errors.is_empty() {
348            return Err(errors);
349        }
350
351        Ok(ResolvedSpecTypes {
352            named_types,
353            inline_type_definitions,
354            unit_index,
355        })
356    }
357
358    fn resolve_type_internal(
359        &self,
360        spec: &Arc<LemmaSpec>,
361        name: &str,
362        visited: &mut HashSet<String>,
363    ) -> Result<Option<LemmaType>, Vec<Error>> {
364        let key = format!("{}::{}", spec.name, name);
365        if visited.contains(&key) {
366            let source_location = self
367                .named_types
368                .get(spec)
369                .and_then(|dt| dt.get(name))
370                .map(|td| td.source_location().clone())
371                .unwrap_or_else(|| {
372                    unreachable!(
373                        "BUG: circular dependency detected for type '{}::{}' but type definition not found in registry",
374                        spec.name, name
375                    )
376                });
377            return Err(vec![Error::validation(
378                format!("Circular dependency detected in type resolution: {}", key),
379                Some(source_location),
380                None::<String>,
381            )]);
382        }
383        visited.insert(key.clone());
384
385        let type_def = match self.named_types.get(spec).and_then(|dt| dt.get(name)) {
386            Some(def) => def.clone(),
387            None => {
388                visited.remove(&key);
389                return Ok(None);
390            }
391        };
392
393        // Resolve the parent type (standard or custom)
394        let (parent, from, constraints, type_name) = match &type_def {
395            TypeDef::Regular {
396                name,
397                parent,
398                constraints,
399                ..
400            } => (parent.clone(), None, constraints.clone(), name.clone()),
401            TypeDef::Import {
402                name,
403                source_type,
404                from,
405                constraints,
406                ..
407            } => (
408                source_type.clone(),
409                Some(from.clone()),
410                constraints.clone(),
411                name.clone(),
412            ),
413            TypeDef::Inline { .. } => {
414                // Inline types are resolved separately
415                visited.remove(&key);
416                return Ok(None);
417            }
418        };
419
420        let parent_specs = match self.resolve_parent(
421            spec,
422            &parent,
423            &from,
424            visited,
425            type_def.source_location(),
426        ) {
427            Ok(Some(specs)) => specs,
428            Ok(None) => {
429                // Parent type not found - this is an error for named types
430                // (inline type definitions might have forward references, but named types should be resolvable)
431                visited.remove(&key);
432                let source = type_def.source_location().clone();
433                return Err(vec![Error::validation(
434                    format!("Unknown type: '{}'. Type must be defined before use. Valid primitive types are: boolean, scale, number, ratio, text, date, time, duration, percent", parent),
435                    Some(source.clone()),
436                    None::<String>,
437                )]);
438            }
439            Err(es) => {
440                visited.remove(&key);
441                return Err(es);
442            }
443        };
444
445        let final_specs = if let Some(constraints) = &constraints {
446            match self.apply_constraints(parent_specs, constraints, type_def.source_location()) {
447                Ok(specs) => specs,
448                Err(errors) => {
449                    visited.remove(&key);
450                    return Err(errors);
451                }
452            }
453        } else {
454            parent_specs
455        };
456
457        visited.remove(&key);
458
459        let extends = if self.resolve_primitive_type(&parent).is_some() {
460            TypeExtends::Primitive
461        } else {
462            let parent_spec_name = from
463                .as_ref()
464                .map(|r| r.name.as_str())
465                .unwrap_or(spec.name.as_str());
466            let parent_arc = self.name_to_arc.get(parent_spec_name);
467            let family = match parent_arc {
468                Some(arc) => match self.resolve_type_internal(arc, &parent, visited) {
469                    Ok(Some(parent_type)) => parent_type
470                        .scale_family_name()
471                        .map(String::from)
472                        .unwrap_or_else(|| parent.clone()),
473                    Ok(None) => parent.clone(),
474                    Err(es) => return Err(es),
475                },
476                None => parent.clone(),
477            };
478            TypeExtends::Custom {
479                parent: parent.clone(),
480                family,
481            }
482        };
483
484        Ok(Some(LemmaType {
485            name: Some(type_name),
486            specifications: final_specs,
487            extends,
488        }))
489    }
490
491    fn resolve_parent(
492        &self,
493        spec: &Arc<LemmaSpec>,
494        parent: &str,
495        from: &Option<crate::parsing::ast::SpecRef>,
496        visited: &mut HashSet<String>,
497        source: &crate::Source,
498    ) -> Result<Option<TypeSpecification>, Vec<Error>> {
499        if let Some(specs) = self.resolve_primitive_type(parent) {
500            return Ok(Some(specs));
501        }
502
503        let parent_spec_name = from
504            .as_ref()
505            .map(|r| r.name.as_str())
506            .unwrap_or(spec.name.as_str());
507        let parent_arc = self.name_to_arc.get(parent_spec_name);
508        let result = match parent_arc {
509            Some(arc) => self.resolve_type_internal(arc, parent, visited),
510            None => Ok(None),
511        };
512        match result {
513            Ok(Some(t)) => Ok(Some(t.specifications)),
514            Ok(None) => {
515                let type_exists = parent_arc
516                    .and_then(|arc| self.named_types.get(arc))
517                    .map(|spec_types| spec_types.contains_key(parent))
518                    .unwrap_or(false);
519
520                if !type_exists {
521                    let suggestion = from.as_ref().filter(|r| r.is_registry).map(|r| {
522                        format!(
523                            "Run `lemma get` or `lemma get @{}` to fetch this dependency.",
524                            r.name
525                        )
526                    });
527                    Err(vec![Error::validation(
528                        format!("Unknown type: '{}'. Type must be defined before use. Valid primitive types are: boolean, scale, number, ratio, text, date, time, duration, percent", parent),
529                        Some(source.clone()),
530                        suggestion,
531                    )])
532                } else {
533                    Ok(None)
534                }
535            }
536            Err(es) => Err(es),
537        }
538    }
539
540    /// Resolve a primitive type by name
541    pub fn resolve_primitive_type(&self, name: &str) -> Option<TypeSpecification> {
542        match name {
543            "boolean" => Some(TypeSpecification::boolean()),
544            "scale" => Some(TypeSpecification::scale()),
545            "number" => Some(TypeSpecification::number()),
546            "ratio" => Some(TypeSpecification::ratio()),
547            "text" => Some(TypeSpecification::text()),
548            "date" => Some(TypeSpecification::date()),
549            "time" => Some(TypeSpecification::time()),
550            "duration" => Some(TypeSpecification::duration()),
551            "percent" => Some(TypeSpecification::ratio()),
552            _ => None,
553        }
554    }
555
556    /// Apply command-argument constraints to a TypeSpecification.
557    /// Each TypeSpecification variant handles its own commands; we just apply them in order.
558    fn apply_constraints(
559        &self,
560        mut specs: TypeSpecification,
561        constraints: &[(String, Vec<CommandArg>)],
562        source: &crate::Source,
563    ) -> Result<TypeSpecification, Vec<Error>> {
564        let mut errors = Vec::new();
565        for (command, args) in constraints {
566            let specs_clone = specs.clone();
567            match specs.apply_constraint(command, args) {
568                Ok(updated_specs) => specs = updated_specs,
569                Err(e) => {
570                    errors.push(Error::validation(
571                        format!("Failed to apply constraint '{}': {}", command, e),
572                        Some(source.clone()),
573                        None::<String>,
574                    ));
575                    specs = specs_clone;
576                }
577            }
578        }
579        if !errors.is_empty() {
580            return Err(errors);
581        }
582        Ok(specs)
583    }
584
585    fn resolve_inline_type_definition(
586        &self,
587        spec: &Arc<LemmaSpec>,
588        type_def: &TypeDef,
589        visited: &mut HashSet<String>,
590    ) -> Result<Option<LemmaType>, Vec<Error>> {
591        let def_loc = type_def.source_location().clone();
592        let TypeDef::Inline {
593            parent,
594            constraints,
595            fact_ref: _,
596            from,
597            ..
598        } = type_def
599        else {
600            return Ok(None);
601        };
602
603        let parent_specs = match self.resolve_parent(spec, parent, from, visited, &def_loc) {
604            Ok(Some(specs)) => specs,
605            Ok(None) => {
606                return Err(vec![Error::validation(
607                    format!("Unknown type: '{}'. Type must be defined before use. Valid primitive types are: boolean, scale, number, ratio, text, date, time, duration, percent", parent),
608                    Some(def_loc.clone()),
609                    None::<String>,
610                )]);
611            }
612            Err(es) => return Err(es),
613        };
614
615        let final_specs = if let Some(constraints) = constraints {
616            self.apply_constraints(parent_specs, constraints, &def_loc)?
617        } else {
618            parent_specs
619        };
620
621        let extends = if self.resolve_primitive_type(parent).is_some() {
622            TypeExtends::Primitive
623        } else {
624            let parent_spec_name = from
625                .as_ref()
626                .map(|r| r.name.as_str())
627                .unwrap_or(spec.name.as_str());
628            let family = match self.name_to_arc.get(parent_spec_name) {
629                Some(arc) => match self.resolve_type_internal(arc, parent, visited) {
630                    Ok(Some(parent_type)) => parent_type
631                        .scale_family_name()
632                        .map(String::from)
633                        .unwrap_or_else(|| parent.to_string()),
634                    Ok(None) => parent.to_string(),
635                    Err(es) => return Err(es),
636                },
637                None => parent.to_string(),
638            };
639            TypeExtends::Custom {
640                parent: parent.to_string(),
641                family,
642            }
643        };
644
645        Ok(Some(LemmaType::without_name(final_specs, extends)))
646    }
647
648    fn add_scale_units_to_index(
649        &self,
650        unit_index: &mut HashMap<String, LemmaType>,
651        resolved_type: &LemmaType,
652        spec: &Arc<LemmaSpec>,
653        type_name: &str,
654    ) -> Result<(), Error> {
655        let units = self.extract_units_from_specs(&resolved_type.specifications);
656        for unit in units {
657            if let Some(existing_type) = unit_index.get(&unit) {
658                let existing_name = existing_type.name.as_deref().unwrap_or("inline");
659                let same_type = existing_type.name.as_deref() == resolved_type.name.as_deref();
660
661                if same_type {
662                    let source = self
663                        .named_types
664                        .get(spec)
665                        .and_then(|defs| defs.get(type_name))
666                        .map(|def| def.source_location().clone())
667                        .expect("BUG: named type definition must have source location");
668
669                    return Err(Error::validation(
670                        format!(
671                            "Unit '{}' is defined more than once in type '{}'",
672                            unit, type_name
673                        ),
674                        Some(source.clone()),
675                        None::<String>,
676                    ));
677                }
678
679                let current_extends_existing = resolved_type
680                    .extends
681                    .parent_name()
682                    .map(|p| existing_name == p)
683                    .unwrap_or(false);
684                let existing_extends_current = existing_type
685                    .extends
686                    .parent_name()
687                    .map(|p| p == resolved_type.name.as_deref().unwrap_or(""))
688                    .unwrap_or(false);
689
690                if existing_type.is_scale()
691                    && (current_extends_existing || existing_extends_current)
692                {
693                    if current_extends_existing {
694                        unit_index.insert(unit, resolved_type.clone());
695                    }
696                    continue;
697                }
698
699                // Siblings in the same scale family (e.g. both extend "money")
700                // inherit the same unit — not ambiguous.
701                if existing_type.same_scale_family(resolved_type) {
702                    continue;
703                }
704
705                let source = self
706                    .named_types
707                    .get(spec)
708                    .and_then(|defs| defs.get(type_name))
709                    .map(|def| def.source_location().clone())
710                    .expect("BUG: named type definition must have source location");
711
712                return Err(Error::validation(
713                    format!(
714                        "Ambiguous unit '{}' in spec '{}'. Defined in multiple types: {} and {}",
715                        unit, spec.name, existing_name, type_name
716                    ),
717                    Some(source.clone()),
718                    None::<String>,
719                ));
720            }
721            unit_index.insert(unit, resolved_type.clone());
722        }
723        Ok(())
724    }
725
726    fn add_ratio_units_to_index(
727        &self,
728        unit_index: &mut HashMap<String, LemmaType>,
729        resolved_type: &LemmaType,
730        spec: &Arc<LemmaSpec>,
731        type_name: &str,
732    ) -> Result<(), Error> {
733        let units = self.extract_units_from_specs(&resolved_type.specifications);
734        for unit in units {
735            if let Some(existing_type) = unit_index.get(&unit) {
736                if existing_type.is_ratio() {
737                    continue;
738                }
739                let existing_name = existing_type.name.as_deref().unwrap_or("inline");
740                let source = self
741                    .named_types
742                    .get(spec)
743                    .and_then(|defs| defs.get(type_name))
744                    .map(|def| def.source_location().clone())
745                    .expect("BUG: named type definition must have source location");
746
747                return Err(Error::validation(
748                    format!(
749                        "Ambiguous unit '{}' in spec '{}'. Defined in multiple types: {} and {}",
750                        unit, spec.name, existing_name, type_name
751                    ),
752                    Some(source.clone()),
753                    None::<String>,
754                ));
755            }
756            unit_index.insert(unit, resolved_type.clone());
757        }
758        Ok(())
759    }
760
761    /// Extract all unit names from a TypeSpecification
762    /// Only Scale types can have units (Number types are dimensionless)
763    fn extract_units_from_specs(&self, specs: &TypeSpecification) -> Vec<String> {
764        match specs {
765            TypeSpecification::Scale { units, .. } => {
766                units.iter().map(|unit| unit.name.clone()).collect()
767            }
768            TypeSpecification::Ratio { units, .. } => {
769                units.iter().map(|unit| unit.name.clone()).collect()
770            }
771            _ => Vec::new(),
772        }
773    }
774}
775
776impl Default for TypeResolver {
777    fn default() -> Self {
778        Self::new()
779    }
780}
781
782#[cfg(test)]
783mod tests {
784    use super::*;
785    use crate::parse;
786    use crate::parsing::ast::LemmaSpec;
787    use crate::ResourceLimits;
788    use rust_decimal::Decimal;
789    use std::sync::Arc;
790
791    fn test_registry() -> TypeResolver {
792        TypeResolver::new()
793    }
794
795    fn test_spec_arc() -> Arc<LemmaSpec> {
796        Arc::new(LemmaSpec::new("test_spec".to_string()))
797    }
798
799    #[test]
800    fn test_registry_creation() {
801        let registry = test_registry();
802        let spec_arc = test_spec_arc();
803        let resolved = registry.resolve_types(&spec_arc).unwrap();
804        assert!(resolved.named_types.is_empty());
805        assert!(resolved.inline_type_definitions.is_empty());
806    }
807
808    #[test]
809    fn test_resolve_primitive_types() {
810        let registry = test_registry();
811
812        assert!(registry.resolve_primitive_type("boolean").is_some());
813        assert!(registry.resolve_primitive_type("scale").is_some());
814        assert!(registry.resolve_primitive_type("number").is_some());
815        assert!(registry.resolve_primitive_type("ratio").is_some());
816        assert!(registry.resolve_primitive_type("text").is_some());
817        assert!(registry.resolve_primitive_type("date").is_some());
818        assert!(registry.resolve_primitive_type("time").is_some());
819        assert!(registry.resolve_primitive_type("duration").is_some());
820        assert!(registry.resolve_primitive_type("unknown").is_none());
821    }
822
823    #[test]
824    fn test_register_named_type() {
825        let mut registry = test_registry();
826        let type_def = TypeDef::Regular {
827            source_location: crate::Source::new(
828                "<test>",
829                crate::parsing::ast::Span {
830                    start: 0,
831                    end: 0,
832                    line: 1,
833                    col: 0,
834                },
835                "test_spec",
836                Arc::from("spec test\nfact x: 1"),
837            ),
838            name: "money".to_string(),
839            parent: "number".to_string(),
840            constraints: None,
841        };
842
843        let result = registry.register_type(&test_spec_arc(), type_def);
844        assert!(result.is_ok());
845    }
846
847    #[test]
848    fn test_register_inline_type_definition() {
849        use crate::parsing::ast::Reference;
850        let mut registry = test_registry();
851        let fact_ref = Reference::local("age".to_string());
852        let type_def = TypeDef::Inline {
853            source_location: crate::Source::new(
854                "<test>",
855                crate::parsing::ast::Span {
856                    start: 0,
857                    end: 0,
858                    line: 1,
859                    col: 0,
860                },
861                "test_spec",
862                Arc::from("spec test\nfact x: 1"),
863            ),
864            parent: "number".to_string(),
865            constraints: Some(vec![
866                (
867                    "minimum".to_string(),
868                    vec![CommandArg::Number("0".to_string())],
869                ),
870                (
871                    "maximum".to_string(),
872                    vec![CommandArg::Number("150".to_string())],
873                ),
874            ]),
875            fact_ref: fact_ref.clone(),
876            from: None,
877        };
878
879        let spec_arc = test_spec_arc();
880        let result = registry.register_type(&spec_arc, type_def);
881        assert!(result.is_ok());
882        let resolved = registry.resolve_types(&spec_arc).unwrap();
883        assert!(resolved.inline_type_definitions.contains_key(&fact_ref));
884    }
885
886    #[test]
887    fn test_register_duplicate_type_fails() {
888        let mut registry = test_registry();
889        let type_def = TypeDef::Regular {
890            source_location: crate::Source::new(
891                "<test>",
892                crate::parsing::ast::Span {
893                    start: 0,
894                    end: 0,
895                    line: 1,
896                    col: 0,
897                },
898                "test_spec",
899                Arc::from("spec test\nfact x: 1"),
900            ),
901            name: "money".to_string(),
902            parent: "number".to_string(),
903            constraints: None,
904        };
905
906        let spec_arc = test_spec_arc();
907        registry.register_type(&spec_arc, type_def.clone()).unwrap();
908        let result = registry.register_type(&spec_arc, type_def);
909        assert!(result.is_err());
910    }
911
912    #[test]
913    fn test_resolve_custom_type_from_primitive() {
914        let mut registry = test_registry();
915        let type_def = TypeDef::Regular {
916            source_location: crate::Source::new(
917                "<test>",
918                crate::parsing::ast::Span {
919                    start: 0,
920                    end: 0,
921                    line: 1,
922                    col: 0,
923                },
924                "test_spec",
925                Arc::from("spec test\nfact x: 1"),
926            ),
927            name: "money".to_string(),
928            parent: "number".to_string(),
929            constraints: None,
930        };
931
932        let spec_arc = test_spec_arc();
933        registry.register_type(&spec_arc, type_def).unwrap();
934        let resolved = registry.resolve_types(&spec_arc).unwrap();
935
936        assert!(resolved.named_types.contains_key("money"));
937        let money_type = resolved.named_types.get("money").unwrap();
938        assert_eq!(money_type.name, Some("money".to_string()));
939    }
940
941    #[test]
942    fn test_type_definition_resolution() {
943        let code = r#"spec test
944type dice: number -> minimum 0 -> maximum 6"#;
945
946        let specs = parse(code, "test.lemma", &ResourceLimits::default()).unwrap();
947        let spec = &specs[0];
948
949        // Use TypeResolver to resolve the type
950        let mut registry = test_registry();
951        registry
952            .register_type(&Arc::new(spec.clone()), spec.types[0].clone())
953            .unwrap();
954
955        let resolved_types = registry.resolve_types(&Arc::new(spec.clone())).unwrap();
956        let dice_type = resolved_types.named_types.get("dice").unwrap();
957
958        // Verify it's a Number type (dimensionless) with the correct constraints
959        match &dice_type.specifications {
960            TypeSpecification::Number {
961                minimum, maximum, ..
962            } => {
963                assert_eq!(*minimum, Some(Decimal::from(0)));
964                assert_eq!(*maximum, Some(Decimal::from(6)));
965            }
966            _ => panic!("Expected Number type specifications"),
967        }
968    }
969
970    #[test]
971    fn test_type_definition_with_multiple_commands() {
972        let code = r#"spec test
973type money: scale -> decimals 2 -> unit eur 1.0 -> unit usd 1.18"#;
974
975        let specs = parse(code, "test.lemma", &ResourceLimits::default()).unwrap();
976        let spec = &specs[0];
977        let type_def = &spec.types[0];
978
979        // Use TypeResolver to resolve the type
980        let mut registry = test_registry();
981        registry
982            .register_type(&Arc::new(spec.clone()), type_def.clone())
983            .unwrap();
984
985        let resolved_types = registry.resolve_types(&Arc::new(spec.clone())).unwrap();
986        let money_type = resolved_types.named_types.get("money").unwrap();
987
988        match &money_type.specifications {
989            TypeSpecification::Scale {
990                decimals, units, ..
991            } => {
992                assert_eq!(*decimals, Some(2));
993                assert_eq!(units.len(), 2);
994                assert!(units.iter().any(|u| u.name == "eur"));
995                assert!(units.iter().any(|u| u.name == "usd"));
996            }
997            _ => panic!("Expected Scale type specifications"),
998        }
999    }
1000
1001    #[test]
1002    fn test_number_type_with_decimals() {
1003        let code = r#"spec test
1004type price: number -> decimals 2 -> minimum 0"#;
1005
1006        let specs = parse(code, "test.lemma", &ResourceLimits::default()).unwrap();
1007        let spec = &specs[0];
1008
1009        // Use TypeResolver to resolve the type
1010        let mut registry = test_registry();
1011        registry
1012            .register_type(&Arc::new(spec.clone()), spec.types[0].clone())
1013            .unwrap();
1014
1015        let resolved_types = registry.resolve_types(&Arc::new(spec.clone())).unwrap();
1016        let price_type = resolved_types.named_types.get("price").unwrap();
1017
1018        // Verify it's a Number type with decimals set to 2
1019        match &price_type.specifications {
1020            TypeSpecification::Number {
1021                decimals, minimum, ..
1022            } => {
1023                assert_eq!(*decimals, Some(2));
1024                assert_eq!(*minimum, Some(Decimal::from(0)));
1025            }
1026            _ => panic!("Expected Number type specifications with decimals"),
1027        }
1028    }
1029
1030    #[test]
1031    fn test_number_type_decimals_only() {
1032        let code = r#"spec test
1033type precise_number: number -> decimals 4"#;
1034
1035        let specs = parse(code, "test.lemma", &ResourceLimits::default()).unwrap();
1036        let spec = &specs[0];
1037
1038        let mut registry = test_registry();
1039        registry
1040            .register_type(&Arc::new(spec.clone()), spec.types[0].clone())
1041            .unwrap();
1042
1043        let resolved_types = registry.resolve_types(&Arc::new(spec.clone())).unwrap();
1044        let precise_type = resolved_types.named_types.get("precise_number").unwrap();
1045
1046        match &precise_type.specifications {
1047            TypeSpecification::Number { decimals, .. } => {
1048                assert_eq!(*decimals, Some(4));
1049            }
1050            _ => panic!("Expected Number type with decimals 4"),
1051        }
1052    }
1053
1054    #[test]
1055    fn test_scale_type_decimals_only() {
1056        let code = r#"spec test
1057type weight: scale -> unit kg 1 -> decimals 3"#;
1058
1059        let specs = parse(code, "test.lemma", &ResourceLimits::default()).unwrap();
1060        let spec = &specs[0];
1061
1062        let mut registry = test_registry();
1063        registry
1064            .register_type(&Arc::new(spec.clone()), spec.types[0].clone())
1065            .unwrap();
1066
1067        let resolved_types = registry.resolve_types(&Arc::new(spec.clone())).unwrap();
1068        let weight_type = resolved_types.named_types.get("weight").unwrap();
1069
1070        match &weight_type.specifications {
1071            TypeSpecification::Scale { decimals, .. } => {
1072                assert_eq!(*decimals, Some(3));
1073            }
1074            _ => panic!("Expected Scale type with decimals 3"),
1075        }
1076    }
1077
1078    #[test]
1079    fn test_ratio_type_accepts_optional_decimals_command() {
1080        let code = r#"spec test
1081type ratio_type: ratio -> decimals 2"#;
1082
1083        let specs = parse(code, "test.lemma", &ResourceLimits::default()).unwrap();
1084        let spec = &specs[0];
1085
1086        let mut registry = test_registry();
1087        registry
1088            .register_type(&Arc::new(spec.clone()), spec.types[0].clone())
1089            .unwrap();
1090
1091        let resolved_types = registry.resolve_types(&Arc::new(spec.clone())).unwrap();
1092        let ratio_type = resolved_types.named_types.get("ratio_type").unwrap();
1093
1094        match &ratio_type.specifications {
1095            TypeSpecification::Ratio { decimals, .. } => {
1096                assert_eq!(
1097                    *decimals,
1098                    Some(2),
1099                    "ratio type should accept decimals command"
1100                );
1101            }
1102            _ => panic!("Expected Ratio type with decimals 2"),
1103        }
1104    }
1105
1106    #[test]
1107    fn test_ratio_type_with_default_command() {
1108        let code = r#"spec test
1109type percentage: ratio -> minimum 0 -> maximum 1 -> default 0.5"#;
1110
1111        let specs = parse(code, "test.lemma", &ResourceLimits::default()).unwrap();
1112        let spec = &specs[0];
1113
1114        let mut registry = test_registry();
1115        registry
1116            .register_type(&Arc::new(spec.clone()), spec.types[0].clone())
1117            .unwrap();
1118
1119        let resolved_types = registry.resolve_types(&Arc::new(spec.clone())).unwrap();
1120        let percentage_type = resolved_types.named_types.get("percentage").unwrap();
1121
1122        match &percentage_type.specifications {
1123            TypeSpecification::Ratio {
1124                minimum,
1125                maximum,
1126                default,
1127                ..
1128            } => {
1129                assert_eq!(
1130                    *minimum,
1131                    Some(Decimal::from(0)),
1132                    "ratio type should have minimum 0"
1133                );
1134                assert_eq!(
1135                    *maximum,
1136                    Some(Decimal::from(1)),
1137                    "ratio type should have maximum 1"
1138                );
1139                assert_eq!(
1140                    *default,
1141                    Some(Decimal::from_i128_with_scale(5, 1)),
1142                    "ratio type with default command must work"
1143                );
1144            }
1145            _ => panic!("Expected Ratio type with minimum, maximum, and default"),
1146        }
1147    }
1148
1149    #[test]
1150    fn test_scale_extension_chain_same_family_units_allowed() {
1151        let code = r#"spec test
1152type money: scale -> unit eur 1
1153type money2: money -> unit usd 1.24"#;
1154
1155        let specs = parse(code, "test.lemma", &ResourceLimits::default()).unwrap();
1156        let spec = &specs[0];
1157
1158        let mut registry = test_registry();
1159        for type_def in &spec.types {
1160            registry
1161                .register_type(&Arc::new(spec.clone()), type_def.clone())
1162                .unwrap();
1163        }
1164
1165        let result = registry.resolve_types(&Arc::new(spec.clone()));
1166        assert!(
1167            result.is_ok(),
1168            "Scale extension chain should resolve: {:?}",
1169            result.err()
1170        );
1171
1172        let resolved = result.unwrap();
1173        assert!(
1174            resolved.unit_index.contains_key("eur"),
1175            "eur should be in unit_index"
1176        );
1177        assert!(
1178            resolved.unit_index.contains_key("usd"),
1179            "usd should be in unit_index"
1180        );
1181        let eur_type = resolved.unit_index.get("eur").unwrap();
1182        let usd_type = resolved.unit_index.get("usd").unwrap();
1183        assert_eq!(
1184            eur_type.name.as_deref(),
1185            Some("money2"),
1186            "more derived type (money2) should own eur for conversion"
1187        );
1188        assert_eq!(usd_type.name.as_deref(), Some("money2"));
1189    }
1190
1191    #[test]
1192    fn test_invalid_parent_type_in_named_type_should_error() {
1193        let code = r#"spec test
1194type invalid: nonexistent_type -> minimum 0"#;
1195
1196        let specs = parse(code, "test.lemma", &ResourceLimits::default()).unwrap();
1197        let spec = &specs[0];
1198
1199        let mut registry = test_registry();
1200        registry
1201            .register_type(&Arc::new(spec.clone()), spec.types[0].clone())
1202            .unwrap();
1203
1204        let result = registry.resolve_types(&Arc::new(spec.clone()));
1205        assert!(result.is_err(), "Should reject invalid parent type");
1206
1207        let errs = result.unwrap_err();
1208        assert!(!errs.is_empty(), "expected at least one error");
1209        let error_msg = errs[0].to_string();
1210        assert!(
1211            error_msg.contains("Unknown type") && error_msg.contains("nonexistent_type"),
1212            "Error should mention unknown type. Got: {}",
1213            error_msg
1214        );
1215    }
1216
1217    #[test]
1218    fn test_invalid_primitive_type_name_should_error() {
1219        // "choice" is not a primitive type; this should fail resolution.
1220        let code = r#"spec test
1221type invalid: choice -> option "a""#;
1222
1223        let specs = parse(code, "test.lemma", &ResourceLimits::default()).unwrap();
1224        let spec = &specs[0];
1225
1226        let mut registry = test_registry();
1227        registry
1228            .register_type(&Arc::new(spec.clone()), spec.types[0].clone())
1229            .unwrap();
1230
1231        let result = registry.resolve_types(&Arc::new(spec.clone()));
1232        assert!(result.is_err(), "Should reject invalid type base 'choice'");
1233
1234        let errs = result.unwrap_err();
1235        assert!(!errs.is_empty(), "expected at least one error");
1236        let error_msg = errs[0].to_string();
1237        assert!(
1238            error_msg.contains("Unknown type") && error_msg.contains("choice"),
1239            "Error should mention unknown type 'choice'. Got: {}",
1240            error_msg
1241        );
1242    }
1243
1244    #[test]
1245    fn test_unit_constraint_validation_errors_are_reported() {
1246        // Regression guard: overriding existing units should not silently succeed.
1247        let code = r#"spec test
1248type money: scale
1249  -> unit eur 1.00
1250  -> unit usd 1.19
1251
1252type money2: money
1253  -> unit eur 1.20
1254  -> unit usd 1.21
1255  -> unit gbp 1.30"#;
1256
1257        let specs = parse(code, "test.lemma", &ResourceLimits::default()).unwrap();
1258        let spec = &specs[0];
1259
1260        let mut registry = test_registry();
1261        for type_def in &spec.types {
1262            registry
1263                .register_type(&Arc::new(spec.clone()), type_def.clone())
1264                .unwrap();
1265        }
1266
1267        let result = registry.resolve_types(&Arc::new(spec.clone()));
1268        assert!(
1269            result.is_err(),
1270            "Expected unit constraint conflicts to error"
1271        );
1272
1273        let errs = result.unwrap_err();
1274        assert!(!errs.is_empty(), "expected at least one error");
1275        let error_msg = errs
1276            .iter()
1277            .map(ToString::to_string)
1278            .collect::<Vec<_>>()
1279            .join("; ");
1280        assert!(
1281            error_msg.contains("eur") || error_msg.contains("usd"),
1282            "Error should mention the conflicting units. Got: {}",
1283            error_msg
1284        );
1285    }
1286
1287    #[test]
1288    fn test_spec_level_unit_ambiguity_errors_are_reported() {
1289        // Regression guard: the same unit name must not be defined by multiple types in one spec.
1290        let code = r#"spec test
1291type money_a: scale
1292  -> unit eur 1.00
1293  -> unit usd 1.19
1294
1295type money_b: scale
1296  -> unit eur 1.00
1297  -> unit usd 1.20
1298
1299type length_a: scale
1300  -> unit meter 1.0
1301
1302type length_b: scale
1303  -> unit meter 1.0"#;
1304
1305        let specs = parse(code, "test.lemma", &ResourceLimits::default()).unwrap();
1306        let spec = &specs[0];
1307
1308        let mut registry = test_registry();
1309        for type_def in &spec.types {
1310            registry
1311                .register_type(&Arc::new(spec.clone()), type_def.clone())
1312                .unwrap();
1313        }
1314
1315        let result = registry.resolve_types(&Arc::new(spec.clone()));
1316        assert!(
1317            result.is_err(),
1318            "Expected ambiguous unit definitions to error"
1319        );
1320
1321        let errs = result.unwrap_err();
1322        assert!(!errs.is_empty(), "expected at least one error");
1323        let error_msg = errs
1324            .iter()
1325            .map(ToString::to_string)
1326            .collect::<Vec<_>>()
1327            .join("; ");
1328        assert!(
1329            error_msg.contains("eur") || error_msg.contains("usd") || error_msg.contains("meter"),
1330            "Error should mention at least one ambiguous unit. Got: {}",
1331            error_msg
1332        );
1333    }
1334
1335    #[test]
1336    fn test_number_type_cannot_have_units() {
1337        let code = r#"spec test
1338type price: number
1339  -> unit eur 1.00"#;
1340
1341        let specs = parse(code, "test.lemma", &ResourceLimits::default()).unwrap();
1342        let spec = &specs[0];
1343
1344        let mut registry = test_registry();
1345        registry
1346            .register_type(&Arc::new(spec.clone()), spec.types[0].clone())
1347            .unwrap();
1348
1349        let result = registry.resolve_types(&Arc::new(spec.clone()));
1350        assert!(result.is_err(), "Number types must reject unit commands");
1351
1352        let errs = result.unwrap_err();
1353        assert!(!errs.is_empty(), "expected at least one error");
1354        let error_msg = errs[0].to_string();
1355        assert!(
1356            error_msg.contains("unit") && error_msg.contains("number"),
1357            "Error should mention units are invalid on number. Got: {}",
1358            error_msg
1359        );
1360    }
1361
1362    #[test]
1363    fn test_scale_type_can_have_units() {
1364        let code = r#"spec test
1365type money: scale
1366  -> unit eur 1.00
1367  -> unit usd 1.19"#;
1368
1369        let specs = parse(code, "test.lemma", &ResourceLimits::default()).unwrap();
1370        let spec = &specs[0];
1371
1372        let mut registry = test_registry();
1373        registry
1374            .register_type(&Arc::new(spec.clone()), spec.types[0].clone())
1375            .unwrap();
1376
1377        let resolved = registry.resolve_types(&Arc::new(spec.clone())).unwrap();
1378        let money_type = resolved.named_types.get("money").unwrap();
1379
1380        match &money_type.specifications {
1381            TypeSpecification::Scale { units, .. } => {
1382                assert_eq!(units.len(), 2);
1383                assert!(units.iter().any(|u| u.name == "eur"));
1384                assert!(units.iter().any(|u| u.name == "usd"));
1385            }
1386            other => panic!("Expected Scale type specifications, got {:?}", other),
1387        }
1388    }
1389
1390    #[test]
1391    fn test_extending_type_inherits_units() {
1392        let code = r#"spec test
1393type money: scale
1394  -> unit eur 1.00
1395  -> unit usd 1.19
1396
1397type my_money: money
1398  -> unit gbp 1.30"#;
1399
1400        let specs = parse(code, "test.lemma", &ResourceLimits::default()).unwrap();
1401        let spec = &specs[0];
1402
1403        let mut registry = test_registry();
1404        for type_def in &spec.types {
1405            registry
1406                .register_type(&Arc::new(spec.clone()), type_def.clone())
1407                .unwrap();
1408        }
1409
1410        let resolved = registry.resolve_types(&Arc::new(spec.clone())).unwrap();
1411        let my_money_type = resolved.named_types.get("my_money").unwrap();
1412
1413        match &my_money_type.specifications {
1414            TypeSpecification::Scale { units, .. } => {
1415                assert_eq!(units.len(), 3);
1416                assert!(units.iter().any(|u| u.name == "eur"));
1417                assert!(units.iter().any(|u| u.name == "usd"));
1418                assert!(units.iter().any(|u| u.name == "gbp"));
1419            }
1420            other => panic!("Expected Scale type specifications, got {:?}", other),
1421        }
1422    }
1423
1424    #[test]
1425    fn test_duplicate_unit_in_same_type_is_rejected() {
1426        let code = r#"spec test
1427type money: scale
1428  -> unit eur 1.00
1429  -> unit eur 1.19"#;
1430
1431        let specs = parse(code, "test.lemma", &ResourceLimits::default()).unwrap();
1432        let spec = &specs[0];
1433
1434        let mut registry = test_registry();
1435        registry
1436            .register_type(&Arc::new(spec.clone()), spec.types[0].clone())
1437            .unwrap();
1438
1439        let result = registry.resolve_types(&Arc::new(spec.clone()));
1440        assert!(
1441            result.is_err(),
1442            "Duplicate units within a type should error"
1443        );
1444
1445        let errs = result.unwrap_err();
1446        assert!(!errs.is_empty(), "expected at least one error");
1447        let error_msg = errs[0].to_string();
1448        assert!(
1449            error_msg.contains("Duplicate unit")
1450                || error_msg.contains("duplicate")
1451                || error_msg.contains("already exists")
1452                || error_msg.contains("eur"),
1453            "Error should mention duplicate unit issue. Got: {}",
1454            error_msg
1455        );
1456    }
1457}