use std::fmt;
use std::marker::PhantomData;
mod sealed {
pub trait Sealed {}
}
pub trait Buf: sealed::Sealed + Copy + Clone + 'static {}
pub struct BufId<B: Buf> {
pub(crate) idx: u32,
pub(crate) _tag: PhantomData<B>,
}
impl<B: Buf> BufId<B> {
#[inline]
pub(crate) fn from_raw(idx: u32) -> Self {
Self { idx, _tag: PhantomData }
}
#[inline]
pub(crate) fn raw(self) -> u32 {
self.idx
}
}
impl<B: Buf> Copy for BufId<B> {}
impl<B: Buf> Clone for BufId<B> {
#[inline]
fn clone(&self) -> Self {
*self
}
}
impl<B: Buf> PartialEq for BufId<B> {
#[inline]
fn eq(&self, other: &Self) -> bool {
self.idx == other.idx
}
}
impl<B: Buf> Eq for BufId<B> {}
impl<B: Buf> std::hash::Hash for BufId<B> {
#[inline]
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.idx.hash(state);
}
}
impl<B: Buf> PartialOrd for BufId<B> {
#[inline]
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<B: Buf> Ord for BufId<B> {
#[inline]
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.idx.cmp(&other.idx)
}
}
impl<B: Buf> fmt::Debug for BufId<B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "BufId<_>({})", self.idx)
}
}
impl<B: Buf> fmt::Display for BufId<B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "%{}", self.idx)
}
}
macro_rules! tag {
($(#[$doc:meta])* $name:ident) => {
$(#[$doc])*
#[derive(Copy, Clone, Debug)]
pub struct $name;
impl sealed::Sealed for $name {}
impl Buf for $name {}
};
}
tag!(
EmbedOutBuf
);
tag!(
ResidualBuf
);
tag!(
HiddenBuf
);
tag!(
AttnInputBuf
);
tag!(
MoeInputBuf
);
tag!(
TailNormedBuf
);
tag!(
QProjOutBuf
);
tag!(
QBuf
);
tag!(
KProjOutBuf
);
tag!(
VProjOutBuf
);
tag!(
QGateBuf
);
tag!(
AttnOutBuf
);
tag!(
OProjOutBuf
);
tag!(
KvCacheKBuf
);
tag!(
KvCacheVBuf
);
tag!(
RouterLogitsBuf
);
tag!(
RouterIdxBuf
);
tag!(
RouterWeightsBuf
);
tag!(
SharedGateBuf
);
tag!(
SharedFfnGateBuf
);
tag!(
SharedFfnUpBuf
);
tag!(
SharedFfnActBuf
);
tag!(
SharedFfnDownBuf
);
tag!(
MoeOutSumBuf
);
tag!(
BucketInputBuf
);
tag!(
BucketGateBuf
);
tag!(
BucketUpBuf
);
tag!(
BucketActBuf
);
tag!(
BucketOutBuf
);
tag!(
BucketTokenIdxBuf
);
tag!(
BucketWeightsBuf
);
tag!(
ExpertIndicesBuf
);
tag!(
ExpertBaseBuf
);
tag!(
HtpeBuf
);
tag!(
HidsBuf
);
tag!(
GateMidBuf
);
tag!(
UpMidBuf
);
tag!(
DownMidBuf
);
tag!(
ConvOutBuf
);
tag!(
ConvStateBuf
);
tag!(
QkvStackBuf
);
tag!(
ZStackBuf
);
tag!(
AlphaStackBuf
);
tag!(
BetaStackBuf
);
tag!(
GDecayBuf
);
tag!(
BetaGateBuf
);
tag!(
DeltaStateBuf
);
tag!(
DeltaOutBuf
);
tag!(
ValueOutBuf
);
tag!(
RopeInvFreqBuf
);
tag!(
TokenIdsBuf
);
tag!(
LogitsBuf
);
tag!(
DeprecatedCogitoBuf
);
tag!(
MatvecIn
);
tag!(
MatvecOut
);
tag!(
RmsNormIn
);
tag!(
RmsNormOut
);
macro_rules! impl_from {
($src:ident -> $dst:ident) => {
impl From<BufId<$src>> for BufId<$dst> {
#[inline]
fn from(x: BufId<$src>) -> Self {
BufId::from_raw(x.raw())
}
}
};
}
impl_from!(BucketGateBuf -> GateMidBuf);
impl_from!(BucketUpBuf -> UpMidBuf);
impl_from!(BucketOutBuf -> DownMidBuf);
impl_from!(HiddenBuf -> EmbedOutBuf);
impl_from!(HiddenBuf -> TailNormedBuf);
impl_from!(AttnInputBuf -> MatvecIn);
impl_from!(MoeInputBuf -> MatvecIn);
impl_from!(AttnOutBuf -> MatvecIn);
impl_from!(ValueOutBuf -> MatvecIn);
impl_from!(SharedFfnActBuf -> MatvecIn);
impl_from!(TailNormedBuf -> MatvecIn);
impl_from!(QProjOutBuf -> MatvecOut);
impl_from!(KProjOutBuf -> MatvecOut);
impl_from!(VProjOutBuf -> MatvecOut);
impl_from!(OProjOutBuf -> MatvecOut);
impl_from!(RouterLogitsBuf -> MatvecOut);
impl_from!(SharedGateBuf -> MatvecOut);
impl_from!(SharedFfnGateBuf -> MatvecOut);
impl_from!(SharedFfnUpBuf -> MatvecOut);
impl_from!(SharedFfnDownBuf -> MatvecOut);
impl_from!(LogitsBuf -> MatvecOut);
impl_from!(QkvStackBuf -> MatvecOut);
impl_from!(ZStackBuf -> MatvecOut);
impl_from!(AlphaStackBuf -> MatvecOut);
impl_from!(BetaStackBuf -> MatvecOut);
impl_from!(EmbedOutBuf -> RmsNormIn);
impl_from!(ResidualBuf -> RmsNormIn);
impl_from!(HiddenBuf -> RmsNormIn);
impl_from!(ConvOutBuf -> RmsNormIn);
impl_from!(QBuf -> RmsNormIn);
impl_from!(KProjOutBuf -> RmsNormIn);
impl_from!(AttnInputBuf -> RmsNormOut);
impl_from!(MoeInputBuf -> RmsNormOut);
impl_from!(TailNormedBuf -> RmsNormOut);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn buf_id_construction_and_round_trip() {
let id: BufId<MoeInputBuf> = BufId::from_raw(42);
assert_eq!(id.raw(), 42);
}
#[test]
fn buf_id_eq_hash_by_index_only() {
let a: BufId<MoeInputBuf> = BufId::from_raw(7);
let b: BufId<MoeInputBuf> = BufId::from_raw(7);
let c: BufId<MoeInputBuf> = BufId::from_raw(8);
assert_eq!(a, b);
assert_ne!(a, c);
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut ha = DefaultHasher::new();
let mut hb = DefaultHasher::new();
a.hash(&mut ha);
b.hash(&mut hb);
assert_eq!(ha.finish(), hb.finish());
}
#[test]
fn buf_id_display_matches_legacy() {
let id: BufId<MoeInputBuf> = BufId::from_raw(13);
assert_eq!(format!("{id}"), "%13");
}
#[test]
fn unidirectional_bucket_to_mid_conversion() {
let bucket: BufId<BucketGateBuf> = BufId::from_raw(1);
let mid: BufId<GateMidBuf> = bucket.into();
assert_eq!(mid.raw(), 1);
}
#[test]
fn union_conversion_from_concrete() {
let normed: BufId<AttnInputBuf> = BufId::from_raw(5);
let matvec_in: BufId<MatvecIn> = normed.into();
assert_eq!(matvec_in.raw(), 5);
}
#[test]
fn moe_input_converts_to_matvec_and_rms_out() {
let moe_in: BufId<MoeInputBuf> = BufId::from_raw(99);
let m: BufId<MatvecIn> = moe_in.into();
assert_eq!(m.raw(), 99);
let r: BufId<RmsNormOut> = moe_in.into();
assert_eq!(r.raw(), 99);
}
}