use crate::error::{TarantoolError, TarantoolErrorCode};
use crate::ffi::has_fiber_id;
use crate::ffi::tarantool::fiber_sleep;
use crate::ffi::{lua, tarantool as ffi};
use crate::static_assert;
use crate::time::Instant;
use crate::tlua::{self as tlua, AsLua};
use crate::unwrap_ok_or;
use crate::{c_ptr, set_error};
use ::va_list::VaList;
pub use channel::Channel;
pub use channel::RecvError;
pub use channel::RecvTimeout;
pub use channel::SendError;
pub use channel::SendTimeout;
pub use channel::TryRecvError;
pub use channel::TrySendError;
pub use csw::check_yield;
pub use csw::YieldResult;
pub use mutex::Mutex;
pub use r#async::block_on;
use std::cell::UnsafeCell;
use std::ffi::CString;
use std::future::Future;
use std::marker::PhantomData;
use std::mem::{align_of, size_of};
use std::os::raw::c_void;
use std::ptr::NonNull;
use std::rc::Rc;
use std::time::Duration;
pub mod r#async;
pub mod safety;
pub use safety::*;
pub mod channel;
mod csw;
pub mod mutex;
pub type FiberId = u64;
pub const FIBER_ID_INVALID: FiberId = 0;
pub const FIBER_ID_SCHED: FiberId = 1;
pub const FIBER_ID_MAX_RESERVED: FiberId = 100;
#[deprecated = "use fiber::start, fiber::defer or fiber::Builder"]
pub struct Fiber<'a, T: 'a> {
inner: *mut ffi::Fiber,
callback: *mut c_void,
phantom: PhantomData<&'a T>,
}
#[allow(deprecated)]
impl<T> ::std::fmt::Debug for Fiber<'_, T> {
fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
f.debug_struct("Fiber").finish_non_exhaustive()
}
}
#[allow(deprecated)]
impl<T> Fiber<'_, T> {
pub fn new<F>(name: &str, callback: &mut F) -> Self
where
F: FnMut(Box<T>) -> i32,
{
let (callback_ptr, trampoline) = unsafe { unpack_callback(callback) };
let name_cstr = CString::new(name).expect("fiber name should not contain nul bytes");
Self {
inner: unsafe { ffi::fiber_new(name_cstr.as_ptr(), trampoline) },
callback: callback_ptr,
phantom: PhantomData,
}
}
pub fn new_with_attr<F>(name: &str, attr: &FiberAttr, callback: &mut F) -> Self
where
F: FnMut(Box<T>) -> i32,
{
let (callback_ptr, trampoline) = unsafe { unpack_callback(callback) };
let name_cstr = CString::new(name).expect("fiber name should not contain nul bytes");
Self {
inner: unsafe { ffi::fiber_new_ex(name_cstr.as_ptr(), attr.inner, trampoline) },
callback: callback_ptr,
phantom: PhantomData,
}
}
pub fn start(&mut self, arg: T) {
unsafe {
let boxed_arg = Box::into_raw(Box::<T>::new(arg));
ffi::fiber_start(self.inner, self.callback, boxed_arg);
}
}
pub fn wakeup(&self) {
unsafe { ffi::fiber_wakeup(self.inner) }
}
pub fn join(&self) -> i32 {
unsafe { ffi::fiber_join(self.inner) }
}
pub fn set_joinable(&mut self, is_joinable: bool) {
unsafe { ffi::fiber_set_joinable(self.inner, is_joinable) }
}
pub fn cancel(&mut self) {
unsafe { ffi::fiber_cancel(self.inner) }
}
#[inline(always)]
#[track_caller]
pub fn id(&self) -> FiberId {
self.id_checked().expect("fiber_id api is not supported")
}
pub fn id_checked(&self) -> Option<FiberId> {
if unsafe { !has_fiber_id() } {
return None;
}
let res = unsafe { ffi::fiber_id(self.inner) };
Some(res)
}
}
pub struct Builder<F> {
name: Option<String>,
attr: Option<FiberAttr>,
f: F,
}
impl<T> ::std::fmt::Debug for Builder<T> {
fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
f.debug_struct("Builder").finish_non_exhaustive()
}
}
impl Builder<NoFunc> {
#[inline(always)]
pub fn new() -> Self {
Builder {
name: None,
attr: None,
f: NoFunc,
}
}
#[inline(always)]
pub fn func<'f, F, T>(self, f: F) -> Builder<F>
where
F: FnOnce() -> T,
F: 'f,
{
Builder {
name: self.name,
attr: self.attr,
f,
}
}
#[inline(always)]
pub fn func_async<'f, F, T>(self, f: F) -> Builder<impl FnOnce() -> T + 'f>
where
F: Future<Output = T> + 'f,
T: 'f,
{
self.func(|| block_on(f))
}
#[deprecated = "Use `Builder::func` instead"]
#[inline(always)]
pub fn proc<'f, F>(self, f: F) -> Builder<F>
where
F: FnOnce(),
F: 'f,
{
self.func(f)
}
#[deprecated = "Use `Builder::func_async` instead"]
#[inline(always)]
pub fn proc_async<'f, F>(self, f: F) -> Builder<impl FnOnce() + 'f>
where
F: Future<Output = ()> + 'f,
{
self.func_async(f)
}
}
impl Default for Builder<NoFunc> {
#[inline(always)]
fn default() -> Self {
Self::new()
}
}
impl<F> Builder<F> {
#[inline(always)]
pub fn name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
#[inline(always)]
pub fn stack_size(mut self, stack_size: usize) -> crate::Result<Self> {
let mut attr = FiberAttr::new();
attr.set_stack_size(stack_size)?;
self.attr = Some(attr);
Ok(self)
}
}
impl<'f, F, T> Builder<F>
where
F: FnOnce() -> T + 'f,
T: 'f,
{
#[inline(always)]
pub fn start(self) -> crate::Result<JoinHandle<'f, T>> {
let (name, f, attr) = self.into_fiber_args();
let res = Fyber::spawn_and_yield(name, f, true, attr.as_ref())?;
let Ok(jh) = res else {
unreachable!("spawn_and_yield returns the join handle when is_joinable = true");
};
Ok(jh)
}
#[inline(always)]
pub fn defer(self) -> crate::Result<JoinHandle<'f, T>> {
let (name, f, attr) = self.into_fiber_args();
if !unsafe { crate::ffi::has_fiber_set_ctx() } {
return Fyber::spawn_lua(name, f, attr.as_ref());
}
let res = Fyber::spawn_deferred(name, f, true, attr.as_ref())?;
let Ok(jh) = res else {
unreachable!("spawn_deferred returns the join handle when is_joinable = true");
};
Ok(jh)
}
#[inline(always)]
pub fn defer_ffi(self) -> crate::Result<JoinHandle<'f, T>> {
let (name, f, attr) = self.into_fiber_args();
let res = Fyber::spawn_deferred(name, f, true, attr.as_ref())?;
let Ok(jh) = res else {
unreachable!("spawn_deferred returns the join handle when is_joinable = true");
};
Ok(jh)
}
#[inline(always)]
pub fn defer_lua(self) -> crate::Result<JoinHandle<'f, T>> {
let (name, f, attr) = self.into_fiber_args();
Fyber::spawn_lua(name, f, attr.as_ref())
}
fn into_fiber_args(self) -> (String, F, Option<FiberAttr>) {
#[rustfmt::skip]
let Self { name, attr, f } = self;
let name = name.unwrap_or_else(|| "<rust>".into());
(name, f, attr)
}
}
impl<F, T> Builder<F>
where
F: FnOnce() -> T + 'static,
T: 'static,
{
#[inline(always)]
pub fn start_non_joinable(self) -> crate::Result<FiberId> {
let (name, f, attr) = self.into_fiber_args();
let res = Fyber::spawn_and_yield(name, f, false, attr.as_ref())?;
let Err(id) = res else {
unreachable!("spawn_and_yield returns the fiber id when is_joinable = false");
};
Ok(id)
}
#[inline(always)]
pub fn defer_non_joinable(self) -> crate::Result<Option<FiberId>> {
let (name, f, attr) = self.into_fiber_args();
if !unsafe { crate::ffi::has_fiber_set_ctx() } {
#[rustfmt::skip]
set_error!(TarantoolErrorCode::Unsupported, "deferred non-joinable fibers are not supported in current tarantool version (fiber_set_ctx API is required)");
return Err(TarantoolError::last().into());
}
let res = Fyber::spawn_deferred(name, f, false, attr.as_ref())?;
let Err(id) = res else {
unreachable!("spawn_deferred returns the fiber id when is_joinable = false");
};
Ok(id)
}
const _TEST_NON_STATIC_FIBER_FUNCS_DONT_COMPILE: () = ();
}
pub struct Fyber<F, T> {
_marker: PhantomData<(F, T)>,
}
impl<F, T> ::std::fmt::Debug for Fyber<F, T> {
fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
f.debug_struct("Fyber").finish_non_exhaustive()
}
}
impl<'f, F, T> Fyber<F, T>
where
F: FnOnce() -> T + 'f,
T: 'f,
{
pub fn spawn_and_yield(
name: String,
f: F,
is_joinable: bool,
attr: Option<&FiberAttr>,
) -> crate::Result<Result<JoinHandle<'f, T>, FiberId>> {
if !is_joinable && needs_returning::<T>() {
#[rustfmt::skip]
set_error!(TarantoolErrorCode::Unsupported, "non-joinable fibers which return a value are not supported");
return Err(TarantoolError::last().into());
}
let cname = unwrap_ok_or!(CString::new(name),
Err(e) => {
#[rustfmt::skip]
set_error!(TarantoolErrorCode::IllegalParams, "fiber name may not contain nul-bytes: {e}");
return Err(TarantoolError::last().into());
}
);
let inner_raw = unsafe {
if let Some(attr) = attr {
ffi::fiber_new_ex(
cname.as_ptr(),
attr.inner,
Some(Self::trampoline_for_ffi::<false>),
)
} else {
ffi::fiber_new(cname.as_ptr(), Some(Self::trampoline_for_ffi::<false>))
}
};
let Some(inner) = NonNull::new(inner_raw) else {
return Err(TarantoolError::last().into());
};
unsafe {
ffi::fiber_set_joinable(inner.as_ptr(), is_joinable);
let result_cell = needs_returning::<T>().then(FiberResultCell::default);
let mut ctx = Context::default();
if let Some(result_cell) = &result_cell {
ctx.fiber_result_ptr = result_cell.get() as _;
}
ctx.fiber_rust_closure = Box::into_raw(Box::new(f)) as _;
let ctx_rc: Rc<UnsafeCell<Context>> = Rc::new(UnsafeCell::new(ctx));
ffi::fiber_start(inner.as_ptr(), Rc::into_raw(ctx_rc.clone()));
if is_joinable {
Ok(Ok(JoinHandle::ffi(inner, result_cell)))
} else {
let ctx = &*ctx_rc.get();
Ok(Err(ctx.fiber_id))
}
}
}
pub fn spawn_deferred(
name: String,
f: F,
is_joinable: bool,
attr: Option<&FiberAttr>,
) -> crate::Result<Result<JoinHandle<'f, T>, Option<FiberId>>> {
if !is_joinable && needs_returning::<T>() {
#[rustfmt::skip]
set_error!(TarantoolErrorCode::Unsupported, "non-joinable fibers which return a value are not supported");
return Err(TarantoolError::last().into());
}
let cname = unwrap_ok_or!(CString::new(name),
Err(e) => {
#[rustfmt::skip]
set_error!(TarantoolErrorCode::IllegalParams, "fiber name may not contain nul-bytes: {e}");
return Err(TarantoolError::last().into());
}
);
let inner_raw = unsafe {
if let Some(attr) = attr {
ffi::fiber_new_ex(
cname.as_ptr(),
attr.inner,
Some(Self::trampoline_for_ffi::<true>),
)
} else {
ffi::fiber_new(cname.as_ptr(), Some(Self::trampoline_for_ffi::<true>))
}
};
let Some(inner) = NonNull::new(inner_raw) else {
return Err(TarantoolError::last().into());
};
unsafe {
ffi::fiber_set_joinable(inner.as_ptr(), is_joinable);
let result_cell = needs_returning::<T>().then(FiberResultCell::default);
let mut ctx = Context::default();
if let Some(result_cell) = &result_cell {
ctx.fiber_result_ptr = result_cell.get() as _;
}
ctx.fiber_rust_closure = Box::into_raw(Box::new(f)) as _;
let ctx_rc: Rc<UnsafeCell<Context>> = Rc::new(UnsafeCell::new(ctx));
ffi::fiber_set_ctx(inner.as_ptr(), Rc::into_raw(ctx_rc) as _);
ffi::fiber_wakeup(inner.as_ptr());
if is_joinable {
Ok(Ok(JoinHandle::ffi(inner, result_cell)))
} else {
if has_fiber_id() {
Ok(Err(Some(ffi::fiber_id(inner.as_ptr()))))
} else {
Ok(Err(None))
}
}
}
}
unsafe extern "C" fn trampoline_for_ffi<const VIA_CONTEXT: bool>(mut args: VaList) -> i32 {
ffi::fiber_set_cancellable(true);
let ctx;
if VIA_CONTEXT {
let fiber_self = ffi::fiber_self();
ctx = ffi::fiber_get_ctx(fiber_self).cast::<Context>();
} else {
ctx = args.get::<*const Context>() as _;
if crate::ffi::has_fiber_set_ctx() {
let fiber_self = ffi::fiber_self();
ffi::fiber_set_ctx(fiber_self, ctx as _);
}
}
debug_assert!(context_is_valid(ctx));
let ctx_rc: Rc<UnsafeCell<Context>> = Rc::from_raw(ctx.cast());
let ctx = &mut *ctx_rc.get();
ctx.fiber_id = id();
let f = std::mem::replace(&mut ctx.fiber_rust_closure, std::ptr::null_mut());
let f = Box::from_raw(f.cast::<F>());
let t = (f)();
if needs_returning::<T>() {
assert!(!ctx.fiber_result_ptr.is_null());
std::ptr::write(ctx.fiber_result_ptr.cast(), Some(t));
} else {
debug_assert!(ctx.fiber_result_ptr.is_null());
}
0
}
pub fn spawn_lua(
name: String,
f: F,
_attr: Option<&FiberAttr>,
) -> crate::Result<JoinHandle<'f, T>> {
if let Some(pos) = name.find('\0') {
#[rustfmt::skip]
set_error!(TarantoolErrorCode::IllegalParams, "fiber name may not contain nul-bytes: nul byte found in provided data at position: {pos}");
return Err(TarantoolError::last().into());
}
unsafe {
let l = ffi::luaT_state();
lua::lua_getglobal(l, c_ptr!("require"));
lua::lua_pushstring(l, c_ptr!("fiber"));
impl_details::guarded_pcall(l, 1, 1)?;
lua::lua_getfield(l, -1, c_ptr!("new"));
impl_details::push_userdata(l, f);
lua::lua_pushcclosure(l, Self::trampoline_for_lua, 1);
impl_details::guarded_pcall(l, 1, 1).inspect_err(|_| {
lua::lua_pop(l, 1);
})?;
lua::lua_getfield(l, -1, c_ptr!("set_joinable"));
lua::lua_pushvalue(l, -2); lua::lua_pushboolean(l, true as _);
impl_details::guarded_pcall(l, 2, 0) .map_err(|e| panic!("{}", e))
.unwrap();
lua::lua_getfield(l, -1, c_ptr!("name"));
lua::lua_pushvalue(l, -2); lua::lua_pushlstring(l, name.as_ptr() as _, name.len());
impl_details::guarded_pcall(l, 2, 0) .map_err(|e| panic!("{}", e))
.unwrap();
lua::lua_getfield(l, -1, c_ptr!("id"));
lua::lua_insert(l, -2); impl_details::guarded_pcall(l, 1, 1) .expect("lua error");
let fiber_id = lua::lua_tointeger(l, -1);
lua::lua_pop(l, 2);
Ok(JoinHandle::lua(fiber_id as _))
}
}
unsafe extern "C-unwind" fn trampoline_for_lua(l: *mut lua::lua_State) -> i32 {
let ud_ptr = lua::lua_touserdata(l, lua::lua_upvalueindex(1));
let f = (ud_ptr as *mut Option<F>)
.as_mut()
.unwrap_or_else(||
tlua::error!(l, "failed to extract upvalue"))
.take()
.unwrap_or_else(||
tlua::error!(l, "rust FnOnce callback was called more than once"));
let res = f();
if needs_returning::<T>() {
impl_details::push_userdata(l, res);
1
} else {
0
}
}
}
mod impl_details {
use super::*;
use crate::tlua::{AsLua, LuaError, PushGuard, StaticLua};
pub(super) unsafe fn lua_error_from_top(l: *mut lua::lua_State) -> LuaError {
let mut len = std::mem::MaybeUninit::uninit();
let data = lua::lua_tolstring(l, -1, len.as_mut_ptr());
assert!(!data.is_null());
let msg_bytes = std::slice::from_raw_parts(data as *mut u8, len.assume_init());
let msg = String::from_utf8_lossy(msg_bytes);
tlua::LuaError::ExecutionError(msg)
}
pub(super) unsafe fn guarded_pcall(
lptr: *mut lua::lua_State,
nargs: i32,
nresults: i32,
) -> crate::Result<()> {
match lua::lua_pcall(lptr, nargs, nresults, 0) {
lua::LUA_OK => Ok(()),
lua::LUA_ERRRUN => {
let err = lua_error_from_top(lptr).into();
lua::lua_pop(lptr, 1);
Err(err)
}
code => panic!("lua_pcall: Unrecoverable failure code: {}", code),
}
}
pub(super) unsafe fn lua_fiber_join(f_id: FiberId) -> crate::Result<PushGuard<StaticLua>> {
let lua = crate::global_lua();
let l = lua.as_lua();
let top_svp = lua::lua_gettop(l);
lua::lua_getglobal(l, c_ptr!("require"));
lua::lua_pushstring(l, c_ptr!("fiber"));
impl_details::guarded_pcall(l, 1, 1)?;
lua::lua_getfield(l, -1, c_ptr!("join"));
lua::lua_pushinteger(l, f_id as _);
guarded_pcall(l, 1, 2).inspect_err(|_| {
lua::lua_pop(l, 1);
})?;
let top = lua::lua_gettop(l);
debug_assert_eq!(top - top_svp, 3);
let guard = PushGuard::new(lua, 3);
debug_assert_ne!(lua::lua_toboolean(l, -2), 0);
Ok(guard)
}
pub(super) unsafe fn push_userdata<T>(lua: tlua::LuaState, value: T) {
use tlua::ffi;
type UDBox<T> = Option<T>;
let ud_ptr = ffi::lua_newuserdata(lua, std::mem::size_of::<UDBox<T>>());
std::ptr::write(ud_ptr.cast::<UDBox<T>>(), Some(value));
if std::mem::needs_drop::<T>() {
ffi::lua_newtable(lua);
ffi::lua_pushstring(lua, c_ptr!("__gc"));
ffi::lua_pushcfunction(lua, wrap_gc::<T>);
ffi::lua_settable(lua, -3);
ffi::lua_setmetatable(lua, -2);
}
unsafe extern "C-unwind" fn wrap_gc<T>(lua: *mut ffi::lua_State) -> i32 {
let ud_ptr = ffi::lua_touserdata(lua, 1);
let ud = ud_ptr
.cast::<UDBox<T>>()
.as_mut()
.expect("__gc called with userdata pointing to NULL");
drop(ud.take());
0
}
}
}
pub struct NoFunc;
#[derive(PartialEq, Eq, Hash)]
pub struct JoinHandle<'f, T> {
inner: Option<JoinHandleImpl<T>>,
marker: PhantomData<&'f ()>,
}
impl<T> std::fmt::Debug for JoinHandle<'_, T> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("JoinHandle").finish_non_exhaustive()
}
}
#[deprecated = "Use `fiber::JoinHandle<'f, ()>` instead"]
pub type UnitJoinHandle<'f> = JoinHandle<'f, ()>;
#[deprecated = "Use `fiber::JoinHandle<'f, T>` instead"]
pub type LuaJoinHandle<'f, T> = JoinHandle<'f, T>;
#[deprecated = "Use `fiber::JoinHandle<'f, ()>` instead"]
pub type LuaUnitJoinHandle<'f> = JoinHandle<'f, ()>;
#[derive(Debug)]
enum JoinHandleImpl<T> {
Ffi {
fiber: NonNull<ffi::Fiber>,
result_cell: Option<FiberResultCell<T>>,
},
#[rustfmt::skip] Lua {
fiber_id: FiberId,
},
}
type FiberResultCell<T> = Box<UnsafeCell<Option<T>>>;
impl<T> JoinHandle<'_, T> {
#[inline(always)]
fn ffi(fiber: NonNull<ffi::Fiber>, result_cell: Option<FiberResultCell<T>>) -> Self {
Self {
inner: Some(JoinHandleImpl::Ffi { fiber, result_cell }),
marker: PhantomData,
}
}
#[inline(always)]
fn lua(fiber_id: FiberId) -> Self {
Self {
inner: Some(JoinHandleImpl::Lua { fiber_id }),
marker: PhantomData,
}
}
#[rustfmt::skip]
pub fn join(mut self) -> T {
let inner = self
.inner
.take()
.expect("after construction join is called at most once");
match inner {
JoinHandleImpl::Ffi { fiber, mut result_cell, .. } => {
let code = unsafe { ffi::fiber_join(fiber.as_ptr()) };
debug_assert_eq!(code, 0, "rust fiber functions always return 0");
if needs_returning::<T>() {
let mut result_cell = result_cell.take().expect("should not be None for non unit types");
let res = result_cell.get_mut().take().expect("should have been set by the fiber function");
return res;
}
debug_assert!(result_cell.is_none());
}
JoinHandleImpl::Lua { fiber_id } => unsafe {
let guard = impl_details::lua_fiber_join(fiber_id)
.map_err(|e| panic!("Unrecoverable lua failure: {}", e))
.unwrap();
if needs_returning::<T>() {
let ud_ptr = lua::lua_touserdata(guard.as_lua(), -1);
let res = (ud_ptr as *mut Option<T>)
.as_mut()
.expect("fiber:join must return correct userdata")
.take()
.expect("data can only be taken once from the UDBox");
return res;
}
debug_assert!(lua::lua_isnil(guard.as_lua(), -1));
},
}
#[allow(clippy::uninit_assumed_init)]
unsafe { std::mem::MaybeUninit::uninit().assume_init() }
}
#[inline(always)]
#[track_caller]
pub fn id(&self) -> FiberId {
self.id_checked().expect("fiber_id api is not supported")
}
pub fn id_checked(&self) -> Option<FiberId> {
match self.inner {
None => {
unreachable!("it has either been moved into JoinHandle::join, or been dropped")
}
Some(JoinHandleImpl::Ffi { fiber, .. }) => {
if unsafe { !has_fiber_id() } {
return None;
}
let res = unsafe { ffi::fiber_id(fiber.as_ptr()) };
return Some(res);
}
Some(JoinHandleImpl::Lua { fiber_id, .. }) => Some(fiber_id),
}
}
pub fn cancel(&self) {
match self.inner {
None => {
unreachable!("it has either been moved into JoinHandle::join, or been dropped")
}
Some(JoinHandleImpl::Ffi { fiber, .. }) => {
unsafe {
ffi::fiber_cancel(fiber.as_ptr());
}
}
Some(JoinHandleImpl::Lua { fiber_id, .. }) => {
let found = cancel(fiber_id);
debug_assert!(
found,
"non-joinable fiber has been recycled before being joined"
);
}
}
}
pub fn wakeup(&self) {
match self.inner {
None => {
unreachable!("it has either been moved into JoinHandle::join, or been dropped")
}
Some(JoinHandleImpl::Ffi { fiber, .. }) => {
unsafe {
ffi::fiber_wakeup(fiber.as_ptr());
}
}
Some(JoinHandleImpl::Lua { fiber_id, .. }) => {
let found = wakeup(fiber_id);
debug_assert!(
found,
"non-joinable fiber has been recycled before being joined"
);
}
}
}
}
impl<T> Drop for JoinHandle<'_, T> {
fn drop(&mut self) {
if let Some(mut inner) = self.inner.take() {
if let JoinHandleImpl::Ffi { result_cell, .. } = &mut inner {
std::mem::forget(result_cell.take());
}
panic!("JoinHandle dropped before being joined")
}
}
}
#[rustfmt::skip]
impl<T> ::std::cmp::PartialEq for JoinHandleImpl<T> {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Ffi { fiber: self_fiber, .. }, Self::Ffi { fiber: other_fiber, .. },) => {
self_fiber == other_fiber
}
(Self::Lua { fiber_id: self_id, .. }, Self::Lua { fiber_id: other_id, .. },) => {
self_id == other_id
}
(_, _) => false,
}
}
}
impl<T> ::std::cmp::Eq for JoinHandleImpl<T> {}
impl<T> ::std::hash::Hash for JoinHandleImpl<T> {
fn hash<H>(&self, state: &mut H)
where
H: ::std::hash::Hasher,
{
match self {
Self::Ffi { fiber, .. } => fiber.hash(state),
Self::Lua { fiber_id, .. } => fiber_id.hash(state),
}
}
}
#[inline(always)]
pub fn start<'f, F, T>(f: F) -> JoinHandle<'f, T>
where
F: FnOnce() -> T,
F: 'f,
T: 'f,
{
Builder::new().func(f).start().unwrap()
}
#[inline(always)]
pub fn start_async<'f, F, T>(f: F) -> JoinHandle<'f, T>
where
F: Future<Output = T> + 'f,
T: 'f,
{
start(|| block_on(f))
}
#[deprecated = "Use `fiber::start` instead"]
#[inline(always)]
pub fn start_proc<'f, F>(f: F) -> JoinHandle<'f, ()>
where
F: FnOnce(),
F: 'f,
{
start(f)
}
#[inline(always)]
pub fn defer<'f, F, T>(f: F) -> JoinHandle<'f, T>
where
F: FnOnce() -> T,
F: 'f,
T: 'f,
{
Builder::new().func(f).defer().unwrap()
}
#[inline(always)]
pub fn defer_async<'f, F, T>(f: F) -> JoinHandle<'f, T>
where
F: Future<Output = T> + 'f,
T: 'f,
{
defer(|| block_on(f))
}
#[deprecated = "Use `fiber::defer` instead"]
#[inline(always)]
pub fn defer_proc<'f, F>(f: F) -> JoinHandle<'f, ()>
where
F: FnOnce(),
F: 'f,
{
defer(f)
}
#[inline(always)]
pub fn set_cancellable(is_cancellable: bool) -> bool {
unsafe { ffi::fiber_set_cancellable(is_cancellable) }
}
#[inline(always)]
pub fn is_cancelled() -> bool {
unsafe { ffi::fiber_is_cancelled() }
}
#[inline(always)]
pub fn cancel(id: FiberId) -> bool {
if unsafe { has_fiber_id() } {
let f = unsafe { ffi::fiber_find(id) };
if f.is_null() {
return false;
}
unsafe { ffi::fiber_cancel(f) };
return true;
} else {
let lua = crate::global_lua();
let res: bool = lua
.eval_with("return pcall(require 'fiber'.cancel, ...)", id)
.expect("lua error");
return res;
}
}
#[inline(always)]
pub fn wakeup(id: FiberId) -> bool {
if unsafe { has_fiber_id() } {
let f = unsafe { ffi::fiber_find(id) };
if f.is_null() {
return false;
}
unsafe { ffi::fiber_wakeup(f) };
return true;
} else {
let lua = crate::global_lua();
let res: bool = lua
.eval_with("return pcall(require 'fiber'.wakeup, ...)", id)
.expect("lua error");
return res;
}
}
#[inline(always)]
pub fn sleep(time: Duration) {
unsafe { ffi::fiber_sleep(time.as_secs_f64()) }
}
#[inline(always)]
pub fn clock() -> Instant {
Instant::now_fiber()
}
#[inline(always)]
pub fn fiber_yield() {
unsafe { ffi::fiber_yield() }
}
#[inline(always)]
pub fn r#yield() -> crate::Result<()> {
unsafe { fiber_sleep(0f64) };
if is_cancelled() {
set_error!(TarantoolErrorCode::ProcLua, "fiber is cancelled");
return Err(TarantoolError::last().into());
}
Ok(())
}
#[inline(always)]
pub fn reschedule() {
unsafe { ffi::fiber_reschedule() }
}
#[inline(always)]
pub fn exists(id: FiberId) -> bool {
if unsafe { has_fiber_id() } {
return unsafe { !ffi::fiber_find(id).is_null() };
} else {
crate::global_lua()
.eval_with("return require'fiber'.find(...) ~= nil", id)
.expect("lua error")
}
}
#[inline]
pub fn id() -> FiberId {
if unsafe { has_fiber_id() } {
return unsafe { ffi::fiber_id(std::ptr::null_mut()) };
} else {
crate::global_lua()
.eval("return require'fiber'.id()")
.expect("lua error")
}
}
#[inline]
pub fn csw() -> u64 {
if unsafe { has_fiber_id() } {
unsafe { ffi::fiber_csw(std::ptr::null_mut()) }
} else {
csw::csw_lua(None).expect("fiber.self() should always work")
}
}
#[inline]
pub fn csw_of(id: FiberId) -> Option<u64> {
if unsafe { has_fiber_id() } {
unsafe {
let f = ffi::fiber_find(id);
if f.is_null() {
return None;
}
let res = ffi::fiber_csw(f);
return Some(res);
}
} else {
csw::csw_lua(Some(id))
}
}
#[inline]
pub fn name() -> String {
let name = unsafe { name_raw(None) }.expect("fiber_self should always work");
String::from_utf8_lossy(name).into()
}
#[inline]
pub fn name_of(id: FiberId) -> Option<String> {
let name = unsafe { name_raw(Some(id)) }?;
let res = String::from_utf8_lossy(name).into();
Some(res)
}
pub unsafe fn name_raw(id: Option<FiberId>) -> Option<&'static [u8]> {
if has_fiber_id() {
let mut f = std::ptr::null_mut();
if let Some(id) = id {
f = ffi::fiber_find(id);
if f.is_null() {
return None;
}
}
let p = ffi::fiber_name(f);
let cstr = std::ffi::CStr::from_ptr(p as _);
Some(cstr.to_bytes())
} else {
let lua = crate::global_lua();
let s: Option<tlua::StringInLua<_>> = lua
.eval_with(
"local fiber = require'fiber'
local f = fiber.find(... or fiber.id())
return f and f:name()",
id,
)
.expect("lua error");
let s = s?;
let res: &'static [u8] = std::mem::transmute(s.as_bytes());
Some(res)
}
}
#[inline]
pub fn set_name(name: &str) {
if unsafe { has_fiber_id() } {
unsafe { ffi::fiber_set_name_n(std::ptr::null_mut(), name.as_ptr(), name.len() as _) }
} else {
let lua = crate::global_lua();
lua.exec_with("require'fiber'.name(...)", name)
.expect("lua error");
}
}
#[inline]
pub fn set_name_of(id: FiberId, name: &str) -> bool {
if unsafe { has_fiber_id() } {
unsafe {
let f = ffi::fiber_find(id);
if f.is_null() {
return false;
}
ffi::fiber_set_name_n(f, name.as_ptr(), name.len() as _);
return true;
}
} else {
let lua = crate::global_lua();
let res: bool = lua
.eval_with(
"local fiber = require'fiber'
local id, name = ...
local f = fiber.find(id)
if f == nil then
return false
end
f:name(name)
return true",
(id, name),
)
.expect("lua error");
return res;
}
}
#[derive(Debug)]
pub struct FiberAttr {
inner: *mut ffi::FiberAttr,
}
impl FiberAttr {
#[inline(always)]
pub fn new() -> Self {
FiberAttr {
inner: unsafe { ffi::fiber_attr_new() },
}
}
#[inline(always)]
pub fn stack_size(&self) -> usize {
unsafe { ffi::fiber_attr_getstacksize(self.inner) }
}
#[inline(always)]
pub fn set_stack_size(&mut self, stack_size: usize) -> crate::Result<()> {
if unsafe { ffi::fiber_attr_setstacksize(self.inner, stack_size) } < 0 {
Err(TarantoolError::last().into())
} else {
Ok(())
}
}
}
impl Default for FiberAttr {
#[inline(always)]
fn default() -> Self {
Self::new()
}
}
impl Drop for FiberAttr {
#[inline(always)]
fn drop(&mut self) {
unsafe { ffi::fiber_attr_delete(self.inner) }
}
}
#[derive(Debug)]
pub struct Cond {
inner: *mut ffi::FiberCond,
}
impl Cond {
#[inline(always)]
pub fn new() -> Self {
Cond {
inner: unsafe { ffi::fiber_cond_new() },
}
}
#[inline(always)]
pub fn signal(&self) {
unsafe { ffi::fiber_cond_signal(self.inner) }
}
#[inline(always)]
pub fn broadcast(&self) {
unsafe { ffi::fiber_cond_broadcast(self.inner) }
}
#[inline(always)]
pub fn wait_timeout(&self, timeout: Duration) -> bool {
unsafe { ffi::fiber_cond_wait_timeout(self.inner, timeout.as_secs_f64()) >= 0 }
}
#[inline(always)]
pub fn wait_deadline(&self, deadline: Instant) -> bool {
let timeout = deadline.duration_since(clock());
unsafe { ffi::fiber_cond_wait_timeout(self.inner, timeout.as_secs_f64()) >= 0 }
}
#[inline(always)]
pub fn wait(&self) -> bool {
unsafe { ffi::fiber_cond_wait(self.inner) >= 0 }
}
}
impl Default for Cond {
#[inline(always)]
fn default() -> Self {
Self::new()
}
}
impl Drop for Cond {
#[inline(always)]
fn drop(&mut self) {
unsafe { ffi::fiber_cond_delete(self.inner) }
}
}
#[derive(Debug)]
pub struct Latch {
inner: *mut ffi::Latch,
}
impl Latch {
#[inline(always)]
pub fn new() -> Self {
Latch {
inner: unsafe { ffi::box_latch_new() },
}
}
#[inline(always)]
pub fn lock(&self) -> LatchGuard {
unsafe { ffi::box_latch_lock(self.inner) };
LatchGuard {
latch_inner: self.inner,
}
}
#[inline(always)]
pub fn try_lock(&self) -> Option<LatchGuard> {
if unsafe { ffi::box_latch_trylock(self.inner) } == 0 {
Some(LatchGuard {
latch_inner: self.inner,
})
} else {
None
}
}
}
impl Default for Latch {
#[inline(always)]
fn default() -> Self {
Self::new()
}
}
impl Drop for Latch {
#[inline(always)]
fn drop(&mut self) {
unsafe { ffi::box_latch_delete(self.inner) }
}
}
#[derive(Debug)]
pub struct LatchGuard {
latch_inner: *mut ffi::Latch,
}
impl Drop for LatchGuard {
#[inline(always)]
fn drop(&mut self) {
unsafe { ffi::box_latch_unlock(self.latch_inner) }
}
}
#[inline]
pub unsafe fn context_is_valid(context: *mut Context) -> bool {
if context as usize == 0 {
return false;
}
if (context as usize) % CONTEXT_ALIGNMENT != 0 {
return false;
}
let magic_ptr = std::ptr::addr_of!((*context).magic);
if *magic_ptr != CONTEXT_MAGIC {
return false;
}
let size_ptr = std::ptr::addr_of!((*context).size);
if *size_ptr != CONTEXT_SIZE {
return false;
}
let version_ptr = std::ptr::addr_of!((*context).version);
if *version_ptr != CONTEXT_VERSION {
return false;
}
true
}
pub const CONTEXT_MAGIC: u64 = 0x69F1BE5C047E8769;
pub const CONTEXT_SIZE: u64 = size_of::<Context>() as _;
pub const CONTEXT_ALIGNMENT: usize = align_of::<Context>() as _;
static_assert!(CONTEXT_ALIGNMENT == 8, "this should never change");
pub const CONTEXT_VERSION: u64 = 2;
#[repr(C)]
pub struct Context {
magic: u64,
size: u64,
version: u64,
fiber_id: FiberId,
fiber_rust_closure: *mut (),
fiber_result_ptr: *mut (),
}
impl std::fmt::Debug for Context {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("Context")
.field("magic", &self.magic)
.field("size", &self.size)
.field("version", &self.version)
.finish_non_exhaustive()
}
}
impl Default for Context {
#[inline(always)]
fn default() -> Self {
Self {
magic: CONTEXT_MAGIC,
size: CONTEXT_SIZE,
version: CONTEXT_VERSION,
fiber_id: FIBER_ID_INVALID,
fiber_rust_closure: std::ptr::null_mut(),
fiber_result_ptr: std::ptr::null_mut(),
}
}
}
pub(crate) unsafe fn unpack_callback<F, T>(callback: &mut F) -> (*mut c_void, ffi::FiberFunc)
where
F: FnMut(Box<T>) -> i32,
{
unsafe extern "C" fn trampoline<F, T>(mut args: VaList) -> i32
where
F: FnMut(Box<T>) -> i32,
{
ffi::fiber_set_cancellable(true);
let closure: &mut F = &mut *(args.get::<*const c_void>() as *mut F);
let boxed_arg = Box::from_raw(args.get::<*const c_void>() as *mut T);
(*closure)(boxed_arg)
}
(callback as *mut F as *mut c_void, Some(trampoline::<F, T>))
}
const fn needs_returning<T>() -> bool {
std::mem::size_of::<T>() != 0 || std::mem::needs_drop::<T>()
}
const _: () = {
assert!(needs_returning::<i32>());
assert!(needs_returning::<bool>());
assert!(!needs_returning::<()>());
struct UnitStruct;
assert!(!needs_returning::<UnitStruct>());
struct DroppableUnitStruct;
impl Drop for DroppableUnitStruct {
fn drop(&mut self) {}
}
assert!(needs_returning::<DroppableUnitStruct>());
};
#[cfg(feature = "internal_test")]
mod tests {
use super::*;
use crate::fiber;
use crate::test::util::LuaStackIntegrityGuard;
use std::cell::Cell;
use std::cell::RefCell;
use std::rc::Rc;
#[crate::test(tarantool = "crate")]
fn builder_async_func() {
let jh = Builder::new().func_async(async { 69 }).start().unwrap();
let res = jh.join();
assert_eq!(res, 69);
}
#[crate::test(tarantool = "crate")]
#[allow(deprecated)]
fn builder_async_proc() {
let res = Rc::new(RefCell::new(0u32));
let res_moved = res.clone();
let jh = Builder::new()
.proc_async(async move {
*res_moved.borrow_mut() = 1;
})
.start()
.unwrap();
jh.join();
assert_eq!(*res.borrow(), 1);
}
#[crate::test(tarantool = "crate")]
fn fiber_sleep_and_clock() {
let before_sleep = clock();
let sleep_for = Duration::from_millis(100);
sleep(sleep_for);
assert!(before_sleep.elapsed() >= sleep_for);
assert!(clock() >= before_sleep);
assert!(clock() - before_sleep >= sleep_for);
}
#[crate::test(tarantool = "crate", should_panic)]
fn start_dont_join_no_use_after_free() {
let f = start(move || {
reschedule();
[0xaa; 4096]
});
drop(f);
}
#[crate::test(tarantool = "crate")]
fn fiber_id() {
fiber::id();
let jh = fiber::start(fiber::reschedule);
if unsafe { has_fiber_id() } {
assert!(jh.id_checked().is_some());
} else {
assert!(jh.id_checked().is_none());
}
jh.join();
}
#[crate::test(tarantool = "crate")]
fn fiber_name() {
const NAME1: &str = "test_fiber_name_1";
const NAME2: &str = "test_fiber_name_2";
if unsafe { has_fiber_id() } {
let jh = fiber::start(|| {
fiber::set_name(NAME1);
assert_eq!(fiber::name(), NAME1);
fiber::reschedule();
assert_eq!(fiber::name(), NAME2);
});
let f_id = jh.id();
assert_eq!(fiber::name_of(f_id).unwrap(), NAME1);
assert!(fiber::set_name_of(f_id, NAME2));
assert_eq!(fiber::name_of(f_id).unwrap(), NAME2);
assert!(fiber::exists(f_id));
jh.join();
assert!(!fiber::exists(f_id));
assert!(fiber::name_of(f_id).is_none());
assert!(!fiber::set_name_of(f_id, "foo"));
} else {
let f_id = Cell::new(None);
let jh = fiber::start(|| {
f_id.set(Some(fiber::id()));
fiber::set_name(NAME1);
assert_eq!(fiber::name(), NAME1);
assert!(fiber::set_name_of(fiber::id(), NAME2));
assert_eq!(fiber::name_of(fiber::id()).unwrap(), NAME2);
assert!(!fiber::set_name_of(0xCAFE_BABE_DEAD_F00D, "foo"));
assert!(fiber::name_of(0xCAFE_BABE_DEAD_F00D).is_none());
});
let f_id = f_id.get().unwrap();
assert!(fiber::exists(f_id));
jh.join();
assert!(!fiber::exists(f_id));
}
}
#[allow(clippy::unusual_byte_groupings)]
#[crate::test(tarantool = "crate")]
fn fiber_csw() {
if unsafe { has_fiber_id() } {
let csw_parent_0 = fiber::csw();
let jh = fiber::defer(|| {
fiber::reschedule();
1337
});
assert_eq!(fiber::csw(), csw_parent_0);
let child_id = jh.id();
let csw_child_0 = fiber::csw_of(child_id).unwrap();
fiber::reschedule();
assert_eq!(fiber::csw(), csw_parent_0 + 1);
assert_eq!(fiber::csw_of(child_id).unwrap(), csw_child_0 + 1);
assert_eq!(jh.join(), 1337);
assert_eq!(fiber::csw(), csw_parent_0 + 2);
assert!(fiber::csw_of(child_id).is_none());
} else {
let csw_parent_0 = fiber::csw();
let jh = fiber::defer(|| {
let csw_0 = fiber::csw_of(fiber::id()).unwrap();
fiber::reschedule();
assert_eq!(fiber::csw_of(fiber::id()).unwrap(), csw_0 + 1);
1337
});
assert_eq!(fiber::csw(), csw_parent_0);
fiber::reschedule();
assert_eq!(fiber::csw(), csw_parent_0 + 1);
assert_eq!(jh.join(), 1337);
assert_eq!(fiber::csw(), csw_parent_0 + 2);
assert!(fiber::csw_of(0xFACE_BEEF_BAD_DEED5).is_none());
}
}
#[crate::test(tarantool = "crate")]
fn start_non_joinable() {
let e = fiber::Builder::new()
.func(|| 10569)
.start_non_joinable()
.unwrap_err();
#[rustfmt::skip]
assert_eq!(e.to_string(), "box error: Unsupported: non-joinable fibers which return a value are not supported");
struct ZeroSizedType; let id = fiber::Builder::new()
.func(|| ZeroSizedType)
.start_non_joinable()
.unwrap();
assert!(!fiber::exists(id));
let id = fiber::Builder::new()
.func(|| {
while !fiber::is_cancelled() {
fiber::fiber_yield();
}
})
.start_non_joinable()
.unwrap();
let csw0 = fiber::csw_of(id).unwrap();
assert!(fiber::wakeup(id));
fiber::reschedule();
assert_eq!(fiber::csw_of(id).unwrap(), csw0 + 1);
assert!(fiber::wakeup(id));
fiber::reschedule();
assert_eq!(fiber::csw_of(id).unwrap(), csw0 + 2);
assert!(fiber::cancel(id));
fiber::reschedule();
assert!(!fiber::exists(id));
assert!(fiber::csw_of(id).is_none());
assert!(!fiber::wakeup(id));
assert!(!fiber::cancel(id));
}
#[crate::test(tarantool = "crate")]
fn defer_non_joinable() {
if unsafe { !crate::ffi::has_fiber_set_ctx() } {
let e = fiber::Builder::new()
.func(|| {})
.defer_non_joinable()
.unwrap_err();
assert_eq!(e.to_string(), "box error: Unsupported: deferred non-joinable fibers are not supported in current tarantool version (fiber_set_ctx API is required)");
return;
}
let e = fiber::Builder::new()
.func(|| 10569)
.defer_non_joinable()
.unwrap_err();
#[rustfmt::skip]
assert_eq!(e.to_string(), "box error: Unsupported: non-joinable fibers which return a value are not supported");
if unsafe { has_fiber_id() } {
struct ZeroSizedType; let id = fiber::Builder::new()
.func(|| ZeroSizedType)
.defer_non_joinable()
.unwrap()
.unwrap();
assert!(fiber::exists(id));
fiber::reschedule();
assert!(!fiber::exists(id));
let is_cancelled = Rc::new(Cell::new(None));
let is_cancelled_tx = is_cancelled.clone();
let id = fiber::Builder::new()
.func(move || is_cancelled_tx.set(Some(fiber::is_cancelled())))
.defer_non_joinable()
.unwrap()
.unwrap();
assert!(fiber::cancel(id));
fiber::reschedule();
assert!(!fiber::exists(id));
assert_eq!(is_cancelled.get(), Some(true));
}
let id = if unsafe { has_fiber_id() } {
fiber::Builder::new()
.func(|| {
while !fiber::is_cancelled() {
fiber::fiber_yield();
}
})
.defer_non_joinable()
.unwrap()
.unwrap()
} else {
let id = Rc::new(Cell::new(None));
let id_tx = id.clone();
let maybe_id = fiber::Builder::new()
.func(move || {
id_tx.set(Some(fiber::id()));
while !fiber::is_cancelled() {
fiber::fiber_yield();
}
})
.defer_non_joinable()
.unwrap();
assert_eq!(maybe_id, None);
assert_eq!(id.get(), None);
fiber::reschedule();
id.get().unwrap()
};
let csw0 = fiber::csw_of(id).unwrap();
assert!(fiber::wakeup(id));
fiber::reschedule();
assert_eq!(fiber::csw_of(id).unwrap(), csw0 + 1);
assert!(fiber::wakeup(id));
fiber::reschedule();
assert_eq!(fiber::csw_of(id).unwrap(), csw0 + 2);
assert!(fiber::cancel(id));
fiber::reschedule();
assert!(!fiber::exists(id));
assert!(fiber::csw_of(id).is_none());
assert!(!fiber::wakeup(id));
assert!(!fiber::cancel(id));
}
#[crate::test(tarantool = "crate")]
fn defer_lua() {
let _guard = LuaStackIntegrityGuard::global("defer_lua");
let jh = Builder::new().func(|| 42).defer_lua().unwrap();
let res = jh.join();
assert_eq!(res, 42);
let jh = Builder::new().func(|| ()).defer_lua().unwrap();
jh.join();
}
#[crate::test(tarantool = "crate")]
fn illegal_fiber_name() {
let e = Builder::new()
.name("nul\0byte")
.func(|| {})
.start()
.unwrap_err();
#[rustfmt::skip]
assert_eq!(e.to_string(), "box error: IllegalParams: fiber name may not contain nul-bytes: nul byte found in provided data at position: 3");
let e = Builder::new()
.name("nul\0byte")
.func(|| {})
.defer()
.unwrap_err();
#[rustfmt::skip]
assert_eq!(e.to_string(), "box error: IllegalParams: fiber name may not contain nul-bytes: nul byte found in provided data at position: 3");
}
#[rustfmt::skip]
#[crate::test(tarantool = "crate")]
fn wakeup_or_cancel_while_waiting_on_cond() {
let cond = Cond::new();
let ch = Channel::new(1);
let fiber_id = Cell::new(None);
let jh = fiber::start(|| {
fiber_id.set(Some(fiber::id()));
ch.send(cond.wait()).unwrap();
ch.send(cond.wait_timeout(crate::clock::INFINITY)).unwrap();
ch.send(cond.wait_deadline(fiber::clock().saturating_add(crate::clock::INFINITY))).unwrap();
ch.send(cond.wait()).unwrap();
ch.send(cond.wait_timeout(crate::clock::INFINITY)).unwrap();
ch.send(cond.wait_deadline(fiber::clock().saturating_add(crate::clock::INFINITY))).unwrap();
});
let fiber_id = fiber_id.get().unwrap();
fiber::wakeup(fiber_id);
assert_eq!(ch.recv().unwrap(), true);
fiber::wakeup(fiber_id);
assert_eq!(ch.recv().unwrap(), true);
fiber::wakeup(fiber_id);
assert_eq!(ch.recv().unwrap(), true);
fiber::cancel(fiber_id);
assert_eq!(ch.try_recv(), Err(TryRecvError::Empty));
fiber::wakeup(fiber_id);
assert_eq!(ch.recv().unwrap(), false);
fiber::wakeup(fiber_id);
assert_eq!(ch.recv().unwrap(), false);
fiber::wakeup(fiber_id);
assert_eq!(ch.recv().unwrap(), false);
jh.join();
}
}