use crate::device_context::with_default_device_policy;
use crate::device_future::DeviceFuture;
use crate::error::{device_error, DeviceError};
use crate::scheduling_policies::SchedulingPolicy;
use cuda_core::{CudaContext, CudaStream};
use std::cell::{Cell, UnsafeCell};
use std::fmt::Debug;
use std::future::IntoFuture;
use std::marker::PhantomData;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
thread_local! {
static DEVICE_OP_EXECUTING: Cell<bool> = const { Cell::new(false) };
}
pub(crate) fn acquire_execution_lock() -> Result<(), DeviceError> {
DEVICE_OP_EXECUTING.with(|flag| {
if flag.get() {
Err(DeviceError::Internal(
"DeviceOp execution is non-reentrant: another DeviceOp is already \
executing on this thread. If this is intentional and you have \
verified there are no cross-stream data races, use \
`then_unchecked`."
.into(),
))
} else {
flag.set(true);
Ok(())
}
})
}
pub(crate) fn release_execution_lock() {
DEVICE_OP_EXECUTING.with(|flag| {
flag.set(false);
});
}
pub type Device = usize;
#[derive(Debug, Clone)]
pub struct ExecutionContext {
device: Device,
cuda_stream: Arc<CudaStream>,
cuda_context: Arc<CudaContext>,
}
impl ExecutionContext {
pub fn new(cuda_stream: Arc<CudaStream>) -> Self {
let cuda_context = cuda_stream.context().clone();
let device = cuda_context.ordinal();
Self {
cuda_stream,
cuda_context,
device,
}
}
pub fn get_cuda_stream(&self) -> &Arc<CudaStream> {
&self.cuda_stream
}
pub fn get_cuda_context(&self) -> &Arc<CudaContext> {
&self.cuda_context
}
pub fn get_device_id(&self) -> Device {
self.device
}
#[expect(
dead_code,
reason = "kept for direct synchronous execution in tests and future blocking APIs"
)]
fn execute<T: Send>(&self, op: impl DeviceOp<Output = T>) -> Result<T, DeviceError> {
unsafe {
op.execute(self)
}
}
}
pub trait DeviceOp:
Send + Sized + IntoFuture<Output = Result<<Self as DeviceOp>::Output, DeviceError>>
{
type Output: Send;
unsafe fn execute(
self,
context: &ExecutionContext,
) -> Result<<Self as DeviceOp>::Output, DeviceError>;
fn schedule(
self,
policy: &Arc<dyn SchedulingPolicy>,
) -> Result<DeviceFuture<<Self as DeviceOp>::Output, Self>, DeviceError> {
let stream = policy.next_stream()?;
let mut future = DeviceFuture::new();
future.device_operation = Some(self);
future.execution_context = Some(ExecutionContext::new(stream));
Ok(future)
}
fn then<O: Send, DO, F>(self, f: F) -> AndThen<<Self as DeviceOp>::Output, Self, O, DO, F>
where
DO: DeviceOp<Output = O>,
F: FnOnce(<Self as DeviceOp>::Output) -> DO,
{
AndThen {
op: self,
closure: f,
}
}
unsafe fn then_unchecked<O: Send, DO, F>(
self,
f: F,
) -> AndThen<<Self as DeviceOp>::Output, Self, O, DO, F>
where
DO: DeviceOp<Output = O>,
F: FnOnce(<Self as DeviceOp>::Output) -> DO,
{
AndThen {
op: self,
closure: f,
}
}
fn map<O: Send, F>(
self,
f: F,
) -> AndThen<
<Self as DeviceOp>::Output,
Self,
O,
Value<O>,
impl FnOnce(<Self as DeviceOp>::Output) -> Value<O> + Send,
>
where
F: FnOnce(<Self as DeviceOp>::Output) -> O + Send,
{
self.then(move |x| value(f(x)))
}
fn inspect<F>(
self,
f: F,
) -> AndThen<
<Self as DeviceOp>::Output,
Self,
<Self as DeviceOp>::Output,
Value<<Self as DeviceOp>::Output>,
impl FnOnce(<Self as DeviceOp>::Output) -> Value<<Self as DeviceOp>::Output> + Send,
>
where
F: FnOnce(&<Self as DeviceOp>::Output) + Send,
{
self.map(move |x| {
f(&x);
x
})
}
fn and_then_with_context<O: Send, DO, F>(
self,
f: F,
) -> AndThenWithContext<<Self as DeviceOp>::Output, Self, O, DO, F>
where
DO: DeviceOp<Output = O>,
F: FnOnce(&ExecutionContext, <Self as DeviceOp>::Output) -> DO,
{
AndThenWithContext {
op: self,
closure: f,
}
}
fn boxed(self) -> BoxedDeviceOp<<Self as DeviceOp>::Output>
where
Self: 'static,
{
BoxedDeviceOp {
inner: Box::new(move |ctx| unsafe { self.execute(ctx) }),
}
}
fn shared(self) -> SharedDeviceOp<<Self as DeviceOp>::Output>
where
Self: 'static,
<Self as DeviceOp>::Output: Sync,
{
SharedDeviceOp {
inner: Arc::new(SharedExec {
computed: AtomicBool::new(false),
op: UnsafeCell::new(Some(Box::new(move |ctx: &ExecutionContext| unsafe {
self.execute(ctx)
}))),
result: UnsafeCell::new(None),
}),
}
}
fn graph(
self,
) -> Result<crate::cuda_graph::CudaGraph<<Self as DeviceOp>::Output>, DeviceError> {
with_default_device_policy(|policy| {
let stream = policy.next_stream()?;
self.graph_on(stream)
})?
}
fn graph_on(
self,
stream: Arc<CudaStream>,
) -> Result<crate::cuda_graph::CudaGraph<<Self as DeviceOp>::Output>, DeviceError> {
crate::cuda_graph::CudaGraph::capture(stream, self)
}
fn sync(self) -> Result<<Self as DeviceOp>::Output, DeviceError> {
with_default_device_policy(|policy| {
let stream = policy.next_stream()?;
self.sync_on(&stream)
})?
}
unsafe fn async_on(
self,
stream: &Arc<CudaStream>,
) -> Result<<Self as DeviceOp>::Output, DeviceError> {
let ctx = ExecutionContext::new(stream.clone());
unsafe { self.execute(&ctx) }
}
fn sync_on(self, stream: &Arc<CudaStream>) -> Result<<Self as DeviceOp>::Output, DeviceError> {
acquire_execution_lock()?;
let ctx = ExecutionContext::new(stream.clone());
let res = unsafe { self.execute(&ctx) };
let sync_res = stream.synchronize();
release_execution_lock();
sync_res?;
res
}
}
pub trait GraphNode: DeviceOp {}
pub struct BoxedDeviceOp<T: Send> {
inner: Box<dyn FnOnce(&ExecutionContext) -> Result<T, DeviceError> + Send>,
}
impl<T: Send> DeviceOp for BoxedDeviceOp<T> {
type Output = T;
unsafe fn execute(self, context: &ExecutionContext) -> Result<T, DeviceError> {
(self.inner)(context)
}
}
impl<T: Send> IntoFuture for BoxedDeviceOp<T> {
type Output = Result<T, DeviceError>;
type IntoFuture = DeviceFuture<T, Self>;
fn into_future(self) -> Self::IntoFuture {
match with_default_device_policy(|policy| {
let stream = policy.next_stream()?;
let mut f = DeviceFuture::new();
f.device_operation = Some(self);
f.execution_context = Some(ExecutionContext::new(stream));
Ok(f)
}) {
Ok(Ok(future)) => future,
Ok(Err(e)) => DeviceFuture::failed(e),
Err(e) => DeviceFuture::failed(e),
}
}
}
struct SharedExec<T: Send + Sync> {
computed: AtomicBool,
op: UnsafeCell<Option<Box<dyn FnOnce(&ExecutionContext) -> Result<T, DeviceError> + Send>>>,
result: UnsafeCell<Option<Arc<T>>>,
}
unsafe impl<T: Send + Sync> Send for SharedExec<T> {}
unsafe impl<T: Send + Sync> Sync for SharedExec<T> {}
pub struct SharedDeviceOp<T: Send + Sync> {
inner: Arc<SharedExec<T>>,
}
impl<T: Send + Sync> Clone for SharedDeviceOp<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
impl<T: Send + Sync> DeviceOp for SharedDeviceOp<T> {
type Output = Arc<T>;
unsafe fn execute(self, context: &ExecutionContext) -> Result<Arc<T>, DeviceError> {
if !self.inner.computed.load(Ordering::Acquire) {
let op = unsafe { (&mut *self.inner.op.get()).take() }.ok_or(DeviceError::Internal(
"SharedDeviceOp: operation already taken".to_string(),
))?;
let result = op(context)?;
unsafe {
*self.inner.result.get() = Some(Arc::new(result));
}
self.inner.computed.store(true, Ordering::Release);
}
Ok(unsafe { (&*self.inner.result.get()).as_ref().unwrap().clone() })
}
}
impl<T: Send + Sync> IntoFuture for SharedDeviceOp<T> {
type Output = Result<Arc<T>, DeviceError>;
type IntoFuture = DeviceFuture<Arc<T>, SharedDeviceOp<T>>;
fn into_future(self) -> Self::IntoFuture {
match with_default_device_policy(|policy| {
let stream = policy.next_stream()?;
let mut f = DeviceFuture::new();
f.device_operation = Some(self);
f.execution_context = Some(ExecutionContext::new(stream));
Ok(f)
}) {
Ok(Ok(future)) => future,
Ok(Err(e)) => DeviceFuture::failed(e),
Err(e) => DeviceFuture::failed(e),
}
}
}
pub fn shared<T: Send + Sync>(val: Arc<T>) -> SharedDeviceOp<T> {
SharedDeviceOp {
inner: Arc::new(SharedExec {
computed: AtomicBool::new(true),
op: UnsafeCell::new(None),
result: UnsafeCell::new(Some(val)),
}),
}
}
pub trait IntoDeviceOp<T: Send> {
type Op: DeviceOp<Output = T>;
fn into_op(self) -> Self::Op;
}
impl<T: Send, DO: DeviceOp<Output = T>> IntoDeviceOp<T> for DO {
type Op = DO;
fn into_op(self) -> DO {
self
}
}
impl<T: Send + Sync + 'static> IntoDeviceOp<Arc<T>> for Arc<T> {
type Op = Value<Arc<T>>;
fn into_op(self) -> Value<Arc<T>> {
value(self)
}
}
impl<T: Send + Sync + 'static> IntoDeviceOp<Arc<T>> for &Arc<T> {
type Op = Value<Arc<T>>;
fn into_op(self) -> Value<Arc<T>> {
value(self.clone())
}
}
macro_rules! impl_into_device_op_scalar {
($($ty:ty),*) => {
$(
impl IntoDeviceOp<$ty> for $ty {
type Op = Value<$ty>;
fn into_op(self) -> Value<$ty> { value(self) }
}
)*
};
}
impl_into_device_op_scalar!(
f32,
f64,
i8,
i16,
i32,
i64,
u8,
u16,
u32,
u64,
usize,
bool,
half::f16,
half::bf16
);
impl<T: cuda_core::DType + Send> IntoDeviceOp<crate::device_buffer::DevicePointer<T>>
for crate::device_buffer::DevicePointer<T>
{
type Op = Value<crate::device_buffer::DevicePointer<T>>;
fn into_op(self) -> Value<crate::device_buffer::DevicePointer<T>> {
value(self)
}
}
pub trait DeviceOpUnwrapArc<T: Send + Sync>: DeviceOp<Output = Arc<T>> + Sized {
fn unwrap_arc(
self,
) -> AndThen<Arc<T>, Self, T, Value<T>, impl FnOnce(Arc<T>) -> Value<T> + Send> {
self.then(|arc| {
value(
Arc::try_unwrap(arc)
.unwrap_or_else(|_| panic!("unwrap_arc: Arc has multiple owners")),
)
})
}
}
impl<T: Send + Sync, DI: DeviceOp<Output = Arc<T>>> DeviceOpUnwrapArc<T> for DI {}
pub struct AndThen<I: Send, DI, O: Send, DO, F>
where
DI: DeviceOp<Output = I>,
DO: DeviceOp<Output = O>,
F: FnOnce(I) -> DO,
{
op: DI,
closure: F,
}
unsafe impl<I: Send, DI, O: Send, DO, F> Send for AndThen<I, DI, O, DO, F>
where
DI: DeviceOp<Output = I>,
DO: DeviceOp<Output = O>,
F: FnOnce(I) -> DO + Send,
{
}
impl<I: Send, DI, O: Send, DO, F> DeviceOp for AndThen<I, DI, O, DO, F>
where
DI: DeviceOp<Output = I>,
DO: DeviceOp<Output = O>,
F: FnOnce(I) -> DO + Send,
{
type Output = O;
unsafe fn execute(
self,
context: &ExecutionContext,
) -> Result<<Self as DeviceOp>::Output, DeviceError> {
let input: I = self.op.execute(context)?;
let output_device_op: DO = (self.closure)(input);
output_device_op.execute(context)
}
}
impl<I: Send, DI, O: Send, DO, F> IntoFuture for AndThen<I, DI, O, DO, F>
where
DI: DeviceOp<Output = I>,
DO: DeviceOp<Output = O>,
F: FnOnce(I) -> DO + Send,
{
type Output = Result<O, DeviceError>;
type IntoFuture = DeviceFuture<O, AndThen<I, DI, O, DO, F>>;
fn into_future(self) -> Self::IntoFuture {
match with_default_device_policy(|policy| {
let stream = policy.next_stream()?;
let mut f = DeviceFuture::new();
f.device_operation = Some(self);
f.execution_context = Some(ExecutionContext::new(stream));
Ok(f)
}) {
Ok(Ok(future)) => future,
Ok(Err(e)) => DeviceFuture::failed(e),
Err(e) => DeviceFuture::failed(e),
}
}
}
pub struct Value<T>(T);
unsafe impl<T> Send for Value<T> {}
impl<T> Value<T> {
pub fn new(value: T) -> Self {
Self(value)
}
}
impl<T: Send> DeviceOp for Value<T> {
type Output = T;
unsafe fn execute(
self,
_context: &ExecutionContext,
) -> Result<<Self as DeviceOp>::Output, DeviceError> {
Ok(self.0)
}
}
impl<T: Send> GraphNode for Value<T> {}
impl<T: Send> IntoFuture for Value<T> {
type Output = Result<T, DeviceError>;
type IntoFuture = DeviceFuture<T, Value<T>>;
fn into_future(self) -> Self::IntoFuture {
match with_default_device_policy(|policy| {
let stream = policy.next_stream()?;
let mut f = DeviceFuture::new();
f.device_operation = Some(self);
f.execution_context = Some(ExecutionContext::new(stream));
Ok(f)
}) {
Ok(Ok(future)) => future,
Ok(Err(e)) => DeviceFuture::failed(e),
Err(e) => DeviceFuture::failed(e),
}
}
}
pub fn value<T: Send>(x: T) -> Value<T> {
Value::new(x)
}
impl From<f32> for Value<f32> {
fn from(val: f32) -> Self {
Value::new(val)
}
}
pub struct Empty<O: Send, DO: DeviceOp<Output = O>, F: FnOnce() -> DO> {
closure: F,
}
pub fn empty<O: Send, DO: DeviceOp<Output = O>, F: FnOnce() -> DO>(closure: F) -> Empty<O, DO, F> {
Empty { closure }
}
unsafe impl<O: Send, DO, F> Send for Empty<O, DO, F>
where
DO: DeviceOp<Output = O>,
F: FnOnce() -> DO,
{
}
impl<O: Send, DO, F> DeviceOp for Empty<O, DO, F>
where
DO: DeviceOp<Output = O>,
F: FnOnce() -> DO,
{
type Output = O;
unsafe fn execute(
self,
context: &ExecutionContext,
) -> Result<<Self as DeviceOp>::Output, DeviceError> {
let out_device_op = (self.closure)();
out_device_op.execute(context)
}
}
impl<O: Send, DO: DeviceOp<Output = O>, F: FnOnce() -> DO> IntoFuture for Empty<O, DO, F> {
type Output = Result<O, DeviceError>;
type IntoFuture = DeviceFuture<O, Empty<O, DO, F>>;
fn into_future(self) -> Self::IntoFuture {
match with_default_device_policy(|policy| {
let stream = policy.next_stream()?;
let mut f = DeviceFuture::new();
f.device_operation = Some(self);
f.execution_context = Some(ExecutionContext::new(stream));
Ok(f)
}) {
Ok(Ok(future)) => future,
Ok(Err(e)) => DeviceFuture::failed(e),
Err(e) => DeviceFuture::failed(e),
}
}
}
pub struct Zip<T1: Send, T2: Send, A: DeviceOp<Output = T1>, B: DeviceOp<Output = T2>> {
phantom: PhantomData<(T1, T2)>,
a: A,
b: B,
}
unsafe impl<T1: Send, T2: Send, A: DeviceOp<Output = T1>, B: DeviceOp<Output = T2>> Send
for Zip<T1, T2, A, B>
{
}
fn _zip<T1: Send, T2: Send, A: DeviceOp<Output = T1>, B: DeviceOp<Output = T2>>(
a: A,
b: B,
) -> Zip<T1, T2, A, B> {
Zip {
phantom: PhantomData,
a,
b,
}
}
impl<T1: Send, T2: Send, A: DeviceOp<Output = T1>, B: DeviceOp<Output = T2>> DeviceOp
for Zip<T1, T2, A, B>
{
type Output = (T1, T2);
unsafe fn execute(
self,
context: &ExecutionContext,
) -> Result<<Self as DeviceOp>::Output, DeviceError> {
let a: T1 = self.a.execute(context)?;
let b: T2 = self.b.execute(context)?;
Ok((a, b))
}
}
impl<T1: Send, T2: Send, A: DeviceOp<Output = T1>, B: DeviceOp<Output = T2>> IntoFuture
for Zip<T1, T2, A, B>
{
type Output = Result<(T1, T2), DeviceError>;
type IntoFuture = DeviceFuture<(T1, T2), Zip<T1, T2, A, B>>;
fn into_future(self) -> Self::IntoFuture {
match with_default_device_policy(|policy| {
let stream = policy.next_stream()?;
let mut f = DeviceFuture::new();
f.device_operation = Some(self);
f.execution_context = Some(ExecutionContext::new(stream));
Ok(f)
}) {
Ok(Ok(future)) => future,
Ok(Err(e)) => DeviceFuture::failed(e),
Err(e) => DeviceFuture::failed(e),
}
}
}
pub trait Zippable<I, O: Send> {
fn zip(self) -> impl DeviceOp<Output = O>;
}
impl<T0: Send, T1: Send, DI0: DeviceOp<Output = T0>, DI1: DeviceOp<Output = T1>>
Zippable<(DI0, DI1), (T0, T1)> for (DI0, DI1)
{
fn zip(self) -> impl DeviceOp<Output = (T0, T1)> {
_zip(self.0, self.1)
}
}
impl<
T0: Send,
T1: Send,
T2: Send,
DI0: DeviceOp<Output = T0>,
DI1: DeviceOp<Output = T1>,
DI2: DeviceOp<Output = T2>,
> Zippable<(DI0, DI1, DI2), (T0, T1, T2)> for (DI0, DI1, DI2)
{
fn zip(self) -> impl DeviceOp<Output = (T0, T1, T2)> {
let cons = _zip(self.1, self.2);
let cons = _zip(self.0, cons);
cons.then(|(arg0, (arg1, arg2))| value((arg0, arg1, arg2)))
}
}
impl<
T0: Send,
T1: Send,
T2: Send,
T3: Send,
DI0: DeviceOp<Output = T0>,
DI1: DeviceOp<Output = T1>,
DI2: DeviceOp<Output = T2>,
DI3: DeviceOp<Output = T3>,
> Zippable<(DI0, DI1, DI2, DI3), (T0, T1, T2, T3)> for (DI0, DI1, DI2, DI3)
{
fn zip(self) -> impl DeviceOp<Output = (T0, T1, T2, T3)> {
let cons = _zip(self.2, self.3);
let cons = _zip(self.1, cons);
let cons = _zip(self.0, cons);
cons.then(|(arg0, (arg1, (arg2, arg3)))| value((arg0, arg1, arg2, arg3)))
}
}
impl<
T0: Send,
T1: Send,
T2: Send,
T3: Send,
T4: Send,
DI0: DeviceOp<Output = T0>,
DI1: DeviceOp<Output = T1>,
DI2: DeviceOp<Output = T2>,
DI3: DeviceOp<Output = T3>,
DI4: DeviceOp<Output = T4>,
> Zippable<(DI0, DI1, DI2, DI3, DI4), (T0, T1, T2, T3, T4)> for (DI0, DI1, DI2, DI3, DI4)
{
fn zip(self) -> impl DeviceOp<Output = (T0, T1, T2, T3, T4)> {
let cons = _zip(self.3, self.4);
let cons = _zip(self.2, cons);
let cons = _zip(self.1, cons);
let cons = _zip(self.0, cons);
cons.then(|(arg0, (arg1, (arg2, (arg3, arg4))))| value((arg0, arg1, arg2, arg3, arg4)))
}
}
impl<
T0: Send,
T1: Send,
T2: Send,
T3: Send,
T4: Send,
T5: Send,
DI0: DeviceOp<Output = T0>,
DI1: DeviceOp<Output = T1>,
DI2: DeviceOp<Output = T2>,
DI3: DeviceOp<Output = T3>,
DI4: DeviceOp<Output = T4>,
DI5: DeviceOp<Output = T5>,
> Zippable<(DI0, DI1, DI2, DI3, DI4, DI5), (T0, T1, T2, T3, T4, T5)>
for (DI0, DI1, DI2, DI3, DI4, DI5)
{
fn zip(self) -> impl DeviceOp<Output = (T0, T1, T2, T3, T4, T5)> {
let cons = _zip(self.4, self.5);
let cons = _zip(self.3, cons);
let cons = _zip(self.2, cons);
let cons = _zip(self.1, cons);
let cons = _zip(self.0, cons);
cons.then(|(arg0, (arg1, (arg2, (arg3, (arg4, arg5)))))| {
value((arg0, arg1, arg2, arg3, arg4, arg5))
})
}
}
#[macro_export]
macro_rules! zip {
($arg0:expr) => {
$arg0
};
($arg0:expr, $arg1:expr) => {
($arg0, $arg1).zip()
};
($arg0:expr, $arg1:expr, $arg2:expr) => {
($arg0, $arg1, $arg2).zip()
};
($arg0:expr, $arg1:expr, $arg2:expr, $arg3:expr) => {
($arg0, $arg1, $arg2, $arg3).zip()
};
($arg0:expr, $arg1:expr, $arg2:expr, $arg3:expr, $arg4:expr) => {
($arg0, $arg1, $arg2, $arg3, $arg4).zip()
};
($arg0:expr, $arg1:expr, $arg2:expr, $arg3:expr, $arg4:expr, $arg5:expr) => {
($arg0, $arg1, $arg2, $arg3, $arg4, $arg5).zip()
};
}
pub use zip;
fn _unzip<T1: Send, T2: Send, DI>(input: DI) -> (SelectLeft<T1, T2, DI>, SelectRight<T1, T2, DI>)
where
DI: DeviceOp<Output = (T1, T2)>,
{
let select = Select {
computed: AtomicBool::new(false),
input: UnsafeCell::new(Some(input)),
left: UnsafeCell::new(None),
right: UnsafeCell::new(None),
};
let select_arc = Arc::new(select);
let out1 = SelectLeft {
select: select_arc.clone(),
};
let out2 = SelectRight { select: select_arc };
(out1, out2)
}
pub struct Select<T1: Send, T2: Send, DI>
where
DI: DeviceOp<Output = (T1, T2)>,
{
computed: AtomicBool,
input: UnsafeCell<Option<DI>>,
left: UnsafeCell<Option<T1>>,
right: UnsafeCell<Option<T2>>,
}
impl<T1: Send, T2: Send, DI> Select<T1, T2, DI>
where
DI: DeviceOp<Output = (T1, T2)>,
{
unsafe fn execute(self: &Arc<Self>, context: &ExecutionContext) -> Result<(), DeviceError> {
if !self.computed.load(Ordering::Acquire) {
let input = unsafe { (&mut *self.input.get()).take() }.ok_or(device_error(
context.get_device_id(),
"Select operation failed.",
))?;
let (left, right) = input.execute(context)?;
unsafe {
*self.left.get() = Some(left);
*self.right.get() = Some(right);
}
self.computed.store(true, Ordering::Release);
}
Ok(())
}
unsafe fn left(&self) -> T1 {
let left = unsafe { (&mut *self.left.get()).take() }.unwrap();
left
}
unsafe fn right(&self) -> T2 {
let right = unsafe { (&mut *self.right.get()).take() }.unwrap();
right
}
}
pub struct SelectLeft<T1: Send, T2: Send, DI>
where
DI: DeviceOp<Output = (T1, T2)>,
{
select: Arc<Select<T1, T2, DI>>,
}
unsafe impl<T1: Send, T2: Send, DI: DeviceOp<Output = (T1, T2)>> Send for SelectLeft<T1, T2, DI> {}
impl<T1: Send, T2: Send, DI> IntoFuture for SelectLeft<T1, T2, DI>
where
DI: DeviceOp<Output = (T1, T2)>,
{
type Output = Result<T1, DeviceError>;
type IntoFuture = DeviceFuture<T1, SelectLeft<T1, T2, DI>>;
fn into_future(self) -> Self::IntoFuture {
match with_default_device_policy(|policy| {
let stream = policy.next_stream()?;
let mut f = DeviceFuture::new();
f.device_operation = Some(self);
f.execution_context = Some(ExecutionContext::new(stream));
Ok(f)
}) {
Ok(Ok(future)) => future,
Ok(Err(e)) => DeviceFuture::failed(e),
Err(e) => DeviceFuture::failed(e),
}
}
}
impl<T1: Send, T2: Send, DI> DeviceOp for SelectLeft<T1, T2, DI>
where
DI: DeviceOp<Output = (T1, T2)>,
{
type Output = T1;
unsafe fn execute(
self,
context: &ExecutionContext,
) -> Result<<Self as DeviceOp>::Output, DeviceError> {
self.select.execute(context)?;
Ok(self.select.left())
}
}
pub struct SelectRight<T1: Send, T2: Send, DI>
where
DI: DeviceOp<Output = (T1, T2)>,
{
select: Arc<Select<T1, T2, DI>>,
}
unsafe impl<T1: Send, T2: Send, DI: DeviceOp<Output = (T1, T2)>> Send for SelectRight<T1, T2, DI> {}
impl<T1: Send, T2: Send, DI> IntoFuture for SelectRight<T1, T2, DI>
where
DI: DeviceOp<Output = (T1, T2)>,
{
type Output = Result<T2, DeviceError>;
type IntoFuture = DeviceFuture<T2, SelectRight<T1, T2, DI>>;
fn into_future(self) -> Self::IntoFuture {
match with_default_device_policy(|policy| {
let stream = policy.next_stream()?;
let mut f = DeviceFuture::new();
f.device_operation = Some(self);
f.execution_context = Some(ExecutionContext::new(stream));
Ok(f)
}) {
Ok(Ok(future)) => future,
Ok(Err(e)) => DeviceFuture::failed(e),
Err(e) => DeviceFuture::failed(e),
}
}
}
impl<T1: Send, T2: Send, DI> DeviceOp for SelectRight<T1, T2, DI>
where
DI: DeviceOp<Output = (T1, T2)>,
{
type Output = T2;
unsafe fn execute(
self,
context: &ExecutionContext,
) -> Result<<Self as DeviceOp>::Output, DeviceError> {
self.select.execute(context)?;
Ok(self.select.right())
}
}
pub trait Unzippable1<T0: Send>
where
Self: DeviceOp<Output = (T0,)>,
{
fn unzip(self) -> (impl DeviceOp<Output = T0>,) {
(self.then(|(r,)| value(r)),)
}
}
impl<T0: Send, DI: DeviceOp<Output = (T0,)>> Unzippable1<T0> for DI {}
pub trait Unzippable2<T0: Send, T1: Send>
where
Self: DeviceOp<Output = (T0, T1)>,
{
fn unzip(self) -> (impl DeviceOp<Output = T0>, impl DeviceOp<Output = T1>) {
_unzip(self)
}
fn first(self) -> impl DeviceOp<Output = T0>
where
Self: Sized,
{
self.then(|(first, _)| value(first))
}
fn last(self) -> impl DeviceOp<Output = T1>
where
Self: Sized,
{
self.then(|(_, last)| value(last))
}
}
impl<T0: Send, T1: Send, DI: DeviceOp<Output = (T0, T1)>> Unzippable2<T0, T1> for DI {}
pub trait Unzippable3<T0: Send, T1: Send, T2: Send>
where
Self: DeviceOp<Output = (T0, T1, T2)>,
{
fn unzip(
self,
) -> (
impl DeviceOp<Output = T0>,
impl DeviceOp<Output = T1>,
impl DeviceOp<Output = T2>,
) {
let cons = self.then(|(arg0, arg1, arg2)| value((arg0, (arg1, arg2))));
let (car, cdr) = _unzip(cons);
let (cdr_car, cdr_cdr) = _unzip(cdr);
(car, cdr_car, cdr_cdr)
}
fn first(self) -> impl DeviceOp<Output = T0>
where
Self: Sized,
{
self.then(|(first, _, _)| value(first))
}
fn last(self) -> impl DeviceOp<Output = T2>
where
Self: Sized,
{
self.then(|(_, _, last)| value(last))
}
}
impl<T0: Send, T1: Send, T2: Send, DI: DeviceOp<Output = (T0, T1, T2)>> Unzippable3<T0, T1, T2>
for DI
{
}
pub trait Unzippable4<T0: Send, T1: Send, T2: Send, T3: Send>
where
Self: DeviceOp<Output = (T0, T1, T2, T3)>,
{
fn unzip(
self,
) -> (
impl DeviceOp<Output = T0>,
impl DeviceOp<Output = T1>,
impl DeviceOp<Output = T2>,
impl DeviceOp<Output = T3>,
) {
let cons = self.then(|(a0, a1, a2, a3)| value((a0, (a1, (a2, a3)))));
let (car, cdr) = _unzip(cons);
let (cdr0, cdr1) = _unzip(cdr);
let (cdr1_0, cdr1_1) = _unzip(cdr1);
(car, cdr0, cdr1_0, cdr1_1)
}
fn first(self) -> impl DeviceOp<Output = T0>
where
Self: Sized,
{
self.then(|(first, _, _, _)| value(first))
}
fn last(self) -> impl DeviceOp<Output = T3>
where
Self: Sized,
{
self.then(|(_, _, _, last)| value(last))
}
}
impl<T0: Send, T1: Send, T2: Send, T3: Send, DI: DeviceOp<Output = (T0, T1, T2, T3)>>
Unzippable4<T0, T1, T2, T3> for DI
{
}
pub trait Unzippable5<T0: Send, T1: Send, T2: Send, T3: Send, T4: Send>
where
Self: DeviceOp<Output = (T0, T1, T2, T3, T4)>,
{
fn unzip(
self,
) -> (
impl DeviceOp<Output = T0>,
impl DeviceOp<Output = T1>,
impl DeviceOp<Output = T2>,
impl DeviceOp<Output = T3>,
impl DeviceOp<Output = T4>,
) {
let cons = self.then(|(a0, a1, a2, a3, a4)| value((a0, (a1, (a2, (a3, a4))))));
let (car, cdr) = _unzip(cons);
let (cdr0, cdr1) = _unzip(cdr);
let (cdr1_0, cdr1_1) = _unzip(cdr1);
let (cdr2_0, cdr2_1) = _unzip(cdr1_1);
(car, cdr0, cdr1_0, cdr2_0, cdr2_1)
}
fn first(self) -> impl DeviceOp<Output = T0>
where
Self: Sized,
{
self.then(|(first, _, _, _, _)| value(first))
}
fn last(self) -> impl DeviceOp<Output = T4>
where
Self: Sized,
{
self.then(|(_, _, _, _, last)| value(last))
}
}
impl<
T0: Send,
T1: Send,
T2: Send,
T3: Send,
T4: Send,
DI: DeviceOp<Output = (T0, T1, T2, T3, T4)>,
> Unzippable5<T0, T1, T2, T3, T4> for DI
{
}
pub trait Unzippable6<T0: Send, T1: Send, T2: Send, T3: Send, T4: Send, T5: Send>
where
Self: DeviceOp<Output = (T0, T1, T2, T3, T4, T5)>,
{
fn unzip(
self,
) -> (
impl DeviceOp<Output = T0>,
impl DeviceOp<Output = T1>,
impl DeviceOp<Output = T2>,
impl DeviceOp<Output = T3>,
impl DeviceOp<Output = T4>,
impl DeviceOp<Output = T5>,
) {
let cons = self.then(|(a0, a1, a2, a3, a4, a5)| value((a0, (a1, (a2, (a3, (a4, a5)))))));
let (car, cdr) = _unzip(cons);
let (cdr0, cdr1) = _unzip(cdr);
let (cdr1_0, cdr1_1) = _unzip(cdr1);
let (cdr2_0, cdr2_1) = _unzip(cdr1_1);
let (cdr3_0, cdr3_1) = _unzip(cdr2_1);
(car, cdr0, cdr1_0, cdr2_0, cdr3_0, cdr3_1)
}
fn first(self) -> impl DeviceOp<Output = T0>
where
Self: Sized,
{
self.then(|(first, _, _, _, _, _)| value(first))
}
fn last(self) -> impl DeviceOp<Output = T5>
where
Self: Sized,
{
self.then(|(_, _, _, _, _, last)| value(last))
}
}
impl<
T0: Send,
T1: Send,
T2: Send,
T3: Send,
T4: Send,
T5: Send,
DI: DeviceOp<Output = (T0, T1, T2, T3, T4, T5)>,
> Unzippable6<T0, T1, T2, T3, T4, T5> for DI
{
}
#[macro_export]
macro_rules! unzip {
($arg0:expr) => {
$arg0.unzip()
};
}
pub use unzip;
pub struct StreamOperation<
O: Send,
DO: DeviceOp<Output = O>,
F: FnOnce(&ExecutionContext) -> DO + Send,
> {
f: F,
}
impl<O: Send, DO: DeviceOp<Output = O>, F: FnOnce(&ExecutionContext) -> DO + Send> DeviceOp
for StreamOperation<O, DO, F>
{
type Output = O;
unsafe fn execute(
self,
context: &ExecutionContext,
) -> Result<<Self as DeviceOp>::Output, DeviceError> {
let dop_out: DO = (self.f)(context);
dop_out.execute(context)
}
}
pub fn with_context<
O: Send,
DO: DeviceOp<Output = O>,
F: FnOnce(&ExecutionContext) -> DO + Send,
>(
f: F,
) -> impl DeviceOp<Output = O> {
StreamOperation { f }
}
impl<O: Send, DO: DeviceOp<Output = O>, F: FnOnce(&ExecutionContext) -> DO + Send> IntoFuture
for StreamOperation<O, DO, F>
{
type Output = Result<O, DeviceError>;
type IntoFuture = DeviceFuture<O, StreamOperation<O, DO, F>>;
fn into_future(self) -> Self::IntoFuture {
match with_default_device_policy(|policy| {
let stream = policy.next_stream()?;
let mut f = DeviceFuture::new();
f.device_operation = Some(self);
f.execution_context = Some(ExecutionContext::new(stream));
Ok(f)
}) {
Ok(Ok(future)) => future,
Ok(Err(e)) => DeviceFuture::failed(e),
Err(e) => DeviceFuture::failed(e),
}
}
}
pub struct AndThenWithContext<I: Send, DI, O: Send, DO, F>
where
DI: DeviceOp<Output = I>,
DO: DeviceOp<Output = O>,
F: FnOnce(&ExecutionContext, I) -> DO,
{
op: DI,
closure: F,
}
unsafe impl<I: Send, DI, O: Send, DO, F> Send for AndThenWithContext<I, DI, O, DO, F>
where
DI: DeviceOp<Output = I>,
DO: DeviceOp<Output = O>,
F: FnOnce(&ExecutionContext, I) -> DO + Send,
{
}
impl<I: Send, DI, O: Send, DO, F> DeviceOp for AndThenWithContext<I, DI, O, DO, F>
where
DI: DeviceOp<Output = I>,
DO: DeviceOp<Output = O>,
F: FnOnce(&ExecutionContext, I) -> DO + Send,
{
type Output = O;
unsafe fn execute(
self,
context: &ExecutionContext,
) -> Result<<Self as DeviceOp>::Output, DeviceError> {
let input: I = self.op.execute(context)?;
let output_device_op: DO = (self.closure)(context, input);
output_device_op.execute(context)
}
}
impl<I: Send, DI, O: Send, DO, F> IntoFuture for AndThenWithContext<I, DI, O, DO, F>
where
DI: DeviceOp<Output = I>,
DO: DeviceOp<Output = O>,
F: FnOnce(&ExecutionContext, I) -> DO + Send,
{
type Output = Result<O, DeviceError>;
type IntoFuture = DeviceFuture<O, AndThenWithContext<I, DI, O, DO, F>>;
fn into_future(self) -> Self::IntoFuture {
match with_default_device_policy(|policy| {
let stream = policy.next_stream()?;
let mut f = DeviceFuture::new();
f.device_operation = Some(self);
f.execution_context = Some(ExecutionContext::new(stream));
Ok(f)
}) {
Ok(Ok(future)) => future,
Ok(Err(e)) => DeviceFuture::failed(e),
Err(e) => DeviceFuture::failed(e),
}
}
}
pub struct DeviceOpVec<T: Send> {
ops: Vec<BoxedDeviceOp<T>>,
}
impl<T: Send + 'static> DeviceOpVec<T> {
pub fn empty() -> Self {
Self { ops: Vec::new() }
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
ops: Vec::with_capacity(capacity),
}
}
pub fn new(ops: Vec<BoxedDeviceOp<T>>) -> Self {
Self { ops }
}
pub fn push<DO: DeviceOp<Output = T> + 'static>(&mut self, op: DO) {
self.ops.push(op.boxed());
}
pub fn remove(&mut self, index: usize) -> BoxedDeviceOp<T> {
self.ops.remove(index)
}
pub fn last(&self) -> Option<&BoxedDeviceOp<T>> {
self.ops.last()
}
}
impl<T: Send> DeviceOp for DeviceOpVec<T> {
type Output = Vec<T>;
unsafe fn execute(self, context: &ExecutionContext) -> Result<Vec<T>, DeviceError> {
let mut results = Vec::with_capacity(self.ops.len());
for op in self.ops {
results.push(op.execute(context)?);
}
Ok(results)
}
}
impl<T: Send> IntoFuture for DeviceOpVec<T> {
type Output = Result<Vec<T>, DeviceError>;
type IntoFuture = DeviceFuture<Vec<T>, Self>;
fn into_future(self) -> Self::IntoFuture {
match with_default_device_policy(|policy| {
let stream = policy.next_stream()?;
let mut f = DeviceFuture::new();
f.device_operation = Some(self);
f.execution_context = Some(ExecutionContext::new(stream));
Ok(f)
}) {
Ok(Ok(future)) => future,
Ok(Err(e)) => DeviceFuture::failed(e),
Err(e) => DeviceFuture::failed(e),
}
}
}
impl<T: Send + 'static> From<Vec<BoxedDeviceOp<T>>> for DeviceOpVec<T> {
fn from(ops: Vec<BoxedDeviceOp<T>>) -> Self {
Self::new(ops)
}
}