#[cfg(feature = "candle")]
use candle_core::{Device, Tensor};
pub const LCG_MULTIPLIER: u64 = 6_364_136_223_846_793_005;
pub const LCG_INCREMENT: u64 = 1;
pub const LCG_SHIFT: u32 = 33;
pub fn pseudo_random_vec(dim: usize, seed: u64) -> Vec<f32> {
let mut state = seed;
(0..dim)
.map(|_| {
state = state
.wrapping_mul(LCG_MULTIPLIER)
.wrapping_add(LCG_INCREMENT);
let bits = (state >> LCG_SHIFT) as i32;
bits as f32 / (i32::MAX as f32)
})
.collect()
}
pub const K_SEED_FREQ: f32 = 0.0137;
pub const V_SEED_FREQ: f32 = 0.0213;
pub const V_SEED_OFFSET: f32 = 1000.0;
pub const K_AMPLITUDE: f32 = 2.0;
pub const V_AMPLITUDE: f32 = 1.5;
#[cfg(feature = "candle")]
pub fn make_kv(
seq_len: usize,
num_kv_heads: usize,
head_dim: usize,
seed: u32,
) -> (Tensor, Tensor) {
let n = num_kv_heads * seq_len * head_dim;
let s = seed as f32;
let k_data: Vec<f32> = (0..n)
.map(|i| ((i as f32 + s) * K_SEED_FREQ).sin() * K_AMPLITUDE)
.collect();
let v_data: Vec<f32> = (0..n)
.map(|i| ((i as f32 + s + V_SEED_OFFSET) * V_SEED_FREQ).cos() * V_AMPLITUDE)
.collect();
let k = Tensor::from_vec(k_data, (1, num_kv_heads, seq_len, head_dim), &Device::Cpu).unwrap();
let v = Tensor::from_vec(v_data, (1, num_kv_heads, seq_len, head_dim), &Device::Cpu).unwrap();
(k, v)
}
#[cfg(feature = "candle")]
pub fn make_q(seq_len: usize, num_attn_heads: usize, head_dim: usize) -> Tensor {
Tensor::zeros(
(1, num_attn_heads, seq_len, head_dim),
candle_core::DType::F32,
&Device::Cpu,
)
.unwrap()
}
#[cfg(feature = "candle")]
pub fn cosine_sim(a: &Tensor, b: &Tensor) -> f32 {
let a_flat: Vec<f32> = a
.to_dtype(candle_core::DType::F32)
.unwrap()
.flatten_all()
.unwrap()
.to_vec1()
.unwrap();
let b_flat: Vec<f32> = b
.to_dtype(candle_core::DType::F32)
.unwrap()
.flatten_all()
.unwrap()
.to_vec1()
.unwrap();
let dot: f32 = a_flat.iter().zip(b_flat.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a_flat.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b_flat.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a < 1e-10 || norm_b < 1e-10 {
return 0.0;
}
dot / (norm_a * norm_b)
}
pub const SPLITMIX_GAMMA: u64 = 0x9e37_79b9_7f4a_7c15;
pub const SPLITMIX_MUL1: u64 = 0xbf58_476d_1ce4_e5b9;
pub const SPLITMIX_MUL2: u64 = 0x94d0_49bb_1331_11eb;
pub const SPLITMIX_SHIFT_1: u32 = 30;
pub const SPLITMIX_SHIFT_2: u32 = 27;
pub const SPLITMIX_SHIFT_3: u32 = 31;
pub const F64_MANTISSA_BITS: u32 = 53;
pub const U64_TO_F64_SHIFT: u32 = 64 - F64_MANTISSA_BITS;
pub struct SplitMix64 {
state: u64,
}
impl SplitMix64 {
pub fn new(seed: u64) -> Self {
Self { state: seed }
}
pub fn next_u64(&mut self) -> u64 {
self.state = self.state.wrapping_add(SPLITMIX_GAMMA);
let mut z = self.state;
z = (z ^ (z >> SPLITMIX_SHIFT_1)).wrapping_mul(SPLITMIX_MUL1);
z = (z ^ (z >> SPLITMIX_SHIFT_2)).wrapping_mul(SPLITMIX_MUL2);
z ^ (z >> SPLITMIX_SHIFT_3)
}
pub fn next_open01(&mut self) -> f64 {
((self.next_u64() >> U64_TO_F64_SHIFT) as f64 + 0.5) / (1u64 << F64_MANTISSA_BITS) as f64
}
}
pub fn random_unit_vec(dim: usize, seed: u64) -> Vec<f32> {
let mut rng = SplitMix64::new(seed);
let mut gaussians = Vec::with_capacity(dim);
let pairs = dim.div_ceil(2);
for _ in 0..pairs {
let u1 = rng.next_open01();
let u2 = rng.next_open01();
let r = (-2.0 * u1.ln()).sqrt();
let theta = 2.0 * std::f64::consts::PI * u2;
gaussians.push(r * theta.cos());
gaussians.push(r * theta.sin());
}
gaussians.truncate(dim);
let norm: f64 = gaussians.iter().map(|x| x * x).sum::<f64>().sqrt();
gaussians.iter().map(|x| (*x / norm) as f32).collect()
}
#[allow(dead_code)] pub fn splitmix_random_vec(dim: usize, seed: u64) -> Vec<f32> {
let mut rng = SplitMix64::new(seed);
(0..dim)
.map(|_| (rng.next_u64() as i64) as f32 / (i64::MAX as f32))
.collect()
}
pub fn random_normal_vec(dim: usize, seed: u64) -> Vec<f32> {
let mut rng = SplitMix64::new(seed);
let mut out = Vec::with_capacity(dim);
let pairs = dim.div_ceil(2);
for _ in 0..pairs {
let u1 = rng.next_open01();
let u2 = rng.next_open01();
let r = (-2.0 * u1.ln()).sqrt();
let theta = 2.0 * std::f64::consts::PI * u2;
out.push((r * theta.cos()) as f32);
if out.len() < dim {
out.push((r * theta.sin()) as f32);
}
}
out
}