ora_common/
task.rs

1//! Task definition and implementations.
2
3use std::{
4    borrow::Cow,
5    collections::{BTreeMap, HashMap},
6    fmt::Display,
7    marker::PhantomData,
8    str::FromStr,
9    time::{SystemTime, UNIX_EPOCH},
10};
11
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14use time::{Duration, OffsetDateTime};
15
16use crate::{timeout::TimeoutPolicy, UnixNanos};
17
18/// A selector that is used to connect tasks
19/// with workers.
20/// Currently a task can be executed by a worker only if
21/// the worker selector of the task and the worker are equal.
22#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize, Hash)]
23#[must_use]
24pub struct WorkerSelector {
25    /// A common shared name between the tasks and workers,
26    pub kind: Cow<'static, str>,
27}
28
29impl<T> From<T> for WorkerSelector
30where
31    T: Into<String>,
32{
33    fn from(value: T) -> Self {
34        Self {
35            kind: value.into().into(),
36        }
37    }
38}
39
40/// An untyped complete task definition
41/// that can be added to the queue.
42#[derive(Debug, Serialize, Deserialize)]
43#[must_use]
44pub struct TaskDefinition<T = ()> {
45    /// The target time of the task execution.
46    pub target: UnixNanos,
47    /// The worker selector of the task.
48    pub worker_selector: WorkerSelector,
49    /// Arbitrary task data that is passed to the workers.
50    pub data: Vec<u8>,
51    /// The input data format.
52    pub data_format: TaskDataFormat,
53    /// Arbitrary task labels.
54    #[serde(default)]
55    pub labels: HashMap<String, Value>,
56    /// An optional timeout policy.
57    #[serde(default)]
58    pub timeout: TimeoutPolicy,
59    #[doc(hidden)]
60    #[serde(default, skip)]
61    pub _task_type: PhantomData<T>,
62}
63
64impl<T> Clone for TaskDefinition<T> {
65    fn clone(&self) -> Self {
66        Self {
67            target: self.target,
68            worker_selector: self.worker_selector.clone(),
69            data: self.data.clone(),
70            data_format: self.data_format,
71            labels: self.labels.clone(),
72            timeout: self.timeout,
73            _task_type: PhantomData,
74        }
75    }
76}
77
78impl<T> TaskDefinition<T> {
79    /// Set a timeout policy.
80    pub fn with_timeout(mut self, timeout: impl Into<TimeoutPolicy>) -> Self {
81        self.timeout = timeout.into();
82        self
83    }
84
85    /// Schedule the task immediately.
86    pub fn immediate(mut self) -> Self {
87        self.target = UnixNanos(0);
88        self
89    }
90
91    /// Set the target execution time of the new task.
92    pub fn at(mut self, target: OffsetDateTime) -> Self {
93        let nanos = target.unix_timestamp_nanos();
94
95        self.target = if nanos.is_negative() {
96            UnixNanos(0)
97        } else {
98            UnixNanos(nanos.unsigned_abs().try_into().unwrap_or(u64::MAX))
99        };
100        self
101    }
102
103    /// Schedule the task at the given unix nanosecond duration.
104    pub fn at_unix(mut self, target: UnixNanos) -> Self {
105        self.target = target;
106        self
107    }
108
109    /// Schedule the task with the current time as target.
110    pub fn now(mut self) -> Self {
111        self.target = UnixNanos::now();
112        self
113    }
114
115    /// Set the target execution time of the new task to
116    /// be after the given duration.
117    ///
118    /// # Panics
119    ///
120    /// Panics if the system time is before UNIX epoch.
121    #[allow(clippy::cast_possible_truncation)]
122    pub fn after(mut self, duration: Duration) -> Self {
123        let nanos = duration.whole_nanoseconds();
124        let nanos = if nanos.is_negative() {
125            0
126        } else {
127            nanos.unsigned_abs().try_into().unwrap_or(u64::MAX)
128        };
129
130        self.target = UnixNanos(
131            SystemTime::now()
132                .duration_since(UNIX_EPOCH)
133                .unwrap()
134                .saturating_add(std::time::Duration::from_nanos(nanos))
135                .as_nanos() as u64,
136        );
137
138        self
139    }
140
141    /// Set the worker selector for the given task.
142    pub fn with_worker_selector(mut self, selector: impl Into<WorkerSelector>) -> Self {
143        self.worker_selector = selector.into();
144        self
145    }
146
147    /// Set a label value.
148    ///
149    /// # Panics
150    ///
151    /// Panics if the value is not JSON-serializable.
152    pub fn with_label(mut self, name: &str, value: impl Serialize) -> Self {
153        self.labels
154            .insert(name.into(), serde_json::to_value(value).unwrap());
155        self
156    }
157
158    /// Cast the task to a different task type,
159    /// or to erase the task type replacing it with `()`.
160    ///
161    /// This is not required in most circumstances.
162    pub fn cast<U>(self) -> TaskDefinition<U> {
163        TaskDefinition {
164            target: self.target,
165            worker_selector: self.worker_selector,
166            data: self.data,
167            data_format: self.data_format,
168            labels: self.labels,
169            timeout: self.timeout,
170            _task_type: PhantomData,
171        }
172    }
173}
174
175/// All valid statuses for tasks.
176#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
177pub enum TaskStatus {
178    /// The task is not yet ready to run.
179    Pending,
180    /// The task is waiting for a worker.
181    Ready,
182    /// The task is currently running.
183    Started,
184    /// The task finished successfully.
185    Succeeded,
186    /// The task failed.
187    Failed,
188    /// The task was cancelled.
189    Cancelled,
190}
191
192impl FromStr for TaskStatus {
193    type Err = UnexpectedValueError;
194
195    fn from_str(s: &str) -> Result<Self, Self::Err> {
196        Ok(match s {
197            "pending" => TaskStatus::Pending,
198            "ready" => TaskStatus::Ready,
199            "started" => TaskStatus::Started,
200            "succeeded" => TaskStatus::Succeeded,
201            "failed" => TaskStatus::Failed,
202            "cancelled" => TaskStatus::Cancelled,
203            _ => Err(UnexpectedValueError(s.to_string()))?,
204        })
205    }
206}
207
208impl TaskStatus {
209    /// Return a string representation.
210    #[must_use]
211    pub fn as_str(&self) -> &'static str {
212        match self {
213            TaskStatus::Pending => "pending",
214            TaskStatus::Ready => "ready",
215            TaskStatus::Started => "started",
216            TaskStatus::Succeeded => "succeeded",
217            TaskStatus::Failed => "failed",
218            TaskStatus::Cancelled => "cancelled",
219        }
220    }
221
222    /// Return whether the status is considered final,
223    /// as in there is no possible other status value
224    /// to change to under normal circumstances.
225    ///
226    /// In other words a task is finished if either:
227    ///
228    /// - it succeeded
229    /// - it failed
230    /// - it was cancelled
231    #[must_use]
232    pub fn is_finished(&self) -> bool {
233        matches!(
234            self,
235            TaskStatus::Succeeded | TaskStatus::Failed | TaskStatus::Cancelled
236        )
237    }
238}
239
240/// Extra metadata about a task or task type.
241#[derive(Debug, Default, Clone, Serialize, Deserialize)]
242pub struct TaskMetadata {
243    /// Task description.
244    pub description: Option<String>,
245    /// Task input JSON schema.
246    pub input_schema: Option<Value>,
247    /// Task output JSON schema.
248    pub output_schema: Option<Value>,
249    /// Arbitrary fields.
250    #[serde(flatten)]
251    pub other: BTreeMap<String, Value>,
252}
253
254/// The format of the task input or output data.
255#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
256#[serde(rename_all = "snake_case")]
257pub enum TaskDataFormat {
258    /// Arbitrary bytes.
259    #[default]
260    Unknown,
261    /// The data can be interpreted as self-describing MessagePack.
262    MessagePack,
263    /// The data can be interpreted as JSON.
264    Json,
265}
266
267impl TaskDataFormat {
268    /// Return a string representation.
269    #[must_use]
270    pub fn as_str(&self) -> &'static str {
271        match self {
272            TaskDataFormat::Unknown => "unknown",
273            TaskDataFormat::MessagePack => "message_pack",
274            TaskDataFormat::Json => "json",
275        }
276    }
277}
278
279impl FromStr for TaskDataFormat {
280    type Err = UnexpectedValueError;
281
282    fn from_str(s: &str) -> Result<Self, Self::Err> {
283        Ok(match s {
284            "unknown" => TaskDataFormat::Unknown,
285            "message_pack" => TaskDataFormat::MessagePack,
286            "json" => TaskDataFormat::Json,
287            _ => Err(UnexpectedValueError(s.to_string()))?,
288        })
289    }
290}
291
292/// Unexpected value received.
293#[derive(Debug)]
294pub struct UnexpectedValueError(String);
295
296impl Display for UnexpectedValueError {
297    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
298        write!(f, "unexpected value: {}", self.0)
299    }
300}
301
302impl std::error::Error for UnexpectedValueError {}