raphtory 0.17.0

raphtory, a temporal graph library
Documentation
#[cfg(test)]
mod fast_rp_test {

    use raphtory::{
        algorithms::embeddings::fast_rp::fast_rp, db::api::mutation::AdditionOps, prelude::*,
        test_storage,
    };
    use std::collections::HashMap;

    #[test]
    fn simple_fast_rp_test() {
        let graph = Graph::new();

        let edges = vec![
            (1, 2, 1),
            (1, 3, 1),
            (2, 3, 1),
            (4, 5, 1),
            (4, 6, 1),
            (4, 7, 1),
            (5, 6, 1),
            (5, 7, 1),
            (6, 7, 1),
            (6, 8, 1),
        ];

        for (src, dst, ts) in edges {
            graph.add_edge(ts, src, dst, NO_PROPS, None).unwrap();
        }

        let baseline: HashMap<String, Vec<f64>> = HashMap::from([
            (
                String::from("5"),
                vec![
                    0.0,
                    1.9620916355920008,
                    -1.6817928305074292,
                    -1.6817928305074292,
                    0.2802988050845715,
                    -0.2802988050845715,
                    0.2802988050845715,
                    1.4014940254228576,
                    -0.2802988050845715,
                    0.0,
                    0.0,
                    -1.6817928305074292,
                    0.2802988050845715,
                    0.2802988050845715,
                    -0.2802988050845715,
                    1.121195220338286,
                ],
            ),
            (
                String::from("1"),
                vec![
                    1.6817928305074292,
                    0.4204482076268573,
                    -0.4204482076268573,
                    0.0,
                    0.0,
                    2.1022410381342866,
                    0.4204482076268573,
                    0.4204482076268573,
                    2.1022410381342866,
                    -0.8408964152537146,
                    0.0,
                    1.6817928305074292,
                    0.0,
                    -1.6817928305074292,
                    0.0,
                    -0.8408964152537146,
                ],
            ),
            (
                String::from("4"),
                vec![
                    -1.4014940254228576,
                    0.560597610169143,
                    1.121195220338286,
                    -0.2802988050845715,
                    0.2802988050845715,
                    -0.2802988050845715,
                    0.2802988050845715,
                    0.0,
                    -1.6817928305074292,
                    0.0,
                    0.0,
                    -0.2802988050845715,
                    0.2802988050845715,
                    0.2802988050845715,
                    -0.2802988050845715,
                    -1.6817928305074292,
                ],
            ),
            (
                String::from("6"),
                vec![
                    -0.21022410381342865,
                    0.6306723114402859,
                    -1.6817928305074292,
                    -1.4715687266940005,
                    1.6817928305074292,
                    -1.6817928305074292,
                    0.0,
                    -1.4715687266940005,
                    -0.21022410381342865,
                    0.0,
                    0.0,
                    -0.4204482076268573,
                    1.6817928305074292,
                    0.21022410381342865,
                    -0.21022410381342865,
                    -0.21022410381342865,
                ],
            ),
            (
                String::from("7"),
                vec![
                    1.4014940254228576,
                    1.9620916355920008,
                    -0.2802988050845715,
                    1.121195220338286,
                    0.2802988050845715,
                    -0.2802988050845715,
                    1.6817928305074292,
                    0.0,
                    -0.2802988050845715,
                    0.0,
                    0.0,
                    -0.2802988050845715,
                    0.2802988050845715,
                    1.6817928305074292,
                    -1.6817928305074292,
                    -1.6817928305074292,
                ],
            ),
            (
                String::from("2"),
                vec![
                    0.4204482076268573,
                    1.6817928305074292,
                    -1.6817928305074292,
                    0.0,
                    0.0,
                    0.8408964152537146,
                    1.6817928305074292,
                    1.6817928305074292,
                    2.1022410381342866,
                    -2.1022410381342866,
                    0.0,
                    0.4204482076268573,
                    0.0,
                    -0.4204482076268573,
                    0.0,
                    -2.1022410381342866,
                ],
            ),
            (
                String::from("8"),
                vec![
                    -1.6817928305074292,
                    1.6817928305074292,
                    -0.8408964152537146,
                    0.8408964152537146,
                    0.8408964152537146,
                    -0.8408964152537146,
                    -1.6817928305074292,
                    -0.8408964152537146,
                    0.0,
                    0.0,
                    0.0,
                    -1.6817928305074292,
                    0.8408964152537146,
                    0.0,
                    0.0,
                    0.0,
                ],
            ),
            (
                String::from("3"),
                vec![
                    0.4204482076268573,
                    0.4204482076268573,
                    -0.4204482076268573,
                    0.0,
                    0.0,
                    2.1022410381342866,
                    0.4204482076268573,
                    0.4204482076268573,
                    0.8408964152537146,
                    -2.1022410381342866,
                    0.0,
                    0.4204482076268573,
                    0.0,
                    -0.4204482076268573,
                    0.0,
                    -2.1022410381342866,
                ],
            ),
        ]);
        test_storage!(&graph, |graph| {
            let results = fast_rp(graph, 16, 1.0, vec![1.0, 1.0], Some(42), None)
                .to_hashmap(|value| value.embedding_state);
            assert_eq!(results, baseline);
        });
    }

