tokitai-operator 0.1.0

Verified DL kernel compiler: formally-checked GEMM, p-adic, sheaf, contract-carrying ops. Paper-artifact grade.
Documentation
//! Memory model and layout helpers.
//!
//! `MemorySpace` (Cpu, GpuHbm, GpuShared) and `Layout` are the
//! shared types that the backends consume to express where a
//! tensor lives. `MemorySpace` is part of the
//! `ObjectMeta` fingerprint chain; changing it invalidates the
//! cap-fingerprint of any plan that depends on the tensor.
//!
use crate::backend::hardware::{HardwareTarget, MemorySpace};
use crate::object::{Layout, ObjectMeta, Representation, Shape};
use crate::{Error, Result};

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum BufferLayout {
    Scalar,
    Contiguous {
        element_count: usize,
    },
    Strided {
        shape: Shape,
        strides: Vec<isize>,
    },
    Blocked {
        shape: Shape,
        block_shape: Vec<usize>,
    },
    CoverIndexed {
        open_count: usize,
    },
    PadicLimbArray {
        element_count: usize,
        precision_digits: u32,
    },
    PadicBall {
        element_count: usize,
    },
    Symbolic,
}

impl BufferLayout {
    pub fn from_meta(meta: &ObjectMeta) -> Self {
        let element_count = static_element_count(&meta.shape).unwrap_or(0);
        match &meta.representation {
            Representation::DenseTensor { layout, .. } => match layout {
                Layout::Scalar => Self::Scalar,
                Layout::Contiguous => Self::Contiguous { element_count },
                Layout::Strided(strides) => Self::Strided {
                    shape: meta.shape.clone(),
                    strides: strides.clone(),
                },
                Layout::Blocked { block_shape } => Self::Blocked {
                    shape: meta.shape.clone(),
                    block_shape: block_shape.clone(),
                },
                Layout::CoverIndexed => Self::CoverIndexed {
                    open_count: element_count,
                },
                Layout::PadicLimbArray => Self::PadicLimbArray {
                    element_count,
                    precision_digits: parse_padic_precision(&meta.domain.0).unwrap_or(0),
                },
                Layout::PadicBall => Self::PadicBall { element_count },
                Layout::Symbolic => Self::Symbolic,
            },
            Representation::SparseTensor { layout, .. } => match layout {
                Layout::Contiguous => Self::Contiguous { element_count },
                Layout::Strided(strides) => Self::Strided {
                    shape: meta.shape.clone(),
                    strides: strides.clone(),
                },
                Layout::Blocked { block_shape } => Self::Blocked {
                    shape: meta.shape.clone(),
                    block_shape: block_shape.clone(),
                },
                Layout::CoverIndexed => Self::CoverIndexed {
                    open_count: element_count,
                },
                Layout::PadicLimbArray => Self::PadicLimbArray {
                    element_count,
                    precision_digits: parse_padic_precision(&meta.domain.0).unwrap_or(0),
                },
                Layout::PadicBall => Self::PadicBall { element_count },
                Layout::Scalar => Self::Scalar,
                Layout::Symbolic => Self::Symbolic,
            },
            Representation::CoverIndexedSection { .. } => Self::CoverIndexed {
                open_count: element_count,
            },
            Representation::PadicScalar => Self::PadicLimbArray {
                element_count: 1,
                precision_digits: parse_padic_precision(&meta.domain.0).unwrap_or(0),
            },
            Representation::LazyExpression | Representation::Symbolic => Self::Symbolic,
        }
    }
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DeviceBuffer {
    pub id: String,
    pub target: HardwareTarget,
    pub memory_space: MemorySpace,
    pub layout: BufferLayout,
    pub meta: ObjectMeta,
    pub byte_len: Option<usize>,
    pub oracle_fixture: Option<String>,
}

impl DeviceBuffer {
    pub fn host_tensor(
        id: impl Into<String>,
        target: HardwareTarget,
        meta: ObjectMeta,
        element_size: usize,
    ) -> Self {
        let element_count = static_element_count(&meta.shape).unwrap_or(1);
        Self {
            id: id.into(),
            target,
            memory_space: MemorySpace::Host,
            layout: BufferLayout::from_meta(&meta),
            meta,
            byte_len: Some(element_count.saturating_mul(element_size)),
            oracle_fixture: None,
        }
    }

    pub fn with_oracle_fixture(mut self, fixture: impl Into<String>) -> Self {
        self.oracle_fixture = Some(fixture.into());
        self
    }

