pub mod inmemory;
pub mod wrappers;
pub use inmemory::{InMemory, InMemoryBuilder, InMemoryGraph};
pub use wrappers::{GraphblasMatrix, GraphblasVector, LagraphGraph, load_mm_file};
pub(crate) use wrappers::{ThreadScope, compute_outer_inner, ensure_grb_init};
use std::marker::PhantomData;
use std::sync::Arc;
use crate::lagraph_sys::GrB_Info;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum GraphError {
#[error("GraphBLAS error: info code {0}")]
GraphBlas(GrB_Info),
#[error("GraphBLAS error: info code {0}; msg: {1}")]
LAGraph(GrB_Info, String),
#[error("LAGraph initialization failed")]
InitFailed,
#[error("Label not found: '{0}'")]
LabelNotFound(String),
#[error("Format error: {0}")]
Format(#[from] crate::formats::FormatError),
}
#[derive(Debug, Clone)]
pub struct Edge {
pub source: String,
pub target: String,
pub label: String,
}
pub trait GraphSource<B: GraphBuilder> {
fn apply_to(self, builder: B) -> Result<B, B::Error>;
}
pub trait GraphBuilder: Default + Sized {
type Graph: GraphDecomposition;
type Error: std::error::Error + Send + Sync + 'static;
fn load<S: GraphSource<Self>>(self, source: S) -> Result<Self, Self::Error> {
source.apply_to(self)
}
fn build(self) -> Result<Self::Graph, Self::Error>;
}
pub trait GraphDecomposition {
fn get_graph(&self, label: &str) -> Result<Arc<LagraphGraph>, GraphError>;
fn get_node_id(&self, string_id: &str) -> Option<usize>;
fn get_node_name(&self, mapped_id: usize) -> Option<String>;
fn num_nodes(&self) -> usize;
}
pub trait Backend {
type Graph: GraphDecomposition;
type Builder: GraphBuilder<Graph = Self::Graph>;
}
pub struct Graph<B: Backend> {
_marker: PhantomData<B>,
}
impl<B: Backend> Graph<B> {
pub fn builder() -> B::Builder {
B::Builder::default()
}
pub fn try_from<S>(source: S) -> Result<B::Graph, <B::Builder as GraphBuilder>::Error>
where
S: GraphSource<B::Builder>,
{
B::Builder::default().load(source)?.build()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::utils::{CountOutput, CountingBuilder, VecSource};
fn edges(triples: &[(&str, &str, &str)]) -> Vec<Edge> {
triples
.iter()
.map(|&(s, t, l)| Edge {
source: s.into(),
target: t.into(),
label: l.into(),
})
.collect()
}
#[test]
fn test_load_and_build() {
let source = VecSource(edges(&[
("A", "B", "knows"),
("B", "C", "knows"),
("A", "C", "likes"),
]));
let output: CountOutput<std::convert::Infallible> = CountingBuilder::default()
.load(source)
.unwrap()
.build()
.unwrap();
assert_eq!(output.num_nodes(), 3);
}
#[test]
fn test_compute_outer_inner_product_bounded_by_cores() {
let cores = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
for num_tasks in [0usize, 1, 2, 4, 8, 16, 64, 1024] {
let (outer, inner) = compute_outer_inner(num_tasks);
assert!(outer >= 1, "outer must be >= 1 for num_tasks={num_tasks}");
assert!(inner >= 1, "inner must be >= 1 for num_tasks={num_tasks}");
let product = (outer as usize) * (inner as usize);
assert!(
product <= cores.max(1),
"outer*inner ({outer}*{inner}={product}) must not exceed cores ({cores}) for num_tasks={num_tasks}"
);
}
}
#[test]
fn test_compute_outer_inner_caps_outer_at_tasks() {
let (outer, _inner) = compute_outer_inner(1);
assert_eq!(outer, 1);
let (outer, _inner) = compute_outer_inner(2);
assert!(outer <= 2);
}
#[test]
fn test_graph_try_from() {
struct TestBackend;
impl Backend for TestBackend {
type Graph = CountOutput<std::convert::Infallible>;
type Builder = CountingBuilder<std::convert::Infallible>;
}
let source = VecSource(edges(&[("X", "Y", "rel"), ("Y", "Z", "rel")]));
let g = Graph::<TestBackend>::try_from(source).unwrap();
assert_eq!(g.num_nodes(), 2);
}
}