use crate::{
components::global::{MaxGlobalReaderPlanes, specialization::roles::PlaneFlowCounts},
definition::StageIdent,
};
#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub struct LoadFlows {
pub lhs: InputLoadFlow,
pub rhs: InputLoadFlow,
}
#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub enum InputLoadFlow {
#[default]
MainOnly,
LoadOnly,
}
impl LoadFlows {
pub fn has_specialization(&self) -> bool {
self.lhs.has_specialization() || self.rhs.has_specialization()
}
}
impl InputLoadFlow {
pub fn has_specialization(&self) -> bool {
match self {
InputLoadFlow::MainOnly => false,
InputLoadFlow::LoadOnly => true,
}
}
}
impl LoadFlows {
pub fn to_plane_flow_counts(
&self,
main_flow: u32,
reader_tasks: MaxGlobalReaderPlanes,
) -> PlaneFlowCounts {
use InputLoadFlow::*;
let ideal_load_only = match (self.lhs, self.rhs) {
(MainOnly, MainOnly) => 0,
(MainOnly, LoadOnly) => reader_tasks.rhs,
(LoadOnly, MainOnly) => reader_tasks.lhs,
(LoadOnly, LoadOnly) => gcd(reader_tasks.lhs, reader_tasks.rhs),
};
let load_only = best_divisor_close_to_reference(ideal_load_only, main_flow);
PlaneFlowCounts {
main_flow,
load_only,
}
}
}
fn best_divisor_close_to_reference(dividible_value: u32, reference: u32) -> u32 {
let mut best = 1;
let mut best_dist = reference.abs_diff(1);
for d in 1..=dividible_value {
if dividible_value.is_multiple_of(d) {
let dist = reference.abs_diff(d);
if dist < best_dist || (dist == best_dist && d > best) {
best = d;
best_dist = dist;
}
}
}
best
}
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub enum LoadingSides {
Both,
Lhs,
Rhs,
None,
}
impl LoadingSides {
pub fn includes_lhs(&self) -> bool {
self.includes(StageIdent::Lhs)
}
pub fn includes_rhs(&self) -> bool {
self.includes(StageIdent::Rhs)
}
pub fn includes(&self, ident: StageIdent) -> bool {
matches!(
(self, ident),
(LoadingSides::Both, _)
| (LoadingSides::Lhs, StageIdent::Lhs)
| (LoadingSides::Rhs, StageIdent::Rhs)
)
}
}
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub struct SpecializedLoadingSides {
pub main_flow: LoadingSides,
pub load_only: LoadingSides,
}
impl SpecializedLoadingSides {
pub fn num_loading_planes(
&self,
specialized: bool,
ident: StageIdent,
plane_flow_counts: PlaneFlowCounts,
) -> u32 {
if specialized {
let mut num_loading_planes = 0;
if self.main_flow.includes(ident) {
num_loading_planes += plane_flow_counts.main_flow;
}
if self.load_only.includes(ident) {
num_loading_planes += plane_flow_counts.load_only;
}
num_loading_planes
} else {
plane_flow_counts.main_flow
}
}
}
impl From<LoadFlows> for SpecializedLoadingSides {
fn from(lsc: LoadFlows) -> Self {
use InputLoadFlow::*;
match (lsc.lhs, lsc.rhs) {
(MainOnly, MainOnly) => SpecializedLoadingSides {
main_flow: LoadingSides::Both,
load_only: LoadingSides::None,
},
(MainOnly, LoadOnly) => SpecializedLoadingSides {
main_flow: LoadingSides::Lhs,
load_only: LoadingSides::Rhs,
},
(LoadOnly, MainOnly) => SpecializedLoadingSides {
main_flow: LoadingSides::Rhs,
load_only: LoadingSides::Lhs,
},
(LoadOnly, LoadOnly) => SpecializedLoadingSides {
main_flow: LoadingSides::None,
load_only: LoadingSides::Both,
},
}
}
}
pub(crate) fn gcd(mut a: u32, mut b: u32) -> u32 {
while b != 0 {
let r = a % b;
a = b;
b = r;
}
a
}
pub struct MatmulPlaneCounts {
pub lhs: u32,
pub rhs: u32,
pub out: u32,
pub total: u32,
}
impl MatmulPlaneCounts {
pub fn new(tensor_load_flows: LoadFlows, plane_flow_counts: PlaneFlowCounts) -> Self {
let total = plane_flow_counts.total_count();
match tensor_load_flows.has_specialization() {
true => {
let loading_sides: SpecializedLoadingSides = tensor_load_flows.into();
Self {
lhs: loading_sides.num_loading_planes(true, StageIdent::Lhs, plane_flow_counts),
rhs: loading_sides.num_loading_planes(true, StageIdent::Rhs, plane_flow_counts),
out: plane_flow_counts.main_flow,
total,
}
}
false => Self {
lhs: total,
rhs: total,
out: total,
total,
},
}
}
}