1use std::fmt::Display;
2use std::sync::RwLock;
3
4use crate::interfaces::edge::{DirectedHyperedge, Hyperedge, NodeSet};
5use crate::interfaces::graph::SingleId;
6use crate::interfaces::hypergraph::{Hypergraph, IdVector};
7use crate::interfaces::typed::Type;
8use crate::interfaces::vertex::Vertex;
9
10use std::{collections::{HashMap, HashSet}, hash::Hash, ops::{Add, BitXor, Div, Mul, Sub}};
11use rand::{prelude::*, rng};
12use rand::distr::StandardUniform;
13use serde::{Serialize, Deserialize};
14use lazy_static::lazy_static;
15
16lazy_static!{
17 static ref clusters: RwLock<HashMap<usize, Desc>> = RwLock::new(HashMap::new());
18}
19#[derive(Clone, Serialize, Deserialize)]
20pub struct Desc([f64; 16]);
21
22impl Hash for Desc {
23 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
24 for value in &self.0 {
25 let bits = value.to_bits();
26 bits.hash(state);
27 }
28 }
29}
30
31impl PartialEq for Desc {
32 fn eq(&self, other: &Self) -> bool {
33 self.0.iter().zip(other.0.iter()).all(|(x, y)| x == y)
34 }
35}
36
37impl Eq for Desc {}
38
39impl Add for Desc {
40 type Output = Self;
41 fn add(self, other: Self) -> Self {
42 let mut result = [0.0; 16];
43 for i in 0..16 {
44 result[i] = self.0[i] + other.0[i];
45 }
46 Desc(result)
47 }
48}
49
50impl Sub for Desc {
51 type Output = Self;
52 fn sub(self, other: Self) -> Self {
53 let mut result = [0.0; 16];
54 for i in 0..16 {
55 result[i] = self.0[i] - other.0[i];
56 }
57 Desc(result)
58 }
59}
60
61impl Mul for Desc {
62 type Output = f64;
63 fn mul(self, other: Self) -> f64 {
64 let mut result = 0.0;
65 for i in 0..16 {
66 result += self.0[i] * other.0[i];
67 }
68 result
69 }
70}
71
72impl Mul<f64> for Desc {
73 type Output = Desc;
74 fn mul(self, scalar: f64) -> Desc {
75 let mut result = [0.0; 16];
76 for i in 0..16 {
77 result[i] = self.0[i] * scalar;
78 }
79 Desc(result)
80 }
81}
82
83impl Div<f64> for Desc {
84 type Output = Desc;
85 fn div(self, scalar: f64) -> Desc {
86 let mut result = [0.0; 16];
87 for i in 0..16 {
88 result[i] = self.0[i] / scalar;
89 }
90 Desc(result)
91 }
92}
93
94impl BitXor for Desc {
96 type Output = f64;
97 fn bitxor(self, other: Self) -> f64 {
98 let res = self.clone() * other.clone();
99 let norm1 = self.clone() * self.clone();
100 let norm2 = other.clone() * other.clone();
101 res / (norm1.sqrt() * norm2.sqrt())
102 }
103}
104
105#[derive(Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
106pub struct NodeType(usize);
107
108impl Type for NodeType {
109 fn type_id(&self) -> usize {
110 self.0
111 }
112}
113
114impl Display for NodeType {
115 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116 write!(f, "{}", self.0)
117 }
118}
119
120impl NodeType {
121 pub fn new(id: usize) -> Self {
122 NodeType(id)
123 }
124}
125
126#[derive(Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
127pub struct Node {
128 id: usize,
129 node_type: NodeType,
130 desc: Desc
131}
132
133impl BitXor for Node {
134 type Output = f64;
135 fn bitxor(self, other: Self) -> f64 {
136 return self.desc.clone() ^ other.desc.clone();
137 }
138}
139
140fn generate_orthogonal_unit(base: &Desc) -> Desc {
141 let base_norm = (base.clone() * base.clone()).sqrt();
142 let mut orthogonal = Desc([0.0; 16]);
143
144 let mut rng = rng();
145 loop {
146 for i in 0..16 {
148 orthogonal.0[i] = rng.sample(StandardUniform);
149 }
150
151
152 let projection = (orthogonal.clone() * base.clone()) / base_norm;
154
155 orthogonal = orthogonal - (base.clone() * projection) / base_norm;
161
162 let ortho_norm = (orthogonal.clone() * orthogonal.clone()).sqrt();
164 if ortho_norm > 1e-10 {
165 orthogonal = orthogonal / ortho_norm;
166 break;
167 }
168 }
169 orthogonal
170}
171
172
173impl Node {
174 pub fn from_random(id: usize, k: usize, p: f64, alpha: f64, rng: &mut impl Rng) -> Node {
175 let random_type = rng.random_range(0..k);
177 let desc = {
178 let random_vec: [f64; 16] = rng.sample(StandardUniform);
179 let desc = Desc(random_vec);
180
181 if !clusters.read().unwrap().contains_key(&random_type) {
182 if clusters.read().unwrap().is_empty() {
183 clusters.write().unwrap().insert(random_type, desc.clone());
184 desc
185 } else {
186 let avg_vec = clusters.read().unwrap().iter().map(|(_, v)| v.clone()).reduce(|a, b| a + b).unwrap();
187 let orthogonal = generate_orthogonal_unit(&avg_vec);
188 let res = orthogonal + desc;
189 clusters.write().unwrap().insert(random_type, res.clone());
190 res
191 }
192 } else {
193 let cluster_guard = clusters.read().unwrap();
194 let cluster_desc = cluster_guard.get(&random_type).unwrap().clone();
195
196 if rng.random_bool(p) {
197 let res = cluster_desc.clone() * (1.0 - alpha) + desc * alpha;
198 res
199 } else {
200 let orthogonal = generate_orthogonal_unit(&cluster_desc);
201 let res = orthogonal * (1.0 - alpha) + desc * alpha;
202 res
203 }
204 }
205 };
206
207 Node {
208 id,
209 node_type: NodeType::new(random_type),
210 desc
211 }
212 }
213}
214
215impl SingleId for Node {
216 fn id(&self) -> usize {
217 self.id
218 }
219}
220
221impl Display for Node {
222 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
223 write!(f, "Node {{ id: {}, node_type: {} }}", self.id, self.node_type)
224 }
225}
226
227impl Vertex for Node {}
228
229
230#[derive(Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
231pub struct HyperedgeImpl {
232 nodes: Vec<usize>,
233}
234
235impl IdVector for HyperedgeImpl {
236 fn id(&self) -> Vec<usize> {
237 self.nodes.clone()
238 }
239}
240
241impl NodeSet for HyperedgeImpl {
242 fn from_nodes(nodes: Vec<usize>) -> Self {
243 HyperedgeImpl { nodes }
244 }
245}
246
247impl Hyperedge for HyperedgeImpl {
248 fn id_set(&self) -> HashSet<usize> {
249 self.nodes.iter().cloned().collect()
250 }
251}
252
253#[derive(Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
254pub struct DirectedHyperedgeImpl {
255 src: Vec<usize>,
256 dst: Vec<usize>,
257}
258
259impl IdVector for DirectedHyperedgeImpl {
260 fn id(&self) -> Vec<usize> {
261 self.src.iter().chain(self.dst.iter()).cloned().collect()
262 }
263}
264
265impl DirectedHyperedge for DirectedHyperedgeImpl {
266 fn src(&self) -> HashSet<usize> {
267 self.src.iter().cloned().collect()
268 }
269
270 fn dst(&self) -> HashSet<usize> {
271 self.dst.iter().cloned().collect()
272 }
273}
274
275#[derive(Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
276pub struct HypergraphImpl {
277 nodes: Vec<Node>,
278 edges: Vec<HyperedgeImpl>,
279}
280
281impl<'a> Hypergraph<'a> for HypergraphImpl {
282 type Node = Node;
283 type Edge = HyperedgeImpl;
284
285 fn new() -> Self {
286 HypergraphImpl {
287 nodes: Vec::new(),
288 edges: Vec::new(),
289 }
290 }
291
292 fn nodes(&'a self) -> impl Iterator<Item = &'a Self::Node> {
293 self.nodes.iter()
294 }
295
296 fn hyperedges(&'a self) -> impl Iterator<Item = &'a Self::Edge> {
297 self.edges.iter()
298 }
299
300 fn add_node(&mut self, node: Self::Node) {
301 self.nodes.push(node);
302 }
303
304 fn add_hyperedge(&mut self, edge: Self::Edge) {
305 self.edges.push(edge);
306 }
307
308 fn get_node_by_id(&'a self, id: usize) -> Option<&'a Self::Node> {
309 self.nodes.iter().find(|node| node.id() == id)
310 }
311}
312
313
314#[derive(Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
315pub struct DirectedHypergraphImpl {
316 nodes: Vec<Node>,
317 edges: Vec<DirectedHyperedgeImpl>,
318}
319
320impl DirectedHypergraphImpl {
321 pub fn new() -> Self {
322 DirectedHypergraphImpl {
323 nodes: Vec::new(),
324 edges: Vec::new(),
325 }
326 }
327
328 pub fn add_node(&mut self, node: Node) {
329 self.nodes.push(node);
330 }
331
332 pub fn add_hyperedge(&mut self, edge: DirectedHyperedgeImpl) {
333 self.edges.push(edge);
334 }
335}