kn_graph/
cpu.rs

1use std::cmp::min;
2use std::fmt::{Debug, Display, Formatter};
3use std::time::Instant;
4
5use indexmap::IndexMap;
6use itertools::Itertools;
7use ndarray::{
8    ArcArray, Array3, Array4, ArrayView, ArrayView3, ArrayView4, Ix3, Ix4, IxDyn, LinalgScalar, s, SliceInfo,
9    SliceInfoElem, Zip,
10};
11
12use crate::dtype::{
13    dispatch_dtensor, dispatch_dtype, DTensor, DType, IntoDScalar, map_dtensor, map_dtensor_pair, Tensor,
14};
15use crate::graph::{ConvDetails, Graph, Operation, SliceRange, Value, ValueInfo};
16use crate::ndarray::{Array, ArrayBase, Axis};
17use crate::shape::ConcreteShape;
18use crate::wrap_debug::WrapDebug;
19
20pub fn cpu_eval_graph(graph: &Graph, batch_size: usize, inputs: &[DTensor]) -> Vec<DTensor> {
21    let exec = cpu_eval_graph_exec(graph, batch_size, inputs, false);
22    exec.output_tensors()
23}
24
25/// Evaluate the given graph on the CPU, with the given batch size and inputs,
26/// returning the full execution state including profiling information.
27///
28/// Prefer using [cpu_eval_graph] if only the output are necessary.
29///
30/// `keep_all` controls whether all intermediate tensors are kept in memory,
31/// or freed as soon as they are no longer necessary.
32pub fn cpu_eval_graph_exec(graph: &Graph, batch_size: usize, inputs: &[DTensor], keep_all: bool) -> ExecutionInfo {
33    assert_eq!(
34        graph.inputs().len(),
35        inputs.len(),
36        "Wrong input count, graph has {} but got {}",
37        graph.inputs().len(),
38        inputs.len()
39    );
40
41    let mut map: IndexMap<Value, CalculatedValue> = IndexMap::default();
42
43    for output in graph.values() {
44        let info = &graph[output];
45
46        let start_time = Instant::now();
47        let result = run_cpu_operation(info, &map, inputs, batch_size);
48        let end_time = Instant::now();
49
50        let tensor_shape = ConcreteShape::new(result.shape().to_vec());
51
52        let mut output_calc = CalculatedValue {
53            value: output,
54            tensor: Some(result),
55            tensor_shape,
56            uses_seen: 0,
57            time_spent: (end_time - start_time).as_secs_f32(),
58        };
59
60        // free tensors that won't be used again
61        if !keep_all {
62            // immediately discard this output
63            if graph.is_hidden_with_uses(output, 0) {
64                output_calc.tensor = None
65            }
66
67            // discard inputs that just got used for the last time
68            for input in graph[output].operation.inputs() {
69                let input_calc: &mut CalculatedValue = map.get_mut(&input).unwrap();
70                input_calc.uses_seen += 1;
71
72                if graph.is_hidden_with_uses(input, input_calc.uses_seen) {
73                    input_calc.tensor = None;
74                }
75            }
76        }
77
78        // store output for later
79        let prev = map.insert(output, output_calc);
80        assert!(prev.is_none());
81    }
82
83    ExecutionInfo {
84        batch_size,
85        values: map,
86        outputs: graph.outputs().to_owned(),
87    }
88}
89
90#[derive(Debug, Copy, Clone, Eq, PartialEq)]
91pub(crate) enum OperationError {
92    NoBatchSize,
93    MissingOperand,
94    MissingInput,
95}
96
97pub(crate) type OperationResult = Result<DTensor, OperationError>;
98
99fn run_cpu_operation(
100    info: &ValueInfo,
101    map: &IndexMap<Value, CalculatedValue>,
102    inputs: &[DTensor],
103    batch_size: usize,
104) -> DTensor {
105    try_run_cpu_operation(
106        info,
107        |value| Ok(map.get(&value).unwrap().tensor.as_ref().unwrap().clone()),
108        |index| Ok(inputs[index].clone()),
109        Some(batch_size),
110    )
111    .unwrap()
112}
113
114pub(crate) fn run_cpu_const_operation(info: &ValueInfo, map: impl FnMut(Value) -> OperationResult) -> OperationResult {
115    try_run_cpu_operation(info, map, |_| Err(OperationError::MissingInput), None)
116}
117
118fn try_run_cpu_operation(
119    info: &ValueInfo,
120    mut map: impl FnMut(Value) -> OperationResult,
121    input: impl Fn(usize) -> OperationResult,
122    batch_size: Option<usize>,
123) -> OperationResult {
124    let output_shape = match info.shape.as_fixed() {
125        Some(shape) => shape,
126        None => batch_size
127            .map(|batch_size| info.shape.eval(batch_size))
128            .ok_or(OperationError::NoBatchSize)?,
129    };
130    let output_shape_dyn = IxDyn(&output_shape.dims);
131    let dtype = info.dtype;
132
133    let result: DTensor = match info.operation {
134        Operation::Input { index } => input(index)?,
135        Operation::Constant { tensor: WrapDebug(ref tensor) } => tensor.clone(),
136        Operation::View { input } => {
137            let input = map(input)?;
138            input.reshape(output_shape_dyn)
139        }
140        Operation::Broadcast { input } => {
141            let input = map(input)?;
142            map_dtensor!(input, |input| input.broadcast(output_shape_dyn).unwrap().to_shared())
143        }
144        Operation::Permute { input, ref permutation } => {
145            let input = map(input)?;
146            map_dtensor!(input, |input| input
147                .view()
148                .permuted_axes(permutation.clone())
149                .to_shared())
150        }
151        Operation::Slice { input, axis, range } => {
152            let input = map(input)?;
153            map_dtensor!(input, |input| cpu_slice(&input, axis, range))
154        }
155        Operation::Flip { input, axis } => {
156            let input = map(input)?;
157            map_dtensor!(input, |input| cpu_flip(&input, axis))
158        }
159        Operation::Gather { input, axis, indices } => {
160            let input = map(input)?;
161            let indices = map(indices)?;
162            map_dtensor!(input, |input| cpu_gather(&input, axis, indices))
163        }
164        Operation::Concat { ref inputs, axis } => {
165            macro_rules! concat {
166                (inputs, axis, $dtype:path) => {{
167                    let inputs: Vec<_> = inputs.iter().map(|&x| map(x)).try_collect()?;
168                    let inputs_viewed = inputs.iter().map(|x| unwrap_match::unwrap_match!(x, $dtype(x) => x).view()).collect_vec();
169                    $dtype(concatenate(output_shape_dyn, axis, &inputs_viewed))
170                }}
171            }
172
173            match dtype {
174                DType::F32 => concat!(inputs, axis, DTensor::F32),
175                DType::F64 => concat!(inputs, axis, DTensor::F64),
176                DType::I8 => concat!(inputs, axis, DTensor::I8),
177                DType::I16 => concat!(inputs, axis, DTensor::I16),
178                DType::I32 => concat!(inputs, axis, DTensor::I32),
179                DType::I64 => concat!(inputs, axis, DTensor::I64),
180                DType::U8 => concat!(inputs, axis, DTensor::U8),
181                DType::U16 => concat!(inputs, axis, DTensor::U16),
182                DType::U32 => concat!(inputs, axis, DTensor::U32),
183                DType::U64 => concat!(inputs, axis, DTensor::U64),
184                DType::Bool => concat!(inputs, axis, DTensor::Bool),
185            }
186        }
187        Operation::Conv {
188            input,
189            filter,
190            details: conv_shape,
191        } => {
192            let input = map(input)?;
193            let filter = map(filter)?;
194
195            map_dtensor_pair!(input, filter, |input, filter| {
196                convolution(
197                    conv_shape,
198                    input.view().into_dimensionality::<Ix4>().unwrap(),
199                    filter.view().into_dimensionality::<Ix4>().unwrap(),
200                )
201                .into_dyn()
202                .into_shared()
203            })
204        }
205        Operation::MatMul { left, right } => {
206            let left = map(left)?;
207            let right = map(right)?;
208
209            map_dtensor_pair!(left, right, |left, right| {
210                batched_mat_mul(
211                    left.view().into_dimensionality::<Ix3>().unwrap(),
212                    right.view().into_dimensionality::<Ix3>().unwrap(),
213                )
214                .into_dyn()
215                .into_shared()
216            })
217        }
218        Operation::Unary { input, op } => {
219            let input = map(input)?;
220
221            // TODO this is really slow (since we're boxing), is there no faster way?
222            //   worst case just fully write out all possible type and unary op combinations
223            let general = dispatch_dtensor!(input, |_T, _f, input| input.map(|x| op.map(x.to_dscalar())));
224
225            if let Some(y) = general.iter().next() {
226                let y_dtype = y.dtype();
227                assert_eq!(
228                    dtype, y_dtype,
229                    "Unary operation wrong dtype: expected {:?}: {:?} -> {:?}, got {:?}",
230                    op, dtype, dtype, y_dtype
231                );
232            }
233
234            dispatch_dtype!(dtype, |T, _fs, ft| ft(general
235                .mapv(|x| T::from_dscalar(x).unwrap())
236                .into_shared()))
237        }
238        Operation::Binary { left, right, op } => {
239            let left = map(left)?;
240            let right = map(right)?;
241
242            map_dtensor_pair!(left, right, |left, right| {
243                Zip::from(&left)
244                    .and(&right)
245                    .map_collect(|&l, &r| op.map_t(l, r))
246                    .into_shared()
247            })
248        }
249        Operation::Softmax { input, axis } => {
250            let input = map(input)?;
251            let input = input.unwrap_f32().unwrap();
252            DTensor::F32(softmax(input.view(), Axis(axis)).into_shared())
253        }
254        Operation::Layernorm { input, axis, eps } => {
255            let input = map(input)?;
256            let input = input.unwrap_f32().unwrap();
257            DTensor::F32(layernorm(input.view(), Axis(axis), eps.into_inner()).into_shared())
258        }
259        Operation::Reduce { input, ref axes, op } => {
260            let input = map(input)?;
261
262            map_dtensor!(input, |input| {
263                axes.iter()
264                    .fold(input.to_shared(), |curr, &axis| {
265                        Zip::from(curr.lanes(Axis(axis)))
266                            .map_collect(|lane| op.reduce_t(lane.iter().copied()))
267                            .into_shared()
268                            .insert_axis(Axis(axis))
269                    })
270                    .reshape(output_shape_dyn)
271            })
272        }
273    };
274
275    assert_eq!(result.shape(), &output_shape.dims, "Wrong output shape");
276    Ok(result)
277}
278
279pub fn cpu_flip<T: Clone>(input: &Tensor<T>, axis: usize) -> Tensor<T> {
280    // slice with negative step (ndarray convention is different from python)
281    let info = slice_info(input.ndim(), axis, 0, None, -1);
282
283    input.slice(info).to_shared()
284}
285
286pub fn cpu_slice<T: Clone>(input: &Tensor<T>, axis: usize, range: SliceRange) -> Tensor<T> {
287    // We have to clamp the end:
288    // * SliceRange requires that `(end - start) % step == 0`
289    // * SliceInfo instead requires that `end <= len`.
290    let axis_len = input.shape()[axis];
291    let clamped_end = min(range.end, axis_len);
292
293    let info = slice_info(
294        input.ndim(),
295        axis,
296        range.start as isize,
297        Some(clamped_end as isize),
298        range.step as isize,
299    );
300
301    input.slice(info).to_shared()
302}
303
304pub fn cpu_gather<T: Clone>(input: &Tensor<T>, axis: usize, indices: DTensor) -> Tensor<T> {
305    assert_eq!(indices.rank(), 1);
306    let mut output_shape = input.shape().to_vec();
307    output_shape[axis] = indices.len();
308
309    let indices = match indices {
310        DTensor::F32(_) | DTensor::F64(_) | DTensor::Bool(_) => {
311            unreachable!("gather indices should be unsigned integers, got {:?}", indices.dtype())
312        }
313        DTensor::U8(indices) => indices.mapv(|x| x as u64).into_shared(),
314        DTensor::U16(indices) => indices.mapv(|x| x as u64).into_shared(),
315        DTensor::U32(indices) => indices.mapv(|x| x as u64).into_shared(),
316        DTensor::U64(indices) => indices,
317        // ensure no underflow
318        DTensor::I8(indices) => indices.mapv(|x| x.try_into().unwrap()).into_shared(),
319        DTensor::I16(indices) => indices.mapv(|x| x.try_into().unwrap()).into_shared(),
320        DTensor::I32(indices) => indices.mapv(|x| x.try_into().unwrap()).into_shared(),
321        DTensor::I64(indices) => indices.mapv(|x| x.try_into().unwrap()).into_shared(),
322    };
323
324    let slices = indices
325        .iter()
326        .map(|&f| {
327            let i: isize = f.try_into().expect("Index out of bounds");
328            input.slice(slice_info(input.ndim(), axis, i, Some(i + 1), 1))
329        })
330        .collect_vec();
331
332    concatenate(IxDyn(&output_shape), axis, slices.as_slice())
333}
334
335/// Wrapper around [ndarray::concatenate()] that can handle an empty input list.
336pub fn concatenate<T: Clone>(output_shape: IxDyn, axis: usize, inputs: &[ArrayView<T, IxDyn>]) -> ArcArray<T, IxDyn> {
337    let result = if inputs.is_empty() {
338        ArcArray::from_shape_fn(output_shape.clone(), |_| unreachable!("empty array has no elements"))
339    } else {
340        ndarray::concatenate(Axis(axis), inputs).unwrap().into_shared()
341    };
342
343    assert_eq!(result.dim(), output_shape);
344    result
345}
346
347pub fn convolution<T: IntoDScalar>(details: ConvDetails, input: ArrayView4<T>, kernel: ArrayView4<T>) -> Array4<T> {
348    let ConvDetails {
349        dtype,
350        batch_size: _,
351        input_channels,
352        output_channels,
353        input_h,
354        input_w,
355        kernel_h,
356        kernel_w,
357        stride_y,
358        stride_x,
359        padding_y,
360        padding_x,
361        output_h,
362        output_w,
363    } = details;
364    assert_eq!(T::DTYPE, dtype);
365
366    assert!(
367        kernel_h % 2 == 1 && kernel_w % 2 == 1,
368        "Only odd kernels supported for now"
369    );
370    let batch_size = input.shape()[0];
371
372    // We compute the convolution via im2col
373    //   * create the input matrix by repeating input values around F^2 times with some padding
374    //   * permute and reshape the kernel weights into a flat matrix
375    //   * compute the dot product
376    //   * permute and reshape the output back into a tensor
377    let input_matrix = {
378        let mut input_matrix = Array::zeros((batch_size, output_h, output_w, input_channels, kernel_h, kernel_w));
379
380        // copy over entire (batch_size, input_channels) slices at once
381        //   this mostly helps with non-optimized build performance, which is nice to have
382        for oy in 0..output_h {
383            for ox in 0..output_w {
384                for fy in 0..kernel_h {
385                    for fx in 0..kernel_w {
386                        let iy = (oy * stride_y) as isize + fy as isize - padding_y as isize;
387                        let ix = (ox * stride_x) as isize + fx as isize - padding_x as isize;
388
389                        if (0..input_h as isize).contains(&iy) && (0..input_w as isize).contains(&ix) {
390                            input_matrix
391                                .slice_mut(s![.., oy, ox, .., fy, fx])
392                                .assign(&input.slice(s![.., .., iy as usize, ix]));
393                        }
394                        // leave the padding values at zero
395                    }
396                }
397            }
398        }
399
400        input_matrix
401            .into_shape((batch_size * output_h * output_w, input_channels * kernel_h * kernel_w))
402            .unwrap()
403    };
404
405    let kernel_permuted = kernel.permuted_axes([1, 2, 3, 0]);
406    let kernel_matrix = kernel_permuted
407        .as_standard_layout()
408        .into_shape((input_channels * kernel_h * kernel_w, output_channels))
409        .unwrap();
410
411    let result_matrix = input_matrix.dot(&kernel_matrix);
412
413    let result = result_matrix
414        .into_shape((batch_size, output_h, output_w, output_channels))
415        .unwrap()
416        .permuted_axes([0, 3, 1, 2]);
417
418    result
419}
420
421pub fn batched_mat_mul<T: LinalgScalar>(left: ArrayView3<T>, right: ArrayView3<T>) -> Array3<T> {
422    let (n0, p, q0) = left.dim();
423    let (n1, q1, r) = right.dim();
424    assert!(
425        n0 == n1 && q0 == q1,
426        "Invalid matmul dimensions: {:?} and {:?}",
427        left.dim(),
428        right.dim()
429    );
430
431    let mut result = Array3::zeros((n0, p, r));
432    for i in 0..n0 {
433        let slice = s![i, .., ..];
434        result
435            .slice_mut(&slice)
436            .assign(&left.slice(&slice).dot(&right.slice(&slice)));
437    }
438    result
439}
440
441/// Softmax along the given axis of the tensor.
442/// Implementation (and more importantly, the generic bounds) based on softmax within the onnxruntime crate
443pub fn softmax<S, D>(array: ArrayBase<S, D>, axis: Axis) -> Array<f32, D>
444where
445    D: ndarray::RemoveAxis,
446    S: ndarray::RawData + ndarray::Data + ndarray::RawData<Elem = f32>,
447{
448    let mut result = array.to_owned();
449
450    let max = result.fold_axis(axis, f32::NEG_INFINITY, |&a, &x| a.max(x));
451    result -= &max.insert_axis(axis);
452
453    result.map_inplace(|x: &mut f32| *x = x.exp());
454    let sum = result.sum_axis(axis).insert_axis(axis);
455    result /= &sum;
456
457    result
458}
459
460/// Layernorm along the given axis of the tensor.
461pub fn layernorm<S, D>(array: ArrayBase<S, D>, axis: Axis, eps: f32) -> Array<f32, D>
462where
463    D: ndarray::RemoveAxis,
464    S: ndarray::RawData + ndarray::Data + ndarray::RawData<Elem = f32>,
465{
466    let mut result = array.to_owned();
467
468    let mean = result.mean_axis(axis).unwrap();
469    result -= &mean.insert_axis(axis);
470
471    let std = result
472        .mapv(|f| f.powi(2))
473        .mean_axis(axis)
474        .unwrap()
475        .mapv(|f| (f + eps).sqrt());
476    result /= &std.insert_axis(axis);
477
478    result
479}
480
481pub fn slice_info(
482    rank: usize,
483    axis: usize,
484    start: isize,
485    end: Option<isize>,
486    step: isize,
487) -> SliceInfo<Vec<SliceInfoElem>, IxDyn, IxDyn> {
488    assert_ne!(step, 0);
489
490    let vec = (0..rank)
491        .map(|r| {
492            if r == axis {
493                // grab the relevant range
494                SliceInfoElem::Slice { start, end, step }
495            } else {
496                // grab everything
497                SliceInfoElem::Slice {
498                    start: 0,
499                    end: None,
500                    step: 1,
501                }
502            }
503        })
504        .collect_vec();
505
506    // safety: we pass an owned Vec, whose .as_ref will always return the same reference
507    unsafe { SliceInfo::new(vec).unwrap() }
508}
509
510#[derive(Debug)]
511pub struct ExecutionInfo {
512    pub batch_size: usize,
513    pub values: IndexMap<Value, CalculatedValue>,
514    pub outputs: Vec<Value>,
515}
516
517pub struct CalculatedValue {
518    pub value: Value,
519    pub tensor: Option<DTensor>,
520    pub tensor_shape: ConcreteShape,
521    pub uses_seen: usize,
522    pub time_spent: f32,
523}
524
525impl ExecutionInfo {
526    pub fn output_tensors(self) -> Vec<DTensor> {
527        self.outputs
528            .iter()
529            .map(|v| {
530                // convert to standard layout so users get easily get slices if they want
531                let tensor = self.values.get(v).unwrap().tensor.as_ref().unwrap();
532                map_dtensor!(tensor, |tensor| tensor.as_standard_layout().to_shared())
533            })
534            .collect_vec()
535    }
536}
537
538impl Debug for CalculatedValue {
539    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
540        f.debug_struct("CalculatedTensor")
541            .field("value", &self.value)
542            .field("kept", &self.tensor.is_some())
543            .field("shape", &self.tensor_shape)
544            .field("time_spent", &self.time_spent)
545            .finish()
546    }
547}
548
549impl Display for ExecutionInfo {
550    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
551        writeln!(f, "ExecutionInfo {{")?;
552        for (_, value) in &self.values {
553            writeln!(f, "  {:?}", value)?;
554        }
555        writeln!(f, "}}")?;
556
557        Ok(())
558    }
559}