homestar_runtime/
runner.rs

1//! General [Runner] interface for working across multiple workers
2//! and executing workflows.
3
4#[cfg(feature = "ipfs")]
5use crate::network::IpfsCli;
6use crate::{
7    channel::{AsyncChannel, AsyncChannelReceiver, AsyncChannelSender},
8    db::Database,
9    event_handler::{Event, EventHandler},
10    network::{rpc, swarm, webserver},
11    settings,
12    tasks::Fetch,
13    worker::WorkerMessage,
14    workflow::{self, Resource},
15    Db, Receipt, Settings, Worker,
16};
17use anyhow::{anyhow, Context, Result};
18use atomic_refcell::AtomicRefCell;
19use chrono::NaiveDateTime;
20use dashmap::DashMap;
21use faststr::FastStr;
22use fnv::FnvHashSet;
23use futures::{future::poll_fn, FutureExt};
24use homestar_invocation::Pointer;
25use homestar_wasm::io::Arg;
26use homestar_workflow::Workflow;
27use jsonrpsee::server::ServerHandle;
28use libipld::Cid;
29use metrics_exporter_prometheus::PrometheusHandle;
30#[cfg(not(test))]
31use std::sync::atomic::{AtomicUsize, Ordering};
32use std::{collections::HashMap, ops::ControlFlow, rc::Rc, sync::Arc, task::Poll};
33#[cfg(not(windows))]
34use tokio::signal::unix::{signal, SignalKind};
35#[cfg(windows)]
36use tokio::signal::windows;
37use tokio::{
38    runtime, select,
39    task::{AbortHandle, JoinHandle},
40    time,
41};
42use tokio_util::time::{delay_queue, DelayQueue};
43use tracing::{debug, error, info, info_span, instrument, warn, Instrument};
44
45mod error;
46pub(crate) mod file;
47mod nodeinfo;
48pub mod response;
49pub(crate) use error::Error;
50pub use nodeinfo::NodeInfo;
51pub(crate) use nodeinfo::{DynamicNodeInfo, StaticNodeInfo};
52
53/// Name of the thread used for the [Runner] / runtime.
54#[cfg(not(test))]
55const HOMESTAR_THREAD: &str = "homestar-runtime";
56
57/// Type alias for a [DashMap] containing running worker [JoinHandle]s.
58pub(crate) type RunningWorkerSet = DashMap<Cid, (JoinHandle<Result<()>>, delay_queue::Key)>;
59
60/// Type alias for a [DashMap] containing running task [AbortHandle]s.
61pub(crate) type RunningTaskSet = DashMap<Cid, Vec<AbortHandle>>;
62
63/// Trait for managing a [DashMap] of running task information.
64pub(crate) trait ModifiedSet {
65    /// Append or insert a new [AbortHandle] into the [RunningTaskSet].
66    fn append_or_insert(&self, cid: Cid, handles: Vec<AbortHandle>);
67}
68
69/// [AsyncChannelSender] for RPC server messages.
70pub(crate) type RpcSender = AsyncChannelSender<(
71    rpc::ServerMessage,
72    Option<AsyncChannelSender<rpc::ServerMessage>>,
73)>;
74
75/// [AsyncChannelReceiver] for RPC server messages.
76pub(crate) type RpcReceiver = AsyncChannelReceiver<(
77    rpc::ServerMessage,
78    Option<AsyncChannelSender<rpc::ServerMessage>>,
79)>;
80
81/// Type alias for a tuple containing a receipt Cid and associated `ran` and `instruction` values.
82pub(crate) type WorkflowReceiptInfo = (Cid, Option<(String, Pointer)>);
83
84/// [AsyncChannelSender] for sending messages WebSocket server clients.
85pub(crate) type WsSender = AsyncChannelSender<(
86    webserver::Message,
87    Option<AsyncChannelSender<webserver::Message>>,
88)>;
89
90/// [AsyncChannelReceiver] for receiving messages from WebSocket server clients.
91pub(crate) type WsReceiver = AsyncChannelReceiver<(
92    webserver::Message,
93    Option<AsyncChannelSender<webserver::Message>>,
94)>;
95
96impl ModifiedSet for RunningTaskSet {
97    fn append_or_insert(&self, cid: Cid, mut handles: Vec<AbortHandle>) {
98        self.entry(cid)
99            .and_modify(|prev_handles| {
100                prev_handles.append(&mut handles);
101            })
102            .or_insert_with(|| handles);
103    }
104}
105
106/// Runner interface.
107/// Used to manage workers and execute/run [Workflows].
108///
109/// [Workflows]: homestar_workflow::Workflow
110#[derive(Debug)]
111pub struct Runner {
112    event_sender: Arc<AsyncChannelSender<Event>>,
113    expiration_queue: Rc<AtomicRefCell<DelayQueue<Cid>>>,
114    node_info: StaticNodeInfo,
115    running_tasks: Arc<RunningTaskSet>,
116    running_workers: RunningWorkerSet,
117    pub(crate) runtime: tokio::runtime::Runtime,
118    pub(crate) settings: Arc<Settings>,
119    webserver: Arc<webserver::Server>,
120}
121
122impl Runner {
123    /// Setup bounded, MPSC channel for top-level RPC communication.
124    pub(crate) fn setup_rpc_channel(capacity: usize) -> (RpcSender, RpcReceiver) {
125        AsyncChannel::with(capacity)
126    }
127
128    /// Setup bounded, MPSC channel for top-level Worker communication.
129    pub(crate) fn setup_worker_channel(
130        capacity: usize,
131    ) -> (
132        AsyncChannelSender<WorkerMessage>,
133        AsyncChannelReceiver<WorkerMessage>,
134    ) {
135        AsyncChannel::with(capacity)
136    }
137
138    /// MPSC channel for sending and receiving messages through to/from
139    /// WebSocket server clients.
140    pub(crate) fn setup_ws_mpsc_channel(capacity: usize) -> (WsSender, WsReceiver) {
141        AsyncChannel::with(capacity)
142    }
143
144    /// Initialize and start the Homestar [Runner] / runtime.
145    #[cfg(not(test))]
146    pub fn start(settings: Settings, db: impl Database + 'static) -> Result<()> {
147        let runtime = runtime::Builder::new_multi_thread()
148            .enable_all()
149            .thread_name_fn(|| {
150                static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0);
151                let id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst);
152                format!("{HOMESTAR_THREAD}-{id}")
153            })
154            .build()?;
155
156        Self::init(settings, db.clone(), runtime)?.serve(db)
157    }
158
159    /// Initialize and start the Homestar [Runner] / runtime.
160    #[cfg(test)]
161    pub fn start(settings: Settings, db: impl Database + 'static) -> Result<Self> {
162        let runtime = runtime::Builder::new_current_thread()
163            .enable_all()
164            .build()?;
165
166        let runner = Self::init(settings, db, runtime)?;
167        Ok(runner)
168    }
169
170    fn init(
171        settings: Settings,
172        db: impl Database + 'static,
173        runtime: tokio::runtime::Runtime,
174    ) -> Result<Self> {
175        let swarm = runtime.block_on(swarm::new(settings.node().network()))?;
176        let peer_id = *swarm.local_peer_id();
177
178        let webserver = webserver::Server::new(settings.node().network().webserver())?;
179
180        #[cfg(feature = "websocket-notify")]
181        let (ws_msg_tx, ws_evt_tx) = {
182            let ws_msg_tx = webserver.workflow_msg_notifier();
183            let ws_evt_tx = webserver.evt_notifier();
184
185            (ws_msg_tx, ws_evt_tx)
186        };
187
188        #[cfg(feature = "websocket-notify")]
189        let event_handler =
190            EventHandler::new(swarm, db, settings.node().network(), ws_evt_tx, ws_msg_tx);
191        #[cfg(not(feature = "websocket-notify"))]
192        let event_handler = EventHandler::new(swarm, db, settings.node().network());
193
194        let event_sender = event_handler.sender();
195
196        #[cfg(feature = "ipfs")]
197        let _event_handler_hdl = runtime.spawn({
198            let ipfs = IpfsCli::new(settings.node.network.ipfs())?;
199            event_handler.start(ipfs)
200        });
201
202        #[cfg(not(feature = "ipfs"))]
203        let _event_handler_hdl = runtime.spawn(event_handler.start());
204
205        Ok(Self {
206            event_sender,
207            expiration_queue: Rc::new(AtomicRefCell::new(DelayQueue::new())),
208            node_info: StaticNodeInfo::new(peer_id),
209            running_tasks: DashMap::new().into(),
210            running_workers: DashMap::new(),
211            runtime,
212            settings: settings.into(),
213            webserver: webserver.into(),
214        })
215    }
216
217    /// Listen loop for [Runner] signals and messages.
218    #[allow(dead_code)]
219    fn serve(self, db: impl Database + 'static) -> Result<()> {
220        let message_buffer_len = self.settings.node.network.events_buffer_len;
221
222        #[cfg(feature = "monitoring")]
223        let metrics_hdl: PrometheusHandle = self.runtime.block_on(crate::metrics::start(
224            self.settings.node.monitoring(),
225            self.settings.node.network(),
226        ))?;
227
228        #[cfg(not(feature = "monitoring"))]
229        let metrics_hdl: PrometheusHandle = self
230            .runtime
231            .block_on(crate::metrics::start(self.settings.node.network()))?;
232
233        let (ws_receiver, ws_hdl) = {
234            let (mpsc_ws_tx, mpsc_ws_rx) = Self::setup_ws_mpsc_channel(message_buffer_len);
235            let ws_hdl =
236                self.runtime
237                    .block_on(self.webserver.start(mpsc_ws_tx, metrics_hdl, db.clone()))?;
238            (mpsc_ws_rx, ws_hdl)
239        };
240
241        let (rpc_tx, rpc_rx) = Self::setup_rpc_channel(message_buffer_len);
242        let (runner_worker_tx, runner_worker_rx) = Self::setup_worker_channel(message_buffer_len);
243
244        let shutdown_timeout = self.settings.node.shutdown_timeout;
245        let rpc_server = rpc::Server::new(self.settings.node.network(), rpc_tx.into());
246        let rpc_sender = rpc_server.sender();
247        self.runtime.block_on(rpc_server.spawn())?;
248
249        let shutdown_time_left = self.runtime.block_on(async {
250            let mut gc_interval = tokio::time::interval(self.settings.node.gc_interval);
251            loop {
252                select! {
253                    // Handle RPC messages.
254                    Ok((rpc_message, Some(oneshot_tx))) = rpc_rx.recv_async() => {
255                        let now = time::Instant::now();
256                        let handle = self.handle_command_message(
257                            rpc_message,
258                            Channels {
259                                rpc: rpc_sender.clone(),
260                                runner: runner_worker_tx.clone(),
261                            },
262                            ws_hdl.clone(),
263                            db.clone(),
264                            self.settings.node.network().libp2p().dht(),
265                            now
266                        ).await;
267
268
269                        match handle {
270                            Ok(ControlFlow::Break(())) => break now.elapsed(),
271                            Ok(ControlFlow::Continue(rpc::ServerMessage::Skip)) => {},
272                            Ok(ControlFlow::Continue(msg @ rpc::ServerMessage::NodeInfoAck(_))) => {
273                                debug!(subject = "rpc.ack",
274                                       category = "rpc",
275                                       "sending node_info message to rpc server");
276                                let _ = oneshot_tx.send_async(msg).await;
277                            },
278                            Ok(ControlFlow::Continue(msg @ rpc::ServerMessage::RunAck(_))) => {
279                                debug!(subject = "rpc.ack",
280                                       category = "rpc",
281                                       "sending workflow_run message to rpc server");
282                                let _ = oneshot_tx.send_async(msg).await;
283                            },
284                            Err(err) => {
285                                error!(subject = "rpc.err",
286                                       category = "rpc",
287                                       err=?err,
288                                       "error handling rpc message");
289                                let _ = oneshot_tx.send_async(rpc::ServerMessage::RunErr(err.into())).await;
290                            },
291                             _ => {}
292                        }
293                    }
294                    Ok(msg) = ws_receiver.recv_async() => {
295                        match msg {
296                            (webserver::Message::RunWorkflow((name, workflow)), Some(oneshot_tx)) => {
297                                info!(subject = "workflow",
298                                      category = "workflow.run",
299                                      "running workflow: {}", name);
300                                // TODO: Parse this from the workflow data itself.
301                                let workflow_settings = workflow::Settings::default();
302                                match self.run_worker(
303                                    workflow,
304                                    workflow_settings,
305                                    self.settings.node.network().libp2p().dht(),
306                                    Some(name),
307                                    runner_worker_tx.clone(),
308                                    db.clone(),
309                                ).await {
310                                    Ok(data) => {
311                                        debug!(subject = "jsonrpc.ack",
312                                               category = "jsonrpc",
313                                               "sending message to jsonrpc server");
314                                        let _ = oneshot_tx.send_async(webserver::Message::AckWorkflow((data.info.cid, data.name))).await;
315                                    }
316                                    Err(err) => {
317                                        error!(subject = "jsonrpc.err",
318                                               category = "jsonrpc",
319                                               err=?err,
320                                               "error handling ws message");
321                                        let _ = oneshot_tx.send_async(webserver::Message::RunErr(err.into())).await;
322                                    }
323                                }
324
325                            }
326                            (webserver::Message::GetNodeInfo, Some(oneshot_tx)) => {
327                                debug!(subject = "jsonrpc.nodeinfo",
328                                       category = "jsonrpc",
329                                       "getting node info");
330                                let (tx, rx) = AsyncChannel::oneshot();
331                                let _ = self.event_sender.send_async(Event::GetNodeInfo(tx)).await;
332                                let dyn_node_info = if let Ok(info) = rx.recv_async().await {
333                                    info
334                                } else {
335                                    DynamicNodeInfo::default()
336                                };
337                                let _ = oneshot_tx.send_async(webserver::Message::AckNodeInfo((self.node_info.clone(), dyn_node_info))).await;
338                            }
339                            _ => ()
340                        }
341                    }
342
343                    // Handle messages from the worker.
344                    Ok(msg) = runner_worker_rx.recv_async() => {
345                        match msg {
346                            WorkerMessage::Dropped(cid) => {
347                                let _ = self.abort_worker(cid);
348                            },
349                        }
350                    }
351                    // Handle GC interval tick.
352                    _ = gc_interval.tick() => {
353                        let _ = self.gc();
354                    },
355                    // Handle expired workflows.
356                    Some(expired) = poll_fn(
357                        |ctx| match self.expiration_queue.try_borrow_mut() {
358                            Ok(mut queue) => queue.poll_expired(ctx),
359                            Err(_) => Poll::Pending,
360                        }
361                    ) => {
362                        info!(subject = "worker.expired",
363                              category = "worker",
364                              "worker expired, aborting");
365                        let _ = self.abort_worker(*expired.get_ref());
366                    },
367                    // Handle shutdown signal.
368                    _ = Self::shutdown_signal() => {
369                        info!(subject = "shutdown",
370                              category = "homestar.shutdown",
371                              "gracefully shutting down runner");
372
373                        let now = time::Instant::now();
374                        let drain_timeout = now + shutdown_timeout;
375                        select! {
376                            // Graceful shutdown.
377                            Ok(()) = self.shutdown(rpc_sender, ws_hdl) => {
378                                break now.elapsed();
379                            },
380                            // Force shutdown upon drain timeout.
381                            _ = time::sleep_until(drain_timeout) => {
382                                info!(subject = "shutdown",
383                                      category = "homestar.shutdown",
384                                      "shutdown timeout reached, shutting down runner anyway");
385                                break now.elapsed();
386                            }
387                        }
388
389                    }
390                }
391            }
392        });
393
394        if shutdown_time_left < shutdown_timeout {
395            self.runtime
396                .shutdown_timeout(shutdown_timeout - shutdown_time_left);
397            info!(
398                subject = "shutdown",
399                category = "homestar.shutdown",
400                "runner shutdown complete"
401            );
402        }
403
404        Ok(())
405    }
406
407    /// [AsyncChannelSender] of the event-handler.
408    ///
409    /// [EventHandler]: crate::EventHandler
410    pub(crate) fn event_sender(&self) -> Arc<AsyncChannelSender<Event>> {
411        self.event_sender.clone()
412    }
413
414    /// Getter for the [RunningTaskSet], cloned as an [Arc].
415    pub(crate) fn running_tasks(&self) -> Arc<RunningTaskSet> {
416        self.running_tasks.clone()
417    }
418
419    /// Garbage-collect task [AbortHandle]s in the [RunningTaskSet] and
420    /// workers in the [RunningWorkerSet].
421    #[allow(dead_code)]
422    fn gc(&self) -> Result<()> {
423        self.running_tasks.retain(|_cid, handles| {
424            handles.retain(|handle| !handle.is_finished());
425            !handles.is_empty()
426        });
427
428        let mut expiration_q = self
429            .expiration_queue
430            .try_borrow_mut()
431            .map_err(|e| anyhow!("failed to borrow expiration queue: {e}"))?;
432
433        for worker in self.running_workers.iter_mut() {
434            let (handle, delay_key) = worker.value();
435            if handle.is_finished() {
436                let _ = expiration_q.try_remove(delay_key);
437            }
438        }
439
440        self.running_workers
441            .retain(|_cid, (handle, _delay_key)| !handle.is_finished());
442
443        Ok(())
444    }
445
446    /// Abort and gc/cleanup all workers and tasks.
447    #[allow(dead_code)]
448    fn abort_and_cleanup_workers(&self) -> Result<()> {
449        self.abort_workers();
450        self.cleanup_workers()?;
451
452        Ok(())
453    }
454
455    /// Abort all workers.
456    #[allow(dead_code)]
457    fn abort_workers(&self) {
458        self.running_workers.iter_mut().for_each(|data| {
459            let (handle, _delay_key) = data.value();
460            handle.abort()
461        });
462        self.abort_tasks();
463    }
464
465    /// Cleanup all workers, tasks, and the expiration queue.
466    #[allow(dead_code)]
467    fn cleanup_workers(&self) -> Result<()> {
468        self.running_workers.clear();
469        self.expiration_queue
470            .try_borrow_mut()
471            .map_err(|e| anyhow!("failed to borrow expiration queue: {e}"))?
472            .clear();
473        self.cleanup_tasks();
474        Ok(())
475    }
476
477    /// Cleanup all tasks in the [RunningTaskSet].
478    #[allow(dead_code)]
479    fn cleanup_tasks(&self) {
480        self.running_tasks.clear();
481    }
482
483    /// Aborts and garbage-collects a set of task [AbortHandle]s running for all
484    /// workers.
485    #[allow(dead_code)]
486    fn abort_tasks(&self) {
487        self.running_tasks.iter_mut().for_each(|handles| {
488            for abort_handle in &*handles {
489                abort_handle.abort();
490            }
491        });
492    }
493
494    /// Aborts and removes a specific worker's [JoinHandle] and
495    /// set of task [AbortHandle]s given a Cid.
496    #[allow(dead_code)]
497    fn abort_worker(&self, cid: Cid) -> Result<()> {
498        let mut expiration_q = self
499            .expiration_queue
500            .try_borrow_mut()
501            .map_err(|e| anyhow!("failed to borrow expiration queue: {e}"))?;
502
503        if let Some((cid, (handle, delay_key))) = self.running_workers.remove(&cid) {
504            let _ = expiration_q.try_remove(&delay_key);
505            handle.abort();
506            self.abort_worker_tasks(cid);
507        }
508
509        Ok(())
510    }
511
512    /// Abort a specific worker's tasks given a Cid.
513    fn abort_worker_tasks(&self, cid: Cid) {
514        if let Some((_cid, handles)) = self.running_tasks.remove(&cid) {
515            for abort_handle in &*handles {
516                abort_handle.abort();
517            }
518        }
519    }
520
521    /// Captures shutdown signals for [Runner].
522    #[allow(dead_code)]
523    #[cfg(not(windows))]
524    async fn shutdown_signal() -> Result<()> {
525        let mut sigint = signal(SignalKind::interrupt())?;
526        let mut sigterm = signal(SignalKind::terminate())?;
527
528        select! {
529            _ = tokio::signal::ctrl_c() =>
530                info!(subject = "shutdown",
531                      category = "homestar.shutdown",
532                      "CTRL-C received, shutting down"),
533            _ = sigint.recv() =>
534                info!(subject = "shutdown",
535                      category = "homestar.shutdown",
536                      "SIGINT received, shutting down"),
537            _ = sigterm.recv() =>
538                info!(subject = "shutdown",
539                      category = "homestar.shutdown",
540                      "SIGTERM received, shutting down"),
541        }
542        Ok(())
543    }
544
545    #[allow(dead_code)]
546    #[cfg(windows)]
547    async fn shutdown_signal() -> Result<()> {
548        let mut sigint = windows::ctrl_close()?;
549        let mut sigterm = windows::ctrl_shutdown()?;
550        let mut sighup = windows::ctrl_break()?;
551
552        select! {
553            _ = tokio::signal::ctrl_c() =>
554                info!(subject = "shutdown",
555                      category = "homestar.shutdown",
556                      "CTRL-C received, shutting down"),
557            _ = sigint.recv() =>
558                info!(subject = "shutdown",
559                      category = "homestar.shutdown",
560                      "SIGINT received, shutting down"),
561            _ = sigterm.recv() =>
562                info!(subject = "shutdown",
563                      category = "homestar.shutdown",
564                      "SIGTERM received, shutting down"),
565            _ = sighup.recv() =>
566                info!(subject = "shutdown",
567                      category = "homestar.shutdown",
568                      "SIGHUP received, shutting down")
569        }
570        Ok(())
571    }
572
573    /// Sequence for shutting down a [Runner], including:
574    /// a) RPC (CLI)
575    /// b) Webserver
576    /// b) Event-handler channels
577    /// c) Running workers
578    async fn shutdown(
579        &self,
580        rpc_sender: Arc<AsyncChannelSender<rpc::ServerMessage>>,
581        ws_hdl: ServerHandle,
582    ) -> Result<()> {
583        let (shutdown_sender, shutdown_receiver) = AsyncChannel::oneshot();
584        let _ = rpc_sender
585            .send_async(rpc::ServerMessage::GracefulShutdown(shutdown_sender))
586            .await;
587        let _ = shutdown_receiver;
588
589        info!(
590            subject = "shutdown",
591            category = "homestar.shutdown",
592            "shutting down webserver"
593        );
594
595        let _ = ws_hdl.stop();
596        ws_hdl.stopped().await;
597
598        let (shutdown_sender, shutdown_receiver) = AsyncChannel::oneshot();
599        let _ = self
600            .event_sender
601            .send_async(Event::Shutdown(shutdown_sender))
602            .await;
603        let _ = shutdown_receiver;
604
605        // abort all workers
606        self.abort_workers();
607
608        Ok(())
609    }
610
611    async fn handle_command_message(
612        &self,
613        msg: rpc::ServerMessage,
614        channels: Channels,
615        ws_hdl: ServerHandle,
616        db: impl Database + 'static,
617        network_settings: &settings::Dht,
618        now: time::Instant,
619    ) -> Result<ControlFlow<(), rpc::ServerMessage>> {
620        match msg {
621            rpc::ServerMessage::NodeInfo => {
622                info!(
623                    subject = "rpc.command",
624                    category = "rpc",
625                    "RPC node command received, sending node info"
626                );
627
628                let (tx, rx) = AsyncChannel::oneshot();
629                let _ = self.event_sender.send_async(Event::GetNodeInfo(tx)).await;
630
631                let dyn_node_info = if let Ok(info) = rx.recv_async().await {
632                    info
633                } else {
634                    DynamicNodeInfo::default()
635                };
636
637                Ok(ControlFlow::Continue(rpc::ServerMessage::NodeInfoAck(
638                    response::AckNodeInfo::new(self.node_info.clone(), dyn_node_info),
639                )))
640            }
641            rpc::ServerMessage::ShutdownCmd => {
642                info!(
643                    subject = "rpc.command",
644                    category = "rpc",
645                    "RPC shutdown signal received, shutting down runner"
646                );
647                let drain_timeout = now + self.settings.node.shutdown_timeout;
648                select! {
649                    // we can unwrap here b/c we know we have a sender based
650                    // on the feature flag.
651                    Ok(()) = self.shutdown(channels.rpc, ws_hdl) => {
652                        Ok(ControlFlow::Break(()))
653                    },
654                    _ = time::sleep_until(drain_timeout) => {
655                        info!(subject = "shutdown",
656                              category = "homestar.shutdown",
657                              "shutdown timeout reached, shutting down runner anyway");
658                        Ok(ControlFlow::Break(()))
659                    }
660                }
661            }
662            rpc::ServerMessage::Run((name, workflow_file)) => {
663                info!(
664                    subject = "rpc.command",
665                    category = "rpc",
666                    "RPC run command received, running workflow"
667                );
668                let (workflow, workflow_settings) =
669                    workflow_file.validate_and_parse().await.with_context(|| {
670                        format!("failed to validate/parse workflow @ path: {workflow_file}",)
671                    })?;
672
673                let data = self
674                    .run_worker(
675                        workflow,
676                        workflow_settings,
677                        network_settings,
678                        name,
679                        channels.runner,
680                        db.clone(),
681                    )
682                    .await?;
683
684                Ok(ControlFlow::Continue(rpc::ServerMessage::RunAck(Box::new(
685                    response::AckWorkflow::new(
686                        data.info,
687                        data.replayed_receipt_info,
688                        data.name,
689                        data.timestamp,
690                    ),
691                ))))
692            }
693            msg => {
694                warn!(
695                    subject = "rpc.command",
696                    category = "rpc",
697                    "received unexpected message: {:?}",
698                    msg
699                );
700                Ok(ControlFlow::Continue(rpc::ServerMessage::Skip))
701            }
702        }
703    }
704
705    #[instrument(skip_all)]
706    async fn run_worker<S: Into<FastStr>>(
707        &self,
708        workflow: Workflow<'static, Arg>,
709        workflow_settings: workflow::Settings,
710        network_settings: &settings::Dht,
711        name: Option<S>,
712        runner_sender: AsyncChannelSender<WorkerMessage>,
713        db: impl Database + 'static,
714    ) -> Result<WorkflowData> {
715        let worker = {
716            Worker::new(
717                workflow,
718                workflow_settings,
719                network_settings.clone().to_owned(),
720                name,
721                self.event_sender(),
722                runner_sender,
723                db.clone(),
724            )
725            .await?
726        };
727
728        // Deliberate use of Arc::clone for readability, could just be
729        // `clone`, as the underlying type is an `Arc`.
730        let initial_info = Arc::clone(&worker.workflow_info);
731        let workflow_timeout = worker.workflow_settings.timeout;
732        let workflow_name = worker.workflow_name.clone();
733        let workflow_settings = worker.workflow_settings.clone();
734        let timestamp = worker.workflow_started;
735
736        // Spawn worker, which initializees the scheduler and runs
737        // the workflow.
738        info!(
739            subject = "workflow.run",
740            category = "workflow",
741            cid = worker.workflow_info.cid.to_string(),
742            "running workflow with settings: {:#?}",
743            worker.workflow_settings
744        );
745
746        // Provide workflow to network.
747        //
748        // This essentially says, I'm running this workflow Cid.
749        self.event_sender
750            .send_async(Event::ProvideRecord(
751                worker.workflow_info.cid,
752                None,
753                swarm::CapsuleTag::Workflow,
754            ))
755            .await?;
756
757        #[cfg(feature = "ipfs")]
758        let fetch_fn = {
759            let settings = Arc::clone(&self.settings);
760            let ipfs = IpfsCli::new(settings.node.network.ipfs())?;
761            move |rscs: FnvHashSet<Resource>| {
762                async move { Fetch::get_resources(rscs, workflow_settings, ipfs).await }.boxed()
763            }
764        };
765
766        #[cfg(not(feature = "ipfs"))]
767        let fetch_fn = |rscs: FnvHashSet<Resource>| {
768            async move { Fetch::get_resources(rscs, workflow_settings).await }.boxed()
769        };
770
771        let handle = self.runtime.spawn(
772            worker
773                .run(self.running_tasks(), fetch_fn)
774                .instrument(info_span!("run").or_current()),
775        );
776
777        // Add Cid to expirations timing wheel
778        let delay_key = self
779            .expiration_queue
780            .try_borrow_mut()
781            .map_err(|e| anyhow!("failed to borrow expiration queue: {e}"))?
782            .insert(initial_info.cid, workflow_timeout);
783
784        // Insert handle into running workers map
785        self.running_workers
786            .insert(initial_info.cid, (handle, delay_key));
787
788        // Gather receipt info
789        let receipt_pointers = initial_info
790            .progress
791            .iter()
792            .map(|cid| Pointer::new(*cid))
793            .collect();
794        let replayed_receipt_info = find_receipt_info_by_pointers(&receipt_pointers, db)?;
795
796        // Log replayed receipts if any
797        if !replayed_receipt_info.is_empty() {
798            info!(
799                subject = "workflow.receipts",
800                category = "workflow",
801                receipt_cids = replayed_receipt_info
802                    .iter()
803                    .map(|info| info.0.to_string())
804                    .collect::<Vec<String>>()
805                    .join(","),
806                "replaying receipts",
807            );
808        };
809
810        Ok(WorkflowData {
811            info: initial_info,
812            name: workflow_name,
813            timestamp,
814            replayed_receipt_info,
815        })
816    }
817}
818
819/// Find receipts given a batch of [Receipt] [Pointer]s, and return them as [WorkflowReceiptInfo]s.
820fn find_receipt_info_by_pointers(
821    pointers: &Vec<Pointer>,
822    db: impl Database + 'static,
823) -> Result<Vec<WorkflowReceiptInfo>> {
824    let receipts: HashMap<Cid, Receipt> = Db::find_receipt_pointers(pointers, &mut db.conn()?)?
825        .into_iter()
826        .map(|receipt| (receipt.cid(), receipt))
827        .collect();
828
829    let receipt_info = pointers
830        .iter()
831        .map(|pointer| match receipts.get(&pointer.cid()) {
832            Some(receipt) => (
833                pointer.cid(),
834                Some((receipt.ran(), receipt.instruction().clone())),
835            ),
836            None => (pointer.cid(), None),
837        })
838        .collect();
839
840    Ok(receipt_info)
841}
842
843/// Internal Workflow data used for wrapper.
844struct WorkflowData {
845    info: Arc<workflow::Info>,
846    name: FastStr,
847    timestamp: NaiveDateTime,
848    replayed_receipt_info: Vec<WorkflowReceiptInfo>,
849}
850
851/// Channels for sending messages to/from the RPC server and the runner.
852#[derive(Debug)]
853struct Channels {
854    rpc: Arc<AsyncChannelSender<rpc::ServerMessage>>,
855    runner: AsyncChannelSender<WorkerMessage>,
856}
857
858#[cfg(test)]
859mod test {
860    use super::*;
861    use crate::{
862        network::rpc::Client,
863        test_utils::{db::MemoryDb, WorkerBuilder},
864    };
865    use metrics_exporter_prometheus::PrometheusBuilder;
866    use rand::thread_rng;
867    use std::net::{IpAddr, Ipv4Addr, SocketAddr};
868    use tarpc::context;
869    use tokio::net::TcpStream;
870
871    #[homestar_runtime_proc_macro::runner_test]
872    fn shutdown() {
873        let TestRunner { runner, settings } = TestRunner::start();
874        let (tx, _rx) = Runner::setup_rpc_channel(1);
875        let (runner_tx, _runner_rx) = Runner::setup_ws_mpsc_channel(1);
876        let db = MemoryDb::setup_connection_pool(settings.node(), None).unwrap();
877        let rpc_server = rpc::Server::new(settings.node.network(), Arc::new(tx));
878        let rpc_sender = rpc_server.sender();
879
880        let addr = SocketAddr::new(
881            settings.node.network.rpc.host,
882            settings.node.network.rpc.port,
883        );
884
885        let ws_hdl = runner.runtime.block_on(async {
886            rpc_server.spawn().await.unwrap();
887
888            let port = port_selector::random_free_tcp_port().unwrap();
889            let socket = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port);
890            let (recorder, _exporter) = PrometheusBuilder::new()
891                .with_http_listener(socket)
892                .build()
893                .expect("failed to install recorder/exporter");
894            let metrics_hdl = recorder.handle();
895
896            let ws_hdl = runner
897                .webserver
898                .start(runner_tx, metrics_hdl, db)
899                .await
900                .unwrap();
901            let _stream = TcpStream::connect(addr).await.expect("Connection error");
902            let _another_stream = TcpStream::connect(addr).await.expect("Connection error");
903
904            ws_hdl
905        });
906
907        runner.runtime.block_on(async {
908            match runner.shutdown(rpc_sender, ws_hdl).await {
909                Ok(()) => {
910                    // with shutdown, we should not be able to connect to the server(s)
911                    let stream_error = TcpStream::connect(addr).await;
912                    assert!(stream_error.is_err());
913                    assert!(matches!(
914                        stream_error.unwrap_err().kind(),
915                        std::io::ErrorKind::ConnectionRefused
916                    ));
917
918                    let ws_error =
919                        tokio_tungstenite::connect_async("ws://localhost:1337".to_string()).await;
920                    assert!(ws_error.is_err());
921                }
922                _ => panic!("Shutdown failed."),
923            }
924        });
925    }
926
927    #[homestar_runtime_proc_macro::runner_test]
928    fn spawn_rpc_server_and_ping() {
929        let TestRunner { runner, settings } = TestRunner::start();
930
931        let (tx, _rx) = Runner::setup_rpc_channel(1);
932        let rpc_server = rpc::Server::new(settings.node.network(), tx.into());
933
934        runner.runtime.block_on(rpc_server.spawn()).unwrap();
935
936        runner.runtime.spawn(async move {
937            let addr = SocketAddr::new(
938                settings.node.network.rpc.host,
939                settings.node.network.rpc.port,
940            );
941
942            let client = Client::new(addr, context::current()).await.unwrap();
943            let response = client.ping().await.unwrap();
944            assert_eq!(response, "pong".to_string());
945        });
946    }
947
948    #[homestar_runtime_proc_macro::runner_test]
949    fn abort_all_workers() {
950        let TestRunner { runner, settings } = TestRunner::start();
951
952        runner.runtime.block_on(async {
953            let builder = WorkerBuilder::new(settings.node);
954            let fetch_fn = builder.fetch_fn();
955            let worker = builder.build().await;
956            let workflow_cid = worker.workflow_info.cid;
957            let workflow_timeout = worker.workflow_settings.timeout;
958            let handle = runner
959                .runtime
960                .spawn(worker.run(runner.running_tasks(), fetch_fn));
961            let delay_key = runner
962                .expiration_queue
963                .try_borrow_mut()
964                .unwrap()
965                .insert(workflow_cid, workflow_timeout);
966            runner
967                .running_workers
968                .insert(workflow_cid, (handle, delay_key));
969        });
970
971        runner.abort_workers();
972        runner.runtime.block_on(async {
973            for (_, (handle, _)) in runner.running_workers {
974                assert!(!handle.is_finished());
975                assert!(handle.await.unwrap_err().is_cancelled());
976            }
977        });
978        runner.running_tasks.iter().for_each(|handles| {
979            for handle in &*handles {
980                assert!(handle.is_finished());
981            }
982        });
983    }
984
985    #[homestar_runtime_proc_macro::runner_test]
986    fn abort_and_cleanup_all_workers() {
987        let TestRunner { runner, settings } = TestRunner::start();
988
989        runner.runtime.block_on(async {
990            let builder = WorkerBuilder::new(settings.node);
991            let fetch_fn = builder.fetch_fn();
992            let worker = builder.build().await;
993            let workflow_cid = worker.workflow_info.cid;
994            let workflow_timeout = worker.workflow_settings.timeout;
995            let handle = runner
996                .runtime
997                .spawn(worker.run(runner.running_tasks(), fetch_fn));
998            let delay_key = runner
999                .expiration_queue
1000                .try_borrow_mut()
1001                .unwrap()
1002                .insert(workflow_cid, workflow_timeout);
1003            runner
1004                .running_workers
1005                .insert(workflow_cid, (handle, delay_key));
1006        });
1007
1008        runner.abort_and_cleanup_workers().unwrap();
1009        assert!(runner.running_workers.is_empty());
1010        assert!(runner.running_tasks.is_empty());
1011    }
1012
1013    #[homestar_runtime_proc_macro::runner_test]
1014    fn gc_while_workers_still_running() {
1015        let TestRunner { runner, settings } = TestRunner::start();
1016
1017        runner.runtime.block_on(async {
1018            let builder = WorkerBuilder::new(settings.node);
1019            let fetch_fn = builder.fetch_fn();
1020            let worker = builder.build().await;
1021            let workflow_cid = worker.workflow_info.cid;
1022            let workflow_timeout = worker.workflow_settings.timeout;
1023            let handle = runner
1024                .runtime
1025                .spawn(worker.run(runner.running_tasks(), fetch_fn));
1026            let delay_key = runner
1027                .expiration_queue
1028                .try_borrow_mut()
1029                .unwrap()
1030                .insert(workflow_cid, workflow_timeout);
1031
1032            runner
1033                .running_workers
1034                .insert(workflow_cid, (handle, delay_key));
1035        });
1036
1037        runner.gc().unwrap();
1038        assert!(!runner.running_workers.is_empty());
1039
1040        runner.runtime.block_on(async {
1041            for (_, (handle, _)) in runner.running_workers {
1042                assert!(!handle.is_finished());
1043                let _ = handle.await.unwrap();
1044            }
1045        });
1046
1047        runner.running_tasks.iter().for_each(|handles| {
1048            for handle in &*handles {
1049                assert!(handle.is_finished());
1050            }
1051        });
1052
1053        assert!(!runner.running_tasks.is_empty());
1054        assert!(!runner.expiration_queue.try_borrow_mut().unwrap().is_empty());
1055    }
1056
1057    #[homestar_runtime_proc_macro::runner_test]
1058    fn gc_while_workers_finished() {
1059        let TestRunner { runner, settings } = TestRunner::start();
1060        runner.runtime.block_on(async {
1061            let builder = WorkerBuilder::new(settings.node);
1062            let fetch_fn = builder.fetch_fn();
1063            let worker = builder.build().await;
1064            let _ = worker.run(runner.running_tasks(), fetch_fn).await;
1065        });
1066
1067        runner.running_tasks.iter().for_each(|handles| {
1068            for handle in &*handles {
1069                assert!(handle.is_finished());
1070            }
1071        });
1072
1073        runner.gc().unwrap();
1074        assert!(runner.running_tasks.is_empty());
1075    }
1076
1077    #[homestar_runtime_proc_macro::runner_test]
1078    fn abort_all_tasks() {
1079        let TestRunner { runner, .. } = TestRunner::start();
1080        let mut set = tokio::task::JoinSet::new();
1081        runner.runtime.block_on(async {
1082            for i in 0..3 {
1083                let handle = set.spawn(async move { i });
1084                runner.running_tasks.append_or_insert(
1085                    homestar_invocation::test_utils::cid::generate_cid(&mut thread_rng()),
1086                    vec![handle],
1087                );
1088            }
1089
1090            while set.join_next().await.is_some() {}
1091        });
1092
1093        runner.abort_tasks();
1094        runner.cleanup_tasks();
1095        assert!(runner.running_tasks.is_empty());
1096    }
1097
1098    #[homestar_runtime_proc_macro::runner_test]
1099    fn abort_one_task() {
1100        let TestRunner { runner, .. } = TestRunner::start();
1101        let mut set = tokio::task::JoinSet::new();
1102        let mut cids = vec![];
1103        runner.runtime.block_on(async {
1104            for i in 0..3 {
1105                let handle = set.spawn(async move { i });
1106                let cid = homestar_invocation::test_utils::cid::generate_cid(&mut thread_rng());
1107                runner.running_tasks.append_or_insert(cid, vec![handle]);
1108                cids.push(cid);
1109            }
1110
1111            while set.join_next().await.is_some() {}
1112        });
1113
1114        assert!(runner.running_tasks.len() == 3);
1115        runner.abort_worker_tasks(cids[0]);
1116        assert!(runner.running_tasks.len() == 2);
1117    }
1118}