use std::sync::atomic::{AtomicU32, AtomicUsize, Ordering::*};
use crate::loom_shim::Mutex;
use super::g::G;
use super::m::M;
pub(crate) const PIDLE: u32 = 0;
pub(crate) const PRUNNING: u32 = 1;
pub(crate) const PSYSCALL: u32 = 2;
pub(crate) const PGCSTOP: u32 = 3;
pub(crate) const PDEAD: u32 = 4;
const RUNQ_CAP: usize = 256;
pub(crate) struct GlobalRunQueue {
inner: Mutex<GlobalRunQueueInner>,
}
struct GlobalRunQueueInner {
head: *mut G,
tail: *mut G,
count: u32,
}
unsafe impl Send for GlobalRunQueueInner {}
impl GlobalRunQueue {
pub(crate) fn new() -> Self {
Self {
inner: Mutex::new(GlobalRunQueueInner {
head: std::ptr::null_mut(),
tail: std::ptr::null_mut(),
count: 0,
}),
}
}
pub(crate) unsafe fn push_batch(&self, head: *mut G, tail: *mut G, count: u32) {
unsafe { (*tail).schedlink = std::ptr::null_mut() };
let mut inner = self.inner.lock().unwrap();
if inner.tail.is_null() {
inner.head = head;
} else {
unsafe { (*inner.tail).schedlink = head };
}
inner.tail = tail;
inner.count += count;
}
pub(crate) unsafe fn pop(&self) -> *mut G {
let mut inner = self.inner.lock().unwrap();
let gp = inner.head;
if gp.is_null() {
return std::ptr::null_mut();
}
inner.head = unsafe { (*gp).schedlink };
if inner.head.is_null() {
inner.tail = std::ptr::null_mut();
}
unsafe { (*gp).schedlink = std::ptr::null_mut() };
inner.count -= 1;
gp
}
pub(crate) fn len(&self) -> u32 {
self.inner.lock().unwrap().count
}
}
pub(crate) struct P {
pub id: i32,
pub status: AtomicU32,
pub m: *mut M,
runqhead: AtomicU32,
runqtail: AtomicU32,
runq: [AtomicUsize; RUNQ_CAP],
runnext: AtomicUsize,
pub schedtick: AtomicU32,
pub syscalltick: AtomicU32,
pub link: *mut P,
}
unsafe impl Send for P {}
unsafe impl Sync for P {}
impl P {
pub(crate) fn new(id: i32) -> Box<P> {
Box::new(P {
id,
status: AtomicU32::new(PIDLE),
m: std::ptr::null_mut(),
runqhead: AtomicU32::new(0),
runqtail: AtomicU32::new(0),
runq: std::array::from_fn(|_| AtomicUsize::new(0)),
runnext: AtomicUsize::new(0),
schedtick: AtomicU32::new(0),
syscalltick: AtomicU32::new(0),
link: std::ptr::null_mut(),
})
}
pub(crate) unsafe fn runqput(
&self,
mut gp: *mut G,
next: bool,
global_q: &GlobalRunQueue,
) {
if next {
let mut old = self.runnext.load(Relaxed);
loop {
match self.runnext.compare_exchange_weak(old, gp as usize, AcqRel, Relaxed) {
Ok(_) => {
if old == 0 {
return;
}
gp = old as *mut G;
break;
}
Err(cur) => old = cur,
}
}
}
'retry: loop {
let h = self.runqhead.load(Acquire); let t = self.runqtail.load(Relaxed);
if t.wrapping_sub(h) < RUNQ_CAP as u32 {
self.runq[(t as usize) % RUNQ_CAP].store(gp as usize, Relaxed);
self.runqtail.store(t.wrapping_add(1), Release); return;
}
if unsafe { self.runqputslow(gp, h, t, global_q) } {
return;
}
continue 'retry;
}
}
unsafe fn runqputslow(
&self,
gp: *mut G,
h: u32,
t: u32,
global_q: &GlobalRunQueue,
) -> bool {
let n = t.wrapping_sub(h) / 2;
debug_assert_eq!(
n,
(RUNQ_CAP / 2) as u32,
"runqputslow: queue not full (n={n})"
);
if self
.runqhead
.compare_exchange(h, h.wrapping_add(n), Release, Relaxed)
.is_err()
{
return false; }
let n_usize = n as usize;
let mut batch = [std::ptr::null_mut::<G>(); RUNQ_CAP / 2 + 1];
for (i, b) in batch.iter_mut().enumerate().take(n_usize) {
let slot = self.runq[(h.wrapping_add(i as u32) as usize) % RUNQ_CAP].load(Relaxed);
*b = slot as *mut G;
}
batch[n_usize] = gp;
for i in 0..n_usize {
unsafe { (*batch[i]).schedlink = batch[i + 1] };
}
let head_g = batch[0];
let tail_g = batch[n_usize];
unsafe { global_q.push_batch(head_g, tail_g, (n_usize + 1) as u32) };
true
}
pub(crate) fn runqget(&self) -> (*mut G, bool) {
let next = self.runnext.load(Relaxed);
if next != 0 {
if self
.runnext
.compare_exchange(next, 0, AcqRel, Relaxed)
.is_ok()
{
return (next as *mut G, true);
}
}
loop {
let h = self.runqhead.load(Acquire); let t = self.runqtail.load(Relaxed);
if t == h {
return (std::ptr::null_mut(), false);
}
let gp = self.runq[(h as usize) % RUNQ_CAP].load(Relaxed) as *mut G;
if self
.runqhead
.compare_exchange(h, h.wrapping_add(1), Release, Relaxed)
.is_ok()
{
return (gp, false);
}
}
}
pub(crate) fn runqsteal(&self, victim: &P, steal_run_next: bool) -> *mut G {
let t = self.runqtail.load(Relaxed);
let n = victim.runqgrab(&self.runq, t, steal_run_next);
if n == 0 {
return std::ptr::null_mut();
}
let n = n - 1;
let gp = self.runq[(t.wrapping_add(n) as usize) % RUNQ_CAP].load(Relaxed) as *mut G;
if n == 0 {
return gp;
}
let h = self.runqhead.load(Acquire);
debug_assert!(
t.wrapping_sub(h).wrapping_add(n) < RUNQ_CAP as u32,
"runqsteal: runq overflow"
);
self.runqtail.store(t.wrapping_add(n), Release);
gp
}
fn runqgrab(
&self,
batch: &[AtomicUsize; RUNQ_CAP],
batch_head: u32,
steal_run_next: bool,
) -> u32 {
loop {
let h = self.runqhead.load(Acquire);
let t = self.runqtail.load(Acquire);
let n_full = t.wrapping_sub(h);
let n = n_full - n_full / 2;
if n == 0 {
if steal_run_next {
let next = self.runnext.load(Relaxed);
if next != 0 {
if self.status.load(Acquire) == PRUNNING {
std::thread::sleep(std::time::Duration::from_micros(3));
}
if self
.runnext
.compare_exchange(next, 0, AcqRel, Relaxed)
.is_ok()
{
batch[(batch_head as usize) % RUNQ_CAP].store(next, Relaxed);
return 1;
}
continue;
}
}
return 0;
}
if n > (RUNQ_CAP / 2) as u32 {
continue; }
for i in 0..n {
let slot = self.runq[(h.wrapping_add(i) as usize) % RUNQ_CAP].load(Relaxed);
batch[(batch_head.wrapping_add(i) as usize) % RUNQ_CAP].store(slot, Relaxed);
}
if self
.runqhead
.compare_exchange(h, h.wrapping_add(n), Release, Relaxed)
.is_ok()
{
return n;
}
}
}
pub(crate) fn runq_size(&self) -> u32 {
let h = self.runqhead.load(Acquire);
let t = self.runqtail.load(Acquire);
let ring = t.wrapping_sub(h);
let rn = if self.runnext.load(Relaxed) != 0 { 1 } else { 0 };
ring + rn
}
}
#[cfg(all(test, not(loom)))]
mod tests {
use super::*;
use crate::runtime::g::{Stack, G};
fn make_g(id: u64) -> Box<G> {
let lo = (id as usize + 1) << 20; G::new(Stack { lo, hi: lo + 65536 }, id)
}
#[test]
fn p_new_initial_state() {
let p = P::new(1);
assert_eq!(p.id, 1);
assert_eq!(p.status.load(Relaxed), PIDLE);
assert!(p.m.is_null());
assert_eq!(p.runqhead.load(Relaxed), 0);
assert_eq!(p.runqtail.load(Relaxed), 0);
assert_eq!(p.runnext.load(Relaxed), 0);
assert!(p.link.is_null());
}
#[test]
fn global_queue_push_pop() {
let gq = GlobalRunQueue::new();
let g1 = make_g(1);
let g2 = make_g(2);
let g1_ptr = Box::into_raw(g1);
let g2_ptr = Box::into_raw(g2);
unsafe {
(*g1_ptr).schedlink = g2_ptr;
gq.push_batch(g1_ptr, g2_ptr, 2);
assert_eq!(gq.len(), 2);
let got1 = gq.pop();
assert_eq!(got1, g1_ptr);
assert_eq!(gq.len(), 1);
let got2 = gq.pop();
assert_eq!(got2, g2_ptr);
assert_eq!(gq.len(), 0);
}
let _ = unsafe { Box::from_raw(g1_ptr) };
let _ = unsafe { Box::from_raw(g2_ptr) };
}
#[test]
fn global_queue_pop_empty() {
let gq = GlobalRunQueue::new();
let got = unsafe { gq.pop() };
assert!(got.is_null());
assert_eq!(gq.len(), 0);
}
#[test]
fn runqput_runqget_fifo() {
let p = P::new(0);
let gq = GlobalRunQueue::new();
let mut goroutines: Vec<Box<G>> = (0..10).map(|i| make_g(i as u64)).collect();
let ptrs: Vec<*mut G> = goroutines.iter_mut().map(|g| &mut **g as *mut G).collect();
for ptr in &ptrs {
unsafe { p.runqput(*ptr, false, &gq) };
}
for (i, expected_ptr) in ptrs.iter().enumerate() {
let (got, inherit) = p.runqget();
assert_eq!(got, *expected_ptr, "mismatch at position {i}");
assert!(!inherit, "should not inherit time for normal enqueue");
}
let (got, _) = p.runqget();
assert!(got.is_null());
for g in goroutines {
let _ = unsafe { Box::from_raw(Box::into_raw(g)) };
}
}
#[test]
fn runqput_next_installs_runnext() {
let p = P::new(0);
let gq = GlobalRunQueue::new();
let g1 = make_g(1);
let g1_ptr = Box::into_raw(g1);
unsafe { p.runqput(g1_ptr, true, &gq) };
let (got, inherit) = p.runqget();
assert_eq!(got, g1_ptr);
assert!(inherit);
let _ = unsafe { Box::from_raw(g1_ptr) };
}
#[test]
fn runqput_next_displaces_old_runnext() {
let p = P::new(0);
let gq = GlobalRunQueue::new();
let g1 = make_g(1);
let g2 = make_g(2);
let g1_ptr = Box::into_raw(g1);
let g2_ptr = Box::into_raw(g2);
unsafe { p.runqput(g1_ptr, true, &gq) };
unsafe { p.runqput(g2_ptr, true, &gq) };
let (got, inherit) = p.runqget();
assert_eq!(got, g2_ptr);
assert!(inherit);
let (got, inherit) = p.runqget();
assert_eq!(got, g1_ptr);
assert!(!inherit);
unsafe {
let _ = Box::from_raw(g1_ptr);
let _ = Box::from_raw(g2_ptr);
}
}
#[test]
fn runqput_overflow_to_global() {
let p = P::new(0);
let gq = GlobalRunQueue::new();
let mut goroutines: Vec<Box<G>> = (0..256).map(|i| make_g(i as u64)).collect();
let ptrs: Vec<*mut G> = goroutines.iter_mut().map(|g| &mut **g as *mut G).collect();
for ptr in &ptrs {
unsafe { p.runqput(*ptr, false, &gq) };
}
let g257 = make_g(257);
let g257_ptr = Box::into_raw(g257);
unsafe { p.runqput(g257_ptr, false, &gq) };
let gq_len = gq.len();
assert!(gq_len > 0, "expected overflow to global queue, got {gq_len}");
let _ = unsafe { Box::from_raw(g257_ptr) };
for g in goroutines {
let _ = unsafe { Box::from_raw(Box::into_raw(g)) };
}
}
#[test]
fn runq_size_counts_runnext() {
let p = P::new(0);
let gq = GlobalRunQueue::new();
let g1 = make_g(1);
let g1_ptr = Box::into_raw(g1);
assert_eq!(p.runq_size(), 0);
unsafe { p.runqput(g1_ptr, true, &gq) };
assert_eq!(p.runq_size(), 1);
let _ = unsafe { Box::from_raw(g1_ptr) };
}
#[test]
fn runqsteal_basic() {
let victim = P::new(0);
let stealer = P::new(1);
let gq = GlobalRunQueue::new();
let mut goroutines: Vec<Box<G>> = (0..10).map(|i| make_g(i as u64)).collect();
let ptrs: Vec<*mut G> = goroutines.iter_mut().map(|g| &mut **g as *mut G).collect();
for ptr in &ptrs {
unsafe { victim.runqput(*ptr, false, &gq) };
}
let stolen = stealer.runqsteal(&victim, false);
assert!(!stolen.is_null(), "should have stolen something");
let (got, _) = stealer.runqget();
assert!(!got.is_null(), "stealer should have something to run");
for g in goroutines {
let _ = unsafe { Box::from_raw(Box::into_raw(g)) };
}
}
}
#[cfg(all(test, loom))]
mod loom_tests {
use super::*;
use crate::runtime::g::{Stack, G};
use loom::sync::Arc;
struct GPtr(*mut G);
unsafe impl Send for GPtr {}
fn make_g(id: u64) -> *mut G {
let lo = (id as usize + 1) << 20;
Box::into_raw(G::new(Stack { lo, hi: lo + 65536 }, id))
}
#[test]
fn concurrent_push_pop() {
loom::model(|| {
let gq = Arc::new(GlobalRunQueue::new());
let gq2 = Arc::clone(&gq);
let g1 = GPtr(make_g(1));
let pusher = loom::thread::spawn(move || {
let gp = g1.0;
unsafe { gq2.push_batch(gp, gp, 1) };
});
let got = unsafe { gq.pop() };
pusher.join().unwrap();
let got2 = unsafe { gq.pop() };
let retrieved = [got, got2].iter().filter(|p| !p.is_null()).count();
assert_eq!(retrieved, 1, "expected exactly one G across both pops");
for p in [got, got2] {
if !p.is_null() {
let _ = unsafe { Box::from_raw(p) };
}
}
});
}
#[test]
fn concurrent_two_pops() {
loom::model(|| {
let gq = Arc::new(GlobalRunQueue::new());
let gq2 = Arc::clone(&gq);
let gq3 = Arc::clone(&gq);
let g1 = make_g(1);
let g2 = make_g(2);
unsafe {
(*g1).schedlink = g2;
gq.push_batch(g1, g2, 2);
}
let t1 = loom::thread::spawn(move || unsafe { gq2.pop() });
let t2 = loom::thread::spawn(move || unsafe { gq3.pop() });
let p1 = t1.join().unwrap();
let p2 = t2.join().unwrap();
let mut ptrs: Vec<*mut G> = [p1, p2, unsafe { gq.pop() }]
.into_iter()
.filter(|p| !p.is_null())
.collect();
assert_eq!(ptrs.len(), 2, "expected exactly 2 Gs from 2 pops");
ptrs.sort();
ptrs.dedup();
assert_eq!(ptrs.len(), 2, "each G must be returned to at most one thread");
for p in ptrs {
let _ = unsafe { Box::from_raw(p) };
}
});
}
}