vkml 0.0.3

High-level Vulkan-based machine learning library
use bytemuck;
use onnx_extractor::Bytes;

#[derive(Default)]
pub enum Initialiser {
    #[default]
    None,
    Bytes(Bytes),
    VecBytes(Vec<Bytes>),
    BoxU8(Box<[u8]>),
    VecU8(Vec<u8>),
    VecF32(Vec<f32>),
    VecF64(Vec<f64>),
    VecI32(Vec<i32>),
    VecI64(Vec<i64>),
    VecU64(Vec<u64>),
    Constant(Vec<u8>),
    Xavier,
    Uniform(f32, f32),
    He,
}

impl Initialiser {
    pub fn as_slice(&self) -> &[u8] {
        match self {
            Initialiser::Bytes(bytes) => bytes.as_ref(),
            Initialiser::BoxU8(boxed) => boxed.as_ref(),
            Initialiser::VecU8(vec) => vec.as_ref(),
            Initialiser::VecF32(v) => bytemuck::cast_slice(v),
            Initialiser::VecF64(v) => bytemuck::cast_slice(v),
            Initialiser::VecI32(v) => bytemuck::cast_slice(v),
            Initialiser::VecI64(v) => bytemuck::cast_slice(v),
            Initialiser::VecU64(v) => bytemuck::cast_slice(v),
            Initialiser::Constant(vec) => vec.as_ref(),

            Initialiser::None => unimplemented!("None"),
            Initialiser::VecBytes(_) => unimplemented!("BytesVec"),
            Initialiser::Xavier => unimplemented!("Xavier"),
            Initialiser::Uniform(_, _) => unimplemented!("Uniform"),
            Initialiser::He => unimplemented!("He"),
        }
    }

    // consumes self
    // a lot of these will require a copy for now
    pub fn into_cpu_buffer(self) -> Box<[u8]> {
        match self {
            Initialiser::Bytes(bytes) => bytes.to_vec().into(),
            Initialiser::VecBytes(parts) => parts.iter().flatten().copied().collect::<Box<[u8]>>(),
            Initialiser::BoxU8(boxed) => boxed,
            Initialiser::VecU8(vec) => vec.into(),
            Initialiser::VecF32(v) => bytemuck::cast_slice(&v).to_vec().into_boxed_slice(),
            Initialiser::VecF64(v) => bytemuck::cast_slice(&v).to_vec().into_boxed_slice(),
            Initialiser::VecI32(v) => bytemuck::cast_slice(&v).to_vec().into_boxed_slice(),
            Initialiser::VecI64(v) => bytemuck::cast_slice(&v).to_vec().into_boxed_slice(),
            Initialiser::VecU64(v) => bytemuck::cast_slice(&v).to_vec().into_boxed_slice(),
            Initialiser::Constant(vec) => vec.into(),

            Initialiser::None => unimplemented!("None"),
            Initialiser::Xavier => unimplemented!("Xavier"),
            Initialiser::Uniform(_, _) => unimplemented!("Uniform"),
            Initialiser::He => unimplemented!("He"),
        }
    }
}