crowdstrike_cloudproto/framing/
socket.rs

1use crate::framing::packet::CloudProtoPacket;
2use crate::framing::CloudProtoError;
3use bytes::Bytes;
4use futures_util::{Sink, SinkExt, Stream, StreamExt};
5use std::pin::Pin;
6use std::task::{ready, Context, Poll};
7use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
8use tokio_util::codec::{BytesCodec, FramedRead, FramedWrite, LengthDelimitedCodec};
9use tracing::{error, trace};
10
11/// Default maximum size of a single [`CloudProtoPacket`](super::CloudProtoPacket), including header
12pub const DEFAULT_MAX_FRAME_LENGTH: usize = 32 * 1024 * 1024;
13
14/// The common socket that carries framing-layer [`packets`](super::CloudProtoPacket) used by higher level protocols
15pub struct CloudProtoSocket<IO: AsyncRead + AsyncWrite> {
16    read: FramedRead<ReadHalf<IO>, LengthDelimitedCodec>,
17    write: FramedWrite<WriteHalf<IO>, BytesCodec>,
18}
19
20impl<IO> CloudProtoSocket<IO>
21where
22    IO: AsyncRead + AsyncWrite,
23{
24    /// CloudProtoSocket is usually layered over a TLS session over TCP port 443,
25    /// so in practice `IO` should usually be `TlsStream<TcpStream>`.
26    ///
27    /// The socket buffers individual packets, and has a default maximum packet size of
28    /// [`DEFAULT_MAX_FRAME_LENGTH`](DEFAULT_MAX_FRAME_LENGTH).
29    /// See [`with_max_frame_length`](Self::with_max_frame_length) to adjust this limit.
30    pub fn new(io: IO) -> Self {
31        Self::with_max_frame_length(io, DEFAULT_MAX_FRAME_LENGTH)
32    }
33
34    /// CloudProtoSocket is usually layered over a TLS session over TCP port 443,
35    /// so in practice `IO` should usually be `TlsStream<TcpStream>`.
36    ///
37    /// The socket buffers individual packets, `max_frame_length` will be the maximum accepted size
38    /// of [`CloudProtoPacket`](super::CloudProtoPacket)s, including header.
39    pub fn with_max_frame_length(io: IO, max_frame_length: usize) -> Self {
40        let (read, write) = tokio::io::split(io);
41        let read = LengthDelimitedCodec::builder()
42            .big_endian()
43            .max_frame_length(max_frame_length)
44            .length_field_type::<u32>()
45            .length_adjustment(0)
46            .length_field_offset(4)
47            .num_skip(0)
48            .new_read(read);
49        let write = FramedWrite::new(write, BytesCodec::new());
50        Self { read, write }
51    }
52}
53
54impl<IO> Stream for CloudProtoSocket<IO>
55where
56    IO: AsyncRead + AsyncWrite,
57{
58    type Item = Result<CloudProtoPacket, CloudProtoError>;
59
60    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
61        let this = self.get_mut();
62        let pkt = match ready!(this.read.poll_next_unpin(cx)) {
63            Some(Ok(frame)) => CloudProtoPacket::from_buf(&frame),
64            Some(Err(e)) => {
65                return Poll::Ready(Some(Err(CloudProtoError::Io { source: e })));
66            }
67            None => return Poll::Ready(None),
68        };
69        match pkt {
70            Ok(pkt) => {
71                trace!(
72                    "Received kind 0x{:x} packet with 0x{:x} bytes payload: {}",
73                    pkt.kind,
74                    pkt.payload.len(),
75                    hex::encode(&pkt.payload),
76                );
77                Poll::Ready(Some(Ok(pkt)))
78            }
79            Err(e) => {
80                error!("Received bad cloudproto packet: {}", e);
81                Poll::Ready(Some(Err(e)))
82            }
83        }
84    }
85}
86
87impl<IO> Sink<CloudProtoPacket> for CloudProtoSocket<IO>
88where
89    IO: AsyncRead + AsyncWrite,
90{
91    type Error = std::io::Error;
92
93    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
94        SinkExt::<Bytes>::poll_ready_unpin(&mut self.get_mut().write, cx)
95    }
96
97    fn start_send(self: Pin<&mut Self>, pkt: CloudProtoPacket) -> Result<(), Self::Error> {
98        let this = self.get_mut();
99        let buf = Bytes::from(pkt.to_buf());
100        trace!(
101            "Sending kind 0x{:x} packet with 0x{:x} bytes payload: {}",
102            pkt.kind,
103            pkt.payload.len(),
104            hex::encode(&pkt.payload),
105        );
106        this.write.start_send_unpin(buf)
107    }
108
109    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
110        SinkExt::<Bytes>::poll_flush_unpin(&mut self.get_mut().write, cx)
111    }
112
113    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
114        SinkExt::<Bytes>::poll_close_unpin(&mut self.get_mut().write, cx)
115    }
116}
117
118#[cfg(test)]
119mod test {
120    use crate::framing::{CloudProtoPacket, CloudProtoSocket, CloudProtoVersion};
121    use crate::services::CloudProtoMagic;
122    use anyhow::Result;
123    use futures_util::{SinkExt, StreamExt};
124    use rand::Rng;
125
126    #[test_log::test(tokio::test)]
127    async fn single_send_recv() -> Result<()> {
128        let (client, server) = tokio::io::duplex(100 * 1024);
129        let mut client = CloudProtoSocket::new(client);
130        let mut server = CloudProtoSocket::new(server);
131
132        let mut rng = rand::thread_rng();
133        let len = rng.gen::<u16>() as usize;
134        let mut payload = Vec::with_capacity(len);
135        payload.resize(len, len as u8);
136        let pkt = CloudProtoPacket {
137            magic: CloudProtoMagic::TS,
138            kind: 0,
139            version: CloudProtoVersion::Normal,
140            payload,
141        };
142        client.send(pkt.clone()).await?;
143        let reply = server.next().await.unwrap()?;
144        assert_eq!(pkt, reply);
145
146        Ok(())
147    }
148}