webnn_graph/
parser.rs

1use crate::ast::{new_graph_json, ConstDecl, ConstInit, DataType, GraphJson, Node, OperandDesc};
2use pest::iterators::Pair;
3use pest::Parser;
4use pest_derive::Parser;
5use serde_json::{Map, Value};
6use std::collections::BTreeMap;
7use thiserror::Error;
8
9#[derive(Parser)]
10#[grammar = "wg.pest"]
11struct WGParser;
12
13#[derive(Debug, Error)]
14pub enum ParseError {
15    #[error("parse error: {0}")]
16    Pest(Box<pest::error::Error<Rule>>),
17    #[error("invalid dtype: {0}")]
18    BadDType(String),
19    #[error("internal error: {0}")]
20    Internal(String),
21}
22
23impl From<pest::error::Error<Rule>> for ParseError {
24    fn from(err: pest::error::Error<Rule>) -> Self {
25        ParseError::Pest(Box::new(err))
26    }
27}
28
29type ParsedExpr = (String, Vec<String>, Map<String, Value>, Option<Vec<String>>);
30
31pub fn parse_wg_text(input: &str) -> Result<GraphJson, ParseError> {
32    let mut pairs = WGParser::parse(Rule::file, input)?;
33    let file = pairs
34        .next()
35        .ok_or_else(|| ParseError::Internal("missing file".into()))?;
36
37    let mut g = new_graph_json();
38    let mut nodes: Vec<Node> = Vec::new();
39
40    for p in file.into_inner() {
41        match p.as_rule() {
42            Rule::inputs_block => parse_inputs_block(p, &mut g.inputs)?,
43            Rule::consts_block => parse_consts_block(p, &mut g.consts)?,
44            Rule::nodes_block => parse_nodes_block(p, &mut nodes)?,
45            Rule::outputs_block => parse_outputs_block(p, &mut g.outputs)?,
46            _ => {}
47        }
48    }
49
50    g.nodes = nodes;
51    Ok(g)
52}
53
54fn parse_inputs_block(
55    p: Pair<Rule>,
56    out: &mut BTreeMap<String, OperandDesc>,
57) -> Result<(), ParseError> {
58    for inner in p.into_inner() {
59        if inner.as_rule() == Rule::input_decl {
60            let mut it = inner.into_inner();
61            let name = it.next().unwrap().as_str().to_string();
62            let (dt, shape) = parse_ty(it.next().unwrap())?;
63            out.insert(
64                name,
65                OperandDesc {
66                    data_type: dt,
67                    shape,
68                },
69            );
70        }
71    }
72    Ok(())
73}
74
75fn parse_consts_block(
76    p: Pair<Rule>,
77    out: &mut BTreeMap<String, ConstDecl>,
78) -> Result<(), ParseError> {
79    for inner in p.into_inner() {
80        if inner.as_rule() == Rule::const_decl {
81            let mut it = inner.into_inner();
82            let name = it.next().unwrap().as_str().to_string();
83            let (dt, shape) = parse_ty(it.next().unwrap())?;
84
85            let mut init: Option<ConstInit> = None;
86            for ann in it {
87                if ann.as_rule() == Rule::const_annot {
88                    let text = ann.as_str();
89                    if text.starts_with("@weights") {
90                        let s = ann
91                            .into_inner()
92                            .find(|p| p.as_rule() == Rule::string)
93                            .map(|p| unquote(p.as_str()))
94                            .unwrap_or_else(|| name.clone());
95                        init = Some(ConstInit::Weights { r#ref: s });
96                    } else if text.starts_with("@scalar") {
97                        let n = ann
98                            .into_inner()
99                            .find(|p| p.as_rule() == Rule::number)
100                            .map(|p| parse_number_value(p.as_str()))
101                            .unwrap_or(Value::Null);
102                        init = Some(ConstInit::Scalar { value: n });
103                    }
104                }
105            }
106
107            let init = init.unwrap_or(ConstInit::Weights {
108                r#ref: name.clone(),
109            });
110            out.insert(
111                name,
112                ConstDecl {
113                    data_type: dt,
114                    shape,
115                    init,
116                },
117            );
118        }
119    }
120    Ok(())
121}
122
123fn parse_nodes_block(p: Pair<Rule>, out: &mut Vec<Node>) -> Result<(), ParseError> {
124    for inner in p.into_inner() {
125        if inner.as_rule() != Rule::stmt {
126            continue;
127        }
128        let stmt = inner.into_inner().next().unwrap();
129        match stmt.as_rule() {
130            Rule::assign => out.push(parse_assign(stmt)?),
131            Rule::multi_assign => out.push(parse_multi_assign(stmt)?),
132            _ => {}
133        }
134    }
135    Ok(())
136}
137
138fn parse_assign(p: Pair<Rule>) -> Result<Node, ParseError> {
139    let mut it = p.into_inner();
140    let id = it.next().unwrap().as_str().to_string();
141    let (op, inputs, options, outputs) = parse_expr(it.next().unwrap())?;
142    Ok(Node {
143        id,
144        op,
145        inputs,
146        options,
147        outputs,
148    })
149}
150
151fn parse_multi_assign(p: Pair<Rule>) -> Result<Node, ParseError> {
152    let mut it = p.into_inner();
153    let mut outs: Vec<String> = Vec::new();
154
155    // first items are idents inside [...]
156    // We receive them as a flat sequence of ident tokens due to grammar.
157    // Collect until we hit expr.
158    while let Some(next) = it.peek() {
159        if next.as_rule() == Rule::expr {
160            break;
161        }
162        let t = it.next().unwrap();
163        if t.as_rule() == Rule::ident {
164            outs.push(t.as_str().to_string());
165        }
166    }
167
168    let expr = it
169        .next()
170        .ok_or_else(|| ParseError::Internal("missing expr in multi_assign".into()))?;
171    let (op, inputs, options, _outputs_unused) = parse_expr(expr)?;
172    // Use the first output name as the node id for uniqueness; keep real outputs in Node.outputs.
173    let id = outs.first().cloned().unwrap_or_else(|| "tmp".into());
174    Ok(Node {
175        id,
176        op,
177        inputs,
178        options,
179        outputs: Some(outs),
180    })
181}
182
183fn parse_expr(p: Pair<Rule>) -> Result<ParsedExpr, ParseError> {
184    match p.as_rule() {
185        Rule::expr => parse_expr(p.into_inner().next().unwrap()),
186        Rule::call => parse_call(p),
187        Rule::ident => Ok((
188            String::new(),
189            vec![p.as_str().to_string()],
190            Map::new(),
191            None,
192        )),
193        _ => Err(ParseError::Internal(format!(
194            "unexpected expr rule: {:?}",
195            p.as_rule()
196        ))),
197    }
198}
199
200fn parse_call(p: Pair<Rule>) -> Result<ParsedExpr, ParseError> {
201    let mut it = p.into_inner();
202    let op = it.next().unwrap().as_str().to_string();
203    let mut inputs: Vec<String> = Vec::new();
204    let mut options: Map<String, Value> = Map::new();
205
206    if let Some(args) = it.next() {
207        if args.as_rule() == Rule::args {
208            for arg in args.into_inner() {
209                if arg.as_rule() != Rule::arg {
210                    continue;
211                }
212                let mut a = arg.into_inner().peekable();
213
214                // Check if this is a named argument: ident '=' value
215                let first = match a.next() {
216                    Some(f) => f,
217                    None => continue,
218                };
219
220                if first.as_rule() == Rule::ident
221                    && a.peek().is_some()
222                    && a.peek().unwrap().as_rule() == Rule::value
223                {
224                    // Named argument
225                    let key = first.as_str().to_string();
226                    let val = parse_value(a.next().unwrap())?;
227                    options.insert(key, val);
228                } else {
229                    // Positional argument
230                    match first.as_rule() {
231                        Rule::value => {
232                            let v = parse_value(first)?;
233                            if let Value::String(s) = v {
234                                inputs.push(s);
235                            } else if let Some(sym) = v.as_str() {
236                                inputs.push(sym.to_string());
237                            }
238                        }
239                        Rule::ident => inputs.push(first.as_str().to_string()),
240                        _ => {}
241                    }
242                }
243            }
244        }
245    }
246
247    Ok((op, inputs, options, None))
248}
249
250fn parse_outputs_block(
251    p: Pair<Rule>,
252    out: &mut BTreeMap<String, String>,
253) -> Result<(), ParseError> {
254    // WG: outputs { probs }  OR outputs { a,b; }
255    // We'll map each output name to itself.
256    for inner in p.into_inner() {
257        if inner.as_rule() == Rule::output_item {
258            for item in inner.into_inner() {
259                if item.as_rule() == Rule::ident {
260                    let name = item.as_str().to_string();
261                    out.insert(name.clone(), name);
262                }
263            }
264        }
265    }
266    Ok(())
267}
268
269fn parse_ty(p: Pair<Rule>) -> Result<(DataType, Vec<u32>), ParseError> {
270    let mut it = p.into_inner();
271    let dt_s = it.next().unwrap().as_str();
272    let dt = DataType::from_wg(dt_s).ok_or_else(|| ParseError::BadDType(dt_s.to_string()))?;
273    let shape = parse_shape(it.next().unwrap())?;
274    Ok((dt, shape))
275}
276
277fn parse_shape(p: Pair<Rule>) -> Result<Vec<u32>, ParseError> {
278    let mut shape = Vec::new();
279    for inner in p.into_inner() {
280        if inner.as_rule() == Rule::int {
281            let v: u32 = inner
282                .as_str()
283                .parse()
284                .map_err(|_| ParseError::Internal("bad int".into()))?;
285            shape.push(v);
286        }
287    }
288    Ok(shape)
289}
290
291fn parse_value(p: Pair<Rule>) -> Result<Value, ParseError> {
292    match p.as_rule() {
293        Rule::value => parse_value(p.into_inner().next().unwrap()),
294        Rule::literal => parse_value(p.into_inner().next().unwrap()),
295        Rule::string => Ok(Value::String(unquote(p.as_str()))),
296        Rule::number => Ok(parse_number_value(p.as_str())),
297        Rule::boolean => Ok(Value::Bool(p.as_str() == "true")),
298        Rule::null => Ok(Value::Null),
299        Rule::array => {
300            let mut arr = Vec::new();
301            for inner in p.into_inner() {
302                if inner.as_rule() == Rule::value {
303                    arr.push(parse_value(inner)?);
304                }
305            }
306            Ok(Value::Array(arr))
307        }
308        Rule::ident => Ok(Value::String(p.as_str().to_string())),
309        _ => Err(ParseError::Internal(format!(
310            "unexpected value rule: {:?}",
311            p.as_rule()
312        ))),
313    }
314}
315
316fn parse_number_value(s: &str) -> Value {
317    // Prefer i64 when exact, otherwise f64.
318    if !s.contains('.') && !s.contains('e') && !s.contains('E') {
319        if let Ok(i) = s.parse::<i64>() {
320            return Value::Number(i.into());
321        }
322    }
323    Value::Number(serde_json::Number::from_f64(s.parse::<f64>().unwrap_or(0.0)).unwrap())
324}
325
326fn unquote(s: &str) -> String {
327    let mut t = s.to_string();
328    if t.starts_with('"') && t.ends_with('"') && t.len() >= 2 {
329        t.remove(0);
330        t.pop();
331    }
332    t.replace("\\\"", "\"").replace("\\\\", "\\")
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338
339    #[test]
340    fn test_parse_simple_graph() {
341        let input = r#"
342webnn_graph "test" v1 {
343  inputs {
344    x: f32[1, 10];
345  }
346  consts {
347    W: f32[10, 5] @weights("W");
348  }
349  nodes {
350    result = matmul(x, W);
351  }
352  outputs { result; }
353}
354"#;
355        let graph = parse_wg_text(input).unwrap();
356        assert_eq!(graph.format, "webnn-graph-json");
357        assert_eq!(graph.version, 1);
358        assert_eq!(graph.inputs.len(), 1);
359        assert_eq!(graph.consts.len(), 1);
360        assert_eq!(graph.nodes.len(), 1);
361        assert_eq!(graph.outputs.len(), 1);
362    }
363
364    #[test]
365    fn test_parse_inputs() {
366        let input = r#"
367webnn_graph "test" v1 {
368  inputs {
369    x: f32[1, 10];
370    y: i32[5];
371  }
372  nodes {}
373  outputs { x; }
374}
375"#;
376        let graph = parse_wg_text(input).unwrap();
377        assert_eq!(graph.inputs.len(), 2);
378        assert!(graph.inputs.contains_key("x"));
379        assert!(graph.inputs.contains_key("y"));
380
381        let x_desc = &graph.inputs["x"];
382        assert_eq!(x_desc.data_type, DataType::Float32);
383        assert_eq!(x_desc.shape, vec![1, 10]);
384
385        let y_desc = &graph.inputs["y"];
386        assert_eq!(y_desc.data_type, DataType::Int32);
387        assert_eq!(y_desc.shape, vec![5]);
388    }
389
390    #[test]
391    fn test_parse_consts_with_weights() {
392        let input = r#"
393webnn_graph "test" v1 {
394  inputs { x: f32[1]; }
395  consts {
396    W: f32[10, 5] @weights("W");
397    b: f32[5] @weights("bias");
398  }
399  nodes {}
400  outputs { x; }
401}
402"#;
403        let graph = parse_wg_text(input).unwrap();
404        assert_eq!(graph.consts.len(), 2);
405
406        let w = &graph.consts["W"];
407        assert_eq!(w.data_type, DataType::Float32);
408        assert_eq!(w.shape, vec![10, 5]);
409        assert!(matches!(&w.init, ConstInit::Weights { r#ref } if r#ref == "W"));
410
411        let b = &graph.consts["b"];
412        assert!(matches!(&b.init, ConstInit::Weights { r#ref } if r#ref == "bias"));
413    }
414
415    #[test]
416    fn test_parse_consts_with_scalar() {
417        let input = r#"
418webnn_graph "test" v1 {
419  inputs { x: f32[1]; }
420  consts {
421    scale: f32[1] @scalar(2.5);
422  }
423  nodes {}
424  outputs { x; }
425}
426"#;
427        let graph = parse_wg_text(input).unwrap();
428        let scale = &graph.consts["scale"];
429        match &scale.init {
430            ConstInit::Scalar { value } => {
431                assert_eq!(value.as_f64().unwrap(), 2.5);
432            }
433            _ => panic!("Expected scalar init"),
434        }
435    }
436
437    #[test]
438    fn test_parse_nodes() {
439        let input = r#"
440webnn_graph "test" v1 {
441  inputs { x: f32[1, 2048]; }
442  consts { W: f32[2048, 1000] @weights("W"); }
443  nodes {
444    result = matmul(x, W);
445  }
446  outputs { result; }
447}
448"#;
449        let graph = parse_wg_text(input).unwrap();
450        assert_eq!(graph.nodes.len(), 1);
451
452        let node = &graph.nodes[0];
453        assert_eq!(node.id, "result");
454        assert_eq!(node.op, "matmul");
455        assert_eq!(node.inputs, vec!["x", "W"]);
456        assert!(node.options.is_empty());
457    }
458
459    #[test]
460    fn test_parse_nodes_with_options() {
461        let input = r#"
462webnn_graph "test" v1 {
463  inputs { x: f32[1, 10]; }
464  nodes {
465    result = softmax(x, axis=1);
466  }
467  outputs { result; }
468}
469"#;
470        let graph = parse_wg_text(input).unwrap();
471        let node = &graph.nodes[0];
472        assert_eq!(node.op, "softmax");
473        assert_eq!(node.inputs, vec!["x"]);
474        assert_eq!(node.options.get("axis").unwrap().as_i64().unwrap(), 1);
475    }
476
477    #[test]
478    fn test_parse_multi_assign() {
479        let input = r#"
480webnn_graph "test" v1 {
481  inputs { x: f32[10]; }
482  nodes {
483    [a, b] = split(x);
484  }
485  outputs { a; }
486}
487"#;
488        let graph = parse_wg_text(input).unwrap();
489        let node = &graph.nodes[0];
490        assert_eq!(node.id, "a");
491        assert_eq!(node.op, "split");
492        assert_eq!(node.outputs, Some(vec!["a".to_string(), "b".to_string()]));
493    }
494
495    #[test]
496    fn test_parse_outputs() {
497        let input = r#"
498webnn_graph "test" v1 {
499  inputs { x: f32[1]; }
500  nodes {
501    a = relu(x);
502    b = sigmoid(x);
503  }
504  outputs { a; b; }
505}
506"#;
507        let graph = parse_wg_text(input).unwrap();
508        assert_eq!(graph.outputs.len(), 2);
509        assert_eq!(graph.outputs.get("a").unwrap(), "a");
510        assert_eq!(graph.outputs.get("b").unwrap(), "b");
511    }
512
513    #[test]
514    fn test_parse_invalid_dtype() {
515        let input = r#"
516webnn_graph "test" v1 {
517  inputs { x: float32[1]; }
518  nodes {}
519  outputs { x; }
520}
521"#;
522        let result = parse_wg_text(input);
523        assert!(result.is_err());
524        // The pest parser should fail because "float32" doesn't match the dtype rule
525        match result {
526            Err(ParseError::Pest(_)) => {}
527            Err(e) => panic!("Expected Pest parse error, got: {:?}", e),
528            Ok(_) => panic!("Expected error but parsing succeeded"),
529        }
530    }
531
532    #[test]
533    fn test_unquote() {
534        assert_eq!(unquote(r#""hello""#), "hello");
535        assert_eq!(unquote(r#""hello\"world""#), "hello\"world");
536        assert_eq!(unquote(r#""path\\to\\file""#), "path\\to\\file");
537        assert_eq!(unquote("no_quotes"), "no_quotes");
538    }
539
540    #[test]
541    fn test_parse_number_value() {
542        let int_val = parse_number_value("42");
543        assert_eq!(int_val.as_i64().unwrap(), 42);
544
545        let float_val = parse_number_value("3.12");
546        assert_eq!(float_val.as_f64().unwrap(), 3.12);
547
548        let sci_val = parse_number_value("1e-3");
549        assert_eq!(sci_val.as_f64().unwrap(), 0.001);
550    }
551}