Skip to main content

webnn_graph/
parser.rs

1use crate::ast::{
2    new_graph_json, ConstDecl, ConstInit, DataType, Dimension, DynamicDimension, GraphJson, Node,
3    OperandDesc,
4};
5use pest::iterators::Pair;
6use pest::Parser;
7use pest_derive::Parser;
8use serde_json::{Map, Value};
9use std::collections::BTreeMap;
10use thiserror::Error;
11
12#[derive(Parser)]
13#[grammar = "wg.pest"]
14struct WGParser;
15
16#[derive(Debug, Error)]
17pub enum ParseError {
18    #[error("parse error: {0}")]
19    Pest(Box<pest::error::Error<Rule>>),
20    #[error("invalid dtype: {0}")]
21    BadDType(String),
22    #[error("internal error: {0}")]
23    Internal(String),
24    #[error("constant shapes must be static")]
25    DynamicConstShape,
26}
27
28impl From<pest::error::Error<Rule>> for ParseError {
29    fn from(err: pest::error::Error<Rule>) -> Self {
30        ParseError::Pest(Box::new(err))
31    }
32}
33
34type ParsedExpr = (String, Vec<String>, Map<String, Value>, Option<Vec<String>>);
35
36pub fn parse_wg_text(input: &str) -> Result<GraphJson, ParseError> {
37    let mut pairs = WGParser::parse(Rule::file, input)?;
38    let file = pairs
39        .next()
40        .ok_or_else(|| ParseError::Internal("missing file".into()))?;
41
42    let mut g = new_graph_json();
43    let mut nodes: Vec<Node> = Vec::new();
44
45    for p in file.into_inner() {
46        match p.as_rule() {
47            Rule::header => {
48                // Extract graph name, version, quantized flag from header
49                for inner in p.into_inner() {
50                    match inner.as_rule() {
51                        Rule::string => g.name = Some(unquote(inner.as_str())),
52                        Rule::int => {
53                            let version: u32 = inner
54                                .as_str()
55                                .parse()
56                                .map_err(|e| ParseError::Internal(format!("bad version: {}", e)))?;
57                            g.version = version;
58                        }
59                        Rule::quantized => g.quantized = true,
60                        _ => {}
61                    }
62                }
63            }
64            Rule::inputs_block => parse_inputs_block(p, &mut g.inputs)?,
65            Rule::consts_block => parse_consts_block(p, &mut g.consts)?,
66            Rule::nodes_block => parse_nodes_block(p, &mut nodes)?,
67            Rule::outputs_block => parse_outputs_block(p, &mut g.outputs)?,
68            _ => {}
69        }
70    }
71
72    g.nodes = nodes;
73    Ok(g)
74}
75
76fn parse_inputs_block(
77    p: Pair<Rule>,
78    out: &mut BTreeMap<String, OperandDesc>,
79) -> Result<(), ParseError> {
80    for inner in p.into_inner() {
81        if inner.as_rule() == Rule::input_decl {
82            let mut it = inner.into_inner();
83            let name = it.next().unwrap().as_str().to_string();
84            let (dt, shape) = parse_ty(it.next().unwrap())?;
85            out.insert(
86                name,
87                OperandDesc {
88                    data_type: dt,
89                    shape,
90                },
91            );
92        }
93    }
94    Ok(())
95}
96
97fn parse_consts_block(
98    p: Pair<Rule>,
99    out: &mut BTreeMap<String, ConstDecl>,
100) -> Result<(), ParseError> {
101    for inner in p.into_inner() {
102        if inner.as_rule() == Rule::const_decl {
103            let mut it = inner.into_inner();
104            let name = it.next().unwrap().as_str().to_string();
105            let (dt, shape) = parse_ty(it.next().unwrap())?;
106
107            let mut init: Option<ConstInit> = None;
108            for ann in it {
109                if ann.as_rule() == Rule::const_annot {
110                    let text = ann.as_str();
111                    if text.starts_with("@weights") {
112                        let s = ann
113                            .into_inner()
114                            .find(|p| p.as_rule() == Rule::string)
115                            .map(|p| unquote(p.as_str()))
116                            .unwrap_or_else(|| name.clone());
117                        init = Some(ConstInit::Weights { r#ref: s });
118                    } else if text.starts_with("@scalar") {
119                        let n = ann
120                            .into_inner()
121                            .find(|p| p.as_rule() == Rule::number)
122                            .map(|p| parse_number_value(p.as_str()))
123                            .unwrap_or(Value::Null);
124                        init = Some(ConstInit::Scalar { value: n });
125                    } else if text.starts_with("@bytes") {
126                        let bytes = ann
127                            .into_inner()
128                            .find(|p| p.as_rule() == Rule::byte_array)
129                            .map(|pair| {
130                                pair.into_inner()
131                                    .filter(|p| p.as_rule() == Rule::int)
132                                    .filter_map(|p| p.as_str().parse::<u32>().ok())
133                                    .map(|v| v as u8)
134                                    .collect::<Vec<u8>>()
135                            })
136                            .unwrap_or_default();
137                        init = Some(ConstInit::InlineBytes { bytes });
138                    }
139                }
140            }
141
142            let init = init.unwrap_or(ConstInit::Weights {
143                r#ref: name.clone(),
144            });
145            out.insert(
146                name,
147                ConstDecl {
148                    data_type: dt,
149                    shape: dims_to_static_shape(&shape)?,
150                    init,
151                },
152            );
153        }
154    }
155    Ok(())
156}
157
158fn parse_nodes_block(p: Pair<Rule>, out: &mut Vec<Node>) -> Result<(), ParseError> {
159    for inner in p.into_inner() {
160        if inner.as_rule() != Rule::stmt {
161            continue;
162        }
163        let stmt = inner.into_inner().next().unwrap();
164        match stmt.as_rule() {
165            Rule::assign => out.push(parse_assign(stmt)?),
166            Rule::multi_assign => out.push(parse_multi_assign(stmt)?),
167            _ => {}
168        }
169    }
170    Ok(())
171}
172
173fn parse_assign(p: Pair<Rule>) -> Result<Node, ParseError> {
174    let mut it = p.into_inner();
175    let id = it.next().unwrap().as_str().to_string();
176    let (op, inputs, options, outputs) = parse_expr(it.next().unwrap())?;
177    Ok(Node {
178        id,
179        op,
180        inputs,
181        options,
182        outputs,
183    })
184}
185
186fn parse_multi_assign(p: Pair<Rule>) -> Result<Node, ParseError> {
187    let mut it = p.into_inner();
188    let mut outs: Vec<String> = Vec::new();
189
190    // first items are idents inside [...]
191    // We receive them as a flat sequence of ident tokens due to grammar.
192    // Collect until we hit expr.
193    while let Some(next) = it.peek() {
194        if next.as_rule() == Rule::expr {
195            break;
196        }
197        let t = it.next().unwrap();
198        if t.as_rule() == Rule::ident {
199            outs.push(t.as_str().to_string());
200        }
201    }
202
203    let expr = it
204        .next()
205        .ok_or_else(|| ParseError::Internal("missing expr in multi_assign".into()))?;
206    let (op, inputs, options, _outputs_unused) = parse_expr(expr)?;
207    // Use the first output name as the node id for uniqueness; keep real outputs in Node.outputs.
208    let id = outs.first().cloned().unwrap_or_else(|| "tmp".into());
209    Ok(Node {
210        id,
211        op,
212        inputs,
213        options,
214        outputs: Some(outs),
215    })
216}
217
218fn parse_expr(p: Pair<Rule>) -> Result<ParsedExpr, ParseError> {
219    match p.as_rule() {
220        Rule::expr => parse_expr(p.into_inner().next().unwrap()),
221        Rule::call => parse_call(p),
222        Rule::ident => Ok((
223            String::new(),
224            vec![p.as_str().to_string()],
225            Map::new(),
226            None,
227        )),
228        _ => Err(ParseError::Internal(format!(
229            "unexpected expr rule: {:?}",
230            p.as_rule()
231        ))),
232    }
233}
234
235fn parse_call(p: Pair<Rule>) -> Result<ParsedExpr, ParseError> {
236    let mut it = p.into_inner();
237    let op = it.next().unwrap().as_str().to_string();
238    let mut inputs: Vec<String> = Vec::new();
239    let mut options: Map<String, Value> = Map::new();
240
241    // Debug: trace concat operations
242    let is_concat = op == "concat";
243    if is_concat {
244        crate::debug_println!("[PARSER DEBUG] Parsing concat operation");
245    }
246
247    if let Some(args) = it.next() {
248        if args.as_rule() == Rule::args {
249            for (arg_idx, arg) in args.into_inner().enumerate() {
250                if arg.as_rule() != Rule::arg {
251                    continue;
252                }
253                let mut a = arg.into_inner().peekable();
254
255                // Check if this is a named argument: ident '=' value
256                let first = match a.next() {
257                    Some(f) => f,
258                    None => continue,
259                };
260
261                if is_concat {
262                    crate::debug_println!(
263                        "[PARSER DEBUG]   arg[{}]: first.rule={:?}, first.as_str()={}, has_next={}",
264                        arg_idx,
265                        first.as_rule(),
266                        first.as_str(),
267                        a.peek().is_some()
268                    );
269                    if let Some(next) = a.peek() {
270                        crate::debug_println!(
271                            "[PARSER DEBUG]   arg[{}]: next.rule={:?}, next.as_str()={}",
272                            arg_idx,
273                            next.as_rule(),
274                            next.as_str()
275                        );
276                    }
277                }
278
279                if first.as_rule() == Rule::ident
280                    && a.peek().is_some()
281                    && a.peek().unwrap().as_rule() == Rule::value
282                {
283                    // Named argument
284                    let key = first.as_str().to_string();
285                    let val = parse_value(a.next().unwrap())?;
286                    if is_concat {
287                        crate::debug_println!("[PARSER DEBUG]   Named argument: {}={:?}", key, val);
288                    }
289                    options.insert(key, val);
290                } else {
291                    // Positional argument
292                    if is_concat {
293                        crate::debug_println!(
294                            "[PARSER DEBUG]   Positional argument: rule={:?}",
295                            first.as_rule()
296                        );
297                    }
298                    // Handle bracketed input lists like `concat([a, b], ...)`: `parse_value`
299                    // returns `Value::Array` for `[a, b]`, which we flatten into individual inputs.
300                    let v = parse_value(first)?;
301                    match v {
302                        Value::String(s) => inputs.push(s),
303                        Value::Array(arr) => {
304                            for item in arr {
305                                match item {
306                                    Value::String(s) => inputs.push(s),
307                                    other => {
308                                        if let Some(s) = other.as_str() {
309                                            inputs.push(s.to_string());
310                                        }
311                                    }
312                                }
313                            }
314                        }
315                        other => {
316                            if let Some(s) = other.as_str() {
317                                inputs.push(s.to_string());
318                            }
319                        }
320                    }
321                }
322            }
323        }
324    }
325
326    if is_concat {
327        crate::debug_println!(
328            "[PARSER DEBUG] Concat parsed: inputs={:?}, options={:?}",
329            inputs,
330            options
331        );
332    }
333
334    Ok((op, inputs, options, None))
335}
336
337fn parse_outputs_block(
338    p: Pair<Rule>,
339    out: &mut BTreeMap<String, String>,
340) -> Result<(), ParseError> {
341    // WG: outputs { probs }  OR outputs { a,b; }
342    // We'll map each output name to itself.
343    for inner in p.into_inner() {
344        if inner.as_rule() == Rule::output_item {
345            for item in inner.into_inner() {
346                if item.as_rule() == Rule::ident {
347                    let name = item.as_str().to_string();
348                    out.insert(name.clone(), name);
349                }
350            }
351        }
352    }
353    Ok(())
354}
355
356fn parse_ty(p: Pair<Rule>) -> Result<(DataType, Vec<Dimension>), ParseError> {
357    let mut it = p.into_inner();
358    let dt_s = it.next().unwrap().as_str();
359    let dt = DataType::from_wg(dt_s).ok_or_else(|| ParseError::BadDType(dt_s.to_string()))?;
360    let shape = parse_shape(it.next().unwrap())?;
361    Ok((dt, shape))
362}
363
364fn parse_shape(p: Pair<Rule>) -> Result<Vec<Dimension>, ParseError> {
365    let mut shape = Vec::new();
366    for inner in p.into_inner() {
367        if inner.as_rule() == Rule::shape_dim {
368            let item = inner
369                .into_inner()
370                .next()
371                .ok_or_else(|| ParseError::Internal("shape_dim missing inner value".to_string()))?;
372            match item.as_rule() {
373                Rule::int => {
374                    let v: u32 = item
375                        .as_str()
376                        .parse()
377                        .map_err(|_| ParseError::Internal("bad int".into()))?;
378                    shape.push(Dimension::Static(v));
379                }
380                Rule::dynamic_dim => {
381                    let mut it = item.into_inner();
382                    let name = it
383                        .next()
384                        .map(|p| unquote(p.as_str()))
385                        .ok_or_else(|| ParseError::Internal("dynamic_dim missing name".into()))?;
386                    let max_size: u32 = it
387                        .next()
388                        .ok_or_else(|| ParseError::Internal("dynamic_dim missing max".into()))?
389                        .as_str()
390                        .parse()
391                        .map_err(|_| ParseError::Internal("dynamic_dim bad max".into()))?;
392                    shape.push(Dimension::Dynamic(DynamicDimension { name, max_size }));
393                }
394                _ => return Err(ParseError::Internal("unexpected shape_dim rule".into())),
395            }
396        }
397    }
398    Ok(shape)
399}
400
401fn dims_to_static_shape(shape: &[Dimension]) -> Result<Vec<u32>, ParseError> {
402    let mut out = Vec::with_capacity(shape.len());
403    for dim in shape {
404        match dim {
405            Dimension::Static(v) => out.push(*v),
406            Dimension::Dynamic(_) => return Err(ParseError::DynamicConstShape),
407        }
408    }
409    Ok(out)
410}
411
412fn parse_value(p: Pair<Rule>) -> Result<Value, ParseError> {
413    match p.as_rule() {
414        Rule::value => parse_value(p.into_inner().next().unwrap()),
415        Rule::literal => parse_value(p.into_inner().next().unwrap()),
416        Rule::string => Ok(Value::String(unquote(p.as_str()))),
417        Rule::number => Ok(parse_number_value(p.as_str())),
418        Rule::boolean => Ok(Value::Bool(p.as_str() == "true")),
419        Rule::null => Ok(Value::Null),
420        Rule::array => {
421            let mut arr = Vec::new();
422            for inner in p.into_inner() {
423                if inner.as_rule() == Rule::value {
424                    arr.push(parse_value(inner)?);
425                }
426            }
427            Ok(Value::Array(arr))
428        }
429        Rule::object => {
430            let mut map = serde_json::Map::new();
431            for inner in p.into_inner() {
432                if inner.as_rule() == Rule::object_item {
433                    let mut it = inner.into_inner();
434                    let key_pair = it
435                        .next()
436                        .ok_or_else(|| ParseError::Internal("object key missing".into()))?;
437                    let key = match key_pair.as_rule() {
438                        Rule::string => unquote(key_pair.as_str()),
439                        Rule::ident => key_pair.as_str().to_string(),
440                        _ => {
441                            return Err(ParseError::Internal(
442                                "unexpected object key rule".to_string(),
443                            ));
444                        }
445                    };
446                    let value_pair = it
447                        .next()
448                        .ok_or_else(|| ParseError::Internal("object value missing".into()))?;
449                    map.insert(key, parse_value(value_pair)?);
450                }
451            }
452            Ok(Value::Object(map))
453        }
454        Rule::ident => Ok(Value::String(p.as_str().to_string())),
455        _ => Err(ParseError::Internal(format!(
456            "unexpected value rule: {:?}",
457            p.as_rule()
458        ))),
459    }
460}
461
462fn parse_number_value(s: &str) -> Value {
463    // Prefer i64 when exact, otherwise f64.
464    if !s.contains('.') && !s.contains('e') && !s.contains('E') {
465        if let Ok(i) = s.parse::<i64>() {
466            return Value::Number(i.into());
467        }
468    }
469    Value::Number(serde_json::Number::from_f64(s.parse::<f64>().unwrap_or(0.0)).unwrap())
470}
471
472fn unquote(s: &str) -> String {
473    let mut t = s.to_string();
474    if t.starts_with('"') && t.ends_with('"') && t.len() >= 2 {
475        t.remove(0);
476        t.pop();
477    }
478    t.replace("\\\"", "\"").replace("\\\\", "\\")
479}
480
481#[cfg(test)]
482mod tests {
483    use super::*;
484
485    #[test]
486    fn test_parse_simple_graph() {
487        let input = r#"
488webnn_graph "test" v1 {
489  inputs {
490    x: f32[1, 10];
491  }
492  consts {
493    W: f32[10, 5] @weights("W");
494  }
495  nodes {
496    result = matmul(x, W);
497  }
498  outputs { result; }
499}
500"#;
501        let graph = parse_wg_text(input).unwrap();
502        assert_eq!(graph.format, "webnn-graph-json");
503        assert_eq!(graph.version, 1);
504        assert_eq!(graph.inputs.len(), 1);
505        assert_eq!(graph.consts.len(), 1);
506        assert_eq!(graph.nodes.len(), 1);
507        assert_eq!(graph.outputs.len(), 1);
508    }
509
510    #[test]
511    fn test_parse_inputs() {
512        let input = r#"
513webnn_graph "test" v1 {
514  inputs {
515    x: f32[1, 10];
516    y: i32[5];
517  }
518  nodes {}
519  outputs { x; }
520}
521"#;
522        let graph = parse_wg_text(input).unwrap();
523        assert_eq!(graph.inputs.len(), 2);
524        assert!(graph.inputs.contains_key("x"));
525        assert!(graph.inputs.contains_key("y"));
526
527        let x_desc = &graph.inputs["x"];
528        assert_eq!(x_desc.data_type, DataType::Float32);
529        assert_eq!(
530            x_desc.shape,
531            vec![Dimension::Static(1), Dimension::Static(10)]
532        );
533
534        let y_desc = &graph.inputs["y"];
535        assert_eq!(y_desc.data_type, DataType::Int32);
536        assert_eq!(y_desc.shape, vec![Dimension::Static(5)]);
537    }
538
539    #[test]
540    fn test_parse_dynamic_input_shape() {
541        let input = r#"
542webnn_graph "test" v2 {
543  inputs {
544    x: f32[dyn("batch_size", 8), 128];
545  }
546  nodes {}
547  outputs { x; }
548}
549"#;
550        let graph = parse_wg_text(input).unwrap();
551        let x_desc = &graph.inputs["x"];
552        assert!(matches!(
553            &x_desc.shape[0],
554            Dimension::Dynamic(d) if d.name == "batch_size" && d.max_size == 8
555        ));
556        assert!(matches!(&x_desc.shape[1], Dimension::Static(128)));
557    }
558
559    #[test]
560    fn test_parse_consts_with_weights() {
561        let input = r#"
562webnn_graph "test" v1 {
563  inputs { x: f32[1]; }
564  consts {
565    W: f32[10, 5] @weights("W");
566    b: f32[5] @weights("bias");
567  }
568  nodes {}
569  outputs { x; }
570}
571"#;
572        let graph = parse_wg_text(input).unwrap();
573        assert_eq!(graph.consts.len(), 2);
574
575        let w = &graph.consts["W"];
576        assert_eq!(w.data_type, DataType::Float32);
577        assert_eq!(w.shape, vec![10, 5]);
578        assert!(matches!(&w.init, ConstInit::Weights { r#ref } if r#ref == "W"));
579
580        let b = &graph.consts["b"];
581        assert!(matches!(&b.init, ConstInit::Weights { r#ref } if r#ref == "bias"));
582    }
583
584    #[test]
585    fn test_parse_consts_with_scalar() {
586        let input = r#"
587webnn_graph "test" v1 {
588  inputs { x: f32[1]; }
589  consts {
590    scale: f32[1] @scalar(2.5);
591  }
592  nodes {}
593  outputs { x; }
594}
595"#;
596        let graph = parse_wg_text(input).unwrap();
597        let scale = &graph.consts["scale"];
598        match &scale.init {
599            ConstInit::Scalar { value } => {
600                assert_eq!(value.as_f64().unwrap(), 2.5);
601            }
602            _ => panic!("Expected scalar init"),
603        }
604    }
605
606    #[test]
607    fn test_parse_nodes() {
608        let input = r#"
609webnn_graph "test" v1 {
610  inputs { x: f32[1, 2048]; }
611  consts { W: f32[2048, 1000] @weights("W"); }
612  nodes {
613    result = matmul(x, W);
614  }
615  outputs { result; }
616}
617"#;
618        let graph = parse_wg_text(input).unwrap();
619        assert_eq!(graph.nodes.len(), 1);
620
621        let node = &graph.nodes[0];
622        assert_eq!(node.id, "result");
623        assert_eq!(node.op, "matmul");
624        assert_eq!(node.inputs, vec!["x", "W"]);
625        assert!(node.options.is_empty());
626    }
627
628    #[test]
629    fn test_parse_nodes_with_options() {
630        let input = r#"
631webnn_graph "test" v1 {
632  inputs { x: f32[1, 10]; }
633  nodes {
634    result = softmax(x, axis=1);
635  }
636  outputs { result; }
637}
638"#;
639        let graph = parse_wg_text(input).unwrap();
640        let node = &graph.nodes[0];
641        assert_eq!(node.op, "softmax");
642        assert_eq!(node.inputs, vec!["x"]);
643        assert_eq!(node.options.get("axis").unwrap().as_i64().unwrap(), 1);
644    }
645
646    #[test]
647    fn test_parse_multi_assign() {
648        let input = r#"
649webnn_graph "test" v1 {
650  inputs { x: f32[10]; }
651  nodes {
652    [a, b] = split(x);
653  }
654  outputs { a; }
655}
656"#;
657        let graph = parse_wg_text(input).unwrap();
658        let node = &graph.nodes[0];
659        assert_eq!(node.id, "a");
660        assert_eq!(node.op, "split");
661        assert_eq!(node.outputs, Some(vec!["a".to_string(), "b".to_string()]));
662    }
663
664    #[test]
665    fn test_parse_concat_bracket_input_list() {
666        let input = r#"
667webnn_graph "model" v1 {
668  inputs {
669    tensors_0: f32[2, 3, 4, 5];
670    tensors_1: f32[2, 3, 4, 5];
671  }
672  nodes {
673    [operand_1] = concat([tensors_0, tensors_1], axis=0);
674  }
675  outputs { operand_1; }
676}
677"#;
678        let graph = parse_wg_text(input).unwrap();
679        let node = &graph.nodes[0];
680        assert_eq!(node.op, "concat");
681        assert_eq!(node.inputs, vec!["tensors_0", "tensors_1"]);
682        assert_eq!(node.options.get("axis").unwrap().as_i64().unwrap(), 0);
683    }
684
685    #[test]
686    fn test_parse_outputs() {
687        let input = r#"
688webnn_graph "test" v1 {
689  inputs { x: f32[1]; }
690  nodes {
691    a = relu(x);
692    b = sigmoid(x);
693  }
694  outputs { a; b; }
695}
696"#;
697        let graph = parse_wg_text(input).unwrap();
698        assert_eq!(graph.outputs.len(), 2);
699        assert_eq!(graph.outputs.get("a").unwrap(), "a");
700        assert_eq!(graph.outputs.get("b").unwrap(), "b");
701    }
702
703    #[test]
704    fn test_parse_invalid_dtype() {
705        let input = r#"
706webnn_graph "test" v1 {
707  inputs { x: float32[1]; }
708  nodes {}
709  outputs { x; }
710}
711"#;
712        let result = parse_wg_text(input);
713        assert!(result.is_err());
714        // The pest parser should fail because "float32" doesn't match the dtype rule
715        match result {
716            Err(ParseError::Pest(_)) => {}
717            Err(e) => panic!("Expected Pest parse error, got: {:?}", e),
718            Ok(_) => panic!("Expected error but parsing succeeded"),
719        }
720    }
721
722    #[test]
723    fn test_unquote() {
724        assert_eq!(unquote(r#""hello""#), "hello");
725        assert_eq!(unquote(r#""hello\"world""#), "hello\"world");
726        assert_eq!(unquote(r#""path\\to\\file""#), "path\\to\\file");
727        assert_eq!(unquote("no_quotes"), "no_quotes");
728    }
729
730    #[test]
731    fn test_parse_number_value() {
732        let int_val = parse_number_value("42");
733        assert_eq!(int_val.as_i64().unwrap(), 42);
734
735        let float_val = parse_number_value("3.12");
736        assert_eq!(float_val.as_f64().unwrap(), 3.12);
737
738        let sci_val = parse_number_value("1e-3");
739        assert_eq!(sci_val.as_f64().unwrap(), 0.001);
740    }
741
742    #[test]
743    fn test_parse_dollar_sign_identifiers() {
744        let input = r#"
745webnn_graph "test" v1 {
746  inputs {
747    x: f32[1, 10];
748  }
749  consts {
750    $_weight: f32[10, 5] @weights("W");
751  }
752  nodes {
753    $_temp = relu(x);
754    result = matmul($_temp, $_weight);
755  }
756  outputs { result; }
757}
758"#;
759        let graph = parse_wg_text(input).unwrap();
760        assert_eq!(graph.inputs.len(), 1);
761        assert_eq!(graph.consts.len(), 1);
762        assert!(graph.consts.contains_key("$_weight"));
763        assert_eq!(graph.nodes.len(), 2);
764        assert_eq!(graph.nodes[0].id, "$_temp");
765        assert_eq!(graph.nodes[1].id, "result");
766        assert_eq!(graph.nodes[1].inputs, vec!["$_temp", "$_weight"]);
767    }
768}