Skip to main content

borderless_runtime/rt/agent/
tasks.rs

1use std::{
2    sync::Arc,
3    time::{Duration, Instant},
4};
5
6use borderless::{
7    agents::{Schedule, WsConfig},
8    events::Events,
9    AgentId,
10};
11use borderless_kv_store::Db;
12use futures_util::{SinkExt, StreamExt};
13use thiserror::Error;
14use tokio::{
15    sync::{mpsc, Mutex},
16    task::JoinSet,
17    time::{interval, sleep, MissedTickBehavior},
18};
19use tokio_tungstenite::connect_async;
20use tokio_tungstenite::tungstenite::{Bytes, Message};
21
22use crate::log_shim::*;
23
24use super::Runtime;
25
26#[derive(Debug, Error)]
27#[error("Critical error in schedule task - forced to shutdown")]
28pub struct ScheduleError;
29
30/// Function to handle all schedules of a single sw-agent
31#[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(agent_id = %aid), err))]
32pub async fn handle_schedules<S>(
33    rt: Arc<Mutex<Runtime<S>>>,
34    aid: AgentId,
35    schedules: Vec<Schedule>,
36    out_tx: mpsc::Sender<Events>,
37) -> Result<(), ScheduleError>
38where
39    S: Db + 'static,
40{
41    let mut join_set = JoinSet::new();
42    for sched in schedules {
43        let rt = rt.clone();
44        let action = sched.get_action();
45        let out_tx = out_tx.clone();
46        let action_name = action.print_method();
47
48        join_set.spawn(async move {
49            if sched.delay > 0 {
50                sleep(Duration::from_millis(sched.delay)).await;
51            }
52
53            let mut interval = interval(Duration::from_millis(sched.interval));
54            interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
55
56            loop {
57                interval.tick().await;
58                // Dispatch output events
59                let now = Instant::now();
60                let result = rt.lock().await.process_action(&aid, action.clone()).await;
61                match result {
62                    Ok(Some(events)) => {
63                        // NOTE: We panic here to shutdown the entire task in case the receiver is closed
64                        out_tx
65                            .send(events)
66                            .await
67                            .expect("receiver dropped or closed");
68                    }
69                    Ok(None) => (),
70                    Err(e) => error!("failure while executing schedule {action_name}: {e}"),
71                }
72                info!(
73                    "executed schedule {action_name}, time elapsed: {:?}",
74                    now.elapsed()
75                );
76            }
77        });
78    }
79
80    // This loop will run forever unless the outer task is cancelled.
81    // If the outer task is aborted, all spawned tasks inside JoinSet are also dropped.
82    while let Some(res) = join_set.join_next().await {
83        // Catch panics here and shutdown the entire task
84        if let Err(e) = res {
85            error!("A schedule task failed: {e}");
86            // Gracefully shut down all other tasks
87            join_set.abort_all();
88            // Return error
89            return Err(ScheduleError);
90        }
91    }
92    Ok(())
93}
94
95#[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(agent_id = %aid), err))]
96pub async fn handle_ws_connection<S>(
97    rt: Arc<Mutex<Runtime<S>>>,
98    aid: AgentId,
99    ws_config: WsConfig,
100    out_tx: mpsc::Sender<Events>,
101) -> crate::Result<()>
102where
103    S: Db + 'static,
104{
105    // Register the websocket at the runtime
106    let mut msg_rx = rt.lock().await.register_ws(aid)?;
107
108    let mut failure_cnt = 1;
109    #[allow(clippy::while_immutable_condition)]
110    while ws_config.reconnect {
111        match handle_ws_inner(
112            rt.clone(),
113            aid,
114            ws_config.clone(),
115            out_tx.clone(),
116            &mut msg_rx,
117        )
118        .await
119        {
120            Ok(()) => failure_cnt = 1,
121            Err(e) => {
122                warn!("cnt={failure_cnt}, agent-id={aid}, {e}");
123                failure_cnt = (failure_cnt * 2).min(60);
124                sleep(Duration::from_secs(failure_cnt)).await;
125            }
126        }
127    }
128    Ok(())
129}
130
131async fn handle_ws_inner<S>(
132    rt: Arc<Mutex<Runtime<S>>>,
133    aid: AgentId,
134    ws_config: WsConfig,
135    out_tx: mpsc::Sender<Events>,
136    msg_rx: &mut mpsc::Receiver<Vec<u8>>,
137) -> Result<(), String>
138where
139    S: Db + 'static,
140{
141    info!("opening websocket connection to '{}'", ws_config.url);
142    let result = connect_async(&ws_config.url)
143        .await
144        .map_err(|e| format!("failed to open ws-connection - {e}"))?;
145
146    let (stream, response) = result;
147    if response.status().is_client_error() || response.status().is_server_error() {
148        return Err(format!(
149            "failed to open ws-connection - status={}",
150            response.status()
151        ));
152    }
153
154    // Call "on-open"
155    handle_events(rt.lock().await.on_ws_open(&aid).await, &out_tx).await;
156
157    // Set heartbeat timer
158    let mut heartbeat_timer = interval(Duration::from_secs(ws_config.ping_interval.max(10)));
159
160    // Now start receiving messages from the websocket
161    let (mut tx, mut rx) = stream.split();
162
163    info!("successfully opened ws-connection to '{}'", ws_config.url);
164
165    // main loop
166    loop {
167        tokio::select! {
168            biased;
169            // Check heartbeat timer
170            _ = heartbeat_timer.tick() => {
171                let msg = Message::Ping(Vec::new().into());
172                tx.send(msg).await.map_err(|e| format!("failed to send heartbeat: {e}"))?;
173            }
174            result = msg_rx.recv() => {
175                let payload = result.ok_or("Websocket message receiver closed.")?;
176                let msg = if ws_config.binary {
177                    Message::Binary(payload.into())
178                } else {
179                    Message::Text(payload.try_into().unwrap())
180                };
181                // Send message
182                if let Err(e) = tx.send(msg).await {
183                    warn!("failed to send ws-msg: {e}");
184                }
185            }
186            // Check incoming messages
187            result = rx.next() => {
188                if result.is_none() {
189                    warn!("Websocket receiver closed.");
190                    break;
191                }
192                let msg = result.unwrap();
193                if msg.is_err() {
194                    // TODO: Forward error message to wasm ?
195                    warn!("Websocket-msg failure: {}", msg.unwrap_err());
196                    // Call "on-error"
197                    handle_events(rt.lock().await.on_ws_error(&aid).await, &out_tx).await;
198                    break;
199                }
200                let data = match msg.unwrap() {
201                    Message::Text(text) => {
202                        // TODO: Remove this log line, once everything is up and running
203                        info!("incoming text ws msg");
204                        let bytes: Bytes = text.into();
205                        bytes.into()
206                    }
207                    Message::Binary(b) => {
208                        // TODO: Remove this log line, once everything is up and running
209                        info!("incoming binary ws msg");
210                        b.into()
211                    }
212                    Message::Pong(_) => continue,
213                    Message::Close(frame) => {
214                        // TODO: Forward closing frame to wasm ?
215                        info!("Received closing frame: {frame:#?}");
216                        // Call "on-close"
217                        handle_events(rt.lock().await.on_ws_close(&aid).await, &out_tx).await;
218                        break;
219                    }
220                    other => {
221                        info!("receive other websocket msg: {other:#?}");
222                        continue
223                    }
224                };
225
226                // Apply message and dispatch output events
227                handle_events(rt.lock().await.process_ws_msg(&aid, data).await, &out_tx).await;
228            }
229        }
230    }
231    Ok(())
232}
233
234async fn handle_events(result: crate::Result<Option<Events>>, out_tx: &mpsc::Sender<Events>) {
235    match result {
236        Ok(Some(events)) => {
237            // NOTE: We panic here to shutdown the entire task in case the receiver is closed
238            out_tx
239                .send(events)
240                .await
241                .expect("receiver dropped or closed");
242        }
243        Ok(None) => (),
244        Err(e) => error!("failure while executing on-ws-msg: {e}"),
245    }
246}