1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
use async_channel::{bounded, Receiver, Sender};
use binary_sv2::{Deserialize, Serialize};
use core::convert::TryInto;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream},
task,
};
use binary_sv2::GetSize;
use codec_sv2::{Error::MissingBytes, StandardDecoder, StandardEitherFrame};
use tracing::{error, trace};
#[derive(Debug)]
pub struct PlainConnection {}
impl PlainConnection {
///
///
/// # Arguments
///
/// * `strict` - true - will disconnect a connection that sends a message that can't be translated, false - will ignore messages that can't be translated
///
#[allow(clippy::new_ret_no_self)]
pub async fn new<'a, Message: Serialize + Deserialize<'a> + GetSize + Send + 'static>(
stream: TcpStream,
) -> (
Receiver<StandardEitherFrame<Message>>,
Sender<StandardEitherFrame<Message>>,
) {
const NOISE_HANDSHAKE_SIZE_HINT: usize = 3363412;
let (mut reader, mut writer) = stream.into_split();
let (sender_incoming, receiver_incoming): (
Sender<StandardEitherFrame<Message>>,
Receiver<StandardEitherFrame<Message>>,
) = bounded(10); // TODO caller should provide this param
let (sender_outgoing, receiver_outgoing): (
Sender<StandardEitherFrame<Message>>,
Receiver<StandardEitherFrame<Message>>,
) = bounded(10); // TODO caller should provide this param
// RECEIVE AND PARSE INCOMING MESSAGES FROM TCP STREAM
task::spawn(async move {
let mut decoder = StandardDecoder::<Message>::new();
loop {
let writable = decoder.writable();
match reader.read_exact(writable).await {
Ok(_) => {
match decoder.next_frame() {
Ok(frame) => {
if let Err(e) = sender_incoming.send(frame.into()).await {
error!("Failed to send incoming message: {}", e);
task::yield_now().await;
break;
}
}
Err(MissingBytes(size)) => {
// Only disconnect if we get noise handshake message - this shouldn't
// happen in plain_connection
if size == NOISE_HANDSHAKE_SIZE_HINT {
error!("Got noise message on unencrypted connection - disconnecting");
break;
} else {
trace!("MissingBytes({}) on incoming message - ignoring", size);
}
}
Err(e) => {
error!("Failed to read from stream: {}", e);
sender_incoming.close();
task::yield_now().await;
break;
}
}
}
Err(e) => {
// Just fail and force to reinitialize everything
error!("Failed to read from stream: {}", e);
sender_incoming.close();
task::yield_now().await;
break;
}
}
}
});
// ENCODE AND SEND INCOMING MESSAGES TO TCP STREAM
task::spawn(async move {
let mut encoder = codec_sv2::Encoder::<Message>::new();
loop {
let received = receiver_outgoing.recv().await;
match received {
Ok(frame) => {
let b = encoder.encode(frame.try_into().unwrap()).unwrap();
match (writer).write_all(b).await {
Ok(_) => (),
Err(_) => {
let _ = writer.shutdown().await;
}
}
}
Err(_) => {
// Just fail and force to reinitilize everything
let _ = writer.shutdown().await;
error!("Failed to read from stream - terminating connection");
task::yield_now().await;
break;
}
};
}
});
(receiver_incoming, sender_outgoing)
}
}
pub async fn plain_listen(address: &str, sender: Sender<TcpStream>) {
let listener = TcpListener::bind(address).await.unwrap();
loop {
if let Ok((stream, _)) = listener.accept().await {
let _ = sender.send(stream).await;
}
}
}
pub async fn plain_connect(address: &str) -> Result<TcpStream, ()> {
let stream = TcpStream::connect(address).await.map_err(|_| ())?;
Ok(stream)
}