1use std::collections::HashMap;
2
3use anyhow::Result;
4use pest::iterators::Pair;
5use pest::iterators::Pairs;
6use pest::pratt_parser::Assoc;
7use pest::pratt_parser::Op;
8use pest::pratt_parser::PrattParser;
9use pest::Parser;
10use pest::RuleType;
11use pest_derive::Parser;
12
13use self::AstNode::*;
14use crate::log;
15use log::error;
16
17#[derive(Debug, Clone)]
20pub enum AstNode {
21 FnVar(Vec<String>),
23 Stmt(String, bool, Expr),
25 ExprUnassigned(Expr),
26 Rtrn(Expr),
27 StaticDef(String, Expr),
28 If(Expr, Vec<AstNode>),
29 Loop(Expr, Vec<AstNode>),
30 EmptyVecDef(String, Vec<usize>),
31
32 AssignVec(String, Vec<Expr>, Expr),
36}
37
38#[derive(Debug, Clone)]
41pub enum Expr {
42 VecVec(Vec<Expr>),
43 VecLit(Vec<String>),
44 Lit(String),
45 Val(String, Vec<Expr>),
46 FnCall(String, Vec<Expr>),
47 NumOp {
48 lhs: Box<Expr>,
49 op: NumOp,
50 rhs: Box<Expr>,
51 },
52 BoolOp {
53 lhs: Box<Expr>,
54 bool_op: BoolOp,
55 rhs: Box<Expr>,
56 },
57}
58
59#[derive(Debug, Clone)]
61pub enum BoolOp {
62 Equal,
63 NotEqual,
64 GreaterThan,
65 LessThan,
66}
67
68#[derive(Debug, Clone)]
70pub enum NumOp {
71 Add,
72 Sub,
73 Inv,
74 Mul,
75}
76
77#[derive(Parser)]
78#[grammar = "grammar.pest"] pub struct AshPestParser;
80
81pub struct AshParser {
84 pub ast: Vec<AstNode>,
85 pub fn_names: HashMap<String, u64>,
86 pub entry_fn_name: String,
87}
88
89impl AshParser {
90 pub fn parse(source: &str, name: &str) -> Result<Self> {
93 let source = format!("{source}\n");
96 let mut out = Self {
97 ast: Vec::new(),
98 fn_names: HashMap::new(),
99 entry_fn_name: name.to_string(),
100 };
101
102 match AshPestParser::parse(Rule::program, &source) {
103 Ok(pairs) => {
104 let ast = out.build_ast_from_lines(pairs);
105 if let Err(e) = ast {
106 return error!(&format!("error building program ast: {e}"));
107 }
108 }
109 Err(e) => {
110 return Err(anyhow::anyhow!(log::parse_error(e, name)));
111 }
112 }
113 Ok(out)
114 }
115
116 fn mark_fn_call(&mut self, name: String) {
117 let count = self.fn_names.entry(name).or_insert(0);
118 *count += 1;
119 }
120
121 pub fn next_or_error<'a, T: RuleType>(pairs: &'a mut Pairs<T>) -> Result<Pair<'a, T>> {
122 if let Some(n) = pairs.next() {
123 Ok(n)
124 } else {
125 anyhow::bail!("Expected next token but found none")
126 }
127 }
128
129 fn build_ast_from_lines(&mut self, pairs: Pairs<Rule>) -> Result<()> {
130 for pair in pairs {
131 match pair.as_rule() {
132 Rule::fn_header => {
133 let pair = pair.into_inner();
138 let mut vars: Vec<String> = Vec::new();
139 for v in pair {
140 vars.push(v.as_str().to_string());
141 }
142 self.ast.push(FnVar(vars));
144 }
145 Rule::stmt => {
146 let mut pair = pair.into_inner();
147 let next = AshParser::next_or_error(&mut pair)?;
148 let ast = self.build_ast_from_pair(next)?;
149 self.ast.push(ast);
150 }
151 Rule::return_stmt => {
152 let mut pair = pair.into_inner();
153 let next = AshParser::next_or_error(&mut pair)?;
154 let expr = self.build_expr_from_pair(next)?;
155 self.ast.push(Rtrn(expr));
156 }
157 Rule::EOI => {}
158 _ => anyhow::bail!("unexpected line pair rule: {:?}", pair.as_rule()),
159 }
160 }
161 Ok(())
162 }
163
164 fn build_ast_from_pair(&mut self, pair: pest::iterators::Pair<Rule>) -> Result<AstNode> {
165 match pair.as_rule() {
166 Rule::var_index_assign => {
167 let mut pair = pair.into_inner();
168 let next = AshParser::next_or_error(&mut pair)?;
169 let v = self.build_expr_from_pair(next)?;
170 let name;
171 let indices;
172 match v {
173 Expr::Val(n, i) => {
174 name = n;
175 indices = i;
176 }
177 _ => {
178 anyhow::bail!("unexpected expr in var_index_assign: {:?}, expected Val", v)
179 }
180 }
181 let next = AshParser::next_or_error(&mut pair)?;
182 let expr = self.build_expr_from_pair(next)?;
183 Ok(AssignVec(name, indices, expr))
184 }
185 Rule::var_vec_def => {
186 let mut pair = pair.into_inner();
187 let _ = AshParser::next_or_error(&mut pair)?;
188 let next = AshParser::next_or_error(&mut pair)?;
189 let expr = self.build_expr_from_pair(next)?;
190 match expr {
191 Expr::Val(name, indices) => {
192 let mut indices_static: Vec<usize> = Vec::new();
193 for i in indices {
194 match i {
195 Expr::Lit(v) => {
196 indices_static.push(v.parse::<usize>().unwrap());
197 }
198 _ => {
199 anyhow::bail!(
200 "unexpected expr in var_vec_def: {:?}, expected Lit",
201 i
202 )
203 }
204 }
205 }
206 Ok(EmptyVecDef(name, indices_static))
207 }
208 _ => {
209 anyhow::bail!("unexpected expr in var_vec_def: {:?}, expected Val", expr)
210 }
211 }
212 }
213 Rule::loop_stmt => {
214 let mut pair = pair.into_inner();
215 let iter_count = AshParser::next_or_error(&mut pair)?;
216 let iter_count_expr = self.build_expr_from_pair(iter_count)?;
217 let block = AshParser::next_or_error(&mut pair)?;
218 let block_inner = block.into_inner();
219 let block_ast = block_inner
220 .map(|v| match v.as_rule() {
221 Rule::stmt => {
222 let mut pair = v.into_inner();
223 let next = AshParser::next_or_error(&mut pair)?;
224 self.build_ast_from_pair(next)
225 }
226 _ => Err(anyhow::anyhow!("invalid expression in block")),
227 })
228 .collect::<Result<Vec<AstNode>>>()?;
229 Ok(Loop(iter_count_expr, block_ast))
230 }
231 Rule::function_call => Ok(ExprUnassigned(self.build_expr_from_pair(pair)?)),
232 Rule::var_def => {
233 let mut pair = pair.into_inner();
235 let next = AshParser::next_or_error(&mut pair)?;
236 let mut varpair = next.into_inner();
237 let name;
238 let is_let;
239 if varpair.len() == 2 {
240 AshParser::next_or_error(&mut varpair)?;
242 name = AshParser::next_or_error(&mut varpair)?.as_str().to_string();
243 is_let = true;
244 } else if varpair.len() == 1 {
245 name = AshParser::next_or_error(&mut varpair)?.as_str().to_string();
247 is_let = false;
248 } else {
249 return Err(anyhow::anyhow!("invalid varpait"));
250 }
251
252 let n = AshParser::next_or_error(&mut pair)?;
253 Ok(Stmt(name, is_let, self.build_expr_from_pair(n)?))
254 }
255 Rule::static_def => {
256 let mut pair = pair.into_inner();
257 let name = AshParser::next_or_error(&mut pair)?.as_str().to_string();
258 let expr = AshParser::next_or_error(&mut pair)?;
259 Ok(StaticDef(
260 name.as_str().to_string(),
261 self.build_expr_from_pair(expr)?,
262 ))
263 }
264 Rule::if_stmt => {
265 let mut pair = pair.into_inner();
266 let bool_expr = AshParser::next_or_error(&mut pair)?;
267 let mut bool_expr_pair = bool_expr.into_inner();
268 let expr1 =
269 self.build_expr_from_pair(AshParser::next_or_error(&mut bool_expr_pair)?)?;
270 let bool_op = match AshParser::next_or_error(&mut bool_expr_pair)?.as_rule() {
271 Rule::equal => BoolOp::Equal,
272 Rule::not_equal => BoolOp::NotEqual,
273 Rule::gt => BoolOp::GreaterThan,
274 Rule::lt => BoolOp::LessThan,
275 _ => anyhow::bail!("invalid bool op"),
276 };
277 let expr2 =
278 self.build_expr_from_pair(AshParser::next_or_error(&mut bool_expr_pair)?)?;
279 let block = AshParser::next_or_error(&mut pair)?;
280 let block_inner = block.into_inner();
281 let block_ast = block_inner
282 .map(|v| match v.as_rule() {
283 Rule::stmt => {
284 let mut pair = v.into_inner();
285 let next = AshParser::next_or_error(&mut pair)?;
286 self.build_ast_from_pair(next)
287 }
288 _ => anyhow::bail!("invalid expression in block"),
289 })
290 .collect::<Result<Vec<AstNode>>>()?;
291 Ok(If(
292 Expr::BoolOp {
293 lhs: Box::new(expr1),
294 bool_op,
295 rhs: Box::new(expr2),
296 },
297 block_ast,
298 ))
299 }
300 unknown_expr => anyhow::bail!(
301 "Unable to build ast node, unexpected expression: {:?}",
302 unknown_expr
303 ),
304 }
305 }
306
307 fn build_expr_from_pair(&mut self, pair: pest::iterators::Pair<Rule>) -> Result<Expr> {
308 match pair.as_rule() {
309 Rule::var_indexed => {
310 let mut pair = pair.into_inner();
311 let name = AshParser::next_or_error(&mut pair)?.as_str().to_string();
312 let mut indices: Vec<Expr> = Vec::new();
313 for v in pair {
314 indices.push(self.build_expr_from_pair(v)?);
315 }
316 Ok(Expr::Val(name, indices))
317 }
318 Rule::literal_dec => Ok(Expr::Lit(pair.as_str().to_string())),
319 Rule::vec => {
320 let mut pair = pair.into_inner();
321 let next = AshParser::next_or_error(&mut pair)?;
322 if next.as_rule() == Rule::vec {
323 let mut out: Vec<Expr> = Vec::new();
324 out.push(self.build_expr_from_pair(next.clone())?);
325 for next in pair {
326 out.push(self.build_expr_from_pair(next.clone())?);
327 }
329 Ok(Expr::VecVec(out))
330 } else {
331 let mut out: Vec<String> = Vec::new();
332 out.push(next.as_str().to_string());
333 for next in pair {
334 out.push(next.as_str().to_string());
335 }
336 Ok(Expr::VecLit(out))
337 }
338 }
339 Rule::function_call => {
340 let mut pair = pair.into_inner();
341 let next = AshParser::next_or_error(&mut pair)?;
342 let fn_name = next.as_str().to_string();
343 let arg_pair = AshParser::next_or_error(&mut pair)?.into_inner();
344 let mut vars: Vec<Expr> = Vec::new();
345 for v in arg_pair {
346 vars.push(self.build_expr_from_pair(v)?);
347 }
348 self.mark_fn_call(fn_name.clone());
349 Ok(Expr::FnCall(fn_name, vars))
350 }
351 Rule::atom => {
352 let mut pair = pair.into_inner();
353 let n = AshParser::next_or_error(&mut pair)?;
354 match n.as_rule() {
355 Rule::function_call => Ok(self.build_expr_from_pair(n)?),
356 Rule::varname => Ok(Expr::Val(n.as_str().to_string(), vec![])),
357 Rule::var_indexed => {
358 let mut pair = n.into_inner();
359 let name = AshParser::next_or_error(&mut pair)?.as_str().to_string();
360 let mut indices: Vec<Expr> = Vec::new();
361 for v in pair {
362 indices.push(self.build_expr_from_pair(v)?);
363 }
364 Ok(Expr::Val(name, indices))
365 }
366 Rule::literal_dec => Ok(Expr::Lit(n.as_str().to_string())),
367 _ => anyhow::bail!("invalid atom"),
368 }
369 }
370 Rule::expr => {
371 let mut pair = pair.into_inner();
372 if pair.len() == 1 {
373 return self.build_expr_from_pair(AshParser::next_or_error(&mut pair)?);
374 }
375 let pratt = PrattParser::new()
376 .op(Op::infix(Rule::add, Assoc::Left) | Op::infix(Rule::sub, Assoc::Left))
377 .op(Op::infix(Rule::mul, Assoc::Left) | Op::infix(Rule::inv, Assoc::Left));
378 pratt
379 .map_primary(|primary| match primary.as_rule() {
380 Rule::atom => self.build_expr_from_pair(primary),
381 Rule::expr => self.build_expr_from_pair(primary),
382 _ => Err(anyhow::anyhow!("unexpected rule in pratt parser")),
383 })
384 .map_infix(|lhs, op, rhs| match op.as_rule() {
385 Rule::add => Ok(Expr::NumOp {
386 lhs: Box::new(lhs?),
387 op: NumOp::Add,
388 rhs: Box::new(rhs?),
389 }),
390 Rule::sub => Ok(Expr::NumOp {
391 lhs: Box::new(lhs?),
392 op: NumOp::Sub,
393 rhs: Box::new(rhs?),
394 }),
395 Rule::mul => Ok(Expr::NumOp {
396 lhs: Box::new(lhs?),
397 op: NumOp::Mul,
398 rhs: Box::new(rhs?),
399 }),
400 Rule::inv => Ok(Expr::NumOp {
401 lhs: Box::new(lhs?),
402 op: NumOp::Inv,
403 rhs: Box::new(rhs?),
404 }),
405 _ => unreachable!(),
406 })
407 .parse(pair)
408 }
409 unknown_expr => anyhow::bail!(
410 "Unable to build expression, unexpected rule: {:?}",
411 unknown_expr
412 ),
413 }
414 }
415}