use burn::prelude::*;
use std::time::Instant;
use crate::local_grid_rho::{
CompiledLocalGridRhoPlan, LocalGridNeighborhood, LocalGridRhoPlanSpec, LocalGridShape2d,
local_grid_rho_profile_snapshot, supports_local_grid_rho_backend,
try_fused_local_grid_rho_attention_wgpu_head_decay_with_plan,
};
use crate::profiling::{
KernelProfileSite, KernelProfileSnapshot, profile_enabled, profile_record, profile_reset,
profile_snapshot,
};
static STRUCTURED_PYRAMID_PROFILE: KernelProfileSite = KernelProfileSite::new();
pub type StructuredPyramidProfileSnapshot = KernelProfileSnapshot;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct StructuredPyramidShape {
pub patch: LocalGridShape2d,
pub coarse: LocalGridShape2d,
pub coarse_stride: usize,
pub hub_count: usize,
}
#[derive(Debug, Clone)]
pub struct StructuredPyramidRhoStepInput<B: Backend> {
pub patch_query: Tensor<B, 4>,
pub patch_value: Tensor<B, 4>,
pub coarse_query: Tensor<B, 4>,
pub coarse_value: Tensor<B, 4>,
pub patch_rho: Tensor<B, 5>,
pub coarse_rho: Tensor<B, 5>,
pub hub_rho: Tensor<B, 4>,
pub patch_hub_weights: Option<Tensor<B, 4>>,
pub coarse_hub_weights: Option<Tensor<B, 4>>,
pub neighborhood: LocalGridNeighborhood,
pub decay: Tensor<B, 1>,
}
#[derive(Debug, Clone)]
pub struct StructuredPyramidRhoStepOutput<B: Backend> {
pub patch_local_context: Tensor<B, 4>,
pub coarse_local_context: Tensor<B, 4>,
pub patch_from_coarse_context: Tensor<B, 4>,
pub patch_from_hub_context: Tensor<B, 4>,
pub coarse_from_hub_context: Tensor<B, 4>,
pub next_patch_rho: Tensor<B, 5>,
pub next_coarse_rho: Tensor<B, 5>,
pub next_hub_rho: Tensor<B, 4>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct StructuredPyramidBankMode {
pub patch_local_read: bool,
pub patch_local_write: bool,
pub coarse_local_read: bool,
pub coarse_local_write: bool,
pub patch_from_coarse_read: bool,
pub patch_from_hub_read: bool,
pub coarse_from_hub_read: bool,
pub patch_to_coarse_write: bool,
pub patch_to_global_write: bool,
pub coarse_to_global_write: bool,
}
impl Default for StructuredPyramidBankMode {
fn default() -> Self {
Self {
patch_local_read: true,
patch_local_write: true,
coarse_local_read: true,
coarse_local_write: true,
patch_from_coarse_read: true,
patch_from_hub_read: true,
coarse_from_hub_read: true,
patch_to_coarse_write: true,
patch_to_global_write: true,
coarse_to_global_write: true,
}
}
}
#[derive(Debug, Clone)]
pub struct StructuredPyramidSplitRhoStepInput<B: Backend> {
pub patch_local_query: Tensor<B, 4>,
pub patch_query_for_coarse: Tensor<B, 4>,
pub patch_query_for_global: Tensor<B, 4>,
pub patch_value: Tensor<B, 4>,
pub coarse_local_query: Tensor<B, 4>,
pub coarse_query_for_global: Tensor<B, 4>,
pub coarse_value: Tensor<B, 4>,
pub patch_rho: Tensor<B, 5>,
pub coarse_rho: Tensor<B, 5>,
pub hub_rho: Tensor<B, 4>,
pub patch_hub_weights: Option<Tensor<B, 4>>,
pub coarse_hub_weights: Option<Tensor<B, 4>>,
pub patch_decay: Tensor<B, 1>,
pub coarse_decay: Tensor<B, 1>,
pub global_decay: Tensor<B, 1>,
pub bank_mode: StructuredPyramidBankMode,
}
#[derive(Debug, Clone)]
pub struct StructuredPyramidCoarseOnlyStepInput<B: Backend> {
pub patch_rho: Tensor<B, 5>,
pub coarse_local_query: Tensor<B, 4>,
pub coarse_query_for_global: Tensor<B, 4>,
pub coarse_value: Tensor<B, 4>,
pub coarse_rho: Tensor<B, 5>,
pub hub_rho: Tensor<B, 4>,
pub coarse_hub_weights: Option<Tensor<B, 4>>,
pub coarse_decay: Tensor<B, 1>,
pub global_decay: Tensor<B, 1>,
pub bank_mode: StructuredPyramidBankMode,
}
#[derive(Debug, Clone)]
pub struct StructuredPyramidCoarseOnlyNoPatchStepInput<B: Backend> {
pub coarse_local_query: Tensor<B, 4>,
pub coarse_query_for_global: Tensor<B, 4>,
pub coarse_value: Tensor<B, 4>,
pub coarse_rho: Tensor<B, 5>,
pub hub_rho: Tensor<B, 4>,
pub coarse_hub_weights: Option<Tensor<B, 4>>,
pub coarse_decay: Tensor<B, 1>,
pub global_decay: Tensor<B, 1>,
pub bank_mode: StructuredPyramidBankMode,
}
#[derive(Debug, Clone)]
pub struct StructuredPyramidCoarseOnlyStepOutput<B: Backend> {
pub coarse_local_context: Tensor<B, 4>,
pub coarse_from_hub_context: Tensor<B, 4>,
pub next_coarse_rho: Tensor<B, 5>,
pub next_hub_rho: Tensor<B, 4>,
}
#[derive(Debug, Clone)]
pub struct CompiledStructuredPyramidRhoPlan<B: Backend> {
patch_plan: Option<CompiledLocalGridRhoPlan<B>>,
coarse_plan: Option<CompiledLocalGridRhoPlan<B>>,
patch_from_coarse_route: Option<Tensor<B, 3>>,
patch_to_coarse_pool: Option<Tensor<B, 3>>,
shape: StructuredPyramidShape,
neighborhood: LocalGridNeighborhood,
rank: usize,
value_dim: usize,
}
#[derive(Debug, Clone)]
pub struct CompiledStructuredPyramidSplitPlan<B: Backend> {
patch_plan: Option<CompiledLocalGridRhoPlan<B>>,
coarse_plan: Option<CompiledLocalGridRhoPlan<B>>,
patch_from_coarse_route: Option<Tensor<B, 3>>,
patch_to_coarse_pool: Option<Tensor<B, 3>>,
shape: StructuredPyramidShape,
patch_neighborhood: LocalGridNeighborhood,
coarse_neighborhood: LocalGridNeighborhood,
patch_rank: usize,
coarse_rank: usize,
global_rank: usize,
value_dim: usize,
}
#[derive(Debug, Clone, Copy)]
pub struct StructuredPyramidSplitPlanSpec<'a, B: Backend> {
pub batch: usize,
pub patch_rank: usize,
pub coarse_rank: usize,
pub global_rank: usize,
pub value_dim: usize,
pub shape: StructuredPyramidShape,
pub patch_neighborhood: LocalGridNeighborhood,
pub coarse_neighborhood: LocalGridNeighborhood,
pub device: &'a B::Device,
}
impl<B: Backend> CompiledStructuredPyramidRhoPlan<B> {
pub fn new(
batch: usize,
rank: usize,
value_dim: usize,
shape: StructuredPyramidShape,
neighborhood: LocalGridNeighborhood,
device: &B::Device,
) -> Self {
Self {
patch_plan: (shape.patch.token_count() > 1).then(|| {
CompiledLocalGridRhoPlan::new(
LocalGridRhoPlanSpec {
batch,
heads: rank,
value_heads: 1,
patch_tokens: shape.patch.token_count(),
latent: 1,
embd: value_dim,
grid: shape.patch,
neighborhood,
},
device,
)
}),
coarse_plan: (shape.coarse.token_count() > 1).then(|| {
CompiledLocalGridRhoPlan::new(
LocalGridRhoPlanSpec {
batch,
heads: rank,
value_heads: 1,
patch_tokens: shape.coarse.token_count(),
latent: 1,
embd: value_dim,
grid: shape.coarse,
neighborhood,
},
device,
)
}),
patch_from_coarse_route: (shape.coarse_stride > 1
&& shape.coarse_stride <= 2
&& shape.patch.height == shape.coarse.height * shape.coarse_stride
&& shape.patch.width == shape.coarse.width * shape.coarse_stride)
.then(|| patch_from_coarse_route(batch, shape.patch, shape.coarse, device)),
patch_to_coarse_pool: (shape.coarse_stride > 1
&& shape.patch.height == shape.coarse.height * shape.coarse_stride
&& shape.patch.width == shape.coarse.width * shape.coarse_stride)
.then(|| {
patch_to_coarse_pool_route(batch, shape.patch, shape.coarse_stride, device)
}),
shape,
neighborhood,
rank,
value_dim,
}
}
fn matches(&self, input: &StructuredPyramidRhoStepInput<B>) -> bool {
let [batch, rank, patch_h, patch_w] = input.patch_query.shape().dims::<4>();
let [patch_value_batch, value_dim, patch_value_h, patch_value_w] =
input.patch_value.shape().dims::<4>();
let [coarse_batch, coarse_rank, coarse_h, coarse_w] =
input.coarse_query.shape().dims::<4>();
let [
coarse_value_batch,
coarse_value_dim,
coarse_value_h,
coarse_value_w,
] = input.coarse_value.shape().dims::<4>();
let patch_kernel_ok = if self.shape.patch.token_count() > 1 {
self.patch_plan.is_some()
} else {
self.patch_plan.is_none()
};
let coarse_kernel_ok = if self.shape.coarse.token_count() > 1 {
self.coarse_plan.is_some()
} else {
self.coarse_plan.is_none()
};
batch == patch_value_batch
&& batch == coarse_batch
&& batch == coarse_value_batch
&& rank == self.rank
&& coarse_rank == self.rank
&& value_dim == self.value_dim
&& coarse_value_dim == self.value_dim
&& patch_h == self.shape.patch.height
&& patch_w == self.shape.patch.width
&& patch_value_h == self.shape.patch.height
&& patch_value_w == self.shape.patch.width
&& coarse_h == self.shape.coarse.height
&& coarse_w == self.shape.coarse.width
&& coarse_value_h == self.shape.coarse.height
&& coarse_value_w == self.shape.coarse.width
&& input.patch_rho.shape().dims::<5>()
== [
batch,
self.rank,
self.value_dim,
self.shape.patch.height,
self.shape.patch.width,
]
&& input.coarse_rho.shape().dims::<5>()
== [
batch,
self.rank,
self.value_dim,
self.shape.coarse.height,
self.shape.coarse.width,
]
&& input.hub_rho.shape().dims::<4>()
== [
batch,
self.shape.hub_count.max(1),
self.rank,
self.value_dim,
]
&& input.neighborhood == self.neighborhood
&& patch_kernel_ok
&& coarse_kernel_ok
}
}
impl<B: Backend> CompiledStructuredPyramidSplitPlan<B> {
pub fn new(spec: StructuredPyramidSplitPlanSpec<'_, B>) -> Self {
Self {
patch_plan: (spec.shape.patch.token_count() > 1).then(|| {
CompiledLocalGridRhoPlan::new(
LocalGridRhoPlanSpec {
batch: spec.batch,
heads: spec.patch_rank.max(1),
value_heads: 1,
patch_tokens: spec.shape.patch.token_count(),
latent: 1,
embd: spec.value_dim.max(1),
grid: spec.shape.patch,
neighborhood: spec.patch_neighborhood,
},
spec.device,
)
}),
coarse_plan: (spec.shape.coarse.token_count() > 1).then(|| {
CompiledLocalGridRhoPlan::new(
LocalGridRhoPlanSpec {
batch: spec.batch,
heads: spec.coarse_rank.max(1),
value_heads: 1,
patch_tokens: spec.shape.coarse.token_count(),
latent: 1,
embd: spec.value_dim.max(1),
grid: spec.shape.coarse,
neighborhood: spec.coarse_neighborhood,
},
spec.device,
)
}),
patch_from_coarse_route: (spec.shape.coarse_stride > 1
&& spec.shape.coarse_stride <= 2
&& spec.shape.patch.height == spec.shape.coarse.height * spec.shape.coarse_stride
&& spec.shape.patch.width == spec.shape.coarse.width * spec.shape.coarse_stride)
.then(|| {
patch_from_coarse_route(
spec.batch,
spec.shape.patch,
spec.shape.coarse,
spec.device,
)
}),
patch_to_coarse_pool: (spec.shape.coarse_stride > 1
&& spec.shape.patch.height == spec.shape.coarse.height * spec.shape.coarse_stride
&& spec.shape.patch.width == spec.shape.coarse.width * spec.shape.coarse_stride)
.then(|| {
patch_to_coarse_pool_route(
spec.batch,
spec.shape.patch,
spec.shape.coarse_stride,
spec.device,
)
}),
shape: spec.shape,
patch_neighborhood: spec.patch_neighborhood,
coarse_neighborhood: spec.coarse_neighborhood,
patch_rank: spec.patch_rank.max(1),
coarse_rank: spec.coarse_rank.max(1),
global_rank: spec.global_rank.max(1),
value_dim: spec.value_dim.max(1),
}
}
fn matches_split(&self, input: &StructuredPyramidSplitRhoStepInput<B>) -> bool {
let [batch, patch_rank, patch_h, patch_w] = input.patch_local_query.shape().dims::<4>();
let [patch_value_batch, value_dim, patch_value_h, patch_value_w] =
input.patch_value.shape().dims::<4>();
let [
patch_coarse_batch,
patch_coarse_rank,
patch_coarse_h,
patch_coarse_w,
] = input.patch_query_for_coarse.shape().dims::<4>();
let [
patch_global_batch,
patch_global_rank,
patch_global_h,
patch_global_w,
] = input.patch_query_for_global.shape().dims::<4>();
let [coarse_batch, coarse_rank, coarse_h, coarse_w] =
input.coarse_local_query.shape().dims::<4>();
let [
coarse_global_batch,
coarse_global_rank,
coarse_global_h,
coarse_global_w,
] = input.coarse_query_for_global.shape().dims::<4>();
let [
coarse_value_batch,
coarse_value_dim,
coarse_value_h,
coarse_value_w,
] = input.coarse_value.shape().dims::<4>();
let patch_kernel_ok = if self.shape.patch.token_count() > 1 {
self.patch_plan.is_some()
} else {
self.patch_plan.is_none()
};
let coarse_kernel_ok = if self.shape.coarse.token_count() > 1 {
self.coarse_plan.is_some()
} else {
self.coarse_plan.is_none()
};
batch == patch_value_batch
&& batch == patch_coarse_batch
&& batch == patch_global_batch
&& batch == coarse_batch
&& batch == coarse_global_batch
&& batch == coarse_value_batch
&& patch_rank == self.patch_rank
&& patch_coarse_rank == self.coarse_rank
&& patch_global_rank == self.global_rank
&& coarse_rank == self.coarse_rank
&& coarse_global_rank == self.global_rank
&& value_dim == self.value_dim
&& coarse_value_dim == self.value_dim
&& patch_h == self.shape.patch.height
&& patch_w == self.shape.patch.width
&& patch_value_h == self.shape.patch.height
&& patch_value_w == self.shape.patch.width
&& patch_coarse_h == self.shape.patch.height
&& patch_coarse_w == self.shape.patch.width
&& patch_global_h == self.shape.patch.height
&& patch_global_w == self.shape.patch.width
&& coarse_h == self.shape.coarse.height
&& coarse_w == self.shape.coarse.width
&& coarse_global_h == self.shape.coarse.height
&& coarse_global_w == self.shape.coarse.width
&& coarse_value_h == self.shape.coarse.height
&& coarse_value_w == self.shape.coarse.width
&& input.patch_rho.shape().dims::<5>()
== [
batch,
self.patch_rank,
self.value_dim,
self.shape.patch.height,
self.shape.patch.width,
]
&& input.coarse_rho.shape().dims::<5>()
== [
batch,
self.coarse_rank,
self.value_dim,
self.shape.coarse.height,
self.shape.coarse.width,
]
&& input.hub_rho.shape().dims::<4>()
== [
batch,
self.shape.hub_count.max(1),
self.global_rank,
self.value_dim,
]
&& patch_kernel_ok
&& coarse_kernel_ok
}
fn matches_coarse_only(&self, input: &StructuredPyramidCoarseOnlyStepInput<B>) -> bool {
let [patch_batch, patch_rank, patch_value_dim, patch_h, patch_w] =
input.patch_rho.shape().dims::<5>();
let [coarse_batch, coarse_rank, coarse_h, coarse_w] =
input.coarse_local_query.shape().dims::<4>();
let [
coarse_global_batch,
coarse_global_rank,
coarse_global_h,
coarse_global_w,
] = input.coarse_query_for_global.shape().dims::<4>();
let [
coarse_value_batch,
coarse_value_dim,
coarse_value_h,
coarse_value_w,
] = input.coarse_value.shape().dims::<4>();
let coarse_kernel_ok = if self.shape.coarse.token_count() > 1 {
self.coarse_plan.is_some()
} else {
self.coarse_plan.is_none()
};
patch_batch == coarse_batch
&& patch_batch == coarse_global_batch
&& patch_batch == coarse_value_batch
&& patch_rank == self.patch_rank
&& patch_value_dim == self.value_dim
&& patch_h == self.shape.patch.height
&& patch_w == self.shape.patch.width
&& coarse_rank == self.coarse_rank
&& coarse_global_rank == self.global_rank
&& coarse_value_dim == self.value_dim
&& coarse_h == self.shape.coarse.height
&& coarse_w == self.shape.coarse.width
&& coarse_global_h == self.shape.coarse.height
&& coarse_global_w == self.shape.coarse.width
&& coarse_value_h == self.shape.coarse.height
&& coarse_value_w == self.shape.coarse.width
&& input.coarse_rho.shape().dims::<5>()
== [
patch_batch,
self.coarse_rank,
self.value_dim,
self.shape.coarse.height,
self.shape.coarse.width,
]
&& input.hub_rho.shape().dims::<4>()
== [
patch_batch,
self.shape.hub_count.max(1),
self.global_rank,
self.value_dim,
]
&& coarse_kernel_ok
}
fn matches_coarse_only_no_patch(
&self,
input: &StructuredPyramidCoarseOnlyNoPatchStepInput<B>,
) -> bool {
let [coarse_batch, coarse_rank, coarse_h, coarse_w] =
input.coarse_local_query.shape().dims::<4>();
let [
coarse_global_batch,
coarse_global_rank,
coarse_global_h,
coarse_global_w,
] = input.coarse_query_for_global.shape().dims::<4>();
let [
coarse_value_batch,
coarse_value_dim,
coarse_value_h,
coarse_value_w,
] = input.coarse_value.shape().dims::<4>();
let coarse_kernel_ok = if self.shape.coarse.token_count() > 1 {
self.coarse_plan.is_some()
} else {
self.coarse_plan.is_none()
};
coarse_batch == coarse_global_batch
&& coarse_batch == coarse_value_batch
&& coarse_rank == self.coarse_rank
&& coarse_global_rank == self.global_rank
&& coarse_value_dim == self.value_dim
&& coarse_h == self.shape.coarse.height
&& coarse_w == self.shape.coarse.width
&& coarse_global_h == self.shape.coarse.height
&& coarse_global_w == self.shape.coarse.width
&& coarse_value_h == self.shape.coarse.height
&& coarse_value_w == self.shape.coarse.width
&& input.coarse_rho.shape().dims::<5>()
== [
coarse_batch,
self.coarse_rank,
self.value_dim,
self.shape.coarse.height,
self.shape.coarse.width,
]
&& input.hub_rho.shape().dims::<4>()
== [
coarse_batch,
self.shape.hub_count.max(1),
self.global_rank,
self.value_dim,
]
&& coarse_kernel_ok
}
}
pub fn structured_pyramid_profile_reset() {
profile_reset(&STRUCTURED_PYRAMID_PROFILE);
}
pub fn structured_pyramid_profile_snapshot() -> StructuredPyramidProfileSnapshot {
profile_snapshot(&STRUCTURED_PYRAMID_PROFILE)
}
pub fn reference_structured_pyramid_rho_step<B: Backend>(
shape: StructuredPyramidShape,
input: StructuredPyramidRhoStepInput<B>,
) -> StructuredPyramidRhoStepOutput<B> {
let prof_enabled = profile_enabled();
let total_start = prof_enabled.then(Instant::now);
let patch_local_context = local_read(
input.patch_rho.clone(),
input.patch_query.clone(),
shape.patch,
input.neighborhood,
);
let coarse_local_context = local_read(
input.coarse_rho.clone(),
input.coarse_query.clone(),
shape.coarse,
input.neighborhood,
);
let patch_from_coarse_context = cross_scale_read(
input.coarse_rho.clone(),
input.patch_query.clone(),
shape.coarse,
shape.patch,
shape.coarse_stride.max(1),
);
let (patch_from_hub_context, coarse_from_hub_context) = hub_read_pair(
input.hub_rho.clone(),
input.patch_query.clone(),
input.coarse_query.clone(),
input.patch_hub_weights,
input.coarse_hub_weights,
shape.patch,
shape.coarse,
);
let patch_update = spatial_outer_target_major(input.patch_query, input.patch_value);
let coarse_update = spatial_outer_target_major(input.coarse_query, input.coarse_value);
let pooled_patch_update = pool_target_major_outer_fast(
patch_update.clone(),
shape.patch,
shape.coarse,
shape.coarse_stride.max(1),
None,
)
.unwrap_or_else(|| {
pool_target_major_outer(
patch_update.clone(),
shape.patch,
shape.coarse_stride.max(1),
)
});
let next_patch_rho = rho_from_target_major(
target_major_decay_add(
rho_to_target_major(input.patch_rho),
patch_update.clone(),
input.decay.clone(),
),
shape.patch,
);
let next_coarse_rho = rho_from_target_major(
target_major_decay_add(
rho_to_target_major(input.coarse_rho),
coarse_update.clone().add(pooled_patch_update),
input.decay.clone(),
),
shape.coarse,
);
let next_hub_rho = update_hub_from_deltas(
input.hub_rho,
patch_update.sum_dims_squeeze::<3, usize>(&[1]),
coarse_update.sum_dims_squeeze::<3, usize>(&[1]),
shape.hub_count.max(1),
input.decay,
);
let output = StructuredPyramidRhoStepOutput {
patch_local_context,
coarse_local_context,
patch_from_coarse_context,
patch_from_hub_context,
coarse_from_hub_context,
next_patch_rho,
next_coarse_rho,
next_hub_rho,
};
if let Some(start) = total_start {
profile_record(&STRUCTURED_PYRAMID_PROFILE, |state| {
state.calls = state.calls.saturating_add(1);
state.total_ns = state.total_ns.saturating_add(start.elapsed().as_nanos());
state.transient_allocations = state.transient_allocations.saturating_add(6);
state.resident_rollout_steps = state.resident_rollout_steps.saturating_add(1);
state.metadata_reuse_hits = state.metadata_reuse_hits.saturating_add(1);
});
}
output
}
pub fn try_fused_structured_pyramid_rho_step_wgpu<B: Backend>(
shape: StructuredPyramidShape,
input: StructuredPyramidRhoStepInput<B>,
) -> Option<StructuredPyramidRhoStepOutput<B>>
where
B::FloatTensorPrimitive: 'static,
{
let [batch, rank, _, _] = input.patch_query.shape().dims::<4>();
let value_dim = input.patch_value.shape().dims::<4>()[1];
let plan = CompiledStructuredPyramidRhoPlan::new(
batch,
rank,
value_dim,
shape,
input.neighborhood,
&input.patch_query.device(),
);
let output = try_fused_structured_pyramid_rho_step_wgpu_with_plan(shape, input, &plan);
if output.is_some() {
profile_record(&STRUCTURED_PYRAMID_PROFILE, |state| {
let reused_bytes = (2 * 11 * core::mem::size_of::<f32>()) as u64;
state.metadata_reuse_hits = state.metadata_reuse_hits.saturating_sub(2);
state.metadata_reuse_bytes = state.metadata_reuse_bytes.saturating_sub(reused_bytes);
state.metadata_upload_bytes = state.metadata_upload_bytes.saturating_add(reused_bytes);
});
}
output
}
pub fn supports_structured_pyramid_rho_backend<B: Backend>() -> bool
where
B::FloatTensorPrimitive: 'static,
{
supports_local_grid_rho_backend::<B>()
}
pub fn try_fused_structured_pyramid_rho_step_wgpu_with_plan<B: Backend>(
shape: StructuredPyramidShape,
input: StructuredPyramidRhoStepInput<B>,
plan: &CompiledStructuredPyramidRhoPlan<B>,
) -> Option<StructuredPyramidRhoStepOutput<B>>
where
B::FloatTensorPrimitive: 'static,
{
if !supports_structured_pyramid_rho_backend::<B>()
|| plan.shape != shape
|| !plan.matches(&input)
{
return None;
}
let prof_enabled = profile_enabled();
let total_start = prof_enabled.then(Instant::now);
let local_before = local_grid_rho_profile_snapshot();
let patch_local = if shape.patch.token_count() == 1 {
reference_local_step(
input.patch_query.clone(),
input.patch_value.clone(),
input.patch_rho.clone(),
shape.patch,
input.neighborhood,
input.decay.clone(),
)
} else {
fused_local_step(
input.patch_query.clone(),
input.patch_value.clone(),
input.patch_rho.clone(),
shape.patch,
input.decay.clone(),
plan.patch_plan
.as_ref()
.expect("non-degenerate patch bank requires compiled local-grid plan"),
)?
};
let coarse_local = if shape.coarse.token_count() == 1 {
let reference = reference_local_step(
input.coarse_query.clone(),
input.coarse_value.clone(),
input.coarse_rho.clone(),
shape.coarse,
input.neighborhood,
input.decay.clone(),
);
(reference.0, rho_to_target_major(reference.1))
} else {
fused_local_step_target_major(
input.coarse_query.clone(),
input.coarse_value.clone(),
input.coarse_rho.clone(),
shape.coarse,
input.decay.clone(),
plan.coarse_plan
.as_ref()
.expect("non-degenerate coarse bank requires compiled local-grid plan"),
)?
};
let patch_update =
spatial_outer_target_major(input.patch_query.clone(), input.patch_value.clone());
let coarse_update =
spatial_outer_target_major(input.coarse_query.clone(), input.coarse_value.clone());
let pooled_patch_update = pool_target_major_outer_fast(
patch_update.clone(),
shape.patch,
shape.coarse,
shape.coarse_stride.max(1),
plan.patch_to_coarse_pool.clone(),
)
.unwrap_or_else(|| {
pool_target_major_outer(
patch_update.clone(),
shape.patch,
shape.coarse_stride.max(1),
)
});
let next_coarse_rho = rho_from_target_major(
coarse_local.1.clone().add(pooled_patch_update),
shape.coarse,
);
let next_hub_rho = update_hub_from_deltas(
input.hub_rho.clone(),
patch_update.sum_dims_squeeze::<3, usize>(&[1]),
coarse_update.sum_dims_squeeze::<3, usize>(&[1]),
shape.hub_count.max(1),
input.decay.clone(),
);
let (patch_from_hub_context, coarse_from_hub_context) = hub_read_pair(
input.hub_rho.clone(),
input.patch_query.clone(),
input.coarse_query.clone(),
input.patch_hub_weights.clone(),
input.coarse_hub_weights.clone(),
shape.patch,
shape.coarse,
);
let output = StructuredPyramidRhoStepOutput {
patch_local_context: patch_local.0,
coarse_local_context: coarse_local.0,
patch_from_coarse_context: cross_scale_read_fast(
input.coarse_rho.clone(),
input.patch_query.clone(),
shape.coarse,
shape.patch,
shape.coarse_stride.max(1),
plan.patch_from_coarse_route.clone(),
)
.unwrap_or_else(|| {
cross_scale_read(
input.coarse_rho.clone(),
input.patch_query.clone(),
shape.coarse,
shape.patch,
shape.coarse_stride.max(1),
)
}),
patch_from_hub_context,
coarse_from_hub_context,
next_patch_rho: patch_local.1,
next_coarse_rho,
next_hub_rho,
};
if let Some(start) = total_start {
let local_after = local_grid_rho_profile_snapshot();
profile_record(&STRUCTURED_PYRAMID_PROFILE, |state| {
state.calls = state.calls.saturating_add(1);
state.total_ns = state.total_ns.saturating_add(start.elapsed().as_nanos());
state.launches = state
.launches
.saturating_add(local_after.launches.saturating_sub(local_before.launches));
state.dispatch_ns = state.dispatch_ns.saturating_add(
local_after
.dispatch_ns
.saturating_sub(local_before.dispatch_ns),
);
state.transient_allocations = state.transient_allocations.saturating_add(
local_after
.transient_allocations
.saturating_sub(local_before.transient_allocations)
.saturating_add(4),
);
state.metadata_upload_bytes = state.metadata_upload_bytes.saturating_add(
local_after
.metadata_upload_bytes
.saturating_sub(local_before.metadata_upload_bytes),
);
state.metadata_reuse_hits = state.metadata_reuse_hits.saturating_add(
local_after
.metadata_reuse_hits
.saturating_sub(local_before.metadata_reuse_hits),
);
state.metadata_reuse_bytes = state.metadata_reuse_bytes.saturating_add(
local_after
.metadata_reuse_bytes
.saturating_sub(local_before.metadata_reuse_bytes),
);
state.resident_rollout_steps = state.resident_rollout_steps.saturating_add(1);
});
}
Some(output)
}
pub fn try_fused_structured_pyramid_split_step_wgpu_with_plan<B: Backend>(
shape: StructuredPyramidShape,
input: StructuredPyramidSplitRhoStepInput<B>,
plan: &CompiledStructuredPyramidSplitPlan<B>,
) -> Option<StructuredPyramidRhoStepOutput<B>>
where
B::FloatTensorPrimitive: 'static,
{
if !supports_structured_pyramid_rho_backend::<B>()
|| plan.shape != shape
|| !plan.matches_split(&input)
{
return None;
}
let prof_enabled = profile_enabled();
let total_start = prof_enabled.then(Instant::now);
let local_before = local_grid_rho_profile_snapshot();
let bank_mode = input.bank_mode;
let patch_zero = || zeros_like_spatial(&input.patch_value);
let coarse_zero = || zeros_like_spatial(&input.coarse_value);
let patch_local_value = if bank_mode.patch_local_write {
input.patch_value.clone()
} else {
zeros_like_spatial(&input.patch_value)
};
let coarse_local_value = if bank_mode.coarse_local_write {
input.coarse_value.clone()
} else {
zeros_like_spatial(&input.coarse_value)
};
let patch_local = if shape.patch.token_count() == 1 {
reference_local_step(
input.patch_local_query.clone(),
patch_local_value,
input.patch_rho.clone(),
shape.patch,
plan.patch_neighborhood,
input.patch_decay.clone(),
)
} else {
fused_local_step(
input.patch_local_query.clone(),
patch_local_value,
input.patch_rho.clone(),
shape.patch,
input.patch_decay.clone(),
plan.patch_plan.as_ref().expect("patch plan"),
)?
};
let patch_local_context = if bank_mode.patch_local_read {
patch_local.0
} else {
patch_zero()
};
let coarse_local = if shape.coarse.token_count() == 1 {
let (context, next_rho) = reference_local_step(
input.coarse_local_query.clone(),
coarse_local_value,
input.coarse_rho.clone(),
shape.coarse,
plan.coarse_neighborhood,
input.coarse_decay.clone(),
);
(context, rho_to_target_major(next_rho))
} else {
fused_local_step_target_major(
input.coarse_local_query.clone(),
coarse_local_value,
input.coarse_rho.clone(),
shape.coarse,
input.coarse_decay.clone(),
plan.coarse_plan.as_ref().expect("coarse plan"),
)?
};
let coarse_local_context = if bank_mode.coarse_local_read {
coarse_local.0
} else {
coarse_zero()
};
let patch_from_coarse_context = if bank_mode.patch_from_coarse_read {
cross_scale_read_fast(
input.coarse_rho.clone(),
input.patch_query_for_coarse.clone(),
shape.coarse,
shape.patch,
shape.coarse_stride.max(1),
plan.patch_from_coarse_route.clone(),
)
.unwrap_or_else(|| {
cross_scale_read(
input.coarse_rho.clone(),
input.patch_query_for_coarse.clone(),
shape.coarse,
shape.patch,
shape.coarse_stride.max(1),
)
})
} else {
patch_zero()
};
let (patch_from_hub_context, coarse_from_hub_context) =
if bank_mode.patch_from_hub_read && bank_mode.coarse_from_hub_read {
hub_read_pair(
input.hub_rho.clone(),
input.patch_query_for_global.clone(),
input.coarse_query_for_global.clone(),
input.patch_hub_weights.clone(),
input.coarse_hub_weights.clone(),
shape.patch,
shape.coarse,
)
} else {
let patch_from_hub_context = if bank_mode.patch_from_hub_read {
hub_read(
input.hub_rho.clone(),
input.patch_query_for_global.clone(),
input.patch_hub_weights.clone(),
)
} else {
patch_zero()
};
let coarse_from_hub_context = if bank_mode.coarse_from_hub_read {
hub_read(
input.hub_rho.clone(),
input.coarse_query_for_global.clone(),
input.coarse_hub_weights.clone(),
)
} else {
coarse_zero()
};
(patch_from_hub_context, coarse_from_hub_context)
};
let patch_to_coarse_update = if bank_mode.patch_to_coarse_write {
let patch_update = spatial_outer_target_major(
input.patch_query_for_coarse.clone(),
input.patch_value.clone(),
);
pool_target_major_outer_fast(
patch_update.clone(),
shape.patch,
shape.coarse,
shape.coarse_stride.max(1),
plan.patch_to_coarse_pool.clone(),
)
.unwrap_or_else(|| {
pool_target_major_outer(patch_update, shape.patch, shape.coarse_stride.max(1))
})
} else {
zeros_like_target_major_rho(
input.coarse_rho.shape().dims::<5>()[0],
shape.coarse.token_count(),
plan.coarse_rank,
plan.value_dim,
&input.coarse_rho.device(),
)
};
let next_coarse_rho =
rho_from_target_major(coarse_local.1.add(patch_to_coarse_update), shape.coarse);
let patch_to_global_update = bank_mode
.patch_to_global_write
.then(|| spatial_outer_spatial(input.patch_query_for_global, input.patch_value));
let coarse_to_global_update = bank_mode
.coarse_to_global_write
.then(|| spatial_outer_spatial(input.coarse_query_for_global, input.coarse_value));
let next_hub_rho = update_hub_from_spatial_updates(
input.hub_rho,
patch_to_global_update,
coarse_to_global_update,
input.patch_hub_weights,
input.coarse_hub_weights,
shape.hub_count.max(1),
input.global_decay,
);
let output = StructuredPyramidRhoStepOutput {
patch_local_context,
coarse_local_context,
patch_from_coarse_context,
patch_from_hub_context,
coarse_from_hub_context,
next_patch_rho: patch_local.1,
next_coarse_rho,
next_hub_rho,
};
if let Some(start) = total_start {
let local_after = local_grid_rho_profile_snapshot();
profile_record(&STRUCTURED_PYRAMID_PROFILE, |state| {
state.calls = state.calls.saturating_add(1);
state.total_ns = state.total_ns.saturating_add(start.elapsed().as_nanos());
state.launches = state
.launches
.saturating_add(local_after.launches.saturating_sub(local_before.launches));
state.dispatch_ns = state.dispatch_ns.saturating_add(
local_after
.dispatch_ns
.saturating_sub(local_before.dispatch_ns),
);
state.transient_allocations = state.transient_allocations.saturating_add(
local_after
.transient_allocations
.saturating_sub(local_before.transient_allocations)
.saturating_add(4),
);
state.metadata_upload_bytes = state.metadata_upload_bytes.saturating_add(
local_after
.metadata_upload_bytes
.saturating_sub(local_before.metadata_upload_bytes),
);
state.metadata_reuse_hits = state.metadata_reuse_hits.saturating_add(
local_after
.metadata_reuse_hits
.saturating_sub(local_before.metadata_reuse_hits),
);
state.metadata_reuse_bytes = state.metadata_reuse_bytes.saturating_add(
local_after
.metadata_reuse_bytes
.saturating_sub(local_before.metadata_reuse_bytes),
);
state.resident_rollout_steps = state.resident_rollout_steps.saturating_add(1);
});
}
Some(output)
}
pub fn try_fused_structured_pyramid_coarse_only_step_wgpu_with_plan<B: Backend>(
shape: StructuredPyramidShape,
input: StructuredPyramidCoarseOnlyStepInput<B>,
plan: &CompiledStructuredPyramidSplitPlan<B>,
) -> Option<StructuredPyramidRhoStepOutput<B>>
where
B::FloatTensorPrimitive: 'static,
{
if !supports_structured_pyramid_rho_backend::<B>()
|| plan.shape != shape
|| !plan.matches_coarse_only(&input)
{
return None;
}
let [batch, _, value_dim, height, width] = input.patch_rho.shape().dims::<5>();
let patch_zero =
|| Tensor::<B, 4>::zeros([batch, value_dim, height, width], &input.patch_rho.device());
let patch_rho = input.patch_rho.clone();
let output = try_fused_structured_pyramid_coarse_only_no_patch_step_wgpu_with_plan(
shape,
StructuredPyramidCoarseOnlyNoPatchStepInput {
coarse_local_query: input.coarse_local_query,
coarse_query_for_global: input.coarse_query_for_global,
coarse_value: input.coarse_value,
coarse_rho: input.coarse_rho,
hub_rho: input.hub_rho,
coarse_hub_weights: input.coarse_hub_weights,
coarse_decay: input.coarse_decay,
global_decay: input.global_decay,
bank_mode: input.bank_mode,
},
plan,
)?;
Some(StructuredPyramidRhoStepOutput {
patch_local_context: patch_zero(),
coarse_local_context: output.coarse_local_context,
patch_from_coarse_context: patch_zero(),
patch_from_hub_context: patch_zero(),
coarse_from_hub_context: output.coarse_from_hub_context,
next_patch_rho: patch_rho,
next_coarse_rho: output.next_coarse_rho,
next_hub_rho: output.next_hub_rho,
})
}
pub fn try_fused_structured_pyramid_coarse_only_no_patch_step_wgpu_with_plan<B: Backend>(
shape: StructuredPyramidShape,
input: StructuredPyramidCoarseOnlyNoPatchStepInput<B>,
plan: &CompiledStructuredPyramidSplitPlan<B>,
) -> Option<StructuredPyramidCoarseOnlyStepOutput<B>>
where
B::FloatTensorPrimitive: 'static,
{
if !supports_structured_pyramid_rho_backend::<B>()
|| plan.shape != shape
|| !plan.matches_coarse_only_no_patch(&input)
{
return None;
}
let prof_enabled = profile_enabled();
let total_start = prof_enabled.then(Instant::now);
let local_before = local_grid_rho_profile_snapshot();
let bank_mode = input.bank_mode;
let coarse_zero = || zeros_like_spatial(&input.coarse_value);
let coarse_local_value = if bank_mode.coarse_local_write {
input.coarse_value.clone()
} else {
zeros_like_spatial(&input.coarse_value)
};
let coarse_local = if shape.coarse.token_count() == 1 {
let (context, next_rho) = reference_local_step(
input.coarse_local_query.clone(),
coarse_local_value,
input.coarse_rho.clone(),
shape.coarse,
plan.coarse_neighborhood,
input.coarse_decay.clone(),
);
(context, rho_to_target_major(next_rho))
} else {
fused_local_step_target_major(
input.coarse_local_query.clone(),
coarse_local_value,
input.coarse_rho.clone(),
shape.coarse,
input.coarse_decay.clone(),
plan.coarse_plan.as_ref().expect("coarse plan"),
)?
};
let coarse_local_context = if bank_mode.coarse_local_read {
coarse_local.0
} else {
coarse_zero()
};
let coarse_from_hub_context = if bank_mode.coarse_from_hub_read {
hub_read(
input.hub_rho.clone(),
input.coarse_query_for_global.clone(),
input.coarse_hub_weights.clone(),
)
} else {
coarse_zero()
};
let coarse_to_global_update = bank_mode
.coarse_to_global_write
.then(|| spatial_outer_spatial(input.coarse_query_for_global, input.coarse_value));
let next_hub_rho = update_hub_from_spatial_updates(
input.hub_rho,
None,
coarse_to_global_update,
None,
input.coarse_hub_weights,
shape.hub_count.max(1),
input.global_decay,
);
let output = StructuredPyramidCoarseOnlyStepOutput {
coarse_local_context,
coarse_from_hub_context,
next_coarse_rho: rho_from_target_major(coarse_local.1, shape.coarse),
next_hub_rho,
};
if let Some(start) = total_start {
let local_after = local_grid_rho_profile_snapshot();
profile_record(&STRUCTURED_PYRAMID_PROFILE, |state| {
state.calls = state.calls.saturating_add(1);
state.total_ns = state.total_ns.saturating_add(start.elapsed().as_nanos());
state.launches = state
.launches
.saturating_add(local_after.launches.saturating_sub(local_before.launches));
state.dispatch_ns = state.dispatch_ns.saturating_add(
local_after
.dispatch_ns
.saturating_sub(local_before.dispatch_ns),
);
state.transient_allocations = state.transient_allocations.saturating_add(
local_after
.transient_allocations
.saturating_sub(local_before.transient_allocations)
.saturating_add(2),
);
state.metadata_upload_bytes = state.metadata_upload_bytes.saturating_add(
local_after
.metadata_upload_bytes
.saturating_sub(local_before.metadata_upload_bytes),
);
state.metadata_reuse_hits = state.metadata_reuse_hits.saturating_add(
local_after
.metadata_reuse_hits
.saturating_sub(local_before.metadata_reuse_hits),
);
state.metadata_reuse_bytes = state.metadata_reuse_bytes.saturating_add(
local_after
.metadata_reuse_bytes
.saturating_sub(local_before.metadata_reuse_bytes),
);
state.resident_rollout_steps = state.resident_rollout_steps.saturating_add(1);
});
}
Some(output)
}
fn tokens_to_target_major<B: Backend>(input: Tensor<B, 4>) -> Tensor<B, 3> {
let [batch, channels, height, width] = input.shape().dims::<4>();
input
.swap_dims(1, 3)
.swap_dims(1, 2)
.reshape([batch, height * width, channels])
}
fn zeros_like_spatial<B: Backend>(input: &Tensor<B, 4>) -> Tensor<B, 4> {
let [batch, channels, height, width] = input.shape().dims::<4>();
Tensor::<B, 4>::zeros([batch, channels, height, width], &input.device())
}
fn zeros_like_target_major_rho<B: Backend>(
batch: usize,
tokens: usize,
rank: usize,
value_dim: usize,
device: &B::Device,
) -> Tensor<B, 4> {
Tensor::<B, 4>::zeros(
[batch, tokens.max(1), rank.max(1), value_dim.max(1)],
device,
)
}
fn tokens_from_target_major<B: Backend>(
input: Tensor<B, 3>,
shape: LocalGridShape2d,
) -> Tensor<B, 4> {
let [batch, tokens, channels] = input.shape().dims::<3>();
assert_eq!(tokens, shape.token_count());
input
.reshape([batch, shape.height, shape.width, channels])
.swap_dims(1, 3)
.swap_dims(2, 3)
}
fn rho_to_target_major<B: Backend>(input: Tensor<B, 5>) -> Tensor<B, 4> {
let [batch, rank, value_dim, height, width] = input.shape().dims::<5>();
input
.swap_dims(1, 3)
.swap_dims(2, 4)
.reshape([batch, height * width, rank, value_dim])
}
fn rho_from_target_major<B: Backend>(input: Tensor<B, 4>, shape: LocalGridShape2d) -> Tensor<B, 5> {
let [batch, tokens, rank, value_dim] = input.shape().dims::<4>();
assert_eq!(tokens, shape.token_count());
input
.reshape([batch, shape.height, shape.width, rank, value_dim])
.swap_dims(2, 4)
.swap_dims(1, 3)
}
fn contract<B: Backend>(
rho: Tensor<B, 5>,
query: Tensor<B, 4>,
shape: LocalGridShape2d,
) -> Tensor<B, 4> {
let read = target_major_identity_read(tokens_to_target_major(query), rho_to_target_major(rho));
tokens_from_target_major(read, shape)
}
fn spatial_tokens_to_local_grid<B: Backend>(input: Tensor<B, 4>) -> Tensor<B, 4> {
let [batch, channels, height, width] = input.shape().dims::<4>();
tokens_to_target_major(input)
.swap_dims(1, 2)
.reshape([batch, channels, height * width, 1])
}
fn spatial_values_to_local_grid<B: Backend>(input: Tensor<B, 4>) -> Tensor<B, 4> {
let [batch, channels, height, width] = input.shape().dims::<4>();
tokens_to_target_major(input).reshape([batch, 1, height * width, channels])
}
fn spatial_rho_to_local_grid<B: Backend>(input: Tensor<B, 5>) -> Tensor<B, 5> {
let [batch, rank, value_dim, height, width] = input.shape().dims::<5>();
rho_to_target_major(input)
.swap_dims(1, 2)
.reshape([batch, rank, height * width, 1, value_dim])
}
fn local_grid_context_to_spatial<B: Backend>(
input: Tensor<B, 4>,
shape: LocalGridShape2d,
) -> Tensor<B, 4> {
let [batch, _, tokens, value_dim] = input.shape().dims::<4>();
let target_major = input.sum_dim(1).reshape([batch, tokens, value_dim]);
tokens_from_target_major(target_major, shape)
}
fn local_grid_rho_to_spatial<B: Backend>(
input: Tensor<B, 5>,
shape: LocalGridShape2d,
) -> Tensor<B, 5> {
rho_from_target_major(local_grid_rho_to_target_major(input), shape)
}
fn local_grid_rho_to_target_major<B: Backend>(input: Tensor<B, 5>) -> Tensor<B, 4> {
let [batch, rank, tokens, _, value_dim] = input.shape().dims::<5>();
input
.reshape([batch, rank, tokens, value_dim])
.swap_dims(1, 2)
}
fn fused_local_step<B: Backend>(
query: Tensor<B, 4>,
value: Tensor<B, 4>,
rho: Tensor<B, 5>,
shape: LocalGridShape2d,
decay: Tensor<B, 1>,
plan: &CompiledLocalGridRhoPlan<B>,
) -> Option<(Tensor<B, 4>, Tensor<B, 5>)>
where
B::FloatTensorPrimitive: 'static,
{
let output = try_fused_local_grid_rho_attention_wgpu_head_decay_with_plan(
&spatial_tokens_to_local_grid(query),
&spatial_values_to_local_grid(value),
Some(&spatial_rho_to_local_grid(rho)),
&decay,
plan,
)?;
Some((
local_grid_context_to_spatial(output.context, shape),
local_grid_rho_to_spatial(output.rho, shape),
))
}
fn fused_local_step_target_major<B: Backend>(
query: Tensor<B, 4>,
value: Tensor<B, 4>,
rho: Tensor<B, 5>,
shape: LocalGridShape2d,
decay: Tensor<B, 1>,
plan: &CompiledLocalGridRhoPlan<B>,
) -> Option<(Tensor<B, 4>, Tensor<B, 4>)>
where
B::FloatTensorPrimitive: 'static,
{
let output = try_fused_local_grid_rho_attention_wgpu_head_decay_with_plan(
&spatial_tokens_to_local_grid(query),
&spatial_values_to_local_grid(value),
Some(&spatial_rho_to_local_grid(rho)),
&decay,
plan,
)?;
Some((
local_grid_context_to_spatial(output.context, shape),
local_grid_rho_to_target_major(output.rho),
))
}
fn reference_local_step<B: Backend>(
query: Tensor<B, 4>,
value: Tensor<B, 4>,
rho: Tensor<B, 5>,
shape: LocalGridShape2d,
neighborhood: LocalGridNeighborhood,
decay: Tensor<B, 1>,
) -> (Tensor<B, 4>, Tensor<B, 5>) {
let context = local_read(rho.clone(), query.clone(), shape, neighborhood);
let next_rho = rho_from_target_major(
target_major_decay_add(
rho_to_target_major(rho),
spatial_outer_target_major(query, value),
decay,
),
shape,
);
(context, next_rho)
}
fn spatial_outer_target_major<B: Backend>(
query: Tensor<B, 4>,
value: Tensor<B, 4>,
) -> Tensor<B, 4> {
target_major_outer_product(tokens_to_target_major(query), tokens_to_target_major(value))
}
fn spatial_outer_spatial<B: Backend>(query: Tensor<B, 4>, value: Tensor<B, 4>) -> Tensor<B, 5> {
let [_, _, height, width] = query.shape().dims::<4>();
rho_from_target_major(
spatial_outer_target_major(query, value),
LocalGridShape2d::new(height, width),
)
}
fn shift_spatial<B: Backend>(input: Tensor<B, 4>, dy: isize, dx: isize) -> Tensor<B, 4> {
let [batch, channels, height, width] = input.shape().dims::<4>();
let device = input.device();
let mut output = input;
if dy != 0 {
let shift = dy.unsigned_abs();
if shift >= height {
output = Tensor::<B, 4>::zeros([batch, channels, height, width], &device);
} else if dy > 0 {
let pad = Tensor::<B, 4>::zeros([batch, channels, shift, width], &device);
output = Tensor::cat(vec![pad, output.slice_dim(2, 0..height - shift)], 2);
} else {
let pad = Tensor::<B, 4>::zeros([batch, channels, shift, width], &device);
output = Tensor::cat(vec![output.slice_dim(2, shift..height), pad], 2);
}
}
if dx != 0 {
let shift = dx.unsigned_abs();
if shift >= width {
output = Tensor::<B, 4>::zeros([batch, channels, height, width], &device);
} else if dx > 0 {
let pad = Tensor::<B, 4>::zeros([batch, channels, height, shift], &device);
output = Tensor::cat(vec![pad, output.slice_dim(3, 0..width - shift)], 3);
} else {
let pad = Tensor::<B, 4>::zeros([batch, channels, height, shift], &device);
output = Tensor::cat(vec![output.slice_dim(3, shift..width), pad], 3);
}
}
output
}
fn local_read<B: Backend>(
rho: Tensor<B, 5>,
query: Tensor<B, 4>,
shape: LocalGridShape2d,
neighborhood: LocalGridNeighborhood,
) -> Tensor<B, 4> {
let [batch, _, value_dim, _, _] = rho.shape().dims::<5>();
let mut acc = Tensor::<B, 4>::zeros(
[batch, value_dim, shape.height.max(1), shape.width.max(1)],
&rho.device(),
);
let radius = neighborhood.radius as isize;
if neighborhood.self_edges {
acc = acc.add(contract(rho.clone(), query.clone(), shape));
}
for dy in -radius..=radius {
for dx in -radius..=radius {
if dy == 0 && dx == 0 {
continue;
}
if !neighborhood.diagonals && dy != 0 && dx != 0 {
continue;
}
let shifted_query = shift_spatial(query.clone(), dy, dx);
let msg = contract(rho.clone(), shifted_query, shape);
acc = acc.add(shift_spatial(msg, -dy, -dx));
}
}
acc
}
fn cross_scale_read<B: Backend>(
coarse_rho: Tensor<B, 5>,
patch_query: Tensor<B, 4>,
coarse_shape: LocalGridShape2d,
patch_shape: LocalGridShape2d,
coarse_stride: usize,
) -> Tensor<B, 4> {
if coarse_stride <= 1 {
return contract(coarse_rho, patch_query, patch_shape);
}
let mut up = coarse_rho
.repeat_dim(3, coarse_stride)
.repeat_dim(4, coarse_stride);
let [_, _, _, up_h, up_w] = up.shape().dims::<5>();
if up_h != patch_shape.height {
up = up.slice_dim(3, 0..patch_shape.height.min(up_h));
}
if up_w != patch_shape.width {
up = up.slice_dim(4, 0..patch_shape.width.min(up_w));
}
let effective = LocalGridShape2d::new(
patch_shape.height.min(coarse_shape.height * coarse_stride),
patch_shape.width.min(coarse_shape.width * coarse_stride),
);
contract(up, patch_query, effective)
}
fn cross_scale_read_fast<B: Backend>(
coarse_rho: Tensor<B, 5>,
patch_query: Tensor<B, 4>,
coarse_shape: LocalGridShape2d,
patch_shape: LocalGridShape2d,
coarse_stride: usize,
route: Option<Tensor<B, 3>>,
) -> Option<Tensor<B, 4>> {
if let Some(context) = cross_scale_read_with_route(
coarse_rho.clone(),
patch_query.clone(),
coarse_shape,
patch_shape,
coarse_stride,
route,
) {
return Some(context);
}
cross_scale_read_tiled(
coarse_rho,
patch_query,
coarse_shape,
patch_shape,
coarse_stride,
)
}
fn cross_scale_read_with_route<B: Backend>(
coarse_rho: Tensor<B, 5>,
patch_query: Tensor<B, 4>,
coarse_shape: LocalGridShape2d,
patch_shape: LocalGridShape2d,
coarse_stride: usize,
route: Option<Tensor<B, 3>>,
) -> Option<Tensor<B, 4>> {
if coarse_stride <= 1 {
return Some(contract(coarse_rho, patch_query, patch_shape));
}
let route = route?;
let [batch, patch_tokens, coarse_tokens] = route.shape().dims::<3>();
let [rho_batch, rank, value_dim, coarse_h, coarse_w] = coarse_rho.shape().dims::<5>();
let [query_batch, query_rank, patch_h, patch_w] = patch_query.shape().dims::<4>();
if rho_batch != batch
|| query_batch != batch
|| query_rank != rank
|| coarse_h != coarse_shape.height
|| coarse_w != coarse_shape.width
|| patch_h != patch_shape.height
|| patch_w != patch_shape.width
|| coarse_tokens != coarse_shape.token_count()
|| patch_tokens != patch_shape.token_count()
{
return None;
}
let routed_rho =
rho_to_target_major(coarse_rho).reshape([batch, coarse_tokens, rank * value_dim]);
let routed_rho = route
.matmul(routed_rho)
.reshape([batch, patch_tokens, rank, value_dim]);
let patch_query = tokens_to_target_major(patch_query);
let context = routed_rho
.mul(patch_query.unsqueeze_dim::<4>(3))
.sum_dims_squeeze::<3, usize>(&[2]);
Some(tokens_from_target_major(context, patch_shape))
}
fn patch_from_coarse_route<B: Backend>(
batch: usize,
patch_shape: LocalGridShape2d,
coarse_shape: LocalGridShape2d,
device: &B::Device,
) -> Tensor<B, 3> {
let patch_tokens = patch_shape.token_count();
let coarse_tokens = coarse_shape.token_count();
let mut data = vec![0.0_f32; batch * patch_tokens * coarse_tokens];
for b in 0..batch {
let batch_offset = b * patch_tokens * coarse_tokens;
for py in 0..patch_shape.height {
for px in 0..patch_shape.width {
let patch_idx = py * patch_shape.width + px;
let coarse_idx =
(py % coarse_shape.height) * coarse_shape.width + (px % coarse_shape.width);
data[batch_offset + patch_idx * coarse_tokens + coarse_idx] = 1.0;
}
}
}
Tensor::<B, 3>::from_data(
TensorData::new(data, [batch, patch_tokens, coarse_tokens]),
device,
)
}
fn patch_to_coarse_pool_route<B: Backend>(
batch: usize,
patch_shape: LocalGridShape2d,
coarse_stride: usize,
device: &B::Device,
) -> Tensor<B, 3> {
let patch_tokens = patch_shape.token_count();
let pooled_height = (patch_shape.height / coarse_stride.max(1)).max(1);
let pooled_width = (patch_shape.width / coarse_stride.max(1)).max(1);
let coarse_tokens = pooled_height * pooled_width;
let mut data = vec![0.0_f32; batch * coarse_tokens * patch_tokens];
for b in 0..batch {
let batch_offset = b * coarse_tokens * patch_tokens;
for py in 0..patch_shape.height {
for px in 0..patch_shape.width {
let patch_idx = py * patch_shape.width + px;
let coarse_idx =
(py / coarse_stride.max(1)) * pooled_width + (px / coarse_stride.max(1));
data[batch_offset + coarse_idx * patch_tokens + patch_idx] = 1.0;
}
}
}
Tensor::<B, 3>::from_data(
TensorData::new(data, [batch, coarse_tokens, patch_tokens]),
device,
)
}
fn cross_scale_read_tiled<B: Backend>(
coarse_rho: Tensor<B, 5>,
patch_query: Tensor<B, 4>,
coarse_shape: LocalGridShape2d,
patch_shape: LocalGridShape2d,
coarse_stride: usize,
) -> Option<Tensor<B, 4>> {
if coarse_stride <= 1 {
return Some(contract(coarse_rho, patch_query, patch_shape));
}
let [batch, rank, value_dim, coarse_h, coarse_w] = coarse_rho.shape().dims::<5>();
let [query_batch, query_rank, patch_h, patch_w] = patch_query.shape().dims::<4>();
if query_batch != batch
|| query_rank != rank
|| coarse_h != coarse_shape.height
|| coarse_w != coarse_shape.width
|| patch_h != patch_shape.height
|| patch_w != patch_shape.width
|| patch_h != coarse_h * coarse_stride
|| patch_w != coarse_w * coarse_stride
{
return None;
}
let coarse_rho = coarse_rho.unsqueeze_dim::<6>(3).unsqueeze_dim::<7>(5);
let patch_query = patch_query
.reshape([
batch,
rank,
coarse_stride,
coarse_h,
coarse_stride,
coarse_w,
])
.unsqueeze_dim::<7>(2);
let context = coarse_rho
.mul(patch_query)
.sum_dims_squeeze::<6, usize>(&[1]);
Some(context.reshape([batch, value_dim, patch_h, patch_w]))
}
fn hub_read<B: Backend>(
hub_rho: Tensor<B, 4>,
query: Tensor<B, 4>,
weights: Option<Tensor<B, 4>>,
) -> Tensor<B, 4> {
let [_, hubs, _, _] = hub_rho.shape().dims::<4>();
let [_, _, height, width] = query.shape().dims::<4>();
let query = tokens_to_target_major(query).unsqueeze_dim::<4>(1);
let hub_context = query.matmul(hub_rho);
let reduced = if let Some(weights) = weights {
let weights = tokens_to_target_major(weights)
.swap_dims(1, 2)
.unsqueeze_dim::<4>(3);
hub_context.mul(weights).sum_dims_squeeze::<3, usize>(&[1])
} else if hubs > 1 {
hub_context
.sum_dims_squeeze::<3, usize>(&[1])
.div_scalar(hubs as f32)
} else {
hub_context.sum_dims_squeeze::<3, usize>(&[1])
};
tokens_from_target_major(reduced, LocalGridShape2d::new(height, width))
}
fn hub_read_pair<B: Backend>(
hub_rho: Tensor<B, 4>,
patch_query: Tensor<B, 4>,
coarse_query: Tensor<B, 4>,
patch_weights: Option<Tensor<B, 4>>,
coarse_weights: Option<Tensor<B, 4>>,
patch_shape: LocalGridShape2d,
coarse_shape: LocalGridShape2d,
) -> (Tensor<B, 4>, Tensor<B, 4>) {
let [_, hubs, _, _] = hub_rho.shape().dims::<4>();
let patch_tokens = patch_shape.token_count();
let coarse_tokens = coarse_shape.token_count();
let patch_query = tokens_to_target_major(patch_query);
let coarse_query = tokens_to_target_major(coarse_query);
let query = Tensor::cat(vec![patch_query, coarse_query], 1).unsqueeze_dim::<4>(1);
let context = query.matmul(hub_rho);
let patch_context = hub_reduce_slice(
context.clone(),
0,
patch_tokens,
patch_shape,
patch_weights,
hubs,
);
let coarse_context = hub_reduce_slice(
context,
patch_tokens,
coarse_tokens,
coarse_shape,
coarse_weights,
hubs,
);
(patch_context, coarse_context)
}
fn hub_reduce_slice<B: Backend>(
context: Tensor<B, 4>,
start: usize,
len: usize,
shape: LocalGridShape2d,
weights: Option<Tensor<B, 4>>,
hubs: usize,
) -> Tensor<B, 4> {
let slice = context.slice_dim(2, start..start + len);
let reduced = if let Some(weights) = weights {
let weights = tokens_to_target_major(weights)
.swap_dims(1, 2)
.unsqueeze_dim::<4>(3);
slice.mul(weights).sum_dims_squeeze::<3, usize>(&[1])
} else if hubs > 1 {
slice
.sum_dims_squeeze::<3, usize>(&[1])
.div_scalar(hubs as f32)
} else {
slice.sum_dims_squeeze::<3, usize>(&[1])
};
tokens_from_target_major(reduced, shape)
}
fn pool_target_major_outer<B: Backend>(
update: Tensor<B, 4>,
patch_shape: LocalGridShape2d,
stride: usize,
) -> Tensor<B, 4> {
if stride <= 1 {
return update;
}
let [batch, tokens, rank, value_dim] = update.shape().dims::<4>();
assert_eq!(tokens, patch_shape.token_count());
let pooled_height = patch_shape.height / stride;
let pooled_width = patch_shape.width / stride;
update
.reshape([
batch,
patch_shape.height,
patch_shape.width,
rank * value_dim,
])
.reshape([
batch,
pooled_height,
stride,
pooled_width,
stride,
rank * value_dim,
])
.sum_dims_squeeze::<4, usize>(&[2, 4])
.reshape([batch, pooled_height * pooled_width, rank, value_dim])
}
fn pool_target_major_outer_fast<B: Backend>(
update: Tensor<B, 4>,
patch_shape: LocalGridShape2d,
coarse_shape: LocalGridShape2d,
stride: usize,
route: Option<Tensor<B, 3>>,
) -> Option<Tensor<B, 4>> {
if stride <= 1 {
return Some(update);
}
let route = route?;
let [batch, patch_tokens, rank, value_dim] = update.shape().dims::<4>();
let [route_batch, coarse_tokens, route_patch_tokens] = route.shape().dims::<3>();
if route_batch != batch
|| route_patch_tokens != patch_tokens
|| coarse_tokens != coarse_shape.token_count()
|| patch_tokens != patch_shape.token_count()
|| patch_shape.height != coarse_shape.height * stride
|| patch_shape.width != coarse_shape.width * stride
{
return None;
}
let update = update.reshape([batch, patch_tokens, rank * value_dim]);
Some(
route
.matmul(update)
.reshape([batch, coarse_tokens, rank, value_dim]),
)
}
fn update_hub_from_deltas<B: Backend>(
hub_rho: Tensor<B, 4>,
patch_delta: Tensor<B, 3>,
coarse_delta: Tensor<B, 3>,
hub_count: usize,
decay: Tensor<B, 1>,
) -> Tensor<B, 4> {
let delta = if hub_count > 1 {
patch_delta
.add(coarse_delta)
.unsqueeze_dim::<4>(1)
.div_scalar(hub_count as f32)
} else {
patch_delta.add(coarse_delta).unsqueeze_dim::<4>(1)
};
target_major_decay_add(hub_rho, delta, decay)
}
fn update_hub_from_spatial_updates<B: Backend>(
hub_rho: Tensor<B, 4>,
patch_update: Option<Tensor<B, 5>>,
coarse_update: Option<Tensor<B, 5>>,
patch_weights: Option<Tensor<B, 4>>,
coarse_weights: Option<Tensor<B, 4>>,
hub_count: usize,
decay: Tensor<B, 1>,
) -> Tensor<B, 4> {
let zero_delta = || {
let [batch, hubs, rank, value_dim] = hub_rho.shape().dims::<4>();
Tensor::<B, 4>::zeros([batch, hubs, rank, value_dim], &hub_rho.device())
};
if hub_count <= 1 {
let sum_patch = patch_update
.map(|update| update.sum_dims_squeeze::<3, usize>(&[3, 4]))
.unwrap_or_else(|| {
let [batch, _, rank, value_dim] = hub_rho.shape().dims::<4>();
Tensor::<B, 3>::zeros([batch, rank, value_dim], &hub_rho.device())
});
let sum_coarse = coarse_update
.map(|update| update.sum_dims_squeeze::<3, usize>(&[3, 4]))
.unwrap_or_else(|| {
let [batch, _, rank, value_dim] = hub_rho.shape().dims::<4>();
Tensor::<B, 3>::zeros([batch, rank, value_dim], &hub_rho.device())
});
return target_major_decay_add(
hub_rho,
sum_patch.add(sum_coarse).unsqueeze_dim::<4>(1),
decay,
);
}
let delta = weighted_global_sum_pair(
patch_update,
coarse_update,
patch_weights,
coarse_weights,
hub_count,
)
.unwrap_or_else(zero_delta);
target_major_decay_add(hub_rho, delta, decay)
}
fn weighted_global_sum<B: Backend>(
update: Tensor<B, 5>,
weights: Option<Tensor<B, 4>>,
hub_count: usize,
) -> Tensor<B, 4> {
let [batch, rank, value_dim, height, width] = update.shape().dims::<5>();
let weights = weights.unwrap_or_else(|| {
Tensor::<B, 4>::ones([batch, hub_count.max(1), height, width], &update.device())
.div_scalar(hub_count.max(1) as f32)
});
let tokens = height * width;
update
.reshape([batch, 1, rank, value_dim, tokens])
.mul(weights.reshape([batch, hub_count.max(1), 1, 1, tokens]))
.sum_dims_squeeze::<4, usize>(&[4])
}
fn weighted_global_sum_pair<B: Backend>(
patch_update: Option<Tensor<B, 5>>,
coarse_update: Option<Tensor<B, 5>>,
patch_weights: Option<Tensor<B, 4>>,
coarse_weights: Option<Tensor<B, 4>>,
hub_count: usize,
) -> Option<Tensor<B, 4>> {
match (patch_update, coarse_update) {
(None, None) => None,
(Some(patch), None) => Some(weighted_global_sum(patch, patch_weights, hub_count)),
(None, Some(coarse)) => Some(weighted_global_sum(coarse, coarse_weights, hub_count)),
(Some(patch), Some(coarse)) => {
let [
batch,
patch_rank,
patch_value_dim,
patch_height,
patch_width,
] = patch.shape().dims::<5>();
let [
coarse_batch,
coarse_rank,
coarse_value_dim,
coarse_height,
coarse_width,
] = coarse.shape().dims::<5>();
if coarse_batch != batch
|| patch_rank != coarse_rank
|| patch_value_dim != coarse_value_dim
{
return Some(
weighted_global_sum(patch, patch_weights, hub_count).add(weighted_global_sum(
coarse,
coarse_weights,
hub_count,
)),
);
}
let patch_tokens = patch_height * patch_width;
let coarse_tokens = coarse_height * coarse_width;
let patch_weights = patch_weights.unwrap_or_else(|| {
Tensor::<B, 4>::ones(
[batch, hub_count.max(1), patch_height, patch_width],
&patch.device(),
)
.div_scalar(hub_count.max(1) as f32)
});
let coarse_weights = coarse_weights.unwrap_or_else(|| {
Tensor::<B, 4>::ones(
[batch, hub_count.max(1), coarse_height, coarse_width],
&coarse.device(),
)
.div_scalar(hub_count.max(1) as f32)
});
let patch_update = patch.reshape([batch, patch_rank, patch_value_dim, patch_tokens]);
let coarse_update =
coarse.reshape([batch, coarse_rank, coarse_value_dim, coarse_tokens]);
let patch_weights = patch_weights.reshape([batch, hub_count.max(1), patch_tokens]);
let coarse_weights = coarse_weights.reshape([batch, hub_count.max(1), coarse_tokens]);
let update = Tensor::cat(vec![patch_update, coarse_update], 3).unsqueeze_dim::<5>(1);
let weights = Tensor::cat(vec![patch_weights, coarse_weights], 2).reshape([
batch,
hub_count.max(1),
1,
1,
patch_tokens + coarse_tokens,
]);
Some(update.mul(weights).sum_dims_squeeze::<4, usize>(&[4]))
}
}
}
fn target_major_identity_read<B: Backend>(query: Tensor<B, 3>, rho: Tensor<B, 4>) -> Tensor<B, 3> {
let [batch, targets, rank] = query.shape().dims::<3>();
let [rho_batch, rho_targets, rho_rank, value_dim] = rho.shape().dims::<4>();
assert_eq!(rho_batch, batch);
assert_eq!(rho_targets, targets);
assert_eq!(rho_rank, rank);
rho.mul(query.unsqueeze_dim::<4>(3))
.sum_dims_squeeze::<3, usize>(&[2])
.reshape([batch, targets, value_dim])
}
fn target_major_outer_product<B: Backend>(
query: Tensor<B, 3>,
value: Tensor<B, 3>,
) -> Tensor<B, 4> {
let [batch, targets, rank] = query.shape().dims::<3>();
let [value_batch, value_targets, value_dim] = value.shape().dims::<3>();
assert_eq!(value_batch, batch);
assert_eq!(value_targets, targets);
query
.unsqueeze_dim::<4>(3)
.mul(value.unsqueeze_dim::<4>(2))
.reshape([batch, targets, rank, value_dim])
}
fn target_major_decay_add<B: Backend>(
rho: Tensor<B, 4>,
update: Tensor<B, 4>,
decay: Tensor<B, 1>,
) -> Tensor<B, 4> {
let [_, _, rank, _] = rho.shape().dims::<4>();
let [decay_len] = decay.shape().dims::<1>();
let decay = match decay_len {
1 => decay.repeat_dim(0, rank.max(1)),
len if len == rank => decay,
_ => panic!("structured pyramid decay length {decay_len} must be 1 or {rank}"),
};
rho.mul(decay.reshape([1, 1, rank, 1])).add(update)
}
#[cfg(test)]
mod tests;