tract_nnef/ast/
parse.rs

1use nom_language::error::{convert_error, VerboseError};
2use tract_core::internal::*;
3
4use nom::branch::alt;
5use nom::combinator::map;
6use nom::{bytes::complete::*, character::complete::*, combinator::*, multi::*, sequence::*};
7use nom::{Finish, IResult, Parser};
8
9use crate::ast::*;
10
11type R<'i, O> = IResult<&'i str, O, VerboseError<&'i str>>;
12
13pub(super) fn translate_error(e: nom::Err<VerboseError<&str>>) -> TractError {
14    format_err!("{}", e)
15}
16
17#[inline(never)]
18pub fn unwrap_parse<'s, P, O>(input: &'s str, parser: P) -> TractResult<O>
19where
20    P: Parser<&'s str, Output = O, Error = VerboseError<&'s str>>,
21{
22    all_consuming(parser)
23        .parse(input)
24        .finish()
25        .map(|(_, p)| p)
26        .map_err(|e| anyhow!(convert_error(input, e)))
27}
28
29pub fn parse_document(doc: &str) -> TractResult<Document> {
30    unwrap_parse(doc, document)
31}
32
33#[inline(never)]
34pub fn parse_fragments(doc: &str) -> TractResult<Vec<FragmentDef>> {
35    unwrap_parse(doc, fragments)
36}
37
38#[inline(never)]
39pub fn parse_fragment_decl(doc: &str) -> TractResult<FragmentDecl> {
40    unwrap_parse(doc, fragment_decl)
41}
42
43#[inline(never)]
44pub fn parse_parameters(doc: &str) -> TractResult<Vec<Parameter>> {
45    unwrap_parse(doc, parameter_list)
46}
47
48// <document> ::= <version> <extension>* <fragmentdefinition>* <graph-definition>
49fn document(i: &str) -> R<'_, Document> {
50    map(
51        (version, many0(extension), fragments, graph_def),
52        |(version, extension, fragments, graph_def)| Document {
53            version,
54            extension,
55            fragments,
56            graph_def,
57        },
58    )
59    .parse(i)
60}
61
62fn fragments(i: &str) -> R<'_, Vec<FragmentDef>> {
63    many0(fragment_def).parse(i)
64}
65
66// <version> ::= "version" <numeric-literal> ";"
67
68fn version(i: &str) -> R<'_, NumericLiteral> {
69    preceded(stag("version"), cut(terminated(numeric_literal, stag(";")))).parse(i)
70}
71
72// NNEF spec: <extension> ::= "extension" <identifier>+ ";"
73// tract accepts: <extension> ::= "extension" <identifier> <anything-but-;>";"
74fn extension(i: &str) -> R<'_, (Identifier, String)> {
75    delimited(
76        stag("extension"),
77        pair(spaced(identifier), map(take_until(";"), |s: &str| s.to_string())),
78        stag(";"),
79    )
80    .parse(i)
81}
82
83// FRAGMENT
84
85// <fragment-definition> ::= <fragment-declaration> (<body> | ";")
86fn fragment_def(i: &str) -> R<'_, FragmentDef> {
87    spaced(map(
88        pair(fragment_decl, alt((map(body, Some), map(stag(";"), |_| None)))),
89        |(decl, body)| FragmentDef { decl, body },
90    ))
91    .parse(i)
92}
93
94// <fragment-declaration> ::= "fragment" <identifier> [<generic-declaration>] "(" <parameter-list> ")" "->" "(" <result-list> ")"
95fn fragment_decl(i: &str) -> R<'_, FragmentDecl> {
96    preceded(stag("fragment"), cut(commited_fragment_decl)).parse(i)
97}
98
99fn commited_fragment_decl(i: &str) -> R<'_, FragmentDecl> {
100    let (i, id) = identifier(i)?;
101    let (i, generic_decl) = opt(generic_decl).parse(i)?;
102    let (i, _) = stag("(").parse(i)?;
103    let (i, parameters) = cut(parameter_list).parse(i)?;
104    let (i, _) = stag(")").parse(i)?;
105    let (i, _) = stag("->").parse(i)?;
106    let (i, _) = stag("(").parse(i)?;
107    let (i, results) = cut(result_list).parse(i)?;
108    let (i, _) = stag(")").parse(i)?;
109    Ok((i, FragmentDecl { id, parameters, results, generic_decl }))
110}
111
112// <generic-declaration> ::= "<" "?" ["=" <type-name>] ">"
113fn generic_decl(i: &str) -> R<'_, Option<TypeName>> {
114    let (i, _) = stag("<").parse(i)?;
115    let (i, _) = stag("?").parse(i)?;
116    let (i, name) = opt(preceded(stag("="), type_name)).parse(i)?;
117    let (i, _) = stag(">").parse(i)?;
118    Ok((i, name))
119}
120
121// <parameter-list> ::= <parameter> ("," <parameter>)*
122fn parameter_list(i: &str) -> R<'_, Vec<Parameter>> {
123    separated_list0(stag(","), parameter).parse(i)
124}
125
126// <result-list> ::= <result> ("," <result>)*
127fn result_list(i: &str) -> R<'_, Vec<Result_>> {
128    separated_list0(stag(","), result).parse(i)
129}
130
131// <parameter> ::= <identifier> ":" <type-spec> ["=" <literal-expr>]
132fn parameter(i: &str) -> R<'_, Parameter> {
133    map(
134        pair(
135            separated_pair(identifier, stag(":"), cut(type_spec)),
136            opt(preceded(stag("="), literal_expr)),
137        ),
138        |((id, spec), lit)| Parameter { id, spec, lit, doc: None },
139    )
140    .parse(i)
141}
142
143// <result> ::= <identifier> ":" <type-spec>
144fn result(i: &str) -> R<'_, Result_> {
145    map(separated_pair(identifier, stag(":"), cut(type_spec)), |(id, spec)| Result_ { id, spec })
146        .parse(i)
147}
148
149fn literal_expr(i: &str) -> R<'_, Literal> {
150    spaced(alt((
151        literal,
152        map(delimited(stag("["), separated_list0(stag(","), literal), stag("]")), Literal::Array),
153        map(delimited(stag("("), separated_list0(stag(","), literal), stag(")")), Literal::Tuple),
154    )))
155    .parse(i)
156}
157
158// <type-spec> ::= <type-name> | <tensor-type-spec> | <array-type-spec> | <tuple-type-spec>
159fn type_spec(i: &str) -> R<'_, TypeSpec> {
160    fn non_array_type(i: &str) -> R<'_, TypeSpec> {
161        alt((tuple_type_spec, map(type_name, TypeSpec::Single), tensor_type_spec)).parse(i)
162    }
163    alt((
164        (map(terminated(non_array_type, pair(stag("["), stag("]"))), |t| {
165            TypeSpec::Array(Box::new(t))
166        })),
167        non_array_type,
168    ))
169    .parse(i)
170}
171
172// <type-name> ::= "integer" | "scalar" | "logical" | "string" | "?"
173fn type_name(i: &str) -> R<'_, TypeName> {
174    spaced(alt((
175        map(tag("integer"), |_| TypeName::Integer),
176        map(tag("scalar"), |_| TypeName::Scalar),
177        map(tag("logical"), |_| TypeName::Logical),
178        map(tag("string"), |_| TypeName::String),
179        #[cfg(feature = "complex")]
180        map(tag("complex"), |_| TypeName::Complex),
181        map(tag("?"), |_| TypeName::Any),
182    )))
183    .parse(i)
184}
185
186// <tensor-type-spec> ::= "tensor" "<" [<type-name>] ">"
187fn tensor_type_spec(i: &str) -> R<'_, TypeSpec> {
188    map(delimited(pair(stag("tensor"), stag("<")), type_name, stag(">")), TypeSpec::Tensor).parse(i)
189}
190
191// <tuple-type-spec> ::= "(" <type-spec> ("," <type-spec>)+ ")"
192fn tuple_type_spec(i: &str) -> R<'_, TypeSpec> {
193    map(delimited(stag("("), separated_list0(stag(","), type_spec), stag(")")), TypeSpec::Tuple)
194        .parse(i)
195}
196
197// GRAPH
198
199// <graph-definition> ::= <graph-declaration> <body>
200// <graph-declaration> ::= "graph" <identifier> "(" <identifier-list> ")" "->" "(" <identifier-list> ")"
201// <identifier-list> ::= <identifier> ("," <identifier>)*
202fn graph_def(i: &str) -> R<'_, GraphDef> {
203    let (i, _) = stag("graph").parse(i)?;
204    let (i, id) = identifier(i)?;
205    let (i, _) = stag("(").parse(i)?;
206    let (i, parameters) = separated_list0(stag(","), identifier).parse(i)?;
207    let (i, _) = stag(")").parse(i)?;
208    let (i, _) = stag("->").parse(i)?;
209    let (i, _) = stag("(").parse(i)?;
210    let (i, results) = separated_list0(stag(","), identifier).parse(i)?;
211    let (i, _) = stag(")").parse(i)?;
212    let (i, body) = spaced(body).parse(i)?;
213    Ok((i, GraphDef { id, parameters, results, body }))
214}
215
216// BODY
217
218// <body> ::= "{" <assignment>+ "}"
219fn body(i: &str) -> R<'_, Vec<Assignment>> {
220    delimited(stag("{"), many0(assignment), stag("}")).parse(i)
221}
222
223// <assignment> ::= <lvalue-expr> "=" <rvalue-expr> ";"
224fn assignment(i: &str) -> R<'_, Assignment> {
225    spaced(terminated(
226        map(separated_pair(lvalue, stag("="), rvalue), |(left, right)| Assignment { left, right }),
227        stag(";"),
228    ))
229    .parse(i)
230}
231
232// <lvalue-expr> ::= <identifier> | <array-lvalue-expr> | <tuple-lvalue-expr>
233// <array-lvalue-expr> ::= "[" [<lvalue-expr> ("," <lvalue-expr>)* ] "]"
234// <tuple-lvalue-expr> ::= "(" <lvalue-expr> ("," <lvalue-expr>)+ ")" | <lvalue-expr> ("," <lvalue-expr>)+
235fn lvalue(i: &str) -> R<'_, LValue> {
236    fn inner_lvalue(i: &str) -> R<'_, LValue> {
237        alt((
238            map(
239                delimited(stag("["), separated_list0(stag(","), inner_lvalue), stag("]")),
240                LValue::Array,
241            ),
242            map(
243                delimited(stag("("), separated_list0(stag(","), inner_lvalue), stag(")")),
244                LValue::Tuple,
245            ),
246            map(spaced(identifier), LValue::Identifier),
247        ))
248        .parse(i)
249    }
250
251    map(separated_list0(stag(","), inner_lvalue), |mut iv| {
252        if iv.len() == 1 {
253            iv.remove(0)
254        } else {
255            LValue::Tuple(iv)
256        }
257    })
258    .parse(i)
259}
260
261// <invocation> ::= <identifier> ["<" <type-name> ">"] "(" <argument-list> ")"
262fn invocation(i: &str) -> R<'_, Invocation> {
263    let (i, id) = spaced(identifier).parse(i)?;
264    let (i, generic_type_name) = opt(delimited(stag("<"), type_name, stag(">"))).parse(i)?;
265    let (i, _) = stag("(").parse(i)?;
266    let (i, arguments) = argument_list.parse(i)?;
267    let (i, _) = stag(")").parse(i)?;
268    Ok((i, Invocation { id, generic_type_name, arguments }))
269}
270
271// <argument-list> ::= <argument> ("," <argument>)*
272fn argument_list(i: &str) -> R<'_, Vec<Argument>> {
273    separated_list0(stag(","), argument).parse(i)
274}
275
276// <argument> ::= <rvalue-expr> | <identifier> "=" <rvalue-expr>
277fn argument(i: &str) -> R<'_, Argument> {
278    spaced(map(pair(opt(terminated(identifier, stag("="))), rvalue), |(id, rvalue)| Argument {
279        id,
280        rvalue,
281    }))
282    .parse(i)
283}
284
285//<rvalue-expr> ::= <identifier> | <literal> | <binary-expr> | <unary-expr> | <paren-expr>
286//                  | <array-rvalue-expr> | <tuple-rvalue-expr> | <subscript-expr> | <if-else-expr>
287//                  | <comprehension-expr> | <builtin-expr> | <invocation>
288fn rvalue(i: &str) -> R<'_, RValue> {
289    fn atom(i: &str) -> R<'_, RValue> {
290        spaced(alt((
291            map(invocation, RValue::Invocation),
292            map(literal, RValue::Literal),
293            map(identifier, RValue::Identifier),
294            map(pair(spaced(recognize(one_of("+-!"))), rvalue), |(op, rv)| {
295                RValue::Unary(op.into(), Box::new(rv))
296            }),
297            map(delimited(tag("("), separated_list0(stag(","), rvalue), tag(")")), |mut rvs| {
298                if rvs.len() == 1 {
299                    rvs.remove(0)
300                } else {
301                    RValue::Tuple(rvs)
302                }
303            }),
304            map(comprehension_expr, |c| RValue::Comprehension(Box::new(c))),
305            map(delimited(tag("["), separated_list0(stag(","), rvalue), tag("]")), |rvs| {
306                RValue::Array(rvs)
307            }),
308        )))
309        .parse(i)
310    }
311    macro_rules! bin {
312        ($name:ident, $operand: ident, $operator: expr) => {
313            fn $name(i: &str) -> R<'_, RValue> {
314                let (i, init) = $operand(i)?;
315                fold_many0(
316                    pair($operator, $operand),
317                    move || init.clone(),
318                    |left, (op, right)| {
319                        RValue::Binary(Box::new(left), op.to_string(), Box::new(right))
320                    },
321                )
322                .parse(i)
323            }
324        };
325    }
326
327    // <subscript-expr> ::= <rvalue-expr> "[" (<rvalue-expr> | [<rvalue-expr>] ":" [<rvalue-expr>]) "]"
328    fn sub(i: &str) -> R<'_, RValue> {
329        alt((
330            map(
331                pair(
332                    atom,
333                    delimited(
334                        stag("["),
335                        alt((
336                            map(separated_pair(opt(rvalue), stag(":"), opt(rvalue)), |(a, b)| {
337                                Subscript::Range(a, b)
338                            }),
339                            map(rvalue, Subscript::Single),
340                        )),
341                        stag("]"),
342                    ),
343                ),
344                |(rv, range)| RValue::Subscript(Box::new(rv), Box::new(range)),
345            ),
346            atom,
347        ))
348        .parse(i)
349    }
350
351    bin!(exp, sub, tag("^"));
352    bin!(mul, exp, one_of("*/"));
353    bin!(add, mul, one_of("+-"));
354    bin!(comp, add, alt((tag("=="), tag("!="), tag("<"), tag(">"), tag("<="), tag(">="))));
355    bin!(boolean, comp, alt((tag("||"), tag("&&"))));
356    bin!(in_for, boolean, tag("in"));
357
358    // <if-else-expr> ::= <rvalue-expr> "if" <rvalue-expr> "else" <rvalue-expr>
359    fn ite(i: &str) -> R<'_, RValue> {
360        let (i, leftmost) = in_for(i)?;
361        let (i, _) = space_and_comments(i)?;
362        if i.starts_with("if") {
363            let (i, _) = stag("if").parse(i)?;
364            let (i, cond) = in_for(i)?;
365            let (i, _) = stag("else").parse(i)?;
366            let (i, otherwise) = in_for(i)?;
367            Ok((i, RValue::IfThenElse(Box::new(IfThenElse { cond, then: leftmost, otherwise }))))
368        } else {
369            Ok((i, leftmost))
370        }
371    }
372
373    ite(i)
374}
375
376// <comprehension-expr> ::= "[" "for" <loop-iter-list> ["if" <rvalue-expr>] "yield" <rvalue-expr> "]"
377fn comprehension_expr(i: &str) -> R<'_, Comprehension> {
378    delimited(
379        pair(stag("["), stag("for")),
380        map(separated_pair(loop_iters, stag("yield"), rvalue), |(loop_iters, yields)| {
381            Comprehension { loop_iters, filter: None, yields }
382        }),
383        stag("]"),
384    )
385    .parse(i)
386}
387
388// <loop-iter> ::= <identifier> "in" <rvalue-expr>
389// <loop-iter-list> ::= <loop-iter> ("," <loop-iter>)*
390fn loop_iters(i: &str) -> R<'_, Vec<(Identifier, RValue)>> {
391    separated_list0(stag(","), separated_pair(identifier, stag("in"), rvalue)).parse(i)
392}
393
394// TERMINALS
395
396// identifier: identifiers must consist of the following ASCII characters: _, [a-z], [A-Z], [0-9].
397// The identifier must not start with a digit.
398pub(super) fn identifier(i: &str) -> R<'_, Identifier> {
399    alt((escaped_identifier, direct_identifier)).parse(i)
400}
401
402pub(super) fn direct_identifier(i: &str) -> R<'_, Identifier> {
403    map(
404        recognize(pair(alt((alpha1, tag("_"))), many0(alt((alphanumeric1, tag("_")))))),
405        Identifier::from,
406    )
407    .parse(i)
408}
409
410pub(super) fn escaped_identifier(i: &str) -> R<'_, Identifier> {
411    map(preceded(tag("i"), string_literal), Identifier).parse(i)
412}
413
414// <literal> ::= <numeric-literal> | <string-literal> | <logical-literal>
415fn literal(i: &str) -> R<'_, Literal> {
416    spaced(alt((
417        map(numeric_literal, Literal::Numeric),
418        map(string_literal, Literal::String),
419        map(logical_literal, Literal::Logical),
420    )))
421    .parse(i)
422}
423
424pub(super) fn numeric_literal(i: &str) -> R<'_, String> {
425    fn exp_part(i: &str) -> R<'_, &str> {
426        recognize((one_of("eE"), opt(tag("-")), digit1)).parse(i)
427    }
428    fn frac_part(i: &str) -> R<'_, &str> {
429        recognize((tag("."), digit0)).parse(i)
430    }
431    spaced(map(
432        recognize((opt(tag("-")), alt((digit1, tag("inf"))), opt(frac_part), opt(exp_part))),
433        |s: &str| s.to_owned(),
434    ))
435    .parse(i)
436}
437
438fn string_literal(i: &str) -> R<'_, String> {
439    fn inner(i: &str) -> R<'_, String> {
440        map(
441            many0(alt((
442                preceded(tag("\\"), nom::character::complete::anychar),
443                nom::character::complete::none_of("\\\"'"),
444            ))),
445            |v: Vec<char>| v.into_iter().collect(),
446        )
447        .parse(i)
448    }
449    map(alt((delimited(tag("'"), inner, tag("'")), delimited(tag("\""), inner, tag("\"")))), |s| s)
450        .parse(i)
451}
452
453pub(super) fn logical_literal(i: &str) -> R<'_, bool> {
454    spaced(alt((map(tag("true"), |_| true), map(tag("false"), |_| false)))).parse(i)
455}
456
457// SPACES
458
459fn space_and_comments(i: &str) -> R<'_, ()> {
460    map(
461        many0(alt((recognize(one_of(" \t\n\r")), recognize((tag("#"), many0(none_of("\r\n"))))))),
462        |_| (),
463    )
464    .parse(i)
465}
466
467fn spaced<'s, O, F>(it: F) -> impl Parser<&'s str, Output = O, Error = VerboseError<&'s str>>
468where
469    F: Parser<&'s str, Output = O, Error = VerboseError<&'s str>>,
470{
471    delimited(space_and_comments, it, space_and_comments)
472}
473
474pub(super) fn stag<'s>(
475    t: &'static str,
476) -> impl Parser<&'s str, Output = &'s str, Error = VerboseError<&'s str>> {
477    spaced(tag(t))
478}
479
480#[cfg(test)]
481mod test {
482    use super::*;
483    use TypeName::*;
484    use TypeSpec::*;
485
486    fn p<'s, P, O, E>(parser: P, i: &'s str) -> O
487    where
488        O: std::fmt::Debug,
489        P: Fn(&'s str) -> IResult<&'s str, O, E>,
490        E: nom::error::ParseError<&'s str> + std::fmt::Debug,
491    {
492        let res = all_consuming(parser).parse(i).unwrap();
493        res.1
494    }
495
496    fn param(s: impl Into<std::string::String>, t: TypeSpec) -> Parameter {
497        Parameter { id: Identifier(s.into()), spec: t, lit: None, doc: None }
498    }
499
500    fn result(s: impl Into<std::string::String>, t: TypeSpec) -> Result_ {
501        Result_ { id: Identifier(s.into()), spec: t }
502    }
503
504    #[test]
505    fn test_type_spec() {
506        assert_eq!(p(type_spec, "scalar"), Single(Scalar));
507        assert_eq!(p(type_spec, "scalar[]"), Array(Box::new(Single(Scalar))));
508        assert_eq!(p(type_spec, "tensor<scalar>[]"), Array(Box::new(Tensor(TypeName::Scalar))));
509        assert_eq!(
510            p(type_spec, "(scalar,scalar[],tensor<scalar>)"),
511            Tuple(vec!(Single(Scalar), Array(Box::new(Single(Scalar))), Tensor(Scalar)))
512        );
513        assert_eq!(p(type_spec, "tensor<?>[]"), Array(Box::new(Tensor(TypeName::Any))));
514        assert_eq!(p(type_spec, "scalar[ ]"), Array(Box::new(Single(Scalar))));
515        assert_eq!(
516            p(type_spec, " ( scalar , scalar [ ] , tensor < scalar > ) "),
517            Tuple(vec!(Single(Scalar), Array(Box::new(Single(Scalar))), Tensor(Scalar)))
518        );
519        #[cfg(feature = "complex")]
520        assert_eq!(p(type_spec, "tensor<complex>[]"), Array(Box::new(Tensor(TypeName::Complex))));
521    }
522
523    #[test]
524    fn test_fragment_decl_fizz() {
525        let parsed = p(
526            fragment_decl,
527            "fragment fizz<? = scalar>( shape: integer[] ) -> ( output: tensor<?> )",
528        );
529        assert_eq!(
530            parsed,
531            FragmentDecl {
532                id: "fizz".into(),
533                generic_decl: Some(Some(Scalar)),
534                parameters: vec!(param("shape", Array(Box::new(Single(Integer)))),),
535                results: vec!(result("output", Tensor(Any))),
536            }
537        );
538    }
539
540    #[test]
541    fn test_fragment_decl_logarithmic_quantize() {
542        let parsed = p(fragment_decl,
543                           "fragment logarithmic_quantize(x: tensor<scalar>, max: tensor<scalar>, bits: integer ) -> ( y: tensor<scalar> )"
544                          );
545        assert_eq!(
546            parsed,
547            FragmentDecl {
548                id: "logarithmic_quantize".into(),
549                generic_decl: None,
550                parameters: vec!(
551                    param("x", Tensor(Scalar)),
552                    param("max", Tensor(Scalar)),
553                    param("bits", Single(Integer))
554                ),
555                results: vec!(result("y", Tensor(Scalar))),
556            }
557        );
558    }
559
560    #[test]
561    fn test_fragment_decl_external() {
562        p(
563            fragment_decl,
564            "fragment external<? = scalar>( shape: integer[] ) -> ( output: tensor<?> )",
565        );
566    }
567
568    #[test]
569    fn test_fragment_reshape() {
570        p(fragments, "fragment reshape<?>( input: tensor<?>, shape: integer[], axis_start: integer = 0, axis_count: integer = -1 ) -> ( output: tensor<?> );");
571    }
572
573    #[test]
574    fn test_fragment_conv() {
575        p(
576            fragments,
577            r#"
578            fragment conv(
579                input: tensor<scalar>,
580                filter: tensor<scalar>,
581                bias: tensor<scalar> = 0.0,
582                border: string = 'constant',
583                padding: (integer,integer)[] = [],
584                stride: integer[] = [],
585                dilation: integer[] = [],
586                groups: integer = 1 )
587            -> ( output: tensor<scalar> );
588            "#,
589        );
590    }
591
592    #[test]
593    fn test_fragment_local_response_normalization() {
594        p(
595            fragments,
596            r#"
597            fragment local_response_normalization(
598                input: tensor<scalar>,
599                size: integer[],
600                alpha: scalar = 1.0,
601                beta: scalar = 0.5,
602                bias: scalar = 1.0 )
603            -> ( output: tensor<scalar> )
604            {
605                sigma = bias + alpha * box(sqr(input), size = size, normalize = true);
606                output = input / (sigma ^ beta);
607            }
608            "#,
609        );
610    }
611
612    #[test]
613    fn test_batch_normalization() {
614        p(
615            fragments,
616            r#"
617            fragment batch_normalization( input: tensor<scalar>, mean: tensor<scalar>, variance: tensor<scalar>, offset: tensor<scalar>, scale: tensor<scalar>, epsilon: scalar )
618            -> ( output: tensor<scalar> )
619            {
620                output = offset + scale * (input - mean) / sqrt(variance + epsilon);
621            }
622            "#,
623        );
624    }
625
626    #[test]
627    fn test_avg_roi_align() {
628        p(
629            fragments,
630            r#"
631                fragment avg_roi_align(
632                    input: tensor<scalar>,
633                    rois: tensor<scalar>,
634                    batch_index: tensor<integer>,
635                    output_size: integer[],
636                    sampling_rate: integer[],
637                    resize_method: string = 'symmetric' )
638                -> ( output: tensor<scalar> )
639                {
640                    size = [for i in range_of(output_size) yield output_size[i] * sampling_rate[i]];
641                    resized = roi_resample(input, rois, batch_index, output_size = size,
642                                         method = resize_method);
643                    output = avg_pool(resized, size = sampling_rate, stride = sampling_rate);
644                }
645            "#,
646        );
647    }
648
649    #[test]
650    fn test_min_max_linear_quantize() {
651        p(
652            fragments,
653            r#"
654                fragment min_max_linear_quantize(
655                    x: tensor<scalar>,
656                    min: tensor<scalar>,
657                    max: tensor<scalar>,
658                    bits: integer,
659                    signed: logical,
660                    symmetric: logical )
661                -> ( y: tensor<scalar> )
662                {
663                    r = scalar(2 ^ bits - 1 - integer(signed && symmetric));
664                    z = clamp(x, min, max);
665                    p = scalar(2 ^ (bits - 1) - integer(symmetric) if signed else 0);
666                    q = round((z - min) / (max - min) * r) - p;
667                    y = (q + p) / r * (max - min) + min;
668}
669            "#,
670        );
671    }
672
673    #[test]
674    fn test_numeric() {
675        p(numeric_literal, "12.0");
676    }
677
678    #[test]
679    fn test_string() {
680        assert_eq!(p(string_literal, r#""""#), "");
681        assert_eq!(p(string_literal, r#""foo""#), "foo");
682        assert_eq!(p(string_literal, r#"''"#), "");
683        assert_eq!(p(string_literal, r#"'foo'"#), "foo");
684
685        assert_eq!(p(string_literal, r"'f\oo'"), "foo");
686        assert_eq!(p(string_literal, r"'f\'oo'"), "f'oo");
687        assert_eq!(p(string_literal, r#"'f\"oo'"#), "f\"oo");
688    }
689
690    #[test]
691    fn test_identifier() {
692        p(identifier, "foo");
693        assert!(identifier("1").is_err());
694        assert!(identifier("1foo").is_err());
695    }
696
697    #[test]
698    fn test_spacing() {
699        p(space_and_comments, "");
700        p(space_and_comments, "\n");
701        p(space_and_comments, "#comment\n");
702        p(space_and_comments, "#boum");
703    }
704
705    #[test]
706    fn test_spaced() {
707        assert!(spaced(identifier).parse("foo").is_ok());
708        assert!(spaced(identifier).parse(" foo ").is_ok());
709        assert!(many1(spaced(identifier)).parse(" foo bar ").is_ok());
710        assert_eq!(
711            many1(spaced(identifier)).parse(" foo bar\n").unwrap().1,
712            &[Identifier("foo".to_string()), Identifier("bar".to_string())]
713        );
714        assert_eq!(
715            many1(spaced(identifier)).parse(" foo # bar\n").unwrap().1,
716            &[Identifier("foo".to_string())]
717        );
718        assert_eq!(
719            many1(spaced(identifier)).parse(" foo # bar\nbaz").unwrap().1,
720            &[Identifier("foo".to_string()), Identifier("baz".to_string())]
721        );
722    }
723
724    #[test]
725    fn test_document() {
726        assert!(document("version 1.0; graph foo() -> () {}").is_ok());
727    }
728
729    #[test]
730    fn test_version() {
731        p(version, "version 1.0;");
732    }
733
734    #[test]
735    fn test_body() {
736        p(body, "{}");
737        p(body, "{foo=bar;}");
738    }
739
740    #[test]
741    fn test_lvalue() {
742        p(lvalue, "foo");
743        p(lvalue, "foo,bar");
744        p(lvalue, "foo , bar");
745        p(lvalue, "(foo,bar)");
746    }
747
748    #[test]
749    fn test_graph_def() {
750        p(graph_def, "graph foo() -> () {}");
751    }
752
753    #[test]
754    fn test_assignment() {
755        p(assignment, "input = external(12);");
756        p(assignment, "input = external(shape = [1, 3, 224, 224]);");
757        p(assignment, "sigma = bias + alpha * box(sqr(input), size = size, normalize = true);");
758        p(assignment, "output = offset + scale * (input - mean) / sqrt(variance + epsilon);");
759        p(
760            assignment,
761            "size = [for i in range_of(output_size) yield output_size[i] * sampling_rate[i]];",
762        );
763        p(assignment, "r = scalar(2 ^ bits - 1 - integer(signed && symmetric));");
764        p(assignment, "output, index = max_pool_with_index(input, size = size, border = border, padding = padding, stride = stride, dilation = dilation);");
765    }
766
767    #[test]
768    fn test_invocation() {
769        p(invocation, "external(12)");
770        p(invocation, "sqrt(var + eps)");
771    }
772
773    #[test]
774    fn test_arguments() {
775        p(argument, "2");
776        p(argument, "12");
777        p(argument, "shape = [1, 3, 224, 224]");
778    }
779
780    #[test]
781    fn test_rvalue() {
782        p(rvalue, "12");
783        p(rvalue, "(0, 0)");
784        p(rvalue, "x ^ 2.0");
785        p(rvalue, "1+2");
786        p(rvalue, "1+sqrt(var)");
787        p(rvalue, "1+sqrt(var+eps)");
788        p(rvalue, "1 + sqrt(var + eps)");
789        p(rvalue, "[for i in range_of(output_size) yield output_size[i] * sampling_rate[i]]");
790        p(rvalue, "scalar(2 ^ (bits - 1) - integer(symmetric) if signed else 0)");
791    }
792
793    #[test]
794    fn test_comprehenion() {
795        p(comprehension_expr, "[for i in range_of(output_size) yield output_size * sampling_rate]");
796    }
797
798    #[test]
799    fn test_freeze() {
800        p(
801            document,
802            r#"
803version 1.0;
804
805graph y( x, s, bias ) -> ( y ) {
806  x = external<scalar>(shape = [1, 2, 1, 3]);
807  s = external<scalar>(shape = [2]);
808  bias = external<scalar>(shape = [2]);
809  y = add(
810        mul(
811            mul(
812                sub(
813                    x,
814                    mul(
815                        0.33333334,
816                        sum_reduce(
817                            x,
818                            axes = [0, 2, 3]
819                        )
820                    )
821                ),
822                rsqrt(
823                    add(
824                        0.00001,
825                        mul(
826                            0.33333334,
827                            sum_reduce(
828                                square(
829                                    sub(
830                                        x,
831                                        mul(
832                                            0.33333334,
833                                            sum_reduce(
834                                                x,
835                                                axes = [0, 2, 3]
836                                            )
837                                        )
838                                    )
839                                ),
840                                axes = [0, 2, 3]
841                            )
842                        )
843                    )
844                )
845            ),
846            unsqueeze(
847                unsqueeze(
848                    unsqueeze(
849                        s,
850                        axes = [0]
851                    ),
852                axes = [2]
853                ),
854            axes = [2]
855            )
856        ),
857        unsqueeze(
858            unsqueeze(
859                unsqueeze(
860                    bias,
861                    axes = [0]
862                ),
863                axes = [2]
864            ),
865            axes = [2]
866        )
867    );
868}
869
870"#,
871        );
872    }
873
874    #[test]
875    fn test_fragments() {
876        p(
877            fragments,
878            r#"
879            fragment add( x: tensor<scalar>, y: tensor<scalar> ) -> ( z: tensor<scalar> );
880            fragment sub( x: tensor<scalar>, y: tensor<scalar> ) -> ( z: tensor<scalar> );
881            "#,
882        );
883    }
884}