    pub fn is_host_accessible(&self) -> bool {
        matches!(
            self.memory_space,
            MemorySpace::Host | MemorySpace::PinnedHost | MemorySpace::Unified
        )
    }
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TransferDirection {
    HostToDevice,
    DeviceToHost,
    DeviceToDevice,
    HostToHost,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TransferStatus {
    Supported,
    NoOp,
    Unsupported(TransferFallbackReason),
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TransferFallbackReason {
    MissingDeviceMemory {
        source: MemorySpace,
        destination: MemorySpace,
    },
    MissingPinnedMemory {
        source: MemorySpace,
        destination: MemorySpace,
    },
    UnsupportedLayout {
        layout: BufferLayout,
    },
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TransferPlan {
    pub source: DeviceBuffer,
    pub destination: DeviceBuffer,
    pub direction: TransferDirection,
    pub status: TransferStatus,
    pub preserves_semantics: bool,
    pub reason: Option<TransferFallbackReason>,
}

impl TransferPlan {
    pub fn plan(source: DeviceBuffer, destination: DeviceBuffer) -> Self {
        let direction = transfer_direction(source.memory_space, destination.memory_space);
        let reason = transfer_fallback_reason(&source, &destination);
        let status = if source.memory_space == destination.memory_space
            && source.target == destination.target
            && source.layout == destination.layout
        {
            TransferStatus::NoOp
        } else if let Some(reason) = reason.clone() {
            TransferStatus::Unsupported(reason)
        } else {
            TransferStatus::Supported
        };
        Self {
            source,
            destination,
            direction,
            status,
            preserves_semantics: reason.is_none(),
            reason,
        }
    }

    pub fn require_supported(&self) -> Result<()> {
        match &self.status {
            TransferStatus::Supported | TransferStatus::NoOp => Ok(()),
            TransferStatus::Unsupported(reason) => Err(Error::backend(format!(
                "unsupported transfer plan: {}",
                transfer_reason_message(reason)
            ))),
        }
    }
}

pub fn static_element_count(shape: &Shape) -> Option<usize> {
    shape.dims.iter().try_fold(1usize, |acc, dim| match dim {
        crate::object::Dim::Static(value) => acc.checked_mul(*value),
        _ => None,
    })
}

fn transfer_direction(source: MemorySpace, destination: MemorySpace) -> TransferDirection {
    match (source, destination) {
        (MemorySpace::Host | MemorySpace::PinnedHost, MemorySpace::Device) => {
            TransferDirection::HostToDevice
        }
        (MemorySpace::Device, MemorySpace::Host | MemorySpace::PinnedHost) => {
            TransferDirection::DeviceToHost
        }
        (MemorySpace::Device, MemorySpace::Device) => TransferDirection::DeviceToDevice,
        _ => TransferDirection::HostToHost,
    }
}

fn transfer_fallback_reason(
    source: &DeviceBuffer,
    destination: &DeviceBuffer,
) -> Option<TransferFallbackReason> {
    if source.layout != destination.layout {
        return Some(TransferFallbackReason::UnsupportedLayout {
            layout: destination.layout.clone(),
        });
    }
    match (source.memory_space, destination.memory_space) {
        (MemorySpace::Host, MemorySpace::Device) | (MemorySpace::Device, MemorySpace::Host) => {
            Some(TransferFallbackReason::MissingDeviceMemory {
                source: source.memory_space,
                destination: destination.memory_space,
            })
        }
        (MemorySpace::PinnedHost, MemorySpace::Device)
        | (MemorySpace::Device, MemorySpace::PinnedHost) => {
            Some(TransferFallbackReason::MissingPinnedMemory {
                source: source.memory_space,
                destination: destination.memory_space,
            })
        }
        _ => None,
    }
}

fn transfer_reason_message(reason: &TransferFallbackReason) -> String {
    match reason {
        TransferFallbackReason::MissingDeviceMemory {
            source,
            destination,
        } => format!("missing device-memory implementation for {source:?} -> {destination:?}"),
        TransferFallbackReason::MissingPinnedMemory {
            source,
            destination,
        } => format!("missing pinned-memory transfer path for {source:?} -> {destination:?}"),
        TransferFallbackReason::UnsupportedLayout { layout } => {
            format!("unsupported destination buffer layout {layout:?}")
        }
    }
}

fn parse_padic_precision(domain: &str) -> Option<u32> {
    domain
        .rsplit_once('[')
        .and_then(|(_, suffix)| suffix.strip_suffix(']'))
        .and_then(|digits| digits.parse().ok())
}