use crate::clock::{Clock, Scheduler, SchedulerRef};
use crate::core::{Duration, Instant, Listing};
use crate::hammerfest::{HammerfestStore, HammerfestStoreRef};
#[cfg(feature = "sqlx")]
use crate::pg_num::PgU16;
use crate::twinoid::store::TwinoidStoreRef;
use crate::twinoid::TwinoidStore;
use crate::types::{AnyError, WeakError};
use crate::user::{ShortUser, UserIdRef};
use async_trait::async_trait;
use auto_impl::auto_impl;
use core::fmt;
use core::pin::{pin, Pin};
use futures::future::select;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::any::{Any, TypeId};
use std::borrow::Cow;
use std::collections::{HashMap, HashSet};
use std::convert::Infallible;
use std::future::Future;
use std::marker::PhantomData;
use std::sync::atomic::AtomicU16;
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::Notify;
declare_new_uuid! {
pub struct JobId(Uuid);
pub type ParseError = JobIdParseError;
const SQL_NAME = "job_id";
}
declare_new_string! {
pub struct TaskKind(String);
pub type ParseError = TaskKindParseError;
const PATTERN = r"^[A-Z][A-Za-z0-9]{0,31}$";
const SQL_NAME = "task_kind";
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct StoredJob {
pub id: JobId,
pub created_at: Instant,
pub root_task: TaskId,
}
declare_new_uuid! {
pub struct TaskId(Uuid);
pub type ParseError = TaskIdParseError;
const SQL_NAME = "task_id";
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct TaskRevId {
pub id: TaskId,
pub rev: u32,
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct ShortStoredTask {
pub id: TaskId,
pub job_id: JobId,
pub parent: Option<TaskId>,
pub status: TaskStatus,
pub status_message: Option<String>,
pub created_at: Instant,
pub advanced_at: Instant,
pub step_count: u32,
pub running_time: Duration,
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Clone, Debug)]
pub struct StoredTask<Opaque> {
#[cfg_attr(feature = "serde", serde(flatten))]
pub short: ShortStoredTask,
pub state: StoredTaskState<Opaque>,
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Clone, Debug)]
pub struct StoredTaskState<Opaque> {
pub kind: Cow<'static, str>,
pub data_version: u32,
pub options: Opaque,
pub state: Opaque,
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Clone, Debug)]
pub struct UpdateTaskOptions<'a, Opaque> {
pub id: TaskId,
pub current_step: u32,
pub step_time: Duration,
pub status: TaskStatus,
pub status_message: Option<&'a str>,
pub state: Opaque,
}
#[derive(Debug, Clone)]
pub struct AnyBox {
inner: Arc<dyn Any + Send + Sync>,
}
impl AnyBox {
pub fn new<T: Any + Send + Sync>(v: T) -> Self {
Self { inner: Arc::new(v) }
}
}
pub trait WriteOpaque<Opaque> {
type WriteError: std::error::Error;
fn write_opaque(&self) -> Result<Opaque, Self::WriteError>;
}
impl<T> WriteOpaque<AnyBox> for T
where
T: Any + Send + Sync + Clone,
{
type WriteError = Infallible;
fn write_opaque(&self) -> Result<AnyBox, Self::WriteError> {
Ok(AnyBox {
inner: Arc::new(self.clone()),
})
}
}
#[cfg(feature = "serde")]
impl<T> WriteOpaque<serde_json::Value> for T
where
T: Serialize + Clone,
{
type WriteError = serde_json::Error;
fn write_opaque(&self) -> Result<serde_json::Value, Self::WriteError> {
serde_json::to_value(self.clone())
}
}
pub trait ReadOpaque<Opaque> {
type ReadError: std::error::Error;
fn read_opaque(opaque: &Opaque) -> Result<Self, Self::ReadError>
where
Self: Sized;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Error)]
#[error("failed to read `AnyBox` of {expected_type_name:?}: expected `TypeId` = {expected_type_id:?}, actual `TypeId` = {actual_type_id:?}")]
pub struct ReadAnyBoxError {
expected_type_name: &'static str,
expected_type_id: TypeId,
actual_type_id: TypeId,
}
impl<T> ReadOpaque<AnyBox> for T
where
T: Any + Clone,
{
type ReadError = ReadAnyBoxError;
fn read_opaque(opaque: &AnyBox) -> Result<T, Self::ReadError> {
match opaque.inner.downcast_ref::<T>() {
Some(r) => Ok(r.clone()),
None => {
let expected_type_name: &'static str = core::any::type_name::<T>();
let expected_type_id = TypeId::of::<T>();
let actual_type_id = opaque.inner.type_id();
Err(ReadAnyBoxError {
expected_type_name,
expected_type_id,
actual_type_id,
})
}
}
}
}
#[cfg(feature = "serde")]
impl<T> ReadOpaque<serde_json::Value> for T
where
T: for<'de> Deserialize<'de>,
{
type ReadError = serde_json::Error;
fn read_opaque(opaque: &serde_json::Value) -> Result<T, Self::ReadError> {
serde_json::from_value(opaque.clone())
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum TaskPoll<T> {
Ready(T),
Pending,
}
impl<T, O> WriteOpaque<TaskPoll<O>> for TaskPoll<T>
where
T: WriteOpaque<O>,
{
type WriteError = T::WriteError;
fn write_opaque(&self) -> Result<TaskPoll<O>, Self::WriteError> {
match self {
Self::Ready(value) => value.write_opaque().map(TaskPoll::Ready),
Self::Pending => Ok(TaskPoll::Pending),
}
}
}
pub trait Task<Arg> {
const NAME: &'static str;
const VERSION: u32;
type Output;
#[must_use]
fn poll<'afn, 'fut>(&'afn mut self, arg: Arg) -> Pin<Box<dyn Future<Output = TaskPoll<Self::Output>> + Send + 'fut>>
where
'afn: 'fut,
Arg: 'fut,
Self: 'fut;
}
impl<TH, Arg> AsyncFnMut<Arg> for TH
where
TH: Task<Arg>,
{
type Output = TaskPoll<TH::Output>;
fn call_mut<'afn, 'fut>(&'afn mut self, arg: Arg) -> Pin<Box<dyn Future<Output = Self::Output> + Send + 'fut>>
where
'afn: 'fut,
Arg: 'fut,
Self: 'fut,
{
self.poll(arg)
}
}
pub trait AsyncFnMut<Arg> {
type Output;
#[must_use]
fn call_mut<'afn, 'fut>(&'afn mut self, arg: Arg) -> Pin<Box<dyn Future<Output = Self::Output> + Send + 'fut>>
where
'afn: 'fut,
Arg: 'fut,
Self: 'fut;
}
pub trait AsyncFn2<Arg0, Arg1> {
type Output;
#[must_use]
fn call2<'afn, 'fut>(&'afn self, arg0: Arg0, arg1: Arg1) -> Pin<Box<dyn Future<Output = Self::Output> + Send + 'fut>>
where
'afn: 'fut,
Arg0: 'fut,
Arg1: 'fut,
Self: 'fut;
}
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
pub struct OpaqueAsyncFnMutCaller<F, OpaqueOut> {
phantom: PhantomData<fn(F) -> OpaqueOut>,
}
impl<F, OpaqueOut> fmt::Debug for OpaqueAsyncFnMutCaller<F, OpaqueOut> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OpaqueAsyncFnMutCaller")
.field(
"phantom",
&format!(
"PhantomData<fn({}) -> {}>",
core::any::type_name::<F>(),
core::any::type_name::<OpaqueOut>()
),
)
.finish()
}
}
impl<F, OpaqueOut> OpaqueAsyncFnMutCaller<F, OpaqueOut> {
pub fn new() -> Self {
Self { phantom: PhantomData }
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Error)]
pub enum OpaqueAsyncFnMutCallError {
#[error("failed to read opaque function")]
ReadFn(#[source] WeakError),
#[error("failed to write opaque function")]
WriteFn(#[source] WeakError),
#[error("failed to write opaque output")]
WriteOut(#[source] WeakError),
}
impl OpaqueAsyncFnMutCallError {
pub fn read_fn<E: std::error::Error>(e: E) -> Self {
Self::ReadFn(WeakError::wrap(e))
}
pub fn write_fn<E: std::error::Error>(e: E) -> Self {
Self::WriteFn(WeakError::wrap(e))
}
pub fn write_out<E: std::error::Error>(e: E) -> Self {
Self::WriteOut(WeakError::wrap(e))
}
}
impl<'of, F, OpaqueF, Arg, OpaqueOut> AsyncFn2<&'of mut OpaqueF, Arg> for OpaqueAsyncFnMutCaller<F, OpaqueOut>
where
F: AsyncFnMut<Arg> + ReadOpaque<OpaqueF> + WriteOpaque<OpaqueF> + Send,
F::Output: WriteOpaque<OpaqueOut>,
OpaqueF: Send,
Arg: Send,
{
type Output = Result<OpaqueOut, OpaqueAsyncFnMutCallError>;
#[must_use]
fn call2<'afn, 'fut>(
&'afn self,
opaque_afn: &'of mut OpaqueF,
arg: Arg,
) -> Pin<Box<dyn Future<Output = Self::Output> + Send + 'fut>>
where
'afn: 'fut,
&'of OpaqueF: 'fut,
Arg: 'fut,
Self: 'fut,
{
Box::pin(async move {
let mut afn: F = F::read_opaque(&*opaque_afn).map_err(OpaqueAsyncFnMutCallError::read_fn)?;
let out = afn.call_mut(arg).await;
*opaque_afn = afn.write_opaque().map_err(OpaqueAsyncFnMutCallError::write_fn)?;
out.write_opaque().map_err(OpaqueAsyncFnMutCallError::write_out)
})
}
}
declare_new_enum! {
pub enum TaskStatus {
#[str("Complete")]
Complete,
#[str("Available")]
Available,
#[str("Blocked")]
Blocked,
}
pub type ParseError = TaskStatusParseError;
const SQL_NAME = "task_status";
}
impl TaskStatus {
pub fn can_transition_to(self, other: Self) -> bool {
use TaskStatus::*;
match self {
Available | Blocked => true,
Complete => matches!(other, Complete),
}
}
}
declare_new_int! {
pub struct TickSalt(u16);
pub type RangeError = TickSaltRangeError;
const BOUNDS = 0..=65535;
type SqlType = PgU16;
const SQL_NAME = "tick_salt";
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Tick {
pub time: Instant,
pub salt: TickSalt,
}
impl PartialOrd for Tick {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
use core::cmp::Ordering::*;
match Ord::cmp(&self.time, &other.time) {
Equal => None,
ordering @ (Greater | Less) => Some(ordering),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Job {
pub id: JobId,
pub created_at: Instant,
pub created_by: Option<ShortUser>,
pub task: ShortTask,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(PartialEq, Eq, Serialize, Deserialize))]
pub struct ApiTask {
pub id: TaskId,
pub revision: u32,
pub polled_at: Option<Instant>,
pub status: TaskStatus,
pub starvation: i32,
pub kind: TaskKind,
pub kind_version: u32,
pub state: OpaqueTask,
pub output: Option<OpaqueValue>,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct ShortTask {
pub id: TaskId,
pub revision: u32,
pub polled_at: Option<Instant>,
pub status: TaskStatus,
pub starvation: i32,
pub kind: TaskKind,
pub kind_version: u32,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub struct StoreTask<OpaqueTask, OpaqueValue = OpaqueTask> {
pub id: TaskId,
pub revision: u32,
pub kind: TaskKind,
pub kind_version: u32,
pub created_at: Instant,
pub polled_at: Option<Tick>,
pub status: TaskStatus,
pub starvation: i32,
pub state: OpaqueTask,
pub output: Option<OpaqueValue>,
}
impl<OpaqueTask, OpaqueValue> StoreTask<OpaqueTask, OpaqueValue> {
pub const fn rev_id(&self) -> TaskRevId {
TaskRevId {
id: self.id,
rev: self.revision,
}
}
pub fn to_short(&self) -> ShortStoreTask {
ShortStoreTask {
id: self.id,
revision: self.revision,
kind: self.kind.clone(),
kind_version: self.kind_version,
created_at: self.created_at,
polled_at: self.polled_at,
status: self.status,
starvation: self.starvation,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub struct ShortStoreTask {
pub id: TaskId,
pub revision: u32,
pub kind: TaskKind,
pub kind_version: u32,
pub created_at: Instant,
pub polled_at: Option<Tick>,
pub status: TaskStatus,
pub starvation: i32,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub struct StoreJob {
pub id: JobId,
pub created_at: Instant,
pub created_by: Option<UserIdRef>,
pub task: ShortStoreTask,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub struct StoreCreateJob<OpaqueTask> {
pub now: Instant,
pub user: Option<UserIdRef>,
pub kind: TaskKind,
pub kind_version: u32,
pub task: OpaqueTask,
}
#[derive(Error, Debug)]
pub enum StoreCreateJobError {
#[error(transparent)]
Other(AnyError),
}
impl StoreCreateJobError {
pub fn other<E: 'static + std::error::Error + Send + Sync>(e: E) -> Self {
Self::Other(Box::new(e))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)]
pub struct StoreGetJobs {
pub status: Option<TaskStatus>,
pub creator: Option<Option<UserIdRef>>,
pub offset: u32,
pub limit: u32,
}
#[derive(Error, Debug)]
pub enum StoreGetJobsError {
#[error(transparent)]
Other(AnyError),
}
impl StoreGetJobsError {
pub fn other<E: 'static + std::error::Error + Send + Sync>(e: E) -> Self {
Self::Other(Box::new(e))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)]
pub struct StoreGetJob {
pub id: JobId,
}
#[derive(Error, Debug)]
pub enum StoreGetJobError {
#[error("job {0} not found")]
NotFound(JobId),
#[error(transparent)]
Other(AnyError),
}
impl StoreGetJobError {
pub fn other<E: 'static + std::error::Error + Send + Sync>(e: E) -> Self {
Self::Other(Box::new(e))
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub struct StoreGetTask {
pub id: TaskId,
}
#[derive(Error, Debug)]
pub enum StoreGetTaskError {
#[error("task {0} not found")]
NotFound(TaskId),
#[error(transparent)]
Other(AnyError),
}
impl StoreGetTaskError {
pub fn other<E: 'static + std::error::Error + Send + Sync>(e: E) -> Self {
Self::Other(Box::new(e))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)]
pub struct StoreGetTasks {
pub status: Option<TaskStatus>,
pub skip_polled: Option<Tick>,
pub offset: u32,
pub limit: u32,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Error)]
pub enum StoreGetTasksError {
#[error(transparent)]
Other(WeakError),
}
impl StoreGetTasksError {
pub fn other<E: std::error::Error>(e: E) -> Self {
Self::Other(WeakError::wrap(e))
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub struct StoreUpdateTask<OpaqueTask, OpaqueValue> {
pub rev_id: TaskRevId,
pub tick: Tick,
pub status: TaskStatus,
pub starvation: i32,
pub state: OpaqueTask,
pub output: Option<OpaqueValue>,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Error)]
pub enum StoreUpdateTaskError {
#[error("task {0} not found")]
NotFound(TaskId),
#[error("cannot update task {task} due to revision conflict: expected={expected}, actual={actual}")]
Conflict { task: TaskId, expected: u32, actual: u32 },
#[error("task dependency {0} not found")]
DependencyNotFound(TaskId),
#[error("detected circular dependency from task {0}")]
CircularDependency(TaskId),
#[error("updating task {0} leads to overflow of the revision")]
RevisionOverflow(TaskId),
#[error(transparent)]
Other(WeakError),
}
impl StoreUpdateTaskError {
pub fn other<E: std::error::Error>(e: E) -> Self {
Self::Other(WeakError::wrap(e))
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub struct StoreCreateTimer {
pub task_id: TaskId,
pub deadline: Instant,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Error)]
pub enum StoreCreateTimerError {
#[error("task {0} not found")]
NotFound(TaskId),
#[error(transparent)]
Other(WeakError),
}
impl StoreCreateTimerError {
pub fn other<E: std::error::Error>(e: E) -> Self {
Self::Other(WeakError::wrap(e))
}
}
#[derive(Error, Debug)]
pub enum StoreNextTimerError {
#[error(transparent)]
Other(WeakError),
}
impl StoreNextTimerError {
pub fn other<E: std::error::Error>(e: E) -> Self {
Self::Other(WeakError::wrap(e))
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Error)]
pub enum StoreOnTimerError {
#[error(transparent)]
Other(WeakError),
}
impl StoreOnTimerError {
pub fn other<E: std::error::Error>(e: E) -> Self {
Self::Other(WeakError::wrap(e))
}
}
#[async_trait]
#[auto_impl(&, Arc)]
pub trait JobStore<OpaqueTask, OpaqueValue>: Send + Sync {
async fn create_job(&self, cmd: StoreCreateJob<OpaqueTask>) -> Result<StoreJob, StoreCreateJobError>;
async fn get_jobs(&self, query: StoreGetJobs) -> Result<Listing<StoreJob>, StoreGetJobsError>;
async fn get_job(&self, query: StoreGetJob) -> Result<StoreJob, StoreGetJobError>;
async fn get_task(&self, query: StoreGetTask) -> Result<StoreTask<OpaqueTask, OpaqueValue>, StoreGetTaskError>;
async fn get_tasks(
&self,
query: StoreGetTasks,
) -> Result<Listing<StoreTask<OpaqueTask, OpaqueValue>>, StoreGetTasksError>;
async fn update_task(&self, cmd: StoreUpdateTask<OpaqueTask, OpaqueValue>)
-> Result<TaskRevId, StoreUpdateTaskError>;
async fn create_timer(&self, cmd: StoreCreateTimer) -> Result<(), StoreCreateTimerError>;
async fn next_timer(&self) -> Result<Option<Instant>, StoreNextTimerError>;
async fn on_timer(&self, time: Instant) -> Result<(), StoreOnTimerError>;
}
#[cfg(feature = "serde")]
pub type OpaqueTask = serde_json::Value;
#[cfg(not(feature = "serde"))]
pub type OpaqueTask = AnyBox;
#[cfg(feature = "serde")]
pub type OpaqueValue = serde_json::Value;
#[cfg(not(feature = "serde"))]
pub type OpaqueValue = AnyBox;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct TaskHandle<T> {
id: TaskId,
phantom: PhantomData<fn() -> T>,
}
impl<T> TaskHandle<T> {
fn new(id: TaskId) -> Self {
Self {
id,
phantom: PhantomData,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct TimerHandle {
deadline: Instant,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum TaskBlock {
Deadline(Instant),
Task(TaskId),
}
pub enum TaskEvent {
DeadlineReady(Instant),
TaskComplete(TaskId, OpaqueValue),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Sleep {
deadline: Instant,
}
impl Sleep {
pub const fn until(deadline: Instant) -> Self {
Self { deadline }
}
}
impl<'cx, Cx> Task<&'cx mut Cx> for Sleep
where
Cx: TaskCx,
{
const NAME: &'static str = "Sleep";
const VERSION: u32 = 1;
type Output = ();
#[must_use]
fn poll<'afn, 'fut>(
&'afn mut self,
cx: &'cx mut Cx,
) -> Pin<Box<dyn Future<Output = TaskPoll<Self::Output>> + Send + 'fut>>
where
'afn: 'fut,
&'cx mut Cx: 'fut,
Self: 'fut,
{
Box::pin(async move {
if cx.now() >= self.deadline {
TaskPoll::Ready(())
} else {
cx.register_timer(self.deadline);
TaskPoll::Pending
}
})
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Error)]
pub enum TickError {
#[error("failed to update job store timers")]
UpdateTimers(#[from] StoreOnTimerError),
#[error("failed to retrieve tasks with query {1:?}")]
GetTasks(#[source] StoreGetTasksError, StoreGetTasks),
#[error("no handler registered for task {0:?}")]
MissingHandler(ShortStoreTask),
#[error("failed to call task {1:?}")]
CallTask(#[source] OpaqueAsyncFnMutCallError, TaskId),
#[error("failed to update store value for task {1:?}")]
UpdateTask(#[source] StoreUpdateTaskError, TaskId),
#[error("failed to create timer for task {1:?} and deadline {2:?}")]
CreateTimer(#[source] StoreCreateTimerError, TaskId, Instant),
#[error("reached max iteration while draining ready tasks")]
Stuck,
}
pub type DynJobRuntime<'reg> = JobRuntime<
'reg,
Arc<dyn Scheduler<Timer = Pin<Box<dyn Future<Output = ()> + Send>>>>,
Arc<dyn JobStore<OpaqueTask, OpaqueValue>>,
Arc<dyn HammerfestStore>,
Arc<dyn TwinoidStore>,
>;
pub struct JobRuntime<'reg, TyClock, TyJobStore, TyHammerfestStore, TyTwinoidStore> {
pub clock: TyClock,
job_store: TyJobStore,
hammerfest_store: TyHammerfestStore,
twinoid_store: TyTwinoidStore,
#[allow(clippy::type_complexity)] registry: HashMap<
&'static str,
Box<
dyn 'reg
+ for<'cx> AsyncFn2<
&'cx mut OpaqueTask,
&'cx mut JobContext<'cx, 'reg, TyClock, TyJobStore, TyHammerfestStore, TyTwinoidStore>,
Output = Result<TaskPoll<OpaqueValue>, OpaqueAsyncFnMutCallError>,
>
+ Send
+ Sync,
>,
>,
tick_salt: AtomicU16,
job_created: Notify,
}
impl<'reg, TyClock, TyJobStore, TyHammerfestStore, TyTwinoidStore>
JobRuntime<'reg, TyClock, TyJobStore, TyHammerfestStore, TyTwinoidStore>
{
pub fn new(
clock: TyClock,
job_store: TyJobStore,
hammerfest_store: TyHammerfestStore,
twinoid_store: TyTwinoidStore,
) -> Self {
Self {
clock,
job_store,
hammerfest_store,
twinoid_store,
registry: HashMap::new(),
tick_salt: AtomicU16::new(0),
job_created: Notify::new(),
}
}
}
impl<'reg, TyClock, TyJobStore, TyHammerfestStore, TyTwinoidStore>
JobRuntime<'reg, TyClock, TyJobStore, TyHammerfestStore, TyTwinoidStore>
{
pub fn register<Handler>(&mut self)
where
for<'cx> Handler: 'reg
+ Task<&'cx mut JobContext<'cx, 'reg, TyClock, TyJobStore, TyHammerfestStore, TyTwinoidStore>>
+ ReadOpaque<OpaqueTask>
+ WriteOpaque<OpaqueTask>
+ Send,
for<'cx> <Handler as Task<&'cx mut JobContext<'cx, 'reg, TyClock, TyJobStore, TyHammerfestStore, TyTwinoidStore>>>::Output:
WriteOpaque<OpaqueValue>,
TyClock: Sync,
TyJobStore: Sync,
TyHammerfestStore: Sync,
TyTwinoidStore: Sync,
{
let caller = Box::new(OpaqueAsyncFnMutCaller::<Handler, TaskPoll<OpaqueValue>>::new());
let old = self.registry.insert(Handler::NAME, caller);
if old.is_some() {
panic!("duplicate task register for name {:?}", Handler::NAME);
}
}
}
impl<'reg, TyClock, TyJobStore, TyHammerfestStore, TyTwinoidStore>
JobRuntime<'reg, TyClock, TyJobStore, TyHammerfestStore, TyTwinoidStore>
where
TyClock: SchedulerRef,
TyJobStore: JobStore<OpaqueTask, OpaqueValue>,
{
pub async fn spawn<Handler>(
&self,
task: Handler,
user: Option<UserIdRef>,
) -> Result<
(
StoreJob,
TaskHandle<
<Handler as Task<&mut JobContext<'_, 'reg, TyClock, TyJobStore, TyHammerfestStore, TyTwinoidStore>>>::Output,
>,
),
AnyError,
>
where
for<'cx> Handler: 'reg
+ Task<&'cx mut JobContext<'cx, 'reg, TyClock, TyJobStore, TyHammerfestStore, TyTwinoidStore>>
+ WriteOpaque<OpaqueTask>
+ Send
+ Sync,
{
if !self.registry.contains_key(Handler::NAME) {
return Err(format!("failed to spawn task non-registered handler {:?}", Handler::NAME).into());
}
let now = self.clock.clock().now();
let store_job: StoreJob = self
.job_store
.create_job(StoreCreateJob {
now,
user,
kind: Handler::NAME.parse().expect("invalid task kind"),
kind_version: Handler::VERSION,
task: task.write_opaque().map_err(WeakError::wrap)?,
})
.await?;
self.job_created.notify_one();
let handle = TaskHandle::new(store_job.task.id);
Ok((store_job, handle))
}
pub async fn try_join<T>(&self, handle: TaskHandle<T>) -> Result<TaskPoll<T>, AnyError>
where
T: ReadOpaque<OpaqueValue>,
{
let store_task: StoreTask<OpaqueTask, OpaqueValue> =
self.job_store.get_task(StoreGetTask { id: handle.id }).await?;
Ok(match store_task.output {
Some(out) => TaskPoll::Ready(T::read_opaque(&out).map_err(WeakError::wrap)?),
None => TaskPoll::Pending,
})
}
pub async fn tick(&self) -> Result<(), TickError> {
const MAX_ITERATION: usize = 1000;
const MAX_COUNT: u32 = 1000;
let mut tick: Option<Tick> = None;
self.job_store.on_timer(self.clock.clock().scheduler().now()).await?;
for _ in 0..MAX_ITERATION {
let tasks: Listing<StoreTask<_>> = {
let query = StoreGetTasks {
status: Some(TaskStatus::Available),
skip_polled: tick,
offset: 0,
limit: MAX_COUNT,
};
self
.job_store
.get_tasks(query)
.await
.map_err(|e| TickError::GetTasks(e, query))?
};
if tasks.count == 0 {
return Ok(());
}
let tick = tick.get_or_insert_with(|| Tick {
time: self.clock.clock().scheduler().now(),
salt: TickSalt::new(self.tick_salt.fetch_add(1, core::sync::atomic::Ordering::SeqCst))
.expect("`TickSalt` accepts all `u8` values, the constructor never fails"), });
let tick = *tick;
for store_task in tasks.items {
let task_rev_id = store_task.rev_id();
let handler = self
.registry
.get(store_task.kind.as_str())
.ok_or_else(|| TickError::MissingHandler(store_task.to_short()))?;
let handler = &**handler;
let mut task = store_task.state;
let task_id = task_rev_id.id;
let mut starvation = store_task.starvation;
let mut timers = HashSet::new();
let mut context = JobContext {
timers: &mut timers,
starvation: &mut starvation,
runtime: self,
};
let poll = handler.call2(&mut task, &mut context).await;
let cmd: StoreUpdateTask<_, _> = match poll {
Ok(TaskPoll::Ready(value)) => StoreUpdateTask {
rev_id: task_rev_id,
tick,
status: TaskStatus::Complete,
starvation,
state: task,
output: Some(value),
},
Ok(TaskPoll::Pending) => StoreUpdateTask {
rev_id: task_rev_id,
tick,
status: if timers.is_empty() {
TaskStatus::Available
} else {
TaskStatus::Blocked
},
starvation,
state: task,
output: None,
},
Err(e) => return Err(TickError::CallTask(e, task_id)),
};
self
.job_store
.update_task(cmd)
.await
.map_err(|e| TickError::UpdateTask(e, store_task.id))?;
for timer in timers {
self
.register_timer(task_id, timer)
.await
.map_err(|e| TickError::CreateTimer(e, task_id, timer))?;
}
}
}
Err(TickError::Stuck)
}
pub async fn wait_for_available(&self) -> Result<(), AnyError> {
let job_created = self.job_created.notified();
{
let query = StoreGetTasks {
status: Some(TaskStatus::Available),
skip_polled: None,
offset: 0,
limit: 1,
};
let available = self.job_store.get_tasks(query).await.map_err(AnyError::from)?;
if available.count > 0 {
return Ok(());
}
};
let next_timer = self.job_store.next_timer().await?;
if let Some(next_timer) = next_timer {
let next_timer = self.clock.scheduler().schedule(next_timer);
let next_timer = pin!(next_timer);
let job_created = pin!(job_created);
select(next_timer, job_created).await;
} else {
job_created.await;
};
Ok(())
}
pub(crate) async fn register_timer(&self, task_id: TaskId, deadline: Instant) -> Result<(), StoreCreateTimerError> {
self
.job_store
.create_timer(StoreCreateTimer { task_id, deadline })
.await
}
pub fn sleep(&self, duration: Duration) -> Sleep {
Sleep::until(self.clock.clock().now() + duration)
}
}
mod private {
pub trait Sealed {}
}
pub trait TaskCx: private::Sealed + Sync + Send {
#[must_use]
fn now(&self) -> Instant;
#[must_use]
fn sleep_until(&self, deadline: Instant) -> Sleep;
#[must_use]
fn sleep(&self, duration: Duration) -> Sleep {
self.sleep_until(self.now() + duration)
}
fn register_timer(&mut self, deadline: Instant);
fn reset_starvation(&mut self);
fn inc_starvation(&mut self);
}
pub struct JobContext<'cx, 'reg, TyClock, TyJobStore, TyHammerfestStore, TyTwinoidStore> {
timers: &'cx mut HashSet<Instant>,
starvation: &'cx mut i32,
runtime: &'cx JobRuntime<'reg, TyClock, TyJobStore, TyHammerfestStore, TyTwinoidStore>,
}
impl<'cx, 'reg, TyClock, TyJobStore, TyHammerfestStore, TyTwinoidStore> private::Sealed
for JobContext<'cx, 'reg, TyClock, TyJobStore, TyHammerfestStore, TyTwinoidStore>
{
}
impl<'cx, 'reg, TyClock, TyJobStore, TyHammerfestStore, TyTwinoidStore> TaskCx
for JobContext<'cx, 'reg, TyClock, TyJobStore, TyHammerfestStore, TyTwinoidStore>
where
TyClock: SchedulerRef,
TyJobStore: JobStore<OpaqueTask, OpaqueValue> + Send + Sync,
TyHammerfestStore: Send + Sync,
Self: Sync + Send,
{
fn now(&self) -> Instant {
self.runtime.clock.clock().now()
}
fn sleep_until(&self, deadline: Instant) -> Sleep {
Sleep::until(deadline)
}
fn register_timer(&mut self, deadline: Instant) {
self.timers.insert(deadline);
}
fn reset_starvation(&mut self) {
*self.starvation = 0;
}
fn inc_starvation(&mut self) {
*self.starvation += 1;
}
}
impl<'cx, 'reg, TyClock, TyJobStore, TyHammerfestStore, TyTwinoidStore> HammerfestStoreRef
for JobContext<'cx, 'reg, TyClock, TyJobStore, TyHammerfestStore, TyTwinoidStore>
where
TyClock: Send + Sync,
TyJobStore: Send + Sync,
TyHammerfestStore: HammerfestStoreRef,
TyTwinoidStore: Send + Sync,
{
type HammerfestStore = TyHammerfestStore::HammerfestStore;
fn hammerfest_store(&self) -> &Self::HammerfestStore {
self.runtime.hammerfest_store.hammerfest_store()
}
}
impl<'cx, 'reg, TyClock, TyJobStore, TyHammerfestStore, TyTwinoidStore> TwinoidStoreRef
for JobContext<'cx, 'reg, TyClock, TyJobStore, TyHammerfestStore, TyTwinoidStore>
where
TyClock: Send + Sync,
TyJobStore: Send + Sync,
TyHammerfestStore: Send + Sync,
TyTwinoidStore: TwinoidStoreRef,
{
type TwinoidStore = TyTwinoidStore::TwinoidStore;
fn twinoid_store(&self) -> &Self::TwinoidStore {
self.runtime.twinoid_store.twinoid_store()
}
}