1use serde::{Deserialize, Serialize};
8
9const GEOMETRY_DIM: usize = 64;
10const NUM_COORDS: usize = 3;
11
12#[derive(Debug, Clone)]
18struct Linear {
19 weights: Vec<f32>,
20 bias: Vec<f32>,
21 in_f: usize,
22 out_f: usize,
23}
24
25impl Linear {
26 fn new(in_f: usize, out_f: usize, seed: u64) -> Self {
28 let k = (1.0 / in_f as f32).sqrt();
29 Linear {
30 weights: det_uniform(in_f * out_f, -k, k, seed),
31 bias: vec![0.0; out_f],
32 in_f,
33 out_f,
34 }
35 }
36
37 fn forward(&self, x: &[f32]) -> Vec<f32> {
38 debug_assert_eq!(x.len(), self.in_f);
39 let mut y = self.bias.clone();
40 for j in 0..self.out_f {
41 let off = j * self.in_f;
42 let mut s = 0.0f32;
43 for i in 0..self.in_f {
44 s += x[i] * self.weights[off + i];
45 }
46 y[j] += s;
47 }
48 y
49 }
50}
51
52fn det_uniform(n: usize, lo: f32, hi: f32, seed: u64) -> Vec<f32> {
55 let r = hi - lo;
56 let mut s = seed.wrapping_add(0x9E37_79B9_7F4A_7C15);
57 (0..n)
58 .map(|_| {
59 s ^= s << 13;
60 s ^= s >> 7;
61 s ^= s << 17;
62 lo + (s >> 40) as f32 / (1u64 << 24) as f32 * r
63 })
64 .collect()
65}
66
67fn relu(v: &mut [f32]) {
68 for x in v.iter_mut() {
69 if *x < 0.0 { *x = 0.0; }
70 }
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct MeridianGeometryConfig {
80 pub n_frequencies: usize,
82 pub scale: f32,
84 pub geometry_dim: usize,
86 pub seed: u64,
88}
89
90impl Default for MeridianGeometryConfig {
91 fn default() -> Self {
92 MeridianGeometryConfig { n_frequencies: 10, scale: 1.0, geometry_dim: GEOMETRY_DIM, seed: 42 }
93 }
94}
95
96pub struct FourierPositionalEncoding {
105 n_frequencies: usize,
106 scale: f32,
107 output_dim: usize,
108}
109
110impl FourierPositionalEncoding {
111 pub fn new(cfg: &MeridianGeometryConfig) -> Self {
113 FourierPositionalEncoding { n_frequencies: cfg.n_frequencies, scale: cfg.scale, output_dim: cfg.geometry_dim }
114 }
115
116 pub fn encode(&self, coords: &[f32; 3]) -> Vec<f32> {
118 let raw = NUM_COORDS * 2 * self.n_frequencies;
119 let mut enc = Vec::with_capacity(raw.max(self.output_dim));
120 for &c in coords {
121 let sc = c * self.scale;
122 for l in 0..self.n_frequencies {
123 let f = (2.0f32).powi(l as i32) * std::f32::consts::PI * sc;
124 enc.push(f.sin());
125 enc.push(f.cos());
126 }
127 }
128 enc.resize(self.output_dim, 0.0);
129 enc
130 }
131}
132
133pub struct DeepSets {
139 phi: Linear,
140 rho: Linear,
141 dim: usize,
142}
143
144impl DeepSets {
145 pub fn new(cfg: &MeridianGeometryConfig) -> Self {
147 let d = cfg.geometry_dim;
148 DeepSets { phi: Linear::new(d, d, cfg.seed.wrapping_add(1)), rho: Linear::new(d, d, cfg.seed.wrapping_add(2)), dim: d }
149 }
150
151 pub fn encode(&self, ap_embeddings: &[Vec<f32>]) -> Vec<f32> {
153 assert!(!ap_embeddings.is_empty(), "DeepSets: input set must be non-empty");
154 let n = ap_embeddings.len() as f32;
155 let mut pooled = vec![0.0f32; self.dim];
156 for emb in ap_embeddings {
157 debug_assert_eq!(emb.len(), self.dim);
158 let mut t = self.phi.forward(emb);
159 relu(&mut t);
160 for (p, v) in pooled.iter_mut().zip(t.iter()) { *p += *v; }
161 }
162 for p in pooled.iter_mut() { *p /= n; }
163 let mut out = self.rho.forward(&pooled);
164 relu(&mut out);
165 out
166 }
167}
168
169pub struct GeometryEncoder {
175 pos_embed: FourierPositionalEncoding,
176 set_encoder: DeepSets,
177}
178
179impl GeometryEncoder {
180 pub fn new(cfg: &MeridianGeometryConfig) -> Self {
182 GeometryEncoder { pos_embed: FourierPositionalEncoding::new(cfg), set_encoder: DeepSets::new(cfg) }
183 }
184
185 pub fn encode(&self, ap_positions: &[[f32; 3]]) -> Vec<f32> {
187 let embs: Vec<Vec<f32>> = ap_positions.iter().map(|p| self.pos_embed.encode(p)).collect();
188 self.set_encoder.encode(&embs)
189 }
190}
191
192pub struct FilmLayer {
198 gamma_proj: Linear,
199 beta_proj: Linear,
200}
201
202impl FilmLayer {
203 pub fn new(cfg: &MeridianGeometryConfig) -> Self {
205 let d = cfg.geometry_dim;
206 let mut gamma_proj = Linear::new(d, d, cfg.seed.wrapping_add(3));
207 for b in gamma_proj.bias.iter_mut() { *b = 1.0; }
208 FilmLayer { gamma_proj, beta_proj: Linear::new(d, d, cfg.seed.wrapping_add(4)) }
209 }
210
211 pub fn modulate(&self, features: &[f32], geometry: &[f32]) -> Vec<f32> {
213 let gamma = self.gamma_proj.forward(geometry);
214 let beta = self.beta_proj.forward(geometry);
215 features.iter().zip(gamma.iter()).zip(beta.iter()).map(|((&f, &g), &b)| g * f + b).collect()
216 }
217}
218
219#[cfg(test)]
224mod tests {
225 use super::*;
226
227 fn cfg() -> MeridianGeometryConfig { MeridianGeometryConfig::default() }
228
229 #[test]
230 fn fourier_output_dimension_is_64() {
231 let c = cfg();
232 let out = FourierPositionalEncoding::new(&c).encode(&[1.0, 2.0, 3.0]);
233 assert_eq!(out.len(), c.geometry_dim);
234 }
235
236 #[test]
237 fn fourier_different_coords_different_outputs() {
238 let enc = FourierPositionalEncoding::new(&cfg());
239 let a = enc.encode(&[0.0, 0.0, 0.0]);
240 let b = enc.encode(&[1.0, 0.0, 0.0]);
241 let c = enc.encode(&[0.0, 1.0, 0.0]);
242 let d = enc.encode(&[0.0, 0.0, 1.0]);
243 assert_ne!(a, b); assert_ne!(a, c); assert_ne!(a, d); assert_ne!(b, c);
244 }
245
246 #[test]
247 fn fourier_values_bounded() {
248 let out = FourierPositionalEncoding::new(&cfg()).encode(&[5.5, -3.2, 0.1]);
249 for &v in &out { assert!(v.abs() <= 1.0 + 1e-6, "got {v}"); }
250 }
251
252 #[test]
253 fn deepsets_permutation_invariant() {
254 let c = cfg();
255 let enc = FourierPositionalEncoding::new(&c);
256 let ds = DeepSets::new(&c);
257 let (a, b, d) = (enc.encode(&[1.0,0.0,0.0]), enc.encode(&[0.0,2.0,0.0]), enc.encode(&[0.0,0.0,3.0]));
258 let abc = ds.encode(&[a.clone(), b.clone(), d.clone()]);
259 let cba = ds.encode(&[d.clone(), b.clone(), a.clone()]);
260 let bac = ds.encode(&[b.clone(), a.clone(), d.clone()]);
261 for i in 0..c.geometry_dim {
262 assert!((abc[i] - cba[i]).abs() < 1e-5, "dim {i}: abc={} cba={}", abc[i], cba[i]);
263 assert!((abc[i] - bac[i]).abs() < 1e-5, "dim {i}: abc={} bac={}", abc[i], bac[i]);
264 }
265 }
266
267 #[test]
268 fn deepsets_variable_ap_count() {
269 let c = cfg();
270 let enc = FourierPositionalEncoding::new(&c);
271 let ds = DeepSets::new(&c);
272 let one = ds.encode(&[enc.encode(&[1.0,0.0,0.0])]);
273 assert_eq!(one.len(), c.geometry_dim);
274 let three = ds.encode(&[enc.encode(&[1.0,0.0,0.0]), enc.encode(&[0.0,2.0,0.0]), enc.encode(&[0.0,0.0,3.0])]);
275 assert_eq!(three.len(), c.geometry_dim);
276 let six = ds.encode(&[
277 enc.encode(&[1.0,0.0,0.0]), enc.encode(&[0.0,2.0,0.0]), enc.encode(&[0.0,0.0,3.0]),
278 enc.encode(&[-1.0,0.0,0.0]), enc.encode(&[0.0,-2.0,0.0]), enc.encode(&[0.0,0.0,-3.0]),
279 ]);
280 assert_eq!(six.len(), c.geometry_dim);
281 assert_ne!(one, three); assert_ne!(three, six);
282 }
283
284 #[test]
285 fn geometry_encoder_end_to_end() {
286 let c = cfg();
287 let g = GeometryEncoder::new(&c).encode(&[[1.0,0.0,2.5],[0.0,3.0,2.5],[-2.0,1.0,2.5]]);
288 assert_eq!(g.len(), c.geometry_dim);
289 for &v in &g { assert!(v.is_finite()); }
290 }
291
292 #[test]
293 fn geometry_encoder_single_ap() {
294 let c = cfg();
295 assert_eq!(GeometryEncoder::new(&c).encode(&[[0.0,0.0,0.0]]).len(), c.geometry_dim);
296 }
297
298 #[test]
299 fn film_identity_when_geometry_zero() {
300 let c = cfg();
301 let film = FilmLayer::new(&c);
302 let feat = vec![1.0f32; c.geometry_dim];
303 let out = film.modulate(&feat, &vec![0.0f32; c.geometry_dim]);
304 assert_eq!(out.len(), c.geometry_dim);
305 for i in 0..c.geometry_dim {
307 assert!((out[i] - feat[i]).abs() < 1e-5, "dim {i}: expected {}, got {}", feat[i], out[i]);
308 }
309 }
310
311 #[test]
312 fn film_nontrivial_modulation() {
313 let c = cfg();
314 let film = FilmLayer::new(&c);
315 let feat: Vec<f32> = (0..c.geometry_dim).map(|i| i as f32 * 0.1).collect();
316 let geom: Vec<f32> = (0..c.geometry_dim).map(|i| (i as f32 - 32.0) * 0.01).collect();
317 let out = film.modulate(&feat, &geom);
318 assert_eq!(out.len(), c.geometry_dim);
319 assert!(out.iter().zip(feat.iter()).any(|(o, f)| (o - f).abs() > 1e-6));
320 for &v in &out { assert!(v.is_finite()); }
321 }
322
323 #[test]
324 fn film_explicit_gamma_beta() {
325 let c = MeridianGeometryConfig { geometry_dim: 4, ..cfg() };
326 let mut film = FilmLayer::new(&c);
327 film.gamma_proj.weights = vec![0.0; 16];
328 film.gamma_proj.bias = vec![2.0, 3.0, 0.5, 1.0];
329 film.beta_proj.weights = vec![0.0; 16];
330 film.beta_proj.bias = vec![10.0, 20.0, 30.0, 40.0];
331 let out = film.modulate(&[1.0, 2.0, 3.0, 4.0], &[999.0; 4]);
332 let exp = [12.0, 26.0, 31.5, 44.0];
333 for i in 0..4 { assert!((out[i] - exp[i]).abs() < 1e-5, "dim {i}"); }
334 }
335
336 #[test]
337 fn config_defaults() {
338 let c = MeridianGeometryConfig::default();
339 assert_eq!(c.n_frequencies, 10);
340 assert!((c.scale - 1.0).abs() < 1e-6);
341 assert_eq!(c.geometry_dim, 64);
342 assert_eq!(c.seed, 42);
343 }
344
345 #[test]
346 fn config_serde_round_trip() {
347 let c = MeridianGeometryConfig { n_frequencies: 8, scale: 0.5, geometry_dim: 32, seed: 123 };
348 let j = serde_json::to_string(&c).unwrap();
349 let d: MeridianGeometryConfig = serde_json::from_str(&j).unwrap();
350 assert_eq!(d.n_frequencies, 8); assert!((d.scale - 0.5).abs() < 1e-6);
351 assert_eq!(d.geometry_dim, 32); assert_eq!(d.seed, 123);
352 }
353
354 #[test]
355 fn linear_forward_dim() {
356 assert_eq!(Linear::new(8, 4, 0).forward(&vec![1.0; 8]).len(), 4);
357 }
358
359 #[test]
360 fn linear_zero_input_gives_bias() {
361 let lin = Linear::new(4, 3, 0);
362 let out = lin.forward(&[0.0; 4]);
363 for i in 0..3 { assert!((out[i] - lin.bias[i]).abs() < 1e-6); }
364 }
365}