lemma/planning/
types.rs

1//! Type registry for managing custom type definitions and resolution
2//!
3//! This module provides the `TypeRegistry` which handles:
4//! - Registering user-defined types for each document
5//! - Resolving type hierarchies and inheritance chains
6//! - Detecting and preventing circular dependencies
7//! - Applying overrides to create final type specifications
8
9use crate::error::LemmaError;
10use crate::parsing::ast::Span;
11use crate::semantic::{FactReference, LemmaType, TypeDef, TypeSpecification};
12use std::collections::{HashMap, HashSet};
13use std::sync::Arc;
14
15/// Fully resolved types for a single document
16/// After resolution, all imports are inlined - documents are independent
17#[derive(Debug)]
18pub struct ResolvedDocumentTypes {
19    /// Named types: type_name -> fully resolved type
20    pub named_types: HashMap<String, LemmaType>,
21
22    /// Inline type definitions: fact_reference -> fully resolved type
23    pub inline_type_definitions: HashMap<FactReference, LemmaType>,
24
25    /// Unit index: unit_name -> type that defines it
26    /// Built during resolution - if unit appears in multiple types, resolution fails
27    pub unit_index: HashMap<String, LemmaType>,
28}
29
30/// Registry for managing and resolving custom types
31///
32/// Types are organized per document and support inheritance through parent references.
33/// The registry handles cycle detection and accumulates overrides through the inheritance chain.
34#[derive(Debug, Clone)]
35pub struct TypeRegistry {
36    /// Named types per document: doc_name -> (type_name -> TypeDef)
37    /// Stores the raw definitions extracted from the AST
38    named_types: HashMap<String, HashMap<String, TypeDef>>,
39    /// Inline type definitions per document: doc_name -> (fact_reference -> TypeDef)
40    /// Stores inline type definitions keyed by their fact reference
41    inline_type_definitions: HashMap<String, HashMap<FactReference, TypeDef>>,
42}
43
44impl TypeRegistry {
45    /// Create a new, empty registry
46    pub fn new() -> Self {
47        TypeRegistry {
48            named_types: HashMap::new(),
49            inline_type_definitions: HashMap::new(),
50        }
51    }
52
53    /// Register a user-defined type for a given document
54    pub fn register_type(&mut self, doc: &str, def: TypeDef) -> Result<(), LemmaError> {
55        match &def {
56            TypeDef::Regular { name, .. } | TypeDef::Import { name, .. } => {
57                // Named type
58                let doc_types = self.named_types.entry(doc.to_string()).or_default();
59
60                // Check if this type already exists
61                if doc_types.contains_key(name) {
62                    return Err(LemmaError::engine(
63                        format!("Type '{}' is already defined in document '{}'", name, doc),
64                        Span {
65                            start: 0,
66                            end: 0,
67                            line: 1,
68                            col: 0,
69                        },
70                        "<unknown>",
71                        Arc::from(""),
72                        doc,
73                        1,
74                        None::<String>,
75                    ));
76                }
77
78                // Store the type definition
79                doc_types.insert(name.clone(), def);
80            }
81            TypeDef::Inline { fact_ref, .. } => {
82                // Inline type definition
83                let doc_inline_types = self
84                    .inline_type_definitions
85                    .entry(doc.to_string())
86                    .or_default();
87
88                // Check if this inline type definition already exists
89                if doc_inline_types.contains_key(fact_ref) {
90                    return Err(LemmaError::engine(
91                        format!(
92                            "Inline type definition for fact '{}' is already defined in document '{}'",
93                            fact_ref.fact, doc
94                        ),
95                        Span {
96                            start: 0,
97                            end: 0,
98                            line: 1,
99                            col: 0,
100                        },
101                        "<unknown>",
102                        Arc::from(""),
103                        doc,
104                        1,
105                        None::<String>,
106                    ));
107                }
108
109                // Store the inline type definition
110                doc_inline_types.insert(fact_ref.clone(), def);
111            }
112        }
113        Ok(())
114    }
115
116    /// Resolve all types for a certain document
117    ///
118    /// Returns fully resolved types for the document, including named types, inline type definitions,
119    /// and a unit index. After resolution, all imports are inlined - documents are independent.
120    /// Follows `parent` chains, accumulates overrides into `specifications`.
121    /// Handles cycle detection and cross-document references.
122    ///
123    /// # Errors
124    /// Returns an error if a unit appears in multiple types within the same document (ambiguous unit).
125    pub fn resolve_types(&self, doc: &str) -> Result<ResolvedDocumentTypes, LemmaError> {
126        self.resolve_types_internal(doc, true)
127    }
128
129    /// Resolve only named types (for validation before inline type definitions are registered)
130    pub fn resolve_named_types(&self, doc: &str) -> Result<ResolvedDocumentTypes, LemmaError> {
131        self.resolve_types_internal(doc, false)
132    }
133
134    fn resolve_types_internal(
135        &self,
136        doc: &str,
137        include_anonymous: bool,
138    ) -> Result<ResolvedDocumentTypes, LemmaError> {
139        let mut named_types = HashMap::new();
140        let mut inline_type_definitions = HashMap::new();
141        let mut visited = HashSet::new();
142
143        // Resolve named types
144        if let Some(doc_types) = self.named_types.get(doc) {
145            for type_name in doc_types.keys() {
146                match self.resolve_type_internal(doc, type_name, &mut visited)? {
147                    Some(resolved_type) => {
148                        named_types.insert(type_name.clone(), resolved_type);
149                    }
150                    None => {
151                        // Type was registered but couldn't be resolved - this is an error
152                        // (should not happen for named types that are in the registry)
153                        return Err(LemmaError::engine(
154                            format!("Type '{}' is registered but could not be resolved. This indicates an internal error.", type_name),
155                            Span { start: 0, end: 0, line: 1, col: 0 },
156                            "<unknown>",
157                            Arc::from(""),
158                            doc,
159                            1,
160                            None::<String>,
161                        ));
162                    }
163                }
164                visited.clear();
165            }
166        }
167
168        // Resolve inline type definitions (only if requested)
169        if include_anonymous {
170            if let Some(doc_inline_types) = self.inline_type_definitions.get(doc) {
171                for (fact_ref, type_def) in doc_inline_types {
172                    let mut visited = HashSet::new();
173                    match self.resolve_inline_type_definition(
174                        doc,
175                        fact_ref,
176                        type_def,
177                        &mut visited,
178                    )? {
179                        Some(resolved_type) => {
180                            inline_type_definitions.insert(fact_ref.clone(), resolved_type);
181                        }
182                        None => {
183                            // Inline type definition was registered but couldn't be resolved - this is an error
184                            return Err(LemmaError::engine(
185                                format!("Inline type definition for fact '{}' is registered but could not be resolved. This indicates an internal error.", fact_ref),
186                                Span { start: 0, end: 0, line: 1, col: 0 },
187                                "<unknown>",
188                                Arc::from(""),
189                                doc,
190                                1,
191                                None::<String>,
192                            ));
193                        }
194                    }
195                }
196            }
197        }
198
199        // Build unit index from both named and inline type definitions
200        let mut unit_index: HashMap<String, LemmaType> = HashMap::new();
201        let mut errors = Vec::new();
202
203        // Add units from named types (collect all errors)
204        for resolved_type in named_types.values() {
205            if let Err(e) = self.add_units_to_index(&mut unit_index, resolved_type, doc, || {
206                resolved_type
207                    .name
208                    .as_deref()
209                    .unwrap_or("inline")
210                    .to_string()
211            }) {
212                errors.push(e);
213            }
214        }
215
216        // Add units from inline type definitions (collect all errors)
217        for (fact_ref, resolved_type) in &inline_type_definitions {
218            if let Err(e) = self.add_units_to_index(&mut unit_index, resolved_type, doc, || {
219                format!("{}::{}", doc, fact_ref)
220            }) {
221                errors.push(e);
222            }
223        }
224
225        // Return all collected errors if any
226        if !errors.is_empty() {
227            // Combine all errors into a single error message for now
228            // (Future: LSP can use Vec<LemmaError> for better diagnostics)
229            let combined_message = errors
230                .iter()
231                .map(|e| match e {
232                    LemmaError::Engine(details) => details.message.clone(),
233                    LemmaError::CircularDependency { details, .. } => details.message.clone(),
234                    LemmaError::Parse(details) => details.message.clone(),
235                    LemmaError::Semantic(details) => details.message.clone(),
236                    LemmaError::Runtime(details) => details.message.clone(),
237                    LemmaError::MissingFact(details) => details.message.clone(),
238                    LemmaError::ResourceLimitExceeded {
239                        limit_name,
240                        limit_value,
241                        actual_value,
242                        suggestion,
243                    } => {
244                        format!(
245                            "Resource limit exceeded: {} (limit: {}, actual: {}). {}",
246                            limit_name, limit_value, actual_value, suggestion
247                        )
248                    }
249                    LemmaError::MultipleErrors(errs) => errs
250                        .iter()
251                        .map(|e| match e {
252                            LemmaError::Engine(details) => details.message.clone(),
253                            LemmaError::CircularDependency { details, .. } => {
254                                details.message.clone()
255                            }
256                            LemmaError::Parse(details) => details.message.clone(),
257                            LemmaError::Semantic(details) => details.message.clone(),
258                            LemmaError::Runtime(details) => details.message.clone(),
259                            LemmaError::MissingFact(details) => details.message.clone(),
260                            LemmaError::ResourceLimitExceeded {
261                                limit_name,
262                                limit_value,
263                                actual_value,
264                                suggestion,
265                            } => {
266                                format!(
267                                    "Resource limit exceeded: {} (limit: {}, actual: {}). {}",
268                                    limit_name, limit_value, actual_value, suggestion
269                                )
270                            }
271                            LemmaError::MultipleErrors(_) => "Multiple errors".to_string(),
272                        })
273                        .collect::<Vec<_>>()
274                        .join("; "),
275                })
276                .collect::<Vec<_>>()
277                .join("; ");
278            return Err(LemmaError::engine(
279                &combined_message,
280                Span {
281                    start: 0,
282                    end: 0,
283                    line: 1,
284                    col: 0,
285                },
286                "<unknown>",
287                Arc::from(""),
288                doc,
289                1,
290                None::<String>,
291            ));
292        }
293
294        Ok(ResolvedDocumentTypes {
295            named_types,
296            inline_type_definitions,
297            unit_index,
298        })
299    }
300
301    /// Resolve a single type with cycle detection
302    fn resolve_type_internal(
303        &self,
304        doc: &str,
305        name: &str,
306        visited: &mut HashSet<String>,
307    ) -> Result<Option<LemmaType>, LemmaError> {
308        // Cycle detection using doc::name key
309        let key = format!("{}::{}", doc, name);
310        if visited.contains(&key) {
311            return Err(LemmaError::circular_dependency(
312                format!("Circular dependency detected in type resolution: {}", key),
313                crate::parsing::ast::Span {
314                    start: 0,
315                    end: 0,
316                    line: 1,
317                    col: 0,
318                },
319                "<unknown>",
320                std::sync::Arc::from(""),
321                "<unknown>",
322                1,
323                vec![],
324                None::<String>,
325            ));
326        }
327        visited.insert(key.clone());
328
329        // Get the TypeDef from the document (check named types)
330        let type_def = match self.named_types.get(doc).and_then(|dt| dt.get(name)) {
331            Some(def) => def.clone(),
332            None => {
333                visited.remove(&key);
334                return Ok(None);
335            }
336        };
337
338        // Resolve the parent type (standard or custom)
339        let (parent, from, overrides, type_name) = match &type_def {
340            TypeDef::Regular {
341                name,
342                parent,
343                overrides,
344            } => (parent.clone(), None, overrides.clone(), name.clone()),
345            TypeDef::Import {
346                name,
347                source_type,
348                from,
349                overrides,
350            } => (
351                source_type.clone(),
352                Some(from.clone()),
353                overrides.clone(),
354                name.clone(),
355            ),
356            TypeDef::Inline { .. } => {
357                // Inline types are resolved separately
358                visited.remove(&key);
359                return Ok(None);
360            }
361        };
362
363        let parent_specs = match self.resolve_parent(doc, &parent, &from, visited) {
364            Ok(Some(specs)) => specs,
365            Ok(None) => {
366                // Parent type not found - this is an error for named types
367                // (inline type definitions might have forward references, but named types should be resolvable)
368                visited.remove(&key);
369                return Err(LemmaError::engine(
370                    format!("Unknown type: '{}'. Type must be defined before use. Valid standard types are: boolean, scale, number, ratio, text, date, time, duration, percent", parent),
371                    Span { start: 0, end: 0, line: 1, col: 0 },
372                    "<unknown>",
373                    Arc::from(""),
374                    doc,
375                    1,
376                    None::<String>,
377                ));
378            }
379            Err(e) => {
380                visited.remove(&key);
381                return Err(e);
382            }
383        };
384
385        // Apply overrides from the TypeDef
386        let final_specs = if let Some(overrides) = &overrides {
387            match self.apply_overrides(parent_specs, overrides) {
388                Ok(specs) => specs,
389                Err(errors) => {
390                    visited.remove(&key);
391                    // Combine all errors into a single error message for now
392                    // (Future: LSP can use Vec<LemmaError> for better diagnostics)
393                    let combined_message = errors
394                        .iter()
395                        .map(|e| match e {
396                            LemmaError::Engine(details) => details.message.clone(),
397                            LemmaError::CircularDependency { details, .. } => details.message.clone(),
398                            LemmaError::Parse(details) => details.message.clone(),
399                            LemmaError::Semantic(details) => details.message.clone(),
400                            LemmaError::Runtime(details) => details.message.clone(),
401                            LemmaError::MissingFact(details) => details.message.clone(),
402                            LemmaError::ResourceLimitExceeded { limit_name, limit_value, actual_value, suggestion } => {
403                                format!("Resource limit exceeded: {} (limit: {}, actual: {}). {}", limit_name, limit_value, actual_value, suggestion)
404                            },
405                            LemmaError::MultipleErrors(errs) => {
406                                errs.iter().map(|e| match e {
407                                    LemmaError::Engine(details) => details.message.clone(),
408                                    LemmaError::CircularDependency { details, .. } => details.message.clone(),
409                                    LemmaError::Parse(details) => details.message.clone(),
410                                    LemmaError::Semantic(details) => details.message.clone(),
411                                    LemmaError::Runtime(details) => details.message.clone(),
412                                    LemmaError::MissingFact(details) => details.message.clone(),
413                                    LemmaError::ResourceLimitExceeded { limit_name, limit_value, actual_value, suggestion } => {
414                                        format!("Resource limit exceeded: {} (limit: {}, actual: {}). {}", limit_name, limit_value, actual_value, suggestion)
415                                    },
416                                    LemmaError::MultipleErrors(_) => "Multiple errors".to_string(),
417                                }).collect::<Vec<_>>().join("; ")
418                            },
419                        })
420                        .collect::<Vec<_>>()
421                        .join("; ");
422                    return Err(LemmaError::engine(
423                        &combined_message,
424                        Span {
425                            start: 0,
426                            end: 0,
427                            line: 1,
428                            col: 0,
429                        },
430                        "<unknown>",
431                        Arc::from(""),
432                        doc,
433                        1,
434                        None::<String>,
435                    ));
436                }
437            }
438        } else {
439            parent_specs
440        };
441
442        visited.remove(&key);
443
444        Ok(Some(LemmaType {
445            name: Some(type_name),
446            specifications: final_specs,
447        }))
448    }
449
450    /// Resolve a parent type reference (standard or custom)
451    fn resolve_parent(
452        &self,
453        doc: &str,
454        parent: &str,
455        from: &Option<String>,
456        visited: &mut HashSet<String>,
457    ) -> Result<Option<TypeSpecification>, LemmaError> {
458        // Try standard types first
459        if let Some(specs) = self.resolve_standard_type(parent) {
460            return Ok(Some(specs));
461        }
462
463        // Otherwise resolve as a custom type in the specified document (or same document if not specified)
464        let parent_doc = from.as_deref().unwrap_or(doc);
465        match self.resolve_type_internal(parent_doc, parent, visited) {
466            Ok(Some(t)) => Ok(Some(t.specifications)),
467            Ok(None) => {
468                // Parent type not found - check if it was ever registered
469                let type_exists = if let Some(doc_types) = self.named_types.get(parent_doc) {
470                    doc_types.contains_key(parent)
471                } else {
472                    false
473                };
474
475                if !type_exists {
476                    // Type was never registered - invalid parent type
477                    Err(LemmaError::engine(
478                        format!("Unknown type: '{}'. Type must be defined before use. Valid standard types are: boolean, scale, number, ratio, text, date, time, duration, percent", parent),
479                        Span { start: 0, end: 0, line: 1, col: 0 },
480                        "<unknown>",
481                        Arc::from(""),
482                        doc,
483                        1,
484                        None::<String>,
485                    ))
486                } else {
487                    // Type exists but couldn't be resolved (circular dependency or other issue)
488                    // Return None - the caller will handle this appropriately
489                    Ok(None)
490                }
491            }
492            Err(e) => Err(e),
493        }
494    }
495
496    /// Resolve a standard type by name
497    pub fn resolve_standard_type(&self, name: &str) -> Option<TypeSpecification> {
498        match name {
499            "boolean" => Some(TypeSpecification::boolean()),
500            "scale" => Some(TypeSpecification::scale()),
501            "number" => Some(TypeSpecification::number()),
502            "ratio" => Some(TypeSpecification::ratio()),
503            "text" => Some(TypeSpecification::text()),
504            "date" => Some(TypeSpecification::date()),
505            "time" => Some(TypeSpecification::time()),
506            "duration" => Some(TypeSpecification::duration()),
507            "percent" => Some(TypeSpecification::ratio()),
508            _ => None,
509        }
510    }
511
512    /// Find which document a FactReference belongs to (for inline type definitions)
513    pub fn find_document_for_fact(&self, fact_ref: &FactReference) -> Option<String> {
514        for (doc_name, inline_types) in &self.inline_type_definitions {
515            if inline_types.contains_key(fact_ref) {
516                return Some(doc_name.clone());
517            }
518        }
519        None
520    }
521
522    /// Apply command-argument overrides to a TypeSpecification
523    fn apply_overrides(
524        &self,
525        mut specs: TypeSpecification,
526        overrides: &[(String, Vec<String>)],
527    ) -> Result<TypeSpecification, Vec<LemmaError>> {
528        // Extract existing units from parent type before applying overrides
529        let mut existing_units: Vec<String> = match &specs {
530            TypeSpecification::Scale { units, .. } => {
531                units.iter().map(|u| u.name.clone()).collect()
532            }
533            TypeSpecification::Ratio { units, .. } => {
534                units.iter().map(|u| u.name.clone()).collect()
535            }
536            _ => Vec::new(),
537        };
538
539        let mut errors = Vec::new();
540        let mut valid_overrides = Vec::new();
541
542        // First pass: validate all unit overrides and collect errors
543        for (command, args) in overrides {
544            if command == "unit" && !args.is_empty() {
545                let unit_name = &args[0];
546                if existing_units.iter().any(|u| u == unit_name) {
547                    errors.push(LemmaError::engine(
548                        format!("Unit '{}' already exists in parent type. Use a different unit name (e.g., 'my_{}') if you need another unit factor.", unit_name, unit_name),
549                        Span { start: 0, end: 0, line: 1, col: 0 },
550                        "<unknown>",
551                        Arc::from(""),
552                        "<unknown>",
553                        1,
554                        None::<String>,
555                    ));
556                    // Skip this override
557                } else {
558                    existing_units.push(unit_name.clone());
559                    valid_overrides.push((command.clone(), args.clone()));
560                }
561            } else {
562                // Non-unit override - always valid to apply
563                valid_overrides.push((command.clone(), args.clone()));
564            }
565        }
566
567        // Second pass: apply all valid overrides and collect any application errors
568        // We apply overrides sequentially, accumulating the result
569        // If one fails, we can't continue with a valid state, but we still collect all errors
570        for (command, args) in valid_overrides {
571            // Clone specs before applying override so we can continue even if this one fails
572            let specs_clone = specs.clone();
573            match specs.apply_override(&command, &args) {
574                Ok(updated_specs) => {
575                    specs = updated_specs;
576                }
577                Err(e) => {
578                    errors.push(LemmaError::engine(
579                        format!("Failed to apply override '{}': {}", command, e),
580                        Span {
581                            start: 0,
582                            end: 0,
583                            line: 1,
584                            col: 0,
585                        },
586                        "<unknown>",
587                        Arc::from(""),
588                        "<unknown>",
589                        1,
590                        None::<String>,
591                    ));
592                    // Restore from clone so we can continue trying other overrides
593                    specs = specs_clone;
594                }
595            }
596        }
597
598        if !errors.is_empty() {
599            return Err(errors);
600        }
601
602        Ok(specs)
603    }
604
605    /// Resolve an inline type definition from its definition
606    fn resolve_inline_type_definition(
607        &self,
608        doc: &str,
609        _fact_ref: &FactReference,
610        type_def: &TypeDef,
611        visited: &mut HashSet<String>,
612    ) -> Result<Option<LemmaType>, LemmaError> {
613        let TypeDef::Inline {
614            parent,
615            overrides,
616            fact_ref: _,
617            from,
618        } = type_def
619        else {
620            return Ok(None);
621        };
622
623        let parent_specs = match self.resolve_parent(doc, parent, from, visited) {
624            Ok(Some(specs)) => specs,
625            Ok(None) => {
626                // Parent type not found - this is an error for inline type definitions too
627                // Inline type definitions should have valid parent types
628                return Err(LemmaError::engine(
629                    format!("Unknown type: '{}'. Type must be defined before use. Valid standard types are: boolean, scale, number, ratio, text, date, time, duration, percent", parent),
630                    Span { start: 0, end: 0, line: 1, col: 0 },
631                    "<unknown>",
632                    Arc::from(""),
633                    doc,
634                    1,
635                    None::<String>,
636                ));
637            }
638            Err(e) => return Err(e),
639        };
640
641        let final_specs = if let Some(overrides) = overrides {
642            match self.apply_overrides(parent_specs, overrides) {
643                Ok(specs) => specs,
644                Err(errors) => {
645                    // Combine all errors into a single error message for now
646                    // (Future: LSP can use Vec<LemmaError> for better diagnostics)
647                    let combined_message = errors
648                        .iter()
649                        .map(|e| match e {
650                            LemmaError::Engine(details) => details.message.clone(),
651                            LemmaError::CircularDependency { details, .. } => details.message.clone(),
652                            LemmaError::Parse(details) => details.message.clone(),
653                            LemmaError::Semantic(details) => details.message.clone(),
654                            LemmaError::Runtime(details) => details.message.clone(),
655                            LemmaError::MissingFact(details) => details.message.clone(),
656                            LemmaError::ResourceLimitExceeded { limit_name, limit_value, actual_value, suggestion } => {
657                                format!("Resource limit exceeded: {} (limit: {}, actual: {}). {}", limit_name, limit_value, actual_value, suggestion)
658                            },
659                            LemmaError::MultipleErrors(errs) => {
660                                errs.iter().map(|e| match e {
661                                    LemmaError::Engine(details) => details.message.clone(),
662                                    LemmaError::CircularDependency { details, .. } => details.message.clone(),
663                                    LemmaError::Parse(details) => details.message.clone(),
664                                    LemmaError::Semantic(details) => details.message.clone(),
665                                    LemmaError::Runtime(details) => details.message.clone(),
666                                    LemmaError::MissingFact(details) => details.message.clone(),
667                                    LemmaError::ResourceLimitExceeded { limit_name, limit_value, actual_value, suggestion } => {
668                                        format!("Resource limit exceeded: {} (limit: {}, actual: {}). {}", limit_name, limit_value, actual_value, suggestion)
669                                    },
670                                    LemmaError::MultipleErrors(_) => "Multiple errors".to_string(),
671                                }).collect::<Vec<_>>().join("; ")
672                            },
673                        })
674                        .collect::<Vec<_>>()
675                        .join("; ");
676                    return Err(LemmaError::engine(
677                        &combined_message,
678                        Span {
679                            start: 0,
680                            end: 0,
681                            line: 1,
682                            col: 0,
683                        },
684                        "<unknown>",
685                        Arc::from(""),
686                        doc,
687                        1,
688                        None::<String>,
689                    ));
690                }
691            }
692        } else {
693            parent_specs
694        };
695
696        Ok(Some(LemmaType::without_name(final_specs)))
697    }
698
699    /// Get all units defined across all types, organized by unit name
700    ///
701    /// Returns a map from unit name to a list of (document_name, type_name) pairs
702    /// where that unit is defined. Useful for error messages and debugging.
703    ///
704    /// # Returns
705    /// A HashMap where:
706    /// - Key: unit name (e.g., "celsius", "kilogram")
707    /// - Value: list of (document_name, type_name) pairs where the unit is defined
708    pub fn get_all_units(&self) -> HashMap<String, Vec<(String, String)>> {
709        let mut unit_map: HashMap<String, Vec<(String, String)>> = HashMap::new();
710        let mut visited = HashSet::new();
711
712        // Process named types
713        for (doc_name, doc_types) in &self.named_types {
714            for type_name in doc_types.keys() {
715                if let Ok(Some(resolved_type)) =
716                    self.resolve_type_internal(doc_name, type_name, &mut visited)
717                {
718                    let units = self.extract_units_from_specs(&resolved_type.specifications);
719                    for unit in units {
720                        unit_map
721                            .entry(unit)
722                            .or_default()
723                            .push((doc_name.clone(), type_name.clone()));
724                    }
725                }
726                visited.clear();
727            }
728        }
729
730        // Process inline type definitions
731        for (doc_name, inline_types) in &self.inline_type_definitions {
732            for (fact_ref, type_def) in inline_types {
733                if let Ok(Some(resolved_type)) = self.resolve_inline_type_definition(
734                    doc_name,
735                    fact_ref,
736                    type_def,
737                    &mut HashSet::new(),
738                ) {
739                    let units = self.extract_units_from_specs(&resolved_type.specifications);
740                    let type_name = format!("{}::{}", doc_name, fact_ref);
741                    for unit in units {
742                        unit_map
743                            .entry(unit)
744                            .or_default()
745                            .push((doc_name.clone(), type_name.clone()));
746                    }
747                }
748            }
749        }
750
751        unit_map
752    }
753
754    /// Add units from a resolved type to the unit index
755    /// Returns an error if a unit is already defined in another type (ambiguous unit)
756    fn add_units_to_index<F>(
757        &self,
758        unit_index: &mut HashMap<String, LemmaType>,
759        resolved_type: &LemmaType,
760        doc: &str,
761        get_type_name: F,
762    ) -> Result<(), LemmaError>
763    where
764        F: FnOnce() -> String,
765    {
766        let units = self.extract_units_from_specs(&resolved_type.specifications);
767        let current_name = get_type_name(); // Call FnOnce before the loop
768        for unit in units {
769            if let Some(existing_type) = unit_index.get(&unit) {
770                let existing_name = existing_type.name.as_deref().unwrap_or("inline");
771
772                // Check if one type extends the other
773                // If the existing type's name matches the current type's name, they're the same type
774                let same_type = existing_type.name.as_deref() == resolved_type.name.as_deref();
775                if same_type {
776                    // Same type - not ambiguous, just continue
777                    continue;
778                }
779
780                // Check if types share the same base and if one might extend the other
781                // by comparing their specifications - if one has all units from the other plus MORE,
782                // it's likely an extension (strict superset required)
783                // Check both directions: current extends existing OR existing extends current
784                let might_be_inheritance =
785                    match (&existing_type.specifications, &resolved_type.specifications) {
786                        (
787                            TypeSpecification::Scale {
788                                units: existing_units,
789                                ..
790                            },
791                            TypeSpecification::Scale {
792                                units: current_units,
793                                ..
794                            },
795                        ) => {
796                            // Check if current extends existing (current has all of existing + more)
797                            let current_extends_existing = existing_units.len()
798                                < current_units.len()
799                                && existing_units
800                                    .iter()
801                                    .all(|eu| current_units.iter().any(|cu| cu.name == eu.name));
802                            // Check if existing extends current (existing has all of current + more)
803                            let existing_extends_current = current_units.len()
804                                < existing_units.len()
805                                && current_units
806                                    .iter()
807                                    .all(|cu| existing_units.iter().any(|eu| eu.name == cu.name));
808                            current_extends_existing || existing_extends_current
809                        }
810                        (
811                            TypeSpecification::Ratio {
812                                units: existing_units,
813                                ..
814                            },
815                            TypeSpecification::Ratio {
816                                units: current_units,
817                                ..
818                            },
819                        ) => {
820                            // Check if current extends existing (current has all of existing + more)
821                            let current_extends_existing = existing_units.len()
822                                < current_units.len()
823                                && existing_units
824                                    .iter()
825                                    .all(|eu| current_units.iter().any(|cu| cu.name == eu.name));
826                            // Check if existing extends current (existing has all of current + more)
827                            let existing_extends_current = current_units.len()
828                                < existing_units.len()
829                                && current_units
830                                    .iter()
831                                    .all(|cu| existing_units.iter().any(|eu| eu.name == cu.name));
832                            current_extends_existing || existing_extends_current
833                        }
834                        _ => false,
835                    };
836
837                if might_be_inheritance {
838                    // One type likely extends the other - allow unit sharing via inheritance
839                    continue;
840                }
841
842                return Err(LemmaError::engine(
843                    format!("Ambiguous unit '{}' in document '{}'. Defined in multiple types: {} and {}", unit, doc, existing_name, current_name),
844                    Span { start: 0, end: 0, line: 1, col: 0 },
845                    "<unknown>",
846                    Arc::from(""),
847                    doc,
848                    1,
849                    None::<String>,
850                ));
851            }
852            unit_index.insert(unit, resolved_type.clone());
853        }
854        Ok(())
855    }
856
857    /// Extract all unit names from a TypeSpecification
858    /// Only Scale types can have units (Number types are dimensionless)
859    fn extract_units_from_specs(&self, specs: &TypeSpecification) -> Vec<String> {
860        match specs {
861            TypeSpecification::Scale { units, .. } => {
862                units.iter().map(|unit| unit.name.clone()).collect()
863            }
864            TypeSpecification::Ratio { units, .. } => {
865                units.iter().map(|unit| unit.name.clone()).collect()
866            }
867            _ => Vec::new(),
868        }
869    }
870}
871
872impl Default for TypeRegistry {
873    fn default() -> Self {
874        Self::new()
875    }
876}
877
878#[cfg(test)]
879mod tests {
880    use super::*;
881    use crate::parse;
882    use crate::ResourceLimits;
883    use rust_decimal::Decimal;
884    use std::str::FromStr;
885
886    #[test]
887    fn test_registry_creation() {
888        let registry = TypeRegistry::new();
889        assert!(registry.named_types.is_empty());
890        assert!(registry.inline_type_definitions.is_empty());
891    }
892
893    #[test]
894    fn test_resolve_standard_types() {
895        let registry = TypeRegistry::new();
896
897        assert!(registry.resolve_standard_type("boolean").is_some());
898        assert!(registry.resolve_standard_type("scale").is_some());
899        assert!(registry.resolve_standard_type("number").is_some());
900        assert!(registry.resolve_standard_type("ratio").is_some());
901        assert!(registry.resolve_standard_type("text").is_some());
902        assert!(registry.resolve_standard_type("date").is_some());
903        assert!(registry.resolve_standard_type("time").is_some());
904        assert!(registry.resolve_standard_type("duration").is_some());
905        assert!(registry.resolve_standard_type("unknown").is_none());
906    }
907
908    #[test]
909    fn test_register_named_type() {
910        let mut registry = TypeRegistry::new();
911        let type_def = TypeDef::Regular {
912            name: "money".to_string(),
913            parent: "number".to_string(),
914            overrides: None,
915        };
916
917        let result = registry.register_type("test_doc", type_def);
918        assert!(result.is_ok());
919    }
920
921    #[test]
922    fn test_register_inline_type_definition() {
923        use crate::semantic::FactReference;
924        let mut registry = TypeRegistry::new();
925        let fact_ref = FactReference::local("age".to_string());
926        let type_def = TypeDef::Inline {
927            parent: "number".to_string(),
928            overrides: Some(vec![
929                ("minimum".to_string(), vec!["0".to_string()]),
930                ("maximum".to_string(), vec!["150".to_string()]),
931            ]),
932            fact_ref: fact_ref.clone(),
933            from: None,
934        };
935
936        let result = registry.register_type("test_doc", type_def);
937        assert!(result.is_ok());
938
939        // Verify the inline type definition is registered
940        assert!(registry
941            .inline_type_definitions
942            .get("test_doc")
943            .unwrap()
944            .contains_key(&fact_ref));
945    }
946
947    #[test]
948    fn test_register_duplicate_type_fails() {
949        let mut registry = TypeRegistry::new();
950        let type_def = TypeDef::Regular {
951            name: "money".to_string(),
952            parent: "number".to_string(),
953            overrides: None,
954        };
955
956        registry
957            .register_type("test_doc", type_def.clone())
958            .unwrap();
959        let result = registry.register_type("test_doc", type_def);
960        assert!(result.is_err());
961    }
962
963    #[test]
964    fn test_resolve_custom_type_from_standard() {
965        let mut registry = TypeRegistry::new();
966        let type_def = TypeDef::Regular {
967            name: "money".to_string(),
968            parent: "number".to_string(),
969            overrides: None,
970        };
971
972        registry.register_type("test_doc", type_def).unwrap();
973        let resolved = registry.resolve_types("test_doc").unwrap();
974
975        assert!(resolved.named_types.contains_key("money"));
976        let money_type = resolved.named_types.get("money").unwrap();
977        assert_eq!(money_type.name, Some("money".to_string()));
978    }
979
980    #[test]
981    fn test_type_definition_resolution() {
982        let code = r#"doc test
983type dice = number -> minimum 0 -> maximum 6"#;
984
985        let docs = parse(code, "test.lemma", &ResourceLimits::default()).unwrap();
986        let doc = &docs[0];
987
988        // Use TypeRegistry to resolve the type
989        let mut registry = TypeRegistry::new();
990        registry
991            .register_type(&doc.name, doc.types[0].clone())
992            .unwrap();
993
994        let resolved_types = registry.resolve_types(&doc.name).unwrap();
995        let dice_type = resolved_types.named_types.get("dice").unwrap();
996
997        // Verify it's a Number type (dimensionless) with the correct constraints
998        match &dice_type.specifications {
999            TypeSpecification::Number {
1000                minimum, maximum, ..
1001            } => {
1002                assert_eq!(*minimum, Some(Decimal::from(0)));
1003                assert_eq!(*maximum, Some(Decimal::from(6)));
1004            }
1005            _ => panic!("Expected Number type specifications"),
1006        }
1007    }
1008
1009    #[test]
1010    fn test_type_definition_with_multiple_commands() {
1011        let code = r#"doc test
1012type money = scale -> decimals 2 -> unit eur 1.0 -> unit usd 1.18"#;
1013
1014        let docs = parse(code, "test.lemma", &ResourceLimits::default()).unwrap();
1015        let doc = &docs[0];
1016        let type_def = &doc.types[0];
1017
1018        // Use TypeRegistry to resolve the type
1019        let mut registry = TypeRegistry::new();
1020        registry.register_type(&doc.name, type_def.clone()).unwrap();
1021
1022        let resolved_types = registry.resolve_types(&doc.name).unwrap();
1023        let money_type = resolved_types.named_types.get("money").unwrap();
1024
1025        match &money_type.specifications {
1026            TypeSpecification::Scale {
1027                decimals, units, ..
1028            } => {
1029                assert_eq!(*decimals, Some(2));
1030                assert_eq!(units.len(), 2);
1031                assert!(units.iter().any(|u| u.name == "eur"));
1032                assert!(units.iter().any(|u| u.name == "usd"));
1033            }
1034            _ => panic!("Expected Scale type specifications"),
1035        }
1036    }
1037
1038    #[test]
1039    fn test_number_type_with_decimals() {
1040        let code = r#"doc test
1041type price = number -> decimals 2 -> minimum 0"#;
1042
1043        let docs = parse(code, "test.lemma", &ResourceLimits::default()).unwrap();
1044        let doc = &docs[0];
1045
1046        // Use TypeRegistry to resolve the type
1047        let mut registry = TypeRegistry::new();
1048        registry
1049            .register_type(&doc.name, doc.types[0].clone())
1050            .unwrap();
1051
1052        let resolved_types = registry.resolve_types(&doc.name).unwrap();
1053        let price_type = resolved_types.named_types.get("price").unwrap();
1054
1055        // Verify it's a Number type with decimals set to 2
1056        match &price_type.specifications {
1057            TypeSpecification::Number {
1058                decimals, minimum, ..
1059            } => {
1060                assert_eq!(*decimals, Some(2));
1061                assert_eq!(*minimum, Some(Decimal::from(0)));
1062            }
1063            _ => panic!("Expected Number type specifications with decimals"),
1064        }
1065    }
1066
1067    #[test]
1068    fn test_number_type_decimals_only() {
1069        let code = r#"doc test
1070type precise_number = number -> decimals 4"#;
1071
1072        let docs = parse(code, "test.lemma", &ResourceLimits::default()).unwrap();
1073        let doc = &docs[0];
1074
1075        let mut registry = TypeRegistry::new();
1076        registry
1077            .register_type(&doc.name, doc.types[0].clone())
1078            .unwrap();
1079
1080        let resolved_types = registry.resolve_types(&doc.name).unwrap();
1081        let precise_type = resolved_types.named_types.get("precise_number").unwrap();
1082
1083        match &precise_type.specifications {
1084            TypeSpecification::Number { decimals, .. } => {
1085                assert_eq!(*decimals, Some(4));
1086            }
1087            _ => panic!("Expected Number type with decimals 4"),
1088        }
1089    }
1090
1091    #[test]
1092    fn test_scale_type_decimals_only() {
1093        let code = r#"doc test
1094type weight = scale -> decimals 3"#;
1095
1096        let docs = parse(code, "test.lemma", &ResourceLimits::default()).unwrap();
1097        let doc = &docs[0];
1098
1099        let mut registry = TypeRegistry::new();
1100        registry
1101            .register_type(&doc.name, doc.types[0].clone())
1102            .unwrap();
1103
1104        let resolved_types = registry.resolve_types(&doc.name).unwrap();
1105        let weight_type = resolved_types.named_types.get("weight").unwrap();
1106
1107        match &weight_type.specifications {
1108            TypeSpecification::Scale { decimals, .. } => {
1109                assert_eq!(*decimals, Some(3));
1110            }
1111            _ => panic!("Expected Scale type with decimals 3"),
1112        }
1113    }
1114
1115    #[test]
1116    fn test_ratio_type_rejects_decimals_command() {
1117        let code = r#"doc test
1118type ratio_type = ratio -> decimals 2"#;
1119
1120        let docs = parse(code, "test.lemma", &ResourceLimits::default()).unwrap();
1121        let doc = &docs[0];
1122
1123        let mut registry = TypeRegistry::new();
1124        // register_type only stores the definition, it doesn't apply overrides
1125        registry
1126            .register_type(&doc.name, doc.types[0].clone())
1127            .unwrap();
1128
1129        // resolve_types applies overrides, so the error should occur here
1130        let result = registry.resolve_types(&doc.name);
1131
1132        // Ratio type should reject decimals command since it doesn't have a decimals field
1133        assert!(
1134            result.is_err(),
1135            "Ratio type should reject decimals command during resolution"
1136        );
1137        let error_msg = result.unwrap_err().to_string();
1138        assert!(
1139            error_msg.contains("decimals") || error_msg.contains("Invalid command"),
1140            "Error message should mention decimals or invalid command, got: {}",
1141            error_msg
1142        );
1143    }
1144
1145    #[test]
1146    fn test_ratio_type_with_default_command() {
1147        let code = r#"doc test
1148type percentage = ratio -> minimum 0 -> maximum 1 -> default 0.5"#;
1149
1150        let docs = parse(code, "test.lemma", &ResourceLimits::default()).unwrap();
1151        let doc = &docs[0];
1152
1153        let mut registry = TypeRegistry::new();
1154        registry
1155            .register_type(&doc.name, doc.types[0].clone())
1156            .unwrap();
1157
1158        let resolved_types = registry.resolve_types(&doc.name).unwrap();
1159        let percentage_type = resolved_types.named_types.get("percentage").unwrap();
1160
1161        match &percentage_type.specifications {
1162            TypeSpecification::Ratio {
1163                minimum,
1164                maximum,
1165                default,
1166                ..
1167            } => {
1168                assert_eq!(*minimum, Some(Decimal::from(0)));
1169                assert_eq!(*maximum, Some(Decimal::from(1)));
1170                assert_eq!(*default, Some(Decimal::from_str("0.5").unwrap()));
1171            }
1172            _ => panic!("Expected Ratio type specifications"),
1173        }
1174    }
1175
1176    #[test]
1177    fn test_invalid_parent_type_in_named_type_should_error() {
1178        let code = r#"doc test
1179type invalid = nonexistent_type -> minimum 0"#;
1180
1181        let docs = parse(code, "test.lemma", &ResourceLimits::default()).unwrap();
1182        let doc = &docs[0];
1183
1184        let mut registry = TypeRegistry::new();
1185        registry
1186            .register_type(&doc.name, doc.types[0].clone())
1187            .unwrap();
1188
1189        let result = registry.resolve_types(&doc.name);
1190        assert!(result.is_err(), "Should reject invalid parent type");
1191
1192        let error_msg = result.unwrap_err().to_string();
1193        assert!(
1194            error_msg.contains("Unknown type") && error_msg.contains("nonexistent_type"),
1195            "Error should mention unknown type. Got: {}",
1196            error_msg
1197        );
1198    }
1199
1200    #[test]
1201    fn test_invalid_standard_type_name_should_error() {
1202        // "choice" is not a standard type; this should fail resolution.
1203        let code = r#"doc test
1204type invalid = choice -> option "a""#;
1205
1206        let docs = parse(code, "test.lemma", &ResourceLimits::default()).unwrap();
1207        let doc = &docs[0];
1208
1209        let mut registry = TypeRegistry::new();
1210        registry
1211            .register_type(&doc.name, doc.types[0].clone())
1212            .unwrap();
1213
1214        let result = registry.resolve_types(&doc.name);
1215        assert!(result.is_err(), "Should reject invalid type base 'choice'");
1216
1217        let error_msg = result.unwrap_err().to_string();
1218        assert!(
1219            error_msg.contains("Unknown type") && error_msg.contains("choice"),
1220            "Error should mention unknown type 'choice'. Got: {}",
1221            error_msg
1222        );
1223    }
1224
1225    #[test]
1226    fn test_unit_override_validation_errors_are_reported() {
1227        // Regression guard: overriding existing units should not silently succeed.
1228        let code = r#"doc test
1229type money = scale
1230  -> unit eur 1.00
1231  -> unit usd 1.19
1232
1233type money2 = money
1234  -> unit eur 1.20
1235  -> unit usd 1.21
1236  -> unit gbp 1.30"#;
1237
1238        let docs = parse(code, "test.lemma", &ResourceLimits::default()).unwrap();
1239        let doc = &docs[0];
1240
1241        let mut registry = TypeRegistry::new();
1242        for type_def in &doc.types {
1243            registry.register_type(&doc.name, type_def.clone()).unwrap();
1244        }
1245
1246        let result = registry.resolve_types(&doc.name);
1247        assert!(result.is_err(), "Expected unit override conflicts to error");
1248
1249        let error_msg = result.unwrap_err().to_string();
1250        assert!(
1251            error_msg.contains("eur") || error_msg.contains("usd"),
1252            "Error should mention the conflicting units. Got: {}",
1253            error_msg
1254        );
1255    }
1256
1257    #[test]
1258    fn test_document_level_unit_ambiguity_errors_are_reported() {
1259        // Regression guard: the same unit name must not be defined by multiple types in one doc.
1260        let code = r#"doc test
1261type money_a = scale
1262  -> unit eur 1.00
1263  -> unit usd 1.19
1264
1265type money_b = scale
1266  -> unit eur 1.00
1267  -> unit usd 1.20
1268
1269type length_a = scale
1270  -> unit meter 1.0
1271
1272type length_b = scale
1273  -> unit meter 1.0"#;
1274
1275        let docs = parse(code, "test.lemma", &ResourceLimits::default()).unwrap();
1276        let doc = &docs[0];
1277
1278        let mut registry = TypeRegistry::new();
1279        for type_def in &doc.types {
1280            registry.register_type(&doc.name, type_def.clone()).unwrap();
1281        }
1282
1283        let result = registry.resolve_types(&doc.name);
1284        assert!(
1285            result.is_err(),
1286            "Expected ambiguous unit definitions to error"
1287        );
1288
1289        let error_msg = result.unwrap_err().to_string();
1290        assert!(
1291            error_msg.contains("eur") || error_msg.contains("usd") || error_msg.contains("meter"),
1292            "Error should mention at least one ambiguous unit. Got: {}",
1293            error_msg
1294        );
1295    }
1296
1297    #[test]
1298    fn test_number_type_cannot_have_units() {
1299        let code = r#"doc test
1300type price = number
1301  -> unit eur 1.00"#;
1302
1303        let docs = parse(code, "test.lemma", &ResourceLimits::default()).unwrap();
1304        let doc = &docs[0];
1305
1306        let mut registry = TypeRegistry::new();
1307        registry
1308            .register_type(&doc.name, doc.types[0].clone())
1309            .unwrap();
1310
1311        let result = registry.resolve_types(&doc.name);
1312        assert!(result.is_err(), "Number types must reject unit commands");
1313
1314        let error_msg = result.unwrap_err().to_string();
1315        assert!(
1316            error_msg.contains("unit") && error_msg.contains("number"),
1317            "Error should mention units are invalid on number. Got: {}",
1318            error_msg
1319        );
1320    }
1321
1322    #[test]
1323    fn test_scale_type_can_have_units() {
1324        let code = r#"doc test
1325type money = scale
1326  -> unit eur 1.00
1327  -> unit usd 1.19"#;
1328
1329        let docs = parse(code, "test.lemma", &ResourceLimits::default()).unwrap();
1330        let doc = &docs[0];
1331
1332        let mut registry = TypeRegistry::new();
1333        registry
1334            .register_type(&doc.name, doc.types[0].clone())
1335            .unwrap();
1336
1337        let resolved = registry.resolve_types(&doc.name).unwrap();
1338        let money_type = resolved.named_types.get("money").unwrap();
1339
1340        match &money_type.specifications {
1341            TypeSpecification::Scale { units, .. } => {
1342                assert_eq!(units.len(), 2);
1343                assert!(units.iter().any(|u| u.name == "eur"));
1344                assert!(units.iter().any(|u| u.name == "usd"));
1345            }
1346            other => panic!("Expected Scale type specifications, got {:?}", other),
1347        }
1348    }
1349
1350    #[test]
1351    fn test_extending_type_inherits_units() {
1352        let code = r#"doc test
1353type money = scale
1354  -> unit eur 1.00
1355  -> unit usd 1.19
1356
1357type my_money = money
1358  -> unit gbp 1.30"#;
1359
1360        let docs = parse(code, "test.lemma", &ResourceLimits::default()).unwrap();
1361        let doc = &docs[0];
1362
1363        let mut registry = TypeRegistry::new();
1364        for type_def in &doc.types {
1365            registry.register_type(&doc.name, type_def.clone()).unwrap();
1366        }
1367
1368        let resolved = registry.resolve_types(&doc.name).unwrap();
1369        let my_money_type = resolved.named_types.get("my_money").unwrap();
1370
1371        match &my_money_type.specifications {
1372            TypeSpecification::Scale { units, .. } => {
1373                assert_eq!(units.len(), 3);
1374                assert!(units.iter().any(|u| u.name == "eur"));
1375                assert!(units.iter().any(|u| u.name == "usd"));
1376                assert!(units.iter().any(|u| u.name == "gbp"));
1377            }
1378            other => panic!("Expected Scale type specifications, got {:?}", other),
1379        }
1380    }
1381
1382    #[test]
1383    fn test_duplicate_unit_in_same_type_is_rejected() {
1384        let code = r#"doc test
1385type money = scale
1386  -> unit eur 1.00
1387  -> unit eur 1.19"#;
1388
1389        let docs = parse(code, "test.lemma", &ResourceLimits::default()).unwrap();
1390        let doc = &docs[0];
1391
1392        let mut registry = TypeRegistry::new();
1393        registry
1394            .register_type(&doc.name, doc.types[0].clone())
1395            .unwrap();
1396
1397        let result = registry.resolve_types(&doc.name);
1398        assert!(
1399            result.is_err(),
1400            "Duplicate units within a type should error"
1401        );
1402
1403        let error_msg = result.unwrap_err().to_string();
1404        assert!(
1405            error_msg.contains("Duplicate unit")
1406                || error_msg.contains("duplicate")
1407                || error_msg.contains("already exists")
1408                || error_msg.contains("eur"),
1409            "Error should mention duplicate unit issue. Got: {}",
1410            error_msg
1411        );
1412    }
1413}