use std::f64::consts::PI;
pub fn positional_encode(x: f64, n_freq: usize) -> Vec<f64> {
let out_len = 1 + 2 * n_freq;
let mut out = Vec::with_capacity(out_len);
out.push(x);
for k in 0..n_freq {
let freq = (1u64 << k) as f64 * PI; out.push((freq * x).sin());
out.push((freq * x).cos());
}
out
}
pub fn encode_position(pos: &[f64; 3], n_freq: usize) -> Vec<f64> {
let component_len = 1 + 2 * n_freq;
let mut out = Vec::with_capacity(3 * component_len);
for &coord in pos.iter() {
out.extend_from_slice(&positional_encode(coord, n_freq));
}
out
}
pub fn encode_direction(dir: &[f64; 3], n_freq: usize) -> Vec<f64> {
encode_position(dir, n_freq)
}
#[inline]
pub fn encoding_dim(n_freq: usize) -> usize {
3 * (1 + 2 * n_freq)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_positional_encode_dim() {
for n_freq in [1_usize, 4, 10] {
let enc = positional_encode(0.5, n_freq);
assert_eq!(
enc.len(),
1 + 2 * n_freq,
"n_freq={n_freq}: expected dim {}, got {}",
1 + 2 * n_freq,
enc.len()
);
}
}
#[test]
fn test_encode_position_dim() {
for n_freq in [1_usize, 4, 10] {
let enc = encode_position(&[0.1, 0.2, 0.3], n_freq);
assert_eq!(enc.len(), 3 * (1 + 2 * n_freq));
}
}
#[test]
fn test_positional_encode_zero() {
let enc = positional_encode(0.0, 3);
assert!((enc[0] - 0.0).abs() < 1e-12); for k in 0..3 {
let sin_idx = 1 + 2 * k;
let cos_idx = 2 + 2 * k;
assert!(
(enc[sin_idx] - 0.0).abs() < 1e-12,
"sin at k={k} should be 0, got {}",
enc[sin_idx]
);
assert!(
(enc[cos_idx] - 1.0).abs() < 1e-12,
"cos at k={k} should be 1, got {}",
enc[cos_idx]
);
}
}
#[test]
fn test_encoding_dim_helper() {
assert_eq!(encoding_dim(10), 3 * 21);
assert_eq!(encoding_dim(4), 3 * 9);
}
}