use furiosa_mapping::*;
use furiosa_opt_macro::primitive;
use crate::context::*;
use crate::engine::CanApplyFetch;
use crate::runtime::{Backend, CurrentBackend};
use crate::scalar::*;
use crate::tensor::tu::{Position, TuTensor};
use furiosa_opt_lower::{FETCH_ALIGN_BYTES, FetchError, config_fetch};
#[derive(Debug)]
pub struct PositionFetch;
impl Position for PositionFetch {}
pub type FetchTensor<'l, const T: Tu, D, Chip, Cluster, Slice, Time, Packet, B = CurrentBackend> =
TuTensor<'l, { T }, PositionFetch, D, Chip, Cluster, Slice, Time, Packet, B>;
impl<'l, const T: Tu, P: CanApplyFetch, D: Scalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M, B: Backend>
TuTensor<'l, T, P, D, Chip, Cluster, Slice, Time, Packet, B>
{
#[primitive(TuTensor::fetch)]
pub fn fetch<OutTime: M, OutPacket: M>(self) -> FetchTensor<'l, T, D, Chip, Cluster, Slice, OutTime, OutPacket, B> {
verify_fetch::<D, Cluster, Slice, Time, Packet, OutTime, OutPacket>();
FetchTensor::new(self.ctx, self.inner.transpose(true))
}
}
fn verify_fetch<D: Scalar, Cluster: M, Slice: M, Time: M, Packet: M, OutTime: M, OutPacket: M>() {
assert!(
matches!(Cluster::SIZE, 1 | 2),
"{}",
FetchError::ClusterSize(Cluster::SIZE)
);
assert!(
matches!(Slice::SIZE, 64 | 128 | 192 | 256),
"{}",
FetchError::SliceSize(Slice::SIZE)
);
let bytes = D::size_in_bytes_from_length(OutPacket::SIZE);
assert!(
bytes.is_multiple_of(FETCH_ALIGN_BYTES),
"{}",
FetchError::PacketAlignment { bytes }
);
let _ = config_fetch(
&Time::to_value(),
&Packet::to_value(),
&OutTime::to_value(),
&OutPacket::to_value(),
)
.unwrap_or_else(|e| panic!("{e}"));
}
#[cfg(test)]
mod tests {
use super::*;
axes![A = 8, B = 8, S = 64, X = 8, L = 16];
#[test]
fn valid_read() {
verify_fetch::<i8, m![1], m![S], m![1], m![A, B], m![A], m![B]>();
}
#[test]
fn transposed_read() {
verify_fetch::<i8, m![1], m![S], m![1], m![A, B], m![B], m![A]>();
}
#[test]
fn split_axis_across_packet_and_time() {
verify_fetch::<i8, m![1], m![S], m![1], m![L], m![L / 8], m![L % 8]>();
}
#[test]
fn broadcast_time_read() {
verify_fetch::<i8, m![1], m![S], m![1], m![A], m![4], m![A]>();
}
#[test]
fn live_input_time_reads() {
verify_fetch::<i8, m![1], m![S], m![A], m![B], m![A], m![B]>();
}
}