use crate::backend::hardware::{HardwareTarget, MemorySpace};
use crate::object::{Layout, ObjectMeta, Representation, Shape};
use crate::{Error, Result};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum BufferLayout {
Scalar,
Contiguous {
element_count: usize,
},
Strided {
shape: Shape,
strides: Vec<isize>,
},
Blocked {
shape: Shape,
block_shape: Vec<usize>,
},
CoverIndexed {
open_count: usize,
},
PadicLimbArray {
element_count: usize,
precision_digits: u32,
},
PadicBall {
element_count: usize,
},
Symbolic,
}
impl BufferLayout {
pub fn from_meta(meta: &ObjectMeta) -> Self {
let element_count = static_element_count(&meta.shape).unwrap_or(0);
match &meta.representation {
Representation::DenseTensor { layout, .. } => match layout {
Layout::Scalar => Self::Scalar,
Layout::Contiguous => Self::Contiguous { element_count },
Layout::Strided(strides) => Self::Strided {
shape: meta.shape.clone(),
strides: strides.clone(),
},
Layout::Blocked { block_shape } => Self::Blocked {
shape: meta.shape.clone(),
block_shape: block_shape.clone(),
},
Layout::CoverIndexed => Self::CoverIndexed {
open_count: element_count,
},
Layout::PadicLimbArray => Self::PadicLimbArray {
element_count,
precision_digits: parse_padic_precision(&meta.domain.0).unwrap_or(0),
},
Layout::PadicBall => Self::PadicBall { element_count },
Layout::Symbolic => Self::Symbolic,
},
Representation::SparseTensor { layout, .. } => match layout {
Layout::Contiguous => Self::Contiguous { element_count },
Layout::Strided(strides) => Self::Strided {
shape: meta.shape.clone(),
strides: strides.clone(),
},
Layout::Blocked { block_shape } => Self::Blocked {
shape: meta.shape.clone(),
block_shape: block_shape.clone(),
},
Layout::CoverIndexed => Self::CoverIndexed {
open_count: element_count,
},
Layout::PadicLimbArray => Self::PadicLimbArray {
element_count,
precision_digits: parse_padic_precision(&meta.domain.0).unwrap_or(0),
},
Layout::PadicBall => Self::PadicBall { element_count },
Layout::Scalar => Self::Scalar,
Layout::Symbolic => Self::Symbolic,
},
Representation::CoverIndexedSection { .. } => Self::CoverIndexed {
open_count: element_count,
},
Representation::PadicScalar => Self::PadicLimbArray {
element_count: 1,
precision_digits: parse_padic_precision(&meta.domain.0).unwrap_or(0),
},
Representation::LazyExpression | Representation::Symbolic => Self::Symbolic,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DeviceBuffer {
pub id: String,
pub target: HardwareTarget,
pub memory_space: MemorySpace,
pub layout: BufferLayout,
pub meta: ObjectMeta,
pub byte_len: Option<usize>,
pub oracle_fixture: Option<String>,
}
impl DeviceBuffer {
pub fn host_tensor(
id: impl Into<String>,
target: HardwareTarget,
meta: ObjectMeta,
element_size: usize,
) -> Self {
let element_count = static_element_count(&meta.shape).unwrap_or(1);
Self {
id: id.into(),
target,
memory_space: MemorySpace::Host,
layout: BufferLayout::from_meta(&meta),
meta,
byte_len: Some(element_count.saturating_mul(element_size)),
oracle_fixture: None,
}
}
pub fn with_oracle_fixture(mut self, fixture: impl Into<String>) -> Self {
self.oracle_fixture = Some(fixture.into());
self
}
pub fn is_host_accessible(&self) -> bool {
matches!(
self.memory_space,
MemorySpace::Host | MemorySpace::PinnedHost | MemorySpace::Unified
)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TransferDirection {
HostToDevice,
DeviceToHost,
DeviceToDevice,
HostToHost,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TransferStatus {
Supported,
NoOp,
Unsupported(TransferFallbackReason),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TransferFallbackReason {
MissingDeviceMemory {
source: MemorySpace,
destination: MemorySpace,
},
MissingPinnedMemory {
source: MemorySpace,
destination: MemorySpace,
},
UnsupportedLayout {
layout: BufferLayout,
},
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TransferPlan {
pub source: DeviceBuffer,
pub destination: DeviceBuffer,
pub direction: TransferDirection,
pub status: TransferStatus,
pub preserves_semantics: bool,
pub reason: Option<TransferFallbackReason>,
}
impl TransferPlan {
pub fn plan(source: DeviceBuffer, destination: DeviceBuffer) -> Self {
let direction = transfer_direction(source.memory_space, destination.memory_space);
let reason = transfer_fallback_reason(&source, &destination);
let status = if source.memory_space == destination.memory_space
&& source.target == destination.target
&& source.layout == destination.layout
{
TransferStatus::NoOp
} else if let Some(reason) = reason.clone() {
TransferStatus::Unsupported(reason)
} else {
TransferStatus::Supported
};
Self {
source,
destination,
direction,
status,
preserves_semantics: reason.is_none(),
reason,
}
}
pub fn require_supported(&self) -> Result<()> {
match &self.status {
TransferStatus::Supported | TransferStatus::NoOp => Ok(()),
TransferStatus::Unsupported(reason) => Err(Error::backend(format!(
"unsupported transfer plan: {}",
transfer_reason_message(reason)
))),
}
}
}
pub fn static_element_count(shape: &Shape) -> Option<usize> {
shape.dims.iter().try_fold(1usize, |acc, dim| match dim {
crate::object::Dim::Static(value) => acc.checked_mul(*value),
_ => None,
})
}
fn transfer_direction(source: MemorySpace, destination: MemorySpace) -> TransferDirection {
match (source, destination) {
(MemorySpace::Host | MemorySpace::PinnedHost, MemorySpace::Device) => {
TransferDirection::HostToDevice
}
(MemorySpace::Device, MemorySpace::Host | MemorySpace::PinnedHost) => {
TransferDirection::DeviceToHost
}
(MemorySpace::Device, MemorySpace::Device) => TransferDirection::DeviceToDevice,
_ => TransferDirection::HostToHost,
}
}
fn transfer_fallback_reason(
source: &DeviceBuffer,
destination: &DeviceBuffer,
) -> Option<TransferFallbackReason> {
if source.layout != destination.layout {
return Some(TransferFallbackReason::UnsupportedLayout {
layout: destination.layout.clone(),
});
}
match (source.memory_space, destination.memory_space) {
(MemorySpace::Host, MemorySpace::Device) | (MemorySpace::Device, MemorySpace::Host) => {
Some(TransferFallbackReason::MissingDeviceMemory {
source: source.memory_space,
destination: destination.memory_space,
})
}
(MemorySpace::PinnedHost, MemorySpace::Device)
| (MemorySpace::Device, MemorySpace::PinnedHost) => {
Some(TransferFallbackReason::MissingPinnedMemory {
source: source.memory_space,
destination: destination.memory_space,
})
}
_ => None,
}
}
fn transfer_reason_message(reason: &TransferFallbackReason) -> String {
match reason {
TransferFallbackReason::MissingDeviceMemory {
source,
destination,
} => format!("missing device-memory implementation for {source:?} -> {destination:?}"),
TransferFallbackReason::MissingPinnedMemory {
source,
destination,
} => format!("missing pinned-memory transfer path for {source:?} -> {destination:?}"),
TransferFallbackReason::UnsupportedLayout { layout } => {
format!("unsupported destination buffer layout {layout:?}")
}
}
}
fn parse_padic_precision(domain: &str) -> Option<u32> {
domain
.rsplit_once('[')
.and_then(|(_, suffix)| suffix.strip_suffix(']'))
.and_then(|digits| digits.parse().ok())
}