kn_cuda_eval/
tester.rs

1use bytemuck::cast_slice_mut;
2use itertools::{enumerate, Itertools, zip_eq};
3
4use kn_cuda_sys::wrapper::handle::CudaDevice;
5use kn_graph::dispatch_dtensor_pair;
6use kn_graph::dtype::{DScalar, DTensor, DType, IntoDScalar, Tensor};
7use kn_graph::graph::{Graph, Value};
8use kn_graph::ndarray::{Dimension, IxDyn};
9
10use crate::executor::CudaExecutor;
11
12/// Check that the given graph produces the correct outputs as described by `check_data`,
13/// which typically comes from a `.bin` file next to the `.onnx` file.
14pub fn check_cudnn(graph: &Graph, check_data_bytes: &[u8]) {
15    let (batch_size, inputs, expected_outputs) = load_check_data(graph, check_data_bytes);
16    let outputs = eval_cudnn(graph, batch_size, &inputs, true);
17    assert_tensors_match(&expected_outputs, &outputs, true);
18}
19
20pub fn eval_cudnn(graph: &Graph, batch_size: usize, inputs: &[DTensor], print_executor: bool) -> Vec<DTensor> {
21    let mut executor = CudaExecutor::new(CudaDevice::new(0).unwrap(), graph, batch_size);
22    if print_executor {
23        println!("{:?}", executor);
24    }
25    executor.evaluate(inputs).to_owned()
26}
27
28const TOLERANCE_ABS_DIFF: f64 = 0.001;
29const TOLERANCE_REL_DIFF: f64 = 0.001;
30const MAX_LOGGED_ERRORS: usize = 8;
31
32pub fn assert_tensors_match(expected: &[DTensor], actual: &[DTensor], print_match: bool) {
33    match check_tensors_match(expected, actual) {
34        Ok(Match {
35            diff_per_tensor: diff_per_output,
36        }) => {
37            if print_match {
38                for (i, diff) in enumerate(diff_per_output) {
39                    match diff {
40                        Difference::Float(DifferenceFloat {
41                            max_abs_diff,
42                            max_rel_diff,
43                        }) => {
44                            println!(
45                                "Output {} with shape {:?} and {:?} matched, max diff: abs {}, rel {}",
46                                i,
47                                actual[i].shape(),
48                                actual[i].dtype(),
49                                max_abs_diff,
50                                max_rel_diff
51                            );
52                        }
53                        Difference::IntMatch => {
54                            println!(
55                                "Output {} with shape {:?} and {:?} matched",
56                                i,
57                                actual[i].shape(),
58                                actual[i].dtype(),
59                            );
60                        }
61                    }
62                }
63            }
64        }
65        Err(Mismatch {
66            error_count,
67            total_count,
68            first_errors,
69        }) => {
70            eprintln!("Mismatch in {}/{} values:", error_count, total_count);
71
72            for error in &first_errors {
73                let Error {
74                    tensor,
75                    ref indices,
76                    expected_value,
77                    actual_value,
78                    more_omitted,
79                } = *error;
80
81                eprintln!(
82                    "  Wrong output value {:?}, expected {:?} at indices {:?} in tensor {} (shape {:?})",
83                    actual_value,
84                    expected_value,
85                    indices,
86                    tensor,
87                    expected[tensor].shape()
88                );
89
90                if more_omitted {
91                    eprintln!("  ...");
92                }
93            }
94
95            panic!("Output mismatch");
96        }
97    }
98}
99
100#[derive(Debug, Clone)]
101pub struct Match {
102    pub diff_per_tensor: Vec<Difference>,
103}
104
105#[derive(Debug, Clone)]
106pub enum Difference {
107    Float(DifferenceFloat),
108    IntMatch,
109}
110
111// TODO int/float enum? or just float/dscalar?
112#[derive(Debug, Copy, Clone)]
113pub struct DifferenceFloat {
114    pub max_rel_diff: f64,
115    pub max_abs_diff: f64,
116}
117
118#[derive(Debug, Clone)]
119pub struct Mismatch {
120    pub error_count: u64,
121    pub total_count: u64,
122    pub first_errors: Vec<Error>,
123}
124
125#[derive(Debug, Clone)]
126pub struct Error {
127    pub tensor: usize,
128    pub indices: Vec<usize>,
129    pub expected_value: DScalar,
130    pub actual_value: DScalar,
131    pub more_omitted: bool,
132}
133
134#[derive(Default, Debug, Clone)]
135pub struct Counts {
136    total_element_count: u64,
137    total_error_count: u64,
138}
139
140pub fn check_tensors_match(expected: &[DTensor], actual: &[DTensor]) -> Result<Match, Mismatch> {
141    assert_eq!(expected.len(), actual.len(), "Wrong number of tensors");
142
143    let mut counts = Counts::default();
144
145    let mut diff_per_tensor = vec![];
146    let mut first_errors = vec![];
147
148    for (i, (expected_output, output)) in zip_eq(expected, actual).enumerate() {
149        let diff = check_tensor_match(i, expected_output, output, &mut counts, &mut first_errors);
150        diff_per_tensor.push(diff);
151    }
152
153    if counts.total_error_count == 0 {
154        Ok(Match { diff_per_tensor })
155    } else {
156        Err(Mismatch {
157            error_count: counts.total_error_count,
158            total_count: counts.total_element_count,
159            first_errors,
160        })
161    }
162}
163
164fn check_tensor_match(
165    i: usize,
166    expected_output: &DTensor,
167    output: &DTensor,
168    counts: &mut Counts,
169    first_errors: &mut Vec<Error>,
170) -> Difference {
171    assert_eq!(
172        expected_output.shape(),
173        output.shape(),
174        "Wrong output shape for tensor {}",
175        i
176    );
177    assert_eq!(
178        expected_output.dtype(),
179        output.dtype(),
180        "Wrong output dtype for tensor {}",
181        i
182    );
183    let dtype = expected_output.dtype();
184
185    match dtype {
186        DType::F32 => Difference::Float(check_tensor_match_approx(
187            i,
188            &expected_output.unwrap_f32().unwrap().mapv(|x| x as f64).into_shared(),
189            &output.unwrap_f32().unwrap().mapv(|x| x as f64).into_shared(),
190            counts,
191            first_errors,
192        )),
193        DType::F64 => Difference::Float(check_tensor_match_approx(
194            i,
195            expected_output.unwrap_f64().unwrap(),
196            output.unwrap_f64().unwrap(),
197            counts,
198            first_errors,
199        )),
200        DType::U8
201        | DType::U16
202        | DType::U32
203        | DType::U64
204        | DType::I8
205        | DType::I16
206        | DType::I32
207        | DType::I64
208        | DType::Bool => {
209            dispatch_dtensor_pair!(expected_output, output, |_T, _f, expected_output, output| {
210                check_tensor_match_exact(i, expected_output, output, counts, first_errors)
211            })
212        }
213    }
214}
215
216fn check_tensor_match_exact<T: IntoDScalar>(
217    i: usize,
218    expected_output: &Tensor<T>,
219    output: &Tensor<T>,
220    counts: &mut Counts,
221    first_errors: &mut Vec<Error>,
222) -> Difference {
223    assert!(T::DTYPE.is_int() || T::DTYPE.is_bool());
224
225    let mut current_error_count = 0;
226
227    for ((indices, &expected_value), &value) in zip_eq(expected_output.indexed_iter(), output.iter()) {
228        counts.total_element_count += 1;
229
230        if expected_value != value {
231            counts.total_error_count += 1;
232            current_error_count += 1;
233
234            if current_error_count < MAX_LOGGED_ERRORS {
235                first_errors.push(Error {
236                    tensor: i,
237                    indices: indices.slice().to_vec(),
238                    expected_value: expected_value.to_dscalar(),
239                    actual_value: value.to_dscalar(),
240                    more_omitted: false,
241                });
242            } else {
243                first_errors.last_mut().unwrap().more_omitted = true;
244            }
245        }
246    }
247
248    Difference::IntMatch
249}
250
251fn check_tensor_match_approx(
252    i: usize,
253    expected_output: &Tensor<f64>,
254    output: &Tensor<f64>,
255    counts: &mut Counts,
256    first_errors: &mut Vec<Error>,
257) -> DifferenceFloat {
258    let mut max_abs_diff = 0.0;
259    let mut max_rel_diff = 0.0;
260
261    let mut current_error_count = 0;
262
263    for ((indices, &expected_value), &value) in zip_eq(expected_output.indexed_iter(), output.iter()) {
264        let (abs_diff, rel_diff) = if expected_value == value || expected_value.is_nan() || value.is_nan() {
265            (0.0, 0.0)
266        } else {
267            let abs_diff = (expected_value - value).abs();
268            let rel_diff = abs_diff / expected_value.abs();
269            (abs_diff, rel_diff)
270        };
271
272        max_abs_diff = f64::max(max_abs_diff, abs_diff);
273        max_rel_diff = f64::max(max_rel_diff, rel_diff);
274
275        counts.total_element_count += 1;
276
277        let exceeds_tolerance = abs_diff >= TOLERANCE_ABS_DIFF && rel_diff >= TOLERANCE_REL_DIFF;
278        let nan_mismatch = expected_value.is_nan() != value.is_nan();
279
280        if exceeds_tolerance || nan_mismatch {
281            counts.total_error_count += 1;
282            current_error_count += 1;
283
284            if current_error_count < MAX_LOGGED_ERRORS {
285                first_errors.push(Error {
286                    tensor: i,
287                    indices: indices.slice().to_vec(),
288                    expected_value: expected_value.to_dscalar(),
289                    actual_value: value.to_dscalar(),
290                    more_omitted: false,
291                });
292            } else {
293                first_errors.last_mut().unwrap().more_omitted = true;
294            }
295        }
296    }
297
298    DifferenceFloat {
299        max_rel_diff,
300        max_abs_diff,
301    }
302}
303
304/// Load the check data into `(batch_size, inputs, expected_outputs)`.
305pub fn load_check_data(graph: &Graph, check_data_bytes: &[u8]) -> (usize, Vec<DTensor>, Vec<DTensor>) {
306    assert!(
307        !check_data_bytes.is_empty(),
308        "Check data must have at least one byte, the batch size"
309    );
310    let batch_size = check_data_bytes[0] as usize;
311
312    assert_eq!(
313        (check_data_bytes.len() - 1) % 4,
314        0,
315        "Data byte count must be multiple of 4 + 1 to be able to cast to float, got {}",
316        check_data_bytes.len()
317    );
318
319    // copy the data into a float array instead of just casting it to ensure it's properly aligned
320    let mut check_data = vec![0.0; (check_data_bytes.len() - 1) / 4];
321    cast_slice_mut(&mut check_data).copy_from_slice(&check_data_bytes[1..]);
322
323    let mut buf = &*check_data;
324    let inputs = load_check_values(graph, batch_size, &mut buf, graph.inputs());
325    let expected_outputs = load_check_values(graph, batch_size, &mut buf, graph.outputs());
326
327    assert!(buf.is_empty(), "Leftover elements in check data buffer: {}", buf.len());
328
329    (batch_size, inputs, expected_outputs)
330}
331
332/// Load the given values from the buffer while advancing it.
333fn load_check_values(graph: &Graph, batch_size: usize, buf: &mut &[f32], values: &[Value]) -> Vec<DTensor> {
334    // TODO support loading non-f32 values
335    values
336        .iter()
337        .map(|&value| {
338            let shape = graph[value].shape.eval(batch_size);
339            let tensor =
340                DTensor::F32(Tensor::from_shape_vec(IxDyn(&shape.dims), buf[0..shape.size()].to_vec()).unwrap());
341            *buf = &buf[shape.size()..];
342            tensor
343        })
344        .collect_vec()
345}