use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::{EinsumGraph, IrError};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
pub enum LayoutStrategy {
#[default]
RowMajor,
ColumnMajor,
Blocked { block_size: usize },
Tiled {
tile_height: usize,
tile_width: usize,
},
ZOrder,
Hilbert,
}
impl LayoutStrategy {
pub fn for_operation(op: &str) -> Self {
match op {
"matmul" | "einsum" => Self::Blocked { block_size: 32 },
"transpose" => Self::ColumnMajor,
"conv2d" => Self::Tiled {
tile_height: 8,
tile_width: 8,
},
"scan" | "reduce" => Self::RowMajor,
_ => Self::default(),
}
}
pub fn supports_vectorization(&self) -> bool {
matches!(
self,
Self::RowMajor | Self::Blocked { .. } | Self::Tiled { .. }
)
}
pub fn preserves_locality(&self) -> bool {
matches!(
self,
Self::Blocked { .. } | Self::Tiled { .. } | Self::ZOrder | Self::Hilbert
)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct StridePattern {
pub strides: Vec<usize>,
pub is_contiguous: bool,
pub alignment: usize,
}
impl StridePattern {
pub fn row_major(dims: &[usize]) -> Self {
let mut strides = vec![1];
for i in (0..dims.len() - 1).rev() {
strides.insert(0, strides[0] * dims[i + 1]);
}
Self {
strides,
is_contiguous: true,
alignment: 0,
}
}
pub fn column_major(dims: &[usize]) -> Self {
let mut strides = vec![1];
for i in 0..dims.len() - 1 {
strides.push(strides[i] * dims[i]);
}
Self {
strides,
is_contiguous: true,
alignment: 0,
}
}
pub fn custom(strides: Vec<usize>) -> Self {
let is_contiguous = is_contiguous_strides(&strides);
Self {
strides,
is_contiguous,
alignment: 0,
}
}
pub fn with_alignment(mut self, alignment: usize) -> Self {
self.alignment = alignment;
self
}
pub fn is_vectorizable(&self) -> bool {
self.is_contiguous && self.strides.last().copied().unwrap_or(0) == 1
}
pub fn access_cost(&self) -> f64 {
if self.is_contiguous {
1.0
} else {
1.5 + (self.strides.len() as f64 * 0.1)
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct TensorLayout {
pub tensor_idx: usize,
pub strategy: LayoutStrategy,
pub strides: StridePattern,
pub is_mutable: bool,
}
impl TensorLayout {
pub fn new(tensor_idx: usize, strategy: LayoutStrategy, dims: &[usize]) -> Self {
let strides = match strategy {
LayoutStrategy::RowMajor => StridePattern::row_major(dims),
LayoutStrategy::ColumnMajor => StridePattern::column_major(dims),
_ => StridePattern::row_major(dims), };
Self {
tensor_idx,
strategy,
strides,
is_mutable: true,
}
}
pub fn access_efficiency(&self) -> f64 {
let base_efficiency = if self.strides.is_contiguous { 0.9 } else { 0.5 };
let locality_bonus: f64 = if self.strategy.preserves_locality() {
0.1
} else {
0.0
};
(base_efficiency + locality_bonus).min(1.0f64)
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct LayoutOptimizationResult {
pub layouts: HashMap<usize, TensorLayout>,
pub transformations_needed: usize,
pub estimated_improvement: f64,
pub estimated_speedup: f64,
}
impl LayoutOptimizationResult {
pub fn none() -> Self {
Self {
layouts: HashMap::new(),
transformations_needed: 0,
estimated_improvement: 0.0,
estimated_speedup: 1.0,
}
}
pub fn get_layout(&self, tensor_idx: usize) -> Option<&TensorLayout> {
self.layouts.get(&tensor_idx)
}
}
pub fn optimize_layouts(graph: &EinsumGraph) -> Result<LayoutOptimizationResult, IrError> {
let mut result = LayoutOptimizationResult::none();
for (tensor_idx, tensor_name) in graph.tensors.iter().enumerate() {
let dims = infer_dimensions(tensor_name, graph, tensor_idx);
let strategy = analyze_usage_pattern(graph, tensor_idx);
let layout = TensorLayout::new(tensor_idx, strategy, &dims);
result.layouts.insert(tensor_idx, layout);
}
result.transformations_needed = count_layout_conversions(&result.layouts);
let avg_efficiency: f64 = result
.layouts
.values()
.map(|l| l.access_efficiency())
.sum::<f64>()
/ result.layouts.len().max(1) as f64;
result.estimated_improvement = (avg_efficiency - 0.7).max(0.0);
result.estimated_speedup = 1.0 + result.estimated_improvement * 0.3;
Ok(result)
}
pub fn apply_layouts(
graph: &mut EinsumGraph,
layouts: &HashMap<usize, TensorLayout>,
) -> Result<(), IrError> {
for (tensor_idx, layout) in layouts {
if *tensor_idx < graph.tensors.len() {
let mut metadata = graph
.get_tensor_metadata(*tensor_idx)
.cloned()
.unwrap_or_else(crate::Metadata::new);
metadata
.attributes
.push(("layout".to_string(), format!("{:?}", layout.strategy)));
metadata.attributes.push((
"is_contiguous".to_string(),
layout.strides.is_contiguous.to_string(),
));
graph.add_tensor_metadata(*tensor_idx, metadata);
}
}
Ok(())
}
pub fn find_layout_fusion_opportunities(
layouts: &HashMap<usize, TensorLayout>,
) -> Vec<(usize, usize)> {
let mut opportunities = Vec::new();
let tensor_indices: Vec<_> = layouts.keys().copied().collect();
for i in 0..tensor_indices.len() {
for j in (i + 1)..tensor_indices.len() {
let idx1 = tensor_indices[i];
let idx2 = tensor_indices[j];
if let (Some(layout1), Some(layout2)) = (layouts.get(&idx1), layouts.get(&idx2)) {
if layout1.strategy != layout2.strategy && layout1.is_mutable && layout2.is_mutable
{
opportunities.push((idx1, idx2));
}
}
}
}
opportunities
}
fn infer_dimensions(_tensor_name: &str, _graph: &EinsumGraph, _tensor_idx: usize) -> Vec<usize> {
vec![64, 64]
}
fn analyze_usage_pattern(graph: &EinsumGraph, tensor_idx: usize) -> LayoutStrategy {
let mut read_patterns = Vec::new();
for node in &graph.nodes {
if node.inputs.contains(&tensor_idx) {
let pattern = match &node.op {
crate::OpType::Einsum { spec } => analyze_einsum_pattern(spec),
crate::OpType::Reduce { .. } => "reduce",
crate::OpType::ElemUnary { .. } => "scan",
crate::OpType::ElemBinary { .. } => "scan",
};
read_patterns.push(pattern);
}
}
if read_patterns.contains(&"matmul") {
LayoutStrategy::Blocked { block_size: 32 }
} else if read_patterns.contains(&"transpose") {
LayoutStrategy::ColumnMajor
} else if read_patterns.contains(&"conv") {
LayoutStrategy::Tiled {
tile_height: 8,
tile_width: 8,
}
} else {
LayoutStrategy::RowMajor
}
}
fn analyze_einsum_pattern(spec: &str) -> &'static str {
if spec.contains(',') {
"matmul"
} else if spec.contains("->") {
let parts: Vec<&str> = spec.split("->").collect();
if parts.len() == 2 && parts[0].len() > parts[1].len() {
"reduce"
} else {
"scan"
}
} else {
"scan"
}
}
fn count_layout_conversions(layouts: &HashMap<usize, TensorLayout>) -> usize {
layouts
.values()
.filter(|l| l.strategy != LayoutStrategy::RowMajor)
.count()
}
fn is_contiguous_strides(strides: &[usize]) -> bool {
if strides.is_empty() {
return true;
}
let mut prev = strides[strides.len() - 1];
if prev != 1 {
return false;
}
for &stride in strides.iter().rev().skip(1) {
if stride <= prev {
return false;
}
let ratio = stride / prev;
if ratio == 0 || ratio > 10000 {
return false;
}
if stride % prev != 0 {
return false;
}
prev = stride;
}
true
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_layout_strategy_default() {
assert_eq!(LayoutStrategy::default(), LayoutStrategy::RowMajor);
}
#[test]
fn test_layout_strategy_for_operation() {
let matmul_layout = LayoutStrategy::for_operation("matmul");
assert!(matches!(matmul_layout, LayoutStrategy::Blocked { .. }));
let transpose_layout = LayoutStrategy::for_operation("transpose");
assert_eq!(transpose_layout, LayoutStrategy::ColumnMajor);
let conv_layout = LayoutStrategy::for_operation("conv2d");
assert!(matches!(conv_layout, LayoutStrategy::Tiled { .. }));
}
#[test]
fn test_layout_strategy_vectorization() {
assert!(LayoutStrategy::RowMajor.supports_vectorization());
assert!(LayoutStrategy::Blocked { block_size: 32 }.supports_vectorization());
assert!(!LayoutStrategy::ZOrder.supports_vectorization());
}
#[test]
fn test_layout_strategy_locality() {
assert!(LayoutStrategy::Blocked { block_size: 32 }.preserves_locality());
assert!(LayoutStrategy::ZOrder.preserves_locality());
assert!(LayoutStrategy::Hilbert.preserves_locality());
assert!(!LayoutStrategy::RowMajor.preserves_locality());
}
#[test]
fn test_stride_pattern_row_major() {
let dims = vec![4, 8, 16];
let pattern = StridePattern::row_major(&dims);
assert_eq!(pattern.strides, vec![128, 16, 1]);
assert!(pattern.is_contiguous);
assert!(pattern.is_vectorizable());
}
#[test]
fn test_stride_pattern_column_major() {
let dims = vec![4, 8, 16];
let pattern = StridePattern::column_major(&dims);
assert_eq!(pattern.strides, vec![1, 4, 32]);
assert!(pattern.is_contiguous);
}
#[test]
fn test_stride_pattern_custom() {
let strides = vec![64, 8, 1];
let pattern = StridePattern::custom(strides.clone());
assert_eq!(pattern.strides, strides);
assert!(pattern.is_contiguous);
}
#[test]
fn test_stride_pattern_non_contiguous() {
let strides = vec![100, 10, 2]; let pattern = StridePattern::custom(strides);
assert!(!pattern.is_contiguous);
assert!(!pattern.is_vectorizable());
}
#[test]
fn test_stride_pattern_with_alignment() {
let pattern = StridePattern::row_major(&[4, 8]).with_alignment(64);
assert_eq!(pattern.alignment, 64);
}
#[test]
fn test_stride_pattern_access_cost() {
let contiguous = StridePattern::row_major(&[4, 8]);
let non_contiguous = StridePattern::custom(vec![100, 10, 2]);
assert!(contiguous.access_cost() < non_contiguous.access_cost());
}
#[test]
fn test_tensor_layout_creation() {
let layout = TensorLayout::new(0, LayoutStrategy::RowMajor, &[4, 8]);
assert_eq!(layout.tensor_idx, 0);
assert_eq!(layout.strategy, LayoutStrategy::RowMajor);
assert!(layout.is_mutable);
assert!(layout.strides.is_contiguous);
}
#[test]
fn test_tensor_layout_access_efficiency() {
let row_major = TensorLayout::new(0, LayoutStrategy::RowMajor, &[4, 8]);
let blocked = TensorLayout::new(0, LayoutStrategy::Blocked { block_size: 32 }, &[4, 8]);
let row_efficiency = row_major.access_efficiency();
let blocked_efficiency = blocked.access_efficiency();
assert!(row_efficiency > 0.0 && row_efficiency <= 1.0);
assert!(blocked_efficiency > row_efficiency); }
#[test]
fn test_layout_optimization_result_none() {
let result = LayoutOptimizationResult::none();
assert!(result.layouts.is_empty());
assert_eq!(result.transformations_needed, 0);
assert_eq!(result.estimated_improvement, 0.0);
assert_eq!(result.estimated_speedup, 1.0);
}
#[test]
fn test_optimize_layouts_empty_graph() {
let graph = EinsumGraph::new();
let result = optimize_layouts(&graph).expect("unwrap");
assert!(result.layouts.is_empty());
}
#[test]
fn test_optimize_layouts_simple_graph() {
let mut graph = EinsumGraph::new();
let a = graph.add_tensor("A");
let b = graph.add_tensor("B");
let c = graph.add_tensor("C");
graph
.add_node(crate::EinsumNode::einsum("ik,kj->ij", vec![a, b], vec![c]))
.expect("unwrap");
let result = optimize_layouts(&graph).expect("unwrap");
assert_eq!(result.layouts.len(), 3);
assert!(result.estimated_speedup >= 1.0);
}
#[test]
fn test_apply_layouts() {
let mut graph = EinsumGraph::new();
let a = graph.add_tensor("A");
let mut layouts = HashMap::new();
layouts.insert(
a,
TensorLayout::new(a, LayoutStrategy::Blocked { block_size: 32 }, &[64, 64]),
);
apply_layouts(&mut graph, &layouts).expect("unwrap");
let metadata = graph.get_tensor_metadata(a);
assert!(metadata.is_some());
}
#[test]
fn test_find_layout_fusion_opportunities() {
let mut layouts = HashMap::new();
layouts.insert(0, TensorLayout::new(0, LayoutStrategy::RowMajor, &[4, 8]));
layouts.insert(
1,
TensorLayout::new(1, LayoutStrategy::ColumnMajor, &[4, 8]),
);
layouts.insert(2, TensorLayout::new(2, LayoutStrategy::RowMajor, &[4, 8]));
let opportunities = find_layout_fusion_opportunities(&layouts);
assert!(!opportunities.is_empty());
}
#[test]
fn test_analyze_einsum_pattern() {
assert_eq!(analyze_einsum_pattern("ik,kj->ij"), "matmul");
assert_eq!(analyze_einsum_pattern("ijk->ij"), "reduce");
assert_eq!(analyze_einsum_pattern("ij->ij"), "scan");
}
#[test]
fn test_is_contiguous_strides() {
assert!(is_contiguous_strides(&[8, 4, 1]));
assert!(is_contiguous_strides(&[1]));
assert!(is_contiguous_strides(&[]));
assert!(is_contiguous_strides(&[8, 2, 1])); assert!(!is_contiguous_strides(&[8, 4, 2])); assert!(!is_contiguous_strides(&[9, 2, 1])); }
#[test]
fn test_count_layout_conversions() {
let mut layouts = HashMap::new();
layouts.insert(0, TensorLayout::new(0, LayoutStrategy::RowMajor, &[4, 8]));
layouts.insert(
1,
TensorLayout::new(1, LayoutStrategy::ColumnMajor, &[4, 8]),
);
layouts.insert(
2,
TensorLayout::new(2, LayoutStrategy::Blocked { block_size: 32 }, &[4, 8]),
);
let conversions = count_layout_conversions(&layouts);
assert_eq!(conversions, 2); }
#[test]
fn test_layout_optimization_with_metadata() {
let mut graph = EinsumGraph::new();
let a = graph.add_tensor("A");
let b = graph.add_tensor("B");
let metadata = crate::Metadata::new().with_attribute("preferred_layout", "blocked");
graph.add_tensor_metadata(a, metadata);
let result = optimize_layouts(&graph).expect("unwrap");
assert!(result.get_layout(a).is_some());
assert!(result.get_layout(b).is_some());
}
}