stratum-server 3.0.0-beta-2

The server code for the Rust Stratum (v1) implementation
Documentation
pub use crate::ConnectionList;
use crate::{
    connection::Connection,
    id_manager::IDManager,
    router::Router,
    server::{UpstreamConfig, VarDiffConfig},
    types::{ExMessageGeneric, MessageValue},
    BanManager, Error, Result, EX_MAGIC_NUMBER,
};
use async_std::{net::TcpStream, prelude::FutureExt, sync::Arc};
use extended_primitives::Buffer;
use futures::{
    channel::mpsc::{unbounded, UnboundedReceiver},
    io::{AsyncBufReadExt, AsyncReadExt, BufReader, ReadHalf, WriteHalf},
    AsyncWriteExt, SinkExt, StreamExt,
};
use log::{trace, warn};
use serde_json::{Map, Value};
use std::{net::SocketAddr, time::Duration};
use stop_token::future::FutureExt as stopFutureExt;

//@todo might make sene to wrap a lot of these into one param called "ConnectionConfig" and then
//just pass that along, but we'll see.
#[allow(clippy::too_many_arguments)]
pub async fn handle_connection<
    State: Clone + Send + Sync + 'static,
    CState: Clone + Send + Sync + 'static,
>(
    id_manager: Arc<IDManager>,
    ban_manager: Arc<BanManager>,
    mut addr: SocketAddr,
    connection_list: Arc<ConnectionList<CState>>,
    router: Arc<Router<State, CState>>,
    upstream_router: Arc<Router<State, CState>>,
    upstream_config: UpstreamConfig,
    state: State,
    stream: TcpStream,
    var_diff_config: VarDiffConfig,
    initial_difficulty: f64,
    connection_state: CState,
    proxy: bool,
    expected_port: u16,
) {
    let (rh, wh) = stream.split();

    let mut buffer_stream = BufReader::new(rh);

    //@todo wrap this in a function.
    if proxy {
        let mut buf = String::new();
        //@todo handle unwrap here.
        buffer_stream.read_line(&mut buf).await.unwrap();

        //Buf will be of the format "PROXY TCP4 92.118.161.17 172.20.42.228 55867 8080\r\n"
        //Trim the \r\n off
        let buf = buf.trim();
        //Might want to not be ascii whitespace and just normal here.
        // let pieces = buf.split_ascii_whitespace();

        let pieces: Vec<&str> = buf.split(' ').collect();

        let attempted_port: u16 = pieces[5].parse().unwrap();

        //Check that they were trying to connect to us.
        if attempted_port != expected_port {
            //@todo Warn error here
            return;
        }

        addr = format!("{}:{}", pieces[2], pieces[4]).parse().unwrap();
    }

    if !ban_manager.check_banned(&addr).await {
        let (tx, rx) = unbounded();
        let (utx, urx) = unbounded();
        let (mut urtx, urrx) = unbounded();

        //@todo we should be printing the number of sessions issued out of the total supported.
        //Currently have 24 sessions connected out of 15,000 total. <1% capacity.
        let connection_id = match id_manager.allocate_session_id().await {
            Some(id) => id,
            None => {
                warn!("Sessions full");
                return;
            }
        };

        let connection = Arc::new(Connection::new(
            connection_id,
            tx,
            utx,
            urrx,
            initial_difficulty,
            var_diff_config,
            connection_state,
        ));

        let stop_token = connection.get_stop_token();

        if upstream_config.enabled {
            let upstream = TcpStream::connect(upstream_config.url).await.unwrap();

            let (urh, uwh) = upstream.split();
            let mut upstream_buffer_stream = BufReader::new(urh);

            async_std::task::spawn(async move {
                match upstream_send_loop(urx, uwh).await {
                    //@todo not sure if we even want a info here, we need an ID tho.
                    Ok(_) => trace!("Upstream Send Loop is closing for connection"),
                    Err(e) => warn!(
                        "Upstream Send loop is closed for connection: {}, Reason: {}",
                        1, e
                    ),
                }
            });

            async_std::task::spawn({
                let state = state.clone();
                let connection = connection.clone();
                let stop_token = stop_token.clone();

                async move {
                    loop {
                        // if connection.is_disconnected().await {
                        //     break;
                        // }

                        //Maybe have a wrap here or something and on Error instead of unwrap we
                        //break.
                        // let (method, values) = match connection.next_message().await {
                        //     Ok((method, values)) => (method, values),
                        //     Err(_) => {
                        //         break;
                        //     }
                        // };
                        // @todo here is our error on disconnecting miners.... If they get disconnected, BUT we
                        // never receive a message from them, then we will just sit on this loop indefinitely.
                        // So we need a way of actively telling this loop to disconnect.
                        // @todo it's possible this is super inefficient doing this on each message, but I
                        // can't really think of a better way to incorporate stop_token.
                        let next_message = next_message(&mut upstream_buffer_stream)
                            .timeout_at(stop_token.clone());

                        //@todo has to be better ways of handling this so we can use ?s
                        let (method, values) = match next_message.await {
                            Ok(result) => match result {
                                Ok((method, values)) => {
                                    if method == "result" {
                                        match values {
                                            //@todo if this is not StratumV1, we might not want to
                                            //just continue (Although there are plenty of decent
                                            //backstops here, just thinking super cautiously and
                                            //covering all our basis. So let's review this and
                                            //explcitly handle the scenario.)
                                            MessageValue::StratumV1(map) => {
                                                //@todo handle bad sends
                                                urtx.send(map).await;
                                            }
                                            _ => {}
                                        };
                                        continue;
                                    }

                                    (method, values)
                                }
                                Err(_) => {
                                    break;
                                }
                            },
                            Err(_) => {
                                break;
                            }
                        };

                        upstream_router
                            .call(&method, values, state.clone(), connection.clone())
                            .await;
                    }
                }
            });
        }

        // let stop_token = connection.get_stop_token();

        //@todo use stop token inside this send loop.
        async_std::task::spawn(async move {
            match send_loop(rx, wh).await {
                //@todo not sure if we even want a info(now trace) here, we need an ID tho.
                Ok(_) => trace!("Send Loop is closing for connection"),
                Err(e) => warn!("Send loop is closed for connection: {}, Reason: {}", 1, e),
            }
        });

        connection_list
            .add_miner(addr, connection.clone())
            .await
            .unwrap();

        trace!("Accepting stream from: {}", addr);

        loop {
            if connection.is_disconnected().await {
                break;
            }

            //@todo we need to figure out if this actually works or not.
            //@todo write a test for this.
            let next_message = next_message(&mut buffer_stream)
                .timeout(Duration::from_secs(60))
                .timeout_at(stop_token.clone())
                .await;

            if let Ok(Ok(Ok((method, values)))) = next_message {
                router
                    .call(&method, values, state.clone(), connection.clone())
                    .await;
            }
        }

        trace!("Closing stream from: {}", addr);

        id_manager.remove_session_id(connection_id).await;
        connection_list.remove_miner(addr).await;

        if connection.needs_ban().await {
            ban_manager.add_ban(&addr).await;
        }

        connection.shutdown().await;
    } else {
        warn!(
            "Banned connection attempting to connect: {}. Connected closed",
            addr
        );
    }
}

