1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
//! Purpose-built GEMV thread pool with barrier synchronization for quantized matmul dispatch.
// PMAT-310: Purpose-built GEMV thread pool with barrier synchronization.
// Replaces rayon for matmul dispatch. Zero work-stealing overhead.
use std::sync::{Arc, Barrier, Mutex};
type WorkFn = Box<dyn Fn(usize, usize) + Send + Sync>;
struct SharedWork {
work: Mutex<Option<(WorkFn, usize)>>, // (function, out_dim)
}
pub(crate) struct GemvPool {
n_threads: usize,
shared: Arc<SharedWork>,
start_barrier: Arc<Barrier>,
end_barrier: Arc<Barrier>,
_handles: Vec<std::thread::JoinHandle<()>>,
}
impl GemvPool {
pub fn new(n: usize) -> Self {
let shared = Arc::new(SharedWork {
work: Mutex::new(None),
});
let start_barrier = Arc::new(Barrier::new(n + 1));
let end_barrier = Arc::new(Barrier::new(n + 1));
let mut handles = Vec::with_capacity(n);
for tid in 0..n {
let s = Arc::clone(&shared);
let sb = Arc::clone(&start_barrier);
let eb = Arc::clone(&end_barrier);
let nt = n;
handles.push(
std::thread::Builder::new()
.name(format!("gemv-{tid}"))
.spawn(move || {
loop {
sb.wait();
let (f, out_dim) = {
let guard = s.work.lock().unwrap();
match guard.as_ref() {
Some((f, od)) => {
// Can't clone Box<dyn Fn>, but we can get a ref
// The Mutex is only locked briefly to read the ptr
let f_ptr = f.as_ref()
as *const (dyn Fn(usize, usize) + Send + Sync);
(f_ptr, *od)
},
None => {
eb.wait();
return; // shutdown
},
}
};
let rows_per = (out_dim + nt - 1) / nt;
let start = tid * rows_per;
let end = (start + rows_per).min(out_dim);
if start < out_dim {
// SAFETY: f is valid for the duration of dispatch()
// (caller blocks on end_barrier before dropping the closure)
unsafe { (*f)(start, end) };
}
eb.wait();
}
})
.expect("spawn gemv thread"),
);
}
Self {
n_threads: n,
shared,
start_barrier,
end_barrier,
_handles: handles,
}
}
#[inline]
pub fn dispatch<F: Fn(usize, usize) + Send + Sync>(&self, out_dim: usize, f: &F) {
// Box the closure reference (we know f lives until end_barrier.wait)
let f_ref: &(dyn Fn(usize, usize) + Send + Sync) = f;
// SAFETY: f lives until we call end_barrier.wait() below
let f_static: &'static (dyn Fn(usize, usize) + Send + Sync) =
unsafe { std::mem::transmute(f_ref) };
let boxed: WorkFn = Box::new(move |a, b| (*f_static)(a, b));
{
let mut guard = self.shared.work.lock().unwrap();
*guard = Some((boxed, out_dim));
}
self.start_barrier.wait();
self.end_barrier.wait();
// Clear work
{
let mut guard = self.shared.work.lock().unwrap();
*guard = None;
}
}
}
impl Drop for GemvPool {
fn drop(&mut self) {
{
let mut guard = self.shared.work.lock().unwrap();
*guard = None;
}
self.start_barrier.wait();
self.end_barrier.wait();
}
}
static POOL: std::sync::OnceLock<GemvPool> = std::sync::OnceLock::new();
pub(crate) fn get_pool() -> &'static GemvPool {
POOL.get_or_init(|| GemvPool::new(rayon::current_num_threads()))
}