1use std::collections::HashSet;
4
5use chrono::NaiveDate;
6
7use crate::models::{EdgeId, Graph, NodeId};
8
9#[derive(Debug, Clone)]
11pub struct SplitConfig {
12 pub train_ratio: f64,
14 pub val_ratio: f64,
16 pub random_seed: u64,
18 pub strategy: SplitStrategy,
20}
21
22impl Default for SplitConfig {
23 fn default() -> Self {
24 Self {
25 train_ratio: 0.7,
26 val_ratio: 0.15,
27 random_seed: 42,
28 strategy: SplitStrategy::Random,
29 }
30 }
31}
32
33#[derive(Debug, Clone)]
35pub enum SplitStrategy {
36 Random,
38 Temporal {
40 train_cutoff: NaiveDate,
42 val_cutoff: NaiveDate,
43 },
44 Stratified,
46 KFold { k: usize, fold: usize },
48 Transductive,
50}
51
52#[derive(Debug, Clone)]
54pub struct DataSplit {
55 pub train_nodes: Vec<NodeId>,
57 pub val_nodes: Vec<NodeId>,
59 pub test_nodes: Vec<NodeId>,
61 pub train_edges: Vec<EdgeId>,
63 pub val_edges: Vec<EdgeId>,
65 pub test_edges: Vec<EdgeId>,
67}
68
69impl DataSplit {
70 pub fn node_masks(&self, graph: &Graph) -> (Vec<bool>, Vec<bool>, Vec<bool>) {
72 let n = graph.node_count();
73 let mut train_mask = vec![false; n];
74 let mut val_mask = vec![false; n];
75 let mut test_mask = vec![false; n];
76
77 let mut node_ids: Vec<_> = graph.nodes.keys().copied().collect();
79 node_ids.sort();
80 let id_to_idx: std::collections::HashMap<_, _> = node_ids
81 .iter()
82 .enumerate()
83 .map(|(i, &id)| (id, i))
84 .collect();
85
86 for &id in &self.train_nodes {
87 if let Some(&idx) = id_to_idx.get(&id) {
88 train_mask[idx] = true;
89 }
90 }
91 for &id in &self.val_nodes {
92 if let Some(&idx) = id_to_idx.get(&id) {
93 val_mask[idx] = true;
94 }
95 }
96 for &id in &self.test_nodes {
97 if let Some(&idx) = id_to_idx.get(&id) {
98 test_mask[idx] = true;
99 }
100 }
101
102 (train_mask, val_mask, test_mask)
103 }
104
105 pub fn edge_masks(&self, graph: &Graph) -> (Vec<bool>, Vec<bool>, Vec<bool>) {
107 let m = graph.edge_count();
108 let mut train_mask = vec![false; m];
109 let mut val_mask = vec![false; m];
110 let mut test_mask = vec![false; m];
111
112 let mut edge_ids: Vec<_> = graph.edges.keys().copied().collect();
114 edge_ids.sort();
115 let id_to_idx: std::collections::HashMap<_, _> = edge_ids
116 .iter()
117 .enumerate()
118 .map(|(i, &id)| (id, i))
119 .collect();
120
121 for &id in &self.train_edges {
122 if let Some(&idx) = id_to_idx.get(&id) {
123 train_mask[idx] = true;
124 }
125 }
126 for &id in &self.val_edges {
127 if let Some(&idx) = id_to_idx.get(&id) {
128 val_mask[idx] = true;
129 }
130 }
131 for &id in &self.test_edges {
132 if let Some(&idx) = id_to_idx.get(&id) {
133 test_mask[idx] = true;
134 }
135 }
136
137 (train_mask, val_mask, test_mask)
138 }
139}
140
141pub struct DataSplitter {
143 config: SplitConfig,
144}
145
146impl DataSplitter {
147 pub fn new(config: SplitConfig) -> Self {
149 Self { config }
150 }
151
152 pub fn split(&self, graph: &Graph) -> DataSplit {
154 match &self.config.strategy {
155 SplitStrategy::Random => self.random_split(graph),
156 SplitStrategy::Temporal {
157 train_cutoff,
158 val_cutoff,
159 } => self.temporal_split(graph, *train_cutoff, *val_cutoff),
160 SplitStrategy::Stratified => self.stratified_split(graph),
161 SplitStrategy::KFold { k, fold } => self.kfold_split(graph, *k, *fold),
162 SplitStrategy::Transductive => self.transductive_split(graph),
163 }
164 }
165
166 fn random_split(&self, graph: &Graph) -> DataSplit {
168 let mut rng = SimpleRng::new(self.config.random_seed);
169
170 let mut node_ids: Vec<_> = graph.nodes.keys().copied().collect();
172 shuffle(&mut node_ids, &mut rng);
173
174 let n = node_ids.len();
175 let train_size = (n as f64 * self.config.train_ratio) as usize;
176 let val_size = (n as f64 * self.config.val_ratio) as usize;
177
178 let train_nodes: Vec<_> = node_ids[..train_size].to_vec();
179 let val_nodes: Vec<_> = node_ids[train_size..train_size + val_size].to_vec();
180 let test_nodes: Vec<_> = node_ids[train_size + val_size..].to_vec();
181
182 let mut edge_ids: Vec<_> = graph.edges.keys().copied().collect();
184 shuffle(&mut edge_ids, &mut rng);
185
186 let m = edge_ids.len();
187 let train_edge_size = (m as f64 * self.config.train_ratio) as usize;
188 let val_edge_size = (m as f64 * self.config.val_ratio) as usize;
189
190 let train_edges: Vec<_> = edge_ids[..train_edge_size].to_vec();
191 let val_edges: Vec<_> = edge_ids[train_edge_size..train_edge_size + val_edge_size].to_vec();
192 let test_edges: Vec<_> = edge_ids[train_edge_size + val_edge_size..].to_vec();
193
194 DataSplit {
195 train_nodes,
196 val_nodes,
197 test_nodes,
198 train_edges,
199 val_edges,
200 test_edges,
201 }
202 }
203
204 fn temporal_split(
206 &self,
207 graph: &Graph,
208 train_cutoff: NaiveDate,
209 val_cutoff: NaiveDate,
210 ) -> DataSplit {
211 let mut train_edges = Vec::new();
212 let mut val_edges = Vec::new();
213 let mut test_edges = Vec::new();
214
215 for (&edge_id, edge) in &graph.edges {
217 if let Some(timestamp) = edge.timestamp {
218 if timestamp < train_cutoff {
219 train_edges.push(edge_id);
220 } else if timestamp < val_cutoff {
221 val_edges.push(edge_id);
222 } else {
223 test_edges.push(edge_id);
224 }
225 } else {
226 let r = edge_id % 100;
228 if (r as f64) < self.config.train_ratio * 100.0 {
229 train_edges.push(edge_id);
230 } else if (r as f64) < (self.config.train_ratio + self.config.val_ratio) * 100.0 {
231 val_edges.push(edge_id);
232 } else {
233 test_edges.push(edge_id);
234 }
235 }
236 }
237
238 let _train_edge_set: HashSet<_> = train_edges.iter().copied().collect();
240 let _val_edge_set: HashSet<_> = val_edges.iter().copied().collect();
241
242 let mut train_nodes = HashSet::new();
243 let mut val_nodes = HashSet::new();
244 let mut test_nodes = HashSet::new();
245
246 for &edge_id in &train_edges {
248 if let Some(edge) = graph.edges.get(&edge_id) {
249 train_nodes.insert(edge.source);
250 train_nodes.insert(edge.target);
251 }
252 }
253
254 for &edge_id in &val_edges {
256 if let Some(edge) = graph.edges.get(&edge_id) {
257 if !train_nodes.contains(&edge.source) {
258 val_nodes.insert(edge.source);
259 }
260 if !train_nodes.contains(&edge.target) {
261 val_nodes.insert(edge.target);
262 }
263 }
264 }
265
266 for &edge_id in &test_edges {
268 if let Some(edge) = graph.edges.get(&edge_id) {
269 if !train_nodes.contains(&edge.source) && !val_nodes.contains(&edge.source) {
270 test_nodes.insert(edge.source);
271 }
272 if !train_nodes.contains(&edge.target) && !val_nodes.contains(&edge.target) {
273 test_nodes.insert(edge.target);
274 }
275 }
276 }
277
278 DataSplit {
279 train_nodes: train_nodes.into_iter().collect(),
280 val_nodes: val_nodes.into_iter().collect(),
281 test_nodes: test_nodes.into_iter().collect(),
282 train_edges,
283 val_edges,
284 test_edges,
285 }
286 }
287
288 fn stratified_split(&self, graph: &Graph) -> DataSplit {
290 let mut rng = SimpleRng::new(self.config.random_seed);
291
292 let mut normal_nodes: Vec<_> = graph
294 .nodes
295 .iter()
296 .filter(|(_, n)| !n.is_anomaly)
297 .map(|(&id, _)| id)
298 .collect();
299 let mut anomalous_nodes: Vec<_> = graph
300 .nodes
301 .iter()
302 .filter(|(_, n)| n.is_anomaly)
303 .map(|(&id, _)| id)
304 .collect();
305
306 shuffle(&mut normal_nodes, &mut rng);
307 shuffle(&mut anomalous_nodes, &mut rng);
308
309 let (normal_train, normal_val, normal_test) = split_by_ratio(
311 &normal_nodes,
312 self.config.train_ratio,
313 self.config.val_ratio,
314 );
315 let (anomaly_train, anomaly_val, anomaly_test) = split_by_ratio(
316 &anomalous_nodes,
317 self.config.train_ratio,
318 self.config.val_ratio,
319 );
320
321 let mut train_nodes = normal_train;
323 train_nodes.extend(anomaly_train);
324
325 let mut val_nodes = normal_val;
326 val_nodes.extend(anomaly_val);
327
328 let mut test_nodes = normal_test;
329 test_nodes.extend(anomaly_test);
330
331 let mut normal_edges: Vec<_> = graph
333 .edges
334 .iter()
335 .filter(|(_, e)| !e.is_anomaly)
336 .map(|(&id, _)| id)
337 .collect();
338 let mut anomalous_edges: Vec<_> = graph
339 .edges
340 .iter()
341 .filter(|(_, e)| e.is_anomaly)
342 .map(|(&id, _)| id)
343 .collect();
344
345 shuffle(&mut normal_edges, &mut rng);
346 shuffle(&mut anomalous_edges, &mut rng);
347
348 let (normal_train_e, normal_val_e, normal_test_e) = split_by_ratio(
349 &normal_edges,
350 self.config.train_ratio,
351 self.config.val_ratio,
352 );
353 let (anomaly_train_e, anomaly_val_e, anomaly_test_e) = split_by_ratio(
354 &anomalous_edges,
355 self.config.train_ratio,
356 self.config.val_ratio,
357 );
358
359 let mut train_edges = normal_train_e;
360 train_edges.extend(anomaly_train_e);
361
362 let mut val_edges = normal_val_e;
363 val_edges.extend(anomaly_val_e);
364
365 let mut test_edges = normal_test_e;
366 test_edges.extend(anomaly_test_e);
367
368 DataSplit {
369 train_nodes,
370 val_nodes,
371 test_nodes,
372 train_edges,
373 val_edges,
374 test_edges,
375 }
376 }
377
378 fn kfold_split(&self, graph: &Graph, k: usize, fold: usize) -> DataSplit {
380 let mut rng = SimpleRng::new(self.config.random_seed);
381
382 let mut node_ids: Vec<_> = graph.nodes.keys().copied().collect();
383 shuffle(&mut node_ids, &mut rng);
384
385 let fold_size = node_ids.len() / k;
386 let val_start = fold * fold_size;
387 let val_end = if fold == k - 1 {
388 node_ids.len()
389 } else {
390 (fold + 1) * fold_size
391 };
392
393 let val_nodes: Vec<_> = node_ids[val_start..val_end].to_vec();
394 let train_nodes: Vec<_> = node_ids
395 .iter()
396 .enumerate()
397 .filter(|(i, _)| *i < val_start || *i >= val_end)
398 .map(|(_, &id)| id)
399 .collect();
400
401 let mut edge_ids: Vec<_> = graph.edges.keys().copied().collect();
403 shuffle(&mut edge_ids, &mut rng);
404
405 let edge_fold_size = edge_ids.len() / k;
406 let edge_val_start = fold * edge_fold_size;
407 let edge_val_end = if fold == k - 1 {
408 edge_ids.len()
409 } else {
410 (fold + 1) * edge_fold_size
411 };
412
413 let val_edges: Vec<_> = edge_ids[edge_val_start..edge_val_end].to_vec();
414 let train_edges: Vec<_> = edge_ids
415 .iter()
416 .enumerate()
417 .filter(|(i, _)| *i < edge_val_start || *i >= edge_val_end)
418 .map(|(_, &id)| id)
419 .collect();
420
421 DataSplit {
422 train_nodes,
423 val_nodes: val_nodes.clone(),
424 test_nodes: val_nodes, train_edges,
426 val_edges: val_edges.clone(),
427 test_edges: val_edges,
428 }
429 }
430
431 fn transductive_split(&self, graph: &Graph) -> DataSplit {
433 let mut rng = SimpleRng::new(self.config.random_seed);
434
435 let all_nodes: Vec<_> = graph.nodes.keys().copied().collect();
437
438 let mut edge_ids: Vec<_> = graph.edges.keys().copied().collect();
440 shuffle(&mut edge_ids, &mut rng);
441
442 let m = edge_ids.len();
443 let train_size = (m as f64 * self.config.train_ratio) as usize;
444 let val_size = (m as f64 * self.config.val_ratio) as usize;
445
446 let train_edges: Vec<_> = edge_ids[..train_size].to_vec();
447 let val_edges: Vec<_> = edge_ids[train_size..train_size + val_size].to_vec();
448 let test_edges: Vec<_> = edge_ids[train_size + val_size..].to_vec();
449
450 DataSplit {
451 train_nodes: all_nodes.clone(),
452 val_nodes: all_nodes.clone(),
453 test_nodes: all_nodes,
454 train_edges,
455 val_edges,
456 test_edges,
457 }
458 }
459}
460
461fn split_by_ratio<T: Clone>(
463 items: &[T],
464 train_ratio: f64,
465 val_ratio: f64,
466) -> (Vec<T>, Vec<T>, Vec<T>) {
467 let n = items.len();
468 let train_size = (n as f64 * train_ratio) as usize;
469 let val_size = (n as f64 * val_ratio) as usize;
470
471 let train = items[..train_size].to_vec();
472 let val = items[train_size..train_size + val_size].to_vec();
473 let test = items[train_size + val_size..].to_vec();
474
475 (train, val, test)
476}
477
478struct SimpleRng {
480 state: u64,
481}
482
483impl SimpleRng {
484 fn new(seed: u64) -> Self {
485 Self {
486 state: if seed == 0 { 1 } else { seed },
487 }
488 }
489
490 fn next(&mut self) -> u64 {
491 let mut x = self.state;
492 x ^= x << 13;
493 x ^= x >> 7;
494 x ^= x << 17;
495 self.state = x;
496 x
497 }
498}
499
500fn shuffle<T>(items: &mut [T], rng: &mut SimpleRng) {
502 for i in (1..items.len()).rev() {
503 let j = (rng.next() % (i as u64 + 1)) as usize;
504 items.swap(i, j);
505 }
506}
507
508pub fn sample_negative_edges(
510 graph: &Graph,
511 num_samples: usize,
512 seed: u64,
513) -> Vec<(NodeId, NodeId)> {
514 let mut rng = SimpleRng::new(seed);
515 let node_ids: Vec<_> = graph.nodes.keys().copied().collect();
516 let n = node_ids.len();
517
518 if n < 2 {
519 return Vec::new();
520 }
521
522 let existing_edges: HashSet<_> = graph
524 .edges
525 .values()
526 .map(|e| (e.source.min(e.target), e.source.max(e.target)))
527 .collect();
528
529 let mut negative_edges = Vec::with_capacity(num_samples);
530 let max_attempts = num_samples * 10;
531 let mut attempts = 0;
532
533 while negative_edges.len() < num_samples && attempts < max_attempts {
534 let i = (rng.next() % n as u64) as usize;
535 let j = (rng.next() % n as u64) as usize;
536
537 if i == j {
538 attempts += 1;
539 continue;
540 }
541
542 let src = node_ids[i];
543 let tgt = node_ids[j];
544 let key = (src.min(tgt), src.max(tgt));
545
546 if !existing_edges.contains(&key) {
547 negative_edges.push((src, tgt));
548 }
549
550 attempts += 1;
551 }
552
553 negative_edges
554}
555
556#[cfg(test)]
557mod tests {
558 use super::*;
559 use crate::models::{EdgeType, GraphEdge, GraphNode, GraphType, NodeType};
560
561 fn create_test_graph() -> Graph {
562 let mut graph = Graph::new("test", GraphType::Transaction);
563
564 for i in 0..10 {
565 let mut node = GraphNode::new(
566 0,
567 NodeType::Account,
568 format!("{}", i),
569 format!("Account {}", i),
570 );
571 if i == 5 {
572 node.is_anomaly = true;
573 }
574 graph.add_node(node);
575 }
576
577 for i in 0..9 {
578 let edge = GraphEdge::new(0, i + 1, i + 2, EdgeType::Transaction)
579 .with_timestamp(chrono::NaiveDate::from_ymd_opt(2024, 1, i as u32 + 1).unwrap());
580 graph.add_edge(edge);
581 }
582
583 graph.compute_statistics();
584 graph
585 }
586
587 #[test]
588 fn test_random_split() {
589 let graph = create_test_graph();
590 let splitter = DataSplitter::new(SplitConfig::default());
591 let split = splitter.split(&graph);
592
593 assert_eq!(
594 split.train_nodes.len() + split.val_nodes.len() + split.test_nodes.len(),
595 graph.node_count()
596 );
597 }
598
599 #[test]
600 fn test_temporal_split() {
601 let graph = create_test_graph();
602 let config = SplitConfig {
603 strategy: SplitStrategy::Temporal {
604 train_cutoff: chrono::NaiveDate::from_ymd_opt(2024, 1, 4).unwrap(),
605 val_cutoff: chrono::NaiveDate::from_ymd_opt(2024, 1, 7).unwrap(),
606 },
607 ..Default::default()
608 };
609 let splitter = DataSplitter::new(config);
610 let split = splitter.split(&graph);
611
612 assert!(!split.train_edges.is_empty());
614 }
615
616 #[test]
617 fn test_stratified_split() {
618 let graph = create_test_graph();
619 let config = SplitConfig {
620 strategy: SplitStrategy::Stratified,
621 ..Default::default()
622 };
623 let splitter = DataSplitter::new(config);
624 let split = splitter.split(&graph);
625
626 assert!(!split.train_nodes.is_empty());
627 }
628
629 #[test]
630 fn test_negative_sampling() {
631 let graph = create_test_graph();
632 let negatives = sample_negative_edges(&graph, 5, 42);
633
634 assert!(negatives.len() <= 5);
635 for (src, tgt) in &negatives {
636 assert_ne!(src, tgt);
637 }
638 }
639}