pub(super) mod stream_adapter;
pub(super) mod trf_sequencer;
use furiosa_mapping::*;
use furiosa_opt_lower::config_divide_exact;
use furiosa_opt_macro::primitive;
use crate::cast::{Cast, ContractionCast};
use crate::context::*;
use crate::engine::CanApplyContractOuter;
use crate::runtime::{Backend, CurrentBackend};
use crate::scalar::*;
use crate::tensor::Tensor;
use crate::tensor::memory::TrfTensor;
use crate::tensor::tu::TuTensor;
#[derive(Debug)]
pub struct ContractOuterTensor<
'l,
const T: Tu,
D: Scalar + ContractionCast,
Chip: M,
Cluster: M,
Slice: M,
Lane: M,
Time: M,
Packet: M,
B: Backend = CurrentBackend,
> {
pub(crate) ctx: &'l mut TuContext<{ T }>,
pub(crate) inner: Tensor<
<D as ContractionCast>::Output,
Pair<Chip, Pair<Cluster, Pair<Slice, Pair<Lane, Pair<Time, Packet>>>>>,
B,
>,
}
impl<
'l,
const T: Tu,
P: CanApplyContractOuter,
D: Scalar + ContractionCast,
Chip: M,
Cluster: M,
Slice: M,
Time: M,
Packet: M,
B: Backend,
> TuTensor<'l, T, P, D, Chip, Cluster, Slice, Time, Packet, B>
{
#[primitive(TuTensor::contract_outer)]
pub fn contract_outer<OutTime: M, OutPacket: M, Lane: M, TrfElement: M>(
self,
trf_tensor: &TrfTensor<D, Chip, Cluster, Slice, Lane, TrfElement, B>,
) -> ContractOuterTensor<'l, T, D, Chip, Cluster, Slice, Lane, OutTime, OutPacket, B>
where
D: Cast<<D as ContractionCast>::Output>,
{
verify_contract_outer::<D, Lane, Time, Packet, OutTime, OutPacket>();
let lhs = stream_adapter::contract_outer::<D, Chip, Cluster, Slice, Lane, Time, Packet, OutTime, OutPacket, B>(
self.inner,
);
let trf = trf_sequencer::contract_outer::<D, Chip, Cluster, Slice, Lane, TrfElement, OutTime, OutPacket, B>(
trf_tensor,
);
let lhs = lhs.map(|v| v.map(|v| v.cast()));
let trf = trf.map(|v| v.map(|v| v.cast()));
let inner = lhs.zip_with(&trf, |a, b| a * b);
ContractOuterTensor { ctx: self.ctx, inner }
}
}
fn verify_contract_outer<D: Scalar, Lane: M, Time: M, Packet: M, OutTime: M, OutPacket: M>() {
stream_adapter::verify_stream_adapter::<D, Lane, Time, Packet, OutTime, OutPacket>();
let expected_time = OutTime::to_value()
.pair(
OutPacket::to_value()
.stride(D::length_from_bytes(crate::engine::FLIT_BYTES))
.remove_padding(),
)
.normalize();
let input_time = Time::to_value();
config_divide_exact(&expected_time, &input_time).expect("`contract_outer`: Time does not divide OutTime");
}