graph_base/impls/
hypergraph.rs

1use std::fmt::Display;
2use std::sync::RwLock;
3
4use crate::interfaces::edge::{DirectedHyperedge, Hyperedge, NodeSet};
5use crate::interfaces::graph::SingleId;
6use crate::interfaces::hypergraph::{Hypergraph, IdVector};
7use crate::interfaces::typed::Type;
8use crate::interfaces::vertex::Vertex;
9
10use std::{collections::{HashMap, HashSet}, hash::Hash, ops::{Add, BitXor, Div, Mul, Sub}};
11use rand::{prelude::*, rng};
12use rand::distr::StandardUniform;
13use serde::{Serialize, Deserialize};
14use lazy_static::lazy_static;
15
16lazy_static!{
17    static ref clusters: RwLock<HashMap<usize, Desc>> = RwLock::new(HashMap::new());
18}
19#[derive(Clone, Serialize, Deserialize)]
20pub struct Desc([f64; 16]);
21
22impl Hash for Desc {
23    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
24        for value in &self.0 {
25            let bits = value.to_bits();
26            bits.hash(state);
27        }
28    }
29}
30
31impl PartialEq for Desc {
32    fn eq(&self, other: &Self) -> bool {
33        self.0.iter().zip(other.0.iter()).all(|(x, y)| x == y)
34    }
35}
36
37impl Eq for Desc {}
38
39impl Add for Desc {
40    type Output = Self;
41    fn add(self, other: Self) -> Self {
42        let mut result = [0.0; 16];
43        for i in 0..16 {
44            result[i] = self.0[i] + other.0[i];
45        }
46        Desc(result)
47    }
48}
49
50impl Sub for Desc {
51    type Output = Self;
52    fn sub(self, other: Self) -> Self {
53        let mut result = [0.0; 16];
54        for i in 0..16 {
55            result[i] = self.0[i] - other.0[i];
56        }
57        Desc(result)
58    }
59}
60
61impl Mul for Desc {
62    type Output = f64;
63    fn mul(self, other: Self) -> f64 {
64        let mut result = 0.0;
65        for i in 0..16 {
66            result += self.0[i] * other.0[i];
67        }
68        result
69    }
70}
71
72impl Mul<f64> for Desc {
73    type Output = Desc;
74    fn mul(self, scalar: f64) -> Desc {
75        let mut result = [0.0; 16];
76        for i in 0..16 {
77            result[i] = self.0[i] * scalar;
78        }
79        Desc(result)
80    }
81}
82
83impl Div<f64> for Desc {
84    type Output = Desc;
85    fn div(self, scalar: f64) -> Desc {
86        let mut result = [0.0; 16];
87        for i in 0..16 {
88            result[i] = self.0[i] / scalar;
89        }
90        Desc(result)
91    }
92}
93
94// cosine for Desc
95impl BitXor for Desc {
96    type Output = f64;
97    fn bitxor(self, other: Self) -> f64 {
98        let res = self.clone() * other.clone();
99        let norm1 = self.clone() * self.clone();
100        let norm2 = other.clone() * other.clone();
101        res / (norm1.sqrt() * norm2.sqrt())
102    }
103}
104
105#[derive(Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
106pub struct NodeType(usize);
107
108impl Type for NodeType {
109    fn type_id(&self) -> usize {
110        self.0
111    }
112}
113
114impl Display for NodeType {
115    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116        write!(f, "{}", self.0)
117    }
118}
119
120impl NodeType {
121    pub fn new(id: usize) -> Self {
122        NodeType(id)
123    }
124}
125
126#[derive(Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
127pub struct Node {
128    id: usize,
129    node_type: NodeType,
130    desc: Desc
131}
132
133impl BitXor for Node {
134    type Output = f64;
135    fn bitxor(self, other: Self) -> f64 {
136        return self.desc.clone() ^ other.desc.clone();
137    }
138}
139
140fn generate_orthogonal_unit(base: &Desc) -> Desc {
141    let base_norm = (base.clone() * base.clone()).sqrt();
142    let mut orthogonal = Desc([0.0; 16]);
143    
144    let mut rng = rng();
145    loop {
146        // 生成随机高斯向量
147        for i in 0..16 {
148            orthogonal.0[i] = rng.sample(StandardUniform);
149        }
150        
151        
152        // 计算与基向量的点积
153        let projection = (orthogonal.clone() * base.clone()) / base_norm;
154        
155        // 减去投影分量使其正交
156        // for i in 0..16 {
157        //     orthogonal[i] -= projection * base[i] / base_norm;
158        // }
159
160        orthogonal = orthogonal - (base.clone() * projection) / base_norm;
161        
162        // 归一化处理
163        let ortho_norm = (orthogonal.clone() * orthogonal.clone()).sqrt();
164        if ortho_norm > 1e-10 {
165            orthogonal = orthogonal / ortho_norm;
166            break;
167        }
168    }
169    orthogonal
170}
171
172
173impl Node {
174    pub fn from_random(id: usize, k: usize, p: f64, alpha: f64, rng: &mut impl Rng) -> Node {
175        // get A random [f64; 16]
176        let random_type = rng.random_range(0..k);
177        let desc = {
178            let random_vec: [f64; 16] = rng.sample(StandardUniform);
179            let desc = Desc(random_vec);
180            
181            if !clusters.read().unwrap().contains_key(&random_type) {
182                if clusters.read().unwrap().is_empty() {
183                    clusters.write().unwrap().insert(random_type, desc.clone());
184                    desc
185                } else {
186                    let avg_vec = clusters.read().unwrap().iter().map(|(_, v)| v.clone()).reduce(|a, b| a + b).unwrap();
187                    let orthogonal = generate_orthogonal_unit(&avg_vec);
188                    let res = orthogonal + desc;
189                    clusters.write().unwrap().insert(random_type, res.clone());
190                    res
191                }
192            } else {
193                let cluster_guard = clusters.read().unwrap();
194                let cluster_desc = cluster_guard.get(&random_type).unwrap().clone();
195    
196                if rng.random_bool(p) {
197                    let res = cluster_desc.clone() * (1.0 - alpha) + desc * alpha;
198                    res
199                } else {
200                    let orthogonal = generate_orthogonal_unit(&cluster_desc);
201                    let res = orthogonal * (1.0 - alpha) + desc * alpha;
202                    res
203                }
204            }
205        };
206
207        Node {
208            id,
209            node_type: NodeType::new(random_type),
210            desc
211        }
212    }
213}
214
215impl SingleId for Node {
216    fn id(&self) -> usize {
217        self.id
218    }
219}
220
221impl Display for Node {
222    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
223        write!(f, "Node {{ id: {}, node_type: {} }}", self.id, self.node_type)
224    }
225}
226
227impl Vertex for Node {}
228
229
230#[derive(Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
231pub struct HyperedgeImpl {
232    nodes: Vec<usize>,
233}
234
235impl IdVector for HyperedgeImpl {
236    fn id(&self) -> Vec<usize> {
237        self.nodes.clone()
238    }
239}
240
241impl NodeSet for HyperedgeImpl {
242    fn from_nodes(nodes: Vec<usize>) -> Self {
243        HyperedgeImpl { nodes }
244    }
245}
246
247impl Hyperedge for HyperedgeImpl {
248    fn id_set(&self) -> HashSet<usize> {
249        self.nodes.iter().cloned().collect()
250    }
251}
252
253#[derive(Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
254pub struct DirectedHyperedgeImpl {
255    src: Vec<usize>,
256    dst: Vec<usize>,
257}
258
259impl IdVector for DirectedHyperedgeImpl {
260    fn id(&self) -> Vec<usize> {
261        self.src.iter().chain(self.dst.iter()).cloned().collect()
262    }
263}
264
265impl DirectedHyperedge for DirectedHyperedgeImpl {
266    fn src(&self) -> HashSet<usize> {
267        self.src.iter().cloned().collect()
268    }
269
270    fn dst(&self) -> HashSet<usize> {
271        self.dst.iter().cloned().collect()
272    }
273}
274
275#[derive(Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
276pub struct HypergraphImpl {
277    nodes: Vec<Node>,
278    edges: Vec<HyperedgeImpl>,
279}
280
281impl<'a> Hypergraph<'a> for HypergraphImpl {
282    type Node = Node;
283    type Edge = HyperedgeImpl;
284
285    fn new() -> Self {
286        HypergraphImpl {
287            nodes: Vec::new(),
288            edges: Vec::new(),
289        }
290    }
291
292    fn nodes(&'a self) -> impl Iterator<Item = &'a Self::Node> {
293        self.nodes.iter()
294    }
295
296    fn hyperedges(&'a self) -> impl Iterator<Item = &'a Self::Edge> {
297        self.edges.iter()
298    }
299
300    fn add_node(&mut self, node: Self::Node) {
301        self.nodes.push(node);
302    }
303
304    fn add_hyperedge(&mut self, edge: Self::Edge) {
305        self.edges.push(edge);
306    }
307
308    fn get_node_by_id(&'a self, id: usize) -> Option<&'a Self::Node> {
309        self.nodes.iter().find(|node| node.id() == id)
310    }
311}
312
313
314#[derive(Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
315pub struct DirectedHypergraphImpl {
316    nodes: Vec<Node>,
317    edges: Vec<DirectedHyperedgeImpl>,
318}
319
320impl DirectedHypergraphImpl {
321    pub fn new() -> Self {
322        DirectedHypergraphImpl {
323            nodes: Vec::new(),
324            edges: Vec::new(),
325        }
326    }
327
328    pub fn add_node(&mut self, node: Node) {
329        self.nodes.push(node);
330    }
331
332    pub fn add_hyperedge(&mut self, edge: DirectedHyperedgeImpl) {
333        self.edges.push(edge);
334    }
335}