datasynth_graph/ml/
splits.rs

1//! Train/validation/test split utilities for graph data.
2
3use std::collections::HashSet;
4
5use chrono::NaiveDate;
6
7use crate::models::{EdgeId, Graph, NodeId};
8
9/// Configuration for data splitting.
10#[derive(Debug, Clone)]
11pub struct SplitConfig {
12    /// Train split ratio.
13    pub train_ratio: f64,
14    /// Validation split ratio.
15    pub val_ratio: f64,
16    /// Test split ratio (computed as 1 - train - val).
17    pub random_seed: u64,
18    /// Split strategy.
19    pub strategy: SplitStrategy,
20}
21
22impl Default for SplitConfig {
23    fn default() -> Self {
24        Self {
25            train_ratio: 0.7,
26            val_ratio: 0.15,
27            random_seed: 42,
28            strategy: SplitStrategy::Random,
29        }
30    }
31}
32
33/// Strategy for splitting data.
34#[derive(Debug, Clone)]
35pub enum SplitStrategy {
36    /// Random split.
37    Random,
38    /// Temporal split (by timestamp).
39    Temporal {
40        /// Date field to use for splitting.
41        train_cutoff: NaiveDate,
42        val_cutoff: NaiveDate,
43    },
44    /// Stratified split (maintain class distribution).
45    Stratified,
46    /// K-fold cross validation.
47    KFold { k: usize, fold: usize },
48    /// Transductive split (nodes appear in all splits, but different edges).
49    Transductive,
50}
51
52/// Result of a data split.
53#[derive(Debug, Clone)]
54pub struct DataSplit {
55    /// Training node IDs.
56    pub train_nodes: Vec<NodeId>,
57    /// Validation node IDs.
58    pub val_nodes: Vec<NodeId>,
59    /// Test node IDs.
60    pub test_nodes: Vec<NodeId>,
61    /// Training edge IDs.
62    pub train_edges: Vec<EdgeId>,
63    /// Validation edge IDs.
64    pub val_edges: Vec<EdgeId>,
65    /// Test edge IDs.
66    pub test_edges: Vec<EdgeId>,
67}
68
69impl DataSplit {
70    /// Creates node masks for the graph.
71    pub fn node_masks(&self, graph: &Graph) -> (Vec<bool>, Vec<bool>, Vec<bool>) {
72        let n = graph.node_count();
73        let mut train_mask = vec![false; n];
74        let mut val_mask = vec![false; n];
75        let mut test_mask = vec![false; n];
76
77        // Create ID to index mapping
78        let mut node_ids: Vec<_> = graph.nodes.keys().copied().collect();
79        node_ids.sort();
80        let id_to_idx: std::collections::HashMap<_, _> = node_ids
81            .iter()
82            .enumerate()
83            .map(|(i, &id)| (id, i))
84            .collect();
85
86        for &id in &self.train_nodes {
87            if let Some(&idx) = id_to_idx.get(&id) {
88                train_mask[idx] = true;
89            }
90        }
91        for &id in &self.val_nodes {
92            if let Some(&idx) = id_to_idx.get(&id) {
93                val_mask[idx] = true;
94            }
95        }
96        for &id in &self.test_nodes {
97            if let Some(&idx) = id_to_idx.get(&id) {
98                test_mask[idx] = true;
99            }
100        }
101
102        (train_mask, val_mask, test_mask)
103    }
104
105    /// Creates edge masks for the graph.
106    pub fn edge_masks(&self, graph: &Graph) -> (Vec<bool>, Vec<bool>, Vec<bool>) {
107        let m = graph.edge_count();
108        let mut train_mask = vec![false; m];
109        let mut val_mask = vec![false; m];
110        let mut test_mask = vec![false; m];
111
112        // Create ID to index mapping
113        let mut edge_ids: Vec<_> = graph.edges.keys().copied().collect();
114        edge_ids.sort();
115        let id_to_idx: std::collections::HashMap<_, _> = edge_ids
116            .iter()
117            .enumerate()
118            .map(|(i, &id)| (id, i))
119            .collect();
120
121        for &id in &self.train_edges {
122            if let Some(&idx) = id_to_idx.get(&id) {
123                train_mask[idx] = true;
124            }
125        }
126        for &id in &self.val_edges {
127            if let Some(&idx) = id_to_idx.get(&id) {
128                val_mask[idx] = true;
129            }
130        }
131        for &id in &self.test_edges {
132            if let Some(&idx) = id_to_idx.get(&id) {
133                test_mask[idx] = true;
134            }
135        }
136
137        (train_mask, val_mask, test_mask)
138    }
139}
140
141/// Data splitter for graph data.
142pub struct DataSplitter {
143    config: SplitConfig,
144}
145
146impl DataSplitter {
147    /// Creates a new data splitter.
148    pub fn new(config: SplitConfig) -> Self {
149        Self { config }
150    }
151
152    /// Splits graph data according to configuration.
153    pub fn split(&self, graph: &Graph) -> DataSplit {
154        match &self.config.strategy {
155            SplitStrategy::Random => self.random_split(graph),
156            SplitStrategy::Temporal {
157                train_cutoff,
158                val_cutoff,
159            } => self.temporal_split(graph, *train_cutoff, *val_cutoff),
160            SplitStrategy::Stratified => self.stratified_split(graph),
161            SplitStrategy::KFold { k, fold } => self.kfold_split(graph, *k, *fold),
162            SplitStrategy::Transductive => self.transductive_split(graph),
163        }
164    }
165
166    /// Performs random split.
167    fn random_split(&self, graph: &Graph) -> DataSplit {
168        let mut rng = SimpleRng::new(self.config.random_seed);
169
170        // Split nodes
171        let mut node_ids: Vec<_> = graph.nodes.keys().copied().collect();
172        shuffle(&mut node_ids, &mut rng);
173
174        let n = node_ids.len();
175        let train_size = (n as f64 * self.config.train_ratio) as usize;
176        let val_size = (n as f64 * self.config.val_ratio) as usize;
177
178        let train_nodes: Vec<_> = node_ids[..train_size].to_vec();
179        let val_nodes: Vec<_> = node_ids[train_size..train_size + val_size].to_vec();
180        let test_nodes: Vec<_> = node_ids[train_size + val_size..].to_vec();
181
182        // Split edges
183        let mut edge_ids: Vec<_> = graph.edges.keys().copied().collect();
184        shuffle(&mut edge_ids, &mut rng);
185
186        let m = edge_ids.len();
187        let train_edge_size = (m as f64 * self.config.train_ratio) as usize;
188        let val_edge_size = (m as f64 * self.config.val_ratio) as usize;
189
190        let train_edges: Vec<_> = edge_ids[..train_edge_size].to_vec();
191        let val_edges: Vec<_> = edge_ids[train_edge_size..train_edge_size + val_edge_size].to_vec();
192        let test_edges: Vec<_> = edge_ids[train_edge_size + val_edge_size..].to_vec();
193
194        DataSplit {
195            train_nodes,
196            val_nodes,
197            test_nodes,
198            train_edges,
199            val_edges,
200            test_edges,
201        }
202    }
203
204    /// Performs temporal split based on edge timestamps.
205    fn temporal_split(
206        &self,
207        graph: &Graph,
208        train_cutoff: NaiveDate,
209        val_cutoff: NaiveDate,
210    ) -> DataSplit {
211        let mut train_edges = Vec::new();
212        let mut val_edges = Vec::new();
213        let mut test_edges = Vec::new();
214
215        // Split edges by timestamp
216        for (&edge_id, edge) in &graph.edges {
217            if let Some(timestamp) = edge.timestamp {
218                if timestamp < train_cutoff {
219                    train_edges.push(edge_id);
220                } else if timestamp < val_cutoff {
221                    val_edges.push(edge_id);
222                } else {
223                    test_edges.push(edge_id);
224                }
225            } else {
226                // No timestamp - assign randomly
227                let r = edge_id % 100;
228                if (r as f64) < self.config.train_ratio * 100.0 {
229                    train_edges.push(edge_id);
230                } else if (r as f64) < (self.config.train_ratio + self.config.val_ratio) * 100.0 {
231                    val_edges.push(edge_id);
232                } else {
233                    test_edges.push(edge_id);
234                }
235            }
236        }
237
238        // Determine node splits based on when they first appear
239        let _train_edge_set: HashSet<_> = train_edges.iter().copied().collect();
240        let _val_edge_set: HashSet<_> = val_edges.iter().copied().collect();
241
242        let mut train_nodes = HashSet::new();
243        let mut val_nodes = HashSet::new();
244        let mut test_nodes = HashSet::new();
245
246        // Nodes that appear in training edges
247        for &edge_id in &train_edges {
248            if let Some(edge) = graph.edges.get(&edge_id) {
249                train_nodes.insert(edge.source);
250                train_nodes.insert(edge.target);
251            }
252        }
253
254        // Nodes that first appear in validation edges
255        for &edge_id in &val_edges {
256            if let Some(edge) = graph.edges.get(&edge_id) {
257                if !train_nodes.contains(&edge.source) {
258                    val_nodes.insert(edge.source);
259                }
260                if !train_nodes.contains(&edge.target) {
261                    val_nodes.insert(edge.target);
262                }
263            }
264        }
265
266        // Nodes that first appear in test edges
267        for &edge_id in &test_edges {
268            if let Some(edge) = graph.edges.get(&edge_id) {
269                if !train_nodes.contains(&edge.source) && !val_nodes.contains(&edge.source) {
270                    test_nodes.insert(edge.source);
271                }
272                if !train_nodes.contains(&edge.target) && !val_nodes.contains(&edge.target) {
273                    test_nodes.insert(edge.target);
274                }
275            }
276        }
277
278        DataSplit {
279            train_nodes: train_nodes.into_iter().collect(),
280            val_nodes: val_nodes.into_iter().collect(),
281            test_nodes: test_nodes.into_iter().collect(),
282            train_edges,
283            val_edges,
284            test_edges,
285        }
286    }
287
288    /// Performs stratified split maintaining anomaly distribution.
289    fn stratified_split(&self, graph: &Graph) -> DataSplit {
290        let mut rng = SimpleRng::new(self.config.random_seed);
291
292        // Separate normal and anomalous nodes
293        let mut normal_nodes: Vec<_> = graph
294            .nodes
295            .iter()
296            .filter(|(_, n)| !n.is_anomaly)
297            .map(|(&id, _)| id)
298            .collect();
299        let mut anomalous_nodes: Vec<_> = graph
300            .nodes
301            .iter()
302            .filter(|(_, n)| n.is_anomaly)
303            .map(|(&id, _)| id)
304            .collect();
305
306        shuffle(&mut normal_nodes, &mut rng);
307        shuffle(&mut anomalous_nodes, &mut rng);
308
309        // Split each class
310        let (normal_train, normal_val, normal_test) = split_by_ratio(
311            &normal_nodes,
312            self.config.train_ratio,
313            self.config.val_ratio,
314        );
315        let (anomaly_train, anomaly_val, anomaly_test) = split_by_ratio(
316            &anomalous_nodes,
317            self.config.train_ratio,
318            self.config.val_ratio,
319        );
320
321        // Combine
322        let mut train_nodes = normal_train;
323        train_nodes.extend(anomaly_train);
324
325        let mut val_nodes = normal_val;
326        val_nodes.extend(anomaly_val);
327
328        let mut test_nodes = normal_test;
329        test_nodes.extend(anomaly_test);
330
331        // Split edges similarly
332        let mut normal_edges: Vec<_> = graph
333            .edges
334            .iter()
335            .filter(|(_, e)| !e.is_anomaly)
336            .map(|(&id, _)| id)
337            .collect();
338        let mut anomalous_edges: Vec<_> = graph
339            .edges
340            .iter()
341            .filter(|(_, e)| e.is_anomaly)
342            .map(|(&id, _)| id)
343            .collect();
344
345        shuffle(&mut normal_edges, &mut rng);
346        shuffle(&mut anomalous_edges, &mut rng);
347
348        let (normal_train_e, normal_val_e, normal_test_e) = split_by_ratio(
349            &normal_edges,
350            self.config.train_ratio,
351            self.config.val_ratio,
352        );
353        let (anomaly_train_e, anomaly_val_e, anomaly_test_e) = split_by_ratio(
354            &anomalous_edges,
355            self.config.train_ratio,
356            self.config.val_ratio,
357        );
358
359        let mut train_edges = normal_train_e;
360        train_edges.extend(anomaly_train_e);
361
362        let mut val_edges = normal_val_e;
363        val_edges.extend(anomaly_val_e);
364
365        let mut test_edges = normal_test_e;
366        test_edges.extend(anomaly_test_e);
367
368        DataSplit {
369            train_nodes,
370            val_nodes,
371            test_nodes,
372            train_edges,
373            val_edges,
374            test_edges,
375        }
376    }
377
378    /// Performs k-fold cross validation split.
379    fn kfold_split(&self, graph: &Graph, k: usize, fold: usize) -> DataSplit {
380        let mut rng = SimpleRng::new(self.config.random_seed);
381
382        let mut node_ids: Vec<_> = graph.nodes.keys().copied().collect();
383        shuffle(&mut node_ids, &mut rng);
384
385        let fold_size = node_ids.len() / k;
386        let val_start = fold * fold_size;
387        let val_end = if fold == k - 1 {
388            node_ids.len()
389        } else {
390            (fold + 1) * fold_size
391        };
392
393        let val_nodes: Vec<_> = node_ids[val_start..val_end].to_vec();
394        let train_nodes: Vec<_> = node_ids
395            .iter()
396            .enumerate()
397            .filter(|(i, _)| *i < val_start || *i >= val_end)
398            .map(|(_, &id)| id)
399            .collect();
400
401        // Similarly for edges
402        let mut edge_ids: Vec<_> = graph.edges.keys().copied().collect();
403        shuffle(&mut edge_ids, &mut rng);
404
405        let edge_fold_size = edge_ids.len() / k;
406        let edge_val_start = fold * edge_fold_size;
407        let edge_val_end = if fold == k - 1 {
408            edge_ids.len()
409        } else {
410            (fold + 1) * edge_fold_size
411        };
412
413        let val_edges: Vec<_> = edge_ids[edge_val_start..edge_val_end].to_vec();
414        let train_edges: Vec<_> = edge_ids
415            .iter()
416            .enumerate()
417            .filter(|(i, _)| *i < edge_val_start || *i >= edge_val_end)
418            .map(|(_, &id)| id)
419            .collect();
420
421        DataSplit {
422            train_nodes,
423            val_nodes: val_nodes.clone(),
424            test_nodes: val_nodes, // In k-fold, val and test are the same
425            train_edges,
426            val_edges: val_edges.clone(),
427            test_edges: val_edges,
428        }
429    }
430
431    /// Performs transductive split (all nodes available, edges split).
432    fn transductive_split(&self, graph: &Graph) -> DataSplit {
433        let mut rng = SimpleRng::new(self.config.random_seed);
434
435        // All nodes in all splits
436        let all_nodes: Vec<_> = graph.nodes.keys().copied().collect();
437
438        // Split edges only
439        let mut edge_ids: Vec<_> = graph.edges.keys().copied().collect();
440        shuffle(&mut edge_ids, &mut rng);
441
442        let m = edge_ids.len();
443        let train_size = (m as f64 * self.config.train_ratio) as usize;
444        let val_size = (m as f64 * self.config.val_ratio) as usize;
445
446        let train_edges: Vec<_> = edge_ids[..train_size].to_vec();
447        let val_edges: Vec<_> = edge_ids[train_size..train_size + val_size].to_vec();
448        let test_edges: Vec<_> = edge_ids[train_size + val_size..].to_vec();
449
450        DataSplit {
451            train_nodes: all_nodes.clone(),
452            val_nodes: all_nodes.clone(),
453            test_nodes: all_nodes,
454            train_edges,
455            val_edges,
456            test_edges,
457        }
458    }
459}
460
461/// Splits a slice by ratio.
462fn split_by_ratio<T: Clone>(
463    items: &[T],
464    train_ratio: f64,
465    val_ratio: f64,
466) -> (Vec<T>, Vec<T>, Vec<T>) {
467    let n = items.len();
468    let train_size = (n as f64 * train_ratio) as usize;
469    let val_size = (n as f64 * val_ratio) as usize;
470
471    let train = items[..train_size].to_vec();
472    let val = items[train_size..train_size + val_size].to_vec();
473    let test = items[train_size + val_size..].to_vec();
474
475    (train, val, test)
476}
477
478/// Simple random number generator (xorshift64).
479struct SimpleRng {
480    state: u64,
481}
482
483impl SimpleRng {
484    fn new(seed: u64) -> Self {
485        Self {
486            state: if seed == 0 { 1 } else { seed },
487        }
488    }
489
490    fn next(&mut self) -> u64 {
491        let mut x = self.state;
492        x ^= x << 13;
493        x ^= x >> 7;
494        x ^= x << 17;
495        self.state = x;
496        x
497    }
498}
499
500/// Fisher-Yates shuffle.
501fn shuffle<T>(items: &mut [T], rng: &mut SimpleRng) {
502    for i in (1..items.len()).rev() {
503        let j = (rng.next() % (i as u64 + 1)) as usize;
504        items.swap(i, j);
505    }
506}
507
508/// Creates negative edge samples for link prediction.
509pub fn sample_negative_edges(
510    graph: &Graph,
511    num_samples: usize,
512    seed: u64,
513) -> Vec<(NodeId, NodeId)> {
514    let mut rng = SimpleRng::new(seed);
515    let node_ids: Vec<_> = graph.nodes.keys().copied().collect();
516    let n = node_ids.len();
517
518    if n < 2 {
519        return Vec::new();
520    }
521
522    // Build existing edge set
523    let existing_edges: HashSet<_> = graph
524        .edges
525        .values()
526        .map(|e| (e.source.min(e.target), e.source.max(e.target)))
527        .collect();
528
529    let mut negative_edges = Vec::with_capacity(num_samples);
530    let max_attempts = num_samples * 10;
531    let mut attempts = 0;
532
533    while negative_edges.len() < num_samples && attempts < max_attempts {
534        let i = (rng.next() % n as u64) as usize;
535        let j = (rng.next() % n as u64) as usize;
536
537        if i == j {
538            attempts += 1;
539            continue;
540        }
541
542        let src = node_ids[i];
543        let tgt = node_ids[j];
544        let key = (src.min(tgt), src.max(tgt));
545
546        if !existing_edges.contains(&key) {
547            negative_edges.push((src, tgt));
548        }
549
550        attempts += 1;
551    }
552
553    negative_edges
554}
555
556#[cfg(test)]
557mod tests {
558    use super::*;
559    use crate::models::{EdgeType, GraphEdge, GraphNode, GraphType, NodeType};
560
561    fn create_test_graph() -> Graph {
562        let mut graph = Graph::new("test", GraphType::Transaction);
563
564        for i in 0..10 {
565            let mut node = GraphNode::new(
566                0,
567                NodeType::Account,
568                format!("{}", i),
569                format!("Account {}", i),
570            );
571            if i == 5 {
572                node.is_anomaly = true;
573            }
574            graph.add_node(node);
575        }
576
577        for i in 0..9 {
578            let edge = GraphEdge::new(0, i + 1, i + 2, EdgeType::Transaction)
579                .with_timestamp(chrono::NaiveDate::from_ymd_opt(2024, 1, i as u32 + 1).unwrap());
580            graph.add_edge(edge);
581        }
582
583        graph.compute_statistics();
584        graph
585    }
586
587    #[test]
588    fn test_random_split() {
589        let graph = create_test_graph();
590        let splitter = DataSplitter::new(SplitConfig::default());
591        let split = splitter.split(&graph);
592
593        assert_eq!(
594            split.train_nodes.len() + split.val_nodes.len() + split.test_nodes.len(),
595            graph.node_count()
596        );
597    }
598
599    #[test]
600    fn test_temporal_split() {
601        let graph = create_test_graph();
602        let config = SplitConfig {
603            strategy: SplitStrategy::Temporal {
604                train_cutoff: chrono::NaiveDate::from_ymd_opt(2024, 1, 4).unwrap(),
605                val_cutoff: chrono::NaiveDate::from_ymd_opt(2024, 1, 7).unwrap(),
606            },
607            ..Default::default()
608        };
609        let splitter = DataSplitter::new(config);
610        let split = splitter.split(&graph);
611
612        // Train edges should be before cutoff
613        assert!(!split.train_edges.is_empty());
614    }
615
616    #[test]
617    fn test_stratified_split() {
618        let graph = create_test_graph();
619        let config = SplitConfig {
620            strategy: SplitStrategy::Stratified,
621            ..Default::default()
622        };
623        let splitter = DataSplitter::new(config);
624        let split = splitter.split(&graph);
625
626        assert!(!split.train_nodes.is_empty());
627    }
628
629    #[test]
630    fn test_negative_sampling() {
631        let graph = create_test_graph();
632        let negatives = sample_negative_edges(&graph, 5, 42);
633
634        assert!(negatives.len() <= 5);
635        for (src, tgt) in &negatives {
636            assert_ne!(src, tgt);
637        }
638    }
639}