1use bytemuck::cast_slice_mut;
2use itertools::{enumerate, Itertools, zip_eq};
3
4use kn_cuda_sys::wrapper::handle::CudaDevice;
5use kn_graph::dispatch_dtensor_pair;
6use kn_graph::dtype::{DScalar, DTensor, DType, IntoDScalar, Tensor};
7use kn_graph::graph::{Graph, Value};
8use kn_graph::ndarray::{Dimension, IxDyn};
9
10use crate::executor::CudaExecutor;
11
12pub fn check_cudnn(graph: &Graph, check_data_bytes: &[u8]) {
15 let (batch_size, inputs, expected_outputs) = load_check_data(graph, check_data_bytes);
16 let outputs = eval_cudnn(graph, batch_size, &inputs, true);
17 assert_tensors_match(&expected_outputs, &outputs, true);
18}
19
20pub fn eval_cudnn(graph: &Graph, batch_size: usize, inputs: &[DTensor], print_executor: bool) -> Vec<DTensor> {
21 let mut executor = CudaExecutor::new(CudaDevice::new(0).unwrap(), graph, batch_size);
22 if print_executor {
23 println!("{:?}", executor);
24 }
25 executor.evaluate(inputs).to_owned()
26}
27
28const TOLERANCE_ABS_DIFF: f64 = 0.001;
29const TOLERANCE_REL_DIFF: f64 = 0.001;
30const MAX_LOGGED_ERRORS: usize = 8;
31
32pub fn assert_tensors_match(expected: &[DTensor], actual: &[DTensor], print_match: bool) {
33 match check_tensors_match(expected, actual) {
34 Ok(Match {
35 diff_per_tensor: diff_per_output,
36 }) => {
37 if print_match {
38 for (i, diff) in enumerate(diff_per_output) {
39 match diff {
40 Difference::Float(DifferenceFloat {
41 max_abs_diff,
42 max_rel_diff,
43 }) => {
44 println!(
45 "Output {} with shape {:?} and {:?} matched, max diff: abs {}, rel {}",
46 i,
47 actual[i].shape(),
48 actual[i].dtype(),
49 max_abs_diff,
50 max_rel_diff
51 );
52 }
53 Difference::IntMatch => {
54 println!(
55 "Output {} with shape {:?} and {:?} matched",
56 i,
57 actual[i].shape(),
58 actual[i].dtype(),
59 );
60 }
61 }
62 }
63 }
64 }
65 Err(Mismatch {
66 error_count,
67 total_count,
68 first_errors,
69 }) => {
70 eprintln!("Mismatch in {}/{} values:", error_count, total_count);
71
72 for error in &first_errors {
73 let Error {
74 tensor,
75 ref indices,
76 expected_value,
77 actual_value,
78 more_omitted,
79 } = *error;
80
81 eprintln!(
82 " Wrong output value {:?}, expected {:?} at indices {:?} in tensor {} (shape {:?})",
83 actual_value,
84 expected_value,
85 indices,
86 tensor,
87 expected[tensor].shape()
88 );
89
90 if more_omitted {
91 eprintln!(" ...");
92 }
93 }
94
95 panic!("Output mismatch");
96 }
97 }
98}
99
100#[derive(Debug, Clone)]
101pub struct Match {
102 pub diff_per_tensor: Vec<Difference>,
103}
104
105#[derive(Debug, Clone)]
106pub enum Difference {
107 Float(DifferenceFloat),
108 IntMatch,
109}
110
111#[derive(Debug, Copy, Clone)]
113pub struct DifferenceFloat {
114 pub max_rel_diff: f64,
115 pub max_abs_diff: f64,
116}
117
118#[derive(Debug, Clone)]
119pub struct Mismatch {
120 pub error_count: u64,
121 pub total_count: u64,
122 pub first_errors: Vec<Error>,
123}
124
125#[derive(Debug, Clone)]
126pub struct Error {
127 pub tensor: usize,
128 pub indices: Vec<usize>,
129 pub expected_value: DScalar,
130 pub actual_value: DScalar,
131 pub more_omitted: bool,
132}
133
134#[derive(Default, Debug, Clone)]
135pub struct Counts {
136 total_element_count: u64,
137 total_error_count: u64,
138}
139
140pub fn check_tensors_match(expected: &[DTensor], actual: &[DTensor]) -> Result<Match, Mismatch> {
141 assert_eq!(expected.len(), actual.len(), "Wrong number of tensors");
142
143 let mut counts = Counts::default();
144
145 let mut diff_per_tensor = vec![];
146 let mut first_errors = vec![];
147
148 for (i, (expected_output, output)) in zip_eq(expected, actual).enumerate() {
149 let diff = check_tensor_match(i, expected_output, output, &mut counts, &mut first_errors);
150 diff_per_tensor.push(diff);
151 }
152
153 if counts.total_error_count == 0 {
154 Ok(Match { diff_per_tensor })
155 } else {
156 Err(Mismatch {
157 error_count: counts.total_error_count,
158 total_count: counts.total_element_count,
159 first_errors,
160 })
161 }
162}
163
164fn check_tensor_match(
165 i: usize,
166 expected_output: &DTensor,
167 output: &DTensor,
168 counts: &mut Counts,
169 first_errors: &mut Vec<Error>,
170) -> Difference {
171 assert_eq!(
172 expected_output.shape(),
173 output.shape(),
174 "Wrong output shape for tensor {}",
175 i
176 );
177 assert_eq!(
178 expected_output.dtype(),
179 output.dtype(),
180 "Wrong output dtype for tensor {}",
181 i
182 );
183 let dtype = expected_output.dtype();
184
185 match dtype {
186 DType::F32 => Difference::Float(check_tensor_match_approx(
187 i,
188 &expected_output.unwrap_f32().unwrap().mapv(|x| x as f64).into_shared(),
189 &output.unwrap_f32().unwrap().mapv(|x| x as f64).into_shared(),
190 counts,
191 first_errors,
192 )),
193 DType::F64 => Difference::Float(check_tensor_match_approx(
194 i,
195 expected_output.unwrap_f64().unwrap(),
196 output.unwrap_f64().unwrap(),
197 counts,
198 first_errors,
199 )),
200 DType::U8
201 | DType::U16
202 | DType::U32
203 | DType::U64
204 | DType::I8
205 | DType::I16
206 | DType::I32
207 | DType::I64
208 | DType::Bool => {
209 dispatch_dtensor_pair!(expected_output, output, |_T, _f, expected_output, output| {
210 check_tensor_match_exact(i, expected_output, output, counts, first_errors)
211 })
212 }
213 }
214}
215
216fn check_tensor_match_exact<T: IntoDScalar>(
217 i: usize,
218 expected_output: &Tensor<T>,
219 output: &Tensor<T>,
220 counts: &mut Counts,
221 first_errors: &mut Vec<Error>,
222) -> Difference {
223 assert!(T::DTYPE.is_int() || T::DTYPE.is_bool());
224
225 let mut current_error_count = 0;
226
227 for ((indices, &expected_value), &value) in zip_eq(expected_output.indexed_iter(), output.iter()) {
228 counts.total_element_count += 1;
229
230 if expected_value != value {
231 counts.total_error_count += 1;
232 current_error_count += 1;
233
234 if current_error_count < MAX_LOGGED_ERRORS {
235 first_errors.push(Error {
236 tensor: i,
237 indices: indices.slice().to_vec(),
238 expected_value: expected_value.to_dscalar(),
239 actual_value: value.to_dscalar(),
240 more_omitted: false,
241 });
242 } else {
243 first_errors.last_mut().unwrap().more_omitted = true;
244 }
245 }
246 }
247
248 Difference::IntMatch
249}
250
251fn check_tensor_match_approx(
252 i: usize,
253 expected_output: &Tensor<f64>,
254 output: &Tensor<f64>,
255 counts: &mut Counts,
256 first_errors: &mut Vec<Error>,
257) -> DifferenceFloat {
258 let mut max_abs_diff = 0.0;
259 let mut max_rel_diff = 0.0;
260
261 let mut current_error_count = 0;
262
263 for ((indices, &expected_value), &value) in zip_eq(expected_output.indexed_iter(), output.iter()) {
264 let (abs_diff, rel_diff) = if expected_value == value || expected_value.is_nan() || value.is_nan() {
265 (0.0, 0.0)
266 } else {
267 let abs_diff = (expected_value - value).abs();
268 let rel_diff = abs_diff / expected_value.abs();
269 (abs_diff, rel_diff)
270 };
271
272 max_abs_diff = f64::max(max_abs_diff, abs_diff);
273 max_rel_diff = f64::max(max_rel_diff, rel_diff);
274
275 counts.total_element_count += 1;
276
277 let exceeds_tolerance = abs_diff >= TOLERANCE_ABS_DIFF && rel_diff >= TOLERANCE_REL_DIFF;
278 let nan_mismatch = expected_value.is_nan() != value.is_nan();
279
280 if exceeds_tolerance || nan_mismatch {
281 counts.total_error_count += 1;
282 current_error_count += 1;
283
284 if current_error_count < MAX_LOGGED_ERRORS {
285 first_errors.push(Error {
286 tensor: i,
287 indices: indices.slice().to_vec(),
288 expected_value: expected_value.to_dscalar(),
289 actual_value: value.to_dscalar(),
290 more_omitted: false,
291 });
292 } else {
293 first_errors.last_mut().unwrap().more_omitted = true;
294 }
295 }
296 }
297
298 DifferenceFloat {
299 max_rel_diff,
300 max_abs_diff,
301 }
302}
303
304pub fn load_check_data(graph: &Graph, check_data_bytes: &[u8]) -> (usize, Vec<DTensor>, Vec<DTensor>) {
306 assert!(
307 !check_data_bytes.is_empty(),
308 "Check data must have at least one byte, the batch size"
309 );
310 let batch_size = check_data_bytes[0] as usize;
311
312 assert_eq!(
313 (check_data_bytes.len() - 1) % 4,
314 0,
315 "Data byte count must be multiple of 4 + 1 to be able to cast to float, got {}",
316 check_data_bytes.len()
317 );
318
319 let mut check_data = vec![0.0; (check_data_bytes.len() - 1) / 4];
321 cast_slice_mut(&mut check_data).copy_from_slice(&check_data_bytes[1..]);
322
323 let mut buf = &*check_data;
324 let inputs = load_check_values(graph, batch_size, &mut buf, graph.inputs());
325 let expected_outputs = load_check_values(graph, batch_size, &mut buf, graph.outputs());
326
327 assert!(buf.is_empty(), "Leftover elements in check data buffer: {}", buf.len());
328
329 (batch_size, inputs, expected_outputs)
330}
331
332fn load_check_values(graph: &Graph, batch_size: usize, buf: &mut &[f32], values: &[Value]) -> Vec<DTensor> {
334 values
336 .iter()
337 .map(|&value| {
338 let shape = graph[value].shape.eval(batch_size);
339 let tensor =
340 DTensor::F32(Tensor::from_shape_vec(IxDyn(&shape.dims), buf[0..shape.size()].to_vec()).unwrap());
341 *buf = &buf[shape.size()..];
342 tensor
343 })
344 .collect_vec()
345}