1use std::{future::poll_fn, sync::Arc, task::Poll};
29
30use anyhow::{Result, anyhow};
31use beetry_core::{
32 AbortTask, ActionTask, ExecutorConcept, QueryTask, RegisterTask, TaskDescription, TaskStatus,
33};
34use bon::Builder;
35use futures::{StreamExt, stream::FuturesUnordered};
36use tokio::sync::{
37 Notify,
38 mpsc::{self, Receiver, Sender, error::TryRecvError},
39};
40use tracing::{debug, instrument};
41
42#[derive(Debug, Builder)]
44pub struct ExecutorConfig {
45 #[builder(default = 8)]
46 pub task_channel_capacity: usize,
47 #[builder(default = std::time::Duration::from_millis(10))]
48 pub abort_poll_interval: std::time::Duration,
49}
50
51impl Default for ExecutorConfig {
52 fn default() -> Self {
53 Self::builder().build()
54 }
55}
56
57pub struct WithRegistry {
59 registry: TaskRegistry,
60}
61
62pub struct ExecutionTask {
63 task: ActionTask,
64 status_sender: Sender<TaskStatus>,
65 abort_notifier: Arc<Notify>,
66}
67
68impl ExecutionTask {
69 fn new(
70 task: ActionTask,
71 status_sender: Sender<TaskStatus>,
72 abort_notifier: Arc<Notify>,
73 ) -> Self {
74 Self {
75 task,
76 status_sender,
77 abort_notifier,
78 }
79 }
80
81 #[instrument(skip(self), fields(task = %self.task.desc()))]
82 async fn execute(self) -> Result<()> {
83 tokio::select! {
84 () = self.abort_notifier.notified() => {
85 debug!("aborting execution task");
86 self.status_sender.send(TaskStatus::Aborted).await?;
87 }
88
89 status = self.task.execute() => {
90 debug!("task reached terminal state: {status:?}");
91 self.status_sender.send(status.into()).await?;
92 },
93 }
94 Ok(())
95 }
96}
97
98pub struct Executor<S> {
100 recv: Receiver<ExecutionTask>,
101 state: S,
102}
103
104pub struct Init;
106pub struct Ready;
108
109impl Executor<Init> {
110 #[expect(
111 clippy::needless_pass_by_value,
112 reason = "Config contains only copy types now, but not marked Copy for future extensions"
113 )]
114 pub fn new(config: ExecutorConfig) -> Executor<WithRegistry> {
118 let (sender, recv) = mpsc::channel(config.task_channel_capacity);
119 let registry = TaskRegistry::new(sender);
120
121 Executor {
122 recv,
123 state: WithRegistry { registry },
124 }
125 }
126}
127
128impl Executor<WithRegistry> {
129 pub fn into_ready_with_registry(self) -> (Executor<Ready>, TaskRegistry) {
138 (
139 Executor {
140 recv: self.recv,
141 state: Ready,
142 },
143 self.state.registry,
144 )
145 }
146}
147
148impl ExecutorConcept for Executor<Ready> {
149 #[instrument(skip(self), name = "Executor::run")]
150 async fn run(&mut self) -> Result<()> {
151 debug!("start running registered tasks");
152 let mut tasks = FuturesUnordered::new();
153
154 loop {
155 let execute_next_task_fut = poll_fn(|cx| {
156 if tasks.is_empty() {
157 Poll::Pending
158 } else {
159 tasks.poll_next_unpin(cx)
160 }
161 });
162
163 tokio::select! {
164 Some(exe_task) = self.recv.recv() => {
165 debug!("received new task to execute: {}", exe_task.task.desc());
166 tasks.push(exe_task.execute());
167 },
168 Some(result) = execute_next_task_fut => {
169 result?;
170 }
171
172 }
173 }
174 }
175}
176
177#[derive(Debug, Clone)]
178pub struct TaskRegistry {
185 sender: Sender<ExecutionTask>,
186}
187
188impl TaskRegistry {
189 fn new(sender: Sender<ExecutionTask>) -> Self {
190 Self { sender }
191 }
192}
193
194impl RegisterTask<TaskHandle> for TaskRegistry {
195 #[instrument(skip_all, fields(task = %task.desc()))]
196 fn register(&self, task: ActionTask) -> Result<TaskHandle> {
197 let (status_send, status_recv) = mpsc::channel(1);
198 let notify = Arc::new(Notify::new());
199 let exe_task = ExecutionTask::new(task, status_send, Arc::clone(¬ify));
200 let handle = TaskHandle::new(
201 StatusQuerier::new(status_recv),
202 TaskAborter::new(notify),
203 exe_task.task.desc().clone(),
204 );
205 self.sender.try_send(exe_task).map_err(|err| {
206 anyhow!(
207 "failed to send execution task: {}",
208 err.into_inner().task.desc()
209 )
210 })?;
211
212 Ok(handle)
213 }
214}
215
216#[derive(Debug)]
217struct StatusQuerier {
218 status_recv: Receiver<TaskStatus>,
219}
220
221impl StatusQuerier {
222 fn new(status_recv: Receiver<TaskStatus>) -> Self {
223 Self { status_recv }
224 }
225
226 fn query(&mut self) -> TaskStatus {
227 match self.status_recv.try_recv() {
228 Ok(status) => status,
229 Err(e) => match e {
230 TryRecvError::Empty => TaskStatus::Running,
231 TryRecvError::Disconnected => {
232 panic!(
235 "task status channel disconnected before a terminal status was observed -
236 this indicates an executor/task lifecycle bug"
237 );
238 }
239 },
240 }
241 }
242}
243
244#[derive(Debug)]
245struct TaskAborter {
246 abort_notifier: Arc<Notify>,
247}
248
249impl TaskAborter {
250 fn new(abort_notifier: Arc<Notify>) -> Self {
251 Self { abort_notifier }
252 }
253
254 fn abort(&self) {
255 self.abort_notifier.notify_one();
256 }
257}
258
259#[derive(Debug)]
260pub struct TaskHandle {
271 querier: StatusQuerier,
272 aborter: TaskAborter,
273 desc: TaskDescription,
274}
275
276impl TaskHandle {
277 fn new(querier: StatusQuerier, aborter: TaskAborter, desc: TaskDescription) -> Self {
278 Self {
279 querier,
280 aborter,
281 desc,
282 }
283 }
284}
285
286impl QueryTask for TaskHandle {
287 fn query(&mut self) -> TaskStatus {
288 self.querier.query()
289 }
290}
291
292impl AbortTask for TaskHandle {
293 #[instrument(skip_all, fields(desc=%self.desc))]
294 fn abort(&mut self) {
295 self.aborter.abort();
296 }
297}