use furiosa_mapping::*;
use furiosa_opt_macro::primitive;
use crate::cast::FetchCast;
use crate::context::*;
use crate::engine::{CanApplyFetchCast, CanApplyFetchMask, CanApplyFetchTableLookup};
use crate::runtime::{Backend, CurrentBackend};
use crate::scalar::*;
use crate::tensor::tu::{Position, TuTensor};
#[derive(Debug, Clone, Default)]
pub struct FetchMaskConfig {
pub last_axis: usize,
pub valid_count_dim: ValidCountDim,
pub rightmost_valid_count: [u8; 8],
}
#[derive(Debug, Clone, Copy, Default)]
pub enum ValidCountDim {
#[default]
Rightmost,
RightmostAndSecondRightmost,
Iterator(usize),
}
#[derive(Debug)]
pub struct PositionFetchMask;
impl Position for PositionFetchMask {}
pub type FetchMaskTensor<'l, const T: Tu, D, Chip, Cluster, Slice, Time, Packet, B = CurrentBackend> =
TuTensor<'l, { T }, PositionFetchMask, D, Chip, Cluster, Slice, Time, Packet, B>;
#[derive(Debug)]
pub struct PositionFetchTableLookup;
impl Position for PositionFetchTableLookup {}
pub type FetchTableLookupTensor<'l, const T: Tu, D, Chip, Cluster, Slice, Time, Packet, B = CurrentBackend> =
TuTensor<'l, { T }, PositionFetchTableLookup, D, Chip, Cluster, Slice, Time, Packet, B>;
#[derive(Debug)]
pub struct PositionFetchCast;
impl Position for PositionFetchCast {}
pub type FetchCastTensor<'l, const T: Tu, D, Chip, Cluster, Slice, Time, Packet, B = CurrentBackend> =
TuTensor<'l, { T }, PositionFetchCast, D, Chip, Cluster, Slice, Time, Packet, B>;
impl<'l, const T: Tu, P: CanApplyFetchMask, D: Scalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M, B: Backend>
TuTensor<'l, T, P, D, Chip, Cluster, Slice, Time, Packet, B>
{
#[primitive(TuTensor::fetch_mask)]
pub fn fetch_mask<OutTime: M, OutPacket: M>(
self,
) -> FetchMaskTensor<'l, T, D, Chip, Cluster, Slice, OutTime, OutPacket, B> {
verify_fetch_mask::<Time, Packet, OutTime, OutPacket>();
FetchMaskTensor::new(self.ctx, self.inner.transpose(true))
}
}
impl<
'l,
const T: Tu,
P: CanApplyFetchTableLookup,
D: Scalar,
Chip: M,
Cluster: M,
Slice: M,
Time: M,
Packet: M,
B: Backend,
> TuTensor<'l, T, P, D, Chip, Cluster, Slice, Time, Packet, B>
{
#[primitive(TuTensor::fetch_table_lookup)]
#[allow(unreachable_code)]
pub fn fetch_table_lookup<OutD: Scalar>(
self,
) -> FetchTableLookupTensor<'l, T, OutD, Chip, Cluster, Slice, Time, Packet, B> {
verify_fetch_table_lookup::<D, OutD, Time, Packet>();
FetchTableLookupTensor::new(self.ctx, todo!())
}
}
impl<'l, const T: Tu, P: CanApplyFetchCast, D: Scalar, Chip: M, Cluster: M, Slice: M, Time: M, Packet: M, B: Backend>
TuTensor<'l, T, P, D, Chip, Cluster, Slice, Time, Packet, B>
{
#[primitive(TuTensor::fetch_cast)]
pub fn fetch_cast<OutD: Scalar>(self) -> FetchCastTensor<'l, T, OutD, Chip, Cluster, Slice, Time, Packet, B>
where
D: FetchCast<OutD>,
{
FetchCastTensor::new(self.ctx, self.inner.map(|v| v.map(|v| v.cast())))
}
}
#[allow(clippy::extra_unused_type_parameters)]
fn verify_fetch_mask<Time: M, Packet: M, OutTime: M, OutPacket: M>() {
todo!("fetch_mask is not yet implemented")
}
#[allow(clippy::extra_unused_type_parameters)]
fn verify_fetch_table_lookup<D: Scalar, OutD: Scalar, Time: M, Packet: M>() {
todo!("fetch_table_lookup is not yet implemented")
}