use async_channel::{unbounded, Receiver, Sender};
use futures::StreamExt;
use stratum_core::sv1_api::json_rpc;
use tokio::{
io::{AsyncWriteExt, BufReader, BufWriter},
net::TcpStream,
};
use tokio_util::{
codec::{FramedRead, LinesCodec},
sync::CancellationToken,
};
use tracing::{error, trace, warn};
#[derive(Debug)]
pub struct ConnectionSV1 {
receiver: Receiver<json_rpc::Message>,
sender: Sender<json_rpc::Message>,
}
struct ConnectionState {
receiver_outgoing: Receiver<json_rpc::Message>,
sender_outgoing: Sender<json_rpc::Message>,
receiver_incoming: Receiver<json_rpc::Message>,
sender_incoming: Sender<json_rpc::Message>,
}
impl ConnectionState {
fn new(
receiver_outgoing: Receiver<json_rpc::Message>,
sender_outgoing: Sender<json_rpc::Message>,
receiver_incoming: Receiver<json_rpc::Message>,
sender_incoming: Sender<json_rpc::Message>,
) -> Self {
Self {
receiver_incoming,
receiver_outgoing,
sender_incoming,
sender_outgoing,
}
}
fn close(&self) {
self.receiver_incoming.close();
self.receiver_outgoing.close();
self.sender_incoming.close();
self.sender_outgoing.close();
}
}
const MAX_LINE_LENGTH: usize = 1 << 16;
impl ConnectionSV1 {
pub async fn new(stream: TcpStream, cancellation_token: CancellationToken) -> Self {
let (read_half, write_half) = stream.into_split();
let (sender_incoming, receiver_incoming) = unbounded();
let (sender_outgoing, receiver_outgoing) = unbounded();
let buffer_read_half = BufReader::new(read_half);
let buffer_write_half = BufWriter::new(write_half);
let connection_state = ConnectionState::new(
receiver_outgoing.clone(),
sender_outgoing.clone(),
receiver_incoming.clone(),
sender_incoming.clone(),
);
tokio::spawn(async move {
tokio::select! {
_ = cancellation_token.cancelled() => {
trace!("ConnectionSV1: received cancellation signal.");
connection_state.close();
}
_ = Self::run_reader(buffer_read_half, sender_incoming.clone()) => {
trace!("Reader task exited. Closing writer sender.");
connection_state.close();
}
_ = Self::run_writer(buffer_write_half, receiver_outgoing.clone()) => {
trace!("Writer task exited. Closing reader sender.");
connection_state.close();
}
}
});
Self {
receiver: receiver_incoming,
sender: sender_outgoing,
}
}
async fn run_reader(
reader: BufReader<tokio::net::tcp::OwnedReadHalf>,
sender: Sender<json_rpc::Message>,
) {
let mut lines = FramedRead::new(reader, LinesCodec::new_with_max_length(MAX_LINE_LENGTH));
while let Some(result) = lines.next().await {
match result {
Ok(line) => match serde_json::from_str::<json_rpc::Message>(&line) {
Ok(msg) => {
if sender.send(msg).await.is_err() {
warn!("Receiver dropped, stopping reader");
break;
}
}
Err(e) => {
error!("Failed to deserialize message: {e:?}");
}
},
Err(e) => {
error!("Error reading from stream: {e:?}");
break;
}
}
}
}
async fn run_writer(
mut writer: BufWriter<tokio::net::tcp::OwnedWriteHalf>,
receiver: Receiver<json_rpc::Message>,
) {
while let Ok(msg) = receiver.recv().await {
match serde_json::to_string(&msg) {
Ok(line) => {
let data = format!("{line}\n");
if writer.write_all(data.as_bytes()).await.is_err() {
error!("Failed to write to stream");
break;
}
if writer.flush().await.is_err() {
error!("Failed to flush writer.");
break;
}
}
Err(e) => {
error!("Failed to serialize message: {e:?}");
break;
}
}
}
}
pub async fn send(&self, msg: json_rpc::Message) -> bool {
self.sender.send(msg).await.is_ok()
}
pub async fn receive(&self) -> Option<json_rpc::Message> {
self.receiver.recv().await.ok()
}
pub fn receiver(&self) -> Receiver<json_rpc::Message> {
self.receiver.clone()
}
pub fn sender(&self) -> Sender<json_rpc::Message> {
self.sender.clone()
}
}
#[cfg(test)]
mod tests {
use tokio::net::TcpListener;
use super::*;
#[tokio::test]
async fn test_sv1_connection() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let downstream_stream = TcpStream::connect(addr).await.unwrap();
let (upstream_stream, _) = listener.accept().await.unwrap();
let upstream_connection =
ConnectionSV1::new(upstream_stream, CancellationToken::new()).await;
let downstream_connection =
ConnectionSV1::new(downstream_stream, CancellationToken::new()).await;
let message = json_rpc::Message::StandardRequest(json_rpc::StandardRequest {
id: 1,
method: "test".to_string(),
params: serde_json::Value::Null,
});
assert!(downstream_connection.send(message).await);
let received_on_upstream = upstream_connection.receive().await.unwrap();
match received_on_upstream {
json_rpc::Message::StandardRequest(received) => {
assert_eq!(received.id, 1);
assert_eq!(received.method, "test".to_string());
assert_eq!(received.params, serde_json::Value::Null);
}
_ => {
panic!("Unexpected message type");
}
}
let upstream_response = json_rpc::Message::OkResponse(json_rpc::Response {
id: 1,
result: serde_json::Value::String("response".to_string()),
error: None,
});
assert!(upstream_connection.send(upstream_response).await);
let received_upstream = downstream_connection.receive().await.unwrap();
match received_upstream {
json_rpc::Message::OkResponse(received) => {
assert_eq!(received.id, 1);
assert_eq!(
received.result,
serde_json::Value::String("response".to_string())
);
}
_ => {
panic!("Unexpected message type");
}
}
}
}