1use super::node::NodeId;
14use crate::{GraphError, Result};
15
16#[derive(Debug, Clone)]
23pub struct CsrMatrix {
24 pub num_rows: usize,
26 pub num_cols: usize,
28 pub row_ptr: Vec<u64>,
30 pub col_idx: Vec<u32>,
32 pub values: Option<Vec<f64>>,
34}
35
36impl CsrMatrix {
37 pub fn empty(num_nodes: usize) -> Self {
39 Self {
40 num_rows: num_nodes,
41 num_cols: num_nodes,
42 row_ptr: vec![0; num_nodes + 1],
43 col_idx: Vec::new(),
44 values: None,
45 }
46 }
47
48 pub fn from_edges(num_nodes: usize, edges: &[(u32, u32)]) -> Self {
65 CsrMatrixBuilder::new(num_nodes).with_edges(edges).build()
66 }
67
68 pub fn from_weighted_edges(num_nodes: usize, edges: &[(u32, u32, f64)]) -> Self {
70 CsrMatrixBuilder::new(num_nodes)
71 .with_weighted_edges(edges)
72 .build()
73 }
74
75 pub fn num_nonzeros(&self) -> usize {
77 self.col_idx.len()
78 }
79
80 pub fn is_empty(&self) -> bool {
82 self.col_idx.is_empty()
83 }
84
85 pub fn degree(&self, node: NodeId) -> usize {
87 let i = node.0 as usize;
88 if i >= self.num_rows {
89 return 0;
90 }
91 (self.row_ptr[i + 1] - self.row_ptr[i]) as usize
92 }
93
94 pub fn neighbors(&self, node: NodeId) -> &[u32] {
96 let i = node.0 as usize;
97 if i >= self.num_rows {
98 return &[];
99 }
100 let start = self.row_ptr[i] as usize;
101 let end = self.row_ptr[i + 1] as usize;
102 &self.col_idx[start..end]
103 }
104
105 pub fn weighted_neighbors(&self, node: NodeId) -> Vec<(NodeId, f64)> {
107 let i = node.0 as usize;
108 if i >= self.num_rows {
109 return Vec::new();
110 }
111 let start = self.row_ptr[i] as usize;
112 let end = self.row_ptr[i + 1] as usize;
113
114 let neighbors = &self.col_idx[start..end];
115 match &self.values {
116 Some(vals) => neighbors
117 .iter()
118 .zip(&vals[start..end])
119 .map(|(&col, &w)| (NodeId(col), w))
120 .collect(),
121 None => neighbors.iter().map(|&col| (NodeId(col), 1.0)).collect(),
122 }
123 }
124
125 pub fn has_edge(&self, src: NodeId, dst: NodeId) -> bool {
127 self.neighbors(src).contains(&dst.0)
128 }
129
130 pub fn validate(&self) -> Result<()> {
132 if self.row_ptr.len() != self.num_rows + 1 {
134 return Err(GraphError::InvalidCsr(format!(
135 "row_ptr length {} != num_rows + 1 = {}",
136 self.row_ptr.len(),
137 self.num_rows + 1
138 )));
139 }
140
141 for i in 0..self.num_rows {
143 if self.row_ptr[i] > self.row_ptr[i + 1] {
144 return Err(GraphError::InvalidCsr(format!(
145 "row_ptr not monotonic at index {}",
146 i
147 )));
148 }
149 }
150
151 let nnz = *self.row_ptr.last().unwrap_or(&0) as usize;
153 if nnz != self.col_idx.len() {
154 return Err(GraphError::InvalidCsr(format!(
155 "row_ptr[-1] = {} != col_idx.len() = {}",
156 nnz,
157 self.col_idx.len()
158 )));
159 }
160
161 if let Some(ref vals) = self.values {
163 if vals.len() != self.col_idx.len() {
164 return Err(GraphError::InvalidCsr(format!(
165 "values.len() = {} != col_idx.len() = {}",
166 vals.len(),
167 self.col_idx.len()
168 )));
169 }
170 }
171
172 for &col in &self.col_idx {
174 if col as usize >= self.num_cols {
175 return Err(GraphError::InvalidCsr(format!(
176 "col_idx {} >= num_cols {}",
177 col, self.num_cols
178 )));
179 }
180 }
181
182 Ok(())
183 }
184
185 pub fn transpose(&self) -> Self {
187 let mut builder = CsrMatrixBuilder::new(self.num_cols);
188
189 let mut counts = vec![0u64; self.num_cols];
191 for &col in &self.col_idx {
192 counts[col as usize] += 1;
193 }
194
195 for row in 0..self.num_rows {
197 let start = self.row_ptr[row] as usize;
198 let end = self.row_ptr[row + 1] as usize;
199 for (i, &col) in self.col_idx[start..end].iter().enumerate() {
200 let weight = self.values.as_ref().map(|v| v[start + i]);
201 builder.edges.push((col, row as u32, weight));
202 }
203 }
204
205 builder.build()
206 }
207}
208
209#[derive(Debug, Default)]
211pub struct CsrMatrixBuilder {
212 num_nodes: usize,
213 edges: Vec<(u32, u32, Option<f64>)>,
214}
215
216impl CsrMatrixBuilder {
217 pub fn new(num_nodes: usize) -> Self {
219 Self {
220 num_nodes,
221 edges: Vec::new(),
222 }
223 }
224
225 pub fn with_edges(mut self, edges: &[(u32, u32)]) -> Self {
227 for &(src, dst) in edges {
228 self.edges.push((src, dst, None));
229 }
230 self
231 }
232
233 pub fn with_weighted_edges(mut self, edges: &[(u32, u32, f64)]) -> Self {
235 for &(src, dst, w) in edges {
236 self.edges.push((src, dst, Some(w)));
237 }
238 self
239 }
240
241 pub fn add_edge(&mut self, src: u32, dst: u32) {
243 self.edges.push((src, dst, None));
244 }
245
246 pub fn add_weighted_edge(&mut self, src: u32, dst: u32, weight: f64) {
248 self.edges.push((src, dst, Some(weight)));
249 }
250
251 pub fn build(mut self) -> CsrMatrix {
253 self.edges.sort_by_key(|e| e.0);
255
256 let has_weights = self.edges.iter().any(|e| e.2.is_some());
257
258 let mut row_ptr = vec![0u64; self.num_nodes + 1];
260 for &(src, _, _) in &self.edges {
261 if (src as usize) < self.num_nodes {
262 row_ptr[src as usize + 1] += 1;
263 }
264 }
265
266 for i in 1..=self.num_nodes {
268 row_ptr[i] += row_ptr[i - 1];
269 }
270
271 let col_idx: Vec<u32> = self.edges.iter().map(|e| e.1).collect();
273 let values = if has_weights {
274 Some(self.edges.iter().map(|e| e.2.unwrap_or(1.0)).collect())
275 } else {
276 None
277 };
278
279 CsrMatrix {
280 num_rows: self.num_nodes,
281 num_cols: self.num_nodes,
282 row_ptr,
283 col_idx,
284 values,
285 }
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292
293 #[test]
294 fn test_empty_matrix() {
295 let csr = CsrMatrix::empty(5);
296 assert_eq!(csr.num_rows, 5);
297 assert_eq!(csr.num_nonzeros(), 0);
298 assert!(csr.is_empty());
299 }
300
301 #[test]
302 fn test_from_edges() {
303 let edges = [(0, 1), (1, 2), (1, 3)];
308 let csr = CsrMatrix::from_edges(4, &edges);
309
310 assert_eq!(csr.num_rows, 4);
311 assert_eq!(csr.num_nonzeros(), 3);
312 assert!(csr.validate().is_ok());
313 }
314
315 #[test]
316 fn test_neighbors() {
317 let edges = [(0, 1), (0, 2), (1, 2)];
318 let csr = CsrMatrix::from_edges(3, &edges);
319
320 let n0 = csr.neighbors(NodeId(0));
321 assert_eq!(n0.len(), 2);
322 assert!(n0.contains(&1));
323 assert!(n0.contains(&2));
324
325 let n1 = csr.neighbors(NodeId(1));
326 assert_eq!(n1.len(), 1);
327 assert!(n1.contains(&2));
328
329 let n2 = csr.neighbors(NodeId(2));
330 assert!(n2.is_empty());
331 }
332
333 #[test]
334 fn test_degree() {
335 let edges = [(0, 1), (0, 2), (0, 3), (1, 2)];
336 let csr = CsrMatrix::from_edges(4, &edges);
337
338 assert_eq!(csr.degree(NodeId(0)), 3);
339 assert_eq!(csr.degree(NodeId(1)), 1);
340 assert_eq!(csr.degree(NodeId(2)), 0);
341 assert_eq!(csr.degree(NodeId(3)), 0);
342 }
343
344 #[test]
345 fn test_has_edge() {
346 let edges = [(0, 1), (1, 2)];
347 let csr = CsrMatrix::from_edges(3, &edges);
348
349 assert!(csr.has_edge(NodeId(0), NodeId(1)));
350 assert!(csr.has_edge(NodeId(1), NodeId(2)));
351 assert!(!csr.has_edge(NodeId(0), NodeId(2)));
352 assert!(!csr.has_edge(NodeId(2), NodeId(0)));
353 }
354
355 #[test]
356 fn test_weighted_edges() {
357 let edges = [(0, 1, 1.5), (0, 2, 2.5), (1, 2, 3.0)];
358 let csr = CsrMatrix::from_weighted_edges(3, &edges);
359
360 assert!(csr.values.is_some());
361
362 let neighbors = csr.weighted_neighbors(NodeId(0));
363 assert_eq!(neighbors.len(), 2);
364 assert!(neighbors.contains(&(NodeId(1), 1.5)));
365 assert!(neighbors.contains(&(NodeId(2), 2.5)));
366 }
367
368 #[test]
369 fn test_transpose() {
370 let edges = [(0, 1), (1, 2)];
372 let csr = CsrMatrix::from_edges(3, &edges);
373 let transposed = csr.transpose();
374
375 assert!(transposed.has_edge(NodeId(1), NodeId(0)));
377 assert!(transposed.has_edge(NodeId(2), NodeId(1)));
378 assert!(!transposed.has_edge(NodeId(0), NodeId(1)));
379 }
380
381 #[test]
382 fn test_builder() {
383 let mut builder = CsrMatrixBuilder::new(4);
384 builder.add_edge(0, 1);
385 builder.add_edge(0, 2);
386 builder.add_weighted_edge(1, 3, 2.5);
387
388 let csr = builder.build();
389 assert_eq!(csr.num_nonzeros(), 3);
390 assert!(csr.values.is_some());
391 }
392
393 #[test]
394 fn test_validation() {
395 let csr = CsrMatrix::from_edges(3, &[(0, 1), (1, 2)]);
397 assert!(csr.validate().is_ok());
398
399 let invalid = CsrMatrix {
401 num_rows: 3,
402 num_cols: 3,
403 row_ptr: vec![0, 1, 2, 2],
404 col_idx: vec![1, 10], values: None,
406 };
407 assert!(invalid.validate().is_err());
408 }
409}