pub async fn next_message(
    stream: &mut BufReader<ReadHalf<TcpStream>>,
) -> Result<(String, MessageValue)> {
    //I don't actually think this has to loop here.
    loop {
        let peak = stream.fill_buf().await?;

        if peak.len() == 0 {
            return Err(Error::StreamClosed);
        }

        if peak[0] == EX_MAGIC_NUMBER {
            let mut header_bytes = vec![0u8; 4];
            stream.read_exact(&mut header_bytes).await?;
            let mut header_buffer = Buffer::from(header_bytes);
            let mut saved_header_buffer = header_buffer.clone();

            let _magic_number = header_buffer.read_u8().map_err(|_| Error::BrokenExHeader)?;
            let _cmd = header_buffer.read_u8().map_err(|_| Error::BrokenExHeader)?;
            let length = header_buffer
                .read_u16()
                .map_err(|_| Error::BrokenExHeader)?;

            let mut buf = vec![0u8; length as usize - 4];
            stream.read_exact(&mut buf).await?;

            let buffer = Buffer::from(buf);

            //Add the new buffer body (buffer) to the header_bytes that we had previously saved.
            saved_header_buffer.extend(buffer);

            let ex_message = ExMessageGeneric::from_buffer(&mut saved_header_buffer)?;
            return Ok((
                ex_message.cmd.to_string(),
                MessageValue::ExMessage(ex_message),
            ));
        }

        //If we have reached here, then we did not breat the "Peak test" searching for the magic
        //number of ExMessage.

        //@todo let's break this into 2 separate functions eh?
        let mut buf = String::new();
        let num_bytes = stream.read_line(&mut buf).await?;

        if num_bytes == 0 {
            return Err(Error::StreamClosed);
        }

        if !buf.is_empty() {
            //@smells
            buf = buf.trim().to_owned();

            trace!("Received Message: {}", &buf);

            if buf.is_empty() {
                continue;
            }

            let msg: Map<String, Value> = match serde_json::from_str(&buf) {
                Ok(msg) => msg,
                Err(_) => continue,
            };

            let method = if msg.contains_key("method") {
                match msg.get("method") {
                    Some(method) => method.as_str(),
                    //@todo need better stratum erroring here.
                    None => return Err(Error::MethodDoesntExist),
                }
            } else if msg.contains_key("messsage") {
                match msg.get("message") {
                    Some(method) => method.as_str(),
                    None => return Err(Error::MethodDoesntExist),
                }
            } else if msg.contains_key("result") {
                Some("result")
            } else {
                // return Err(Error::MethodDoesntExist);
                Some("")
            };

            if let Some(method_string) = method {
                //Mark the sender as active as we received a message.
                //We only mark them as active if the message/method was valid
                // self.stats.lock().await.last_active = Utc::now().naive_utc();
                // @todo maybe expose a function on the connection for this btw.

                return Ok((method_string.to_owned(), MessageValue::StratumV1(msg)));
            } else {
                //@todo improper format
                return Err(Error::MethodDoesntExist);
            }
        };
    }
}

pub async fn send_loop(
    mut rx: UnboundedReceiver<String>,
    mut rh: WriteHalf<TcpStream>,
) -> Result<()> {
    while let Some(msg) = rx.next().await {
        rh.write_all(msg.as_bytes()).await?;
        //@todo the reason we write this here is that JSON RPC messages are ended with a newline.
        //This probably should be built into the rpc library, but it works here for now.
        //Don't move this unless websockets ALSO require the newline, then we can move it back into
        //the Connection.send function.
        rh.write_all(b"\n").await?;
    }

    Ok(())
}

pub async fn upstream_send_loop(
    mut rx: UnboundedReceiver<String>,
    mut rh: WriteHalf<TcpStream>,
) -> Result<()> {
    while let Some(msg) = rx.next().await {
        rh.write_all(msg.as_bytes()).await?;
        rh.write_all(b"\n").await?;
    }

    Ok(())
}