pdl_compiler/backends/rust/
test.rs

1// Copyright 2023 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Generate Rust unit tests for canonical test vectors.
16
17use quote::{format_ident, quote};
18use serde::Serialize;
19
20use crate::backends::common::test::Packet;
21
22/// Convert a string of hexadecimal characters into a Rust vector of
23/// bytes.
24///
25/// The string `"80038302"` becomes `vec![0x80, 0x03, 0x83, 0x02]`.
26fn hexadecimal_to_vec(hex: &str) -> proc_macro2::TokenStream {
27    assert!(hex.len() % 2 == 0, "Expects an even number of hex digits");
28    let bytes = hex.as_bytes().chunks_exact(2).map(|chunk| {
29        let number = format!("0x{}", std::str::from_utf8(chunk).unwrap());
30        syn::parse_str::<syn::LitInt>(&number).unwrap()
31    });
32
33    quote! {
34        vec![#(#bytes),*]
35    }
36}
37
38/// Convert `value` to a JSON string literal.
39///
40/// The string literal is a raw literal to avoid escaping
41/// double-quotes.
42fn to_json<T: Serialize>(value: &T) -> syn::LitStr {
43    let json = serde_json::to_string(value).unwrap();
44    assert!(!json.contains("\"#"), "Please increase number of # for {json:?}");
45    syn::parse_str::<syn::LitStr>(&format!("r#\" {json} \"#")).unwrap()
46}
47
48fn generate_unit_tests(input: &str, packet_names: &[&str]) -> Result<String, String> {
49    eprintln!("Reading test vectors from {input}, will use {} packets", packet_names.len());
50
51    let data = std::fs::read_to_string(input)
52        .unwrap_or_else(|err| panic!("Could not read {input}: {err}"));
53    let packets: Vec<Packet> = serde_json::from_str(&data).expect("Could not parse JSON");
54
55    let mut tests = Vec::new();
56    for packet in &packets {
57        for (i, test_vector) in packet.tests.iter().enumerate() {
58            let test_packet = test_vector.packet.as_deref().unwrap_or(packet.name.as_str());
59            if !packet_names.contains(&test_packet) {
60                eprintln!("Skipping packet {test_packet}");
61                continue;
62            }
63            eprintln!("Generating tests for packet {test_packet}");
64
65            let parse_test_name = format_ident!(
66                "test_parse_{}_vector_{}_0x{}",
67                test_packet,
68                i + 1,
69                &test_vector.packed
70            );
71            let serialize_test_name = format_ident!(
72                "test_serialize_{}_vector_{}_0x{}",
73                test_packet,
74                i + 1,
75                &test_vector.packed
76            );
77            let packed = hexadecimal_to_vec(&test_vector.packed);
78            let packet_name = format_ident!("{}", test_packet);
79
80            let object = test_vector.unpacked.as_object().unwrap_or_else(|| {
81                panic!("Expected test vector object, found: {}", test_vector.unpacked)
82            });
83            let assertions = object.iter().map(|(key, value)| {
84                let getter = format_ident!("{key}");
85                let expected = format_ident!("expected_{key}");
86                let json = to_json(&value);
87                quote! {
88                    let #expected: serde_json::Value = serde_json::from_str(#json)
89                        .expect("Could not create expected value from canonical JSON data");
90                    assert_eq!(json!(actual.#getter()), #expected);
91                }
92            });
93
94            let json = to_json(&object);
95            tests.push(quote! {
96                #[test]
97                fn #parse_test_name() {
98                    let packed = #packed;
99                    let actual = #packet_name::decode_full(&packed).unwrap();
100                    assert_eq!(actual.encoded_len(), packed.len());
101                    #(#assertions)*
102                }
103
104                #[test]
105                fn #serialize_test_name() {
106                    let packet: #packet_name = serde_json::from_str(#json)
107                        .expect("Could not create packet from canonical JSON data");
108                    let packed: Vec<u8> = #packed;
109                    assert_eq!(packet.encoded_len(), packed.len());
110                    assert_eq!(packet.encode_to_vec(), Ok(packed));
111                }
112            });
113        }
114    }
115
116    // TODO(mgeisler): make the generated code clean from warnings.
117    let code = quote! {
118        #[allow(warnings, missing_docs)]
119        #[cfg(test)]
120        mod test {
121            use pdl_runtime::Packet;
122            use serde_json::json;
123            use super::*;
124
125            #(#tests)*
126        }
127    };
128    let syntax_tree = syn::parse2::<syn::File>(code).expect("Could not parse {code:#?}");
129    Ok(prettyplease::unparse(&syntax_tree))
130}
131
132pub fn generate_tests(input_file: &str) -> Result<String, String> {
133    // TODO(mgeisler): remove the `packet_names` argument when we
134    // support all canonical packets.
135    generate_unit_tests(
136        input_file,
137        &[
138            "EnumChild_A",
139            "EnumChild_B",
140            "Packet_Array_Field_ByteElement_ConstantSize",
141            "Packet_Array_Field_ByteElement_UnknownSize",
142            "Packet_Array_Field_ByteElement_VariableCount",
143            "Packet_Array_Field_ByteElement_VariableSize",
144            "Packet_Array_Field_EnumElement",
145            "Packet_Array_Field_EnumElement_ConstantSize",
146            "Packet_Array_Field_EnumElement_UnknownSize",
147            "Packet_Array_Field_EnumElement_VariableCount",
148            "Packet_Array_Field_EnumElement_VariableCount",
149            "Packet_Array_Field_ScalarElement",
150            "Packet_Array_Field_ScalarElement_ConstantSize",
151            "Packet_Array_Field_ScalarElement_UnknownSize",
152            "Packet_Array_Field_ScalarElement_VariableCount",
153            "Packet_Array_Field_ScalarElement_VariableSize",
154            "Packet_Array_Field_SizedElement_ConstantSize",
155            "Packet_Array_Field_SizedElement_UnknownSize",
156            "Packet_Array_Field_SizedElement_VariableCount",
157            "Packet_Array_Field_SizedElement_VariableSize",
158            "Packet_Array_Field_UnsizedElement_ConstantSize",
159            "Packet_Array_Field_UnsizedElement_UnknownSize",
160            "Packet_Array_Field_UnsizedElement_VariableCount",
161            "Packet_Array_Field_UnsizedElement_VariableSize",
162            "Packet_Array_Field_SizedElement_VariableSize_Padded",
163            "Packet_Array_Field_UnsizedElement_VariableCount_Padded",
164            "Packet_Array_Field_VariableElementSize_ConstantSize",
165            "Packet_Array_Field_VariableElementSize_VariableSize",
166            "Packet_Array_Field_VariableElementSize_VariableCount",
167            "Packet_Array_Field_VariableElementSize_UnknownSize",
168            "Packet_Optional_Scalar_Field",
169            "Packet_Optional_Enum_Field",
170            "Packet_Optional_Struct_Field",
171            "Packet_Body_Field_UnknownSize",
172            "Packet_Body_Field_UnknownSize_Terminal",
173            "Packet_Body_Field_VariableSize",
174            "Packet_Count_Field",
175            "Packet_Enum8_Field",
176            "Packet_Enum_Field",
177            "Packet_FixedEnum_Field",
178            "Packet_FixedScalar_Field",
179            "Packet_Payload_Field_UnknownSize",
180            "Packet_Payload_Field_UnknownSize_Terminal",
181            "Packet_Payload_Field_VariableSize",
182            "Packet_Payload_Field_SizeModifier",
183            "Packet_Reserved_Field",
184            "Packet_Scalar_Field",
185            "Packet_Size_Field",
186            "Packet_Struct_Field",
187            "ScalarChild_A",
188            "ScalarChild_B",
189            "Struct_Count_Field",
190            "Struct_Array_Field_ByteElement_ConstantSize",
191            "Struct_Array_Field_ByteElement_UnknownSize",
192            "Struct_Array_Field_ByteElement_UnknownSize",
193            "Struct_Array_Field_ByteElement_VariableCount",
194            "Struct_Array_Field_ByteElement_VariableCount",
195            "Struct_Array_Field_ByteElement_VariableSize",
196            "Struct_Array_Field_ByteElement_VariableSize",
197            "Struct_Array_Field_EnumElement_ConstantSize",
198            "Struct_Array_Field_EnumElement_UnknownSize",
199            "Struct_Array_Field_EnumElement_UnknownSize",
200            "Struct_Array_Field_EnumElement_VariableCount",
201            "Struct_Array_Field_EnumElement_VariableCount",
202            "Struct_Array_Field_EnumElement_VariableSize",
203            "Struct_Array_Field_EnumElement_VariableSize",
204            "Struct_Array_Field_ScalarElement_ConstantSize",
205            "Struct_Array_Field_ScalarElement_UnknownSize",
206            "Struct_Array_Field_ScalarElement_UnknownSize",
207            "Struct_Array_Field_ScalarElement_VariableCount",
208            "Struct_Array_Field_ScalarElement_VariableCount",
209            "Struct_Array_Field_ScalarElement_VariableSize",
210            "Struct_Array_Field_ScalarElement_VariableSize",
211            "Struct_Array_Field_SizedElement_ConstantSize",
212            "Struct_Array_Field_SizedElement_UnknownSize",
213            "Struct_Array_Field_SizedElement_UnknownSize",
214            "Struct_Array_Field_SizedElement_VariableCount",
215            "Struct_Array_Field_SizedElement_VariableCount",
216            "Struct_Array_Field_SizedElement_VariableSize",
217            "Struct_Array_Field_SizedElement_VariableSize",
218            "Struct_Array_Field_UnsizedElement_ConstantSize",
219            "Struct_Array_Field_UnsizedElement_UnknownSize",
220            "Struct_Array_Field_UnsizedElement_UnknownSize",
221            "Struct_Array_Field_UnsizedElement_VariableCount",
222            "Struct_Array_Field_UnsizedElement_VariableCount",
223            "Struct_Array_Field_UnsizedElement_VariableSize",
224            "Struct_Array_Field_UnsizedElement_VariableSize",
225            "Struct_Array_Field_SizedElement_VariableSize_Padded",
226            "Struct_Array_Field_UnsizedElement_VariableCount_Padded",
227            "Struct_Optional_Scalar_Field",
228            "Struct_Optional_Enum_Field",
229            "Struct_Optional_Struct_Field",
230            "Struct_Enum_Field",
231            "Struct_FixedEnum_Field",
232            "Struct_FixedScalar_Field",
233            "Struct_Size_Field",
234            "Struct_Struct_Field",
235            "Enum_Incomplete_Truncated_Closed",
236            "Enum_Incomplete_Truncated_Open",
237            "Enum_Incomplete_Truncated_Closed_WithRange",
238            "Enum_Incomplete_Truncated_Open_WithRange",
239            "Enum_Complete_Truncated",
240            "Enum_Complete_Truncated_WithRange",
241            "Enum_Complete_WithRange",
242        ],
243    )
244}