use super::super::{cpa, Elem, Item, Operator, Scope, Variable};
use crate::ir::{BinaryOperator, Vectorization};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ReadGlobal {
pub global: Variable,
pub out: Variable,
pub position: Variable,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ReadGlobalWithLayout {
pub globals: Vec<Variable>,
pub outs: Vec<Variable>,
pub layout: Variable,
pub position: Variable,
}
impl ReadGlobal {
#[allow(missing_docs)]
pub fn expand(self, scope: &mut Scope) {
scope.register(Operator::Index(BinaryOperator {
lhs: self.global,
rhs: self.position,
out: self.out,
}));
}
pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self {
Self {
global: self.global.vectorize(vectorization),
out: self.out.vectorize(vectorization),
position: self.position,
}
}
}
impl ReadGlobalWithLayout {
pub fn try_merge(&self, other: &Self) -> Option<Self> {
if self.layout != other.layout {
return None;
}
if self.position != other.position {
return None;
}
let mut globals = Vec::with_capacity(self.globals.len() + other.globals.len());
globals.extend(&self.globals);
globals.extend(&other.globals);
let mut outs = Vec::with_capacity(self.outs.len() + other.outs.len());
outs.extend(&self.outs);
outs.extend(&other.outs);
Some(Self {
globals,
outs,
layout: self.layout,
position: self.position,
})
}
#[allow(missing_docs)]
pub fn expand(self, scope: &mut Scope) {
let outputs = self.outs;
let tensors = self.globals;
let indexes = tensors
.iter()
.map(|_| scope.create_local(Elem::UInt))
.collect::<Vec<_>>();
IndexOffsetGlobalWithLayout {
tensors: tensors.clone(),
layout: self.layout,
indexes: indexes.clone(),
position: self.position,
dim_start: 0u32.into(),
dim_end: Variable::Rank,
}
.expand(scope);
for i in 0..outputs.len() {
let tensor = tensors[i];
let output = outputs[i];
let index = indexes[i];
cpa!(scope, output = tensor[index]);
}
}
pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self {
Self {
globals: self
.globals
.iter()
.map(|g| g.vectorize(vectorization))
.collect(),
layout: self.layout.vectorize(vectorization),
outs: self
.outs
.iter()
.map(|o| o.vectorize(vectorization))
.collect(),
position: self.position,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[allow(missing_docs)]
pub struct IndexOffsetGlobalWithLayout {
pub tensors: Vec<Variable>,
pub indexes: Vec<Variable>,
pub layout: Variable,
pub position: Variable,
pub dim_start: Variable,
pub dim_end: Variable,
}
impl IndexOffsetGlobalWithLayout {
#[allow(missing_docs)]
pub fn expand(self, scope: &mut Scope) {
let layout = self.layout;
let index_item_ty = Item::new(Elem::UInt);
let offset_ref = self.position;
let zero: Variable = 0u32.into();
let vectorization_factor: u8 = self.tensors[0].item().vectorization;
let vectorization_factor: Variable = (vectorization_factor as u32).into();
for index in self.indexes.iter() {
cpa!(scope, index = zero);
}
cpa!(
scope,
range(self.dim_start, self.dim_end).for_each(|i, scope| {
let stride_layout = scope.create_local(index_item_ty);
let ogwl = scope.create_local(index_item_ty);
cpa!(scope, stride_layout = stride(layout, i));
cpa!(scope, ogwl = offset_ref * vectorization_factor);
cpa!(scope, ogwl = ogwl / stride_layout);
for (tensor, index) in self.tensors.iter().zip(self.indexes.iter()) {
let stride = scope.create_local(index_item_ty);
let shape = scope.create_local(index_item_ty);
let tmp = scope.create_local(index_item_ty);
cpa!(scope, stride = stride(tensor, i));
cpa!(scope, shape = shape(tensor, i));
cpa!(scope, tmp = ogwl % shape);
cpa!(scope, tmp = tmp * stride);
cpa!(scope, index = index + tmp);
}
})
);
for index in self.indexes {
cpa!(scope, index = index / vectorization_factor);
}
}
pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self {
Self {
tensors: self
.tensors
.iter()
.map(|t| t.vectorize(vectorization))
.collect(),
indexes: self
.indexes
.iter()
.map(|t| t.vectorize(vectorization))
.collect(),
layout: self.layout.vectorize(vectorization),
position: self.position.vectorize(vectorization),
dim_start: self.dim_start,
dim_end: self.dim_end,
}
}
}