use super::{ComputationGraph, GraphOperator, KernelType};
use crate::MobilePlatform;
use std::collections::HashMap;
use trustformers_core::errors::Result;
use trustformers_core::Tensor;
use trustformers_core::TrustformersError;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CacheStrategy {
L1Optimized,
L2Optimized,
Streaming,
Prefetch,
}
#[derive(Debug, Clone)]
pub struct TilingConfig {
pub tile_sizes: Vec<usize>,
pub loop_order: Vec<usize>,
pub unroll_factors: Vec<usize>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DataLayout {
NCHW,
NHWC,
NC4HW4,
Custom,
}
#[derive(Debug, Clone)]
pub struct AccessPattern {
pub strides: Vec<isize>,
pub access_type: AccessType,
pub reuse_distance: usize,
pub working_set_size: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AccessType {
Sequential,
Strided,
Random,
Broadcast,
}
#[derive(Debug, Clone)]
pub struct CacheHints {
pub prefetch_distance: usize,
pub bypass_cache: Vec<String>,
pub temporal_hints: HashMap<String, TemporalHint>,
pub spatial_hints: HashMap<String, SpatialHint>,
}
#[derive(Debug, Clone, Copy)]
pub enum TemporalHint {
HighReuse,
MediumReuse,
LowReuse,
NoReuse,
}
#[derive(Debug, Clone, Copy)]
pub enum SpatialHint {
Contiguous,
Strided { stride: usize },
Blocked { block_size: usize },
}
pub struct CacheOptimizer {
platform: MobilePlatform,
cache_hierarchy: CacheHierarchy,
optimization_cache: HashMap<String, CacheStrategy>,
}
#[derive(Debug, Clone)]
struct CacheHierarchy {
l1_size: usize,
l1_line_size: usize,
l1_associativity: usize,
l2_size: usize,
l2_line_size: usize,
l2_associativity: usize,
l3_size: Option<usize>,
}
impl CacheOptimizer {
pub fn new(platform: MobilePlatform) -> Self {
let cache_hierarchy = Self::detect_cache_hierarchy(&platform);
Self {
platform,
cache_hierarchy,
optimization_cache: HashMap::new(),
}
}
pub fn optimize_layout(&self, tensor: &Tensor, pattern: &AccessPattern) -> Result<Tensor> {
let layout = self.select_optimal_layout(&tensor.shape(), pattern)?;
match layout {
DataLayout::NHWC if self.current_layout_is_nchw(tensor) => {
self.transpose_nchw_to_nhwc(tensor)
},
DataLayout::NC4HW4 => self.pack_to_nc4hw4(tensor),
_ => Ok(tensor.clone()),
}
}
pub fn generate_hints(
&self,
kernel: &KernelType,
input_shapes: &[Vec<usize>],
) -> Result<CacheHints> {
let prefetch_distance = self.calculate_prefetch_distance(kernel, input_shapes)?;
let bypass_cache = self.identify_streaming_tensors(kernel, input_shapes)?;
let temporal_hints = self.analyze_temporal_locality(kernel, input_shapes)?;
let spatial_hints = self.analyze_spatial_locality(kernel, input_shapes)?;
Ok(CacheHints {
prefetch_distance,
bypass_cache,
temporal_hints,
spatial_hints,
})
}
pub fn apply_tiling(&self, graph: &mut ComputationGraph) -> Result<()> {
for operator in &mut graph.operators {
if self.can_tile(&operator.kernel) {
let tiling_config = self.compute_tiling_config(operator)?;
if let Some(ref mut hints) = operator.cache_hints {
self.apply_tiling_to_operator(operator, &tiling_config)?;
}
}
}
Ok(())
}
pub fn analyze(&self, tensor_name: &str, kernel: &KernelType) -> Result<AccessPattern> {
let access_type = match kernel {
KernelType::Conv2d => AccessType::Strided,
KernelType::Linear => AccessType::Sequential,
KernelType::Attention => AccessType::Random,
_ => AccessType::Sequential,
};
let pattern = AccessPattern {
strides: self.compute_strides(kernel)?,
access_type,
reuse_distance: self.estimate_reuse_distance(kernel)?,
working_set_size: self.estimate_working_set(kernel)?,
};
Ok(pattern)
}
fn detect_cache_hierarchy(platform: &MobilePlatform) -> CacheHierarchy {
match platform {
MobilePlatform::Ios => CacheHierarchy {
l1_size: 64 * 1024, l1_line_size: 64,
l1_associativity: 4,
l2_size: 3 * 1024 * 1024, l2_line_size: 128,
l2_associativity: 8,
l3_size: None,
},
MobilePlatform::Android => CacheHierarchy {
l1_size: 32 * 1024, l1_line_size: 64,
l1_associativity: 4,
l2_size: 1024 * 1024, l2_line_size: 64,
l2_associativity: 8,
l3_size: None,
},
MobilePlatform::Generic => CacheHierarchy {
l1_size: 32 * 1024,
l1_line_size: 64,
l1_associativity: 4,
l2_size: 256 * 1024,
l2_line_size: 64,
l2_associativity: 8,
l3_size: None,
},
}
}
fn select_optimal_layout(
&self,
shape: &[usize],
pattern: &AccessPattern,
) -> Result<DataLayout> {
match pattern.access_type {
AccessType::Sequential => Ok(DataLayout::NCHW),
AccessType::Strided => {
if self.platform == MobilePlatform::Android {
Ok(DataLayout::NHWC)
} else {
Ok(DataLayout::NCHW)
}
},
AccessType::Random => Ok(DataLayout::Custom),
AccessType::Broadcast => Ok(DataLayout::NC4HW4), }
}
fn current_layout_is_nchw(&self, tensor: &Tensor) -> bool {
tensor.shape().len() == 4
}
fn transpose_nchw_to_nhwc(&self, tensor: &Tensor) -> Result<Tensor> {
if tensor.shape().len() != 4 {
return Err(TrustformersError::tensor_op_error(
"Expected 4D tensor",
"transpose_nchw_to_nhwc",
));
}
let [n, c, h, w] = [
tensor.shape()[0],
tensor.shape()[1],
tensor.shape()[2],
tensor.shape()[3],
];
let mut transposed_data = vec![0.0f32; n * h * w * c];
let src_data = tensor.data()?;
for batch in 0..n {
for channel in 0..c {
for row in 0..h {
for col in 0..w {
let src_idx = batch * c * h * w + channel * h * w + row * w + col;
let dst_idx = batch * h * w * c + row * w * c + col * c + channel;
transposed_data[dst_idx] = src_data[src_idx];
}
}
}
}
Tensor::from_vec(transposed_data, &[n, h, w, c])
}
fn pack_to_nc4hw4(&self, tensor: &Tensor) -> Result<Tensor> {
if tensor.shape().len() != 4 {
return Err(TrustformersError::tensor_op_error(
"Expected 4D tensor",
"transpose_nchw_to_nhwc",
));
}
let [n, c, h, w] = [
tensor.shape()[0],
tensor.shape()[1],
tensor.shape()[2],
tensor.shape()[3],
];
let c_padded = c.div_ceil(4) * 4;
let mut packed_data = vec![0.0f32; n * c_padded * h * w];
let src_data = tensor.data()?;
for batch in 0..n {
for c_group in 0..c.div_ceil(4) {
for row in 0..h {
for col in 0..w {
for c_offset in 0..4 {
let c_idx = c_group * 4 + c_offset;
if c_idx < c {
let src_idx = batch * c * h * w + c_idx * h * w + row * w + col;
let dst_idx = batch * c_padded * h * w
+ c_group * 4 * h * w
+ row * w * 4
+ col * 4
+ c_offset;
packed_data[dst_idx] = src_data[src_idx];
}
}
}
}
}
}
Tensor::from_vec(packed_data, &[n, c_padded, h, w])
}
fn calculate_prefetch_distance(
&self,
kernel: &KernelType,
input_shapes: &[Vec<usize>],
) -> Result<usize> {
let compute_intensity = match kernel {
KernelType::Conv2d => 10.0,
KernelType::Linear => 2.0,
KernelType::Attention => 5.0,
_ => 1.0,
};
Ok((compute_intensity * 2.0) as usize)
}
fn identify_streaming_tensors(
&self,
kernel: &KernelType,
input_shapes: &[Vec<usize>],
) -> Result<Vec<String>> {
let mut streaming = Vec::new();
for (idx, shape) in input_shapes.iter().enumerate() {
let size_bytes = shape.iter().product::<usize>() * 4;
if size_bytes > self.cache_hierarchy.l2_size {
streaming.push(format!("input_{}", idx));
}
}
Ok(streaming)
}
fn analyze_temporal_locality(
&self,
kernel: &KernelType,
input_shapes: &[Vec<usize>],
) -> Result<HashMap<String, TemporalHint>> {
let mut hints = HashMap::new();
match kernel {
KernelType::Conv2d => {
hints.insert("weights".to_string(), TemporalHint::HighReuse);
hints.insert("input".to_string(), TemporalHint::MediumReuse);
},
KernelType::Linear => {
hints.insert("weights".to_string(), TemporalHint::HighReuse);
hints.insert("input".to_string(), TemporalHint::LowReuse);
},
KernelType::BatchNorm => {
hints.insert("mean".to_string(), TemporalHint::HighReuse);
hints.insert("variance".to_string(), TemporalHint::HighReuse);
},
_ => {},
}
Ok(hints)
}
fn analyze_spatial_locality(
&self,
kernel: &KernelType,
input_shapes: &[Vec<usize>],
) -> Result<HashMap<String, SpatialHint>> {
let mut hints = HashMap::new();
match kernel {
KernelType::Conv2d => {
hints.insert("input".to_string(), SpatialHint::Blocked { block_size: 16 });
hints.insert("output".to_string(), SpatialHint::Contiguous);
},
KernelType::Linear => {
hints.insert("input".to_string(), SpatialHint::Contiguous);
hints.insert(
"weights".to_string(),
SpatialHint::Strided {
stride: input_shapes[0][1],
},
);
},
_ => {
hints.insert("default".to_string(), SpatialHint::Contiguous);
},
}
Ok(hints)
}
fn can_tile(&self, kernel: &KernelType) -> bool {
matches!(
kernel,
KernelType::Conv2d | KernelType::Linear | KernelType::Attention
)
}
fn compute_tiling_config(&self, operator: &GraphOperator) -> Result<TilingConfig> {
let tile_size = self.compute_optimal_tile_size(operator)?;
let config = match operator.kernel {
KernelType::Conv2d => TilingConfig {
tile_sizes: vec![1, tile_size, tile_size, tile_size], loop_order: vec![0, 2, 3, 1], unroll_factors: vec![1, 4, 1, 1],
},
KernelType::Linear => TilingConfig {
tile_sizes: vec![tile_size, tile_size], loop_order: vec![0, 1],
unroll_factors: vec![4, 4],
},
_ => TilingConfig {
tile_sizes: vec![tile_size],
loop_order: vec![0],
unroll_factors: vec![4],
},
};
Ok(config)
}
fn compute_optimal_tile_size(&self, operator: &GraphOperator) -> Result<usize> {
let working_set_elements = operator
.input_shapes
.iter()
.map(|shape| shape.iter().product::<usize>())
.sum::<usize>();
let element_size = 4; let working_set_bytes = working_set_elements * element_size;
let available_cache = (self.cache_hierarchy.l1_size as f32 * 0.75) as usize;
let mut tile_size = 64;
while tile_size * tile_size * element_size > available_cache && tile_size > 8 {
tile_size /= 2;
}
Ok(tile_size)
}
fn apply_tiling_to_operator(
&self,
operator: &mut GraphOperator,
config: &TilingConfig,
) -> Result<()> {
operator.kernel = match operator.kernel.clone() {
KernelType::Conv2d => KernelType::Custom("TiledConv2d".to_string()),
KernelType::Linear => KernelType::Custom("TiledLinear".to_string()),
other => other,
};
Ok(())
}
fn compute_strides(&self, kernel: &KernelType) -> Result<Vec<isize>> {
Ok(match kernel {
KernelType::Conv2d => vec![1, 1], KernelType::Pooling => vec![2, 2], _ => vec![1], })
}
fn estimate_reuse_distance(&self, kernel: &KernelType) -> Result<usize> {
Ok(match kernel {
KernelType::Conv2d => 1024, KernelType::Linear => 256, KernelType::BatchNorm => 64, _ => 128,
})
}
fn estimate_working_set(&self, kernel: &KernelType) -> Result<usize> {
Ok(match kernel {
KernelType::Conv2d => 256 * 1024, KernelType::Linear => 128 * 1024, KernelType::Attention => 512 * 1024, _ => 64 * 1024, })
}
}
impl AccessPattern {
pub fn analyze(tensor_name: &str, kernel: &KernelType) -> Result<Self> {
let access_type = if tensor_name.contains("weight") {
AccessType::Broadcast
} else if matches!(kernel, KernelType::Attention) {
AccessType::Random
} else {
AccessType::Sequential
};
Ok(Self {
strides: vec![1],
access_type,
reuse_distance: 128,
working_set_size: 64 * 1024,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_optimizer_creation() {
let optimizer = CacheOptimizer::new(MobilePlatform::Generic);
assert_eq!(optimizer.platform, MobilePlatform::Generic);
}
#[test]
fn test_layout_optimization() {
let optimizer = CacheOptimizer::new(MobilePlatform::Android);
let tensor = Tensor::ones(&[1, 64, 32, 32]).expect("Operation failed");
let pattern = AccessPattern {
strides: vec![1, 1],
access_type: AccessType::Strided,
reuse_distance: 128,
working_set_size: 64 * 1024,
};
let optimized = optimizer.optimize_layout(&tensor, &pattern).expect("Operation failed");
assert_eq!(optimized.shape(), &[1, 32, 32, 64]);
}
#[test]
fn test_tiling_config() {
let optimizer = CacheOptimizer::new(MobilePlatform::Ios);
let operator = GraphOperator {
id: 0,
kernel: KernelType::Conv2d,
inputs: vec!["input".to_string()],
outputs: vec!["output".to_string()],
input_shapes: vec![vec![1, 64, 32, 32]],
output_shape: vec![1, 128, 16, 16],
cache_hints: None,
};
let config = optimizer.compute_tiling_config(&operator).expect("Operation failed");
assert!(!config.tile_sizes.is_empty());
assert!(!config.loop_order.is_empty());
}
#[test]
fn test_cache_hints_generation() {
let optimizer = CacheOptimizer::new(MobilePlatform::Generic);
let kernel = KernelType::Conv2d;
let input_shapes = vec![vec![1, 3, 224, 224]];
let hints = optimizer.generate_hints(&kernel, &input_shapes).expect("Operation failed");
assert!(hints.prefetch_distance > 0);
assert!(!hints.temporal_hints.is_empty());
}
}