zust-vm 0.9.12

Cranelift JIT runtime for executing Zust modules.
Documentation
use anyhow::{Result, anyhow, bail};
use compiler::{SymbolTable, eval_const_int_type};
use dynamic::{Dynamic, Type};
use smol_str::SmolStr;
use std::{collections::BTreeMap, rc::Rc};

#[derive(Debug, Clone)]
pub struct GpuStructLayout {
    pub name: SmolStr,
    pub ty: Type,
    pub size: usize,
    pub align: usize,
    pub fields: Vec<GpuFieldLayout>,
}

#[derive(Debug, Clone)]
pub struct GpuFieldLayout {
    pub name: SmolStr,
    pub ty: Type,
    pub offset: usize,
    pub size: usize,
    pub align: usize,
}

impl GpuStructLayout {
    pub(crate) fn from_symbol_table(symbols: &SymbolTable, name: &str, params: &[Type]) -> Result<Self> {
        let id = symbols.get_id(name)?;
        let ty = resolve_gpu_type(symbols, &Type::Symbol { id, params: params.to_vec() })?;
        Self::from_type(name.into(), ty)
    }

    pub(crate) fn from_type(name: SmolStr, ty: Type) -> Result<Self> {
        let Type::Struct { fields, .. } = &ty else {
            bail!("{name} is not a struct type: {ty:?}");
        };
        let (size, offsets) = Type::struct_layout(fields);
        let fields = fields.iter().zip(offsets).map(|((name, ty), offset)| GpuFieldLayout { name: name.clone(), ty: ty.clone(), offset: offset as usize, size: gpu_type_size(ty), align: gpu_type_align(ty) }).collect();
        let align = gpu_type_align(&ty);
        Ok(Self { name, ty, size: size as usize, align, fields })
    }

    pub fn pack_map(&self, value: &Dynamic) -> Result<Vec<u8>> {
        let mut bytes = vec![0; self.size];
        self.write_map(value, &mut bytes)?;
        Ok(bytes)
    }

    pub fn write_map(&self, value: &Dynamic, dst: &mut [u8]) -> Result<()> {
        if dst.len() < self.size {
            bail!("destination buffer too small for {}: need {}, got {}", self.name, self.size, dst.len());
        }
        write_value(&self.ty, value, &mut dst[..self.size])
    }

    pub fn unpack_map(&self, bytes: &[u8]) -> Result<Dynamic> {
        if bytes.len() < self.size {
            bail!("source buffer too small for {}: need {}, got {}", self.name, self.size, bytes.len());
        }
        read_value(&self.ty, &bytes[..self.size])
    }
}

pub(crate) fn resolve_gpu_type(symbols: &SymbolTable, ty: &Type) -> Result<Type> {
    let resolved = symbols.get_type(ty).unwrap_or_else(|_| ty.clone());
    match resolved {
        Type::Symbol { id, params } => {
            let params = params.iter().map(|param| resolve_gpu_type(symbols, param)).collect::<Result<Vec<_>>>()?;
            let resolved = symbols.get_type(&Type::Symbol { id, params })?;
            resolve_gpu_type(symbols, &resolved)
        }
        Type::Ident { name, params } => {
            let params = params.iter().map(|param| resolve_gpu_type(symbols, param)).collect::<Result<Vec<_>>>()?;
            let ident = Type::Ident { name, params };
            match symbols.get_type(&ident) {
                Ok(resolved) => resolve_gpu_type(symbols, &resolved),
                Err(_) => Ok(ident),
            }
        }
        Type::Struct { params, fields } => Ok(Type::Struct {
            params,
            fields: fields
                .iter()
                .filter_map(|(name, ty)| match resolve_gpu_type(symbols, ty) {
                    Ok(Type::Fn { .. }) => None,
                    Ok(ty) => Some(Ok((name.clone(), ty))),
                    Err(err) => Some(Err(err)),
                })
                .collect::<Result<Vec<_>>>()?,
        }),
        Type::Vec(elem, len) => Ok(Type::Vec(Rc::new(resolve_gpu_type(symbols, &elem)?), len)),
        Type::Array(elem, len) => Ok(Type::Array(Rc::new(resolve_gpu_type(symbols, &elem)?), len)),
        Type::ArrayParam(elem, len) => {
            let elem = resolve_gpu_type(symbols, &elem)?;
            let len = resolve_gpu_type(symbols, &len)?;
            if let Some(len) = eval_const_int_type(&len) {
                let len = u32::try_from(len).map_err(|_| anyhow!("array length out of u32 range"))?;
                Ok(Type::Array(Rc::new(elem), len))
            } else {
                Ok(Type::ArrayParam(Rc::new(elem), Rc::new(len)))
            }
        }
        Type::Fn { tys, ret } => Ok(Type::Fn { tys: tys.iter().map(|ty| resolve_gpu_type(symbols, ty)).collect::<Result<Vec<_>>>()?, ret: Rc::new(resolve_gpu_type(symbols, &ret)?) }),
        Type::Tuple(items) => Ok(Type::Tuple(items.iter().map(|ty| resolve_gpu_type(symbols, ty)).collect::<Result<Vec<_>>>()?)),
        other => Ok(other),
    }
}

