use burn::tensor::backend::Backend;
use burn::tensor::{Tensor, TensorData};
use burn_dragon_core::LayerVizState;
use super::frame::{LAYER_GAP, VizConfig, VizFrame, clamp_history, clamp_layers, units_height};
use super::palette::{
COLOR_ACTIVITY_HIGH, COLOR_ACTIVITY_LOW, COLOR_INTERACTION_HIGH, COLOR_INTERACTION_LOW,
COLOR_MEMORY_HIGH, COLOR_MEMORY_LOW, COLOR_SEPARATOR, COLOR_WRITES_HIGH, COLOR_WRITES_LOW,
};
const ACTIVE_EPS: f32 = 1e-3;
const SOFT_CLIP: f32 = 0.35;
const INTENSITY_GAMMA: f32 = 1.15;
const RGB_MASK: [f32; 4] = [1.0, 1.0, 1.0, 0.0];
const ALPHA_MASK: [f32; 4] = [0.0, 0.0, 0.0, 1.0];
pub struct VizEncoder<B: Backend> {
config: VizConfig,
layers_visible: usize,
heads: usize,
latent_per_head: usize,
latent_total: usize,
layer_gap: usize,
units_x: Tensor<B, 3>,
units_y: Tensor<B, 3>,
units_xy: Tensor<B, 3>,
units_rho: Tensor<B, 3>,
color_x_low: Tensor<B, 3>,
color_x_high: Tensor<B, 3>,
color_y_low: Tensor<B, 3>,
color_y_high: Tensor<B, 3>,
color_xy_low: Tensor<B, 3>,
color_xy_high: Tensor<B, 3>,
color_rho_low: Tensor<B, 3>,
color_rho_high: Tensor<B, 3>,
rgb_mask: Tensor<B, 3>,
alpha_mask: Tensor<B, 3>,
separator_rgba: Option<Tensor<B, 3>>,
zero_units: Tensor<B, 3>,
zero_gap: Option<Tensor<B, 3>>,
}
impl<B: Backend> VizEncoder<B> {
pub fn new(
mut config: VizConfig,
layers: usize,
heads: usize,
latent_per_head: usize,
device: &B::Device,
) -> Self {
config.history = clamp_history(config.history);
let history = config.history;
let latent_total = heads.saturating_mul(latent_per_head).max(1);
let layers = layers.max(1);
let layers_visible = clamp_layers(layers, latent_total);
let layer_gap = LAYER_GAP;
let units_height = units_height(layers_visible, latent_total);
let units_x = Tensor::<B, 3>::zeros([units_height, history, 4], device);
let units_y = Tensor::<B, 3>::zeros([units_height, history, 4], device);
let units_xy = Tensor::<B, 3>::zeros([units_height, history, 4], device);
let units_rho = Tensor::<B, 3>::zeros([units_height, history, 4], device);
let color_x_low = color_tensor::<B>(COLOR_WRITES_LOW, device);
let color_x_high = color_tensor::<B>(COLOR_WRITES_HIGH, device);
let color_y_low = color_tensor::<B>(COLOR_ACTIVITY_LOW, device);
let color_y_high = color_tensor::<B>(COLOR_ACTIVITY_HIGH, device);
let color_xy_low = color_tensor::<B>(COLOR_INTERACTION_LOW, device);
let color_xy_high = color_tensor::<B>(COLOR_INTERACTION_HIGH, device);
let color_rho_low = color_tensor::<B>(COLOR_MEMORY_LOW, device);
let color_rho_high = color_tensor::<B>(COLOR_MEMORY_HIGH, device);
let rgb_mask = color_tensor::<B>(RGB_MASK, device);
let alpha_mask = color_tensor::<B>(ALPHA_MASK, device);
let separator_rgba = build_separator::<B>(latent_total, latent_per_head, device);
let zero_units = Tensor::<B, 3>::zeros([latent_total, 1, 4], device);
let zero_gap = if layer_gap > 0 {
Some(Tensor::<B, 3>::zeros([layer_gap, 1, 4], device))
} else {
None
};
Self {
config,
layers_visible,
heads,
latent_per_head,
latent_total,
layer_gap,
units_x,
units_y,
units_xy,
units_rho,
color_x_low,
color_x_high,
color_y_low,
color_y_high,
color_xy_low,
color_xy_high,
color_rho_low,
color_rho_high,
rgb_mask,
alpha_mask,
separator_rgba,
zero_units,
zero_gap,
}
}
pub fn should_capture(&self, token_index: usize) -> bool {
let stride = self.config.stride_tokens.max(1);
token_index % stride == 0
}
pub fn step(&mut self, layers: &[Option<LayerVizState<B>>], token_index: usize) -> VizFrame<B> {
let device = self.zero_units.device();
let history = self.config.history.max(1);
let cursor = token_index % history;
let available_layers = layers.len().max(1);
let layer_count = self.layers_visible.min(available_layers).max(1);
let layer_start = available_layers.saturating_sub(layer_count);
for layer_idx in 0..layer_count {
let source_idx = layer_start + layer_idx;
let offset = layer_idx.saturating_mul(self.latent_total.saturating_add(self.layer_gap));
if let Some(gap) = &self.zero_gap {
if layer_idx > 0 {
let gap_start = offset.saturating_sub(self.layer_gap);
let gap_end = offset;
self.units_x = self
.units_x
.clone()
.slice_assign([gap_start..gap_end, cursor..cursor + 1, 0..4], gap.clone());
self.units_y = self
.units_y
.clone()
.slice_assign([gap_start..gap_end, cursor..cursor + 1, 0..4], gap.clone());
self.units_xy = self
.units_xy
.clone()
.slice_assign([gap_start..gap_end, cursor..cursor + 1, 0..4], gap.clone());
self.units_rho = self
.units_rho
.clone()
.slice_assign([gap_start..gap_end, cursor..cursor + 1, 0..4], gap.clone());
}
}
let (x_last, y_last, xy_last, rho_last) = layers
.get(source_idx)
.and_then(|layer| layer.as_ref())
.map(|layer| {
(
layer.x_last.clone(),
layer.y_last.clone(),
layer.xy_last.clone(),
layer.rho_last.clone(),
)
})
.unwrap_or_else(|| {
(
Tensor::<B, 2>::zeros([self.heads, self.latent_per_head], &device),
Tensor::<B, 2>::zeros([self.heads, self.latent_per_head], &device),
Tensor::<B, 2>::zeros([self.heads, self.latent_per_head], &device),
Tensor::<B, 2>::zeros([self.heads, self.latent_per_head], &device),
)
});
let x_flat = x_last.reshape([self.latent_total]);
let y_flat = y_last.reshape([self.latent_total]);
let xy_flat = xy_last.reshape([self.latent_total]);
let rho_flat = rho_last.reshape([self.latent_total]);
let units_x_col = self.encode_units(
x_flat,
self.config.gain_x,
&self.color_x_low,
&self.color_x_high,
);
let units_y_col = self.encode_units(
y_flat,
self.config.gain_xy,
&self.color_y_low,
&self.color_y_high,
);
let units_xy_col = self.encode_units(
xy_flat,
self.config.gain_xy,
&self.color_xy_low,
&self.color_xy_high,
);
let units_rho_col = self.encode_units(
rho_flat,
self.config.gain_xy,
&self.color_rho_low,
&self.color_rho_high,
);
let range = offset..offset + self.latent_total;
self.units_x = self
.units_x
.clone()
.slice_assign([range.clone(), cursor..cursor + 1, 0..4], units_x_col);
self.units_y = self
.units_y
.clone()
.slice_assign([range.clone(), cursor..cursor + 1, 0..4], units_y_col);
self.units_xy = self
.units_xy
.clone()
.slice_assign([range.clone(), cursor..cursor + 1, 0..4], units_xy_col);
self.units_rho = self
.units_rho
.clone()
.slice_assign([range, cursor..cursor + 1, 0..4], units_rho_col);
}
VizFrame {
units_x: self.units_x.clone(),
units_y: self.units_y.clone(),
units_xy: self.units_xy.clone(),
units_rho: self.units_rho.clone(),
cursor,
token_index,
}
}
fn encode_units(
&self,
values: Tensor<B, 1>,
gain: f32,
color_low: &Tensor<B, 3>,
color_high: &Tensor<B, 3>,
) -> Tensor<B, 3> {
if self.latent_total == 0 {
return self.zero_units.clone();
}
let values = values.clamp_min(0.0).reshape([self.latent_total, 1]);
let scaled = values.clone().mul_scalar(gain);
let mag = scaled.clone().div(scaled.add_scalar(SOFT_CLIP));
let mask = values
.clone()
.div(values.clone().add_scalar(ACTIVE_EPS * 4.0));
let intensity = (mag * mask).powf_scalar(INTENSITY_GAMMA);
let intensity = intensity.reshape([self.latent_total, 1, 1]);
let ramp = color_low.clone() + (color_high.clone() - color_low.clone()) * intensity.clone();
let rgb = ramp * intensity.clone() * self.rgb_mask.clone();
let alpha = self.alpha_mask.clone();
let mut column = rgb + alpha;
if let Some(sep) = &self.separator_rgba {
column = column.max_pair(sep.clone());
}
column
}
}
fn color_tensor<B: Backend>(rgba: [f32; 4], device: &B::Device) -> Tensor<B, 3> {
let [r, g, b, a] = rgba;
let rgba = [srgb_to_linear(r), srgb_to_linear(g), srgb_to_linear(b), a];
Tensor::<B, 3>::from_data(TensorData::new(rgba.to_vec(), [1, 1, 4]), device)
}
fn srgb_to_linear(value: f32) -> f32 {
if value <= 0.04045 {
value / 12.92
} else {
((value + 0.055) / 1.055).powf(2.4)
}
}
fn build_separator<B: Backend>(
latent_total: usize,
latent_per_head: usize,
device: &B::Device,
) -> Option<Tensor<B, 3>> {
if latent_total == 0 || latent_per_head == 0 {
return None;
}
let mut mask = vec![0.0f32; latent_total];
for idx in 0..latent_total {
if idx > 0 && idx % latent_per_head == 0 {
mask[idx] = 1.0;
}
}
let mask = Tensor::<B, 1>::from_data(TensorData::new(mask, [latent_total]), device).reshape([
latent_total,
1,
1,
]);
let color = color_tensor::<B>(COLOR_SEPARATOR, device);
Some(mask * color)
}
#[cfg(all(test, feature = "viz", feature = "cli"))]
mod tests {
use super::{VizConfig, VizEncoder};
use burn::tensor::{Int, Tensor, TensorData};
use burn_ndarray::{NdArray, NdArrayDevice};
use burn_dragon_core::{BDH, BDHConfig, LayerVizState, ModelState};
type Backend = NdArray<f32>;
fn device() -> NdArrayDevice {
NdArrayDevice::default()
}
#[test]
fn viz_state_collects_last_token() {
let device = device();
let mut config = BDHConfig::default();
config.n_layer = 2;
config.n_embd = 8;
config.n_head = 2;
config.mlp_internal_dim_multiplier = 2;
config.vocab_size = 16;
config.dropout = 0.0;
config.fused_kernels.enabled = false;
let model = BDH::<Backend>::new(config.clone(), &device);
let tokens = Tensor::<Backend, 2, Int>::from_data(
TensorData::new(vec![1_i64, 2, 3], [1, 3]),
&device,
);
let mut state = ModelState::<Backend>::new(config.n_layer);
let _ = model.forward_with_state(tokens, &mut state);
let layers: Vec<Option<LayerVizState<Backend>>> = state.take_viz();
assert_eq!(layers.len(), config.n_layer);
for layer in layers {
let layer = layer.expect("expected viz state per layer");
let dims = layer.x_last.shape().dims::<2>();
assert_eq!(dims[0], config.n_head);
assert_eq!(dims[1], config.latent_per_head());
assert_eq!(layer.y_last.shape().dims::<2>(), dims);
assert_eq!(layer.xy_last.shape().dims::<2>(), dims);
assert_eq!(layer.rho_last.shape().dims::<2>(), dims);
}
}
#[test]
fn viz_encoder_emits_expected_shapes() {
let device = device();
let config = VizConfig {
history: 4,
layer_focus: 0,
stride_tokens: 1,
gain_x: 1.0,
gain_xy: 1.0,
};
let mut encoder = VizEncoder::<Backend>::new(config, 2, 2, 2, &device);
let layer0 = LayerVizState {
x_last: Tensor::<Backend, 2>::from_data(
TensorData::new(vec![1.0, 0.0, 0.0, 0.0], [2, 2]),
&device,
),
y_last: Tensor::<Backend, 2>::from_data(
TensorData::new(vec![0.0, 1.0, 0.0, 0.0], [2, 2]),
&device,
),
xy_last: Tensor::<Backend, 2>::from_data(
TensorData::new(vec![0.5, 0.0, 0.0, 0.0], [2, 2]),
&device,
),
rho_last: Tensor::<Backend, 2>::from_data(
TensorData::new(vec![0.2, 0.0, 0.0, 0.0], [2, 2]),
&device,
),
};
let layer1 = LayerVizState {
x_last: Tensor::<Backend, 2>::zeros([2, 2], &device),
y_last: Tensor::<Backend, 2>::zeros([2, 2], &device),
xy_last: Tensor::<Backend, 2>::zeros([2, 2], &device),
rho_last: Tensor::<Backend, 2>::zeros([2, 2], &device),
};
let frame = encoder.step(&[Some(layer0), Some(layer1)], 0);
let expected_units = super::units_height(2, 2 * 2);
assert_eq!(frame.units_x.shape().dims::<3>(), [expected_units, 4, 4]);
assert_eq!(frame.units_y.shape().dims::<3>(), [expected_units, 4, 4]);
assert_eq!(frame.units_xy.shape().dims::<3>(), [expected_units, 4, 4]);
assert_eq!(frame.units_rho.shape().dims::<3>(), [expected_units, 4, 4]);
assert_eq!(frame.cursor, 0);
let units_x = frame
.units_x
.to_data()
.convert::<f32>()
.into_vec::<f32>()
.expect("units x vec");
let units_y = frame
.units_y
.to_data()
.convert::<f32>()
.into_vec::<f32>()
.expect("units y vec");
let units_xy = frame
.units_xy
.to_data()
.convert::<f32>()
.into_vec::<f32>()
.expect("units xy vec");
let units_rho = frame
.units_rho
.to_data()
.convert::<f32>()
.into_vec::<f32>()
.expect("units rho vec");
let unit_stride = 4 * 4;
let unit0_x_sum: f32 = units_x[0..4].iter().sum();
let unit1_y_sum: f32 = units_y[unit_stride..unit_stride + 4].iter().sum();
let unit0_xy_sum: f32 = units_xy[0..4].iter().sum();
let unit0_rho_sum: f32 = units_rho[0..4].iter().sum();
assert!(unit0_x_sum > 0.0);
assert!(unit1_y_sum > 0.0);
assert!(unit0_xy_sum > 0.0);
assert!(unit0_rho_sum > 0.0);
}
}