use metal::{
Buffer, BufferRef, CommandBufferRef, ComputePipelineState, MTLSize,
NSUInteger,
};
use super::gpu_norm::{encode_rms_norm_bf16_into, RmsNormBf16Pipelines};
use super::metal::{MetalBackend, MetalError, MtlBuffer};
use super::variants::{Variant, GROUP_SIZE, VARIANT};
pub(crate) struct ChainToNormed<'a> {
pub pipes: &'a RmsNormBf16Pipelines,
pub wf_buf: &'a Buffer,
pub next_norm_off: u64,
pub combine_out: &'a Buffer,
pub chain_sum_sq: &'a Buffer,
pub chain_normed: &'a Buffer,
pub eps: f32,
}
pub const MAX_K: usize = 16;
#[derive(Debug, thiserror::Error)]
pub enum ExpertForwardError {
#[error(
"expert_data is the wrong length: expected {expected} bytes \
(4-bit layout), got {actual}"
)]
BadExpertDataLen { expected: usize, actual: usize },
#[error("h_post must be HIDDEN_DIM={expected} floats, got {actual}")]
BadHPostLen { expected: usize, actual: usize },
#[error("expert_out must be HIDDEN_DIM={expected} floats, got {actual}")]
BadExpertOutLen { expected: usize, actual: usize },
#[error("h_mid must be HIDDEN_DIM={expected} floats, got {actual}")]
BadHMidLen { expected: usize, actual: usize },
#[error("shared_out must be HIDDEN_DIM={expected} floats, got {actual}")]
BadSharedOutLen { expected: usize, actual: usize },
#[error("hidden_out must be HIDDEN_DIM={expected} floats, got {actual}")]
BadHiddenOutLen { expected: usize, actual: usize },
#[error(
"actual_K out of range: must be 1..={max}, got {actual}"
)]
BadK { actual: i32, max: usize },
#[error("expert_weights must be {expected} floats, got {actual}")]
BadWeightsLen { expected: usize, actual: usize },
#[error("Metal backend: {0}")]
Metal(#[from] MetalError),
}
pub fn gpu_expert_forward(
metal: &mut MetalBackend,
expert_data: &[u8],
h_post: &[f32],
expert_out: &mut [f32],
) -> Result<(), ExpertForwardError> {
let v = VARIANT;
let expected_data_len = v.expert_size_4bit();
if expert_data.len() != expected_data_len {
return Err(ExpertForwardError::BadExpertDataLen {
expected: expected_data_len,
actual: expert_data.len(),
});
}
if h_post.len() != v.hidden_dim {
return Err(ExpertForwardError::BadHPostLen {
expected: v.hidden_dim,
actual: h_post.len(),
});
}
if expert_out.len() != v.hidden_dim {
return Err(ExpertForwardError::BadExpertOutLen {
expected: v.hidden_dim,
actual: expert_out.len(),
});
}
let matvec = metal.pipeline("dequant_matvec_4bit_v3")?.clone();
let swiglu = metal.pipeline("swiglu_fused")?.clone();
let device = metal.device();
let data = MtlBuffer::<u8>::with_data(device, expert_data);
let input = MtlBuffer::<f32>::with_data(device, h_post);
let gate_out = MtlBuffer::<f32>::with_len(device, v.moe_intermediate);
let up_out = MtlBuffer::<f32>::with_len(device, v.moe_intermediate);
let act = MtlBuffer::<f32>::with_len(device, v.moe_intermediate);
let out = MtlBuffer::<f32>::with_len(device, v.hidden_dim);
let cmdbuf = metal.queue().new_command_buffer();
encode_matvec(
cmdbuf,
&matvec,
&data,
v.gate_w_off_4bit(),
v.gate_s_off_4bit(),
v.gate_b_off_4bit(),
&input,
&gate_out,
v.moe_intermediate as u32,
v.hidden_dim as u32,
);
encode_matvec(
cmdbuf,
&matvec,
&data,
v.up_w_off_4bit(),
v.up_s_off_4bit(),
v.up_b_off_4bit(),
&input,
&up_out,
v.moe_intermediate as u32,
v.hidden_dim as u32,
);
encode_swiglu(
cmdbuf,
&swiglu,
&gate_out,
&up_out,
&act,
v.moe_intermediate as u32,
);
encode_matvec(
cmdbuf,
&matvec,
&data,
v.down_w_off_4bit(),
v.down_s_off_4bit(),
v.down_b_off_4bit(),
&act,
&out,
v.hidden_dim as u32,
v.moe_intermediate as u32,
);
cmdbuf.commit();
cmdbuf.wait_until_completed();
expert_out.copy_from_slice(&out.to_vec());
Ok(())
}
fn encode_matvec(
cmdbuf: &metal::CommandBufferRef,
pipeline: &metal::ComputePipelineState,
data: &MtlBuffer<u8>,
w_off: usize,
s_off: usize,
b_off: usize,
input: &MtlBuffer<f32>,
output: &MtlBuffer<f32>,
out_dim: u32,
in_dim: u32,
) {
let group_size = GROUP_SIZE as u32;
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(pipeline);
enc.set_buffer(0, Some(data.raw()), w_off as NSUInteger);
enc.set_buffer(1, Some(data.raw()), s_off as NSUInteger);
enc.set_buffer(2, Some(data.raw()), b_off as NSUInteger);
enc.set_buffer(3, Some(input.raw()), 0);
enc.set_buffer(4, Some(output.raw()), 0);
enc.set_bytes(5, 4, (&out_dim as *const u32).cast());
enc.set_bytes(6, 4, (&in_dim as *const u32).cast());
enc.set_bytes(7, 4, (&group_size as *const u32).cast());
let num_tgs = (out_dim + 7) / 8;
enc.dispatch_thread_groups(
MTLSize::new(num_tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}
fn encode_swiglu(
cmdbuf: &metal::CommandBufferRef,
pipeline: &metal::ComputePipelineState,
gate: &MtlBuffer<f32>,
up: &MtlBuffer<f32>,
act: &MtlBuffer<f32>,
dim: u32,
) {
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(pipeline);
enc.set_buffer(0, Some(gate.raw()), 0);
enc.set_buffer(1, Some(up.raw()), 0);
enc.set_buffer(2, Some(act.raw()), 0);
enc.set_bytes(3, 4, (&dim as *const u32).cast());
let num_tgs = (dim + 255) / 256;
enc.dispatch_thread_groups(
MTLSize::new(num_tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}
pub struct MoeBuffers {
data_synced: [MtlBuffer<u8>; MAX_K],
data_prefetch: [[MtlBuffer<u8>; MAX_K]; 2],
gate: [MtlBuffer<f32>; MAX_K],
up: [MtlBuffer<f32>; MAX_K],
act: [MtlBuffer<f32>; MAX_K],
out: [MtlBuffer<f32>; MAX_K],
input: MtlBuffer<f32>,
h_mid: MtlBuffer<f32>,
shared_out: MtlBuffer<f32>,
moe_hidden: MtlBuffer<f32>,
combine_params: MtlBuffer<f32>,
}
impl MoeBuffers {
pub fn new(device: &metal::Device) -> Self {
let v: Variant = VARIANT;
let data_synced = std::array::from_fn(|_| {
MtlBuffer::<u8>::with_len(device, v.expert_size_4bit())
});
let data_prefetch: [[MtlBuffer<u8>; MAX_K]; 2] =
std::array::from_fn(|_| {
std::array::from_fn(|_| {
MtlBuffer::<u8>::with_len(device, v.expert_size_4bit())
})
});
const TWO_MIB: usize = 2 * 1024 * 1024;
let probe = |label: &str, set: &[MtlBuffer<u8>]| {
for (slot, buf) in set.iter().enumerate() {
let addr = buf.raw().contents() as usize;
if addr % TWO_MIB != 0 {
eprintln!(
"[moe] WARNING: data_{label} slot {slot} not 2 MB \
aligned (contents=0x{addr:x}, off=0x{off:x}); \
pread DMA may use scatter-gather. See slice 5d-6 \
plan.",
off = addr % TWO_MIB
);
return;
}
}
};
probe("synced", &data_synced[..]);
probe("prefetch[0]", &data_prefetch[0][..]);
probe("prefetch[1]", &data_prefetch[1][..]);
let gate =
std::array::from_fn(|_| MtlBuffer::<f32>::with_len(device, v.moe_intermediate));
let up = std::array::from_fn(|_| {
MtlBuffer::<f32>::with_len(device, v.moe_intermediate)
});
let act = std::array::from_fn(|_| {
MtlBuffer::<f32>::with_len(device, v.moe_intermediate)
});
let out =
std::array::from_fn(|_| MtlBuffer::<f32>::with_len(device, v.hidden_dim));
Self {
data_synced,
data_prefetch,
gate,
up,
act,
out,
input: MtlBuffer::with_len(device, v.hidden_dim),
h_mid: MtlBuffer::with_len(device, v.hidden_dim),
shared_out: MtlBuffer::with_len(device, v.hidden_dim),
moe_hidden: MtlBuffer::with_len(device, v.hidden_dim),
combine_params: MtlBuffer::with_len(device, 18),
}
}
pub(crate) fn moe_hidden(&self) -> &MtlBuffer<f32> {
&self.moe_hidden
}
pub(crate) fn out(&self, slot: usize) -> &MtlBuffer<f32> {
&self.out[slot]
}
pub(crate) fn data_synced_slots_mut_array(
&mut self,
) -> [&mut [u8]; MAX_K] {
self.data_synced.each_mut().map(|b| b.as_mut_slice())
}
pub(crate) fn data_prefetch_slots_mut_array(
&mut self,
set: usize,
) -> [&mut [u8]; MAX_K] {
debug_assert!(set < 2, "prefetch set index must be 0 or 1");
self.data_prefetch[set].each_mut().map(|b| b.as_mut_slice())
}
}
impl std::fmt::Debug for MoeBuffers {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MoeBuffers")
.field("max_k", &MAX_K)
.field("hidden_dim", &VARIANT.hidden_dim)
.field("moe_intermediate", &VARIANT.moe_intermediate)
.field("expert_size_4bit", &VARIANT.expert_size_4bit())
.finish()
}
}
#[allow(clippy::too_many_arguments)]
pub fn gpu_batched_experts_forward(
metal: &mut MetalBackend,
bufs: &mut MoeBuffers,
actual_k: i32,
expert_data: &[u8],
h_post: &[f32],
h_mid: &[f32],
shared_out: &[f32],
expert_weights: &[f32],
shared_gate_score: f32,
hidden_out: &mut [f32],
) -> Result<(), ExpertForwardError> {
let v = VARIANT;
if hidden_out.len() != v.hidden_dim {
return Err(ExpertForwardError::BadHiddenOutLen {
expected: v.hidden_dim,
actual: hidden_out.len(),
});
}
let cmdbuf = gpu_batched_experts_encode(
metal,
bufs,
actual_k,
expert_data,
h_post,
h_mid,
shared_out,
expert_weights,
shared_gate_score,
true,
)?;
cmdbuf.commit();
cmdbuf.wait_until_completed();
hidden_out.copy_from_slice(&bufs.moe_hidden.to_vec());
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn gpu_batched_experts_encode(
metal: &mut MetalBackend,
bufs: &mut MoeBuffers,
actual_k: i32,
expert_data: &[u8],
h_post: &[f32],
h_mid: &[f32],
shared_out: &[f32],
expert_weights: &[f32],
shared_gate_score: f32,
gpu_combine: bool,
) -> Result<metal::CommandBuffer, ExpertForwardError> {
let v = VARIANT;
validate_inputs(actual_k, expert_data, expert_weights)?;
let k = actual_k as usize;
if h_post.len() != v.hidden_dim {
return Err(ExpertForwardError::BadHPostLen {
expected: v.hidden_dim,
actual: h_post.len(),
});
}
if h_mid.len() != v.hidden_dim {
return Err(ExpertForwardError::BadHMidLen {
expected: v.hidden_dim,
actual: h_mid.len(),
});
}
if shared_out.len() != v.hidden_dim {
return Err(ExpertForwardError::BadSharedOutLen {
expected: v.hidden_dim,
actual: shared_out.len(),
});
}
let matvec = metal.pipeline("dequant_matvec_4bit_v3")?.clone();
let swiglu = metal.pipeline("swiglu_fused")?.clone();
let combine = if gpu_combine {
Some(metal.pipeline("moe_combine_residual")?.clone())
} else {
None
};
let expert_size = v.expert_size_4bit();
for slot in 0..k {
let src = &expert_data[slot * expert_size..(slot + 1) * expert_size];
bufs.data_synced[slot].as_mut_slice().copy_from_slice(src);
}
bufs.input.as_mut_slice().copy_from_slice(h_post);
bufs.h_mid.as_mut_slice().copy_from_slice(h_mid);
bufs.shared_out.as_mut_slice().copy_from_slice(shared_out);
{
let params = bufs.combine_params.as_mut_slice();
params.fill(0.0);
params[..k].copy_from_slice(expert_weights);
params[16] = shared_gate_score;
}
let cmdbuf = metal.queue().new_command_buffer();
let data_set_per_slot: [super::SlotSource; MAX_K] =
[super::SlotSource::Synced; MAX_K];
emit_batched_experts(
cmdbuf,
&matvec,
&swiglu,
combine.as_ref(),
bufs,
bufs.input.raw(),
bufs.h_mid.raw(),
bufs.shared_out.raw(),
k,
v,
&data_set_per_slot,
0,
None,
);
Ok(cmdbuf.to_owned())
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn gpu_batched_experts_encode_pre_staged(
metal: &mut MetalBackend,
bufs: &mut MoeBuffers,
actual_k: i32,
input: &BufferRef,
h_mid: &BufferRef,
shared_out: &BufferRef,
expert_weights: &[f32],
shared_gate_score: f32,
data_set_per_slot: &[super::SlotSource; MAX_K],
prefetch_set: usize,
chain: Option<ChainToNormed<'_>>,
) -> Result<metal::CommandBuffer, ExpertForwardError> {
let v = VARIANT;
if actual_k < 1 || (actual_k as usize) > MAX_K {
return Err(ExpertForwardError::BadK {
actual: actual_k,
max: MAX_K,
});
}
let k = actual_k as usize;
if expert_weights.len() != k {
return Err(ExpertForwardError::BadWeightsLen {
expected: k,
actual: expert_weights.len(),
});
}
let matvec = metal.pipeline("dequant_matvec_4bit_v3")?.clone();
let swiglu = metal.pipeline("swiglu_fused")?.clone();
let combine = metal.pipeline("moe_combine_residual")?.clone();
{
let params = bufs.combine_params.as_mut_slice();
params.fill(0.0);
params[..k].copy_from_slice(expert_weights);
params[16] = shared_gate_score;
}
let cmdbuf = metal.queue().new_command_buffer();
emit_batched_experts(
cmdbuf,
&matvec,
&swiglu,
Some(&combine),
bufs,
input,
h_mid,
shared_out,
k,
v,
data_set_per_slot,
prefetch_set,
chain,
);
Ok(cmdbuf.to_owned())
}
fn validate_inputs(
actual_k: i32,
expert_data: &[u8],
expert_weights: &[f32],
) -> Result<(), ExpertForwardError> {
let v = VARIANT;
if actual_k < 1 || (actual_k as usize) > MAX_K {
return Err(ExpertForwardError::BadK {
actual: actual_k,
max: MAX_K,
});
}
let k = actual_k as usize;
let expected_data_len = k * v.expert_size_4bit();
if expert_data.len() != expected_data_len {
return Err(ExpertForwardError::BadExpertDataLen {
expected: expected_data_len,
actual: expert_data.len(),
});
}
if expert_weights.len() != k {
return Err(ExpertForwardError::BadWeightsLen {
expected: k,
actual: expert_weights.len(),
});
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn emit_batched_experts(
cmdbuf: &CommandBufferRef,
matvec: &ComputePipelineState,
swiglu: &ComputePipelineState,
combine: Option<&ComputePipelineState>,
bufs: &MoeBuffers,
input: &BufferRef,
h_mid: &BufferRef,
shared_out: &BufferRef,
k: usize,
v: Variant,
data_set_per_slot: &[super::SlotSource; MAX_K],
prefetch_set: usize,
chain: Option<ChainToNormed<'_>>,
) {
debug_assert!(prefetch_set < 2, "prefetch set index must be 0 or 1");
let pick = |slot: usize| -> &MtlBuffer<u8> {
match data_set_per_slot[slot] {
super::SlotSource::Synced => &bufs.data_synced[slot],
super::SlotSource::Prefetched => {
&bufs.data_prefetch[prefetch_set][slot]
}
}
};
for slot in 0..k {
let weights_buf = pick(slot);
{
let enc = cmdbuf.new_compute_command_encoder();
encode_matvec_into(
enc,
matvec,
weights_buf,
v.gate_w_off_4bit(),
v.gate_s_off_4bit(),
v.gate_b_off_4bit(),
input,
bufs.gate[slot].raw(),
v.moe_intermediate as u32,
v.hidden_dim as u32,
);
encode_matvec_into(
enc,
matvec,
weights_buf,
v.up_w_off_4bit(),
v.up_s_off_4bit(),
v.up_b_off_4bit(),
input,
bufs.up[slot].raw(),
v.moe_intermediate as u32,
v.hidden_dim as u32,
);
enc.end_encoding();
}
{
let enc = cmdbuf.new_compute_command_encoder();
encode_swiglu_into_buf(
enc,
swiglu,
bufs.gate[slot].raw(),
bufs.up[slot].raw(),
bufs.act[slot].raw(),
v.moe_intermediate as u32,
);
encode_matvec_into(
enc,
matvec,
weights_buf,
v.down_w_off_4bit(),
v.down_s_off_4bit(),
v.down_b_off_4bit(),
bufs.act[slot].raw(),
bufs.out[slot].raw(),
v.hidden_dim as u32,
v.moe_intermediate as u32,
);
enc.end_encoding();
}
}
if let Some(combine) = combine {
let combine_out: &BufferRef = match chain.as_ref() {
Some(c) => c.combine_out,
None => bufs.moe_hidden.raw(),
};
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(combine);
enc.set_buffer(0, Some(h_mid), 0);
enc.set_buffer(1, Some(shared_out), 0);
enc.set_buffer(2, Some(combine_out), 0);
for slot in 0..MAX_K {
enc.set_buffer(
3 + slot as NSUInteger,
Some(bufs.out[slot].raw()),
0,
);
}
enc.set_buffer(19, Some(bufs.combine_params.raw()), 0);
let dim = v.hidden_dim as u32;
let k_val = k as u32;
enc.set_bytes(20, 4, (&dim as *const u32).cast());
enc.set_bytes(21, 4, (&k_val as *const u32).cast());
let tgs = (dim + 255) / 256;
enc.dispatch_thread_groups(
MTLSize::new(tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
if let Some(c) = chain {
encode_rms_norm_bf16_into(
cmdbuf,
c.pipes,
c.combine_out,
c.wf_buf,
c.next_norm_off,
c.chain_sum_sq,
c.chain_normed,
v.hidden_dim as u32,
c.eps,
);
}
}
}
fn encode_matvec_into(
enc: &metal::ComputeCommandEncoderRef,
pipeline: &metal::ComputePipelineState,
data: &MtlBuffer<u8>,
w_off: usize,
s_off: usize,
b_off: usize,
input: &BufferRef,
output: &BufferRef,
out_dim: u32,
in_dim: u32,
) {
let group_size = GROUP_SIZE as u32;
enc.set_compute_pipeline_state(pipeline);
enc.set_buffer(0, Some(data.raw()), w_off as NSUInteger);
enc.set_buffer(1, Some(data.raw()), s_off as NSUInteger);
enc.set_buffer(2, Some(data.raw()), b_off as NSUInteger);
enc.set_buffer(3, Some(input), 0);
enc.set_buffer(4, Some(output), 0);
enc.set_bytes(5, 4, (&out_dim as *const u32).cast());
enc.set_bytes(6, 4, (&in_dim as *const u32).cast());
enc.set_bytes(7, 4, (&group_size as *const u32).cast());
let num_tgs = (out_dim + 7) / 8;
enc.dispatch_thread_groups(
MTLSize::new(num_tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
}
fn encode_swiglu_into_buf(
enc: &metal::ComputeCommandEncoderRef,
pipeline: &metal::ComputePipelineState,
gate: &BufferRef,
up: &BufferRef,
act: &BufferRef,
dim: u32,
) {
enc.set_compute_pipeline_state(pipeline);
enc.set_buffer(0, Some(gate), 0);
enc.set_buffer(1, Some(up), 0);
enc.set_buffer(2, Some(act), 0);
enc.set_bytes(3, 4, (&dim as *const u32).cast());
let num_tgs = (dim + 255) / 256;
enc.dispatch_thread_groups(
MTLSize::new(num_tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[ignore = "needs Metal device + access to shaders.metal source"]
fn gpu_expert_forward_runs_and_produces_finite_output() {
let mut metal = MetalBackend::new().expect("MetalBackend::new");
let expert_data = synth::expert_data_seeded();
let h_post = synth::h_post_seeded();
let mut out = vec![0.0f32; VARIANT.hidden_dim];
gpu_expert_forward(&mut metal, &expert_data, &h_post, &mut out)
.expect("gpu_expert_forward");
assert!(out.iter().all(|x| x.is_finite()), "output has NaN/Inf");
assert!(
out.iter().any(|&x| x.abs() > 0.0),
"output is all zero — kernel didn't write?"
);
}
}
pub mod synth {
use super::*;
pub fn expert_data_seeded() -> Vec<u8> {
let v: Variant = VARIANT;
let mut data = vec![0u8; v.expert_size_4bit()];
for block in 0..3 {
let block_off = block * v.expert_block_bytes_4bit();
let w_end = block_off + v.expert_weight_bytes_4bit();
let mut state: u64 = 0xCAFE_BEEF + block as u64;
for byte in &mut data[block_off..w_end] {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
*byte = (state >> 32) as u8;
}
let s_end = w_end + v.expert_scale_bytes();
for chunk in data[w_end..s_end].chunks_exact_mut(2) {
chunk[0] = 0x00;
chunk[1] = 0x3C;
}
}
data
}
pub fn h_post_seeded() -> Vec<f32> {
let v = VARIANT;
(0..v.hidden_dim)
.map(|i| {
(i as f32 - v.hidden_dim as f32 / 2.0) * 1e-3
/ v.hidden_dim as f32
})
.collect()
}
pub fn k_expert_data_seeded(k: usize) -> Vec<u8> {
let v: Variant = VARIANT;
let per_expert = v.expert_size_4bit();
let mut data = vec![0u8; k * per_expert];
for slot in 0..k {
let dst = &mut data[slot * per_expert..(slot + 1) * per_expert];
for block in 0..3 {
let block_off = block * v.expert_block_bytes_4bit();
let w_end = block_off + v.expert_weight_bytes_4bit();
let mut state: u64 = 0xCAFE_BEEF
^ ((slot as u64) << 32)
^ (block as u64);
for byte in &mut dst[block_off..w_end] {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
*byte = (state >> 32) as u8;
}
let s_end = w_end + v.expert_scale_bytes();
for chunk in dst[w_end..s_end].chunks_exact_mut(2) {
chunk[0] = 0x00;
chunk[1] = 0x3C;
}
}
}
data
}
pub fn h_mid_seeded() -> Vec<f32> {
let v = VARIANT;
(0..v.hidden_dim)
.map(|i| (i as f32 * 0.0007 - 0.05).sin() * 0.001)
.collect()
}
pub fn shared_out_seeded() -> Vec<f32> {
let v = VARIANT;
(0..v.hidden_dim)
.map(|i| (i as f32 * 0.0011 + 0.03).cos() * 0.001)
.collect()
}
pub fn expert_weights_seeded(k: usize) -> Vec<f32> {
let raw: Vec<f32> = (0..k)
.map(|i| ((i as f32) * 0.37 + 1.0).abs())
.collect();
let total: f32 = raw.iter().sum();
raw.iter().map(|w| w / total).collect()
}
}