use std::collections::HashSet;
use furiosa_mapping::*;
pub(crate) fn transpose_broadcast<Src: M, Dst: M>(allow_broadcast: bool) -> Mapping {
let src_mapping = Src::to_value();
let dst_mapping = Dst::to_value();
let broadcast = dst_mapping.carve(&src_mapping);
if !allow_broadcast {
assert!(broadcast.is_padding());
}
broadcast
}
pub(crate) fn assert_zip(lhs_axes: &[Term], rhs_axes: &[Term]) {
assert_eq!(
lhs_axes, rhs_axes,
"Tensors must have the same axes for element-wise binary operations"
);
}
pub(crate) fn reduce_broadcast(src_axes: &[Term], dst_axes: &[Term]) -> Mapping {
Mapping::from_terms(dst_axes.iter().cloned()).carve(&Mapping::from_terms(src_axes.iter().cloned()))
}
pub(crate) fn broadcast_axes(src: &Mapping, dst: &Mapping) -> Mapping {
let src_idents: HashSet<Ident> = src.idents().into_iter().collect();
let dst_axes = dst.axes();
let mut ids = Vec::new();
debug_assert!(
dst_axes.iter().all(|term| {
collect_term_idents(term, &mut ids);
let in_src = ids.iter().filter(|i| src_idents.contains(i)).count();
in_src == 0 || in_src == ids.len()
}),
"broadcast_axes assumes each dst axis is wholly new or wholly shared, not a straddling split"
);
let axes: Vec<Term> = dst_axes
.into_iter()
.filter(|term| {
collect_term_idents(term, &mut ids);
ids.iter().all(|ident| !src_idents.contains(ident))
})
.collect();
Mapping::from_terms(axes)
}
fn collect_term_idents(term: &Term, out: &mut Vec<Ident>) {
out.clear();
match &term.inner {
Atom::Symbol { symbol, .. } => out.push(*symbol),
Atom::Composite(inner) => out.extend(inner.idents()),
}
}
pub(crate) fn scatter_params(src: &Mapping, dst: &Mapping, key: &Mapping) -> (Mapping, Term) {
let payload = src.carve(key);
let dst_term = dst
.carve(&payload)
.axes()
.into_iter()
.next()
.expect("scatter dst residue has no live target axis");
(payload, dst_term)
}
pub(crate) struct GatherParams {
pub payload: Mapping,
pub idx_residue: Mapping,
pub src_term: Term,
}
pub(crate) fn gather_params(src: &Mapping, dst: &Mapping, idx: &Mapping) -> GatherParams {
let payload = dst.carve(idx);
let idx_residue = dst.carve(&payload);
let src_term = src
.carve(&payload)
.axes()
.into_iter()
.next()
.expect("gather src residue has no live target axis");
GatherParams {
payload,
idx_residue,
src_term,
}
}
#[cfg(test)]
mod tests {
use furiosa_mapping::*;
use super::broadcast_axes;
axes![A = 4, B = 2, C = 8];
#[test]
fn broadcast_axes_matches_carve() {
let cases: [(Mapping, Mapping); 4] = [
(<m![A, C]>::to_value(), <m![A, B, C]>::to_value()), (<m![A]>::to_value(), <m![A, B]>::to_value()),
(<m![B]>::to_value(), <m![A, B, C]>::to_value()), (<m![A, B, C]>::to_value(), <m![A, B, C]>::to_value()), ];
let sorted_axes = |m: &Mapping| {
let mut a = m.axes();
a.sort();
a
};
for (src, dst) in cases {
assert_eq!(
sorted_axes(&broadcast_axes(&src, &dst)),
sorted_axes(&dst.carve(&src)),
"broadcast_axes(src={src:?}, dst={dst:?})"
);
}
}
}