fn gpu_type_align(ty: &Type) -> usize {
    ty.align().min(8).max(1) as usize
}

fn gpu_type_size(ty: &Type) -> usize {
    ty.storage_width() as usize
}

fn write_value(ty: &Type, value: &Dynamic, dst: &mut [u8]) -> Result<()> {
    match ty {
        Type::Bool => write_scalar(dst, if value.is_true() { 1u8 } else { 0u8 }),
        Type::U8 => write_scalar(dst, u8::try_from(value.clone())?),
        Type::I8 => write_scalar(dst, i8::try_from(value.clone())?),
        Type::U16 => write_scalar(dst, u16::try_from(value.clone())?),
        Type::I16 => write_scalar(dst, i16::try_from(value.clone())?),
        Type::U32 => write_scalar(dst, u32::try_from(value.clone())?),
        Type::I32 => write_scalar(dst, i32::try_from(value.clone())?),
        Type::U64 => write_scalar(dst, u64::try_from(value.clone())?),
        Type::I64 => write_scalar(dst, i64::try_from(value.clone())?),
        Type::F32 => write_scalar(dst, f32::try_from(value.clone())?),
        Type::F64 => write_scalar(dst, f64::try_from(value.clone())?),
        Type::Struct { fields, .. } => {
            let (_, offsets) = Type::struct_layout(fields);
            for ((field_name, field_ty), offset) in fields.iter().zip(offsets) {
                let field_value = value.get_dynamic(field_name.as_str()).ok_or_else(|| anyhow!("missing struct field {field_name}"))?;
                let start = offset as usize;
                let end = start + gpu_type_size(field_ty);
                write_value(field_ty, &field_value, dst.get_mut(start..end).ok_or_else(|| anyhow!("field {field_name} out of buffer"))?)?;
            }
            Ok(())
        }
        Type::Array(elem, len) | Type::Vec(elem, len) if *len > 0 => write_sequence(elem, *len as usize, value, dst),
        Type::Tuple(items) => {
            for (idx, item_ty) in items.iter().enumerate() {
                let item = value.get_idx(idx).ok_or_else(|| anyhow!("missing tuple item {idx}"))?;
                let start = tuple_offset(items, idx);
                let end = start + gpu_type_size(item_ty);
                write_value(item_ty, &item, dst.get_mut(start..end).ok_or_else(|| anyhow!("tuple item {idx} out of buffer"))?)?;
            }
            Ok(())
        }
        Type::Void => Ok(()),
        other => bail!("unsupported GPU layout write type: {other:?}"),
    }
}

