knok-compile 0.3.0

MLIR lowering and IREE compilation for knok
use knok_core::{Conv2dOptions, TensorType};

use super::lowerer::{Lowerer, Value};

impl Lowerer<'_, '_> {
    pub(super) fn conv2d(
        &mut self,
        input: Value,
        kernel: Value,
        options: &Conv2dOptions,
    ) -> anyhow::Result<Value> {
        let input = self.pad_conv2d_input(input, options)?;
        if options.groups > 1 {
            return self.grouped_conv2d(input, kernel, options);
        }
        self.conv2d_direct(input, kernel, options)
    }

    fn conv2d_direct(
        &mut self,
        input: Value,
        kernel: Value,
        options: &Conv2dOptions,
    ) -> anyhow::Result<Value> {
        let ty = TensorType {
            elem: input.ty.elem,
            shape: vec![
                input.ty.shape[0],
                (input.ty.shape[1] - ((kernel.ty.shape[0] - 1) * options.dilation[0] + 1))
                    / options.stride[0]
                    + 1,
                (input.ty.shape[2] - ((kernel.ty.shape[1] - 1) * options.dilation[1] + 1))
                    / options.stride[1]
                    + 1,
                kernel.ty.shape[3],
            ],
        };
        let init = self.zero_initialized_tensor(&ty)?;
        let attrs = conv2d_attrs(options);
        self.append_named_linalg(
            "linalg.conv_2d_nhwc_hwcf",
            &[input, kernel],
            init,
            &ty,
            &attrs,
        )
    }

    fn grouped_conv2d(
        &mut self,
        input: Value,
        kernel: Value,
        options: &Conv2dOptions,
    ) -> anyhow::Result<Value> {
        let groups = options.groups;
        let input_channels_per_group = input.ty.shape[3] / groups;
        let output_channels_per_group = kernel.ty.shape[3] / groups;
        let mut output = None;
        for group in 0..groups {
            let input_slice_ty = TensorType {
                elem: input.ty.elem,
                shape: vec![
                    input.ty.shape[0],
                    input.ty.shape[1],
                    input.ty.shape[2],
                    input_channels_per_group,
                ],
            };
            let kernel_slice_ty = TensorType {
                elem: kernel.ty.elem,
                shape: vec![
                    kernel.ty.shape[0],
                    kernel.ty.shape[1],
                    kernel.ty.shape[2],
                    output_channels_per_group,
                ],
            };
            let input_slice = self.slice(
                input.clone(),
                &input_slice_ty,
                &[0, 0, 0, group * input_channels_per_group],
            )?;
            let kernel_slice = self.slice(
                kernel.clone(),
                &kernel_slice_ty,
                &[0, 0, 0, group * output_channels_per_group],
            )?;
            let group_output = self.conv2d_direct(input_slice, kernel_slice, options)?;
            output = Some(match output {
                Some(previous) => self.concat(previous, group_output, 3)?,
                None => group_output,
            });
        }
        output.ok_or_else(|| anyhow::anyhow!("conv2d groups must be non-zero"))
    }

    fn pad_conv2d_input(&mut self, input: Value, options: &Conv2dOptions) -> anyhow::Result<Value> {
        let padding = options.padding;
        if padding.top == 0 && padding.bottom == 0 && padding.left == 0 && padding.right == 0 {
            return Ok(input);
        }
        let ty = TensorType {
            elem: input.ty.elem,
            shape: vec![
                input.ty.shape[0],
                input.ty.shape[1] + padding.top + padding.bottom,
                input.ty.shape[2] + padding.left + padding.right,
                input.ty.shape[3],
            ],
        };
        let zero = self.constant(ty.elem.zero_literal(), ty.elem)?;
        self.append_tensor_pad(
            input,
            &ty,
            &[0, padding.top, padding.left, 0],
            &[0, padding.bottom, padding.right, 0],
            zero,
        )
    }
}

fn conv2d_attrs(options: &Conv2dOptions) -> Vec<(String, String)> {
    if options.dilation == [1, 1] && options.stride == [1, 1] {
        Vec::new()
    } else {
        vec![
            (
                "dilations".to_string(),
                format!(
                    "dense<[{}, {}]> : vector<2xi64>",
                    options.dilation[0], options.dilation[1]
                ),
            ),
            (
                "strides".to_string(),
                format!(
                    "dense<[{}, {}]> : vector<2xi64>",
                    options.stride[0], options.stride[1]
                ),
            ),
        ]
    }
}