1use std::collections::HashSet;
16use std::error::Error;
17use std::fmt::Debug;
18use std::num::NonZeroUsize;
19use std::time::Duration;
20
21use crate::downlink::failure::BadFrameResponse;
22use crate::timeout_coord::{VoteResult, Voter};
23use crate::Io;
24use backpressure::DownlinkBackpressure;
25use bitflags::bitflags;
26use bytes::{Bytes, BytesMut};
27use futures::future::{join, select, Either};
28use futures::stream::SelectAll;
29use futures::{Future, FutureExt, Sink, SinkExt, Stream, StreamExt};
30use interpretation::MapInterpretation;
31use swimos_agent_protocol::{
32 encoding::downlink::DownlinkNotificationEncoder, DownlinkNotification,
33};
34use swimos_api::address::RelativeAddress;
35use swimos_messages::protocol::{
36 Notification, Operation, RawRequestMessage, RawRequestMessageEncoder,
37 RawResponseMessageDecoder, ResponseMessage,
38};
39use swimos_model::Text;
40use swimos_utilities::byte_channel::{ByteReader, ByteWriter};
41use swimos_utilities::encoding::BytesStr;
42use swimos_utilities::future::{immediate_or_join, immediate_or_start, SecondaryResult};
43use swimos_utilities::trigger;
44use tokio::sync::mpsc;
45use tokio::time::{timeout, Instant};
46use tokio_stream::wrappers::ReceiverStream;
47use tokio_util::codec::{Decoder, Encoder, FramedRead, FramedWrite};
48use tracing::{error, info, info_span, trace, warn, Instrument};
49use uuid::Uuid;
50
51pub use interpretation::NoInterpretation;
52
53mod backpressure;
54pub mod failure;
56mod interpretation;
57#[cfg(test)]
58mod tests;
59
60bitflags! {
61 #[derive(Debug, Copy, Clone)]
64 pub struct DownlinkOptions: u8 {
65 const SYNC = 0b01;
67 const KEEP_LINKED = 0b10;
70 const DEFAULT = Self::SYNC.bits() | Self::KEEP_LINKED.bits();
72 }
73}
74
75bitflags! {
76 struct WriteTaskState: u8 {
78 const FLUSHED = 0b01;
80 const NEEDS_SYNC = 0b10;
82 const INIT = Self::FLUSHED.bits();
84 }
85}
86
87impl WriteTaskState {
88 pub fn set_needs_sync(&mut self, options: DownlinkOptions) {
93 if options.contains(DownlinkOptions::SYNC) {
94 *self |= WriteTaskState::NEEDS_SYNC;
95 }
96 }
97}
98
99pub struct AttachAction {
101 io: Io,
102 options: DownlinkOptions,
103}
104
105impl AttachAction {
106 pub fn new(io: Io, options: DownlinkOptions) -> Self {
110 AttachAction { io, options }
111 }
112}
113
114#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
116pub struct DownlinkRuntimeConfig {
117 pub empty_timeout: Duration,
119 pub attachment_queue_size: NonZeroUsize,
121 pub abort_on_bad_frames: bool,
123 pub remote_buffer_size: NonZeroUsize,
125 pub downlink_buffer_size: NonZeroUsize,
127}
128
129impl Default for DownlinkRuntimeConfig {
130 fn default() -> Self {
131 DownlinkRuntimeConfig {
132 empty_timeout: Duration::from_secs(30),
133 attachment_queue_size: non_zero_usize!(16),
134 abort_on_bad_frames: true,
135 remote_buffer_size: non_zero_usize!(4096),
136 downlink_buffer_size: non_zero_usize!(4096),
137 }
138 }
139}
140
141pub struct ValueDownlinkRuntime {
143 requests: mpsc::Receiver<AttachAction>,
144 input: ByteReader,
145 output: ByteWriter,
146 stopping: trigger::Receiver,
147 identity: Uuid,
148 path: RelativeAddress<Text>,
149 config: DownlinkRuntimeConfig,
150}
151
152pub struct MapDownlinkRuntime<H, I> {
154 requests: mpsc::Receiver<AttachAction>,
155 input: ByteReader,
156 output: ByteWriter,
157 stopping: trigger::Receiver,
158 identity: Uuid,
159 path: RelativeAddress<Text>,
160 config: DownlinkRuntimeConfig,
161 failure_handler: H,
162 interpretation: I,
163}
164
165async fn await_io_tasks<F1, F2, E>(
166 read: F1,
167 write: F2,
168 kill_switch_tx: trigger::Sender,
169) -> Result<(), E>
170where
171 F1: Future<Output = Result<(), E>>,
172 F2: Future<Output = ()>,
173{
174 let read = pin!(read);
175 let write = pin!(write);
176 let first_finished = select(read, write).await;
177 kill_switch_tx.trigger();
178 match first_finished {
179 Either::Left((read_res, write_fut)) => {
180 write_fut.await;
181 read_res
182 }
183 Either::Right((_, read_fut)) => read_fut.await,
184 }
185}
186
187impl ValueDownlinkRuntime {
188 pub fn new(
196 requests: mpsc::Receiver<AttachAction>,
197 io: Io,
198 stopping: trigger::Receiver,
199 address: IdentifiedAddress,
200 config: DownlinkRuntimeConfig,
201 ) -> Self {
202 let (output, input) = io;
203 let IdentifiedAddress {
204 identity,
205 address: path,
206 } = address;
207 ValueDownlinkRuntime {
208 requests,
209 input,
210 output,
211 stopping,
212 identity,
213 path,
214 config,
215 }
216 }
217
218 pub async fn run(self) {
220 let ValueDownlinkRuntime {
221 requests,
222 input,
223 output,
224 stopping,
225 identity,
226 path,
227 config,
228 } = self;
229
230 let (producer_tx, producer_rx) = mpsc::channel(config.attachment_queue_size.get());
231 let (consumer_tx, consumer_rx) = mpsc::channel(config.attachment_queue_size.get());
232 let (kill_switch_tx, kill_switch_rx) = trigger::trigger();
233 let (read_vote, write_vote, vote_rx) = crate::timeout_coord::downlink_timeout_coordinator();
234
235 let combined_stop = select(stopping, select(kill_switch_rx, vote_rx));
236 let att = attach_task(requests, producer_tx, consumer_tx, combined_stop)
237 .instrument(info_span!("Value Downlink Runtime Attachment Task", %path));
238 let read = read_task(
239 input,
240 consumer_rx,
241 config,
242 value_interpretation(),
243 InfallibleStrategy,
244 read_vote,
245 )
246 .instrument(info_span!("Value Downlink Runtime Read Task", %path));
247 let write = write_task(
248 output,
249 producer_rx,
250 identity,
251 RelativeAddress::new(path.node.clone(), path.lane.clone()),
252 config,
253 ValueBackpressure::default(),
254 write_vote,
255 )
256 .instrument(info_span!("Value Downlink Runtime Write Task", %path));
257 let io = async move {
258 let read_res = await_io_tasks(read, write, kill_switch_tx).await;
259 if let Err(e) = read_res {
260 match e {}
261 }
262 };
263 join(att, io).await;
264 }
265}
266
267impl<H> MapDownlinkRuntime<H, MapInterpretation> {
268 pub fn new(
278 requests: mpsc::Receiver<AttachAction>,
279 io: Io,
280 stopping: trigger::Receiver,
281 address: IdentifiedAddress,
282 config: DownlinkRuntimeConfig,
283 failure_handler: H,
284 ) -> Self {
285 let (output, input) = io;
286 let IdentifiedAddress {
287 identity,
288 address: path,
289 } = address;
290 MapDownlinkRuntime {
291 requests,
292 input,
293 output,
294 stopping,
295 identity,
296 path,
297 config,
298 failure_handler,
299 interpretation: MapInterpretation::default(),
300 }
301 }
302}
303
304impl<I, H> MapDownlinkRuntime<H, I> {
305 pub fn with_interpretation(
317 requests: mpsc::Receiver<AttachAction>,
318 io: Io,
319 stopping: trigger::Receiver,
320 address: IdentifiedAddress,
321 config: DownlinkRuntimeConfig,
322 failure_handler: H,
323 interpretation: I,
324 ) -> Self {
325 let (output, input) = io;
326 let IdentifiedAddress {
327 identity,
328 address: path,
329 } = address;
330 MapDownlinkRuntime {
331 requests,
332 input,
333 output,
334 stopping,
335 identity,
336 path,
337 config,
338 failure_handler,
339 interpretation,
340 }
341 }
342}
343
344pub struct IdentifiedAddress {
346 pub identity: Uuid,
348 pub address: RelativeAddress<Text>,
350}
351
352impl<I, H> MapDownlinkRuntime<H, I>
353where
354 I: DownlinkInterpretation,
355 H: BadFrameStrategy<I::Error>,
356{
357 pub async fn run(self) {
359 let MapDownlinkRuntime {
360 requests,
361 input,
362 output,
363 stopping,
364 identity,
365 path,
366 config,
367 failure_handler,
368 interpretation,
369 } = self;
370
371 let (producer_tx, producer_rx) = mpsc::channel(config.attachment_queue_size.get());
372 let (consumer_tx, consumer_rx) = mpsc::channel(config.attachment_queue_size.get());
373 let (read_vote, write_vote, vote_rx) = crate::timeout_coord::downlink_timeout_coordinator();
374 let (kill_switch_tx, kill_switch_rx) = trigger::trigger();
375
376 let combined_stop = select(stopping, select(kill_switch_rx, vote_rx));
377 let att = attach_task(requests, producer_tx, consumer_tx, combined_stop)
378 .instrument(info_span!("Map Downlink Runtime Attachment Task", %path));
379 let read = read_task(
380 input,
381 consumer_rx,
382 config,
383 interpretation,
384 failure_handler,
385 read_vote,
386 )
387 .instrument(info_span!("Map Downlink Runtime Read Task", %path));
388 let write = write_task(
389 output,
390 producer_rx,
391 identity,
392 RelativeAddress::new(path.node.clone(), path.lane.clone()),
393 config,
394 MapBackpressure::default(),
395 write_vote,
396 )
397 .instrument(info_span!("Map Downlink Runtime Write Task", %path));
398 let io = async move {
399 let read_res = await_io_tasks(read, write, kill_switch_tx).await;
400 if let Err(e) = read_res {
401 error!(
402 "Map downlink received invalid event messages: {problems}",
403 problems = e
404 );
405 }
406 };
407 join(att, io).await;
408 }
409}
410
411async fn attach_task<F>(
413 rx: mpsc::Receiver<AttachAction>,
414 producer_tx: mpsc::Sender<(ByteReader, DownlinkOptions)>,
415 consumer_tx: mpsc::Sender<(ByteWriter, DownlinkOptions)>,
416 combined_stop: F,
417) where
418 F: Future + Unpin,
419{
420 let mut stream = ReceiverStream::new(rx).take_until(combined_stop);
421 while let Some(AttachAction {
422 io: (output, input),
423 options,
424 }) = stream.next().await
425 {
426 if consumer_tx.send((output, options)).await.is_err() {
427 break;
428 }
429 if producer_tx.send((input, options)).await.is_err() {
430 break;
431 }
432 }
433 trace!("Attachment task stopping.");
434}
435
436#[derive(Debug, PartialEq, Eq)]
437enum ReadTaskDlState {
438 Init,
439 Linked,
440 Synced,
441}
442
443#[derive(Debug)]
445struct DownlinkSender {
446 sender: FramedWrite<ByteWriter, DownlinkNotificationEncoder>,
447 options: DownlinkOptions,
448}
449
450impl DownlinkSender {
451 fn new(writer: ByteWriter, options: DownlinkOptions) -> Self {
452 DownlinkSender {
453 sender: FramedWrite::new(writer, DownlinkNotificationEncoder),
454 options,
455 }
456 }
457
458 async fn feed(
459 &mut self,
460 message: DownlinkNotification<&BytesMut>,
461 ) -> Result<(), std::io::Error> {
462 self.sender.feed(message).await
463 }
464
465 async fn send(
466 &mut self,
467 message: DownlinkNotification<&BytesMut>,
468 ) -> Result<(), std::io::Error> {
469 self.sender.send(message).await
470 }
471
472 async fn flush(&mut self) -> Result<(), std::io::Error> {
473 flush_sender_notification(&mut self.sender).await
474 }
475}
476
477async fn flush_sender_notification<T>(
478 sender: &mut FramedWrite<ByteWriter, T>,
479) -> Result<(), T::Error>
480where
481 T: Encoder<DownlinkNotification<&'static BytesMut>>,
482{
483 sender.flush().await
484}
485
486async fn flush_sender_req<T>(sender: &mut FramedWrite<ByteWriter, T>) -> Result<(), T::Error>
487where
488 T: Encoder<RawRequestMessage<'static, &'static str>>,
489{
490 sender.flush().await
491}
492
493struct DownlinkReceiver<D> {
495 receiver: FramedRead<ByteReader, D>,
496 id: u64,
497 terminated: bool,
498}
499
500impl<D: Decoder> DownlinkReceiver<D> {
501 fn new(reader: ByteReader, id: u64, decoder: D) -> Self {
502 DownlinkReceiver {
503 receiver: FramedRead::new(reader, decoder),
504 id,
505 terminated: false,
506 }
507 }
508
509 fn terminate(&mut self) {
511 self.terminated = true;
512 }
513}
514
515struct Failed(u64, Box<dyn Error + 'static>);
516
517impl<D: Decoder> Stream for DownlinkReceiver<D>
518where
519 D::Error: Error + 'static,
520{
521 type Item = Result<D::Item, Failed>;
522
523 fn poll_next(
524 self: std::pin::Pin<&mut Self>,
525 cx: &mut std::task::Context<'_>,
526 ) -> std::task::Poll<Option<Self::Item>> {
527 let this = self.get_mut();
528 if this.terminated {
529 std::task::Poll::Ready(None)
530 } else {
531 this.receiver
532 .poll_next_unpin(cx)
533 .map_err(|e| Failed(this.id, Box::new(e)))
534 }
535 }
536}
537
538enum ReadTaskEvent {
539 Message(ResponseMessage<BytesStr, Bytes, Bytes>),
540 ReadFailed(std::io::Error),
541 MessagesStopped,
542 NewConsumer(ByteWriter, DownlinkOptions),
543 ConsumerChannelStopped,
544 ConsumersTimedOut,
545}
546
547async fn read_task<I, H>(
549 input: ByteReader,
550 consumers: mpsc::Receiver<(ByteWriter, DownlinkOptions)>,
551 config: DownlinkRuntimeConfig,
552 mut interpretation: I,
553 mut failure_handler: H,
554 stop_voter: Voter,
555) -> Result<(), H::Report>
556where
557 I: DownlinkInterpretation,
558 H: BadFrameStrategy<I::Error>,
559{
560 let mut messages = FramedRead::new(input, RawResponseMessageDecoder);
561
562 let mut flushed = true;
563 let mut voted = false;
564
565 let mut consumer_stream = ReceiverStream::new(consumers);
566
567 let make_timeout = || tokio::time::sleep_until(Instant::now() + config.empty_timeout);
568 let mut task_state = pin!(Some(make_timeout()));
569 let mut dl_state = ReadTaskDlState::Init;
570 let mut current = BytesMut::new();
571 let mut awaiting_synced: Vec<DownlinkSender> = vec![];
572 let mut awaiting_linked: Vec<DownlinkSender> = vec![];
573 let mut registered: Vec<DownlinkSender> = vec![];
574 let mut sync_event = false;
578
579 let result: Result<(), H::Report> = loop {
580 let (event, is_active) = match task_state.as_mut().as_pin_mut() {
581 Some(sleep) if !voted => (
582 tokio::select! {
583 biased;
584 _ = sleep => {
585 task_state.set(None);
586 ReadTaskEvent::ConsumersTimedOut
587 },
588 maybe_consumer = consumer_stream.next() => {
589 if let Some((consumer, options)) = maybe_consumer {
590 ReadTaskEvent::NewConsumer(consumer, options)
591 } else {
592 ReadTaskEvent::ConsumerChannelStopped
593 }
594 },
595 maybe_message = messages.next() => {
596 match maybe_message {
597 Some(Ok(msg)) => ReadTaskEvent::Message(msg),
598 Some(Err(err)) => ReadTaskEvent::ReadFailed(err),
599 _ => ReadTaskEvent::MessagesStopped,
600 }
601 }
602 },
603 false,
604 ),
605 _ => {
606 let get_next = async {
607 tokio::select! {
608 maybe_consumer = consumer_stream.next() => {
609 if let Some((consumer, options)) = maybe_consumer {
610 ReadTaskEvent::NewConsumer(consumer, options)
611 } else {
612 ReadTaskEvent::ConsumerChannelStopped
613 }
614 },
615 maybe_message = messages.next() => {
616 match maybe_message {
617 Some(Ok(msg)) => ReadTaskEvent::Message(msg),
618 Some(Err(err)) => ReadTaskEvent::ReadFailed(err),
619 _ => ReadTaskEvent::MessagesStopped,
620 }
621 }
622 }
623 };
624 if flushed {
625 trace!("Waiting without flush.");
626 (get_next.await, true)
627 } else {
628 trace!("Waiting with flush.");
629 let flush = join(flush_all(&mut awaiting_synced), flush_all(&mut registered));
630 let next_with_flush = immediate_or_join(get_next, flush);
631 let (next, flush_result) = next_with_flush.await;
632 let is_active = if flush_result.is_some() {
633 trace!("Flush completed.");
634 flushed = true;
635 if registered.is_empty() && awaiting_synced.is_empty() {
636 trace!("Number of subscribers dropped to 0.");
637 task_state.set(Some(make_timeout()));
638 false
639 } else {
640 true
641 }
642 } else {
643 true
644 };
645 (next, is_active)
646 }
647 }
648 };
649
650 match event {
651 ReadTaskEvent::ConsumersTimedOut => {
652 info!("No consumers connected within the timeout period. Voting to stop.");
653 if stop_voter.vote() == VoteResult::Unanimous {
654 break Ok(());
656 } else {
657 voted = true;
658 }
659 }
660 ReadTaskEvent::Message(ResponseMessage { envelope, .. }) => match envelope {
661 Notification::Linked => {
662 trace!("Entering Linked state.");
663 dl_state = ReadTaskDlState::Linked;
664 if is_active {
665 link(&mut awaiting_linked, &mut awaiting_synced, &mut registered).await;
666 if awaiting_synced.is_empty() && registered.is_empty() {
667 trace!("Number of subscribers dropped to 0.");
668 task_state.set(Some(make_timeout()));
669 }
670 }
671 }
672 Notification::Synced => {
673 trace!("Entering Synced state.");
674 dl_state = ReadTaskDlState::Synced;
675 if is_active {
676 if I::SINGLE_FRAME_STATE && sync_event {
687 sync_current(&mut awaiting_synced, &mut registered, ¤t).await;
688 } else {
689 sync_only(&mut awaiting_synced, &mut registered).await;
690 }
691 if registered.is_empty() {
692 trace!("Number of subscribers dropped to 0.");
693 task_state.set(Some(make_timeout()));
694 }
695 }
696 }
697 Notification::Unlinked(message) => {
698 trace!("Stopping after unlinked: {msg:?}", msg = message);
699 break Ok(());
700 }
701 Notification::Event(bytes) => {
702 sync_event = true;
703
704 trace!("Updating the current value.");
705 current.clear();
706
707 if let Err(e) = interpretation.interpret_frame_data(bytes, &mut current) {
708 if let BadFrameResponse::Abort(report) = failure_handler.failed_with(e) {
709 break Err(report);
710 }
711 }
712 if is_active {
713 send_current(&mut registered, ¤t).await;
714 if !I::SINGLE_FRAME_STATE {
715 send_current(&mut awaiting_synced, ¤t).await;
716 }
717 if registered.is_empty() && awaiting_synced.is_empty() {
718 trace!("Number of subscribers dropped to 0.");
719 task_state.set(Some(make_timeout()));
720 flushed = true;
721 } else {
722 flushed = false;
723 }
724 }
725 }
726 },
727 ReadTaskEvent::ReadFailed(err) => {
728 error!(
729 "Failed to read a frame from the input: {error}",
730 error = err
731 );
732 break Ok(());
733 }
734 ReadTaskEvent::NewConsumer(writer, options) => {
735 let mut dl_writer = DownlinkSender::new(writer, options);
736 let added = if matches!(dl_state, ReadTaskDlState::Init) {
737 trace!("Attaching a new subscriber to be linked.");
738 awaiting_linked.push(dl_writer);
739 true
740 } else if dl_writer.send(DownlinkNotification::Linked).await.is_ok() {
741 trace!("Attaching a new subscriber to be synced.");
742 awaiting_synced.push(dl_writer);
743 true
744 } else {
745 false
746 };
747 if added {
748 task_state.set(None);
749 if voted {
750 if stop_voter.rescind() == VoteResult::Unanimous {
751 info!(
752 "Attempted to rescind stop vote but shutdown had already started."
753 );
754 break Ok(());
755 } else {
756 voted = false;
757 }
758 }
759 }
760 }
761 _ => {
762 trace!("Instructed to stop.");
763 break Ok(());
764 }
765 }
766 };
767
768 trace!("Read task stopping and unlinked all subscribers");
769 unlink(awaiting_linked).await;
770 unlink(awaiting_synced).await;
771 unlink(registered).await;
772 result
773}
774
775async fn sync_current(
776 awaiting_synced: &mut Vec<DownlinkSender>,
777 registered: &mut Vec<DownlinkSender>,
778 current: &BytesMut,
779) {
780 let event = DownlinkNotification::Event { body: current };
781 for mut tx in awaiting_synced.drain(..) {
782 if tx.feed(event).await.is_ok() && tx.send(DownlinkNotification::Synced).await.is_ok() {
783 registered.push(tx);
784 }
785 }
786}
787
788async fn sync_only(
789 awaiting_synced: &mut Vec<DownlinkSender>,
790 registered: &mut Vec<DownlinkSender>,
791) {
792 for mut tx in awaiting_synced.drain(..) {
793 if tx.send(DownlinkNotification::Synced).await.is_ok() {
794 registered.push(tx);
795 }
796 }
797}
798
799async fn send_current(senders: &mut Vec<DownlinkSender>, current: &BytesMut) {
800 let event = DownlinkNotification::Event { body: current };
801 let mut failed = HashSet::<usize>::default();
802 for (i, tx) in senders.iter_mut().enumerate() {
803 if tx.feed(event).await.is_err() {
804 failed.insert(i);
805 }
806 }
807 clear_failed(senders, &failed);
808}
809
810async fn link(
811 awaiting_linked: &mut Vec<DownlinkSender>,
812 awaiting_synced: &mut Vec<DownlinkSender>,
813 registered: &mut Vec<DownlinkSender>,
814) {
815 let event = DownlinkNotification::Linked;
816
817 for mut tx in awaiting_linked.drain(..) {
818 if tx.send(event).await.is_ok() {
819 if tx.options.contains(DownlinkOptions::SYNC) {
820 awaiting_synced.push(tx);
821 } else {
822 registered.push(tx);
823 }
824 }
825 }
826}
827
828async fn unlink(senders: Vec<DownlinkSender>) -> Vec<DownlinkSender> {
829 let mut to_keep = vec![];
830 for mut tx in senders.into_iter() {
831 let still_active = tx.send(DownlinkNotification::Unlinked).await.is_ok();
832 if tx.options.contains(DownlinkOptions::KEEP_LINKED) && still_active {
833 to_keep.push(tx);
834 }
835 }
836 to_keep
837}
838
839fn clear_failed(senders: &mut Vec<DownlinkSender>, failed: &HashSet<usize>) {
840 if !failed.is_empty() {
841 trace!(
842 "Clearing {num_failed} failed subscribers.",
843 num_failed = failed.len()
844 );
845 let mut i = 0;
846 senders.retain(|_| {
847 let keep = !failed.contains(&i);
848 i += 1;
849 keep
850 });
851 }
852}
853
854async fn flush_all(senders: &mut Vec<DownlinkSender>) {
855 let mut failed = HashSet::<usize>::default();
856 for (i, tx) in senders.iter_mut().enumerate() {
857 if tx.flush().await.is_err() {
858 failed.insert(i);
859 }
860 }
861 clear_failed(senders, &failed);
862}
863
864#[derive(Debug)]
865struct RequestSender {
866 sender: FramedWrite<ByteWriter, RawRequestMessageEncoder>,
867 identity: Uuid,
868 path: RelativeAddress<Text>,
869}
870
871impl RequestSender {
872 fn new(writer: ByteWriter, identity: Uuid, path: RelativeAddress<Text>) -> Self {
873 RequestSender {
874 sender: FramedWrite::new(writer, RawRequestMessageEncoder),
875 identity,
876 path,
877 }
878 }
879
880 async fn send_link(&mut self) -> Result<(), std::io::Error> {
881 let RequestSender {
882 sender,
883 identity,
884 path,
885 } = self;
886 let message = RawRequestMessage {
887 origin: *identity,
888 path: path.clone(),
889 envelope: Operation::Link,
890 };
891 sender.send(message).await
892 }
893
894 async fn send_sync(&mut self) -> Result<(), std::io::Error> {
895 let RequestSender {
896 sender,
897 identity,
898 path,
899 } = self;
900 let message = RawRequestMessage {
901 origin: *identity,
902 path: path.clone(),
903 envelope: Operation::Sync,
904 };
905 sender.send(message).await
906 }
907
908 async fn feed_command(&mut self, body: &[u8]) -> Result<(), std::io::Error> {
909 let RequestSender {
910 sender,
911 identity,
912 path,
913 } = self;
914 let message = RawRequestMessage {
915 origin: *identity,
916 path: path.clone(),
917 envelope: Operation::Command(body),
918 };
919 sender.feed(message).await
920 }
921
922 async fn flush(&mut self) -> Result<(), std::io::Error> {
923 flush_sender_req(&mut self.sender).await
924 }
925
926 fn owning_flush(self) -> OwningFlush {
927 OwningFlush::new(self)
928 }
929}
930
931enum WriteState<F> {
933 Idle {
935 message_writer: RequestSender,
936 buffer: BytesMut,
937 },
938 Writing(F),
941}
942
943enum WriteKind {
944 Sync,
945 Data,
946}
947
948async fn do_flush(
949 flush: OwningFlush,
950 buffer: BytesMut,
951) -> (Result<RequestSender, std::io::Error>, BytesMut) {
952 let result = flush.await;
953 (result, buffer)
954}
955
956async fn write_task<B: DownlinkBackpressure>(
959 output: ByteWriter,
960 producers: mpsc::Receiver<(ByteReader, DownlinkOptions)>,
961 identity: Uuid,
962 path: RelativeAddress<Text>,
963 config: DownlinkRuntimeConfig,
964 mut backpressure: B,
965 stop_voter: Voter,
966) where
967 <<B as DownlinkBackpressure>::Dec as Decoder>::Error: Error + 'static,
968{
969 let mut message_writer = RequestSender::new(output, identity, path);
970 if message_writer.send_link().await.is_err() {
971 return;
972 }
973
974 let mut state = WriteState::Idle {
975 message_writer,
976 buffer: BytesMut::new(),
977 };
978
979 let mut registered: SelectAll<DownlinkReceiver<B::Dec>> = SelectAll::new();
980 let mut reg_requests = ReceiverStream::new(producers);
981 let mut id: u64 = 0;
983 let mut next_id = move || {
984 let i = id;
985 id += 1;
986 i
987 };
988
989 let mut task_state = WriteTaskState::INIT;
990
991 let suspend_write = |mut message_writer: RequestSender, buffer: BytesMut, kind: WriteKind| async move {
992 let result = match kind {
993 WriteKind::Data => message_writer.feed_command(buffer.as_ref()).await,
994 WriteKind::Sync => message_writer.send_sync().await,
995 };
996 (result.map(move |_| message_writer), buffer)
997 };
998
999 let mut voted = false;
1000
1001 'outer: loop {
1041 match state {
1042 WriteState::Idle {
1043 mut message_writer,
1044 mut buffer,
1045 } => {
1046 if registered.is_empty() {
1047 task_state.remove(WriteTaskState::NEEDS_SYNC);
1048 let req_with_timeout = async {
1049 if voted {
1050 Ok(reg_requests.next().await)
1051 } else {
1052 timeout(config.empty_timeout, reg_requests.next()).await
1053 }
1054 };
1055 let req_result = if task_state.contains(WriteTaskState::FLUSHED) {
1056 req_with_timeout.await
1057 } else {
1058 let (req_result, flush_result) =
1059 join(req_with_timeout, message_writer.flush()).await;
1060 if let Err(err) = flush_result {
1061 warn!(error = %err, "Flushing the output failed.");
1062 break 'outer;
1063 } else {
1064 task_state |= WriteTaskState::FLUSHED;
1065 }
1066 req_result
1067 };
1068 match req_result {
1069 Ok(Some((reader, options))) => {
1070 if voted {
1071 if stop_voter.rescind() == VoteResult::Unanimous {
1072 info!("Attempted to rescind vote to stop but shutdown had already started.");
1073 break 'outer;
1074 } else {
1075 voted = false;
1076 }
1077 }
1078 trace!("Registering new subscriber.");
1079 let receiver =
1080 DownlinkReceiver::new(reader, next_id(), B::make_decoder());
1081 registered.push(receiver);
1082 if options.contains(DownlinkOptions::SYNC) {
1083 trace!("Sending a Sync message.");
1084 let write = suspend_write(message_writer, buffer, WriteKind::Sync);
1085 state = WriteState::Writing(Either::Left(write));
1086 } else {
1087 state = WriteState::Idle {
1088 message_writer,
1089 buffer,
1090 };
1091 }
1092 }
1093 Err(_) => {
1094 if stop_voter.vote() == VoteResult::Unanimous {
1095 info!("Stopping as no subscribers attached within the timeout and read task voted to stop.");
1096 break 'outer;
1097 } else {
1098 voted = true;
1099 state = WriteState::Idle {
1100 message_writer,
1101 buffer,
1102 };
1103 }
1104 }
1105 _ => {
1106 info!("Instructed to stop.");
1107 break 'outer;
1108 }
1109 }
1110 } else {
1111 let next = select(reg_requests.next(), registered.next());
1112 let (next_op, flush_outcome) = if task_state.contains(WriteTaskState::FLUSHED) {
1113 (discard(next.await), Either::Left(message_writer))
1114 } else {
1115 let (next_op, flush_result) =
1116 immediate_or_start(next, message_writer.owning_flush()).await;
1117 let flush_outcome = match flush_result {
1118 SecondaryResult::NotStarted(of) => Either::Left(of.reclaim()),
1119 SecondaryResult::Pending(of) => Either::Right(of),
1120 SecondaryResult::Completed(Ok(sender)) => {
1121 task_state |= WriteTaskState::FLUSHED;
1122 Either::Left(sender)
1123 }
1124 SecondaryResult::Completed(Err(_)) => {
1125 warn!("Flushing the output failed.");
1126 break 'outer;
1127 }
1128 };
1129 (discard(next_op), flush_outcome)
1130 };
1131 match next_op {
1132 Either::Left(Some((reader, options))) => {
1133 let receiver =
1134 DownlinkReceiver::new(reader, next_id(), B::make_decoder());
1135 registered.push(receiver);
1136 match flush_outcome {
1137 Either::Left(message_writer) => {
1138 if options.contains(DownlinkOptions::SYNC) {
1139 trace!("Sending a Sync message.");
1140 let write =
1141 suspend_write(message_writer, buffer, WriteKind::Sync);
1142 state = WriteState::Writing(Either::Left(write));
1143 } else {
1144 state = WriteState::Idle {
1145 message_writer,
1146 buffer,
1147 };
1148 }
1149 }
1150 Either::Right(flush) => {
1151 trace!("Waiting on the completion of a flush.");
1152 task_state.set_needs_sync(options);
1153 state =
1154 WriteState::Writing(Either::Right(do_flush(flush, buffer)));
1155 }
1156 }
1157 }
1158 Either::Left(_) => {
1159 info!("Instructed to stop.");
1160 break 'outer;
1161 }
1162 Either::Right(Some(Ok(op))) => match flush_outcome {
1163 Either::Left(message_writer) => {
1164 trace!("Dispatching an event.");
1165 backpressure.write_direct(op, &mut buffer);
1166 let write_fut =
1167 suspend_write(message_writer, buffer, WriteKind::Data);
1168 task_state.remove(WriteTaskState::FLUSHED);
1169 state = WriteState::Writing(Either::Left(write_fut));
1170 }
1171 Either::Right(flush) => {
1172 trace!("Storing an event in the buffer and waiting on a flush to compelte.");
1173 if let Err(err) = backpressure.push_operation(op) {
1174 error!(
1175 "Failed to process downlink operaton: {error}",
1176 error = err
1177 );
1178 };
1179 state = WriteState::Writing(Either::Right(do_flush(flush, buffer)));
1180 }
1181 },
1182 Either::Right(ow) => {
1183 if let Some(Err(Failed(id, error))) = ow {
1184 trace!(?error, "Removing a failed subscriber");
1185 if let Some(rx) = registered.iter_mut().find(|rx| rx.id == id) {
1186 rx.terminate();
1187 }
1188 }
1189 state = match flush_outcome {
1190 Either::Left(message_writer) => WriteState::Idle {
1191 message_writer,
1192 buffer,
1193 },
1194 Either::Right(flush) => {
1195 trace!("Waiting on the completion of a flush.");
1196 WriteState::Writing(Either::Right(do_flush(flush, buffer)))
1197 }
1198 };
1199 }
1200 }
1201 }
1202 }
1203 WriteState::Writing(write_fut) => {
1204 let mut write_fut = pin!(write_fut);
1208 'inner: loop {
1209 let result = if registered.is_empty() {
1210 task_state.remove(WriteTaskState::NEEDS_SYNC);
1211 match select(&mut write_fut, reg_requests.next()).await {
1212 Either::Left((write_result, _)) => {
1213 SuspendedResult::SuspendedCompleted(write_result)
1214 }
1215 Either::Right((request, _)) => {
1216 SuspendedResult::NewRegistration(request)
1217 }
1218 }
1219 } else {
1220 await_suspended(&mut write_fut, registered.next(), reg_requests.next())
1221 .await
1222 };
1223
1224 match result {
1225 SuspendedResult::SuspendedCompleted((result, mut buffer)) => {
1226 let message_writer = if let Ok(mw) = result {
1227 mw
1228 } else {
1229 warn!("Writing to the output failed.");
1230 break 'outer;
1231 };
1232 if task_state.contains(WriteTaskState::NEEDS_SYNC) {
1233 trace!("Sending a Sync message.");
1234 task_state
1235 .remove(WriteTaskState::FLUSHED | WriteTaskState::NEEDS_SYNC);
1236 let write = suspend_write(message_writer, buffer, WriteKind::Sync);
1237 state = WriteState::Writing(Either::Left(write));
1238 } else if backpressure.has_data() {
1239 trace!("Dispatching the updated buffer.");
1240 backpressure.prepare_write(&mut buffer);
1241 let write_fut =
1242 suspend_write(message_writer, buffer, WriteKind::Data);
1243 task_state.remove(WriteTaskState::FLUSHED);
1244 state = WriteState::Writing(Either::Left(write_fut));
1245 } else {
1246 trace!("Task has become idle.");
1247 state = WriteState::Idle {
1248 message_writer,
1249 buffer,
1250 };
1251 }
1252 break 'inner;
1254 }
1255 SuspendedResult::NextRecord(Some(Ok(op))) => {
1256 trace!("Over-writing the current event buffer.");
1258 if let Err(err) = backpressure.push_operation(op) {
1259 error!(
1260 "Failed to process downlink operation: {error}",
1261 error = err
1262 );
1263 };
1264 }
1265 SuspendedResult::NextRecord(Some(Err(Failed(id, error)))) => {
1266 trace!(?error, "Removing a failed subscriber.");
1267 if let Some(rx) = registered.iter_mut().find(|rx| rx.id == id) {
1268 rx.terminate();
1269 }
1270 }
1271 SuspendedResult::NewRegistration(Some((reader, options))) => {
1272 trace!("Registering a new subscriber.");
1273 let receiver =
1274 DownlinkReceiver::new(reader, next_id(), B::make_decoder());
1275 registered.push(receiver);
1276 task_state.set_needs_sync(options);
1277 }
1278 SuspendedResult::NewRegistration(_) => {
1279 info!("Instructed to stop.");
1280 break 'outer;
1281 }
1282 _ => {}
1283 }
1284 }
1285 }
1286 }
1287 }
1288}
1289
1290use futures::ready;
1291use std::pin::{pin, Pin};
1292use std::task::{Context, Poll};
1293use swimos_utilities::non_zero_usize;
1294
1295use self::failure::{BadFrameStrategy, InfallibleStrategy};
1296use self::interpretation::{value_interpretation, DownlinkInterpretation};
1297use crate::backpressure::{MapBackpressure, ValueBackpressure};
1298
1299struct OwningFlush {
1302 inner: Option<RequestSender>,
1303}
1304
1305impl OwningFlush {
1306 fn new(sender: RequestSender) -> Self {
1307 OwningFlush {
1308 inner: Some(sender),
1309 }
1310 }
1311
1312 fn reclaim(self) -> RequestSender {
1313 if let Some(sender) = self.inner {
1314 sender
1315 } else {
1316 panic!("OwningFlush reclaimed after complete.");
1317 }
1318 }
1319}
1320
1321impl Future for OwningFlush {
1322 type Output = Result<RequestSender, std::io::Error>;
1323
1324 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1325 let OwningFlush { inner } = self.get_mut();
1326 let result = if let Some(tx) = inner {
1327 ready!(sender_poll_flush(&mut tx.sender, cx))
1328 } else {
1329 panic!("OwningFlush polled after complete.");
1330 };
1331 Poll::Ready(result.map(move |_| inner.take().unwrap()))
1332 }
1333}
1334
1335fn sender_poll_flush<Snk>(sink: &mut Snk, cx: &mut Context<'_>) -> Poll<Result<(), Snk::Error>>
1336where
1337 Snk: Sink<RawRequestMessage<'static, &'static str>> + Unpin,
1338{
1339 sink.poll_flush_unpin(cx)
1340}
1341
1342fn discard<A1, A2, B1, B2>(either: Either<(A1, A2), (B1, B2)>) -> Either<A1, B1> {
1343 match either {
1344 Either::Left((a1, _)) => Either::Left(a1),
1345 Either::Right((b1, _)) => Either::Right(b1),
1346 }
1347}
1348
1349enum SuspendedResult<A, B, C> {
1352 SuspendedCompleted(A),
1353 NextRecord(B),
1354 NewRegistration(C),
1355}
1356
1357async fn await_suspended<F1, F2, F3>(
1358 suspended: F1,
1359 next_rec: F2,
1360 next_reg: F3,
1361) -> SuspendedResult<F1::Output, F2::Output, F3::Output>
1362where
1363 F1: Future + Unpin,
1364 F2: Future + Unpin,
1365 F3: Future + Unpin,
1366{
1367 let alternatives = select(next_rec, next_reg).map(discard);
1368 match select(suspended, alternatives).await {
1369 Either::Left((r, _)) => SuspendedResult::SuspendedCompleted(r),
1370 Either::Right((Either::Left(r), _)) => SuspendedResult::NextRecord(r),
1371 Either::Right((Either::Right(r), _)) => SuspendedResult::NewRegistration(r),
1372 }
1373}