1#![allow(non_snake_case)]
5
6use std::fs::File;
57use std::io::prelude::*;
58use std::io::BufWriter;
59
60use log::*;
61use pluma_plugin_trait::PluMAPlugin;
62
63#[cfg(feature = "gpu")]
65pub mod gpu;
66
67#[cfg(feature = "cuda")]
69pub mod cuda;
70
71pub mod pluma_ffi;
73
74type Result<T = ()> = std::result::Result<T, Box<dyn std::error::Error>>;
76
77#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
79pub enum ComputeBackend {
80 #[default]
82 Cpu,
83 Gpu,
85 Cuda,
87 Auto,
89}
90
91#[derive(Debug)]
92pub struct ATriaPlugin {
93 pub gsize: usize,
95 pub bacteria: Vec<String>,
97 pub orig_graph: Vec<f32>,
99 pub output: Vec<f32>,
101 backend: ComputeBackend,
103 #[cfg(feature = "gpu")]
105 gpu_context: Option<gpu::GpuContext>,
106 #[cfg(feature = "cuda")]
108 cuda_context: Option<cuda::CudaContext>,
109}
110
111#[allow(clippy::derivable_impls)]
113impl Default for ATriaPlugin {
114 fn default() -> Self {
115 ATriaPlugin {
116 gsize: 0,
117 bacteria: Vec::new(),
118 orig_graph: Vec::new(),
119 output: Vec::new(),
120 backend: ComputeBackend::default(),
121 #[cfg(feature = "gpu")]
122 gpu_context: None,
123 #[cfg(feature = "cuda")]
124 cuda_context: None,
125 }
126 }
127}
128
129impl ATriaPlugin {
130 pub fn with_backend(backend: ComputeBackend) -> Self {
132 let mut plugin = Self::default();
133 plugin.set_backend(backend);
134 plugin
135 }
136
137 #[inline(always)]
138 fn size(&self) -> usize {
139 self.gsize
140 }
141
142 pub fn set_backend(&mut self, backend: ComputeBackend) {
144 self.backend = backend;
145
146 #[cfg(feature = "cuda")]
148 {
149 if matches!(backend, ComputeBackend::Cuda | ComputeBackend::Auto)
150 && self.cuda_context.is_none()
151 {
152 self.cuda_context = cuda::CudaContext::new();
153 if self.cuda_context.is_some() {
154 info!("CUDA context initialized successfully");
155 } else {
156 warn!("Failed to initialize CUDA context");
157 }
158 }
159 }
160
161 #[cfg(not(feature = "cuda"))]
162 {
163 if matches!(backend, ComputeBackend::Cuda) {
164 warn!("CUDA backend requested but 'cuda' feature is not enabled, using CPU");
165 }
166 }
167
168 #[cfg(feature = "gpu")]
170 {
171 if matches!(backend, ComputeBackend::Gpu | ComputeBackend::Auto)
172 && self.gpu_context.is_none()
173 {
174 self.gpu_context = gpu::GpuContext::new();
175 if self.gpu_context.is_some() {
176 info!("wgpu GPU context initialized successfully");
177 } else {
178 warn!("Failed to initialize wgpu GPU context");
179 }
180 }
181 }
182
183 #[cfg(not(feature = "gpu"))]
184 {
185 if matches!(backend, ComputeBackend::Gpu) {
186 warn!("GPU backend requested but 'gpu' feature is not enabled, using CPU");
187 }
188 }
189 }
190
191 pub fn backend(&self) -> ComputeBackend {
193 self.backend
194 }
195
196 pub fn set_use_gpu(&mut self, use_gpu: bool) {
198 self.set_backend(if use_gpu { ComputeBackend::Gpu } else { ComputeBackend::Cpu });
199 }
200
201 #[cfg(feature = "gpu")]
203 pub fn is_gpu_available(&self) -> bool {
204 self.gpu_context.is_some()
205 }
206
207 #[cfg(not(feature = "gpu"))]
209 pub fn is_gpu_available(&self) -> bool {
210 false
211 }
212
213 #[cfg(feature = "cuda")]
215 pub fn is_cuda_available(&self) -> bool {
216 self.cuda_context.is_some()
217 }
218
219 #[cfg(not(feature = "cuda"))]
221 pub fn is_cuda_available(&self) -> bool {
222 false
223 }
224
225 pub fn effective_backend(&self) -> ComputeBackend {
227 match self.backend {
228 ComputeBackend::Cpu => ComputeBackend::Cpu,
229 ComputeBackend::Gpu => {
230 if self.is_gpu_available() {
231 ComputeBackend::Gpu
232 } else {
233 ComputeBackend::Cpu
234 }
235 }
236 ComputeBackend::Cuda => {
237 if self.is_cuda_available() {
238 ComputeBackend::Cuda
239 } else {
240 ComputeBackend::Cpu
241 }
242 }
243 ComputeBackend::Auto => {
244 if self.is_cuda_available() {
246 ComputeBackend::Cuda
247 } else if self.is_gpu_available() {
248 ComputeBackend::Gpu
249 } else {
250 ComputeBackend::Cpu
251 }
252 }
253 }
254 }
255}
256
257#[inline]
260pub fn cpu_floyd(g: &mut [f32], n: usize) {
261 let ptr = g.as_mut_ptr();
264
265 for k in 0..n {
266 let k_row_offset = k * n;
267
268 for i in 0..n {
269 let i_row_offset = i * n;
270
271 unsafe {
274 let g_i_k = *ptr.add(i_row_offset + k);
275
276 for j in 0..n {
277 if i == j || j == k {
279 continue;
280 }
281
282 let curloc = i_row_offset + j;
283 let g_k_j = *ptr.add(k_row_offset + j);
284 let product = g_i_k * g_k_j;
285 let current = *ptr.add(curloc);
286 let evenodd = i + j;
287
288 if (evenodd & 1) == 0 {
290 if current < product {
292 *ptr.add(curloc) = product;
293 }
294 } else {
295 if current > product {
297 *ptr.add(curloc) = product;
298 }
299 }
300 }
301 }
302 }
303 }
304}
305
306impl PluMAPlugin for ATriaPlugin {
307 fn input(&mut self, file_path: String) -> Result {
309 let mut reader = csv::Reader::from_path(&file_path).expect("Unable to open CSV file");
310
311 {
313 let headers = reader
314 .headers()
315 .expect("Unable to read CSV headers")
316 .clone();
317
318 self.bacteria.reserve(headers.len() - 1);
320 for header in headers.iter().skip(1) {
321 self.bacteria.push(header.to_string());
322 }
323 }
324
325 self.gsize = self.bacteria.len();
326 let gsize = self.gsize;
327
328 let matrix_size = (gsize * 2) * (gsize * 2);
330 self.orig_graph = vec![0.0f32; matrix_size];
331 self.output = vec![0.0f32; gsize];
332
333 let mut reader = csv::Reader::from_path(&file_path).expect("Unable to open CSV file");
335 let stride = 2 * gsize;
336
337 for (row_count, result) in reader.records().enumerate() {
338 let row = result.expect("Unable to read CSV row");
339 let bac1 = row_count;
340 let bac1_2 = bac1 * 2;
341 let bac1_2_1 = bac1_2 + 1;
342
343 for i in 1..row.len() {
344 let bac2 = i - 1;
345 let bac2_2 = bac2 * 2;
346 let bac2_2_1 = bac2_2 + 1;
347
348 let idx_00 = bac1_2 * stride + bac2_2;
350 let idx_11 = bac1_2_1 * stride + bac2_2_1;
351 let idx_10 = bac1_2_1 * stride + bac2_2;
352 let idx_01 = bac1_2 * stride + bac2_2_1;
353
354 if bac1 != bac2 {
355 let weight: f32 = row[i].parse().expect("Unable to parse weight");
356
357 if weight > 0.0 {
358 self.orig_graph[idx_00] = weight;
359 self.orig_graph[idx_11] = weight;
360 } else if weight < 0.0 {
362 self.orig_graph[idx_10] = weight;
363 self.orig_graph[idx_01] = weight;
364 }
366 } else {
368 self.orig_graph[idx_00] = 1.0;
370 self.orig_graph[idx_11] = 1.0;
371 }
373 }
374 }
375
376 Ok(())
377 }
378
379 fn run(&mut self) -> Result {
381 let effective_backend = self.effective_backend();
382 info!("Running ATria with {:?} backend", effective_backend);
383
384 let gsize = self.gsize;
385 let n = gsize * 2; let stride = n;
387 let matrix_len = n * n;
388
389 let mut h_g = vec![0.0f32; matrix_len];
391 let mut h_pay = vec![0.0f32; gsize];
393 let mut maxnodes = Vec::with_capacity(gsize);
395
396 for _ in 0..gsize {
397 h_g.copy_from_slice(&self.orig_graph);
399
400 match effective_backend {
402 ComputeBackend::Cuda => {
403 #[cfg(feature = "cuda")]
404 {
405 if let Some(ref cuda_ctx) = self.cuda_context {
406 cuda_ctx.floyd_warshall(&mut h_g, n);
407 } else {
408 cpu_floyd(&mut h_g, n);
409 }
410 }
411 #[cfg(not(feature = "cuda"))]
412 {
413 cpu_floyd(&mut h_g, n);
414 }
415 }
416 ComputeBackend::Gpu => {
417 #[cfg(feature = "gpu")]
418 {
419 if let Some(ref gpu_ctx) = self.gpu_context {
420 gpu_ctx.floyd_warshall(&mut h_g, n);
421 } else {
422 cpu_floyd(&mut h_g, n);
423 }
424 }
425 #[cfg(not(feature = "gpu"))]
426 {
427 cpu_floyd(&mut h_g, n);
428 }
429 }
430 ComputeBackend::Cpu | ComputeBackend::Auto => {
431 cpu_floyd(&mut h_g, n);
432 }
433 }
434
435 for (i, pay) in h_pay.iter_mut().enumerate() {
437 let row_start = (i * 2) * stride;
438 *pay = h_g[row_start..row_start + n].iter().sum::<f32>() - 1.0;
439 }
440
441 let mut mnode = 0usize;
443 let mut maxpay = -1.0f32;
444 for (i, &pay) in h_pay.iter().enumerate() {
445 if pay.abs() > maxpay {
446 mnode = i;
447 maxpay = pay.abs();
448 }
449 }
450
451 if maxpay == 0.0 {
452 break;
453 }
454
455 maxnodes.clear();
457 maxnodes.push(mnode);
458 for (i, &pay) in h_pay.iter().enumerate() {
459 if i != mnode && pay.abs() == maxpay {
460 maxnodes.push(i);
461 }
462 }
463
464 for &maxnode in &maxnodes {
466 info!("Node with highest pay: {}: {}", self.bacteria[maxnode], h_pay[maxnode]);
467
468 if maxnode == mnode {
470 self.output[maxnode] = h_pay[maxnode];
471 }
472
473 let maxnode_2 = maxnode * 2;
474 let maxnode_2_1 = maxnode_2 + 1;
475 let maxnode_row_even = maxnode_2 * stride;
476 let maxnode_row_odd = maxnode_2_1 * stride;
477
478 for i in 0..n {
480 if (i / 2) != maxnode {
481 let edge_even = self.orig_graph[maxnode_row_even + i] != 0.0;
482 let edge_odd = self.orig_graph[maxnode_row_odd + i] != 0.0;
483
484 if edge_even || edge_odd {
485 let i_row = i * stride;
486
487 for j in (i + 1)..n {
488 if (j / 2) != maxnode {
489 let connected_j = self.orig_graph[maxnode_row_even + j] != 0.0
490 || self.orig_graph[maxnode_row_odd + j] != 0.0;
491
492 if connected_j && self.orig_graph[i_row + j] != 0.0 {
493 self.orig_graph[i_row + j] = 2.0;
494 self.orig_graph[j * stride + i] = 2.0;
495 }
496 }
497 }
498
499 if edge_even {
500 self.orig_graph[maxnode_row_even + i] = 2.0;
501 self.orig_graph[i_row + maxnode_2] = 2.0;
502 }
503
504 if edge_odd {
505 self.orig_graph[maxnode_row_odd + i] = 2.0;
506 self.orig_graph[i_row + maxnode_2_1] = 2.0;
507 }
508 }
509 }
510 }
511
512 self.orig_graph.iter_mut().for_each(|v| {
514 if *v == 2.0 {
515 *v = 0.0;
516 }
517 });
518 }
519 }
520
521 Ok(())
522 }
523
524 fn output(&mut self, file_path: String) -> Result {
526 let file = File::create(file_path).expect("Unable to open output file location");
528 let mut output_file = BufWriter::new(file);
529
530 let size = self.size();
533 for i in (0..size).rev() {
534 for j in 0..i {
535 if self.output[j].abs() < self.output[j + 1].abs() {
536 self.output.swap(j, j + 1);
537 self.bacteria.swap(j, j + 1);
538 }
539 }
540 }
541
542 writeln!(output_file, "Name\tCentrality\tRank")
543 .expect("Unable to write headers to output file");
544
545 for i in 0..size {
546 self.output[i] = self.output[i].abs();
547
548 writeln!(
549 output_file,
550 "{}\t{}\t\t{}",
551 self.bacteria[i],
552 self.output[i],
553 size - i
554 )
555 .expect("Unable to write to output file");
556 }
557
558 Ok(())
559 }
560}
561
562#[cfg(test)]
563mod tests {
564 use super::*;
565
566 #[test]
567 fn it_should_load_bacteria() {
568 let mut plugin = ATriaPlugin::default();
569
570 plugin
571 .input("./tests/corrP.never.csv".to_string())
572 .expect("Failed to read CSV file");
573
574 assert_eq!(126, plugin.bacteria.len());
575 }
576
577 #[test]
578 fn it_can_run() {
579 let mut plugin = ATriaPlugin::default();
580
581 plugin.gsize = 2;
582
583 plugin.orig_graph = vec![
584 1.0, 0.0, 0.5, 0.0,
585 0.0, 1.0, 0.0, 0.5,
586 0.5, 0.0, 1.0, 0.0,
587 0.0, 0.5, 0.0, 1.0,
588 ];
589
590 plugin.bacteria = vec![
591 String::from("Test Bac 1"),
592 String::from("Test Bac 2"),
593 ];
594
595 plugin.output.resize(2, 0.0);
596
597 assert!(plugin.run().is_ok());
598 }
599
600 #[test]
601 fn it_works() {
602 let mut plugin = ATriaPlugin::default();
603
604 plugin
605 .input("./tests/corrP.never.csv".to_string())
606 .expect("Failed to read CSV file...");
607
608 plugin.run().expect("Failed to run ATria...");
609
610 plugin
611 .output("./tests/corrP.never.noa".to_string())
612 .expect("Failed to write NOA file...");
613
614 let mut expected = File::open("./tests/corrP.never.noa.expected")
615 .expect("Failed to open expected output file...");
616
617 let mut expected_content = String::new();
618 expected
619 .read_to_string(&mut expected_content)
620 .expect("Failed to read expected output file.");
621
622 let mut actual =
623 File::open("./tests/corrP.never.noa").expect("Failed to open generated output file...");
624
625 let mut actual_content = String::new();
626 actual
627 .read_to_string(&mut actual_content)
628 .expect("Failed to read actual output to file.");
629
630 assert_eq!(expected_content, actual_content);
631 }
632}