1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
use super::codec::LengthBasedFrameCodec;
use futures::{SinkExt, StreamExt};
use rsocket_rust::error::RSocketError;
use rsocket_rust::frame::Frame;
use rsocket_rust::runtime::{DefaultSpawner, Spawner};
use rsocket_rust::transport::{ClientTransport, Rx, Tx, TxOnce};
use std::net::{AddrParseError, SocketAddr, TcpStream as StdTcpStream};
use std::str::FromStr;
use tokio::net::TcpStream;
use tokio_util::codec::Framed;
enum Connector {
Direct(TcpStream),
Lazy(SocketAddr),
}
pub struct TcpClientTransport {
connector: Connector,
}
impl TcpClientTransport {
#[inline]
fn new(connector: Connector) -> TcpClientTransport {
TcpClientTransport { connector }
}
#[inline]
async fn connect(self) -> Result<TcpStream, RSocketError> {
match self.connector {
Connector::Direct(stream) => Ok(stream),
Connector::Lazy(addr) => match StdTcpStream::connect(&addr) {
Ok(raw) => match TcpStream::from_std(raw) {
Ok(stream) => Ok(stream),
Err(e) => Err(RSocketError::from(e)),
},
Err(e) => Err(RSocketError::from(e)),
},
}
}
}
impl ClientTransport for TcpClientTransport {
fn attach(
self,
incoming: Tx<Frame>,
mut sending: Rx<Frame>,
connected: Option<TxOnce<Result<(), RSocketError>>>,
) {
DefaultSpawner.spawn(async move {
match self.connect().await {
Ok(socket) => {
if let Some(sender) = connected {
sender.send(Ok(())).unwrap();
}
let (mut writer, mut reader) =
Framed::new(socket, LengthBasedFrameCodec).split();
DefaultSpawner.spawn(async move {
while let Some(it) = reader.next().await {
incoming.unbounded_send(it.unwrap()).unwrap();
}
});
while let Some(it) = sending.next().await {
debug!("===> SND: {:?}", &it);
writer.send(it).await.unwrap()
}
}
Err(e) => {
if let Some(sender) = connected {
sender.send(Err(e)).unwrap();
}
}
}
});
}
}
impl FromStr for TcpClientTransport {
type Err = AddrParseError;
fn from_str(addr: &str) -> Result<Self, Self::Err> {
let socket_addr = if addr.starts_with("tcp://") || addr.starts_with("TCP://") {
addr.chars().skip(6).collect::<String>().parse()?
} else {
addr.parse()?
};
Ok(TcpClientTransport::new(Connector::Lazy(socket_addr)))
}
}
impl From<SocketAddr> for TcpClientTransport {
fn from(addr: SocketAddr) -> TcpClientTransport {
TcpClientTransport::new(Connector::Lazy(addr))
}
}
impl From<&str> for TcpClientTransport {
fn from(addr: &str) -> TcpClientTransport {
let socket_addr: SocketAddr = if addr.starts_with("tcp://") || addr.starts_with("TCP://") {
addr.chars().skip(6).collect::<String>().parse()
} else {
addr.parse()
}
.expect("Invalid transport string!");
TcpClientTransport::new(Connector::Lazy(socket_addr))
}
}
impl From<TcpStream> for TcpClientTransport {
fn from(socket: TcpStream) -> TcpClientTransport {
TcpClientTransport::new(Connector::Direct(socket))
}
}