swimos_runtime/downlink/
mod.rs

1// Copyright 2015-2024 Swim Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::error::Error;
17use std::fmt::Debug;
18use std::num::NonZeroUsize;
19use std::time::Duration;
20
21use crate::downlink::failure::BadFrameResponse;
22use crate::timeout_coord::{VoteResult, Voter};
23use crate::Io;
24use backpressure::DownlinkBackpressure;
25use bitflags::bitflags;
26use bytes::{Bytes, BytesMut};
27use futures::future::{join, select, Either};
28use futures::stream::SelectAll;
29use futures::{Future, FutureExt, Sink, SinkExt, Stream, StreamExt};
30use interpretation::MapInterpretation;
31use swimos_agent_protocol::{
32    encoding::downlink::DownlinkNotificationEncoder, DownlinkNotification,
33};
34use swimos_api::address::RelativeAddress;
35use swimos_messages::protocol::{
36    Notification, Operation, RawRequestMessage, RawRequestMessageEncoder,
37    RawResponseMessageDecoder, ResponseMessage,
38};
39use swimos_model::Text;
40use swimos_utilities::byte_channel::{ByteReader, ByteWriter};
41use swimos_utilities::encoding::BytesStr;
42use swimos_utilities::future::{immediate_or_join, immediate_or_start, SecondaryResult};
43use swimos_utilities::trigger;
44use tokio::sync::mpsc;
45use tokio::time::{timeout, Instant};
46use tokio_stream::wrappers::ReceiverStream;
47use tokio_util::codec::{Decoder, Encoder, FramedRead, FramedWrite};
48use tracing::{error, info, info_span, trace, warn, Instrument};
49use uuid::Uuid;
50
51pub use interpretation::NoInterpretation;
52
53mod backpressure;
54/// Strategies for handling invalid envelopes.
55pub mod failure;
56mod interpretation;
57#[cfg(test)]
58mod tests;
59
60bitflags! {
61    /// Flags that a downlink consumer can set to instruct the downlink runtime how it wishes
62    /// to be driven.
63    #[derive(Debug, Copy, Clone)]
64    pub struct DownlinkOptions: u8 {
65        /// The consumer needs to be synchronized with the remote lane.
66        const SYNC = 0b01;
67        /// If the connection fails, it should be restarted and the consumer passed to the new
68        /// connection.
69        const KEEP_LINKED = 0b10;
70        /// By default, all options are enabled.
71        const DEFAULT = Self::SYNC.bits() | Self::KEEP_LINKED.bits();
72    }
73}
74
75bitflags! {
76    /// Internal flags for the downlink runtime write task.
77    struct WriteTaskState: u8 {
78        /// The outgoing channel has been flushed.
79        const FLUSHED = 0b01;
80        /// A new consumer that needs to be synced has joined why a write was pending.
81        const NEEDS_SYNC = 0b10;
82        /// When the task starts it does not need to be flushed.
83        const INIT = Self::FLUSHED.bits();
84    }
85}
86
87impl WriteTaskState {
88    /// If a new consumer needs to be synced, set the appropriate state bit.
89    ///
90    /// # Arguments
91    /// * `options` - The option flags.
92    pub fn set_needs_sync(&mut self, options: DownlinkOptions) {
93        if options.contains(DownlinkOptions::SYNC) {
94            *self |= WriteTaskState::NEEDS_SYNC;
95        }
96    }
97}
98
99/// A request to attach a new consumer to the downlink runtime.
100pub struct AttachAction {
101    io: Io,
102    options: DownlinkOptions,
103}
104
105impl AttachAction {
106    /// # Arguments
107    /// * `io` - Bidirectional channel to communicate with the downlink runtime.
108    /// * `options` - Option flags for the downlink.
109    pub fn new(io: Io, options: DownlinkOptions) -> Self {
110        AttachAction { io, options }
111    }
112}
113
114/// Configuration parameters for the downlink runtime.
115#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
116pub struct DownlinkRuntimeConfig {
117    /// If the runtime has no consumers for longer than this timeout, it will stop.
118    pub empty_timeout: Duration,
119    /// Size of the queue for accepting new subscribers to a downlink.
120    pub attachment_queue_size: NonZeroUsize,
121    /// Abort the downlink on receiving invalid frames.
122    pub abort_on_bad_frames: bool,
123    /// Size of the buffers to communicated with the socket.
124    pub remote_buffer_size: NonZeroUsize,
125    /// Size of the buffers to communicate with the downlink implementation.
126    pub downlink_buffer_size: NonZeroUsize,
127}
128
129impl Default for DownlinkRuntimeConfig {
130    fn default() -> Self {
131        DownlinkRuntimeConfig {
132            empty_timeout: Duration::from_secs(30),
133            attachment_queue_size: non_zero_usize!(16),
134            abort_on_bad_frames: true,
135            remote_buffer_size: non_zero_usize!(4096),
136            downlink_buffer_size: non_zero_usize!(4096),
137        }
138    }
139}
140
141/// The runtime component for a value type downlink (i.e. value downlink, event downlink, etc.).
142pub struct ValueDownlinkRuntime {
143    requests: mpsc::Receiver<AttachAction>,
144    input: ByteReader,
145    output: ByteWriter,
146    stopping: trigger::Receiver,
147    identity: Uuid,
148    path: RelativeAddress<Text>,
149    config: DownlinkRuntimeConfig,
150}
151
152/// The runtime component for a map type downlink.
153pub struct MapDownlinkRuntime<H, I> {
154    requests: mpsc::Receiver<AttachAction>,
155    input: ByteReader,
156    output: ByteWriter,
157    stopping: trigger::Receiver,
158    identity: Uuid,
159    path: RelativeAddress<Text>,
160    config: DownlinkRuntimeConfig,
161    failure_handler: H,
162    interpretation: I,
163}
164
165async fn await_io_tasks<F1, F2, E>(
166    read: F1,
167    write: F2,
168    kill_switch_tx: trigger::Sender,
169) -> Result<(), E>
170where
171    F1: Future<Output = Result<(), E>>,
172    F2: Future<Output = ()>,
173{
174    let read = pin!(read);
175    let write = pin!(write);
176    let first_finished = select(read, write).await;
177    kill_switch_tx.trigger();
178    match first_finished {
179        Either::Left((read_res, write_fut)) => {
180            write_fut.await;
181            read_res
182        }
183        Either::Right((_, read_fut)) => read_fut.await,
184    }
185}
186
187impl ValueDownlinkRuntime {
188    /// # Arguments
189    /// * `requests` - The channel through which new consumers connect to the runtime.
190    /// * `io` - Byte channels through which messages are received from and sent to the remote lane.
191    /// * `stopping` - Trigger to instruct the runtime to stop.
192    /// * `identity` - The routing ID of this runtime.
193    /// * `path` - The path to the remote lane.
194    /// * `config` - Configuration parameters for the runtime.
195    pub fn new(
196        requests: mpsc::Receiver<AttachAction>,
197        io: Io,
198        stopping: trigger::Receiver,
199        address: IdentifiedAddress,
200        config: DownlinkRuntimeConfig,
201    ) -> Self {
202        let (output, input) = io;
203        let IdentifiedAddress {
204            identity,
205            address: path,
206        } = address;
207        ValueDownlinkRuntime {
208            requests,
209            input,
210            output,
211            stopping,
212            identity,
213            path,
214            config,
215        }
216    }
217
218    /// Run the downlink task.
219    pub async fn run(self) {
220        let ValueDownlinkRuntime {
221            requests,
222            input,
223            output,
224            stopping,
225            identity,
226            path,
227            config,
228        } = self;
229
230        let (producer_tx, producer_rx) = mpsc::channel(config.attachment_queue_size.get());
231        let (consumer_tx, consumer_rx) = mpsc::channel(config.attachment_queue_size.get());
232        let (kill_switch_tx, kill_switch_rx) = trigger::trigger();
233        let (read_vote, write_vote, vote_rx) = crate::timeout_coord::downlink_timeout_coordinator();
234
235        let combined_stop = select(stopping, select(kill_switch_rx, vote_rx));
236        let att = attach_task(requests, producer_tx, consumer_tx, combined_stop)
237            .instrument(info_span!("Value Downlink Runtime Attachment Task", %path));
238        let read = read_task(
239            input,
240            consumer_rx,
241            config,
242            value_interpretation(),
243            InfallibleStrategy,
244            read_vote,
245        )
246        .instrument(info_span!("Value Downlink Runtime Read Task", %path));
247        let write = write_task(
248            output,
249            producer_rx,
250            identity,
251            RelativeAddress::new(path.node.clone(), path.lane.clone()),
252            config,
253            ValueBackpressure::default(),
254            write_vote,
255        )
256        .instrument(info_span!("Value Downlink Runtime Write Task", %path));
257        let io = async move {
258            let read_res = await_io_tasks(read, write, kill_switch_tx).await;
259            if let Err(e) = read_res {
260                match e {}
261            }
262        };
263        join(att, io).await;
264    }
265}
266
267impl<H> MapDownlinkRuntime<H, MapInterpretation> {
268    /// # Arguments
269    /// * `requests` - The channel through which new consumers connect to the runtime.
270    /// * `io` - Byte channels through which messages are received from and sent to the remote lane.
271    /// * `stopping` - Trigger to instruct the runtime to stop.
272    /// * `identity` - The routing ID of this runtime.
273    /// * `path` - The path to the remote lane.
274    /// * `config` - Configuration parameters for the runtime.
275    /// * `failure_handler` - Handler for event frames that do no contain valid map
276    ///    messages.
277    pub fn new(
278        requests: mpsc::Receiver<AttachAction>,
279        io: Io,
280        stopping: trigger::Receiver,
281        address: IdentifiedAddress,
282        config: DownlinkRuntimeConfig,
283        failure_handler: H,
284    ) -> Self {
285        let (output, input) = io;
286        let IdentifiedAddress {
287            identity,
288            address: path,
289        } = address;
290        MapDownlinkRuntime {
291            requests,
292            input,
293            output,
294            stopping,
295            identity,
296            path,
297            config,
298            failure_handler,
299            interpretation: MapInterpretation::default(),
300        }
301    }
302}
303
304impl<I, H> MapDownlinkRuntime<H, I> {
305    /// # Arguments
306    /// * `requests` - The channel through which new consumers connect to the runtime.
307    /// * `io` - Byte channels through which messages are received from and sent to the remote lane.
308    /// * `stopping` - Trigger to instruct the runtime to stop.
309    /// * `identity` - The routing ID of this runtime.
310    /// * `path` - The path to the remote lane.
311    /// * `config` - Configuration parameters for the runtime.
312    /// * `failure_handler` - Handler for event frames that do no contain valid map
313    ///    messages.
314    /// * `interpretation` - A transformation to apply to an incoming event body, before passing it
315    ///    on to the downlink implementation.
316    pub fn with_interpretation(
317        requests: mpsc::Receiver<AttachAction>,
318        io: Io,
319        stopping: trigger::Receiver,
320        address: IdentifiedAddress,
321        config: DownlinkRuntimeConfig,
322        failure_handler: H,
323        interpretation: I,
324    ) -> Self {
325        let (output, input) = io;
326        let IdentifiedAddress {
327            identity,
328            address: path,
329        } = address;
330        MapDownlinkRuntime {
331            requests,
332            input,
333            output,
334            stopping,
335            identity,
336            path,
337            config,
338            failure_handler,
339            interpretation,
340        }
341    }
342}
343
344/// Identity labels for a downlink runtime.
345pub struct IdentifiedAddress {
346    /// The unique routing ID of the downlink.
347    pub identity: Uuid,
348    /// The address to which the downlink is attached.
349    pub address: RelativeAddress<Text>,
350}
351
352impl<I, H> MapDownlinkRuntime<H, I>
353where
354    I: DownlinkInterpretation,
355    H: BadFrameStrategy<I::Error>,
356{
357    /// Run the downlink runtime task.
358    pub async fn run(self) {
359        let MapDownlinkRuntime {
360            requests,
361            input,
362            output,
363            stopping,
364            identity,
365            path,
366            config,
367            failure_handler,
368            interpretation,
369        } = self;
370
371        let (producer_tx, producer_rx) = mpsc::channel(config.attachment_queue_size.get());
372        let (consumer_tx, consumer_rx) = mpsc::channel(config.attachment_queue_size.get());
373        let (read_vote, write_vote, vote_rx) = crate::timeout_coord::downlink_timeout_coordinator();
374        let (kill_switch_tx, kill_switch_rx) = trigger::trigger();
375
376        let combined_stop = select(stopping, select(kill_switch_rx, vote_rx));
377        let att = attach_task(requests, producer_tx, consumer_tx, combined_stop)
378            .instrument(info_span!("Map Downlink Runtime Attachment Task", %path));
379        let read = read_task(
380            input,
381            consumer_rx,
382            config,
383            interpretation,
384            failure_handler,
385            read_vote,
386        )
387        .instrument(info_span!("Map Downlink Runtime Read Task", %path));
388        let write = write_task(
389            output,
390            producer_rx,
391            identity,
392            RelativeAddress::new(path.node.clone(), path.lane.clone()),
393            config,
394            MapBackpressure::default(),
395            write_vote,
396        )
397        .instrument(info_span!("Map Downlink Runtime Write Task", %path));
398        let io = async move {
399            let read_res = await_io_tasks(read, write, kill_switch_tx).await;
400            if let Err(e) = read_res {
401                error!(
402                    "Map downlink received invalid event messages: {problems}",
403                    problems = e
404                );
405            }
406        };
407        join(att, io).await;
408    }
409}
410
411/// Communicates with the read and write tasks to add new consumers.
412async fn attach_task<F>(
413    rx: mpsc::Receiver<AttachAction>,
414    producer_tx: mpsc::Sender<(ByteReader, DownlinkOptions)>,
415    consumer_tx: mpsc::Sender<(ByteWriter, DownlinkOptions)>,
416    combined_stop: F,
417) where
418    F: Future + Unpin,
419{
420    let mut stream = ReceiverStream::new(rx).take_until(combined_stop);
421    while let Some(AttachAction {
422        io: (output, input),
423        options,
424    }) = stream.next().await
425    {
426        if consumer_tx.send((output, options)).await.is_err() {
427            break;
428        }
429        if producer_tx.send((input, options)).await.is_err() {
430            break;
431        }
432    }
433    trace!("Attachment task stopping.");
434}
435
436#[derive(Debug, PartialEq, Eq)]
437enum ReadTaskDlState {
438    Init,
439    Linked,
440    Synced,
441}
442
443/// Sender to communicate with a subscriber to the downlink.
444#[derive(Debug)]
445struct DownlinkSender {
446    sender: FramedWrite<ByteWriter, DownlinkNotificationEncoder>,
447    options: DownlinkOptions,
448}
449
450impl DownlinkSender {
451    fn new(writer: ByteWriter, options: DownlinkOptions) -> Self {
452        DownlinkSender {
453            sender: FramedWrite::new(writer, DownlinkNotificationEncoder),
454            options,
455        }
456    }
457
458    async fn feed(
459        &mut self,
460        message: DownlinkNotification<&BytesMut>,
461    ) -> Result<(), std::io::Error> {
462        self.sender.feed(message).await
463    }
464
465    async fn send(
466        &mut self,
467        message: DownlinkNotification<&BytesMut>,
468    ) -> Result<(), std::io::Error> {
469        self.sender.send(message).await
470    }
471
472    async fn flush(&mut self) -> Result<(), std::io::Error> {
473        flush_sender_notification(&mut self.sender).await
474    }
475}
476
477async fn flush_sender_notification<T>(
478    sender: &mut FramedWrite<ByteWriter, T>,
479) -> Result<(), T::Error>
480where
481    T: Encoder<DownlinkNotification<&'static BytesMut>>,
482{
483    sender.flush().await
484}
485
486async fn flush_sender_req<T>(sender: &mut FramedWrite<ByteWriter, T>) -> Result<(), T::Error>
487where
488    T: Encoder<RawRequestMessage<'static, &'static str>>,
489{
490    sender.flush().await
491}
492
493/// Receiver to receive commands from downlink subscribers.
494struct DownlinkReceiver<D> {
495    receiver: FramedRead<ByteReader, D>,
496    id: u64,
497    terminated: bool,
498}
499
500impl<D: Decoder> DownlinkReceiver<D> {
501    fn new(reader: ByteReader, id: u64, decoder: D) -> Self {
502        DownlinkReceiver {
503            receiver: FramedRead::new(reader, decoder),
504            id,
505            terminated: false,
506        }
507    }
508
509    /// Stops a receiver so that it will be removed from a [`SelectAll`] collection.
510    fn terminate(&mut self) {
511        self.terminated = true;
512    }
513}
514
515struct Failed(u64, Box<dyn Error + 'static>);
516
517impl<D: Decoder> Stream for DownlinkReceiver<D>
518where
519    D::Error: Error + 'static,
520{
521    type Item = Result<D::Item, Failed>;
522
523    fn poll_next(
524        self: std::pin::Pin<&mut Self>,
525        cx: &mut std::task::Context<'_>,
526    ) -> std::task::Poll<Option<Self::Item>> {
527        let this = self.get_mut();
528        if this.terminated {
529            std::task::Poll::Ready(None)
530        } else {
531            this.receiver
532                .poll_next_unpin(cx)
533                .map_err(|e| Failed(this.id, Box::new(e)))
534        }
535    }
536}
537
538enum ReadTaskEvent {
539    Message(ResponseMessage<BytesStr, Bytes, Bytes>),
540    ReadFailed(std::io::Error),
541    MessagesStopped,
542    NewConsumer(ByteWriter, DownlinkOptions),
543    ConsumerChannelStopped,
544    ConsumersTimedOut,
545}
546
547/// Consumes incoming messages from the remote lane and passes them to the consumers.
548async fn read_task<I, H>(
549    input: ByteReader,
550    consumers: mpsc::Receiver<(ByteWriter, DownlinkOptions)>,
551    config: DownlinkRuntimeConfig,
552    mut interpretation: I,
553    mut failure_handler: H,
554    stop_voter: Voter,
555) -> Result<(), H::Report>
556where
557    I: DownlinkInterpretation,
558    H: BadFrameStrategy<I::Error>,
559{
560    let mut messages = FramedRead::new(input, RawResponseMessageDecoder);
561
562    let mut flushed = true;
563    let mut voted = false;
564
565    let mut consumer_stream = ReceiverStream::new(consumers);
566
567    let make_timeout = || tokio::time::sleep_until(Instant::now() + config.empty_timeout);
568    let mut task_state = pin!(Some(make_timeout()));
569    let mut dl_state = ReadTaskDlState::Init;
570    let mut current = BytesMut::new();
571    let mut awaiting_synced: Vec<DownlinkSender> = vec![];
572    let mut awaiting_linked: Vec<DownlinkSender> = vec![];
573    let mut registered: Vec<DownlinkSender> = vec![];
574    // Track whether any events have been received while syncing the downlink. While this isn't the
575    // nicest thing to have, it's required to distinguish between communicating with a stateless
576    // lane and a downlink syncing with a lane that has a type which may be optional.
577    let mut sync_event = false;
578
579    let result: Result<(), H::Report> = loop {
580        let (event, is_active) = match task_state.as_mut().as_pin_mut() {
581            Some(sleep) if !voted => (
582                tokio::select! {
583                    biased;
584                    _ = sleep => {
585                        task_state.set(None);
586                        ReadTaskEvent::ConsumersTimedOut
587                    },
588                    maybe_consumer = consumer_stream.next() => {
589                        if let Some((consumer, options)) = maybe_consumer {
590                            ReadTaskEvent::NewConsumer(consumer, options)
591                        } else {
592                            ReadTaskEvent::ConsumerChannelStopped
593                        }
594                    },
595                    maybe_message = messages.next() => {
596                        match maybe_message {
597                            Some(Ok(msg)) => ReadTaskEvent::Message(msg),
598                            Some(Err(err)) => ReadTaskEvent::ReadFailed(err),
599                            _ => ReadTaskEvent::MessagesStopped,
600                        }
601                    }
602                },
603                false,
604            ),
605            _ => {
606                let get_next = async {
607                    tokio::select! {
608                        maybe_consumer = consumer_stream.next() => {
609                            if let Some((consumer, options)) = maybe_consumer {
610                                ReadTaskEvent::NewConsumer(consumer, options)
611                            } else {
612                                ReadTaskEvent::ConsumerChannelStopped
613                            }
614                        },
615                        maybe_message = messages.next() => {
616                            match maybe_message {
617                                Some(Ok(msg)) => ReadTaskEvent::Message(msg),
618                                Some(Err(err)) => ReadTaskEvent::ReadFailed(err),
619                                _ => ReadTaskEvent::MessagesStopped,
620                            }
621                        }
622                    }
623                };
624                if flushed {
625                    trace!("Waiting without flush.");
626                    (get_next.await, true)
627                } else {
628                    trace!("Waiting with flush.");
629                    let flush = join(flush_all(&mut awaiting_synced), flush_all(&mut registered));
630                    let next_with_flush = immediate_or_join(get_next, flush);
631                    let (next, flush_result) = next_with_flush.await;
632                    let is_active = if flush_result.is_some() {
633                        trace!("Flush completed.");
634                        flushed = true;
635                        if registered.is_empty() && awaiting_synced.is_empty() {
636                            trace!("Number of subscribers dropped to 0.");
637                            task_state.set(Some(make_timeout()));
638                            false
639                        } else {
640                            true
641                        }
642                    } else {
643                        true
644                    };
645                    (next, is_active)
646                }
647            }
648        };
649
650        match event {
651            ReadTaskEvent::ConsumersTimedOut => {
652                info!("No consumers connected within the timeout period. Voting to stop.");
653                if stop_voter.vote() == VoteResult::Unanimous {
654                    // No consumers registered within the timeout and the write task has voted to stop.
655                    break Ok(());
656                } else {
657                    voted = true;
658                }
659            }
660            ReadTaskEvent::Message(ResponseMessage { envelope, .. }) => match envelope {
661                Notification::Linked => {
662                    trace!("Entering Linked state.");
663                    dl_state = ReadTaskDlState::Linked;
664                    if is_active {
665                        link(&mut awaiting_linked, &mut awaiting_synced, &mut registered).await;
666                        if awaiting_synced.is_empty() && registered.is_empty() {
667                            trace!("Number of subscribers dropped to 0.");
668                            task_state.set(Some(make_timeout()));
669                        }
670                    }
671                }
672                Notification::Synced => {
673                    trace!("Entering Synced state.");
674                    dl_state = ReadTaskDlState::Synced;
675                    if is_active {
676                        // `sync_event` will be false if we're communicating with a stateless lane
677                        // as no event envelope will have been sent. However, it's valid Recon for
678                        // an empty event envelope to be sent (consider Option::None) and this must
679                        // still be sent to the downlink task.
680                        //
681                        // If we're linked to a stateless lane, then `sync_current` cannot be used
682                        // as we will not have received an event envelope as it will dispatch one
683                        // with the empty buffer and this may cause the downlink task's decoder to
684                        // fail due to reading an extant read event. Therefore, delegate the operation to
685                        // `sync_only` which will not send an event notification.
686                        if I::SINGLE_FRAME_STATE && sync_event {
687                            sync_current(&mut awaiting_synced, &mut registered, &current).await;
688                        } else {
689                            sync_only(&mut awaiting_synced, &mut registered).await;
690                        }
691                        if registered.is_empty() {
692                            trace!("Number of subscribers dropped to 0.");
693                            task_state.set(Some(make_timeout()));
694                        }
695                    }
696                }
697                Notification::Unlinked(message) => {
698                    trace!("Stopping after unlinked: {msg:?}", msg = message);
699                    break Ok(());
700                }
701                Notification::Event(bytes) => {
702                    sync_event = true;
703
704                    trace!("Updating the current value.");
705                    current.clear();
706
707                    if let Err(e) = interpretation.interpret_frame_data(bytes, &mut current) {
708                        if let BadFrameResponse::Abort(report) = failure_handler.failed_with(e) {
709                            break Err(report);
710                        }
711                    }
712                    if is_active {
713                        send_current(&mut registered, &current).await;
714                        if !I::SINGLE_FRAME_STATE {
715                            send_current(&mut awaiting_synced, &current).await;
716                        }
717                        if registered.is_empty() && awaiting_synced.is_empty() {
718                            trace!("Number of subscribers dropped to 0.");
719                            task_state.set(Some(make_timeout()));
720                            flushed = true;
721                        } else {
722                            flushed = false;
723                        }
724                    }
725                }
726            },
727            ReadTaskEvent::ReadFailed(err) => {
728                error!(
729                    "Failed to read a frame from the input: {error}",
730                    error = err
731                );
732                break Ok(());
733            }
734            ReadTaskEvent::NewConsumer(writer, options) => {
735                let mut dl_writer = DownlinkSender::new(writer, options);
736                let added = if matches!(dl_state, ReadTaskDlState::Init) {
737                    trace!("Attaching a new subscriber to be linked.");
738                    awaiting_linked.push(dl_writer);
739                    true
740                } else if dl_writer.send(DownlinkNotification::Linked).await.is_ok() {
741                    trace!("Attaching a new subscriber to be synced.");
742                    awaiting_synced.push(dl_writer);
743                    true
744                } else {
745                    false
746                };
747                if added {
748                    task_state.set(None);
749                    if voted {
750                        if stop_voter.rescind() == VoteResult::Unanimous {
751                            info!(
752                                "Attempted to rescind stop vote but shutdown had already started."
753                            );
754                            break Ok(());
755                        } else {
756                            voted = false;
757                        }
758                    }
759                }
760            }
761            _ => {
762                trace!("Instructed to stop.");
763                break Ok(());
764            }
765        }
766    };
767
768    trace!("Read task stopping and unlinked all subscribers");
769    unlink(awaiting_linked).await;
770    unlink(awaiting_synced).await;
771    unlink(registered).await;
772    result
773}
774
775async fn sync_current(
776    awaiting_synced: &mut Vec<DownlinkSender>,
777    registered: &mut Vec<DownlinkSender>,
778    current: &BytesMut,
779) {
780    let event = DownlinkNotification::Event { body: current };
781    for mut tx in awaiting_synced.drain(..) {
782        if tx.feed(event).await.is_ok() && tx.send(DownlinkNotification::Synced).await.is_ok() {
783            registered.push(tx);
784        }
785    }
786}
787
788async fn sync_only(
789    awaiting_synced: &mut Vec<DownlinkSender>,
790    registered: &mut Vec<DownlinkSender>,
791) {
792    for mut tx in awaiting_synced.drain(..) {
793        if tx.send(DownlinkNotification::Synced).await.is_ok() {
794            registered.push(tx);
795        }
796    }
797}
798
799async fn send_current(senders: &mut Vec<DownlinkSender>, current: &BytesMut) {
800    let event = DownlinkNotification::Event { body: current };
801    let mut failed = HashSet::<usize>::default();
802    for (i, tx) in senders.iter_mut().enumerate() {
803        if tx.feed(event).await.is_err() {
804            failed.insert(i);
805        }
806    }
807    clear_failed(senders, &failed);
808}
809
810async fn link(
811    awaiting_linked: &mut Vec<DownlinkSender>,
812    awaiting_synced: &mut Vec<DownlinkSender>,
813    registered: &mut Vec<DownlinkSender>,
814) {
815    let event = DownlinkNotification::Linked;
816
817    for mut tx in awaiting_linked.drain(..) {
818        if tx.send(event).await.is_ok() {
819            if tx.options.contains(DownlinkOptions::SYNC) {
820                awaiting_synced.push(tx);
821            } else {
822                registered.push(tx);
823            }
824        }
825    }
826}
827
828async fn unlink(senders: Vec<DownlinkSender>) -> Vec<DownlinkSender> {
829    let mut to_keep = vec![];
830    for mut tx in senders.into_iter() {
831        let still_active = tx.send(DownlinkNotification::Unlinked).await.is_ok();
832        if tx.options.contains(DownlinkOptions::KEEP_LINKED) && still_active {
833            to_keep.push(tx);
834        }
835    }
836    to_keep
837}
838
839fn clear_failed(senders: &mut Vec<DownlinkSender>, failed: &HashSet<usize>) {
840    if !failed.is_empty() {
841        trace!(
842            "Clearing {num_failed} failed subscribers.",
843            num_failed = failed.len()
844        );
845        let mut i = 0;
846        senders.retain(|_| {
847            let keep = !failed.contains(&i);
848            i += 1;
849            keep
850        });
851    }
852}
853
854async fn flush_all(senders: &mut Vec<DownlinkSender>) {
855    let mut failed = HashSet::<usize>::default();
856    for (i, tx) in senders.iter_mut().enumerate() {
857        if tx.flush().await.is_err() {
858            failed.insert(i);
859        }
860    }
861    clear_failed(senders, &failed);
862}
863
864#[derive(Debug)]
865struct RequestSender {
866    sender: FramedWrite<ByteWriter, RawRequestMessageEncoder>,
867    identity: Uuid,
868    path: RelativeAddress<Text>,
869}
870
871impl RequestSender {
872    fn new(writer: ByteWriter, identity: Uuid, path: RelativeAddress<Text>) -> Self {
873        RequestSender {
874            sender: FramedWrite::new(writer, RawRequestMessageEncoder),
875            identity,
876            path,
877        }
878    }
879
880    async fn send_link(&mut self) -> Result<(), std::io::Error> {
881        let RequestSender {
882            sender,
883            identity,
884            path,
885        } = self;
886        let message = RawRequestMessage {
887            origin: *identity,
888            path: path.clone(),
889            envelope: Operation::Link,
890        };
891        sender.send(message).await
892    }
893
894    async fn send_sync(&mut self) -> Result<(), std::io::Error> {
895        let RequestSender {
896            sender,
897            identity,
898            path,
899        } = self;
900        let message = RawRequestMessage {
901            origin: *identity,
902            path: path.clone(),
903            envelope: Operation::Sync,
904        };
905        sender.send(message).await
906    }
907
908    async fn feed_command(&mut self, body: &[u8]) -> Result<(), std::io::Error> {
909        let RequestSender {
910            sender,
911            identity,
912            path,
913        } = self;
914        let message = RawRequestMessage {
915            origin: *identity,
916            path: path.clone(),
917            envelope: Operation::Command(body),
918        };
919        sender.feed(message).await
920    }
921
922    async fn flush(&mut self) -> Result<(), std::io::Error> {
923        flush_sender_req(&mut self.sender).await
924    }
925
926    fn owning_flush(self) -> OwningFlush {
927        OwningFlush::new(self)
928    }
929}
930
931/// The internal state of the write task.
932enum WriteState<F> {
933    /// No writes are currently pending.
934    Idle {
935        message_writer: RequestSender,
936        buffer: BytesMut,
937    },
938    /// A write to the outgoing channel is pending. In this state backpressure relief will
939    /// cause updates the overwritten.
940    Writing(F),
941}
942
943enum WriteKind {
944    Sync,
945    Data,
946}
947
948async fn do_flush(
949    flush: OwningFlush,
950    buffer: BytesMut,
951) -> (Result<RequestSender, std::io::Error>, BytesMut) {
952    let result = flush.await;
953    (result, buffer)
954}
955
956/// Receives commands for the subscribers to the downlink and writes them to the outgoing channel.
957/// If commands are received faster than the channel can send them, some records will be dropped.
958async fn write_task<B: DownlinkBackpressure>(
959    output: ByteWriter,
960    producers: mpsc::Receiver<(ByteReader, DownlinkOptions)>,
961    identity: Uuid,
962    path: RelativeAddress<Text>,
963    config: DownlinkRuntimeConfig,
964    mut backpressure: B,
965    stop_voter: Voter,
966) where
967    <<B as DownlinkBackpressure>::Dec as Decoder>::Error: Error + 'static,
968{
969    let mut message_writer = RequestSender::new(output, identity, path);
970    if message_writer.send_link().await.is_err() {
971        return;
972    }
973
974    let mut state = WriteState::Idle {
975        message_writer,
976        buffer: BytesMut::new(),
977    };
978
979    let mut registered: SelectAll<DownlinkReceiver<B::Dec>> = SelectAll::new();
980    let mut reg_requests = ReceiverStream::new(producers);
981    // Consumers are given unique IDs to allow them to be removed when they fail.
982    let mut id: u64 = 0;
983    let mut next_id = move || {
984        let i = id;
985        id += 1;
986        i
987    };
988
989    let mut task_state = WriteTaskState::INIT;
990
991    let suspend_write = |mut message_writer: RequestSender, buffer: BytesMut, kind: WriteKind| async move {
992        let result = match kind {
993            WriteKind::Data => message_writer.feed_command(buffer.as_ref()).await,
994            WriteKind::Sync => message_writer.send_sync().await,
995        };
996        (result.map(move |_| message_writer), buffer)
997    };
998
999    let mut voted = false;
1000
1001    // The write task state machine has two states and three supplementary flags. The flags indicate the
1002    // following conditions:
1003    //
1004    // Flags
1005    // =====
1006    //
1007    // HAS_DATA:    While a write or flush was occurring, at least one new command has been received and
1008    //              another write needs to be scheduled. This is stored implicitly in the bakpressure
1009    //              implementation.
1010    // FLUSHED:     Indicates that all data written to the output has been flushed.
1011    // NEEDS_SYNC:  While a write or flush was occurring, a new subscriber was added that requested a SYNC
1012    //              message to be sent and this should be sent at the next opportunity.
1013    //
1014    // The state are as follows:
1015    //
1016    // States
1017    // ======
1018    //
1019    // Idle:    No writes are pending. In this state it will wait for new subscribers and outgoing commands
1020    //          from existing subscribers. If the FLUSHED flag is not set, it will attempt to flush the
1021    //          output channel simultaneously.
1022    //          1. If a new subscriber is received and:
1023    //              a. The flush did not start or completed successfully. Then subscriber is added and if it
1024    //                 requests a SYNC, we move to the Writing state for the SYNC message.
1025    //              b. The flush started and did not complete. The NEEDS_SYNC flag is set and we move to the
1026    //                 Flushing state.
1027    //          2. If a command is received and:
1028    //              a. The flush did not start or completed successfully. We move to the Writing state, writing
1029    //                 the command.
1030    //              b. The flush started and did not complete. The command is written into a buffer, the HAS_DATA
1031    //                 flag is set and we move into the Flushing state.
1032    // Writing: A write (of a command or SYNC message) is pending. In this state it will wait for the write to
1033    //          complete and for new subscribers and outgoing commands from existing subscribers.
1034    //          1. If the write completes and:
1035    //              a. The NEEDS_SYNC flag is set. A new write is scheduled for the SYNC and we remain in the
1036    //                 Writing state.
1037    //              b. The HAS_DATA flag is set. A new write is scheduled for the contents of the buffer and we
1038    //                 remain in the Writing state.
1039    //              c. Otherwise we move back to the Idle state.
1040    'outer: loop {
1041        match state {
1042            WriteState::Idle {
1043                mut message_writer,
1044                mut buffer,
1045            } => {
1046                if registered.is_empty() {
1047                    task_state.remove(WriteTaskState::NEEDS_SYNC);
1048                    let req_with_timeout = async {
1049                        if voted {
1050                            Ok(reg_requests.next().await)
1051                        } else {
1052                            timeout(config.empty_timeout, reg_requests.next()).await
1053                        }
1054                    };
1055                    let req_result = if task_state.contains(WriteTaskState::FLUSHED) {
1056                        req_with_timeout.await
1057                    } else {
1058                        let (req_result, flush_result) =
1059                            join(req_with_timeout, message_writer.flush()).await;
1060                        if let Err(err) = flush_result {
1061                            warn!(error = %err, "Flushing the output failed.");
1062                            break 'outer;
1063                        } else {
1064                            task_state |= WriteTaskState::FLUSHED;
1065                        }
1066                        req_result
1067                    };
1068                    match req_result {
1069                        Ok(Some((reader, options))) => {
1070                            if voted {
1071                                if stop_voter.rescind() == VoteResult::Unanimous {
1072                                    info!("Attempted to rescind vote to stop but shutdown had already started.");
1073                                    break 'outer;
1074                                } else {
1075                                    voted = false;
1076                                }
1077                            }
1078                            trace!("Registering new subscriber.");
1079                            let receiver =
1080                                DownlinkReceiver::new(reader, next_id(), B::make_decoder());
1081                            registered.push(receiver);
1082                            if options.contains(DownlinkOptions::SYNC) {
1083                                trace!("Sending a Sync message.");
1084                                let write = suspend_write(message_writer, buffer, WriteKind::Sync);
1085                                state = WriteState::Writing(Either::Left(write));
1086                            } else {
1087                                state = WriteState::Idle {
1088                                    message_writer,
1089                                    buffer,
1090                                };
1091                            }
1092                        }
1093                        Err(_) => {
1094                            if stop_voter.vote() == VoteResult::Unanimous {
1095                                info!("Stopping as no subscribers attached within the timeout and read task voted to stop.");
1096                                break 'outer;
1097                            } else {
1098                                voted = true;
1099                                state = WriteState::Idle {
1100                                    message_writer,
1101                                    buffer,
1102                                };
1103                            }
1104                        }
1105                        _ => {
1106                            info!("Instructed to stop.");
1107                            break 'outer;
1108                        }
1109                    }
1110                } else {
1111                    let next = select(reg_requests.next(), registered.next());
1112                    let (next_op, flush_outcome) = if task_state.contains(WriteTaskState::FLUSHED) {
1113                        (discard(next.await), Either::Left(message_writer))
1114                    } else {
1115                        let (next_op, flush_result) =
1116                            immediate_or_start(next, message_writer.owning_flush()).await;
1117                        let flush_outcome = match flush_result {
1118                            SecondaryResult::NotStarted(of) => Either::Left(of.reclaim()),
1119                            SecondaryResult::Pending(of) => Either::Right(of),
1120                            SecondaryResult::Completed(Ok(sender)) => {
1121                                task_state |= WriteTaskState::FLUSHED;
1122                                Either::Left(sender)
1123                            }
1124                            SecondaryResult::Completed(Err(_)) => {
1125                                warn!("Flushing the output failed.");
1126                                break 'outer;
1127                            }
1128                        };
1129                        (discard(next_op), flush_outcome)
1130                    };
1131                    match next_op {
1132                        Either::Left(Some((reader, options))) => {
1133                            let receiver =
1134                                DownlinkReceiver::new(reader, next_id(), B::make_decoder());
1135                            registered.push(receiver);
1136                            match flush_outcome {
1137                                Either::Left(message_writer) => {
1138                                    if options.contains(DownlinkOptions::SYNC) {
1139                                        trace!("Sending a Sync message.");
1140                                        let write =
1141                                            suspend_write(message_writer, buffer, WriteKind::Sync);
1142                                        state = WriteState::Writing(Either::Left(write));
1143                                    } else {
1144                                        state = WriteState::Idle {
1145                                            message_writer,
1146                                            buffer,
1147                                        };
1148                                    }
1149                                }
1150                                Either::Right(flush) => {
1151                                    trace!("Waiting on the completion of a flush.");
1152                                    task_state.set_needs_sync(options);
1153                                    state =
1154                                        WriteState::Writing(Either::Right(do_flush(flush, buffer)));
1155                                }
1156                            }
1157                        }
1158                        Either::Left(_) => {
1159                            info!("Instructed to stop.");
1160                            break 'outer;
1161                        }
1162                        Either::Right(Some(Ok(op))) => match flush_outcome {
1163                            Either::Left(message_writer) => {
1164                                trace!("Dispatching an event.");
1165                                backpressure.write_direct(op, &mut buffer);
1166                                let write_fut =
1167                                    suspend_write(message_writer, buffer, WriteKind::Data);
1168                                task_state.remove(WriteTaskState::FLUSHED);
1169                                state = WriteState::Writing(Either::Left(write_fut));
1170                            }
1171                            Either::Right(flush) => {
1172                                trace!("Storing an event in the buffer and waiting on a flush to compelte.");
1173                                if let Err(err) = backpressure.push_operation(op) {
1174                                    error!(
1175                                        "Failed to process downlink operaton: {error}",
1176                                        error = err
1177                                    );
1178                                };
1179                                state = WriteState::Writing(Either::Right(do_flush(flush, buffer)));
1180                            }
1181                        },
1182                        Either::Right(ow) => {
1183                            if let Some(Err(Failed(id, error))) = ow {
1184                                trace!(?error, "Removing a failed subscriber");
1185                                if let Some(rx) = registered.iter_mut().find(|rx| rx.id == id) {
1186                                    rx.terminate();
1187                                }
1188                            }
1189                            state = match flush_outcome {
1190                                Either::Left(message_writer) => WriteState::Idle {
1191                                    message_writer,
1192                                    buffer,
1193                                },
1194                                Either::Right(flush) => {
1195                                    trace!("Waiting on the completion of a flush.");
1196                                    WriteState::Writing(Either::Right(do_flush(flush, buffer)))
1197                                }
1198                            };
1199                        }
1200                    }
1201                }
1202            }
1203            WriteState::Writing(write_fut) => {
1204                // It is necessary for the write to be pinned. Rather than putting it into a heap
1205                // allocation, it is instead pinned to the stack and another, inner loop is started
1206                // until the write completes.
1207                let mut write_fut = pin!(write_fut);
1208                'inner: loop {
1209                    let result = if registered.is_empty() {
1210                        task_state.remove(WriteTaskState::NEEDS_SYNC);
1211                        match select(&mut write_fut, reg_requests.next()).await {
1212                            Either::Left((write_result, _)) => {
1213                                SuspendedResult::SuspendedCompleted(write_result)
1214                            }
1215                            Either::Right((request, _)) => {
1216                                SuspendedResult::NewRegistration(request)
1217                            }
1218                        }
1219                    } else {
1220                        await_suspended(&mut write_fut, registered.next(), reg_requests.next())
1221                            .await
1222                    };
1223
1224                    match result {
1225                        SuspendedResult::SuspendedCompleted((result, mut buffer)) => {
1226                            let message_writer = if let Ok(mw) = result {
1227                                mw
1228                            } else {
1229                                warn!("Writing to the output failed.");
1230                                break 'outer;
1231                            };
1232                            if task_state.contains(WriteTaskState::NEEDS_SYNC) {
1233                                trace!("Sending a Sync message.");
1234                                task_state
1235                                    .remove(WriteTaskState::FLUSHED | WriteTaskState::NEEDS_SYNC);
1236                                let write = suspend_write(message_writer, buffer, WriteKind::Sync);
1237                                state = WriteState::Writing(Either::Left(write));
1238                            } else if backpressure.has_data() {
1239                                trace!("Dispatching the updated buffer.");
1240                                backpressure.prepare_write(&mut buffer);
1241                                let write_fut =
1242                                    suspend_write(message_writer, buffer, WriteKind::Data);
1243                                task_state.remove(WriteTaskState::FLUSHED);
1244                                state = WriteState::Writing(Either::Left(write_fut));
1245                            } else {
1246                                trace!("Task has become idle.");
1247                                state = WriteState::Idle {
1248                                    message_writer,
1249                                    buffer,
1250                                };
1251                            }
1252                            // The write has completed so we can return to the outer loop.
1253                            break 'inner;
1254                        }
1255                        SuspendedResult::NextRecord(Some(Ok(op))) => {
1256                            // Writing is currently blocked so overwrite the next value to be sent.
1257                            trace!("Over-writing the current event buffer.");
1258                            if let Err(err) = backpressure.push_operation(op) {
1259                                error!(
1260                                    "Failed to process downlink operation: {error}",
1261                                    error = err
1262                                );
1263                            };
1264                        }
1265                        SuspendedResult::NextRecord(Some(Err(Failed(id, error)))) => {
1266                            trace!(?error, "Removing a failed subscriber.");
1267                            if let Some(rx) = registered.iter_mut().find(|rx| rx.id == id) {
1268                                rx.terminate();
1269                            }
1270                        }
1271                        SuspendedResult::NewRegistration(Some((reader, options))) => {
1272                            trace!("Registering a new subscriber.");
1273                            let receiver =
1274                                DownlinkReceiver::new(reader, next_id(), B::make_decoder());
1275                            registered.push(receiver);
1276                            task_state.set_needs_sync(options);
1277                        }
1278                        SuspendedResult::NewRegistration(_) => {
1279                            info!("Instructed to stop.");
1280                            break 'outer;
1281                        }
1282                        _ => {}
1283                    }
1284                }
1285            }
1286        }
1287    }
1288}
1289
1290use futures::ready;
1291use std::pin::{pin, Pin};
1292use std::task::{Context, Poll};
1293use swimos_utilities::non_zero_usize;
1294
1295use self::failure::{BadFrameStrategy, InfallibleStrategy};
1296use self::interpretation::{value_interpretation, DownlinkInterpretation};
1297use crate::backpressure::{MapBackpressure, ValueBackpressure};
1298
1299/// A future that flushes a sender and then returns it. This is necessary as we need
1300/// an [`Unpin`] future so an equivalent async block would not work.
1301struct OwningFlush {
1302    inner: Option<RequestSender>,
1303}
1304
1305impl OwningFlush {
1306    fn new(sender: RequestSender) -> Self {
1307        OwningFlush {
1308            inner: Some(sender),
1309        }
1310    }
1311
1312    fn reclaim(self) -> RequestSender {
1313        if let Some(sender) = self.inner {
1314            sender
1315        } else {
1316            panic!("OwningFlush reclaimed after complete.");
1317        }
1318    }
1319}
1320
1321impl Future for OwningFlush {
1322    type Output = Result<RequestSender, std::io::Error>;
1323
1324    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1325        let OwningFlush { inner } = self.get_mut();
1326        let result = if let Some(tx) = inner {
1327            ready!(sender_poll_flush(&mut tx.sender, cx))
1328        } else {
1329            panic!("OwningFlush polled after complete.");
1330        };
1331        Poll::Ready(result.map(move |_| inner.take().unwrap()))
1332    }
1333}
1334
1335fn sender_poll_flush<Snk>(sink: &mut Snk, cx: &mut Context<'_>) -> Poll<Result<(), Snk::Error>>
1336where
1337    Snk: Sink<RawRequestMessage<'static, &'static str>> + Unpin,
1338{
1339    sink.poll_flush_unpin(cx)
1340}
1341
1342fn discard<A1, A2, B1, B2>(either: Either<(A1, A2), (B1, B2)>) -> Either<A1, B1> {
1343    match either {
1344        Either::Left((a1, _)) => Either::Left(a1),
1345        Either::Right((b1, _)) => Either::Right(b1),
1346    }
1347}
1348
1349/// This enum is for clarity only to avoid having nested [`Either`]s in match statements
1350/// after nesting [`select`] calls.
1351enum SuspendedResult<A, B, C> {
1352    SuspendedCompleted(A),
1353    NextRecord(B),
1354    NewRegistration(C),
1355}
1356
1357async fn await_suspended<F1, F2, F3>(
1358    suspended: F1,
1359    next_rec: F2,
1360    next_reg: F3,
1361) -> SuspendedResult<F1::Output, F2::Output, F3::Output>
1362where
1363    F1: Future + Unpin,
1364    F2: Future + Unpin,
1365    F3: Future + Unpin,
1366{
1367    let alternatives = select(next_rec, next_reg).map(discard);
1368    match select(suspended, alternatives).await {
1369        Either::Left((r, _)) => SuspendedResult::SuspendedCompleted(r),
1370        Either::Right((Either::Left(r), _)) => SuspendedResult::NextRecord(r),
1371        Either::Right((Either::Right(r), _)) => SuspendedResult::NewRegistration(r),
1372    }
1373}