stratum-apps 0.4.0

Complete Stratum V2 application development kit - all utilities in one crate
Documentation
#![allow(clippy::new_ret_no_self)]
use crate::network_helpers::{
    noise_stream::{NoiseTcpReadHalf, NoiseTcpStream, NoiseTcpWriteHalf},
    Error, NOISE_HANDSHAKE_TIMEOUT,
};
use async_channel::{unbounded, Receiver, Sender};
use std::sync::Arc;
use stratum_core::{
    binary_sv2::{Deserialize, GetSize, Serialize},
    codec_sv2::{HandshakeRole, StandardEitherFrame},
};
use tokio::{net::TcpStream, task};
use tracing::{debug, error};

pub struct Connection;

struct ConnectionState<Message> {
    sender_incoming: Sender<StandardEitherFrame<Message>>,
    receiver_incoming: Receiver<StandardEitherFrame<Message>>,
    sender_outgoing: Sender<StandardEitherFrame<Message>>,
    receiver_outgoing: Receiver<StandardEitherFrame<Message>>,
}

impl<Message> ConnectionState<Message> {
    fn close_all(&self) {
        self.sender_incoming.close();
        self.receiver_incoming.close();
        self.sender_outgoing.close();
        self.receiver_outgoing.close();
    }
}

impl Connection {
    pub async fn new<Message>(
        stream: TcpStream,
        role: HandshakeRole,
    ) -> Result<
        (
            Receiver<StandardEitherFrame<Message>>,
            Sender<StandardEitherFrame<Message>>,
        ),
        Error,
    >
    where
        Message: Serialize + Deserialize<'static> + GetSize + Send + 'static,
    {
        let (sender_incoming, receiver_incoming) = unbounded();
        let (sender_outgoing, receiver_outgoing) = unbounded();

        let conn_state = Arc::new(ConnectionState {
            sender_incoming,
            receiver_incoming: receiver_incoming.clone(),
            sender_outgoing: sender_outgoing.clone(),
            receiver_outgoing,
        });

        let (read_half, write_half) =
            NoiseTcpStream::<Message>::new(stream, role, NOISE_HANDSHAKE_TIMEOUT)
                .await?
                .into_split();

        Self::spawn_reader(read_half, Arc::clone(&conn_state));
        Self::spawn_writer(write_half, conn_state);

        Ok((receiver_incoming, sender_outgoing))
    }
    fn spawn_reader<Message>(
        mut read_half: NoiseTcpReadHalf<Message>,
        conn_state: Arc<ConnectionState<Message>>,
    ) -> task::JoinHandle<()>
    where
        Message: Serialize + Deserialize<'static> + GetSize + Send + 'static,
    {
        let sender_incoming = conn_state.sender_incoming.clone();

        task::spawn(async move {
            loop {
                tokio::select! {
                    _ = tokio::signal::ctrl_c() => {
                        debug!("Reader received shutdown signal.");
                        break;
                    }
                    res = read_half.read_frame() => match res {
                        Ok(frame) => {
                            if sender_incoming.send(frame).await.is_err() {
                                error!("Reader: channel closed, shutting down.");
                                break;
                            }
                        }
                        Err(e) => {
                            error!("Reader: error while reading frame: {e:?}");
                            break;
                        }
                    }
                }
            }

            conn_state.close_all();
        })
    }

    fn spawn_writer<Message>(
        mut write_half: NoiseTcpWriteHalf<Message>,
        conn_state: Arc<ConnectionState<Message>>,
    ) -> task::JoinHandle<()>
    where
        Message: Serialize + Deserialize<'static> + GetSize + Send + 'static,
    {
        let receiver_outgoing = conn_state.receiver_outgoing.clone();

        task::spawn(async move {
            loop {
                tokio::select! {
                    _ = tokio::signal::ctrl_c() => {
                        debug!("Writer received shutdown signal.");
                        break;
                    }
                    res = receiver_outgoing.recv() => match res {
                        Ok(frame) => {
                            if let Err(e) = write_half.write_frame(frame).await {
                                error!("Writer: error while writing frame: {e:?}");
                                break;
                            }
                        }
                        Err(_) => {
                            debug!("Writer: channel closed, shutting down.");
                            break;
                        }
                    }
                }
            }

            if let Err(e) = write_half.shutdown().await {
                error!("Writer: error during shutdown: {e:?}");
            }

            conn_state.close_all();
        })
    }
}