    // NOTE(Wyatt): the simple fast_rp test is more of a validation of idempotency than correctness (although the results are expected)
    // This test-- in progress-- is going to validate that the algorithm preserves the pairwise topological distances
    /*
    use crate::io::csv_loader::CsvLoader;
    use serde::{Deserialize, Serialize};
    use std::path::PathBuf;

    fn print_samples(map: &HashMap<String, Vec<f64>>, n: usize) {
        let mut count = 0;

        for (key, value) in map {
            println!("Key: {}, Value: {:#?}", key, value);

            count += 1;
            if count >= n {
                break;
            }
        }
    }

    fn top_k_neighbors(
        data: &HashMap<String, Vec<f64>>,
        k: usize,
    ) -> HashMap<String, Vec<String>> {
        let mut neighbors: HashMap<String, Vec<String>> = HashMap::new();

        // Iterate over each ID to find its top K neighbors
        for (id, vector) in data {
            // Collect distances to all other IDs
            let mut distances: Vec<(&String, f64)> = Vec::new();
            for (other_id, other_vector) in data {
                if id == other_id {
                    continue; // Skip self
                }
                // Compute Euclidean distance
                let distance = euclidean_distance(vector, other_vector);
                distances.push((other_id, distance));
            }
            // Sort the distances in ascending order
            distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
            // Collect top K neighbor IDs
            let top_k: Vec<String> = distances
                .iter()
                .take(k)
                .map(|(other_id, _)| (*other_id).clone())
                .collect();
            // Insert into the neighbors map
            neighbors.insert(id.clone(), top_k);
        }

        neighbors
    }

    fn euclidean_distance(a: &Vec<f64>, b: &Vec<f64>) -> f64 {
        assert_eq!(a.len(), b.len(), "Vectors must be of the same length");
        a.iter()
            .zip(b.iter())
            .map(|(&x, &y)| (x - y).powi(2))
            .sum::<f64>()
            .sqrt()
    }

    #[test]
    fn big_fast_rp_test() {
        let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
        d.push("resources/test");
        let loader = CsvLoader::new(d.join("test.csv")).set_delimiter(",");
        let graph = Graph::new();

        #[derive(Deserialize, Serialize, Debug)]
        struct CsvEdge {
            src: u64,
            dst: u64,
        }

        loader
            .load_into_graph(&graph, |e: CsvEdge, g| {
                g.add_edge(1, e.src, e.dst, NO_PROPS, None).unwrap();
                g.add_edge(1, e.dst, e.src, NO_PROPS, None).unwrap();
            })
            .unwrap();

        test_storage!(&graph, |graph| {
            let results = fast_rp(
                graph,
                32,
                1.0,
                vec![1.0, 1.0, 0.5],
                Some(42),
                None,
            ).get_all_with_names();
            // println!("Result: {:#?}", results);
            print_samples(&results, 10);
        });
    }
     */
}