graph_simulation/utils/
validation.rs

1use std::{clone, collections::{HashMap, HashSet}, hash::Hash, ops::{Add, BitXor, Div, Mul, Sub}};
2use rand::{prelude::*, rng};
3use rand::distr::StandardUniform;
4use serde::{Serialize, Deserialize};
5use std::sync::RwLock;
6use rand_pcg::Pcg64;
7use lazy_static::lazy_static;
8
9lazy_static!{
10    static ref clusters: RwLock<HashMap<u64, Desc>> = RwLock::new(HashMap::new()); 
11}
12
13#[derive(Clone, Serialize, Deserialize)]
14struct Desc([f64; 16]);
15
16impl Hash for Desc {
17    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
18        for value in &self.0 {
19            let bits = value.to_bits();
20            bits.hash(state);
21        }
22    }
23}
24
25impl PartialEq for Desc {
26    fn eq(&self, other: &Self) -> bool {
27        self.0.iter().zip(other.0.iter()).all(|(x, y)| x == y)
28    }
29}
30
31impl Eq for Desc {}
32
33impl Add for Desc {
34    type Output = Self;
35    fn add(self, other: Self) -> Self {
36        let mut result = [0.0; 16];
37        for i in 0..16 {
38            result[i] = self.0[i] + other.0[i];
39        }
40        Desc(result)
41    }
42}
43
44impl Sub for Desc {
45    type Output = Self;
46    fn sub(self, other: Self) -> Self {
47        let mut result = [0.0; 16];
48        for i in 0..16 {
49            result[i] = self.0[i] - other.0[i];
50        }
51        Desc(result)
52    }
53}
54
55impl Mul for Desc {
56    type Output = f64;
57    fn mul(self, other: Self) -> f64 {
58        let mut result = 0.0;
59        for i in 0..16 {
60            result += self.0[i] * other.0[i];
61        }
62        result
63    }
64}
65
66impl Mul<f64> for Desc {
67    type Output = Desc;
68    fn mul(self, scalar: f64) -> Desc {
69        let mut result = [0.0; 16];
70        for i in 0..16 {
71            result[i] = self.0[i] * scalar;
72        }
73        Desc(result)
74    }
75}
76
77impl Div<f64> for Desc {
78    type Output = Desc;
79    fn div(self, scalar: f64) -> Desc {
80        let mut result = [0.0; 16];
81        for i in 0..16 {
82            result[i] = self.0[i] / scalar;
83        }
84        Desc(result)
85    }
86}
87
88// cosine for Desc
89impl BitXor for Desc {
90    type Output = f64;
91    fn bitxor(self, other: Self) -> f64 {
92        let res = self.clone() * other.clone();
93        let norm1 = self.clone() * self.clone();
94        let norm2 = other.clone() * other.clone();
95        res / (norm1.sqrt() * norm2.sqrt())
96    }
97}
98
99#[derive(Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
100pub struct Node {
101    id: u64,
102    node_type: u64,
103    desc: Desc
104}
105
106impl BitXor for Node {
107    type Output = f64;
108    fn bitxor(self, other: Self) -> f64 {
109        return self.desc.clone() ^ other.desc.clone();
110    }
111}
112
113fn generate_orthogonal_unit(base: &Desc) -> Desc {
114    let base_norm = (base.clone() * base.clone()).sqrt();
115    let mut orthogonal = Desc([0.0; 16]);
116    
117    let mut rng = rng();
118    loop {
119        // 生成随机高斯向量
120        for i in 0..16 {
121            orthogonal.0[i] = rng.sample(StandardUniform);
122        }
123        
124        
125        // 计算与基向量的点积
126        let projection = (orthogonal.clone() * base.clone()) / base_norm;
127        
128        // 减去投影分量使其正交
129        // for i in 0..16 {
130        //     orthogonal[i] -= projection * base[i] / base_norm;
131        // }
132
133        orthogonal = orthogonal - (base.clone() * projection) / base_norm;
134        
135        // 归一化处理
136        let ortho_norm = (orthogonal.clone() * orthogonal.clone()).sqrt();
137        if ortho_norm > 1e-10 {
138            orthogonal = orthogonal / ortho_norm;
139            break;
140        }
141    }
142    orthogonal
143}
144
145
146impl Node {
147    pub fn from_random(id: u64, k: u64, p: f64, alpha: f64) -> Node {
148        // get A random [f64; 16]
149        let mut rng = rng();
150        let random_type = rng.random_range(0..k);
151        let desc = {
152            let random_vec: [f64; 16] = rng.sample(StandardUniform);
153            let desc =    Desc(random_vec);
154            
155            if !clusters.read().unwrap().contains_key(&random_type) {
156                if clusters.read().unwrap().is_empty() {
157                    clusters.write().unwrap().insert(random_type, desc.clone());
158                    desc
159                } else {
160                    let avg_vec = clusters.read().unwrap().iter().map(|(_, v)| v.clone()).reduce(|a, b| a + b).unwrap();
161                    let orthogonal = generate_orthogonal_unit(&avg_vec);
162                    let res = orthogonal + desc;
163                    clusters.write().unwrap().insert(random_type, res.clone());
164                    res
165                }
166            } else {
167                let cluster_guard = clusters.read().unwrap();
168                let cluster_desc = cluster_guard.get(&random_type).unwrap().clone();
169    
170                if rng.random_bool(p) {
171                    let res = cluster_desc.clone() * (1.0 - alpha) + desc * alpha;
172                    res
173                } else {
174                    let orthogonal = generate_orthogonal_unit(&cluster_desc);
175                    let res = orthogonal * (1.0 - alpha) + desc * alpha;
176                    res
177                }
178            }
179        };
180
181        Node {
182            id,
183            node_type: random_type,
184            desc
185        }
186    }
187}
188
189struct Hyperedge {
190    id_set: HashSet<u64>,
191}
192
193struct DiHyperedge {
194    src: HashSet<u64>,
195    dst: HashSet<u64>
196}
197