Skip to main content

fso_graph_builder/graph/
csr.rs

1use atomic::Atomic;
2use byte_slice_cast::{AsByteSlice, AsMutByteSlice, ToByteSlice, ToMutByteSlice};
3use log::info;
4use std::{
5    convert::TryFrom,
6    fs::File,
7    io::{BufReader, Read, Write},
8    iter::FromIterator,
9    mem::{ManuallyDrop, MaybeUninit},
10    path::PathBuf,
11    sync::atomic::Ordering::Acquire,
12    time::Instant,
13};
14
15use rayon::prelude::*;
16
17use crate::{
18    compat::*,
19    graph_ops::{DeserializeGraphOp, SerializeGraphOp, ToUndirectedOp},
20    index::Idx,
21    input::{edgelist::Edges, Direction},
22    DirectedDegrees, DirectedNeighbors, DirectedNeighborsWithValues, Error, Graph,
23    NodeValues as NodeValuesTrait, SharedMut, Target, UndirectedDegrees, UndirectedNeighbors,
24    UndirectedNeighborsWithValues,
25};
26
27#[cfg(feature = "dotgraph")]
28use crate::input::DotGraph;
29#[cfg(feature = "dotgraph")]
30use std::hash::Hash;
31
32/// Defines how the neighbor list of individual nodes are organized within the
33/// CSR target array.
34#[derive(Default, Clone, Copy, Debug)]
35pub enum CsrLayout {
36    /// Neighbor lists are sorted and may contain duplicate target ids. This is
37    /// the default representation.
38    Sorted,
39    /// Neighbor lists are not in any particular order.
40    #[default]
41    Unsorted,
42    /// Neighbor lists are sorted and do not contain duplicate target ids.
43    /// Self-loops, i.e., edges in the form of `(u, u)` are removed.
44    Deduplicated,
45}
46
47/// A Compressed-Sparse-Row data structure to represent sparse graphs.
48///
49/// The data structure is composed of two arrays: `offsets` and `targets`. For a
50/// graph with node count `n` and edge count `m`, `offsets` has exactly `n + 1`
51/// and `targets` exactly `m` entries.
52///
53/// For a given node `u`, `offsets[u]` stores the start index of the neighbor
54/// list of `u` in `targets`. The degree of `u`, i.e., the length of the
55/// neighbor list is defined by `offsets[u + 1] - offsets[u]`. The neighbor list
56/// of `u` is defined by the slice `&targets[offsets[u]..offsets[u + 1]]`.
57#[derive(Debug)]
58pub struct Csr<Index: Idx, NI, EV> {
59    offsets: Box<[Index]>,
60    targets: Box<[Target<NI, EV>]>,
61}
62
63impl<Index: Idx, NI, EV> Csr<Index, NI, EV> {
64    pub(crate) fn new(offsets: Box<[Index]>, targets: Box<[Target<NI, EV>]>) -> Self {
65        Self { offsets, targets }
66    }
67
68    #[inline]
69    pub(crate) fn node_count(&self) -> Index {
70        Index::new(self.offsets.len() - 1)
71    }
72
73    #[inline]
74    pub(crate) fn edge_count(&self) -> Index {
75        Index::new(self.targets.len())
76    }
77
78    #[inline]
79    pub(crate) fn degree(&self, i: Index) -> Index {
80        let from = self.offsets[i.index()];
81        let to = self.offsets[(i + Index::new(1)).index()];
82
83        to - from
84    }
85
86    #[inline]
87    pub(crate) fn targets_with_values(&self, i: Index) -> &[Target<NI, EV>] {
88        let from = self.offsets[i.index()];
89        let to = self.offsets[(i + Index::new(1)).index()];
90
91        &self.targets[from.index()..to.index()]
92    }
93}
94
95impl<Index: Idx, NI> Csr<Index, NI, ()> {
96    #[inline]
97    pub(crate) fn targets(&self, i: Index) -> &[NI] {
98        assert_eq!(
99            std::mem::size_of::<Target<NI, ()>>(),
100            std::mem::size_of::<NI>()
101        );
102        assert_eq!(
103            std::mem::align_of::<Target<NI, ()>>(),
104            std::mem::align_of::<NI>()
105        );
106        let from = self.offsets[i.index()];
107        let to = self.offsets[(i + Index::new(1)).index()];
108
109        let len = (to - from).index();
110
111        let targets = &self.targets[from.index()..to.index()];
112
113        // SAFETY: len is within bounds as it is calculated above as `to - from`.
114        //         The types Target<T, ()> and T are verified to have the same
115        //         size and alignment.
116        unsafe { std::slice::from_raw_parts(targets.as_ptr() as *const _, len) }
117    }
118}
119
120pub trait SwapCsr<Index: Idx, NI, EV> {
121    fn swap_csr(&mut self, csr: Csr<Index, NI, EV>) -> &mut Self;
122}
123
124impl<NI, EV, E> From<(&'_ E, NI, Direction, CsrLayout)> for Csr<NI, NI, EV>
125where
126    NI: Idx,
127    EV: Copy + Send + Sync,
128    E: Edges<NI = NI, EV = EV>,
129{
130    fn from(
131        (edge_list, node_count, direction, csr_layout): (&'_ E, NI, Direction, CsrLayout),
132    ) -> Self {
133        let start = Instant::now();
134        let degrees = edge_list.degrees(node_count, direction);
135        info!("Computed degrees in {:?}", start.elapsed());
136
137        let start = Instant::now();
138        let offsets = prefix_sum_atomic(degrees);
139        info!("Computed prefix sum in {:?}", start.elapsed());
140
141        let start = Instant::now();
142        let edge_count = offsets[node_count.index()].load(Acquire).index();
143        let mut targets = Vec::<Target<NI, EV>>::with_capacity(edge_count);
144        let targets_ptr = SharedMut::new(targets.as_mut_ptr());
145
146        // The following loop writes all targets into their correct position.
147        // The offsets are a prefix sum of all degrees, which will produce
148        // non-overlapping positions for all node values.
149        //
150        // SAFETY: for any (s, t) tuple from the same edge_list we use the
151        // prefix_sum to find a unique position for the target value, so that we
152        // only write once into each position and every thread that might run
153        // will write into different positions.
154        if matches!(direction, Direction::Outgoing | Direction::Undirected) {
155            edge_list.edges().for_each(|(s, t, v)| {
156                let offset = NI::get_and_increment(&offsets[s.index()], Acquire);
157
158                unsafe {
159                    targets_ptr.add(offset.index()).write(Target::new(t, v));
160                }
161            })
162        }
163
164        if matches!(direction, Direction::Incoming | Direction::Undirected) {
165            edge_list.edges().for_each(|(s, t, v)| {
166                let offset = NI::get_and_increment(&offsets[t.index()], Acquire);
167
168                unsafe {
169                    targets_ptr.add(offset.index()).write(Target::new(s, v));
170                }
171            })
172        }
173
174        // SAFETY: The previous loops iterated the input edge list once (twice
175        // for undirected) and inserted one node id for each edge. The
176        // `edge_count` is defined by the highest offset value.
177        unsafe {
178            targets.set_len(edge_count);
179        }
180        info!("Computed target array in {:?}", start.elapsed());
181
182        let start = Instant::now();
183        let mut offsets = ManuallyDrop::new(offsets);
184        let (ptr, len, cap) = (offsets.as_mut_ptr(), offsets.len(), offsets.capacity());
185
186        // SAFETY: NI and NI::Atomic have the same memory layout
187        let mut offsets = unsafe {
188            let ptr = ptr as *mut _;
189            Vec::from_raw_parts(ptr, len, cap)
190        };
191
192        // Each insert into the target array in the previous loops incremented
193        // the offset for the corresponding node by one. As a consequence the
194        // offset values are shifted one index to the right. We need to correct
195        // this in order to get correct offsets.
196        offsets.rotate_right(1);
197        offsets[0] = NI::zero();
198        info!("Finalized offset array in {:?}", start.elapsed());
199
200        let (offsets, targets) = match csr_layout {
201            CsrLayout::Unsorted => (offsets, targets),
202            CsrLayout::Sorted => {
203                let start = Instant::now();
204                sort_targets(&offsets, &mut targets);
205                info!("Sorted targets in {:?}", start.elapsed());
206                (offsets, targets)
207            }
208            CsrLayout::Deduplicated => {
209                let start = Instant::now();
210                let offsets_targets = sort_and_deduplicate_targets(&offsets, &mut targets[..]);
211                info!("Sorted and deduplicated targets in {:?}", start.elapsed());
212                offsets_targets
213            }
214        };
215
216        Csr {
217            offsets: offsets.into_boxed_slice(),
218            targets: targets.into_boxed_slice(),
219        }
220    }
221}
222
223unsafe impl<NI, EV> ToByteSlice for Target<NI, EV>
224where
225    NI: ToByteSlice,
226    EV: ToByteSlice,
227{
228    fn to_byte_slice<S: AsRef<[Self]> + ?Sized>(slice: &S) -> &[u8] {
229        let slice = slice.as_ref();
230        let len = std::mem::size_of_val(slice);
231        unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const u8, len) }
232    }
233}
234
235unsafe impl<NI, EV> ToMutByteSlice for Target<NI, EV>
236where
237    NI: ToMutByteSlice,
238    EV: ToMutByteSlice,
239{
240    fn to_mut_byte_slice<S: AsMut<[Self]> + ?Sized>(slice: &mut S) -> &mut [u8] {
241        let slice = slice.as_mut();
242        let len = std::mem::size_of_val(slice);
243        unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut u8, len) }
244    }
245}
246
247impl<NI, EV> Csr<NI, NI, EV>
248where
249    NI: Idx + ToByteSlice,
250    EV: ToByteSlice,
251{
252    fn serialize<W: Write>(&self, output: &mut W) -> Result<(), Error> {
253        let type_name = std::any::type_name::<NI>().as_bytes();
254        output.write_all([type_name.len()].as_byte_slice())?;
255        output.write_all(type_name)?;
256
257        let node_count = self.node_count();
258        let edge_count = self.edge_count();
259        let meta = [node_count, edge_count];
260        output.write_all(meta.as_byte_slice())?;
261
262        output.write_all(self.offsets.as_byte_slice())?;
263        output.write_all(self.targets.as_byte_slice())?;
264
265        Ok(())
266    }
267}
268
269impl<NI, EV> Csr<NI, NI, EV>
270where
271    NI: Idx + ToMutByteSlice,
272    EV: ToMutByteSlice,
273{
274    fn deserialize<R: Read>(read: &mut R) -> Result<Csr<NI, NI, EV>, Error> {
275        let mut type_name_len = [0_usize; 1];
276        read.read_exact(type_name_len.as_mut_byte_slice())?;
277        let [type_name_len] = type_name_len;
278
279        let mut type_name = vec![0_u8; type_name_len];
280        read.read_exact(type_name.as_mut_byte_slice())?;
281        let type_name = String::from_utf8(type_name).expect("could not read type name");
282
283        let expected_type_name = std::any::type_name::<NI>().to_string();
284
285        if type_name != expected_type_name {
286            return Err(Error::InvalidIdType {
287                expected: expected_type_name,
288                actual: type_name,
289            });
290        }
291
292        let mut meta = [NI::zero(); 2];
293        read.read_exact(meta.as_mut_byte_slice())?;
294
295        let [node_count, edge_count] = meta;
296
297        let mut offsets = Box::new_uninit_slice_compat(node_count.index() + 1);
298        let offsets_ptr = offsets.as_mut_ptr() as *mut NI;
299        let offsets_ptr =
300            unsafe { std::slice::from_raw_parts_mut(offsets_ptr, node_count.index() + 1) };
301        read.read_exact(offsets_ptr.as_mut_byte_slice())?;
302
303        let mut targets = Box::new_uninit_slice_compat(edge_count.index());
304        let targets_ptr = targets.as_mut_ptr() as *mut Target<NI, EV>;
305        let targets_ptr =
306            unsafe { std::slice::from_raw_parts_mut(targets_ptr, edge_count.index()) };
307        read.read_exact(targets_ptr.as_mut_byte_slice())?;
308
309        let offsets = unsafe { offsets.assume_init_compat() };
310        let targets = unsafe { targets.assume_init_compat() };
311
312        Ok(Csr::new(offsets, targets))
313    }
314}
315
316pub struct NodeValues<NV>(pub(crate) Box<[NV]>);
317
318impl<NV> NodeValues<NV> {
319    pub fn new(node_values: Vec<NV>) -> Self {
320        Self(node_values.into_boxed_slice())
321    }
322}
323
324impl<NV> FromIterator<NV> for NodeValues<NV> {
325    fn from_iter<T: IntoIterator<Item = NV>>(iter: T) -> Self {
326        Self(iter.into_iter().collect::<Vec<_>>().into_boxed_slice())
327    }
328}
329
330impl<NV> NodeValues<NV>
331where
332    NV: ToByteSlice,
333{
334    fn serialize<W: Write>(&self, output: &mut W) -> Result<(), Error> {
335        let node_count = self.0.len();
336        let meta = [node_count];
337        output.write_all(meta.as_byte_slice())?;
338        output.write_all(self.0.as_byte_slice())?;
339        Ok(())
340    }
341}
342
343impl<NV> NodeValues<NV>
344where
345    NV: ToMutByteSlice,
346{
347    fn deserialize<R: Read>(read: &mut R) -> Result<Self, Error> {
348        let mut meta = [0_usize; 1];
349        read.read_exact(meta.as_mut_byte_slice())?;
350        let [node_count] = meta;
351
352        let mut node_values = Box::new_uninit_slice_compat(node_count);
353        let node_values_ptr = node_values.as_mut_ptr() as *mut NV;
354        let node_values_slice =
355            unsafe { std::slice::from_raw_parts_mut(node_values_ptr, node_count.index()) };
356        read.read_exact(node_values_slice.as_mut_byte_slice())?;
357
358        let offsets = unsafe { node_values.assume_init_compat() };
359
360        Ok(NodeValues(offsets))
361    }
362}
363
364pub struct DirectedCsrGraph<NI: Idx, NV = (), EV = ()> {
365    node_values: NodeValues<NV>,
366    csr_out: Csr<NI, NI, EV>,
367    csr_inc: Csr<NI, NI, EV>,
368}
369
370impl<NI: Idx, NV, EV> DirectedCsrGraph<NI, NV, EV> {
371    pub fn new(
372        node_values: NodeValues<NV>,
373        csr_out: Csr<NI, NI, EV>,
374        csr_inc: Csr<NI, NI, EV>,
375    ) -> Self {
376        let g = Self {
377            node_values,
378            csr_out,
379            csr_inc,
380        };
381        info!(
382            "Created directed graph (node_count = {:?}, edge_count = {:?})",
383            g.node_count(),
384            g.edge_count()
385        );
386
387        g
388    }
389}
390
391impl<NI, NV, EV> ToUndirectedOp for DirectedCsrGraph<NI, NV, EV>
392where
393    NI: Idx,
394    NV: Clone + Send + Sync,
395    EV: Copy + Send + Sync,
396{
397    type Undirected = UndirectedCsrGraph<NI, NV, EV>;
398
399    fn to_undirected(&self, layout: impl Into<Option<CsrLayout>>) -> Self::Undirected {
400        let node_values = NodeValues::new(self.node_values.0.to_vec());
401        let layout = layout.into().unwrap_or_default();
402        let edges = ToUndirectedEdges { g: self };
403
404        UndirectedCsrGraph::from((node_values, edges, layout))
405    }
406}
407
408struct ToUndirectedEdges<'g, NI: Idx, NV, EV> {
409    g: &'g DirectedCsrGraph<NI, NV, EV>,
410}
411
412impl<NI, NV, EV> Edges for ToUndirectedEdges<'_, NI, NV, EV>
413where
414    NI: Idx,
415    NV: Send + Sync,
416    EV: Copy + Send + Sync,
417{
418    type NI = NI;
419
420    type EV = EV;
421
422    type EdgeIter<'a>
423        = ToUndirectedEdgesIter<'a, NI, NV, EV>
424    where
425        Self: 'a;
426
427    fn edges(&self) -> Self::EdgeIter<'_> {
428        ToUndirectedEdgesIter { g: self.g }
429    }
430
431    fn max_node_id(&self) -> Self::NI {
432        self.g.node_count() - NI::new(1)
433    }
434
435    #[cfg(test)]
436    fn len(&self) -> usize {
437        unimplemented!("This type is not used in tests")
438    }
439}
440
441struct ToUndirectedEdgesIter<'g, NI: Idx, NV, EV> {
442    g: &'g DirectedCsrGraph<NI, NV, EV>,
443}
444
445impl<NI: Idx, NV: Send + Sync, EV: Copy + Send + Sync> ParallelIterator
446    for ToUndirectedEdgesIter<'_, NI, NV, EV>
447{
448    type Item = (NI, NI, EV);
449
450    fn drive_unindexed<C>(self, consumer: C) -> C::Result
451    where
452        C: rayon::iter::plumbing::UnindexedConsumer<Self::Item>,
453    {
454        let par_iter = (0..self.g.node_count().index())
455            .into_par_iter()
456            .flat_map_iter(|n| {
457                let n = NI::new(n);
458                self.g
459                    .out_neighbors_with_values(n)
460                    .map(move |t| (n, t.target, t.value))
461            });
462        par_iter.drive_unindexed(consumer)
463    }
464}
465
466impl<NI: Idx, NV, EV> Graph<NI> for DirectedCsrGraph<NI, NV, EV> {
467    delegate::delegate! {
468        to self.csr_out {
469            fn node_count(&self) -> NI;
470            fn edge_count(&self) -> NI;
471        }
472    }
473}
474
475impl<NI: Idx, NV, EV> NodeValuesTrait<NI, NV> for DirectedCsrGraph<NI, NV, EV> {
476    fn node_value(&self, node: NI) -> &NV {
477        &self.node_values.0[node.index()]
478    }
479}
480
481impl<NI: Idx, NV, EV> DirectedDegrees<NI> for DirectedCsrGraph<NI, NV, EV> {
482    fn out_degree(&self, node: NI) -> NI {
483        self.csr_out.degree(node)
484    }
485
486    fn in_degree(&self, node: NI) -> NI {
487        self.csr_inc.degree(node)
488    }
489}
490
491impl<NI: Idx, NV> DirectedNeighbors<NI> for DirectedCsrGraph<NI, NV, ()> {
492    type NeighborsIterator<'a>
493        = std::slice::Iter<'a, NI>
494    where
495        NV: 'a;
496
497    fn out_neighbors(&self, node: NI) -> Self::NeighborsIterator<'_> {
498        self.csr_out.targets(node).iter()
499    }
500
501    fn in_neighbors(&self, node: NI) -> Self::NeighborsIterator<'_> {
502        self.csr_inc.targets(node).iter()
503    }
504}
505
506impl<NI: Idx, NV, EV> DirectedNeighborsWithValues<NI, EV> for DirectedCsrGraph<NI, NV, EV> {
507    type NeighborsIterator<'a>
508        = std::slice::Iter<'a, Target<NI, EV>>
509    where
510        NV: 'a,
511        EV: 'a;
512
513    fn out_neighbors_with_values(&self, node: NI) -> Self::NeighborsIterator<'_> {
514        self.csr_out.targets_with_values(node).iter()
515    }
516
517    fn in_neighbors_with_values(&self, node: NI) -> Self::NeighborsIterator<'_> {
518        self.csr_inc.targets_with_values(node).iter()
519    }
520}
521
522impl<NI, EV, E> From<(E, CsrLayout)> for DirectedCsrGraph<NI, (), EV>
523where
524    NI: Idx,
525    EV: Copy + Send + Sync,
526    E: Edges<NI = NI, EV = EV>,
527{
528    fn from((edge_list, csr_option): (E, CsrLayout)) -> Self {
529        info!("Creating directed graph");
530        let node_count = edge_list.max_node_id() + NI::new(1);
531
532        let node_values = NodeValues::new(vec![(); node_count.index()]);
533
534        let start = Instant::now();
535        let csr_out = Csr::from((&edge_list, node_count, Direction::Outgoing, csr_option));
536        info!("Created outgoing csr in {:?}.", start.elapsed());
537
538        let start = Instant::now();
539        let csr_inc = Csr::from((&edge_list, node_count, Direction::Incoming, csr_option));
540        info!("Created incoming csr in {:?}.", start.elapsed());
541
542        DirectedCsrGraph::new(node_values, csr_out, csr_inc)
543    }
544}
545
546impl<NI, NV, EV, E> From<(NodeValues<NV>, E, CsrLayout)> for DirectedCsrGraph<NI, NV, EV>
547where
548    NI: Idx,
549    EV: Copy + Send + Sync,
550    E: Edges<NI = NI, EV = EV>,
551{
552    fn from((node_values, edge_list, csr_option): (NodeValues<NV>, E, CsrLayout)) -> Self {
553        info!("Creating directed graph");
554        let node_count = NI::new(node_values.0.len());
555        let node_count_from_edge_list = edge_list.max_node_id() + NI::new(1);
556
557        assert!(
558            node_count >= node_count_from_edge_list,
559            "number of node values ({}) does not match node count of edge list ({})",
560            node_count.index(),
561            node_count_from_edge_list.index()
562        );
563
564        let start = Instant::now();
565        let csr_out = Csr::from((&edge_list, node_count, Direction::Outgoing, csr_option));
566        info!("Created outgoing csr in {:?}.", start.elapsed());
567
568        let start = Instant::now();
569        let csr_inc = Csr::from((&edge_list, node_count, Direction::Incoming, csr_option));
570        info!("Created incoming csr in {:?}.", start.elapsed());
571
572        DirectedCsrGraph::new(node_values, csr_out, csr_inc)
573    }
574}
575
576#[cfg(feature = "dotgraph")]
577impl<NI, Label> From<(DotGraph<NI, Label>, CsrLayout)> for DirectedCsrGraph<NI, ()>
578where
579    NI: Idx,
580    Label: Idx + Hash,
581{
582    fn from((dot_graph, csr_layout): (DotGraph<NI, Label>, CsrLayout)) -> Self {
583        let DotGraph { edge_list, .. } = dot_graph;
584
585        DirectedCsrGraph::from((edge_list, csr_layout))
586    }
587}
588
589#[cfg(feature = "dotgraph")]
590impl<NI, Label> From<(DotGraph<NI, Label>, CsrLayout)> for DirectedCsrGraph<NI, Label>
591where
592    NI: Idx,
593    Label: Idx + Hash,
594{
595    fn from((dot_graph, csr_layout): (DotGraph<NI, Label>, CsrLayout)) -> Self {
596        let DotGraph {
597            edge_list, labels, ..
598        } = dot_graph;
599
600        let node_values = NodeValues::new(labels);
601
602        DirectedCsrGraph::from((node_values, edge_list, csr_layout))
603    }
604}
605
606impl<W, NI, NV, EV> SerializeGraphOp<W> for DirectedCsrGraph<NI, NV, EV>
607where
608    W: Write,
609    NI: Idx + ToByteSlice,
610    NV: ToByteSlice,
611    EV: ToByteSlice,
612{
613    fn serialize(&self, mut output: W) -> Result<(), Error> {
614        let DirectedCsrGraph {
615            node_values,
616            csr_out,
617            csr_inc,
618        } = self;
619
620        node_values.serialize(&mut output)?;
621        csr_out.serialize(&mut output)?;
622        csr_inc.serialize(&mut output)?;
623
624        Ok(())
625    }
626}
627
628impl<R, NI, NV, EV> DeserializeGraphOp<R, Self> for DirectedCsrGraph<NI, NV, EV>
629where
630    R: Read,
631    NI: Idx + ToMutByteSlice,
632    NV: ToMutByteSlice,
633    EV: ToMutByteSlice,
634{
635    fn deserialize(mut read: R) -> Result<Self, Error> {
636        let node_values: NodeValues<NV> = NodeValues::deserialize(&mut read)?;
637        let csr_out: Csr<NI, NI, EV> = Csr::deserialize(&mut read)?;
638        let csr_inc: Csr<NI, NI, EV> = Csr::deserialize(&mut read)?;
639        Ok(DirectedCsrGraph::new(node_values, csr_out, csr_inc))
640    }
641}
642
643impl<NI, EV> TryFrom<(PathBuf, CsrLayout)> for DirectedCsrGraph<NI, EV>
644where
645    NI: Idx + ToMutByteSlice,
646    EV: ToMutByteSlice,
647{
648    type Error = Error;
649
650    fn try_from((path, _): (PathBuf, CsrLayout)) -> Result<Self, Self::Error> {
651        let reader = BufReader::new(File::open(path)?);
652        let graph = DirectedCsrGraph::deserialize(reader)?;
653
654        Ok(graph)
655    }
656}
657
658pub struct UndirectedCsrGraph<NI: Idx, NV = (), EV = ()> {
659    node_values: NodeValues<NV>,
660    csr: Csr<NI, NI, EV>,
661}
662
663impl<NI: Idx, EV> From<Csr<NI, NI, EV>> for UndirectedCsrGraph<NI, (), EV> {
664    fn from(csr: Csr<NI, NI, EV>) -> Self {
665        UndirectedCsrGraph::new(NodeValues::new(vec![(); csr.node_count().index()]), csr)
666    }
667}
668
669impl<NI: Idx, NV, EV> UndirectedCsrGraph<NI, NV, EV> {
670    pub fn new(node_values: NodeValues<NV>, csr: Csr<NI, NI, EV>) -> Self {
671        let g = Self { node_values, csr };
672        info!(
673            "Created undirected graph (node_count = {:?}, edge_count = {:?})",
674            g.node_count(),
675            g.edge_count()
676        );
677
678        g
679    }
680}
681
682impl<NI: Idx, NV, EV> Graph<NI> for UndirectedCsrGraph<NI, NV, EV> {
683    fn node_count(&self) -> NI {
684        self.csr.node_count()
685    }
686
687    fn edge_count(&self) -> NI {
688        self.csr.edge_count() / NI::new(2)
689    }
690}
691
692impl<NI: Idx, NV, EV> NodeValuesTrait<NI, NV> for UndirectedCsrGraph<NI, NV, EV> {
693    fn node_value(&self, node: NI) -> &NV {
694        &self.node_values.0[node.index()]
695    }
696}
697
698impl<NI: Idx, NV, EV> UndirectedDegrees<NI> for UndirectedCsrGraph<NI, NV, EV> {
699    fn degree(&self, node: NI) -> NI {
700        self.csr.degree(node)
701    }
702}
703
704impl<NI: Idx, NV> UndirectedNeighbors<NI> for UndirectedCsrGraph<NI, NV> {
705    type NeighborsIterator<'a>
706        = std::slice::Iter<'a, NI>
707    where
708        NV: 'a;
709
710    fn neighbors(&self, node: NI) -> Self::NeighborsIterator<'_> {
711        self.csr.targets(node).iter()
712    }
713}
714
715impl<NI: Idx, NV, EV> UndirectedNeighborsWithValues<NI, EV> for UndirectedCsrGraph<NI, NV, EV> {
716    type NeighborsIterator<'a>
717        = std::slice::Iter<'a, Target<NI, EV>>
718    where
719        NV: 'a,
720        EV: 'a;
721
722    fn neighbors_with_values(&self, node: NI) -> Self::NeighborsIterator<'_> {
723        self.csr.targets_with_values(node).iter()
724    }
725}
726
727impl<NI: Idx, NV, EV> SwapCsr<NI, NI, EV> for UndirectedCsrGraph<NI, NV, EV> {
728    fn swap_csr(&mut self, mut csr: Csr<NI, NI, EV>) -> &mut Self {
729        std::mem::swap(&mut self.csr, &mut csr);
730        self
731    }
732}
733
734impl<NI, EV, E> From<(E, CsrLayout)> for UndirectedCsrGraph<NI, (), EV>
735where
736    NI: Idx,
737    EV: Copy + Send + Sync,
738    E: Edges<NI = NI, EV = EV>,
739{
740    fn from((edge_list, csr_option): (E, CsrLayout)) -> Self {
741        info!("Creating undirected graph");
742        let node_count = edge_list.max_node_id() + NI::new(1);
743
744        let node_values = NodeValues::new(vec![(); node_count.index()]);
745
746        let start = Instant::now();
747        let csr = Csr::from((&edge_list, node_count, Direction::Undirected, csr_option));
748        info!("Created csr in {:?}.", start.elapsed());
749
750        UndirectedCsrGraph::new(node_values, csr)
751    }
752}
753
754impl<NI, NV, EV, E> From<(NodeValues<NV>, E, CsrLayout)> for UndirectedCsrGraph<NI, NV, EV>
755where
756    NI: Idx,
757    EV: Copy + Send + Sync,
758    E: Edges<NI = NI, EV = EV>,
759{
760    fn from((node_values, edge_list, csr_option): (NodeValues<NV>, E, CsrLayout)) -> Self {
761        info!("Creating undirected graph");
762        let node_count = NI::new(node_values.0.len());
763        let node_count_from_edge_list = edge_list.max_node_id() + NI::new(1);
764
765        assert!(
766            node_count >= node_count_from_edge_list,
767            "number of node values ({}) does not match node count of edge list ({})",
768            node_count.index(),
769            node_count_from_edge_list.index()
770        );
771
772        let start = Instant::now();
773        let csr = Csr::from((&edge_list, node_count, Direction::Undirected, csr_option));
774        info!("Created csr in {:?}.", start.elapsed());
775
776        UndirectedCsrGraph::new(node_values, csr)
777    }
778}
779
780#[cfg(feature = "dotgraph")]
781impl<NI, Label> From<(DotGraph<NI, Label>, CsrLayout)> for UndirectedCsrGraph<NI, ()>
782where
783    NI: Idx,
784    Label: Idx + Hash,
785{
786    fn from((dot_graph, csr_layout): (DotGraph<NI, Label>, CsrLayout)) -> Self {
787        let DotGraph { edge_list, .. } = dot_graph;
788
789        UndirectedCsrGraph::from((edge_list, csr_layout))
790    }
791}
792
793#[cfg(feature = "dotgraph")]
794impl<NI, Label> From<(DotGraph<NI, Label>, CsrLayout)> for UndirectedCsrGraph<NI, Label>
795where
796    NI: Idx,
797    Label: Idx + Hash,
798{
799    fn from((dot_graph, csr_layout): (DotGraph<NI, Label>, CsrLayout)) -> Self {
800        let DotGraph {
801            edge_list, labels, ..
802        } = dot_graph;
803
804        let node_values = NodeValues::new(labels);
805
806        UndirectedCsrGraph::from((node_values, edge_list, csr_layout))
807    }
808}
809
810impl<W, NI, NV, EV> SerializeGraphOp<W> for UndirectedCsrGraph<NI, NV, EV>
811where
812    W: Write,
813    NI: Idx + ToByteSlice,
814    NV: ToByteSlice,
815    EV: ToByteSlice,
816{
817    fn serialize(&self, mut output: W) -> Result<(), Error> {
818        let UndirectedCsrGraph { node_values, csr } = self;
819
820        node_values.serialize(&mut output)?;
821        csr.serialize(&mut output)?;
822
823        Ok(())
824    }
825}
826
827impl<R, NI, NV, EV> DeserializeGraphOp<R, Self> for UndirectedCsrGraph<NI, NV, EV>
828where
829    R: Read,
830    NI: Idx + ToMutByteSlice,
831    NV: ToMutByteSlice,
832    EV: ToMutByteSlice,
833{
834    fn deserialize(mut read: R) -> Result<Self, Error> {
835        let node_values = NodeValues::deserialize(&mut read)?;
836        let csr: Csr<NI, NI, EV> = Csr::deserialize(&mut read)?;
837        Ok(UndirectedCsrGraph::new(node_values, csr))
838    }
839}
840
841impl<NI, EV> TryFrom<(PathBuf, CsrLayout)> for UndirectedCsrGraph<NI, EV>
842where
843    NI: Idx + ToMutByteSlice,
844    EV: ToMutByteSlice,
845{
846    type Error = Error;
847
848    fn try_from((path, _): (PathBuf, CsrLayout)) -> Result<Self, Self::Error> {
849        let reader = BufReader::new(File::open(path)?);
850        UndirectedCsrGraph::deserialize(reader)
851    }
852}
853
854fn prefix_sum_atomic<NI: Idx>(degrees: Vec<Atomic<NI>>) -> Vec<Atomic<NI>> {
855    let mut last = degrees.last().unwrap().load(Acquire);
856    let mut sums = degrees
857        .into_iter()
858        .scan(NI::zero(), |total, degree| {
859            let value = *total;
860            *total += degree.into_inner();
861            Some(Atomic::new(value))
862        })
863        .collect::<Vec<_>>();
864
865    last += sums.last().unwrap().load(Acquire);
866    sums.push(Atomic::new(last));
867
868    sums
869}
870
871pub(crate) fn prefix_sum<NI: Idx>(degrees: Vec<NI>) -> Vec<NI> {
872    let mut last = *degrees.last().unwrap();
873    let mut sums = degrees
874        .into_iter()
875        .scan(NI::zero(), |total, degree| {
876            let value = *total;
877            *total += degree;
878            Some(value)
879        })
880        .collect::<Vec<_>>();
881    last += *sums.last().unwrap();
882    sums.push(last);
883    sums
884}
885
886pub(crate) fn sort_targets<NI, T, EV>(offsets: &[NI], targets: &mut [Target<T, EV>])
887where
888    NI: Idx,
889    T: Copy + Send + Ord,
890    EV: Send,
891{
892    to_mut_slices(offsets, targets)
893        .par_iter_mut()
894        .for_each(|list| list.sort_unstable());
895}
896
897fn sort_and_deduplicate_targets<NI, EV>(
898    offsets: &[NI],
899    targets: &mut [Target<NI, EV>],
900) -> (Vec<NI>, Vec<Target<NI, EV>>)
901where
902    NI: Idx,
903    EV: Copy + Send,
904{
905    let node_count = offsets.len() - 1;
906
907    let mut new_degrees = Vec::with_capacity(node_count);
908    let mut target_slices = to_mut_slices(offsets, targets);
909
910    target_slices
911        .par_iter_mut()
912        .enumerate()
913        .map(|(node, slice)| {
914            slice.sort_unstable();
915            // deduplicate
916            let (dedup, _) = slice.partition_dedup_compat();
917            let mut new_degree = dedup.len();
918            // remove self loops .. there is at most once occurence of node inside dedup
919            if let Ok(idx) = dedup.binary_search_by_key(&NI::new(node), |t| t.target) {
920                dedup[idx..].rotate_left(1);
921                new_degree -= 1;
922            }
923            NI::new(new_degree)
924        })
925        .collect_into_vec(&mut new_degrees);
926
927    let new_offsets = prefix_sum(new_degrees);
928    debug_assert_eq!(new_offsets.len(), node_count + 1);
929
930    let edge_count = new_offsets[node_count].index();
931    let mut new_targets: Vec<Target<NI, EV>> = Vec::with_capacity(edge_count);
932    let new_target_slices = to_mut_slices(&new_offsets, new_targets.spare_capacity_mut());
933
934    target_slices
935        .into_par_iter()
936        .zip(new_target_slices.into_par_iter())
937        .for_each(|(old_slice, new_slice)| {
938            MaybeUninit::write_slice_compat(new_slice, &old_slice[..new_slice.len()]);
939        });
940
941    // SAFETY: We copied all (potentially shortened) target ids from the old
942    // target list to the new one.
943    unsafe {
944        new_targets.set_len(edge_count);
945    }
946
947    (new_offsets, new_targets)
948}
949
950fn to_mut_slices<'targets, NI: Idx, T>(
951    offsets: &[NI],
952    targets: &'targets mut [T],
953) -> Vec<&'targets mut [T]> {
954    let node_count = offsets.len() - 1;
955    let mut target_slices = Vec::with_capacity(node_count);
956    let mut tail = targets;
957    let mut prev_offset = offsets[0];
958
959    for &offset in &offsets[1..] {
960        let (list, remainder) = tail.split_at_mut((offset - prev_offset).index());
961        target_slices.push(list);
962        tail = remainder;
963        prev_offset = offset;
964    }
965
966    target_slices
967}
968
969#[cfg(test)]
970mod tests {
971    use std::{
972        io::{Seek, SeekFrom},
973        sync::atomic::Ordering::SeqCst,
974    };
975
976    use rayon::ThreadPoolBuilder;
977
978    use crate::builder::GraphBuilder;
979
980    use super::*;
981
982    #[test]
983    fn to_mut_slices_test() {
984        let offsets = &[0, 2, 5, 5, 8];
985        let targets = &mut [0, 1, 2, 3, 4, 5, 6, 7];
986        let slices = to_mut_slices::<usize, usize>(offsets, targets);
987
988        assert_eq!(
989            slices,
990            vec![vec![0, 1], vec![2, 3, 4], vec![], vec![5, 6, 7]]
991        );
992    }
993
994    fn t<T>(t: T) -> Target<T, ()> {
995        Target::new(t, ())
996    }
997
998    #[test]
999    fn sort_targets_test() {
1000        let offsets = &[0, 2, 5, 5, 8];
1001        let mut targets = vec![t(1), t(0), t(4), t(2), t(3), t(5), t(6), t(7)];
1002        sort_targets::<usize, _, _>(offsets, &mut targets);
1003
1004        assert_eq!(
1005            targets,
1006            vec![t(0), t(1), t(2), t(3), t(4), t(5), t(6), t(7)]
1007        );
1008    }
1009
1010    #[test]
1011    fn sort_and_deduplicate_targets_test() {
1012        let offsets = &[0, 3, 7, 7, 10];
1013        // 0: [1, 1, 0]    => [1] (removed duplicate and self loop)
1014        // 1: [4, 2, 3, 2] => [2, 3, 4] (removed duplicate)
1015        let mut targets = vec![t(1), t(1), t(0), t(4), t(2), t(3), t(2), t(5), t(6), t(7)];
1016        let (offsets, targets) = sort_and_deduplicate_targets::<usize, _>(offsets, &mut targets);
1017
1018        assert_eq!(offsets, vec![0, 1, 4, 4, 7]);
1019        assert_eq!(targets, vec![t(1), t(2), t(3), t(4), t(5), t(6), t(7)]);
1020    }
1021
1022    #[test]
1023    fn prefix_sum_test() {
1024        let degrees = vec![42, 0, 1337, 4, 2, 0];
1025        let prefix_sum = prefix_sum::<usize>(degrees);
1026
1027        assert_eq!(prefix_sum, vec![0, 42, 42, 1379, 1383, 1385, 1385]);
1028    }
1029
1030    #[test]
1031    fn prefix_sum_atomic_test() {
1032        let degrees = vec![42, 0, 1337, 4, 2, 0]
1033            .into_iter()
1034            .map(Atomic::<usize>::new)
1035            .collect::<Vec<_>>();
1036
1037        let prefix_sum = prefix_sum_atomic(degrees)
1038            .into_iter()
1039            .map(|n| n.load(SeqCst))
1040            .collect::<Vec<_>>();
1041
1042        assert_eq!(prefix_sum, vec![0, 42, 42, 1379, 1383, 1385, 1385]);
1043    }
1044
1045    #[test]
1046    fn serialize_directed_usize_graph_test() {
1047        let mut file = tempfile::tempfile().unwrap();
1048
1049        let g0: DirectedCsrGraph<usize> = GraphBuilder::new()
1050            .edges(vec![(0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (3, 1)])
1051            .build();
1052
1053        assert!(g0.serialize(&file).is_ok());
1054
1055        file.seek(SeekFrom::Start(0)).unwrap();
1056        let g1 = DirectedCsrGraph::<usize>::deserialize(file).unwrap();
1057
1058        assert_eq!(g0.node_count(), g1.node_count());
1059        assert_eq!(g0.edge_count(), g1.edge_count());
1060
1061        assert_eq!(
1062            g0.out_neighbors(0).as_slice(),
1063            g1.out_neighbors(0).as_slice()
1064        );
1065        assert_eq!(
1066            g0.out_neighbors(1).as_slice(),
1067            g1.out_neighbors(1).as_slice()
1068        );
1069        assert_eq!(
1070            g0.out_neighbors(2).as_slice(),
1071            g1.out_neighbors(2).as_slice()
1072        );
1073        assert_eq!(
1074            g0.out_neighbors(3).as_slice(),
1075            g1.out_neighbors(3).as_slice()
1076        );
1077
1078        assert_eq!(g0.in_neighbors(0).as_slice(), g1.in_neighbors(0).as_slice());
1079        assert_eq!(g0.in_neighbors(1).as_slice(), g1.in_neighbors(1).as_slice());
1080        assert_eq!(g0.in_neighbors(2).as_slice(), g1.in_neighbors(2).as_slice());
1081        assert_eq!(g0.in_neighbors(3).as_slice(), g1.in_neighbors(3).as_slice());
1082    }
1083
1084    #[test]
1085    fn serialize_undirected_usize_graph_test() {
1086        let mut file = tempfile::tempfile().unwrap();
1087
1088        let g0: UndirectedCsrGraph<usize> = GraphBuilder::new()
1089            .edges(vec![(0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (3, 1)])
1090            .build();
1091
1092        assert!(g0.serialize(&file).is_ok());
1093
1094        file.seek(SeekFrom::Start(0)).unwrap();
1095
1096        let g1 = UndirectedCsrGraph::<usize>::deserialize(file).unwrap();
1097
1098        assert_eq!(g0.node_count(), g1.node_count());
1099        assert_eq!(g0.edge_count(), g1.edge_count());
1100
1101        assert_eq!(g0.neighbors(0).as_slice(), g1.neighbors(0).as_slice());
1102        assert_eq!(g0.neighbors(1).as_slice(), g1.neighbors(1).as_slice());
1103        assert_eq!(g0.neighbors(2).as_slice(), g1.neighbors(2).as_slice());
1104        assert_eq!(g0.neighbors(3).as_slice(), g1.neighbors(3).as_slice());
1105    }
1106
1107    #[test]
1108    fn serialize_directed_u32_graph_test() {
1109        let mut file = tempfile::tempfile().unwrap();
1110
1111        let g0: DirectedCsrGraph<u32> = GraphBuilder::new()
1112            .edges(vec![(0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (3, 1)])
1113            .build();
1114
1115        assert!(g0.serialize(&file).is_ok());
1116
1117        file.seek(SeekFrom::Start(0)).unwrap();
1118        let g1 = DirectedCsrGraph::<u32>::deserialize(file).unwrap();
1119
1120        assert_eq!(g0.node_count(), g1.node_count());
1121        assert_eq!(g0.edge_count(), g1.edge_count());
1122
1123        assert_eq!(
1124            g0.out_neighbors(0).as_slice(),
1125            g1.out_neighbors(0).as_slice()
1126        );
1127        assert_eq!(
1128            g0.out_neighbors(1).as_slice(),
1129            g1.out_neighbors(1).as_slice()
1130        );
1131        assert_eq!(
1132            g0.out_neighbors(2).as_slice(),
1133            g1.out_neighbors(2).as_slice()
1134        );
1135        assert_eq!(
1136            g0.out_neighbors(3).as_slice(),
1137            g1.out_neighbors(3).as_slice()
1138        );
1139
1140        assert_eq!(g0.in_neighbors(0).as_slice(), g1.in_neighbors(0).as_slice());
1141        assert_eq!(g0.in_neighbors(1).as_slice(), g1.in_neighbors(1).as_slice());
1142        assert_eq!(g0.in_neighbors(2).as_slice(), g1.in_neighbors(2).as_slice());
1143        assert_eq!(g0.in_neighbors(3).as_slice(), g1.in_neighbors(3).as_slice());
1144    }
1145
1146    #[test]
1147    fn serialize_undirected_u32_graph_test() {
1148        let mut file = tempfile::tempfile().unwrap();
1149
1150        let g0: UndirectedCsrGraph<u32> = GraphBuilder::new()
1151            .edges(vec![(0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (3, 1)])
1152            .build();
1153
1154        assert!(g0.serialize(&file).is_ok());
1155
1156        file.seek(SeekFrom::Start(0)).unwrap();
1157
1158        let g1 = UndirectedCsrGraph::<u32>::deserialize(file).unwrap();
1159
1160        assert_eq!(g0.node_count(), g1.node_count());
1161        assert_eq!(g0.edge_count(), g1.edge_count());
1162
1163        assert_eq!(g0.neighbors(0).as_slice(), g1.neighbors(0).as_slice());
1164        assert_eq!(g0.neighbors(1).as_slice(), g1.neighbors(1).as_slice());
1165        assert_eq!(g0.neighbors(2).as_slice(), g1.neighbors(2).as_slice());
1166        assert_eq!(g0.neighbors(3).as_slice(), g1.neighbors(3).as_slice());
1167    }
1168
1169    #[test]
1170    fn serialize_invalid_id_size() {
1171        let mut file = tempfile::tempfile().unwrap();
1172
1173        let g0: UndirectedCsrGraph<u32> = GraphBuilder::new()
1174            .edges(vec![(0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (3, 1)])
1175            .build();
1176
1177        assert!(g0.serialize(&file).is_ok());
1178
1179        file.seek(SeekFrom::Start(0)).unwrap();
1180
1181        let res: Result<UndirectedCsrGraph<usize>, Error> =
1182            UndirectedCsrGraph::<usize>::deserialize(file);
1183
1184        assert!(res.is_err());
1185
1186        let _expected = Error::InvalidIdType {
1187            expected: String::from("usize"),
1188            actual: String::from("u32"),
1189        };
1190
1191        assert!(matches!(res, _expected));
1192    }
1193
1194    #[test]
1195    fn test_to_undirected() {
1196        // we need a deterministic order of loading, so we're doing stuff in serial
1197        let pool = ThreadPoolBuilder::new().num_threads(1).build().unwrap();
1198        pool.install(|| {
1199            let g: DirectedCsrGraph<u32> = GraphBuilder::new()
1200                .edges(vec![(0, 1), (3, 0), (0, 3), (7, 0), (0, 42), (21, 0)])
1201                .build();
1202
1203            let ug = g.to_undirected(None);
1204            assert_eq!(ug.degree(0), 6);
1205            assert_eq!(ug.neighbors(0).as_slice(), &[1, 3, 42, 3, 7, 21]);
1206
1207            let ug = g.to_undirected(CsrLayout::Unsorted);
1208            assert_eq!(ug.degree(0), 6);
1209            assert_eq!(ug.neighbors(0).as_slice(), &[1, 3, 42, 3, 7, 21]);
1210
1211            let ug = g.to_undirected(CsrLayout::Sorted);
1212            assert_eq!(ug.degree(0), 6);
1213            assert_eq!(ug.neighbors(0).as_slice(), &[1, 3, 3, 7, 21, 42]);
1214
1215            let ug = g.to_undirected(CsrLayout::Deduplicated);
1216            assert_eq!(ug.degree(0), 5);
1217            assert_eq!(ug.neighbors(0).as_slice(), &[1, 3, 7, 21, 42]);
1218        });
1219    }
1220
1221    #[test]
1222    fn directed_from_node_values_exceeding_edge_list_max_id() {
1223        let g0: DirectedCsrGraph<u32, u32> = GraphBuilder::new()
1224            .edges(vec![(0, 1), (1, 2)])
1225            .node_values(vec![0, 1, 2, 3])
1226            .build();
1227
1228        assert_eq!(g0.node_count(), 4);
1229        for node in 0..4 {
1230            assert_eq!(g0.node_value(node), &node);
1231        }
1232
1233        assert_eq!(g0.out_degree(0), 1);
1234        assert_eq!(g0.out_degree(1), 1);
1235        assert_eq!(g0.out_degree(2), 0);
1236        assert_eq!(g0.out_degree(3), 0);
1237    }
1238
1239    #[test]
1240    fn undirected_from_node_values_exceeding_edge_list_max_id() {
1241        let g0: UndirectedCsrGraph<u32, u32> = GraphBuilder::new()
1242            .edges(vec![(0, 1), (1, 2)])
1243            .node_values(vec![0, 1, 2, 3])
1244            .build();
1245
1246        assert_eq!(g0.node_count(), 4);
1247        for node in 0..4 {
1248            assert_eq!(g0.node_value(node), &node);
1249        }
1250
1251        assert_eq!(g0.degree(0), 1);
1252        assert_eq!(g0.degree(1), 2);
1253        assert_eq!(g0.degree(2), 1);
1254        assert_eq!(g0.degree(3), 0);
1255    }
1256}