odra_ir/module_item/
module_struct.rs

1use super::{ModuleEvent, ModuleEvents};
2use crate::{attrs::partition_attributes, utils::FieldsValidator};
3use anyhow::{Context, Result};
4use syn::{parse_quote, punctuated::Punctuated, Field, Fields, FieldsNamed, Token};
5
6use super::ModuleConfiguration;
7
8/// Odra module struct.
9///
10/// Wraps up [syn::ItemStruct].
11pub struct ModuleStruct {
12    pub is_instantiable: bool,
13    pub item: syn::ItemStruct,
14    pub events: ModuleEvents,
15    pub delegated_fields: Vec<DelegatedField>
16}
17
18pub struct DelegatedField {
19    pub field: syn::Field,
20    pub delegated_fields: Vec<String>
21}
22
23impl DelegatedField {
24    pub(crate) fn validate(&self, fields: &syn::Fields) -> Result<(), syn::Error> {
25        let fields = fields
26            .iter()
27            .filter_map(|f| f.ident.clone())
28            .map(|i| i.to_string())
29            .collect::<Vec<_>>();
30        let is_valid = self.delegated_fields.iter().find(|f| !fields.contains(f));
31        if let Some(invalid_ref) = is_valid {
32            let error_msg = format!("Using non-existing field {}", invalid_ref);
33            return Err(syn::Error::new_spanned(&self.field, error_msg));
34        }
35
36        Ok(())
37    }
38}
39
40impl TryFrom<syn::Field> for DelegatedField {
41    type Error = syn::Error;
42
43    fn try_from(value: syn::Field) -> std::result::Result<Self, Self::Error> {
44        let (odra_attrs, other_attrs) = partition_attributes(value.attrs.clone()).unwrap();
45
46        let delegated_fields = odra_attrs
47            .iter()
48            .flat_map(|attr| attr.using())
49            .collect::<Vec<_>>();
50
51        let field_ident = value.ident.clone().unwrap().to_string();
52
53        if delegated_fields.contains(&field_ident) {
54            return Err(syn::Error::new_spanned(&value, "Self-using is not allowed"));
55        }
56
57        Ok(Self {
58            field: Field {
59                attrs: other_attrs,
60                ..value
61            },
62            delegated_fields
63        })
64    }
65}
66
67impl ModuleStruct {
68    pub fn with_config(mut self, mut config: ModuleConfiguration) -> Result<Self, syn::Error> {
69        let submodules = self
70            .item
71            .fields
72            .iter()
73            .filter(|field| field.ident.is_some())
74            .filter_map(filter_primitives)
75            .map(|ident| ModuleEvent { ty: ident })
76            .collect::<Vec<_>>();
77
78        let mut mappings = self
79            .item
80            .fields
81            .iter()
82            .filter(|field| field.ident.is_some())
83            .filter_map(|f| match &f.ty {
84                syn::Type::Path(path) => extract_mapping_value_ident_from_path(path).ok(),
85                _ => None
86            })
87            .map(|ty| ModuleEvent { ty })
88            .collect::<Vec<_>>();
89        mappings.dedup();
90
91        config.events.submodules_events.extend(submodules);
92        config.events.mappings_events.extend(mappings);
93
94        self.events = config.events;
95
96        Ok(self)
97    }
98}
99
100impl TryFrom<syn::ItemStruct> for ModuleStruct {
101    type Error = syn::Error;
102
103    fn try_from(value: syn::ItemStruct) -> std::result::Result<Self, Self::Error> {
104        FieldsValidator::from(&value).result()?;
105
106        let (_, other_attrs) = partition_attributes(value.attrs).unwrap();
107
108        let named = value
109            .fields
110            .clone()
111            .into_iter()
112            .map(|field| {
113                let (_, other_attrs) = partition_attributes(field.attrs).unwrap();
114                Field {
115                    attrs: other_attrs,
116                    ..field
117                }
118            })
119            .collect::<Punctuated<Field, Token![,]>>();
120
121        let fields: Fields = Fields::Named(FieldsNamed {
122            brace_token: Default::default(),
123            named
124        });
125
126        let delegated_fields = value
127            .fields
128            .into_iter()
129            .map(TryInto::try_into)
130            .collect::<Result<Vec<DelegatedField>, syn::Error>>()?;
131
132        delegated_fields
133            .iter()
134            .try_for_each(|f| f.validate(&fields))?;
135
136        Ok(Self {
137            is_instantiable: true,
138            item: syn::ItemStruct {
139                attrs: other_attrs,
140                fields,
141                ..value
142            },
143            events: Default::default(),
144            delegated_fields
145        })
146    }
147}
148
149fn extract_mapping_value_ident_from_path(path: &syn::TypePath) -> Result<syn::Type> {
150    // Eg. odra::type::Mapping<String, Mapping<String, Mapping<u8, String>>>
151    let mut segment = path
152        .path
153        .segments
154        .last()
155        .cloned()
156        .context("At least one segment expected")?;
157    if segment.ident != "Mapping" {
158        return Err(anyhow::anyhow!(
159            "Mapping expected but found {}",
160            segment.ident
161        ));
162    }
163    let mut result: Option<syn::Type> = None;
164    loop {
165        let args = &segment.arguments;
166        if args.is_empty() {
167            break;
168        }
169        if let syn::PathArguments::AngleBracketed(args) = args {
170            match args
171                .args
172                .last()
173                .context("syn::GenericArgument expected but not found")?
174            {
175                syn::GenericArgument::Type(syn::Type::Path(path)) => {
176                    result = Some(parse_quote!(#path));
177                    let path = &path.path;
178                    segment = path
179                        .segments
180                        .last()
181                        .cloned()
182                        .context("At least one segment expected")?;
183                }
184                other => {
185                    return Err(anyhow::anyhow!(
186                        "syn::TypePath expected but found {:?}",
187                        other
188                    ))
189                }
190            }
191        } else {
192            return Err(anyhow::anyhow!(
193                "syn::AngleBracketedGenericArguments expected but found {:?}",
194                args
195            ));
196        }
197    }
198    Ok(result.unwrap())
199}
200
201fn filter_primitives(field: &syn::Field) -> Option<syn::Type> {
202    filter_ident(field, &["Variable", "Mapping", "List", "Sequence"])
203}
204
205fn filter_ident(field: &syn::Field, exclusions: &'static [&str]) -> Option<syn::Type> {
206    match &field.ty {
207        syn::Type::Path(path) => {
208            let path = &path.path;
209            match &path.segments.last() {
210                Some(segment) => {
211                    if exclusions.contains(&segment.ident.to_string().as_str()) {
212                        return None;
213                    }
214                    Some(field.ty.clone())
215                }
216                _ => None
217            }
218        }
219        _ => None
220    }
221}
222
223#[cfg(test)]
224mod test {
225    use quote::ToTokens;
226
227    use super::*;
228
229    #[test]
230    fn test() {
231        let path = syn::parse_str::<syn::TypePath>(
232            "Mapping<String, Mapping<String, Mapping<u8, a::b::String>>>"
233        )
234        .unwrap();
235        let ident = extract_mapping_value_ident_from_path(&path);
236        assert_eq!(
237            ident.unwrap().into_token_stream().to_string(),
238            "a :: b :: String"
239        );
240
241        // Mapping expected but found String
242        let path = syn::parse_str::<syn::TypePath>("String<i32, u8, u16>").unwrap();
243        let ident = extract_mapping_value_ident_from_path(&path);
244        assert!(ident.is_err());
245
246        // Invalid Mapping - parenthesized arguments instead of angle bracketed
247        let path = syn::parse_str::<syn::TypePath>(
248            "Mapping<String, Mapping<String, Mapping(u8, String)>>"
249        )
250        .unwrap();
251        let ident = extract_mapping_value_ident_from_path(&path);
252        assert!(ident.is_err());
253
254        // Invalid Mapping - function type instead of type
255        let path = syn::parse_str::<syn::TypePath>(
256            "Mapping<String, Mapping<String, Mapping<fn(usize) -> bool>>>"
257        )
258        .unwrap();
259        let ident = extract_mapping_value_ident_from_path(&path);
260        assert!(ident.is_err());
261    }
262}