pub struct AttentionCache { /* private fields */ }Expand description
Stores per-layer attention weights from a forward pass.
Each tensor has shape [batch, heads, seq_q, seq_k] — the post-softmax
attention pattern for one layer.
§Example
use candle_mi::AttentionCache;
use candle_core::{Device, Tensor};
let mut cache = AttentionCache::with_capacity(32);
// shape [batch=1, heads=8, seq=10, seq=10]
cache.push(Tensor::zeros((1, 8, 10, 10), candle_core::DType::F32, &Device::Cpu).unwrap());
// Query what position 5 attends to in layer 0
let attn_row = cache.attention_from_position(0, 5).unwrap();Implementations§
Source§impl AttentionCache
impl AttentionCache
Sourcepub fn with_capacity(n_layers: usize) -> Self
pub fn with_capacity(n_layers: usize) -> Self
Create an empty cache with capacity for n_layers layers.
Sourcepub fn attention_from_position(
&self,
layer: usize,
position: usize,
) -> Result<Vec<f32>>
pub fn attention_from_position( &self, layer: usize, position: usize, ) -> Result<Vec<f32>>
Get attention weights FROM a specific query position, averaged across heads.
Returns a vector of length seq_k representing how much position
attends to every key position, averaged over all attention heads
(batch index 0).
§Shapes
- returns:
[seq_k]asVec<f32>
§Errors
Returns MIError::Hook if the layer is not in the cache or the
position is out of range.
Sourcepub fn attention_to_position(
&self,
layer: usize,
position: usize,
) -> Result<Vec<f32>>
pub fn attention_to_position( &self, layer: usize, position: usize, ) -> Result<Vec<f32>>
Get attention weights TO a specific key position, averaged across heads.
Returns a vector of length seq_q representing how much every query
position attends to position, averaged over all attention heads
(batch index 0).
§Shapes
- returns:
[seq_q]asVec<f32>
§Errors
Returns MIError::Hook if the layer is not in the cache or the
position is out of range.
Sourcepub fn top_attended_positions(
&self,
layer: usize,
from_position: usize,
k: usize,
) -> Result<Vec<(usize, f32)>>
pub fn top_attended_positions( &self, layer: usize, from_position: usize, k: usize, ) -> Result<Vec<(usize, f32)>>
Get the top-k key positions that a given query position attends to most.
Returns up to k pairs of (key_position, weight) sorted by
descending attention weight, averaged across heads (batch index 0).
§Errors
Returns MIError::Hook if the layer is not in the cache or the
position is out of range.
Trait Implementations§
Auto Trait Implementations§
impl Freeze for AttentionCache
impl !RefUnwindSafe for AttentionCache
impl Send for AttentionCache
impl Sync for AttentionCache
impl Unpin for AttentionCache
impl UnsafeUnpin for AttentionCache
impl !UnwindSafe for AttentionCache
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more