pdl_compiler/backends/rust/
test.rs

1// Copyright 2023 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Generate Rust unit tests for canonical test vectors.
16
17use quote::{format_ident, quote};
18use serde::{Deserialize, Serialize};
19use serde_json::Value;
20
21#[derive(Debug, Deserialize)]
22struct Packet {
23    #[serde(rename = "packet")]
24    name: String,
25    tests: Vec<TestVector>,
26}
27
28#[derive(Debug, Deserialize)]
29struct TestVector {
30    packed: String,
31    unpacked: Value,
32    packet: Option<String>,
33}
34
35/// Convert a string of hexadecimal characters into a Rust vector of
36/// bytes.
37///
38/// The string `"80038302"` becomes `vec![0x80, 0x03, 0x83, 0x02]`.
39fn hexadecimal_to_vec(hex: &str) -> proc_macro2::TokenStream {
40    assert!(hex.len() % 2 == 0, "Expects an even number of hex digits");
41    let bytes = hex.as_bytes().chunks_exact(2).map(|chunk| {
42        let number = format!("0x{}", std::str::from_utf8(chunk).unwrap());
43        syn::parse_str::<syn::LitInt>(&number).unwrap()
44    });
45
46    quote! {
47        vec![#(#bytes),*]
48    }
49}
50
51/// Convert `value` to a JSON string literal.
52///
53/// The string literal is a raw literal to avoid escaping
54/// double-quotes.
55fn to_json<T: Serialize>(value: &T) -> syn::LitStr {
56    let json = serde_json::to_string(value).unwrap();
57    assert!(!json.contains("\"#"), "Please increase number of # for {json:?}");
58    syn::parse_str::<syn::LitStr>(&format!("r#\" {json} \"#")).unwrap()
59}
60
61fn generate_unit_tests(input: &str, packet_names: &[&str]) -> Result<String, String> {
62    eprintln!("Reading test vectors from {input}, will use {} packets", packet_names.len());
63
64    let data = std::fs::read_to_string(input)
65        .unwrap_or_else(|err| panic!("Could not read {input}: {err}"));
66    let packets: Vec<Packet> = serde_json::from_str(&data).expect("Could not parse JSON");
67
68    let mut tests = Vec::new();
69    for packet in &packets {
70        for (i, test_vector) in packet.tests.iter().enumerate() {
71            let test_packet = test_vector.packet.as_deref().unwrap_or(packet.name.as_str());
72            if !packet_names.contains(&test_packet) {
73                eprintln!("Skipping packet {}", test_packet);
74                continue;
75            }
76            eprintln!("Generating tests for packet {}", test_packet);
77
78            let parse_test_name = format_ident!(
79                "test_parse_{}_vector_{}_0x{}",
80                test_packet,
81                i + 1,
82                &test_vector.packed
83            );
84            let serialize_test_name = format_ident!(
85                "test_serialize_{}_vector_{}_0x{}",
86                test_packet,
87                i + 1,
88                &test_vector.packed
89            );
90            let packed = hexadecimal_to_vec(&test_vector.packed);
91            let packet_name = format_ident!("{}", test_packet);
92
93            let object = test_vector.unpacked.as_object().unwrap_or_else(|| {
94                panic!("Expected test vector object, found: {}", test_vector.unpacked)
95            });
96            let assertions = object.iter().map(|(key, value)| {
97                let getter = format_ident!("{key}");
98                let expected = format_ident!("expected_{key}");
99                let json = to_json(&value);
100                quote! {
101                    let #expected: serde_json::Value = serde_json::from_str(#json)
102                        .expect("Could not create expected value from canonical JSON data");
103                    assert_eq!(json!(actual.#getter()), #expected);
104                }
105            });
106
107            let json = to_json(&object);
108            tests.push(quote! {
109                #[test]
110                fn #parse_test_name() {
111                    let packed = #packed;
112                    let actual = #packet_name::decode_full(&packed).unwrap();
113                    assert_eq!(actual.encoded_len(), packed.len());
114                    #(#assertions)*
115                }
116
117                #[test]
118                fn #serialize_test_name() {
119                    let packet: #packet_name = serde_json::from_str(#json)
120                        .expect("Could not create packet from canonical JSON data");
121                    let packed: Vec<u8> = #packed;
122                    assert_eq!(packet.encoded_len(), packed.len());
123                    assert_eq!(packet.encode_to_vec(), Ok(packed));
124                }
125            });
126        }
127    }
128
129    // TODO(mgeisler): make the generated code clean from warnings.
130    let code = quote! {
131        #[allow(warnings, missing_docs)]
132        #[cfg(test)]
133        mod test {
134            use pdl_runtime::Packet;
135            use serde_json::json;
136            use super::*;
137
138            #(#tests)*
139        }
140    };
141    let syntax_tree = syn::parse2::<syn::File>(code).expect("Could not parse {code:#?}");
142    Ok(prettyplease::unparse(&syntax_tree))
143}
144
145pub fn generate_tests(input_file: &str) -> Result<String, String> {
146    // TODO(mgeisler): remove the `packet_names` argument when we
147    // support all canonical packets.
148    generate_unit_tests(
149        input_file,
150        &[
151            "EnumChild_A",
152            "EnumChild_B",
153            "Packet_Array_Field_ByteElement_ConstantSize",
154            "Packet_Array_Field_ByteElement_UnknownSize",
155            "Packet_Array_Field_ByteElement_VariableCount",
156            "Packet_Array_Field_ByteElement_VariableSize",
157            "Packet_Array_Field_EnumElement",
158            "Packet_Array_Field_EnumElement_ConstantSize",
159            "Packet_Array_Field_EnumElement_UnknownSize",
160            "Packet_Array_Field_EnumElement_VariableCount",
161            "Packet_Array_Field_EnumElement_VariableCount",
162            "Packet_Array_Field_ScalarElement",
163            "Packet_Array_Field_ScalarElement_ConstantSize",
164            "Packet_Array_Field_ScalarElement_UnknownSize",
165            "Packet_Array_Field_ScalarElement_VariableCount",
166            "Packet_Array_Field_ScalarElement_VariableSize",
167            "Packet_Array_Field_SizedElement_ConstantSize",
168            "Packet_Array_Field_SizedElement_UnknownSize",
169            "Packet_Array_Field_SizedElement_VariableCount",
170            "Packet_Array_Field_SizedElement_VariableSize",
171            "Packet_Array_Field_UnsizedElement_ConstantSize",
172            "Packet_Array_Field_UnsizedElement_UnknownSize",
173            "Packet_Array_Field_UnsizedElement_VariableCount",
174            "Packet_Array_Field_UnsizedElement_VariableSize",
175            "Packet_Array_Field_SizedElement_VariableSize_Padded",
176            "Packet_Array_Field_UnsizedElement_VariableCount_Padded",
177            "Packet_Array_Field_VariableElementSize_ConstantSize",
178            "Packet_Array_Field_VariableElementSize_VariableSize",
179            "Packet_Array_Field_VariableElementSize_VariableCount",
180            "Packet_Array_Field_VariableElementSize_UnknownSize",
181            "Packet_Optional_Scalar_Field",
182            "Packet_Optional_Enum_Field",
183            "Packet_Optional_Struct_Field",
184            "Packet_Body_Field_UnknownSize",
185            "Packet_Body_Field_UnknownSize_Terminal",
186            "Packet_Body_Field_VariableSize",
187            "Packet_Count_Field",
188            "Packet_Enum8_Field",
189            "Packet_Enum_Field",
190            "Packet_FixedEnum_Field",
191            "Packet_FixedScalar_Field",
192            "Packet_Payload_Field_UnknownSize",
193            "Packet_Payload_Field_UnknownSize_Terminal",
194            "Packet_Payload_Field_VariableSize",
195            "Packet_Payload_Field_SizeModifier",
196            "Packet_Reserved_Field",
197            "Packet_Scalar_Field",
198            "Packet_Size_Field",
199            "Packet_Struct_Field",
200            "ScalarChild_A",
201            "ScalarChild_B",
202            "Struct_Count_Field",
203            "Struct_Array_Field_ByteElement_ConstantSize",
204            "Struct_Array_Field_ByteElement_UnknownSize",
205            "Struct_Array_Field_ByteElement_UnknownSize",
206            "Struct_Array_Field_ByteElement_VariableCount",
207            "Struct_Array_Field_ByteElement_VariableCount",
208            "Struct_Array_Field_ByteElement_VariableSize",
209            "Struct_Array_Field_ByteElement_VariableSize",
210            "Struct_Array_Field_EnumElement_ConstantSize",
211            "Struct_Array_Field_EnumElement_UnknownSize",
212            "Struct_Array_Field_EnumElement_UnknownSize",
213            "Struct_Array_Field_EnumElement_VariableCount",
214            "Struct_Array_Field_EnumElement_VariableCount",
215            "Struct_Array_Field_EnumElement_VariableSize",
216            "Struct_Array_Field_EnumElement_VariableSize",
217            "Struct_Array_Field_ScalarElement_ConstantSize",
218            "Struct_Array_Field_ScalarElement_UnknownSize",
219            "Struct_Array_Field_ScalarElement_UnknownSize",
220            "Struct_Array_Field_ScalarElement_VariableCount",
221            "Struct_Array_Field_ScalarElement_VariableCount",
222            "Struct_Array_Field_ScalarElement_VariableSize",
223            "Struct_Array_Field_ScalarElement_VariableSize",
224            "Struct_Array_Field_SizedElement_ConstantSize",
225            "Struct_Array_Field_SizedElement_UnknownSize",
226            "Struct_Array_Field_SizedElement_UnknownSize",
227            "Struct_Array_Field_SizedElement_VariableCount",
228            "Struct_Array_Field_SizedElement_VariableCount",
229            "Struct_Array_Field_SizedElement_VariableSize",
230            "Struct_Array_Field_SizedElement_VariableSize",
231            "Struct_Array_Field_UnsizedElement_ConstantSize",
232            "Struct_Array_Field_UnsizedElement_UnknownSize",
233            "Struct_Array_Field_UnsizedElement_UnknownSize",
234            "Struct_Array_Field_UnsizedElement_VariableCount",
235            "Struct_Array_Field_UnsizedElement_VariableCount",
236            "Struct_Array_Field_UnsizedElement_VariableSize",
237            "Struct_Array_Field_UnsizedElement_VariableSize",
238            "Struct_Array_Field_SizedElement_VariableSize_Padded",
239            "Struct_Array_Field_UnsizedElement_VariableCount_Padded",
240            "Struct_Optional_Scalar_Field",
241            "Struct_Optional_Enum_Field",
242            "Struct_Optional_Struct_Field",
243            "Struct_Enum_Field",
244            "Struct_FixedEnum_Field",
245            "Struct_FixedScalar_Field",
246            "Struct_Size_Field",
247            "Struct_Struct_Field",
248            "Enum_Incomplete_Truncated_Closed",
249            "Enum_Incomplete_Truncated_Open",
250            "Enum_Incomplete_Truncated_Closed_WithRange",
251            "Enum_Incomplete_Truncated_Open_WithRange",
252            "Enum_Complete_Truncated",
253            "Enum_Complete_Truncated_WithRange",
254            "Enum_Complete_WithRange",
255        ],
256    )
257}