#[cfg(feature = "sublinear")]
use ruvector_attention::{ScaledDotProductAttention, Attention};
#[cfg(feature = "sublinear")]
use crate::config::SublinearConfig;
#[cfg(feature = "sublinear")]
use crate::error::{GraphTransformerError, Result};
#[cfg(feature = "sublinear")]
pub struct SublinearGraphAttention {
config: SublinearConfig,
attention: ScaledDotProductAttention,
embed_dim: usize,
}
#[cfg(feature = "sublinear")]
impl SublinearGraphAttention {
pub fn new(embed_dim: usize, config: SublinearConfig) -> Self {
let attention = ScaledDotProductAttention::new(embed_dim);
Self {
config,
attention,
embed_dim,
}
}
pub fn lsh_attention(
&self,
node_features: &[Vec<f32>],
) -> Result<Vec<Vec<f32>>> {
if node_features.is_empty() {
return Ok(Vec::new());
}
let dim = node_features[0].len();
if dim != self.embed_dim {
return Err(GraphTransformerError::DimensionMismatch {
expected: self.embed_dim,
actual: dim,
});
}
let n = node_features.len();
let num_buckets = self.config.lsh_buckets.max(1);
let mut buckets: Vec<Vec<usize>> = vec![Vec::new(); num_buckets];
for (i, feat) in node_features.iter().enumerate() {
let bucket = lsh_hash(feat, num_buckets);
buckets[bucket].push(i);
}
let mut outputs = vec![vec![0.0f32; dim]; n];
for bucket in &buckets {
if bucket.len() < 2 {
for &idx in bucket {
outputs[idx] = node_features[idx].clone();
}
continue;
}
for &query_idx in bucket {
let query = &node_features[query_idx];
let keys: Vec<&[f32]> = bucket
.iter()
.filter(|&&i| i != query_idx)
.map(|&i| node_features[i].as_slice())
.collect();
let values: Vec<&[f32]> = keys.clone();
if keys.is_empty() {
outputs[query_idx] = query.clone();
continue;
}
let result = self.attention.compute(query, &keys, &values)
.map_err(GraphTransformerError::Attention)?;
outputs[query_idx] = result;
}
}
Ok(outputs)
}
pub fn ppr_attention(
&self,
node_features: &[Vec<f32>],
edges: &[(usize, usize, f64)],
) -> Result<Vec<Vec<f32>>> {
if node_features.is_empty() {
return Ok(Vec::new());
}
let dim = node_features[0].len();
let n = node_features.len();
let k = self.config.ppr_samples.min(n - 1).max(1);
let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
for &(u, v, _w) in edges {
if u < n && v < n {
adj[u].push(v);
adj[v].push(u);
}
}
let mut outputs = vec![vec![0.0f32; dim]; n];
let mut rng = rand::thread_rng();
for i in 0..n {
let sampled = ppr_sample(&adj, i, k, &mut rng);
if sampled.is_empty() {
outputs[i] = node_features[i].clone();
continue;
}
let query = &node_features[i];
let keys: Vec<&[f32]> = sampled
.iter()
.map(|&j| node_features[j].as_slice())
.collect();
let values: Vec<&[f32]> = keys.clone();
let result = self.attention.compute(query, &keys, &values)
.map_err(GraphTransformerError::Attention)?;
outputs[i] = result;
}
Ok(outputs)
}
pub fn spectral_attention(
&self,
node_features: &[Vec<f32>],
edges: &[(usize, usize, f64)],
) -> Result<Vec<Vec<f32>>> {
if node_features.is_empty() {
return Ok(Vec::new());
}
let dim = node_features[0].len();
let n = node_features.len();
let sparsity = self.config.sparsification_factor;
let mut adj: Vec<Vec<(usize, f64)>> = vec![Vec::new(); n];
for &(u, v, w) in edges {
if u < n && v < n {
adj[u].push((v, w));
adj[v].push((u, w));
}
}
let mut outputs = vec![vec![0.0f32; dim]; n];
for i in 0..n {
let mut neighbors = adj[i].clone();
neighbors.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let keep = ((neighbors.len() as f32 * sparsity) as usize).max(1);
neighbors.truncate(keep);
if neighbors.is_empty() {
outputs[i] = node_features[i].clone();
continue;
}
let query = &node_features[i];
let keys: Vec<&[f32]> = neighbors
.iter()
.map(|&(j, _)| node_features[j].as_slice())
.collect();
let values: Vec<&[f32]> = keys.clone();
let result = self.attention.compute(query, &keys, &values)
.map_err(GraphTransformerError::Attention)?;
outputs[i] = result;
}
Ok(outputs)
}
pub fn embed_dim(&self) -> usize {
self.embed_dim
}
}
#[cfg(feature = "sublinear")]
fn lsh_hash(features: &[f32], num_buckets: usize) -> usize {
let mut h: u64 = 0;
for (i, &v) in features.iter().enumerate() {
let bits = v.to_bits() as u64;
h = h.wrapping_add(bits.wrapping_mul(i as u64 + 1));
}
h = h.wrapping_mul(0x517cc1b727220a95);
(h as usize) % num_buckets
}
#[cfg(feature = "sublinear")]
fn ppr_sample(
adj: &[Vec<usize>],
source: usize,
k: usize,
rng: &mut impl rand::Rng,
) -> Vec<usize> {
use std::collections::HashSet;
let alpha = 0.15; let mut visited = HashSet::new();
let max_walks = k * 4;
for _ in 0..max_walks {
if visited.len() >= k {
break;
}
let mut current = source;
for _ in 0..10 {
if rng.gen::<f64>() < alpha {
break;
}
if adj[current].is_empty() {
break;
}
let idx = rng.gen_range(0..adj[current].len());
current = adj[current][idx];
}
if current != source {
visited.insert(current);
}
}
visited.into_iter().collect()
}
#[cfg(test)]
#[cfg(feature = "sublinear")]
mod tests {
use super::*;
#[test]
fn test_lsh_attention_basic() {
let config = SublinearConfig {
lsh_buckets: 4,
ppr_samples: 8,
sparsification_factor: 0.5,
};
let attn = SublinearGraphAttention::new(8, config);
let features = vec![
vec![1.0; 8],
vec![0.5; 8],
vec![0.3; 8],
vec![0.8; 8],
];
let result = attn.lsh_attention(&features);
assert!(result.is_ok());
let outputs = result.unwrap();
assert_eq!(outputs.len(), 4);
for out in &outputs {
assert_eq!(out.len(), 8);
}
}
#[test]
fn test_lsh_attention_empty() {
let config = SublinearConfig::default();
let attn = SublinearGraphAttention::new(8, config);
let result = attn.lsh_attention(&[]);
assert!(result.is_ok());
assert!(result.unwrap().is_empty());
}
#[test]
fn test_ppr_attention_basic() {
let config = SublinearConfig {
lsh_buckets: 4,
ppr_samples: 2,
sparsification_factor: 0.5,
};
let attn = SublinearGraphAttention::new(4, config);
let features = vec![
vec![1.0, 0.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
vec![0.0, 0.0, 1.0, 0.0],
vec![0.0, 0.0, 0.0, 1.0],
];
let edges = vec![
(0, 1, 1.0),
(1, 2, 1.0),
(2, 3, 1.0),
(3, 0, 1.0),
];
let result = attn.ppr_attention(&features, &edges);
assert!(result.is_ok());
let outputs = result.unwrap();
assert_eq!(outputs.len(), 4);
}
#[test]
fn test_spectral_attention_basic() {
let config = SublinearConfig {
lsh_buckets: 4,
ppr_samples: 4,
sparsification_factor: 0.5,
};
let attn = SublinearGraphAttention::new(4, config);
let features = vec![
vec![1.0, 0.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
vec![0.0, 0.0, 1.0, 0.0],
];
let edges = vec![
(0, 1, 2.0),
(1, 2, 1.0),
(0, 2, 0.5),
];
let result = attn.spectral_attention(&features, &edges);
assert!(result.is_ok());
let outputs = result.unwrap();
assert_eq!(outputs.len(), 3);
}
#[test]
fn test_dimension_mismatch() {
let config = SublinearConfig::default();
let attn = SublinearGraphAttention::new(8, config);
let features = vec![vec![1.0; 4]]; let result = attn.lsh_attention(&features);
assert!(result.is_err());
}
#[test]
fn test_lsh_hash_deterministic() {
let f = vec![1.0, 2.0, 3.0, 4.0];
let h1 = lsh_hash(&f, 16);
let h2 = lsh_hash(&f, 16);
assert_eq!(h1, h2);
assert!(h1 < 16);
}
}