Skip to main content

over_there/core/transport/wire/
tcp.rs

1use super::{
2    auth, crypto, Authenticator, Bicrypter, Decrypter, Encrypter, InboundWire,
3    InboundWireError, OutboundWire, OutboundWireError, Signer, Verifier, Wire,
4};
5use std::net::SocketAddr;
6use tokio::{
7    io::{self, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
8    net::TcpStream,
9};
10
11pub struct TcpStreamWire<A, B>
12where
13    A: Authenticator,
14    B: Bicrypter,
15{
16    wire: Wire<A, B>,
17    stream: TcpStream,
18    remote_addr: SocketAddr,
19}
20
21impl<A, B> TcpStreamWire<A, B>
22where
23    A: Authenticator,
24    B: Bicrypter,
25{
26    pub fn new(
27        wire: Wire<A, B>,
28        stream: TcpStream,
29        remote_addr: SocketAddr,
30    ) -> Self {
31        Self {
32            wire,
33            stream,
34            remote_addr,
35        }
36    }
37
38    pub fn arc_split(
39        self,
40    ) -> (
41        TcpStreamInboundWire<
42            auth::split::VerifierHalf<A>,
43            crypto::split::DecrypterHalf<B>,
44        >,
45        TcpStreamOutboundWire<
46            auth::split::SignerHalf<A>,
47            crypto::split::EncrypterHalf<B>,
48        >,
49    ) {
50        let Self {
51            wire,
52            stream,
53            remote_addr,
54        } = self;
55        let (r, w) = io::split(stream);
56        let (iw, ow) = wire.arc_split();
57
58        (iw.with_tcp_stream(r, remote_addr), ow.with_tcp_stream(w))
59    }
60}
61
62impl<A, B> TcpStreamWire<A, B>
63where
64    A: Authenticator + Clone,
65    B: Bicrypter + Clone,
66{
67    pub fn clone_split(
68        self,
69    ) -> (TcpStreamInboundWire<A, B>, TcpStreamOutboundWire<A, B>) {
70        let Self {
71            wire,
72            stream,
73            remote_addr,
74        } = self;
75        let (r, w) = io::split(stream);
76        let (iw, ow) = wire.clone_split();
77        (iw.with_tcp_stream(r, remote_addr), ow.with_tcp_stream(w))
78    }
79}
80
81pub struct TcpStreamInboundWire<V, D>
82where
83    V: Verifier,
84    D: Decrypter,
85{
86    inbound_wire: InboundWire<V, D>,
87    stream: ReadHalf<TcpStream>,
88    remote_addr: SocketAddr,
89}
90
91impl<V, D> TcpStreamInboundWire<V, D>
92where
93    V: Verifier,
94    D: Decrypter,
95{
96    pub fn new(
97        inbound_wire: InboundWire<V, D>,
98        stream: ReadHalf<TcpStream>,
99        remote_addr: SocketAddr,
100    ) -> Self {
101        Self {
102            inbound_wire,
103            stream,
104            remote_addr,
105        }
106    }
107
108    pub async fn read(
109        &mut self,
110    ) -> Result<(Option<Vec<u8>>, SocketAddr), InboundWireError> {
111        let mut buf =
112            vec![0; self.inbound_wire.transmission_size()].into_boxed_slice();
113        let size = self
114            .stream
115            .read(&mut buf)
116            .await
117            .map_err(InboundWireError::IO)?;
118        let data = self.inbound_wire.process(&buf[..size])?;
119
120        Ok((data, self.remote_addr))
121    }
122}
123
124pub struct TcpStreamOutboundWire<S, E>
125where
126    S: Signer,
127    E: Encrypter,
128{
129    outbound_wire: OutboundWire<S, E>,
130    stream: WriteHalf<TcpStream>,
131}
132
133impl<S, E> TcpStreamOutboundWire<S, E>
134where
135    S: Signer,
136    E: Encrypter,
137{
138    pub fn new(
139        outbound_wire: OutboundWire<S, E>,
140        stream: WriteHalf<TcpStream>,
141    ) -> Self {
142        Self {
143            outbound_wire,
144            stream,
145        }
146    }
147
148    pub async fn write(&mut self, buf: &[u8]) -> Result<(), OutboundWireError> {
149        let data = self.outbound_wire.process(buf)?;
150
151        for packet_bytes in data.iter() {
152            let size = self
153                .stream
154                .write(packet_bytes)
155                .await
156                .map_err(OutboundWireError::IO)?;
157            if size < packet_bytes.len() {
158                return Err(OutboundWireError::IncompleteSend);
159            }
160        }
161
162        Ok(())
163    }
164}