use crate::future::Future;
use crate::loom::cell::UnsafeCell;
use crate::runtime::task::{JoinHandle, LocalNotified, Notified, Schedule, SpawnLocation, Task};
use crate::util::linked_list::{Link, LinkedList};
use crate::util::sharded_list;
use crate::loom::sync::atomic::{AtomicBool, Ordering};
use std::marker::PhantomData;
use std::num::NonZeroU64;
cfg_has_atomic_u64! {
use std::sync::atomic::AtomicU64;
static NEXT_OWNED_TASKS_ID: AtomicU64 = AtomicU64::new(1);
fn get_next_id() -> NonZeroU64 {
loop {
let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed);
if let Some(id) = NonZeroU64::new(id) {
return id;
}
}
}
}
cfg_not_has_atomic_u64! {
use std::sync::atomic::AtomicU32;
static NEXT_OWNED_TASKS_ID: AtomicU32 = AtomicU32::new(1);
fn get_next_id() -> NonZeroU64 {
loop {
let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed);
if let Some(id) = NonZeroU64::new(u64::from(id)) {
return id;
}
}
}
}
pub(crate) struct OwnedTasks<S: 'static> {
list: List<S>,
pub(crate) id: NonZeroU64,
closed: AtomicBool,
}
type List<S> = sharded_list::ShardedList<Task<S>, <Task<S> as Link>::Target>;
pub(crate) struct LocalOwnedTasks<S: 'static> {
inner: UnsafeCell<OwnedTasksInner<S>>,
pub(crate) id: NonZeroU64,
_not_send_or_sync: PhantomData<*const ()>,
}
struct OwnedTasksInner<S: 'static> {
list: LinkedList<Task<S>, <Task<S> as Link>::Target>,
closed: bool,
}
impl<S: 'static> OwnedTasks<S> {
pub(crate) fn new(num_cores: usize) -> Self {
let shard_size = Self::gen_shared_list_size(num_cores);
Self {
list: List::new(shard_size),
closed: AtomicBool::new(false),
id: get_next_id(),
}
}
pub(crate) fn bind<T>(
&self,
task: T,
scheduler: S,
id: super::Id,
spawned_at: SpawnLocation,
) -> (JoinHandle<T::Output>, Option<Notified<S>>)
where
S: Schedule,
T: Future + Send + 'static,
T::Output: Send + 'static,
{
let (task, notified, join) = super::new_task(task, scheduler, id, spawned_at);
let notified = unsafe { self.bind_inner(task, notified) };
(join, notified)
}
pub(crate) unsafe fn bind_local<T>(
&self,
task: T,
scheduler: S,
id: super::Id,
spawned_at: SpawnLocation,
) -> (JoinHandle<T::Output>, Option<Notified<S>>)
where
S: Schedule,
T: Future + 'static,
T::Output: 'static,
{
let (task, notified, join) = super::new_task(task, scheduler, id, spawned_at);
let notified = unsafe { self.bind_inner(task, notified) };
(join, notified)
}
unsafe fn bind_inner(&self, task: Task<S>, notified: Notified<S>) -> Option<Notified<S>>
where
S: Schedule,
{
unsafe {
task.header().set_owner_id(self.id);
}
let shard = self.list.lock_shard(&task);
if self.closed.load(Ordering::Acquire) {
drop(shard);
task.shutdown();
return None;
}
shard.push(task);
Some(notified)
}
#[inline]
pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> {
debug_assert_eq!(task.header().get_owner_id(), Some(self.id));
LocalNotified {
task: task.0,
_not_send: PhantomData,
}
}
pub(crate) fn close_and_shutdown_all(&self, start: usize)
where
S: Schedule,
{
self.closed.store(true, Ordering::Release);
for i in start..self.get_shard_size() + start {
loop {
let task = self.list.pop_back(i);
match task {
Some(task) => {
task.shutdown();
}
None => break,
}
}
}
}
#[inline]
pub(crate) fn get_shard_size(&self) -> usize {
self.list.shard_size()
}
pub(crate) fn num_alive_tasks(&self) -> usize {
self.list.len()
}
cfg_unstable_metrics! {
cfg_64bit_metrics! {
pub(crate) fn spawned_tasks_count(&self) -> u64 {
self.list.added()
}
}
}
pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> {
let task_id = task.header().get_owner_id()?;
assert_eq!(task_id, self.id);
unsafe { self.list.remove(task.header_ptr()) }
}
pub(crate) fn is_empty(&self) -> bool {
self.list.is_empty()
}
fn gen_shared_list_size(num_cores: usize) -> usize {
const MAX_SHARED_LIST_SIZE: usize = 1 << 16;
usize::min(MAX_SHARED_LIST_SIZE, num_cores.next_power_of_two() * 4)
}
}
cfg_taskdump! {
impl<S: 'static> OwnedTasks<S> {
pub(crate) fn for_each<F>(&self, f: F)
where
F: FnMut(&Task<S>),
{
self.list.for_each(f);
}
}
}
impl<S: 'static> LocalOwnedTasks<S> {
pub(crate) fn new() -> Self {
Self {
inner: UnsafeCell::new(OwnedTasksInner {
list: LinkedList::new(),
closed: false,
}),
id: get_next_id(),
_not_send_or_sync: PhantomData,
}
}
pub(crate) fn bind<T>(
&self,
task: T,
scheduler: S,
id: super::Id,
spawned_at: SpawnLocation,
) -> (JoinHandle<T::Output>, Option<Notified<S>>)
where
S: Schedule,
T: Future + 'static,
T::Output: 'static,
{
let (task, notified, join) = super::new_task(task, scheduler, id, spawned_at);
unsafe {
task.header().set_owner_id(self.id);
}
if self.is_closed() {
drop(notified);
task.shutdown();
(join, None)
} else {
self.with_inner(|inner| {
inner.list.push_front(task);
});
(join, Some(notified))
}
}
pub(crate) fn close_and_shutdown_all(&self)
where
S: Schedule,
{
self.with_inner(|inner| inner.closed = true);
while let Some(task) = self.with_inner(|inner| inner.list.pop_back()) {
task.shutdown();
}
}
pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> {
let task_id = task.header().get_owner_id()?;
assert_eq!(task_id, self.id);
self.with_inner(|inner|
unsafe { inner.list.remove(task.header_ptr()) })
}
#[inline]
pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> {
assert_eq!(task.header().get_owner_id(), Some(self.id));
LocalNotified {
task: task.0,
_not_send: PhantomData,
}
}
#[inline]
fn with_inner<F, T>(&self, f: F) -> T
where
F: FnOnce(&mut OwnedTasksInner<S>) -> T,
{
self.inner.with_mut(|ptr| unsafe { f(&mut *ptr) })
}
pub(crate) fn is_closed(&self) -> bool {
self.with_inner(|inner| inner.closed)
}
pub(crate) fn is_empty(&self) -> bool {
self.with_inner(|inner| inner.list.is_empty())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_id_not_broken() {
let mut last_id = get_next_id();
for _ in 0..1000 {
let next_id = get_next_id();
assert!(last_id < next_id);
last_id = next_id;
}
}
}