ATriaPlugin/
lib.rs

1// Copyright (C) 2020 Joseph R. Quinn
2// SPDX-License-Identifier: MIT
3
4#![allow(non_snake_case)]
5
6//! # ATria-rs
7//!
8//! Library for the Ablatio Triadum (ATria) centrality algorithm
9//! (Cickovski et al, 2015, 2017).
10//!
11//! ATria can run on signed and weighted networks and produces a list of
12//! central nodes as both screen output and as a NOde Attribute (NOA)
13//! file for Cytoscape. The NOA file can subsequently be imported into
14//! Cytoscape resulting in centrality values becoming node attributes and
15//! enabling further analysis and visualization based on these values.
16//!
17//! The input network should be specified in CSV format with nodes as rows
18//! and columns and entry (i, j) representing the weight of the edge from
19//! node i to node j.
20//!
21//! The output is the NOA file, with both centrality value and rank as
22//! attributes. Larger magnitude values indicate higher centrality for
23//! both centrality and rank. This is typically more convenient for
24//! visualization, etc.
25//!
26//! ## GPU Acceleration
27//!
28//! Enable the `gpu` feature for cross-platform GPU acceleration (Vulkan/Metal/DX12):
29//!
30//! ```toml
31//! [dependencies]
32//! atria-rs = { version = "1.2", features = ["gpu"] }
33//! ```
34//!
35//! ## CUDA Acceleration
36//!
37//! Enable the `cuda` feature for NVIDIA CUDA acceleration:
38//!
39//! ```toml
40//! [dependencies]
41//! atria-rs = { version = "1.2", features = ["cuda"] }
42//! ```
43//!
44//! Then configure the plugin to use the desired backend:
45//!
46//! ```ignore
47//! let mut plugin = ATriaPlugin::default();
48//! plugin.set_backend(ComputeBackend::Gpu);   // wgpu
49//! plugin.set_backend(ComputeBackend::Cuda);  // NVIDIA CUDA
50//! plugin.set_backend(ComputeBackend::Auto);  // Best available
51//! ```
52//!
53//! Original code for the C++ version of this library may be
54//! found [here](https://github.com/movingpictures83/ATria).
55
56use std::fs::File;
57use std::io::prelude::*;
58use std::io::BufWriter;
59
60use log::*;
61use pluma_plugin_trait::PluMAPlugin;
62
63// GPU module - wgpu (conditional compilation)
64#[cfg(feature = "gpu")]
65pub mod gpu;
66
67// CUDA module (conditional compilation)
68#[cfg(feature = "cuda")]
69pub mod cuda;
70
71// PluMA FFI exports
72pub mod pluma_ffi;
73
74/// Standard replacement for crate-level `std::result::Result<(), Box<dyn std::error::Error>>`
75type Result<T = ()> = std::result::Result<T, Box<dyn std::error::Error>>;
76
77/// Configuration for compute backend selection
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
79pub enum ComputeBackend {
80    /// Use CPU for computation (default)
81    #[default]
82    Cpu,
83    /// Use wgpu GPU for computation (requires `gpu` feature, cross-platform)
84    Gpu,
85    /// Use NVIDIA CUDA for computation (requires `cuda` feature)
86    Cuda,
87    /// Automatically select best available backend (prefers CUDA > GPU > CPU)
88    Auto,
89}
90
91#[derive(Debug)]
92pub struct ATriaPlugin {
93    /// Number of bacteria (GSIZE in C++)
94    pub gsize: usize,
95    /// Vector of the bacteria types in the CSV file.
96    pub bacteria: Vec<String>,
97    /// The original matrix being worked on by the ATria algorithm (2N x 2N).
98    pub orig_graph: Vec<f32>,
99    /// Output centrality values (stores pay values, NOT ranks)
100    pub output: Vec<f32>,
101    /// Compute backend selection
102    backend: ComputeBackend,
103    /// wgpu GPU context (lazily initialized when gpu feature is enabled)
104    #[cfg(feature = "gpu")]
105    gpu_context: Option<gpu::GpuContext>,
106    /// CUDA context (lazily initialized when cuda feature is enabled)
107    #[cfg(feature = "cuda")]
108    cuda_context: Option<cuda::CudaContext>,
109}
110
111// Manual Default impl required due to conditional #[cfg] fields
112#[allow(clippy::derivable_impls)]
113impl Default for ATriaPlugin {
114    fn default() -> Self {
115        ATriaPlugin {
116            gsize: 0,
117            bacteria: Vec::new(),
118            orig_graph: Vec::new(),
119            output: Vec::new(),
120            backend: ComputeBackend::default(),
121            #[cfg(feature = "gpu")]
122            gpu_context: None,
123            #[cfg(feature = "cuda")]
124            cuda_context: None,
125        }
126    }
127}
128
129impl ATriaPlugin {
130    /// Create a new ATriaPlugin with the specified compute backend
131    pub fn with_backend(backend: ComputeBackend) -> Self {
132        let mut plugin = Self::default();
133        plugin.set_backend(backend);
134        plugin
135    }
136
137    #[inline(always)]
138    fn size(&self) -> usize {
139        self.gsize
140    }
141
142    /// Set the compute backend to use
143    pub fn set_backend(&mut self, backend: ComputeBackend) {
144        self.backend = backend;
145
146        // Initialize CUDA context if needed
147        #[cfg(feature = "cuda")]
148        {
149            if matches!(backend, ComputeBackend::Cuda | ComputeBackend::Auto)
150                && self.cuda_context.is_none()
151            {
152                self.cuda_context = cuda::CudaContext::new();
153                if self.cuda_context.is_some() {
154                    info!("CUDA context initialized successfully");
155                } else {
156                    warn!("Failed to initialize CUDA context");
157                }
158            }
159        }
160
161        #[cfg(not(feature = "cuda"))]
162        {
163            if matches!(backend, ComputeBackend::Cuda) {
164                warn!("CUDA backend requested but 'cuda' feature is not enabled, using CPU");
165            }
166        }
167
168        // Initialize wgpu GPU context if needed
169        #[cfg(feature = "gpu")]
170        {
171            if matches!(backend, ComputeBackend::Gpu | ComputeBackend::Auto)
172                && self.gpu_context.is_none()
173            {
174                self.gpu_context = gpu::GpuContext::new();
175                if self.gpu_context.is_some() {
176                    info!("wgpu GPU context initialized successfully");
177                } else {
178                    warn!("Failed to initialize wgpu GPU context");
179                }
180            }
181        }
182
183        #[cfg(not(feature = "gpu"))]
184        {
185            if matches!(backend, ComputeBackend::Gpu) {
186                warn!("GPU backend requested but 'gpu' feature is not enabled, using CPU");
187            }
188        }
189    }
190
191    /// Get the current compute backend
192    pub fn backend(&self) -> ComputeBackend {
193        self.backend
194    }
195
196    /// Enable GPU acceleration (convenience method)
197    pub fn set_use_gpu(&mut self, use_gpu: bool) {
198        self.set_backend(if use_gpu { ComputeBackend::Gpu } else { ComputeBackend::Cpu });
199    }
200
201    /// Check if wgpu GPU is available for computation
202    #[cfg(feature = "gpu")]
203    pub fn is_gpu_available(&self) -> bool {
204        self.gpu_context.is_some()
205    }
206
207    /// Check if wgpu GPU is available for computation
208    #[cfg(not(feature = "gpu"))]
209    pub fn is_gpu_available(&self) -> bool {
210        false
211    }
212
213    /// Check if CUDA is available for computation
214    #[cfg(feature = "cuda")]
215    pub fn is_cuda_available(&self) -> bool {
216        self.cuda_context.is_some()
217    }
218
219    /// Check if CUDA is available for computation
220    #[cfg(not(feature = "cuda"))]
221    pub fn is_cuda_available(&self) -> bool {
222        false
223    }
224
225    /// Get the effective backend that will be used for computation
226    pub fn effective_backend(&self) -> ComputeBackend {
227        match self.backend {
228            ComputeBackend::Cpu => ComputeBackend::Cpu,
229            ComputeBackend::Gpu => {
230                if self.is_gpu_available() {
231                    ComputeBackend::Gpu
232                } else {
233                    ComputeBackend::Cpu
234                }
235            }
236            ComputeBackend::Cuda => {
237                if self.is_cuda_available() {
238                    ComputeBackend::Cuda
239                } else {
240                    ComputeBackend::Cpu
241                }
242            }
243            ComputeBackend::Auto => {
244                // Prefer CUDA > GPU > CPU
245                if self.is_cuda_available() {
246                    ComputeBackend::Cuda
247                } else if self.is_gpu_available() {
248                    ComputeBackend::Gpu
249                } else {
250                    ComputeBackend::Cpu
251                }
252            }
253        }
254    }
255}
256
257/// Modified Floyd-Warshall algorithm for ATria - optimized version
258/// Uses unsafe pointer arithmetic to avoid bounds checking in hot loops
259#[inline]
260pub fn cpu_floyd(g: &mut [f32], n: usize) {
261    // Pre-compute k*n once per outer loop iteration
262    // Use unsafe to avoid bounds checks in the innermost loop
263    let ptr = g.as_mut_ptr();
264
265    for k in 0..n {
266        let k_row_offset = k * n;
267
268        for i in 0..n {
269            let i_row_offset = i * n;
270
271            // SAFETY: All indices are within bounds since i, j, k < n
272            // and the array has n*n elements
273            unsafe {
274                let g_i_k = *ptr.add(i_row_offset + k);
275
276                for j in 0..n {
277                    // Skip diagonal and when j == k
278                    if i == j || j == k {
279                        continue;
280                    }
281
282                    let curloc = i_row_offset + j;
283                    let g_k_j = *ptr.add(k_row_offset + j);
284                    let product = g_i_k * g_k_j;
285                    let current = *ptr.add(curloc);
286                    let evenodd = i + j;
287
288                    // Use bitwise AND for parity check (faster than modulo)
289                    if (evenodd & 1) == 0 {
290                        // Even: maximize
291                        if current < product {
292                            *ptr.add(curloc) = product;
293                        }
294                    } else {
295                        // Odd: minimize (for negative paths)
296                        if current > product {
297                            *ptr.add(curloc) = product;
298                        }
299                    }
300                }
301            }
302        }
303    }
304}
305
306impl PluMAPlugin for ATriaPlugin {
307    /// Create a 2Nx2N adjacency matrix from the input CSV file.
308    fn input(&mut self, file_path: String) -> Result {
309        let mut reader = csv::Reader::from_path(&file_path).expect("Unable to open CSV file");
310
311        // First pass: count rows to determine GSIZE
312        {
313            let headers = reader
314                .headers()
315                .expect("Unable to read CSV headers")
316                .clone();
317
318            // Pre-allocate with expected capacity
319            self.bacteria.reserve(headers.len() - 1);
320            for header in headers.iter().skip(1) {
321                self.bacteria.push(header.to_string());
322            }
323        }
324
325        self.gsize = self.bacteria.len();
326        let gsize = self.gsize;
327
328        // Allocate 2N x 2N matrix
329        let matrix_size = (gsize * 2) * (gsize * 2);
330        self.orig_graph = vec![0.0f32; matrix_size];
331        self.output = vec![0.0f32; gsize];
332
333        // Re-read to populate matrix
334        let mut reader = csv::Reader::from_path(&file_path).expect("Unable to open CSV file");
335        let stride = 2 * gsize;
336
337        for (row_count, result) in reader.records().enumerate() {
338            let row = result.expect("Unable to read CSV row");
339            let bac1 = row_count;
340            let bac1_2 = bac1 * 2;
341            let bac1_2_1 = bac1_2 + 1;
342
343            for i in 1..row.len() {
344                let bac2 = i - 1;
345                let bac2_2 = bac2 * 2;
346                let bac2_2_1 = bac2_2 + 1;
347
348                // Pre-compute indices
349                let idx_00 = bac1_2 * stride + bac2_2;
350                let idx_11 = bac1_2_1 * stride + bac2_2_1;
351                let idx_10 = bac1_2_1 * stride + bac2_2;
352                let idx_01 = bac1_2 * stride + bac2_2_1;
353
354                if bac1 != bac2 {
355                    let weight: f32 = row[i].parse().expect("Unable to parse weight");
356
357                    if weight > 0.0 {
358                        self.orig_graph[idx_00] = weight;
359                        self.orig_graph[idx_11] = weight;
360                        // idx_10 and idx_01 already 0 from initialization
361                    } else if weight < 0.0 {
362                        self.orig_graph[idx_10] = weight;
363                        self.orig_graph[idx_01] = weight;
364                        // idx_00 and idx_11 already 0 from initialization
365                    }
366                    // weight == 0: all already initialized to 0
367                } else {
368                    // Diagonal: start at 1 because they are starting verts
369                    self.orig_graph[idx_00] = 1.0;
370                    self.orig_graph[idx_11] = 1.0;
371                    // idx_10 and idx_01 already 0 from initialization
372                }
373            }
374        }
375
376        Ok(())
377    }
378
379    /// Run the ATria algorithm over the input data.
380    fn run(&mut self) -> Result {
381        let effective_backend = self.effective_backend();
382        info!("Running ATria with {:?} backend", effective_backend);
383
384        let gsize = self.gsize;
385        let n = gsize * 2;  // Matrix dimension
386        let stride = n;
387        let matrix_len = n * n;
388
389        // Working copy of graph for Floyd-Warshall - allocate once
390        let mut h_g = vec![0.0f32; matrix_len];
391        // Pay values for each bacterium
392        let mut h_pay = vec![0.0f32; gsize];
393        // Pre-allocate maxnodes vector
394        let mut maxnodes = Vec::with_capacity(gsize);
395
396        for _ in 0..gsize {
397            // Copy original graph for computation - use fast copy
398            h_g.copy_from_slice(&self.orig_graph);
399
400            // Run modified Floyd-Warshall using selected backend
401            match effective_backend {
402                ComputeBackend::Cuda => {
403                    #[cfg(feature = "cuda")]
404                    {
405                        if let Some(ref cuda_ctx) = self.cuda_context {
406                            cuda_ctx.floyd_warshall(&mut h_g, n);
407                        } else {
408                            cpu_floyd(&mut h_g, n);
409                        }
410                    }
411                    #[cfg(not(feature = "cuda"))]
412                    {
413                        cpu_floyd(&mut h_g, n);
414                    }
415                }
416                ComputeBackend::Gpu => {
417                    #[cfg(feature = "gpu")]
418                    {
419                        if let Some(ref gpu_ctx) = self.gpu_context {
420                            gpu_ctx.floyd_warshall(&mut h_g, n);
421                        } else {
422                            cpu_floyd(&mut h_g, n);
423                        }
424                    }
425                    #[cfg(not(feature = "gpu"))]
426                    {
427                        cpu_floyd(&mut h_g, n);
428                    }
429                }
430                ComputeBackend::Cpu | ComputeBackend::Auto => {
431                    cpu_floyd(&mut h_g, n);
432                }
433            }
434
435            // Calculate pay for each bacterium - using iterators for better SIMD
436            for (i, pay) in h_pay.iter_mut().enumerate() {
437                let row_start = (i * 2) * stride;
438                *pay = h_g[row_start..row_start + n].iter().sum::<f32>() - 1.0;
439            }
440
441            // Find node(s) with maximum pay - find FIRST maximum (matching original behavior)
442            let mut mnode = 0usize;
443            let mut maxpay = -1.0f32;
444            for (i, &pay) in h_pay.iter().enumerate() {
445                if pay.abs() > maxpay {
446                    mnode = i;
447                    maxpay = pay.abs();
448                }
449            }
450
451            if maxpay == 0.0 {
452                break;
453            }
454
455            // Find all nodes with same max pay
456            maxnodes.clear();
457            maxnodes.push(mnode);
458            for (i, &pay) in h_pay.iter().enumerate() {
459                if i != mnode && pay.abs() == maxpay {
460                    maxnodes.push(i);
461                }
462            }
463
464            // Process each max node
465            for &maxnode in &maxnodes {
466                info!("Node with highest pay: {}: {}", self.bacteria[maxnode], h_pay[maxnode]);
467
468                // Only record centrality for the first node (mnode), not ties
469                if maxnode == mnode {
470                    self.output[maxnode] = h_pay[maxnode];
471                }
472
473                let maxnode_2 = maxnode * 2;
474                let maxnode_2_1 = maxnode_2 + 1;
475                let maxnode_row_even = maxnode_2 * stride;
476                let maxnode_row_odd = maxnode_2_1 * stride;
477
478                // Non-GPU Triad Removal
479                for i in 0..n {
480                    if (i / 2) != maxnode {
481                        let edge_even = self.orig_graph[maxnode_row_even + i] != 0.0;
482                        let edge_odd = self.orig_graph[maxnode_row_odd + i] != 0.0;
483
484                        if edge_even || edge_odd {
485                            let i_row = i * stride;
486
487                            for j in (i + 1)..n {
488                                if (j / 2) != maxnode {
489                                    let connected_j = self.orig_graph[maxnode_row_even + j] != 0.0
490                                        || self.orig_graph[maxnode_row_odd + j] != 0.0;
491
492                                    if connected_j && self.orig_graph[i_row + j] != 0.0 {
493                                        self.orig_graph[i_row + j] = 2.0;
494                                        self.orig_graph[j * stride + i] = 2.0;
495                                    }
496                                }
497                            }
498
499                            if edge_even {
500                                self.orig_graph[maxnode_row_even + i] = 2.0;
501                                self.orig_graph[i_row + maxnode_2] = 2.0;
502                            }
503
504                            if edge_odd {
505                                self.orig_graph[maxnode_row_odd + i] = 2.0;
506                                self.orig_graph[i_row + maxnode_2_1] = 2.0;
507                            }
508                        }
509                    }
510                }
511
512                // Sweep through and remove marked edges - vectorized
513                self.orig_graph.iter_mut().for_each(|v| {
514                    if *v == 2.0 {
515                        *v = 0.0;
516                    }
517                });
518            }
519        }
520
521        Ok(())
522    }
523
524    /// Write the results of the ATria calulations to a NOA file.
525    fn output(&mut self, file_path: String) -> Result {
526        // Use buffered writer for better I/O performance
527        let file = File::create(file_path).expect("Unable to open output file location");
528        let mut output_file = BufWriter::new(file);
529
530        // Sort by absolute value of output (descending) using bubble sort
531        // (maintains compatibility with original algorithm output)
532        let size = self.size();
533        for i in (0..size).rev() {
534            for j in 0..i {
535                if self.output[j].abs() < self.output[j + 1].abs() {
536                    self.output.swap(j, j + 1);
537                    self.bacteria.swap(j, j + 1);
538                }
539            }
540        }
541
542        writeln!(output_file, "Name\tCentrality\tRank")
543            .expect("Unable to write headers to output file");
544
545        for i in 0..size {
546            self.output[i] = self.output[i].abs();
547
548            writeln!(
549                output_file,
550                "{}\t{}\t\t{}",
551                self.bacteria[i],
552                self.output[i],
553                size - i
554            )
555            .expect("Unable to write to output file");
556        }
557
558        Ok(())
559    }
560}
561
562#[cfg(test)]
563mod tests {
564    use super::*;
565
566    #[test]
567    fn it_should_load_bacteria() {
568        let mut plugin = ATriaPlugin::default();
569
570        plugin
571            .input("./tests/corrP.never.csv".to_string())
572            .expect("Failed to read CSV file");
573
574        assert_eq!(126, plugin.bacteria.len());
575    }
576
577    #[test]
578    fn it_can_run() {
579        let mut plugin = ATriaPlugin::default();
580
581        plugin.gsize = 2;
582
583        plugin.orig_graph = vec![
584            1.0, 0.0, 0.5, 0.0,
585            0.0, 1.0, 0.0, 0.5,
586            0.5, 0.0, 1.0, 0.0,
587            0.0, 0.5, 0.0, 1.0,
588        ];
589
590        plugin.bacteria = vec![
591            String::from("Test Bac 1"),
592            String::from("Test Bac 2"),
593        ];
594
595        plugin.output.resize(2, 0.0);
596
597        assert!(plugin.run().is_ok());
598    }
599
600    #[test]
601    fn it_works() {
602        let mut plugin = ATriaPlugin::default();
603
604        plugin
605            .input("./tests/corrP.never.csv".to_string())
606            .expect("Failed to read CSV file...");
607
608        plugin.run().expect("Failed to run ATria...");
609
610        plugin
611            .output("./tests/corrP.never.noa".to_string())
612            .expect("Failed to write NOA file...");
613
614        let mut expected = File::open("./tests/corrP.never.noa.expected")
615            .expect("Failed to open expected output file...");
616
617        let mut expected_content = String::new();
618        expected
619            .read_to_string(&mut expected_content)
620            .expect("Failed to read expected output file.");
621
622        let mut actual =
623            File::open("./tests/corrP.never.noa").expect("Failed to open generated output file...");
624
625        let mut actual_content = String::new();
626        actual
627            .read_to_string(&mut actual_content)
628            .expect("Failed to read actual output to file.");
629
630        assert_eq!(expected_content, actual_content);
631    }
632}