1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
use std::sync::Arc;

use once_cell::sync::Lazy;
use rayon::{ThreadPool, ThreadPoolBuilder};

pub(crate) mod context;
pub mod eval_vertex;
pub mod task;
pub mod task_runner;
pub(crate) mod task_state;

pub static POOL: Lazy<Arc<ThreadPool>> = Lazy::new(|| {
    let num_threads = std::env::var("DOCBROWN_MAX_THREADS")
        .map(|s| {
            s.parse::<usize>()
                .expect("DOCBROWN_MAX_THREADS must be a number")
        })
        .unwrap_or_else(|_| {
            std::thread::available_parallelism()
                .unwrap_or(std::num::NonZeroUsize::new(1).unwrap())
                .get()
        });

    let pool = ThreadPoolBuilder::new()
        .num_threads(num_threads)
        .build()
        .unwrap();

    Arc::new(pool)
});

pub fn custom_pool(n_threads: usize) -> Arc<ThreadPool> {
    let pool = ThreadPoolBuilder::new()
        .num_threads(n_threads)
        .build()
        .unwrap();

    Arc::new(pool)
}

#[cfg(test)]
mod task_tests {
    use crate::{
        core::state::{self, compute_state::ComputeStateVec},
        db::graph::Graph,
    };

    use super::{
        context::Context,
        task::{ATask, Job, Step},
        task_runner::TaskRunner,
    };

    // count all the vertices with a global state
    #[test]
    fn count_all_vertices_with_global_state() {
        let graph = Graph::new(2);

        let edges = vec![
            (1, 2, 1),
            (2, 3, 2),
            (3, 4, 3),
            (3, 5, 4),
            (6, 5, 5),
            (7, 8, 6),
            (8, 7, 7),
        ];

        for (src, dst, ts) in edges {
            graph.add_edge(ts, src, dst, &vec![], None).unwrap();
        }

        let mut ctx: Context<Graph, ComputeStateVec> = (&graph).into();

        let count = state::accumulator_id::accumulators::sum::<usize>(0);

        ctx.global_agg(count.clone());

        let step1 = ATask::new(move |vv| {
            vv.global_update(&count, 1);
            Step::Done
        });

        let mut runner = TaskRunner::new(ctx);

        let (_, global_state, _) =
            runner.run(vec![], vec![Job::new(step1)], Some(2), 1, None, None);

        let actual = global_state.inner().read_global(0, &count);

        assert_eq!(actual, Some(8));
    }
}