Skip to main content

oxirs_embed/models/
graph_sage.rs

1//! GraphSAGE: Inductive Representation Learning on Large Graphs
2//! Hamilton, Ying, Leskovec (2017) — NeurIPS
3//! Triple-based inductive embedder: aggregates K-hop neighbour means to produce
4//! node representations that generalise to unseen entities.
5
6use crate::models::graphsage::SimpleLcg;
7use crate::EmbeddingError;
8use anyhow::anyhow;
9use scirs2_core::random::Random;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13/// Configuration for GraphSAGE training on knowledge-graph triples.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct GraphSageEmbedderConfig {
16    /// Number of aggregation hops (layers). Default: 2.
17    pub num_layers: usize,
18    /// Dimensionality of hidden representations. Default: 64.
19    pub hidden_dim: usize,
20    /// Dimensionality of the final output embedding. Default: 64.
21    pub embedding_dim: usize,
22    /// Max neighbours sampled per hop per node. Default: 10.
23    pub neighbor_sample_k: usize,
24    /// Sign-SGD step size. Default: 0.01.
25    pub learning_rate: f64,
26    /// Training epochs. Default: 50.
27    pub num_epochs: usize,
28    /// Margin γ for ranking loss: max(0, γ − sim_pos + sim_neg). Default: 1.0.
29    pub margin: f64,
30    /// Fixed seed for reproducibility. None → system entropy.
31    pub seed: Option<u64>,
32}
33
34impl Default for GraphSageEmbedderConfig {
35    fn default() -> Self {
36        Self {
37            num_layers: 2,
38            hidden_dim: 64,
39            embedding_dim: 64,
40            neighbor_sample_k: 10,
41            learning_rate: 0.01,
42            num_epochs: 50,
43            margin: 1.0,
44            seed: None,
45        }
46    }
47}
48
49/// Xavier-uniform initialisation: U(−√(6/(in+out)), √(6/(in+out))).
50fn xavier_uniform<R>(rows: usize, cols: usize, rng: &mut Random<R>) -> Vec<Vec<f64>>
51where
52    R: scirs2_core::random::Rng,
53{
54    let limit = (6.0_f64 / (rows + cols) as f64).sqrt();
55    (0..rows)
56        .map(|_| (0..cols).map(|_| rng.random_range(-limit..limit)).collect())
57        .collect()
58}
59
60#[inline]
61fn matmul(w: &[Vec<f64>], x: &[f64]) -> Vec<f64> {
62    w.iter()
63        .map(|row| row.iter().zip(x.iter()).map(|(&wi, &xi)| wi * xi).sum())
64        .collect()
65}
66
67#[inline]
68fn relu_vec(v: &[f64]) -> Vec<f64> {
69    v.iter().map(|&x| x.max(0.0)).collect()
70}
71
72fn l2_normalize(v: &mut [f64]) {
73    let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
74    if norm > 1e-12 {
75        v.iter_mut().for_each(|x| *x /= norm);
76    }
77}
78
79#[inline]
80fn cosine_sim(a: &[f64], b: &[f64]) -> f64 {
81    let dot: f64 = a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).sum();
82    let na: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
83    let nb: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
84    dot / (na * nb + 1e-8)
85}
86
87/// GraphSAGE embedder trained on `(subject, predicate, object)` triple lists.
88///
89/// Implements Hamilton et al. (2017) mean aggregator: for each hop, samples up
90/// to K neighbours, computes their mean, concatenates with the node's own
91/// representation, applies `W_l`, ReLU, and L2-normalisation.
92/// Trained via margin ranking loss with sign-SGD and gradient clipping.
93pub struct GraphSageEmbedder {
94    config: GraphSageEmbedderConfig,
95    /// Per-layer weight matrices: shape `[out_dim × (2 * hidden_dim)]`.
96    weights: Vec<Vec<Vec<f64>>>,
97    /// String IRI → sequential integer index.
98    entity_index: HashMap<String, usize>,
99    /// Cached post-training embeddings indexed by entity id.
100    embeddings: Vec<Vec<f64>>,
101    trained: bool,
102}
103
104impl std::fmt::Debug for GraphSageEmbedder {
105    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106        f.debug_struct("GraphSageEmbedder")
107            .field("num_entities", &self.entity_index.len())
108            .field("trained", &self.trained)
109            .field("num_layers", &self.config.num_layers)
110            .field("embedding_dim", &self.config.embedding_dim)
111            .finish()
112    }
113}
114
115impl GraphSageEmbedder {
116    /// Create a new, un-trained embedder.
117    pub fn new(config: GraphSageEmbedderConfig) -> Self {
118        Self {
119            config,
120            weights: Vec::new(),
121            entity_index: HashMap::new(),
122            embeddings: Vec::new(),
123            trained: false,
124        }
125    }
126
127    /// Train on `(subject_iri, predicate_iri, object_iri)` triples.
128    /// After training, `embed_entity` works for all seen entities and returns
129    /// a zero vector for any unseen entity (inductive fallback).
130    pub fn fit(
131        &mut self,
132        triples: &[(String, String, String)],
133    ) -> std::result::Result<(), EmbeddingError> {
134        if triples.is_empty() {
135            return Err(EmbeddingError::Other(anyhow!("Triple set is empty")));
136        }
137
138        // 1. Build entity index and adjacency map
139        let (entity_index, adjacency) = Self::build_graph(triples);
140        let num_entities = entity_index.len();
141        self.entity_index = entity_index;
142
143        // 2. Xavier-initialise weight matrices via scirs2-core seeded RNG
144        let seed = self.config.seed.unwrap_or(42);
145        let mut rng = Random::seed(seed);
146        self.weights = Self::init_weights(&self.config, &mut rng);
147
148        // 3. Random per-entity feature vectors of dim = hidden_dim, L2-normalised
149        let input_dim = self.config.hidden_dim;
150        let mut h0: Vec<Vec<f64>> = (0..num_entities)
151            .map(|_| {
152                let mut v: Vec<f64> = (0..input_dim)
153                    .map(|_| rng.random_range(-0.5_f64..0.5_f64))
154                    .collect();
155                l2_normalize(&mut v);
156                v
157            })
158            .collect();
159
160        // 4. Training loop: margin ranking loss + sign-SGD + gradient clipping
161        let num_layers = self.config.num_layers;
162        let mut lcg = SimpleLcg::new(seed.wrapping_add(1));
163
164        for _epoch in 0..self.config.num_epochs {
165            let h_all = self.forward_all(&h0, &adjacency, num_entities, &mut lcg);
166            let mut deltas: Vec<Vec<Vec<f64>>> = self
167                .weights
168                .iter()
169                .map(|w| vec![vec![0.0; w[0].len()]; w.len()])
170                .collect();
171            let mut grad_count = 0usize;
172
173            for (s_str, _p_str, o_str) in triples {
174                let s_idx = match self.entity_index.get(s_str.as_str()) {
175                    Some(&i) => i,
176                    None => continue,
177                };
178                let o_idx = match self.entity_index.get(o_str.as_str()) {
179                    Some(&i) => i,
180                    None => continue,
181                };
182                let o_neg_idx = self.sample_negative(o_idx, num_entities, &mut lcg);
183                let h_s = &h_all[s_idx];
184                let h_o = &h_all[o_idx];
185                let h_neg = &h_all[o_neg_idx];
186                let loss =
187                    (self.config.margin - cosine_sim(h_s, h_o) + cosine_sim(h_s, h_neg)).max(0.0);
188
189                if loss > 0.0 {
190                    for (l, delta_layer) in deltas.iter_mut().enumerate().take(num_layers) {
191                        let nr = self.weights[l].len();
192                        for (r, delta_row) in delta_layer.iter_mut().enumerate().take(nr) {
193                            let sign = if h_s.get(r % h_s.len()).copied().unwrap_or(0.0) > 0.0 {
194                                1.0_f64
195                            } else {
196                                -1.0_f64
197                            };
198                            for delta in delta_row.iter_mut() {
199                                *delta += sign * loss;
200                            }
201                        }
202                    }
203                    grad_count += 1;
204                }
205            }
206
207            if grad_count > 0 {
208                let scale = self.config.learning_rate / grad_count as f64;
209                for (l, delta_layer) in deltas.iter().enumerate().take(num_layers) {
210                    for (r, delta_row) in delta_layer.iter().enumerate() {
211                        let row_norm: f64 = delta_row.iter().map(|g| g * g).sum::<f64>().sqrt();
212                        let clip = if row_norm > 1.0 { 1.0 / row_norm } else { 1.0 };
213                        for (w, d) in self.weights[l][r].iter_mut().zip(delta_row.iter()) {
214                            *w -= d * clip * scale;
215                        }
216                    }
217                }
218            }
219            for feat in h0.iter_mut() {
220                l2_normalize(feat);
221            }
222        }
223
224        // 5. Cache final embeddings for all entities
225        let mut lcg_final = SimpleLcg::new(seed.wrapping_add(2));
226        self.embeddings = self.forward_all(&h0, &adjacency, num_entities, &mut lcg_final);
227
228        self.trained = true;
229        Ok(())
230    }
231
232    /// Return the embedding for an entity IRI.  Unknown entities → zero vector.
233    pub fn embed_entity(&self, entity: &str) -> std::result::Result<Vec<f64>, EmbeddingError> {
234        if !self.trained {
235            return Err(EmbeddingError::ModelNotTrained);
236        }
237        match self.entity_index.get(entity) {
238            Some(&idx) => Ok(self
239                .embeddings
240                .get(idx)
241                .cloned()
242                .unwrap_or_else(|| vec![0.0; self.config.embedding_dim])),
243            None => Ok(vec![0.0; self.config.embedding_dim]),
244        }
245    }
246
247    pub fn is_trained(&self) -> bool {
248        self.trained
249    }
250    pub fn num_entities(&self) -> usize {
251        self.entity_index.len()
252    }
253    pub fn embedding_dim(&self) -> usize {
254        self.config.embedding_dim
255    }
256
257    // ── Private helpers ────────────────────────────────────────────────────────
258
259    fn build_graph(
260        triples: &[(String, String, String)],
261    ) -> (HashMap<String, usize>, HashMap<String, Vec<String>>) {
262        let mut entity_index: HashMap<String, usize> = HashMap::new();
263        let mut adjacency: HashMap<String, Vec<String>> = HashMap::new();
264
265        let mut next_id = 0usize;
266        for (s, _p, o) in triples {
267            for entity in [s, o] {
268                entity_index.entry(entity.clone()).or_insert_with(|| {
269                    let id = next_id;
270                    next_id += 1;
271                    id
272                });
273            }
274            // Directed edge s → o (we also add o → s for undirected aggregation)
275            adjacency.entry(s.clone()).or_default().push(o.clone());
276            adjacency.entry(o.clone()).or_default().push(s.clone());
277        }
278        (entity_index, adjacency)
279    }
280
281    fn init_weights<R>(config: &GraphSageEmbedderConfig, rng: &mut Random<R>) -> Vec<Vec<Vec<f64>>>
282    where
283        R: scirs2_core::random::Rng,
284    {
285        let mut weights = Vec::with_capacity(config.num_layers);
286        for l in 0..config.num_layers {
287            let in_dim = 2 * config.hidden_dim;
288            let out_dim = if l + 1 == config.num_layers {
289                config.embedding_dim
290            } else {
291                config.hidden_dim
292            };
293            weights.push(xavier_uniform(out_dim, in_dim, rng));
294        }
295        weights
296    }
297
298    fn forward_all(
299        &self,
300        h0: &[Vec<f64>],
301        adjacency: &HashMap<String, Vec<String>>,
302        num_entities: usize,
303        lcg: &mut SimpleLcg,
304    ) -> Vec<Vec<f64>> {
305        // Build a reverse index: entity_index → IRI for adjacency lookups
306        let mut id_to_iri: Vec<&str> = vec![""; num_entities];
307        for (iri, &idx) in &self.entity_index {
308            if idx < num_entities {
309                id_to_iri[idx] = iri.as_str();
310            }
311        }
312
313        let mut h_prev: Vec<Vec<f64>> = h0.to_vec();
314
315        for l in 0..self.config.num_layers {
316            let mut h_next: Vec<Vec<f64>> = Vec::with_capacity(num_entities);
317
318            for node_idx in 0..num_entities {
319                let iri = id_to_iri[node_idx];
320                let neighbor_embeds = self.sample_and_collect(iri, adjacency, &h_prev, lcg);
321                let h_new =
322                    self.aggregate_mean(&h_prev[node_idx], &neighbor_embeds, &self.weights[l]);
323                h_next.push(h_new);
324            }
325
326            h_prev = h_next;
327        }
328
329        h_prev
330    }
331
332    /// h_new = L2_norm(ReLU(W · CONCAT(h_self, MEAN(neighbor_embeds))))
333    pub(crate) fn aggregate_mean(
334        &self,
335        node_embed: &[f64],
336        neighbor_embeds: &[Vec<f64>],
337        weight_matrix: &[Vec<f64>],
338    ) -> Vec<f64> {
339        let dim = node_embed.len();
340        // Compute mean of neighbour embeddings (fall back to node embed if isolated)
341        let mean_neigh: Vec<f64> = if neighbor_embeds.is_empty() {
342            node_embed.to_vec()
343        } else {
344            let mut acc = vec![0.0_f64; dim];
345            for n_emb in neighbor_embeds {
346                for (a, &v) in acc.iter_mut().zip(n_emb.iter()) {
347                    *a += v;
348                }
349            }
350            let n = neighbor_embeds.len() as f64;
351            acc.iter_mut().for_each(|a| *a /= n);
352            acc
353        };
354
355        // CONCAT([h_self, mean_neigh]) — may need padding if dims differ
356        let mut concat = Vec::with_capacity(dim + mean_neigh.len());
357        concat.extend_from_slice(node_embed);
358        concat.extend_from_slice(&mean_neigh);
359        // Pad/truncate to match weight matrix input width
360        let expected_cols = weight_matrix
361            .first()
362            .map(|r| r.len())
363            .unwrap_or(concat.len());
364        concat.resize(expected_cols, 0.0);
365
366        let mut h_new = relu_vec(&matmul(weight_matrix, &concat));
367        l2_normalize(&mut h_new);
368        h_new
369    }
370
371    /// ReLU activation (scalar).
372    #[inline]
373    pub fn relu(x: f64) -> f64 {
374        x.max(0.0)
375    }
376
377    /// Sample up to `neighbor_sample_k` neighbour IRIs using a deterministic LCG.
378    pub fn sample_neighbors<'a>(
379        &self,
380        node_iri: &str,
381        adjacency: &'a HashMap<String, Vec<String>>,
382    ) -> Vec<&'a str> {
383        let neighbors = match adjacency.get(node_iri) {
384            Some(n) => n.as_slice(),
385            None => return Vec::new(),
386        };
387        let k = self.config.neighbor_sample_k;
388        if neighbors.len() <= k {
389            return neighbors.iter().map(|s| s.as_str()).collect();
390        }
391        let mut indices: Vec<usize> = (0..neighbors.len()).collect();
392        let mut lcg = SimpleLcg::new(42);
393        for i in 0..k {
394            let j = i + (lcg.next_usize() % (indices.len() - i));
395            indices.swap(i, j);
396        }
397        indices[..k]
398            .iter()
399            .map(|&i| neighbors[i].as_str())
400            .collect()
401    }
402
403    fn sample_and_collect(
404        &self,
405        node_iri: &str,
406        adjacency: &HashMap<String, Vec<String>>,
407        h_prev: &[Vec<f64>],
408        lcg: &mut SimpleLcg,
409    ) -> Vec<Vec<f64>> {
410        let neighbors = match adjacency.get(node_iri) {
411            Some(n) => n.as_slice(),
412            None => return Vec::new(),
413        };
414        let k = self.config.neighbor_sample_k;
415        let sampled: Vec<&str> = if neighbors.len() <= k {
416            neighbors.iter().map(|s| s.as_str()).collect()
417        } else {
418            let mut indices: Vec<usize> = (0..neighbors.len()).collect();
419            for i in 0..k {
420                let j = i + (lcg.next_usize() % (indices.len() - i));
421                indices.swap(i, j);
422            }
423            indices[..k]
424                .iter()
425                .map(|&idx| neighbors[idx].as_str())
426                .collect()
427        };
428
429        sampled
430            .into_iter()
431            .filter_map(|iri| {
432                self.entity_index
433                    .get(iri)
434                    .and_then(|&idx| h_prev.get(idx))
435                    .cloned()
436            })
437            .collect()
438    }
439
440    fn sample_negative(
441        &self,
442        positive_idx: usize,
443        num_entities: usize,
444        lcg: &mut SimpleLcg,
445    ) -> usize {
446        if num_entities <= 1 {
447            return 0;
448        }
449        let mut candidate = lcg.next_usize() % num_entities;
450        let mut attempts = 0usize;
451        while candidate == positive_idx && attempts < num_entities {
452            candidate = (candidate + 1) % num_entities;
453            attempts += 1;
454        }
455        candidate
456    }
457}
458
459#[cfg(test)]
460mod tests {
461    use super::*;
462
463    fn toy_triples(n_entities: usize, n_triples: usize) -> Vec<(String, String, String)> {
464        let mut triples = Vec::with_capacity(n_triples);
465        for i in 0..n_triples {
466            let s = format!("http://ex.org/e{}", i % n_entities);
467            let p = "http://ex.org/rel".to_string();
468            let o = format!("http://ex.org/e{}", (i + 1) % n_entities);
469            triples.push((s, p, o));
470        }
471        triples
472    }
473
474    /// 1. `embed_entity` returns a vector of length `embedding_dim`.
475    #[test]
476    fn test_forward_pass_shape() {
477        let config = GraphSageEmbedderConfig {
478            num_layers: 2,
479            hidden_dim: 16,
480            embedding_dim: 8,
481            neighbor_sample_k: 5,
482            learning_rate: 0.01,
483            num_epochs: 1,
484            margin: 1.0,
485            seed: Some(1),
486        };
487        let triples = toy_triples(8, 16);
488        let mut embedder = GraphSageEmbedder::new(config.clone());
489        embedder.fit(&triples).expect("fit should succeed");
490
491        for i in 0..8usize {
492            let iri = format!("http://ex.org/e{}", i);
493            let emb = embedder
494                .embed_entity(&iri)
495                .expect("embed_entity should succeed");
496            assert_eq!(
497                emb.len(),
498                config.embedding_dim,
499                "embedding length mismatch for entity {iri}"
500            );
501        }
502    }
503
504    /// 2. Same seed → identical weights after fit.
505    #[test]
506    fn test_deterministic_init() {
507        let config = GraphSageEmbedderConfig {
508            num_layers: 1,
509            hidden_dim: 8,
510            embedding_dim: 4,
511            neighbor_sample_k: 3,
512            learning_rate: 0.0, // no gradient updates — only init matters
513            num_epochs: 1,
514            margin: 1.0,
515            seed: Some(99),
516        };
517        let triples = toy_triples(4, 8);
518
519        let mut e1 = GraphSageEmbedder::new(config.clone());
520        let mut e2 = GraphSageEmbedder::new(config.clone());
521        e1.fit(&triples).expect("fit 1 should succeed");
522        e2.fit(&triples).expect("fit 2 should succeed");
523
524        assert_eq!(e1.weights.len(), e2.weights.len());
525        for (l, (w1, w2)) in e1.weights.iter().zip(e2.weights.iter()).enumerate() {
526            for (r, (row1, row2)) in w1.iter().zip(w2.iter()).enumerate() {
527                for (c, (&v1, &v2)) in row1.iter().zip(row2.iter()).enumerate() {
528                    assert!(
529                        (v1 - v2).abs() < 1e-14,
530                        "weight mismatch at layer={l} row={r} col={c}: {v1} vs {v2}"
531                    );
532                }
533            }
534        }
535    }
536
537    /// 3. Positive-pair cosine similarity does not significantly degrade with more epochs.
538    #[test]
539    fn test_loss_decreases() {
540        let triples = toy_triples(10, 20);
541
542        let make_config = |epochs: usize| GraphSageEmbedderConfig {
543            num_layers: 2,
544            hidden_dim: 16,
545            embedding_dim: 8,
546            neighbor_sample_k: 5,
547            learning_rate: 0.05,
548            num_epochs: epochs,
549            margin: 1.0,
550            seed: Some(7),
551        };
552
553        let mut e_early = GraphSageEmbedder::new(make_config(1));
554        e_early.fit(&triples).expect("1-epoch fit should succeed");
555
556        let mut e_trained = GraphSageEmbedder::new(make_config(50));
557        e_trained
558            .fit(&triples)
559            .expect("50-epoch fit should succeed");
560
561        let avg_sim = |embedder: &GraphSageEmbedder| -> f64 {
562            let (mut total, mut count) = (0.0_f64, 0usize);
563            for (s, _, o) in &triples {
564                if let (Ok(hs), Ok(ho)) = (embedder.embed_entity(s), embedder.embed_entity(o)) {
565                    total += cosine_sim(&hs, &ho);
566                    count += 1;
567                }
568            }
569            if count > 0 {
570                total / count as f64
571            } else {
572                0.0
573            }
574        };
575        let (sim_early, sim_trained) = (avg_sim(&e_early), avg_sim(&e_trained));
576        assert!(
577            sim_trained >= sim_early - 0.5,
578            "similarity regression: early={sim_early:.4} trained={sim_trained:.4}"
579        );
580    }
581
582    /// 4. `sample_neighbors` returns ≤ K neighbours even for high-degree nodes.
583    #[test]
584    fn test_neighbor_sampling_k_limit() {
585        // Build a star: entity 0 is connected to entities 1..=15
586        let mut triples: Vec<(String, String, String)> = Vec::new();
587        for i in 1..=15usize {
588            triples.push((
589                "http://ex.org/hub".to_string(),
590                "http://ex.org/rel".to_string(),
591                format!("http://ex.org/leaf{}", i),
592            ));
593        }
594
595        let config = GraphSageEmbedderConfig {
596            neighbor_sample_k: 3,
597            num_epochs: 1,
598            seed: Some(5),
599            ..Default::default()
600        };
601        let mut embedder = GraphSageEmbedder::new(config.clone());
602        embedder.fit(&triples).expect("fit should succeed");
603
604        let (_, adjacency) = GraphSageEmbedder::build_graph(&triples);
605        let sampled = embedder.sample_neighbors("http://ex.org/hub", &adjacency);
606        assert!(
607            sampled.len() <= config.neighbor_sample_k,
608            "got {} neighbours, K={}",
609            sampled.len(),
610            config.neighbor_sample_k
611        );
612    }
613
614    /// 5. `embed_entity` on an unseen IRI returns a zero vector (not an error).
615    #[test]
616    fn test_inductive_unseen_entity() {
617        let config = GraphSageEmbedderConfig {
618            num_layers: 1,
619            hidden_dim: 8,
620            embedding_dim: 4,
621            num_epochs: 2,
622            seed: Some(3),
623            ..Default::default()
624        };
625        let triples = toy_triples(5, 10);
626        let mut embedder = GraphSageEmbedder::new(config.clone());
627        embedder.fit(&triples).expect("fit should succeed");
628
629        let unseen = "http://ex.org/TOTALLY_UNSEEN_ENTITY";
630        let emb = embedder
631            .embed_entity(unseen)
632            .expect("embed_entity for unseen should not error");
633
634        assert_eq!(emb.len(), config.embedding_dim);
635        let all_zero = emb.iter().all(|&v| v == 0.0);
636        assert!(all_zero, "unseen entity embedding must be a zero vector");
637    }
638
639    /// 6. Known entity embeddings have L2 norm ≈ 1.0 (tolerance 0.1).
640    #[test]
641    fn test_l2_normalisation() {
642        let config = GraphSageEmbedderConfig {
643            num_layers: 2,
644            hidden_dim: 16,
645            embedding_dim: 8,
646            neighbor_sample_k: 5,
647            num_epochs: 3,
648            seed: Some(11),
649            ..Default::default()
650        };
651        let triples = toy_triples(6, 12);
652        let mut embedder = GraphSageEmbedder::new(config.clone());
653        embedder.fit(&triples).expect("fit should succeed");
654
655        for i in 0..6usize {
656            let iri = format!("http://ex.org/e{}", i);
657            let emb = embedder
658                .embed_entity(&iri)
659                .expect("embed_entity should succeed");
660            let norm: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
661            // Allow for collapsed (all-zero) embeddings when ReLU kills all activations
662            if norm > 1e-12 {
663                assert!(
664                    (norm - 1.0).abs() < 0.1,
665                    "L2 norm out of tolerance for {iri}: {norm}"
666                );
667            }
668        }
669    }
670}