slicec/patchers/
type_ref_patcher.rs

1// Copyright (c) ZeroC, Inc.
2
3use crate::ast::node::Node;
4use crate::ast::{Ast, LookupError};
5use crate::compilation_state::CompilationState;
6use crate::diagnostics::*;
7use crate::grammar::attributes::Deprecated;
8use crate::grammar::*;
9use crate::utils::ptr_util::{OwnedPtr, WeakPtr};
10
11pub unsafe fn patch_ast(compilation_state: &mut CompilationState) {
12    let mut patcher = TypeRefPatcher {
13        type_ref_patches: Vec::new(),
14        diagnostics: &mut compilation_state.diagnostics,
15    };
16
17    // TODO why explain we split this logic so that we can for sure have an immutable AST.
18    patcher.compute_patches(&compilation_state.ast);
19    patcher.apply_patches(&mut compilation_state.ast);
20}
21
22struct TypeRefPatcher<'a> {
23    type_ref_patches: Vec<PatchKind>,
24    diagnostics: &'a mut Diagnostics,
25}
26
27impl TypeRefPatcher<'_> {
28    fn compute_patches(&mut self, ast: &Ast) {
29        for node in ast.as_slice() {
30            let patch = match node {
31                Node::Class(class_ptr) => class_ptr
32                    .borrow()
33                    .base
34                    .as_ref()
35                    .and_then(|type_ref| self.resolve_definition(type_ref, ast))
36                    .map(PatchKind::BaseClass),
37                Node::Exception(exception_ptr) => exception_ptr
38                    .borrow()
39                    .base
40                    .as_ref()
41                    .and_then(|type_ref| self.resolve_definition(type_ref, ast))
42                    .map(PatchKind::BaseException),
43                Node::Field(field_ptr) => {
44                    let type_ref = &field_ptr.borrow().data_type;
45                    self.resolve_definition(type_ref, ast).map(PatchKind::FieldType)
46                }
47                Node::Interface(interface_ptr) => {
48                    interface_ptr.borrow().bases.iter()
49                        .map(|type_ref| self.resolve_definition(type_ref, ast))
50                        .collect::<Option<Vec<_>>>() // None if any of the bases couldn't be resolved.
51                        .map(PatchKind::BaseInterfaces)
52                }
53                Node::Operation(operation_ptr) => {
54                    operation_ptr.borrow().exception_specification.iter()
55                        .map(|type_ref| self.resolve_definition(type_ref, ast))
56                        .collect::<Option<Vec<_>>>() // None if any of the exceptions couldn't be resolved.
57                        .map(PatchKind::ExceptionSpecification)
58                }
59                Node::Parameter(parameter_ptr) => {
60                    let type_ref = &parameter_ptr.borrow().data_type;
61                    self.resolve_definition(type_ref, ast).map(PatchKind::ParameterType)
62                }
63                Node::Enum(enum_ptr) => enum_ptr
64                    .borrow()
65                    .underlying
66                    .as_ref()
67                    .and_then(|type_ref| self.resolve_definition(type_ref, ast))
68                    .map(PatchKind::EnumUnderlyingType),
69                Node::TypeAlias(type_alias_ptr) => {
70                    let type_ref = &type_alias_ptr.borrow().underlying;
71                    self.resolve_definition(type_ref, ast)
72                        .map(PatchKind::TypeAliasUnderlyingType)
73                }
74                Node::ResultType(result_ptr) => {
75                    let result_type = result_ptr.borrow();
76                    let success_patch = self.resolve_definition(&result_type.success_type, ast);
77                    let failure_patch = self.resolve_definition(&result_type.failure_type, ast);
78                    Some(PatchKind::ResultTypes(success_patch, failure_patch))
79                }
80                Node::Sequence(sequence_ptr) => {
81                    let type_ref = &sequence_ptr.borrow().element_type;
82                    self.resolve_definition(type_ref, ast).map(PatchKind::SequenceType)
83                }
84                Node::Dictionary(dictionary_ptr) => {
85                    let dictionary_def = dictionary_ptr.borrow();
86                    let key_patch = self.resolve_definition(&dictionary_def.key_type, ast);
87                    let value_patch = self.resolve_definition(&dictionary_def.value_type, ast);
88                    Some(PatchKind::DictionaryTypes(key_patch, value_patch))
89                }
90                _ => None,
91            };
92            self.type_ref_patches.push(patch.unwrap_or_default());
93        }
94    }
95
96    unsafe fn apply_patches(self, ast: &mut Ast) {
97        let elements = ast.as_mut_slice();
98
99        // There should be 1 patch per AST node.
100        debug_assert_eq!(elements.len(), self.type_ref_patches.len());
101
102        // Simultaneously iterate through patches and AST nodes, and apply each patch to its corresponding node.
103        //
104        // Each match arm is broken into 2 steps, separated by a comment. First we navigate to the TypeRefs that needs
105        // patching, then we patch in its definition and any attributes it might of picked up from type aliases.
106        for (patch, element) in self.type_ref_patches.into_iter().zip(elements) {
107            match patch {
108                PatchKind::BaseClass((base_class_ptr, attributes)) => {
109                    let class_ptr: &mut OwnedPtr<Class> = element.try_into().unwrap();
110                    let base_class_ref = class_ptr.borrow_mut().base.as_mut().unwrap();
111                    base_class_ref.patch(base_class_ptr, attributes);
112                }
113                PatchKind::BaseException((base_exception_ptr, attributes)) => {
114                    let exception_ptr: &mut OwnedPtr<Exception> = element.try_into().unwrap();
115                    let base_exception_ref = exception_ptr.borrow_mut().base.as_mut().unwrap();
116                    base_exception_ref.patch(base_exception_ptr, attributes);
117                }
118                PatchKind::BaseInterfaces(base_interface_patches) => {
119                    let interface_ptr: &mut OwnedPtr<Interface> = element.try_into().unwrap();
120                    // Ensure the number of patches is equal to the number of base interfaces.
121                    debug_assert_eq!(interface_ptr.borrow().bases.len(), base_interface_patches.len());
122
123                    // Iterate through and patch each base interface.
124                    for (j, patch) in base_interface_patches.into_iter().enumerate() {
125                        let (base_interface_ptr, attributes) = patch;
126                        let base_interface_ref = &mut interface_ptr.borrow_mut().bases[j];
127                        base_interface_ref.patch(base_interface_ptr, attributes);
128                    }
129                }
130                PatchKind::FieldType((field_type_ptr, attributes)) => {
131                    let field_ptr: &mut OwnedPtr<Field> = element.try_into().unwrap();
132                    let field_type_ref = &mut field_ptr.borrow_mut().data_type;
133                    field_type_ref.patch(field_type_ptr, attributes);
134                }
135                PatchKind::ParameterType((parameter_type_ptr, attributes)) => {
136                    let parameter_ptr: &mut OwnedPtr<Parameter> = element.try_into().unwrap();
137                    let parameter_type_ref = &mut parameter_ptr.borrow_mut().data_type;
138                    parameter_type_ref.patch(parameter_type_ptr, attributes);
139                }
140                PatchKind::ExceptionSpecification(exception_patches) => {
141                    let operation_ptr: &mut OwnedPtr<Operation> = element.try_into().unwrap();
142                    let exception_specification = &mut operation_ptr.borrow_mut().exception_specification;
143                    // Ensure the number of patches is equal to the number of exceptions.
144                    debug_assert_eq!(exception_specification.len(), exception_patches.len());
145
146                    // Iterate through and patch each exception type.
147                    for (j, patch) in exception_patches.into_iter().enumerate() {
148                        let (exception_type_ptr, attributes) = patch;
149                        let exception_type_ref = &mut exception_specification[j];
150                        exception_type_ref.patch(exception_type_ptr, attributes);
151                    }
152                }
153                PatchKind::EnumUnderlyingType((enum_underlying_type_ptr, attributes)) => {
154                    let enum_ptr: &mut OwnedPtr<Enum> = element.try_into().unwrap();
155                    let enum_underlying_type_ref = enum_ptr.borrow_mut().underlying.as_mut().unwrap();
156                    enum_underlying_type_ref.patch(enum_underlying_type_ptr, attributes);
157                }
158                PatchKind::TypeAliasUnderlyingType((type_alias_underlying_type_ptr, attributes)) => {
159                    let type_alias_ptr: &mut OwnedPtr<TypeAlias> = element.try_into().unwrap();
160                    let type_alias_underlying_type_ref = &mut type_alias_ptr.borrow_mut().underlying;
161                    type_alias_underlying_type_ref.patch(type_alias_underlying_type_ptr, attributes);
162                }
163                PatchKind::ResultTypes(success_patch, failure_patch) => {
164                    let result_ptr: &mut OwnedPtr<ResultType> = element.try_into().unwrap();
165                    if let Some((success_type_ptr, attributes)) = success_patch {
166                        result_ptr.borrow_mut().success_type.patch(success_type_ptr, attributes);
167                    }
168                    if let Some((failure_type_ptr, attributes)) = failure_patch {
169                        result_ptr.borrow_mut().failure_type.patch(failure_type_ptr, attributes);
170                    }
171                }
172                PatchKind::SequenceType((element_type_ptr, attributes)) => {
173                    let sequence_ptr: &mut OwnedPtr<Sequence> = element.try_into().unwrap();
174                    let element_type_ref = &mut sequence_ptr.borrow_mut().element_type;
175                    element_type_ref.patch(element_type_ptr, attributes);
176                }
177                PatchKind::DictionaryTypes(key_patch, value_patch) => {
178                    let dictionary_ptr: &mut OwnedPtr<Dictionary> = element.try_into().unwrap();
179                    if let Some((key_type_ptr, attributes)) = key_patch {
180                        dictionary_ptr.borrow_mut().key_type.patch(key_type_ptr, attributes);
181                    }
182                    if let Some((value_type_ptr, attributes)) = value_patch {
183                        dictionary_ptr.borrow_mut().value_type.patch(value_type_ptr, attributes);
184                    }
185                }
186                PatchKind::None => {}
187            }
188        }
189    }
190
191    fn resolve_definition<'a, T>(&mut self, type_ref: &TypeRef<T>, ast: &'a Ast) -> Option<Patch<T>>
192    where
193        T: Element + ?Sized,
194        &'a Node: TryInto<WeakPtr<T>, Error = LookupError>,
195    {
196        // If the definition is already patched, we skip the function and return `None` immediately.
197        // Otherwise we retrieve the type string and try to resolve it in the ast.
198        let TypeRefDefinition::Unpatched(identifier) = &type_ref.definition else { return None };
199
200        // There are 3 steps to type resolution.
201        // First, lookup the type as a node in the AST.
202        // Second, handle the case where the type is an alias (by resolving down to its concrete underlying type).
203        // Third, get the type's pointer from its node and attempt to cast it to `T` (the required Slice type).
204        let lookup_result = ast
205            .find_node_with_scope(&identifier.value, type_ref.module_scope())
206            .and_then(|node| {
207                // We perform the deprecation check here instead of the validators since we need to check type-aliases
208                // which are resolved and erased after TypeRef patching is completed.
209                self.check_for_deprecated_type(type_ref, node);
210
211                if let Node::TypeAlias(type_alias) = node {
212                    self.resolve_type_alias(type_alias.borrow(), ast)
213                } else {
214                    try_into_patch(node, Vec::new())
215                }
216            });
217
218        // If we resolved a definition for the type reference, return it, otherwise report what went wrong.
219        match lookup_result {
220            Ok(definition) => Some(definition),
221            Err(err) => {
222                let mapped_error = match err {
223                    LookupError::DoesNotExist { identifier } => Error::DoesNotExist { identifier },
224                    LookupError::TypeMismatch {
225                        expected,
226                        actual,
227                        is_concrete,
228                    } => Error::TypeMismatch {
229                        expected,
230                        actual,
231                        is_concrete,
232                    },
233                };
234                Diagnostic::new(mapped_error)
235                    .set_span(identifier.span())
236                    .push_into(self.diagnostics);
237                None
238            }
239        }
240    }
241
242    fn check_for_deprecated_type<T: Element + ?Sized>(&mut self, type_ref: &TypeRef<T>, node: &Node) {
243        // Check if the type is an entity, and if so, check if it has the `deprecated` attribute.
244        // Only entities can be deprecated, so this check is sufficient.
245        if let Ok(entity) = <&dyn Entity>::try_from(node) {
246            if let Some(deprecated) = entity.find_attribute::<Deprecated>() {
247                // Compute the lint message. The `deprecated` attribute can have either 0 or 1 arguments, so we
248                // only check the first argument. If it's present, we attach it to the lint message.
249                let identifier = entity.identifier().to_owned();
250                let reason = deprecated.reason.clone();
251                Diagnostic::new(Lint::Deprecated { identifier, reason })
252                    .set_span(type_ref.span())
253                    .set_scope(type_ref.parser_scope())
254                    .add_note(
255                        format!("{} was deprecated here:", entity.identifier()),
256                        Some(entity.span()),
257                    )
258                    .push_into(self.diagnostics);
259            }
260        }
261    }
262
263    fn resolve_type_alias<'a, T>(&mut self, type_alias: &'a TypeAlias, ast: &'a Ast) -> Result<Patch<T>, LookupError>
264    where
265        T: Element + ?Sized,
266        &'a Node: TryInto<WeakPtr<T>, Error = LookupError>,
267    {
268        // TODO this function is run once per type-alias usage, so we will report multiple errors for cyclic aliases,
269        // once for each use. It would be better to only report a single error per cyclic alias.
270
271        // In case there's a chain of type aliases, we maintain a stack of all the ones we've seen.
272        // While resolving the chain, if we see a type alias already in this vector, a cycle is present.
273        let mut type_alias_chain = Vec::new();
274
275        let mut attributes: Vec<WeakPtr<Attribute>> = Vec::new();
276        let mut current_type_alias = type_alias;
277        loop {
278            let type_alias_id = current_type_alias.module_scoped_identifier();
279
280            // If we've already seen the current type alias, it must have a cycle in it's definition.
281            // So we return a `DoesNotExist` error, since there's no way to resolve the original type alias.
282            if type_alias_chain.contains(&type_alias_id) {
283                // If the current type alias is the one we started with, we know it's the cause of a cycle, so we report
284                // an error for it. This check makes sure we don't report errors for type aliases that aren't cyclic,
285                // but use another type-alias which is. In this case, the chain contains it, but it won't be the first.
286                if type_alias_chain.first() == Some(&type_alias_id) {
287                    Diagnostic::new(Error::SelfReferentialTypeAliasNeedsConcreteType {
288                        identifier: current_type_alias.module_scoped_identifier(),
289                    })
290                    .set_span(current_type_alias.span())
291                    .add_note("failed to resolve type due to a cycle in its definition", None)
292                    .add_note(
293                        format!("cycle: {} -> {}", type_alias_chain.join(" -> "), type_alias_id),
294                        None,
295                    )
296                    .push_into(self.diagnostics);
297                }
298                return Err(LookupError::DoesNotExist {
299                    identifier: current_type_alias.module_scoped_identifier(),
300                });
301            }
302
303            // If we reach this point, we haven't hit a cycle in the type aliases yet.
304
305            type_alias_chain.push(current_type_alias.module_scoped_identifier());
306            let underlying_type = &current_type_alias.underlying;
307            attributes.extend(underlying_type.attributes.clone());
308
309            // If we hit a type alias that is already patched, we immediately return its underlying type.
310            // Otherwise we retrieve the alias' type string and try to resolve it in the ast.
311            let identifier = match &underlying_type.definition {
312                TypeRefDefinition::Patched(ptr) => {
313                    // Lookup the node that is being aliased in the AST, and convert it into a patch.
314                    // TODO: when `T = dyn Type` we can skip this, and use `ptr.clone()` directly.
315                    let node = ast.as_slice().iter().find(|node| ptr == &<&dyn Element>::from(*node));
316                    return try_into_patch(node.unwrap(), attributes);
317                }
318                TypeRefDefinition::Unpatched(identifier) => identifier,
319            };
320
321            // We hit another unpatched alias; try to resolve its underlying type's identifier in the AST.
322            let node = ast.find_node_with_scope(&identifier.value, underlying_type.module_scope())?;
323            // If the resolved node is another type alias, push it onto the chain and loop again, otherwise return it.
324            if let Node::TypeAlias(next_type_alias) = node {
325                current_type_alias = next_type_alias.borrow();
326            } else {
327                return try_into_patch(node, attributes);
328            }
329        }
330    }
331}
332
333type Patch<T> = (WeakPtr<T>, Vec<WeakPtr<Attribute>>);
334
335#[derive(Default)]
336enum PatchKind {
337    #[default]
338    None,
339    BaseClass(Patch<Class>),
340    BaseException(Patch<Exception>),
341    BaseInterfaces(Vec<Patch<Interface>>),
342    FieldType(Patch<dyn Type>),
343    ParameterType(Patch<dyn Type>),
344    ExceptionSpecification(Vec<Patch<Exception>>),
345    EnumUnderlyingType(Patch<Primitive>),
346    TypeAliasUnderlyingType(Patch<dyn Type>),
347    ResultTypes(Option<Patch<dyn Type>>, Option<Patch<dyn Type>>),
348    SequenceType(Patch<dyn Type>),
349    DictionaryTypes(Option<Patch<dyn Type>>, Option<Patch<dyn Type>>),
350}
351
352fn try_into_patch<'a, T: ?Sized>(node: &'a Node, attributes: Vec<WeakPtr<Attribute>>) -> Result<Patch<T>, LookupError>
353where
354    &'a Node: TryInto<WeakPtr<T>, Error = LookupError>,
355{
356    node.try_into().map(|ptr| (ptr, attributes))
357}