1use std::cmp::min;
2use std::fmt::{Debug, Display, Formatter};
3use std::time::Instant;
4
5use indexmap::IndexMap;
6use itertools::Itertools;
7use ndarray::{
8 ArcArray, Array3, Array4, ArrayView, ArrayView3, ArrayView4, Ix3, Ix4, IxDyn, LinalgScalar, s, SliceInfo,
9 SliceInfoElem, Zip,
10};
11
12use crate::dtype::{
13 dispatch_dtensor, dispatch_dtype, DTensor, DType, IntoDScalar, map_dtensor, map_dtensor_pair, Tensor,
14};
15use crate::graph::{ConvDetails, Graph, Operation, SliceRange, Value, ValueInfo};
16use crate::ndarray::{Array, ArrayBase, Axis};
17use crate::shape::ConcreteShape;
18use crate::wrap_debug::WrapDebug;
19
20pub fn cpu_eval_graph(graph: &Graph, batch_size: usize, inputs: &[DTensor]) -> Vec<DTensor> {
21 let exec = cpu_eval_graph_exec(graph, batch_size, inputs, false);
22 exec.output_tensors()
23}
24
25pub fn cpu_eval_graph_exec(graph: &Graph, batch_size: usize, inputs: &[DTensor], keep_all: bool) -> ExecutionInfo {
33 assert_eq!(
34 graph.inputs().len(),
35 inputs.len(),
36 "Wrong input count, graph has {} but got {}",
37 graph.inputs().len(),
38 inputs.len()
39 );
40
41 let mut map: IndexMap<Value, CalculatedValue> = IndexMap::default();
42
43 for output in graph.values() {
44 let info = &graph[output];
45
46 let start_time = Instant::now();
47 let result = run_cpu_operation(info, &map, inputs, batch_size);
48 let end_time = Instant::now();
49
50 let tensor_shape = ConcreteShape::new(result.shape().to_vec());
51
52 let mut output_calc = CalculatedValue {
53 value: output,
54 tensor: Some(result),
55 tensor_shape,
56 uses_seen: 0,
57 time_spent: (end_time - start_time).as_secs_f32(),
58 };
59
60 if !keep_all {
62 if graph.is_hidden_with_uses(output, 0) {
64 output_calc.tensor = None
65 }
66
67 for input in graph[output].operation.inputs() {
69 let input_calc: &mut CalculatedValue = map.get_mut(&input).unwrap();
70 input_calc.uses_seen += 1;
71
72 if graph.is_hidden_with_uses(input, input_calc.uses_seen) {
73 input_calc.tensor = None;
74 }
75 }
76 }
77
78 let prev = map.insert(output, output_calc);
80 assert!(prev.is_none());
81 }
82
83 ExecutionInfo {
84 batch_size,
85 values: map,
86 outputs: graph.outputs().to_owned(),
87 }
88}
89
90#[derive(Debug, Copy, Clone, Eq, PartialEq)]
91pub(crate) enum OperationError {
92 NoBatchSize,
93 MissingOperand,
94 MissingInput,
95}
96
97pub(crate) type OperationResult = Result<DTensor, OperationError>;
98
99fn run_cpu_operation(
100 info: &ValueInfo,
101 map: &IndexMap<Value, CalculatedValue>,
102 inputs: &[DTensor],
103 batch_size: usize,
104) -> DTensor {
105 try_run_cpu_operation(
106 info,
107 |value| Ok(map.get(&value).unwrap().tensor.as_ref().unwrap().clone()),
108 |index| Ok(inputs[index].clone()),
109 Some(batch_size),
110 )
111 .unwrap()
112}
113
114pub(crate) fn run_cpu_const_operation(info: &ValueInfo, map: impl FnMut(Value) -> OperationResult) -> OperationResult {
115 try_run_cpu_operation(info, map, |_| Err(OperationError::MissingInput), None)
116}
117
118fn try_run_cpu_operation(
119 info: &ValueInfo,
120 mut map: impl FnMut(Value) -> OperationResult,
121 input: impl Fn(usize) -> OperationResult,
122 batch_size: Option<usize>,
123) -> OperationResult {
124 let output_shape = match info.shape.as_fixed() {
125 Some(shape) => shape,
126 None => batch_size
127 .map(|batch_size| info.shape.eval(batch_size))
128 .ok_or(OperationError::NoBatchSize)?,
129 };
130 let output_shape_dyn = IxDyn(&output_shape.dims);
131 let dtype = info.dtype;
132
133 let result: DTensor = match info.operation {
134 Operation::Input { index } => input(index)?,
135 Operation::Constant { tensor: WrapDebug(ref tensor) } => tensor.clone(),
136 Operation::View { input } => {
137 let input = map(input)?;
138 input.reshape(output_shape_dyn)
139 }
140 Operation::Broadcast { input } => {
141 let input = map(input)?;
142 map_dtensor!(input, |input| input.broadcast(output_shape_dyn).unwrap().to_shared())
143 }
144 Operation::Permute { input, ref permutation } => {
145 let input = map(input)?;
146 map_dtensor!(input, |input| input
147 .view()
148 .permuted_axes(permutation.clone())
149 .to_shared())
150 }
151 Operation::Slice { input, axis, range } => {
152 let input = map(input)?;
153 map_dtensor!(input, |input| cpu_slice(&input, axis, range))
154 }
155 Operation::Flip { input, axis } => {
156 let input = map(input)?;
157 map_dtensor!(input, |input| cpu_flip(&input, axis))
158 }
159 Operation::Gather { input, axis, indices } => {
160 let input = map(input)?;
161 let indices = map(indices)?;
162 map_dtensor!(input, |input| cpu_gather(&input, axis, indices))
163 }
164 Operation::Concat { ref inputs, axis } => {
165 macro_rules! concat {
166 (inputs, axis, $dtype:path) => {{
167 let inputs: Vec<_> = inputs.iter().map(|&x| map(x)).try_collect()?;
168 let inputs_viewed = inputs.iter().map(|x| unwrap_match::unwrap_match!(x, $dtype(x) => x).view()).collect_vec();
169 $dtype(concatenate(output_shape_dyn, axis, &inputs_viewed))
170 }}
171 }
172
173 match dtype {
174 DType::F32 => concat!(inputs, axis, DTensor::F32),
175 DType::F64 => concat!(inputs, axis, DTensor::F64),
176 DType::I8 => concat!(inputs, axis, DTensor::I8),
177 DType::I16 => concat!(inputs, axis, DTensor::I16),
178 DType::I32 => concat!(inputs, axis, DTensor::I32),
179 DType::I64 => concat!(inputs, axis, DTensor::I64),
180 DType::U8 => concat!(inputs, axis, DTensor::U8),
181 DType::U16 => concat!(inputs, axis, DTensor::U16),
182 DType::U32 => concat!(inputs, axis, DTensor::U32),
183 DType::U64 => concat!(inputs, axis, DTensor::U64),
184 DType::Bool => concat!(inputs, axis, DTensor::Bool),
185 }
186 }
187 Operation::Conv {
188 input,
189 filter,
190 details: conv_shape,
191 } => {
192 let input = map(input)?;
193 let filter = map(filter)?;
194
195 map_dtensor_pair!(input, filter, |input, filter| {
196 convolution(
197 conv_shape,
198 input.view().into_dimensionality::<Ix4>().unwrap(),
199 filter.view().into_dimensionality::<Ix4>().unwrap(),
200 )
201 .into_dyn()
202 .into_shared()
203 })
204 }
205 Operation::MatMul { left, right } => {
206 let left = map(left)?;
207 let right = map(right)?;
208
209 map_dtensor_pair!(left, right, |left, right| {
210 batched_mat_mul(
211 left.view().into_dimensionality::<Ix3>().unwrap(),
212 right.view().into_dimensionality::<Ix3>().unwrap(),
213 )
214 .into_dyn()
215 .into_shared()
216 })
217 }
218 Operation::Unary { input, op } => {
219 let input = map(input)?;
220
221 let general = dispatch_dtensor!(input, |_T, _f, input| input.map(|x| op.map(x.to_dscalar())));
224
225 if let Some(y) = general.iter().next() {
226 let y_dtype = y.dtype();
227 assert_eq!(
228 dtype, y_dtype,
229 "Unary operation wrong dtype: expected {:?}: {:?} -> {:?}, got {:?}",
230 op, dtype, dtype, y_dtype
231 );
232 }
233
234 dispatch_dtype!(dtype, |T, _fs, ft| ft(general
235 .mapv(|x| T::from_dscalar(x).unwrap())
236 .into_shared()))
237 }
238 Operation::Binary { left, right, op } => {
239 let left = map(left)?;
240 let right = map(right)?;
241
242 map_dtensor_pair!(left, right, |left, right| {
243 Zip::from(&left)
244 .and(&right)
245 .map_collect(|&l, &r| op.map_t(l, r))
246 .into_shared()
247 })
248 }
249 Operation::Softmax { input, axis } => {
250 let input = map(input)?;
251 let input = input.unwrap_f32().unwrap();
252 DTensor::F32(softmax(input.view(), Axis(axis)).into_shared())
253 }
254 Operation::Layernorm { input, axis, eps } => {
255 let input = map(input)?;
256 let input = input.unwrap_f32().unwrap();
257 DTensor::F32(layernorm(input.view(), Axis(axis), eps.into_inner()).into_shared())
258 }
259 Operation::Reduce { input, ref axes, op } => {
260 let input = map(input)?;
261
262 map_dtensor!(input, |input| {
263 axes.iter()
264 .fold(input.to_shared(), |curr, &axis| {
265 Zip::from(curr.lanes(Axis(axis)))
266 .map_collect(|lane| op.reduce_t(lane.iter().copied()))
267 .into_shared()
268 .insert_axis(Axis(axis))
269 })
270 .reshape(output_shape_dyn)
271 })
272 }
273 };
274
275 assert_eq!(result.shape(), &output_shape.dims, "Wrong output shape");
276 Ok(result)
277}
278
279pub fn cpu_flip<T: Clone>(input: &Tensor<T>, axis: usize) -> Tensor<T> {
280 let info = slice_info(input.ndim(), axis, 0, None, -1);
282
283 input.slice(info).to_shared()
284}
285
286pub fn cpu_slice<T: Clone>(input: &Tensor<T>, axis: usize, range: SliceRange) -> Tensor<T> {
287 let axis_len = input.shape()[axis];
291 let clamped_end = min(range.end, axis_len);
292
293 let info = slice_info(
294 input.ndim(),
295 axis,
296 range.start as isize,
297 Some(clamped_end as isize),
298 range.step as isize,
299 );
300
301 input.slice(info).to_shared()
302}
303
304pub fn cpu_gather<T: Clone>(input: &Tensor<T>, axis: usize, indices: DTensor) -> Tensor<T> {
305 assert_eq!(indices.rank(), 1);
306 let mut output_shape = input.shape().to_vec();
307 output_shape[axis] = indices.len();
308
309 let indices = match indices {
310 DTensor::F32(_) | DTensor::F64(_) | DTensor::Bool(_) => {
311 unreachable!("gather indices should be unsigned integers, got {:?}", indices.dtype())
312 }
313 DTensor::U8(indices) => indices.mapv(|x| x as u64).into_shared(),
314 DTensor::U16(indices) => indices.mapv(|x| x as u64).into_shared(),
315 DTensor::U32(indices) => indices.mapv(|x| x as u64).into_shared(),
316 DTensor::U64(indices) => indices,
317 DTensor::I8(indices) => indices.mapv(|x| x.try_into().unwrap()).into_shared(),
319 DTensor::I16(indices) => indices.mapv(|x| x.try_into().unwrap()).into_shared(),
320 DTensor::I32(indices) => indices.mapv(|x| x.try_into().unwrap()).into_shared(),
321 DTensor::I64(indices) => indices.mapv(|x| x.try_into().unwrap()).into_shared(),
322 };
323
324 let slices = indices
325 .iter()
326 .map(|&f| {
327 let i: isize = f.try_into().expect("Index out of bounds");
328 input.slice(slice_info(input.ndim(), axis, i, Some(i + 1), 1))
329 })
330 .collect_vec();
331
332 concatenate(IxDyn(&output_shape), axis, slices.as_slice())
333}
334
335pub fn concatenate<T: Clone>(output_shape: IxDyn, axis: usize, inputs: &[ArrayView<T, IxDyn>]) -> ArcArray<T, IxDyn> {
337 let result = if inputs.is_empty() {
338 ArcArray::from_shape_fn(output_shape.clone(), |_| unreachable!("empty array has no elements"))
339 } else {
340 ndarray::concatenate(Axis(axis), inputs).unwrap().into_shared()
341 };
342
343 assert_eq!(result.dim(), output_shape);
344 result
345}
346
347pub fn convolution<T: IntoDScalar>(details: ConvDetails, input: ArrayView4<T>, kernel: ArrayView4<T>) -> Array4<T> {
348 let ConvDetails {
349 dtype,
350 batch_size: _,
351 input_channels,
352 output_channels,
353 input_h,
354 input_w,
355 kernel_h,
356 kernel_w,
357 stride_y,
358 stride_x,
359 padding_y,
360 padding_x,
361 output_h,
362 output_w,
363 } = details;
364 assert_eq!(T::DTYPE, dtype);
365
366 assert!(
367 kernel_h % 2 == 1 && kernel_w % 2 == 1,
368 "Only odd kernels supported for now"
369 );
370 let batch_size = input.shape()[0];
371
372 let input_matrix = {
378 let mut input_matrix = Array::zeros((batch_size, output_h, output_w, input_channels, kernel_h, kernel_w));
379
380 for oy in 0..output_h {
383 for ox in 0..output_w {
384 for fy in 0..kernel_h {
385 for fx in 0..kernel_w {
386 let iy = (oy * stride_y) as isize + fy as isize - padding_y as isize;
387 let ix = (ox * stride_x) as isize + fx as isize - padding_x as isize;
388
389 if (0..input_h as isize).contains(&iy) && (0..input_w as isize).contains(&ix) {
390 input_matrix
391 .slice_mut(s![.., oy, ox, .., fy, fx])
392 .assign(&input.slice(s![.., .., iy as usize, ix]));
393 }
394 }
396 }
397 }
398 }
399
400 input_matrix
401 .into_shape((batch_size * output_h * output_w, input_channels * kernel_h * kernel_w))
402 .unwrap()
403 };
404
405 let kernel_permuted = kernel.permuted_axes([1, 2, 3, 0]);
406 let kernel_matrix = kernel_permuted
407 .as_standard_layout()
408 .into_shape((input_channels * kernel_h * kernel_w, output_channels))
409 .unwrap();
410
411 let result_matrix = input_matrix.dot(&kernel_matrix);
412
413 let result = result_matrix
414 .into_shape((batch_size, output_h, output_w, output_channels))
415 .unwrap()
416 .permuted_axes([0, 3, 1, 2]);
417
418 result
419}
420
421pub fn batched_mat_mul<T: LinalgScalar>(left: ArrayView3<T>, right: ArrayView3<T>) -> Array3<T> {
422 let (n0, p, q0) = left.dim();
423 let (n1, q1, r) = right.dim();
424 assert!(
425 n0 == n1 && q0 == q1,
426 "Invalid matmul dimensions: {:?} and {:?}",
427 left.dim(),
428 right.dim()
429 );
430
431 let mut result = Array3::zeros((n0, p, r));
432 for i in 0..n0 {
433 let slice = s![i, .., ..];
434 result
435 .slice_mut(&slice)
436 .assign(&left.slice(&slice).dot(&right.slice(&slice)));
437 }
438 result
439}
440
441pub fn softmax<S, D>(array: ArrayBase<S, D>, axis: Axis) -> Array<f32, D>
444where
445 D: ndarray::RemoveAxis,
446 S: ndarray::RawData + ndarray::Data + ndarray::RawData<Elem = f32>,
447{
448 let mut result = array.to_owned();
449
450 let max = result.fold_axis(axis, f32::NEG_INFINITY, |&a, &x| a.max(x));
451 result -= &max.insert_axis(axis);
452
453 result.map_inplace(|x: &mut f32| *x = x.exp());
454 let sum = result.sum_axis(axis).insert_axis(axis);
455 result /= ∑
456
457 result
458}
459
460pub fn layernorm<S, D>(array: ArrayBase<S, D>, axis: Axis, eps: f32) -> Array<f32, D>
462where
463 D: ndarray::RemoveAxis,
464 S: ndarray::RawData + ndarray::Data + ndarray::RawData<Elem = f32>,
465{
466 let mut result = array.to_owned();
467
468 let mean = result.mean_axis(axis).unwrap();
469 result -= &mean.insert_axis(axis);
470
471 let std = result
472 .mapv(|f| f.powi(2))
473 .mean_axis(axis)
474 .unwrap()
475 .mapv(|f| (f + eps).sqrt());
476 result /= &std.insert_axis(axis);
477
478 result
479}
480
481pub fn slice_info(
482 rank: usize,
483 axis: usize,
484 start: isize,
485 end: Option<isize>,
486 step: isize,
487) -> SliceInfo<Vec<SliceInfoElem>, IxDyn, IxDyn> {
488 assert_ne!(step, 0);
489
490 let vec = (0..rank)
491 .map(|r| {
492 if r == axis {
493 SliceInfoElem::Slice { start, end, step }
495 } else {
496 SliceInfoElem::Slice {
498 start: 0,
499 end: None,
500 step: 1,
501 }
502 }
503 })
504 .collect_vec();
505
506 unsafe { SliceInfo::new(vec).unwrap() }
508}
509
510#[derive(Debug)]
511pub struct ExecutionInfo {
512 pub batch_size: usize,
513 pub values: IndexMap<Value, CalculatedValue>,
514 pub outputs: Vec<Value>,
515}
516
517pub struct CalculatedValue {
518 pub value: Value,
519 pub tensor: Option<DTensor>,
520 pub tensor_shape: ConcreteShape,
521 pub uses_seen: usize,
522 pub time_spent: f32,
523}
524
525impl ExecutionInfo {
526 pub fn output_tensors(self) -> Vec<DTensor> {
527 self.outputs
528 .iter()
529 .map(|v| {
530 let tensor = self.values.get(v).unwrap().tensor.as_ref().unwrap();
532 map_dtensor!(tensor, |tensor| tensor.as_standard_layout().to_shared())
533 })
534 .collect_vec()
535 }
536}
537
538impl Debug for CalculatedValue {
539 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
540 f.debug_struct("CalculatedTensor")
541 .field("value", &self.value)
542 .field("kept", &self.tensor.is_some())
543 .field("shape", &self.tensor_shape)
544 .field("time_spent", &self.time_spent)
545 .finish()
546 }
547}
548
549impl Display for ExecutionInfo {
550 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
551 writeln!(f, "ExecutionInfo {{")?;
552 for (_, value) in &self.values {
553 writeln!(f, " {:?}", value)?;
554 }
555 writeln!(f, "}}")?;
556
557 Ok(())
558 }
559}