calyx_opt/traversal/
construct.rs

1use super::Visitor;
2use calyx_ir as ir;
3use calyx_utils::{CalyxResult, OutputFile};
4use itertools::Itertools;
5use linked_hash_map::LinkedHashMap;
6use std::iter;
7
8#[derive(Clone)]
9/// The value returned from parsing an option.
10pub enum ParseVal {
11    /// A boolean option.
12    Bool(bool),
13    /// A number option.
14    Num(i64),
15    /// A list of values.
16    List(Vec<ParseVal>),
17    /// An output stream (stdout, stderr, file name)
18    OutStream(OutputFile),
19}
20
21impl ParseVal {
22    pub fn bool(&self) -> bool {
23        let ParseVal::Bool(b) = self else {
24            panic!("Expected bool, got {self}");
25        };
26        *b
27    }
28
29    pub fn num(&self) -> i64 {
30        let ParseVal::Num(n) = self else {
31            panic!("Expected number, got {self}");
32        };
33        *n
34    }
35
36    pub fn pos_num(&self) -> Option<u64> {
37        let n = self.num();
38        if n < 0 {
39            None
40        } else {
41            Some(n as u64)
42        }
43    }
44
45    pub fn num_list(&self) -> Vec<i64> {
46        match self {
47            ParseVal::List(l) => {
48                l.iter().map(ParseVal::num).collect::<Vec<_>>()
49            }
50            _ => panic!("Expected list of numbers, got {self}"),
51        }
52    }
53
54    /// Parse a list that should have exactly N elements. If elements are missing, then add None
55    /// to the end of the list.
56    pub fn num_list_exact<const N: usize>(&self) -> [Option<i64>; N] {
57        let list = self.num_list();
58        let len = list.len();
59        if len > N {
60            panic!("Expected list of {N} numbers, got {len}");
61        }
62        list.into_iter()
63            .map(Some)
64            .chain(iter::repeat(None).take(N - len))
65            .collect::<Vec<_>>()
66            .try_into()
67            .unwrap()
68    }
69
70    /// Returns an output stream if it is not the null stream
71    pub fn not_null_outstream(&self) -> Option<OutputFile> {
72        match self {
73            ParseVal::OutStream(o) => {
74                if matches!(o, OutputFile::Null) {
75                    None
76                } else {
77                    Some(o.clone())
78                }
79            }
80            _ => panic!("Expected output stream, got {self}"),
81        }
82    }
83}
84impl std::fmt::Display for ParseVal {
85    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86        match self {
87            ParseVal::Bool(b) => write!(f, "{b}"),
88            ParseVal::Num(n) => write!(f, "{n}"),
89            ParseVal::List(l) => {
90                write!(f, "[")?;
91                for (i, e) in l.iter().enumerate() {
92                    if i != 0 {
93                        write!(f, ", ")?;
94                    }
95                    write!(f, "{e}")?;
96                }
97                write!(f, "]")
98            }
99            ParseVal::OutStream(o) => write!(f, "{}", o.to_string()),
100        }
101    }
102}
103
104/// Option that can be passed to a pass.
105pub struct PassOpt {
106    name: &'static str,
107    description: &'static str,
108    default: ParseVal,
109    parse: fn(&str) -> Option<ParseVal>,
110}
111
112impl PassOpt {
113    pub const fn new(
114        name: &'static str,
115        description: &'static str,
116        default: ParseVal,
117        parse: fn(&str) -> Option<ParseVal>,
118    ) -> Self {
119        Self {
120            name,
121            description,
122            default,
123            parse,
124        }
125    }
126
127    pub const fn name(&self) -> &'static str {
128        self.name
129    }
130
131    pub const fn description(&self) -> &'static str {
132        self.description
133    }
134
135    pub const fn default(&self) -> &ParseVal {
136        &self.default
137    }
138
139    fn parse(&self, s: &str) -> Option<ParseVal> {
140        (self.parse)(s)
141    }
142
143    /// Parse of list using parser for the elements.
144    /// Returns `None` if any of the elements fail to parse.
145    fn parse_list(
146        s: &str,
147        parse: fn(&str) -> Option<ParseVal>,
148    ) -> Option<ParseVal> {
149        let mut res = Vec::new();
150        for e in s.split(',') {
151            res.push(parse(e)?);
152        }
153        Some(ParseVal::List(res))
154    }
155
156    pub fn parse_bool(s: &str) -> Option<ParseVal> {
157        match s {
158            "true" => Some(ParseVal::Bool(true)),
159            "false" => Some(ParseVal::Bool(false)),
160            _ => None,
161        }
162    }
163
164    /// Parse a number from a string.
165    pub fn parse_num(s: &str) -> Option<ParseVal> {
166        s.parse::<i64>().ok().map(ParseVal::Num)
167    }
168
169    /// Parse a list of numbers from a string.
170    pub fn parse_num_list(s: &str) -> Option<ParseVal> {
171        Self::parse_list(s, Self::parse_num)
172    }
173
174    pub fn parse_outstream(s: &str) -> Option<ParseVal> {
175        s.parse::<OutputFile>().ok().map(ParseVal::OutStream)
176    }
177}
178
179/// Trait that describes named things. Calling [`do_pass`](Visitor::do_pass) and [`do_pass_default`](Visitor::do_pass_default).
180/// require this to be implemented.
181///
182/// This has to be a separate trait from [`Visitor`] because these methods don't recieve `self` which
183/// means that it is impossible to create dynamic trait objects.
184pub trait Named {
185    /// The name of a pass. Is used for identifying passes.
186    fn name() -> &'static str;
187    /// A short description of the pass.
188    fn description() -> &'static str;
189    /// Set of options that can be passed to the pass.
190    /// The options contains a tuple of the option name and a description.
191    fn opts() -> Vec<PassOpt> {
192        vec![]
193    }
194}
195
196/// Trait defining method that can be used to construct a Visitor from an
197/// [ir::Context].
198/// This is useful when a pass needs to construct information using the context
199/// *before* visiting the components.
200///
201/// For passes that don't need to use the context, this trait can be automatically
202/// be derived from [Default].
203pub trait ConstructVisitor {
204    fn get_opts(ctx: &ir::Context) -> LinkedHashMap<&'static str, ParseVal>
205    where
206        Self: Named,
207    {
208        let opts = Self::opts();
209        let n = Self::name();
210        let mut values: LinkedHashMap<&'static str, ParseVal> = ctx
211            .extra_opts
212            .iter()
213            .filter_map(|opt| {
214                // The format is either -x pass:opt or -x pass:opt=val
215                let mut splits = opt.split(':');
216                if let Some(pass) = splits.next() {
217                    if pass == n {
218                        let mut splits = splits.next()?.split('=');
219                        let opt = splits.next()?.to_string();
220                        let Some(opt) = opts.iter().find(|o| o.name == opt) else {
221                            log::warn!("Ignoring unknown option for pass `{n}`: {opt}");
222                                return None;
223                        };
224                        let val = if let Some(v) = splits.next() {
225                            let Some(v) = opt.parse(v) else {
226                                log::warn!(
227                                    "Ignoring invalid value for option `{n}:{}`: {v}",
228                                    opt.name(),
229                                );
230                                return None;
231                            };
232                            v
233                        } else {
234                            ParseVal::Bool(true)
235                        };
236                        return Some((opt.name(), val));
237                    }
238                }
239                None
240            })
241            .collect();
242
243        if log::log_enabled!(log::Level::Debug) {
244            log::debug!(
245                "Extra options for {}: {}",
246                Self::name(),
247                values.iter().map(|(o, v)| format!("{o}->{v}")).join(", ")
248            );
249        }
250
251        // For all options that were not provided with values, fill in the defaults.
252        for opt in opts {
253            if !values.contains_key(opt.name()) {
254                values.insert(opt.name(), opt.default.clone());
255            }
256        }
257
258        values
259    }
260
261    /// Construct the visitor using information from the Context
262    fn from(_ctx: &ir::Context) -> CalyxResult<Self>
263    where
264        Self: Sized;
265
266    /// Clear the data stored in the visitor. Called before traversing the
267    /// next component by [ir::traversal::Visitor].
268    fn clear_data(&mut self);
269}
270
271/// Derive ConstructVisitor when [Default] is provided for a visitor.
272impl<T: Default + Sized + Visitor> ConstructVisitor for T {
273    fn from(_ctx: &ir::Context) -> CalyxResult<Self> {
274        Ok(T::default())
275    }
276
277    fn clear_data(&mut self) {
278        *self = T::default();
279    }
280}