calyx_opt/traversal/
construct.rs1use super::Visitor;
2use calyx_ir as ir;
3use calyx_utils::{CalyxResult, OutputFile};
4use itertools::Itertools;
5use linked_hash_map::LinkedHashMap;
6use std::iter;
7
8#[derive(Clone)]
9pub enum ParseVal {
11 Bool(bool),
13 Num(i64),
15 List(Vec<ParseVal>),
17 OutStream(OutputFile),
19}
20
21impl ParseVal {
22 pub fn bool(&self) -> bool {
23 let ParseVal::Bool(b) = self else {
24 panic!("Expected bool, got {self}");
25 };
26 *b
27 }
28
29 pub fn num(&self) -> i64 {
30 let ParseVal::Num(n) = self else {
31 panic!("Expected number, got {self}");
32 };
33 *n
34 }
35
36 pub fn pos_num(&self) -> Option<u64> {
37 let n = self.num();
38 if n < 0 {
39 None
40 } else {
41 Some(n as u64)
42 }
43 }
44
45 pub fn num_list(&self) -> Vec<i64> {
46 match self {
47 ParseVal::List(l) => {
48 l.iter().map(ParseVal::num).collect::<Vec<_>>()
49 }
50 _ => panic!("Expected list of numbers, got {self}"),
51 }
52 }
53
54 pub fn num_list_exact<const N: usize>(&self) -> [Option<i64>; N] {
57 let list = self.num_list();
58 let len = list.len();
59 if len > N {
60 panic!("Expected list of {N} numbers, got {len}");
61 }
62 list.into_iter()
63 .map(Some)
64 .chain(iter::repeat(None).take(N - len))
65 .collect::<Vec<_>>()
66 .try_into()
67 .unwrap()
68 }
69
70 pub fn not_null_outstream(&self) -> Option<OutputFile> {
72 match self {
73 ParseVal::OutStream(o) => {
74 if matches!(o, OutputFile::Null) {
75 None
76 } else {
77 Some(o.clone())
78 }
79 }
80 _ => panic!("Expected output stream, got {self}"),
81 }
82 }
83}
84impl std::fmt::Display for ParseVal {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 match self {
87 ParseVal::Bool(b) => write!(f, "{b}"),
88 ParseVal::Num(n) => write!(f, "{n}"),
89 ParseVal::List(l) => {
90 write!(f, "[")?;
91 for (i, e) in l.iter().enumerate() {
92 if i != 0 {
93 write!(f, ", ")?;
94 }
95 write!(f, "{e}")?;
96 }
97 write!(f, "]")
98 }
99 ParseVal::OutStream(o) => write!(f, "{}", o.to_string()),
100 }
101 }
102}
103
104pub struct PassOpt {
106 name: &'static str,
107 description: &'static str,
108 default: ParseVal,
109 parse: fn(&str) -> Option<ParseVal>,
110}
111
112impl PassOpt {
113 pub const fn new(
114 name: &'static str,
115 description: &'static str,
116 default: ParseVal,
117 parse: fn(&str) -> Option<ParseVal>,
118 ) -> Self {
119 Self {
120 name,
121 description,
122 default,
123 parse,
124 }
125 }
126
127 pub const fn name(&self) -> &'static str {
128 self.name
129 }
130
131 pub const fn description(&self) -> &'static str {
132 self.description
133 }
134
135 pub const fn default(&self) -> &ParseVal {
136 &self.default
137 }
138
139 fn parse(&self, s: &str) -> Option<ParseVal> {
140 (self.parse)(s)
141 }
142
143 fn parse_list(
146 s: &str,
147 parse: fn(&str) -> Option<ParseVal>,
148 ) -> Option<ParseVal> {
149 let mut res = Vec::new();
150 for e in s.split(',') {
151 res.push(parse(e)?);
152 }
153 Some(ParseVal::List(res))
154 }
155
156 pub fn parse_bool(s: &str) -> Option<ParseVal> {
157 match s {
158 "true" => Some(ParseVal::Bool(true)),
159 "false" => Some(ParseVal::Bool(false)),
160 _ => None,
161 }
162 }
163
164 pub fn parse_num(s: &str) -> Option<ParseVal> {
166 s.parse::<i64>().ok().map(ParseVal::Num)
167 }
168
169 pub fn parse_num_list(s: &str) -> Option<ParseVal> {
171 Self::parse_list(s, Self::parse_num)
172 }
173
174 pub fn parse_outstream(s: &str) -> Option<ParseVal> {
175 s.parse::<OutputFile>().ok().map(ParseVal::OutStream)
176 }
177}
178
179pub trait Named {
185 fn name() -> &'static str;
187 fn description() -> &'static str;
189 fn opts() -> Vec<PassOpt> {
192 vec![]
193 }
194}
195
196pub trait ConstructVisitor {
204 fn get_opts(ctx: &ir::Context) -> LinkedHashMap<&'static str, ParseVal>
205 where
206 Self: Named,
207 {
208 let opts = Self::opts();
209 let n = Self::name();
210 let mut values: LinkedHashMap<&'static str, ParseVal> = ctx
211 .extra_opts
212 .iter()
213 .filter_map(|opt| {
214 let mut splits = opt.split(':');
216 if let Some(pass) = splits.next() {
217 if pass == n {
218 let mut splits = splits.next()?.split('=');
219 let opt = splits.next()?.to_string();
220 let Some(opt) = opts.iter().find(|o| o.name == opt) else {
221 log::warn!("Ignoring unknown option for pass `{n}`: {opt}");
222 return None;
223 };
224 let val = if let Some(v) = splits.next() {
225 let Some(v) = opt.parse(v) else {
226 log::warn!(
227 "Ignoring invalid value for option `{n}:{}`: {v}",
228 opt.name(),
229 );
230 return None;
231 };
232 v
233 } else {
234 ParseVal::Bool(true)
235 };
236 return Some((opt.name(), val));
237 }
238 }
239 None
240 })
241 .collect();
242
243 if log::log_enabled!(log::Level::Debug) {
244 log::debug!(
245 "Extra options for {}: {}",
246 Self::name(),
247 values.iter().map(|(o, v)| format!("{o}->{v}")).join(", ")
248 );
249 }
250
251 for opt in opts {
253 if !values.contains_key(opt.name()) {
254 values.insert(opt.name(), opt.default.clone());
255 }
256 }
257
258 values
259 }
260
261 fn from(_ctx: &ir::Context) -> CalyxResult<Self>
263 where
264 Self: Sized;
265
266 fn clear_data(&mut self);
269}
270
271impl<T: Default + Sized + Visitor> ConstructVisitor for T {
273 fn from(_ctx: &ir::Context) -> CalyxResult<Self> {
274 Ok(T::default())
275 }
276
277 fn clear_data(&mut self) {
278 *self = T::default();
279 }
280}