fn write_sequence(elem: &Type, len: usize, value: &Dynamic, dst: &mut [u8]) -> Result<()> {
    let stride = gpu_type_size(elem);
    if value.len() < len {
        bail!("sequence needs {len} items, got {}", value.len());
    }
    for idx in 0..len {
        let item = value.get_idx(idx).ok_or_else(|| anyhow!("missing sequence item {idx}"))?;
        let start = idx * stride;
        let end = start + stride;
        write_value(elem, &item, dst.get_mut(start..end).ok_or_else(|| anyhow!("sequence item {idx} out of buffer"))?)?;
    }
    Ok(())
}

fn read_value(ty: &Type, src: &[u8]) -> Result<Dynamic> {
    match ty {
        Type::Bool => Ok(Dynamic::Bool(read_scalar::<u8>(src)? != 0)),
        Type::U8 => Ok(Dynamic::U8(read_scalar(src)?)),
        Type::I8 => Ok(Dynamic::I8(read_scalar(src)?)),
        Type::U16 => Ok(Dynamic::U16(read_scalar(src)?)),
        Type::I16 => Ok(Dynamic::I16(read_scalar(src)?)),
        Type::U32 => Ok(Dynamic::U32(read_scalar(src)?)),
        Type::I32 => Ok(Dynamic::I32(read_scalar(src)?)),
        Type::U64 => Ok(Dynamic::U64(read_scalar(src)?)),
        Type::I64 => Ok(Dynamic::I64(read_scalar(src)?)),
        Type::F32 => Ok(Dynamic::F32(read_scalar(src)?)),
        Type::F64 => Ok(Dynamic::F64(read_scalar(src)?)),
        Type::Struct { fields, .. } => {
            let (_, offsets) = Type::struct_layout(fields);
            let mut map = BTreeMap::new();
            for ((field_name, field_ty), offset) in fields.iter().zip(offsets) {
                let start = offset as usize;
                let end = start + gpu_type_size(field_ty);
                let value = read_value(field_ty, src.get(start..end).ok_or_else(|| anyhow!("field {field_name} out of buffer"))?)?;
                map.insert(field_name.clone(), value);
            }
            Ok(Dynamic::map(map))
        }
        Type::Array(elem, len) | Type::Vec(elem, len) if *len > 0 => read_sequence(elem, *len as usize, src),
        Type::Tuple(items) => {
            let values = items
                .iter()
                .enumerate()
                .map(|(idx, item_ty)| {
                    let start = tuple_offset(items, idx);
                    let end = start + gpu_type_size(item_ty);
                    read_value(item_ty, src.get(start..end).ok_or_else(|| anyhow!("tuple item {idx} out of buffer"))?)
                })
                .collect::<Result<Vec<_>>>()?;
            Ok(Dynamic::list(values))
        }
        Type::Void => Ok(Dynamic::Null),
        other => bail!("unsupported GPU layout read type: {other:?}"),
    }
}

fn read_sequence(elem: &Type, len: usize, src: &[u8]) -> Result<Dynamic> {
    let stride = gpu_type_size(elem);
    let values = (0..len)
        .map(|idx| {
            let start = idx * stride;
            let end = start + stride;
            read_value(elem, src.get(start..end).ok_or_else(|| anyhow!("sequence item {idx} out of buffer"))?)
        })
        .collect::<Result<Vec<_>>>()?;
    Ok(Dynamic::list(values))
}

fn tuple_offset(items: &[Type], idx: usize) -> usize {
    items.iter().take(idx).map(gpu_type_size).sum()
}

fn write_scalar<T: bytemuck::NoUninit>(dst: &mut [u8], value: T) -> Result<()> {
    let bytes = bytemuck::bytes_of(&value);
    let target = dst.get_mut(..bytes.len()).ok_or_else(|| anyhow!("buffer too small for scalar"))?;
    target.copy_from_slice(bytes);
    Ok(())
}

fn read_scalar<T: bytemuck::AnyBitPattern>(src: &[u8]) -> Result<T> {
    let size = std::mem::size_of::<T>();
    let bytes = src.get(..size).ok_or_else(|| anyhow!("buffer too small for scalar"))?;
    Ok(bytemuck::pod_read_unaligned(bytes))
}