use std::collections::HashMap;
#[derive(Debug, Clone, Copy)]
pub struct HaloConfig {
pub tile_width: usize,
pub tile_height: usize,
pub halo_size: usize,
}
impl HaloConfig {
pub fn default() -> Self {
HaloConfig {
tile_width: 16,
tile_height: 16,
halo_size: 1,
}
}
pub fn padded_dimensions(&self) -> (usize, usize) {
(
self.tile_width + 2 * self.halo_size,
self.tile_height + 2 * self.halo_size,
)
}
pub fn shared_memory_size(&self) -> usize {
let (w, h) = self.padded_dimensions();
w * h
}
}
#[derive(Debug, Clone)]
pub struct HaloTile {
pub tile_id: (usize, usize),
pub dp_values: Vec<i32>,
pub rows: usize,
pub cols: usize,
pub halo_size: usize,
}
impl HaloTile {
pub fn new(tile_id: (usize, usize), config: &HaloConfig) -> Self {
let (w, h) = config.padded_dimensions();
HaloTile {
tile_id,
dp_values: vec![i32::MIN / 2; w * h], rows: h,
cols: w,
halo_size: config.halo_size,
}
}
pub fn get(&self, i: usize, j: usize) -> Option<i32> {
if i < self.rows && j < self.cols {
Some(self.dp_values[i * self.cols + j])
} else {
None
}
}
pub fn set(&mut self, i: usize, j: usize, value: i32) -> bool {
if i < self.rows && j < self.cols {
self.dp_values[i * self.cols + j] = value;
true
} else {
false
}
}
pub fn set_core(&mut self, core_data: &[i32]) -> bool {
let core_size = (self.rows - 2 * self.halo_size) * (self.cols - 2 * self.halo_size);
if core_data.len() < core_size {
return false;
}
for i in 0..self.rows - 2 * self.halo_size {
for j in 0..self.cols - 2 * self.halo_size {
let src_idx = i * (self.cols - 2 * self.halo_size) + j;
let dst_idx = (i + self.halo_size) * self.cols + (j + self.halo_size);
self.dp_values[dst_idx] = core_data[src_idx];
}
}
true
}
pub fn get_core(&self) -> Vec<i32> {
let core_height = self.rows - 2 * self.halo_size;
let core_width = self.cols - 2 * self.halo_size;
let mut core = vec![0i32; core_height * core_width];
for i in 0..core_height {
for j in 0..core_width {
let src_idx = (i + self.halo_size) * self.cols + (j + self.halo_size);
let dst_idx = i * core_width + j;
core[dst_idx] = self.dp_values[src_idx];
}
}
core
}
pub fn update_top_halo(&mut self, neighbor_bottom_row: &[i32]) {
let core_width = self.cols - 2 * self.halo_size;
if neighbor_bottom_row.len() < core_width {
return;
}
for j in 0..core_width {
let src_idx = j;
let dst_idx = (self.halo_size - 1) * self.cols + (self.halo_size + j);
self.dp_values[dst_idx] = neighbor_bottom_row[src_idx];
}
}
pub fn update_left_halo(&mut self, neighbor_right_col: &[i32]) {
let core_height = self.rows - 2 * self.halo_size;
if neighbor_right_col.len() < core_height {
return;
}
for i in 0..core_height {
let src_idx = i;
let dst_idx = (self.halo_size + i) * self.cols + (self.halo_size - 1);
self.dp_values[dst_idx] = neighbor_right_col[src_idx];
}
}
pub fn get_bottom_core_row(&self) -> Vec<i32> {
let core_height = self.rows - 2 * self.halo_size;
let core_width = self.cols - 2 * self.halo_size;
let last_core_row = core_height - 1;
let mut row = vec![0i32; core_width];
for j in 0..core_width {
let idx = (self.halo_size + last_core_row) * self.cols + (self.halo_size + j);
row[j] = self.dp_values[idx];
}
row
}
pub fn get_right_core_col(&self) -> Vec<i32> {
let core_height = self.rows - 2 * self.halo_size;
let core_width = self.cols - 2 * self.halo_size;
let last_core_col = core_width - 1;
let mut col = vec![0i32; core_height];
for i in 0..core_height {
let idx = (self.halo_size + i) * self.cols + (self.halo_size + last_core_col);
col[i] = self.dp_values[idx];
}
col
}
}
pub struct HaloBufferManager {
config: HaloConfig,
tiles: HashMap<(usize, usize), HaloTile>,
pub seq_len1: usize,
pub seq_len2: usize,
}
impl HaloBufferManager {
pub fn new(seq_len1: usize, seq_len2: usize, config: HaloConfig) -> Self {
HaloBufferManager {
config,
tiles: HashMap::new(),
seq_len1,
seq_len2,
}
}
pub fn get_tile(&mut self, tile_row: usize, tile_col: usize) -> &mut HaloTile {
self.tiles
.entry((tile_row, tile_col))
.or_insert_with(|| HaloTile::new((tile_row, tile_col), &self.config))
}
pub fn num_tile_rows(&self) -> usize {
(self.seq_len1 + self.config.tile_height - 1) / self.config.tile_height
}
pub fn num_tile_cols(&self) -> usize {
(self.seq_len2 + self.config.tile_width - 1) / self.config.tile_width
}
pub fn propagate_boundaries(&mut self, tile_row: usize, tile_col: usize) {
if !self.tiles.contains_key(&(tile_row, tile_col)) {
return;
}
let bottom_row = self.tiles[&(tile_row, tile_col)].get_bottom_core_row();
let right_col = self.tiles[&(tile_row, tile_col)].get_right_core_col();
if tile_row + 1 < self.num_tile_rows() {
let neighbor = self.get_tile(tile_row + 1, tile_col);
neighbor.update_top_halo(&bottom_row);
}
if tile_col + 1 < self.num_tile_cols() {
let neighbor = self.get_tile(tile_row, tile_col + 1);
neighbor.update_left_halo(&right_col);
}
}
pub fn initialize_boundaries(&mut self) {
let num_rows = self.num_tile_rows();
let num_cols = self.num_tile_cols();
for tile_row in 0..num_rows {
for tile_col in 0..num_cols {
let tile = self.get_tile(tile_row, tile_col);
if tile_row == 0 {
for j in 0..tile.cols {
tile.set(tile.halo_size - 1, j, 0);
}
}
if tile_col == 0 {
for i in 0..tile.rows {
tile.set(i, tile.halo_size - 1, 0);
}
}
if tile_row == 0 && tile_col == 0 {
tile.set(tile.halo_size - 1, tile.halo_size - 1, 0);
}
}
}
}
pub fn total_gpu_memory(&self) -> usize {
let tile_mem = self.config.shared_memory_size() * std::mem::size_of::<i32>();
let num_tiles = self.num_tile_rows() * self.num_tile_cols();
num_tiles * tile_mem
}
pub fn assemble_result(&self) -> Vec<Vec<i32>> {
let mut result = vec![vec![0i32; self.seq_len2 + 1]; self.seq_len1 + 1];
let num_rows = self.num_tile_rows();
let num_cols = self.num_tile_cols();
for tile_row in 0..num_rows {
for tile_col in 0..num_cols {
if let Some(tile) = self.tiles.get(&(tile_row, tile_col)) {
let core = tile.get_core();
let core_height = self.config.tile_height;
let core_width = self.config.tile_width;
let global_row_start = tile_row * core_height;
let global_col_start = tile_col * core_width;
for i in 0..core_height {
for j in 0..core_width {
let global_i = global_row_start + i + 1;
let global_j = global_col_start + j + 1;
if global_i <= self.seq_len1 && global_j <= self.seq_len2 {
let src_idx = i * core_width + j;
result[global_i][global_j] = core[src_idx];
}
}
}
}
}
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_halo_config_dimensions() {
let config = HaloConfig::default();
let (w, h) = config.padded_dimensions();
assert_eq!(w, 18); assert_eq!(h, 18);
assert_eq!(config.shared_memory_size(), 18 * 18);
}
#[test]
fn test_halo_tile_creation() {
let config = HaloConfig::default();
let tile = HaloTile::new((0, 0), &config);
assert_eq!(tile.rows, 18);
assert_eq!(tile.cols, 18);
assert_eq!(tile.dp_values.len(), 18 * 18);
}
#[test]
fn test_halo_tile_get_set() {
let config = HaloConfig::default();
let mut tile = HaloTile::new((0, 0), &config);
tile.set(5, 5, 42);
assert_eq!(tile.get(5, 5), Some(42));
assert_eq!(tile.get(20, 20), None); }
#[test]
fn test_halo_buffer_manager() {
let config = HaloConfig::default();
let mut manager = HaloBufferManager::new(32, 32, config);
assert_eq!(manager.num_tile_rows(), 2);
assert_eq!(manager.num_tile_cols(), 2);
let tile = manager.get_tile(0, 0);
assert_eq!(tile.tile_id, (0, 0));
}
#[test]
fn test_boundary_initialization() {
let config = HaloConfig::default();
let mut manager = HaloBufferManager::new(16, 16, config);
manager.initialize_boundaries();
let tile = manager.get_tile(0, 0);
for j in 0..tile.cols {
assert_eq!(tile.get(tile.halo_size - 1, j), Some(0));
}
}
#[test]
fn test_core_region_extraction() {
let config = HaloConfig::default();
let mut tile = HaloTile::new((0, 0), &config);
for i in 0..16 {
for j in 0..16 {
tile.set(i + tile.halo_size, j + tile.halo_size, (i * 16 + j) as i32);
}
}
let core = tile.get_core();
assert_eq!(core.len(), 16 * 16);
assert_eq!(core[0], 0);
assert_eq!(core[255], 255);
}
#[test]
fn test_halo_propagation() {
let config = HaloConfig::default();
let mut manager = HaloBufferManager::new(32, 32, config.clone());
let tile1 = manager.get_tile(0, 0);
let core_width = config.tile_width;
for j in 0..core_width {
let idx = (config.tile_height - 1 + config.halo_size) * tile1.cols
+ (config.halo_size + j);
tile1.dp_values[idx] = (j as i32) * 10;
}
manager.propagate_boundaries(0, 0);
let tile2 = manager.get_tile(1, 0);
for j in 0..core_width {
let top_halo_idx = (config.halo_size - 1) * tile2.cols + (config.halo_size + j);
assert_eq!(tile2.dp_values[top_halo_idx], (j as i32) * 10);
}
}
}