crowdstrike_cloudproto/framing/
socket.rs1use crate::framing::packet::CloudProtoPacket;
2use crate::framing::CloudProtoError;
3use bytes::Bytes;
4use futures_util::{Sink, SinkExt, Stream, StreamExt};
5use std::pin::Pin;
6use std::task::{ready, Context, Poll};
7use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
8use tokio_util::codec::{BytesCodec, FramedRead, FramedWrite, LengthDelimitedCodec};
9use tracing::{error, trace};
10
11pub const DEFAULT_MAX_FRAME_LENGTH: usize = 32 * 1024 * 1024;
13
14pub struct CloudProtoSocket<IO: AsyncRead + AsyncWrite> {
16 read: FramedRead<ReadHalf<IO>, LengthDelimitedCodec>,
17 write: FramedWrite<WriteHalf<IO>, BytesCodec>,
18}
19
20impl<IO> CloudProtoSocket<IO>
21where
22 IO: AsyncRead + AsyncWrite,
23{
24 pub fn new(io: IO) -> Self {
31 Self::with_max_frame_length(io, DEFAULT_MAX_FRAME_LENGTH)
32 }
33
34 pub fn with_max_frame_length(io: IO, max_frame_length: usize) -> Self {
40 let (read, write) = tokio::io::split(io);
41 let read = LengthDelimitedCodec::builder()
42 .big_endian()
43 .max_frame_length(max_frame_length)
44 .length_field_type::<u32>()
45 .length_adjustment(0)
46 .length_field_offset(4)
47 .num_skip(0)
48 .new_read(read);
49 let write = FramedWrite::new(write, BytesCodec::new());
50 Self { read, write }
51 }
52}
53
54impl<IO> Stream for CloudProtoSocket<IO>
55where
56 IO: AsyncRead + AsyncWrite,
57{
58 type Item = Result<CloudProtoPacket, CloudProtoError>;
59
60 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
61 let this = self.get_mut();
62 let pkt = match ready!(this.read.poll_next_unpin(cx)) {
63 Some(Ok(frame)) => CloudProtoPacket::from_buf(&frame),
64 Some(Err(e)) => {
65 return Poll::Ready(Some(Err(CloudProtoError::Io { source: e })));
66 }
67 None => return Poll::Ready(None),
68 };
69 match pkt {
70 Ok(pkt) => {
71 trace!(
72 "Received kind 0x{:x} packet with 0x{:x} bytes payload: {}",
73 pkt.kind,
74 pkt.payload.len(),
75 hex::encode(&pkt.payload),
76 );
77 Poll::Ready(Some(Ok(pkt)))
78 }
79 Err(e) => {
80 error!("Received bad cloudproto packet: {}", e);
81 Poll::Ready(Some(Err(e)))
82 }
83 }
84 }
85}
86
87impl<IO> Sink<CloudProtoPacket> for CloudProtoSocket<IO>
88where
89 IO: AsyncRead + AsyncWrite,
90{
91 type Error = std::io::Error;
92
93 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
94 SinkExt::<Bytes>::poll_ready_unpin(&mut self.get_mut().write, cx)
95 }
96
97 fn start_send(self: Pin<&mut Self>, pkt: CloudProtoPacket) -> Result<(), Self::Error> {
98 let this = self.get_mut();
99 let buf = Bytes::from(pkt.to_buf());
100 trace!(
101 "Sending kind 0x{:x} packet with 0x{:x} bytes payload: {}",
102 pkt.kind,
103 pkt.payload.len(),
104 hex::encode(&pkt.payload),
105 );
106 this.write.start_send_unpin(buf)
107 }
108
109 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
110 SinkExt::<Bytes>::poll_flush_unpin(&mut self.get_mut().write, cx)
111 }
112
113 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
114 SinkExt::<Bytes>::poll_close_unpin(&mut self.get_mut().write, cx)
115 }
116}
117
118#[cfg(test)]
119mod test {
120 use crate::framing::{CloudProtoPacket, CloudProtoSocket, CloudProtoVersion};
121 use crate::services::CloudProtoMagic;
122 use anyhow::Result;
123 use futures_util::{SinkExt, StreamExt};
124 use rand::Rng;
125
126 #[test_log::test(tokio::test)]
127 async fn single_send_recv() -> Result<()> {
128 let (client, server) = tokio::io::duplex(100 * 1024);
129 let mut client = CloudProtoSocket::new(client);
130 let mut server = CloudProtoSocket::new(server);
131
132 let mut rng = rand::thread_rng();
133 let len = rng.gen::<u16>() as usize;
134 let mut payload = Vec::with_capacity(len);
135 payload.resize(len, len as u8);
136 let pkt = CloudProtoPacket {
137 magic: CloudProtoMagic::TS,
138 kind: 0,
139 version: CloudProtoVersion::Normal,
140 payload,
141 };
142 client.send(pkt.clone()).await?;
143 let reply = server.next().await.unwrap()?;
144 assert_eq!(pkt, reply);
145
146 Ok(())
147 }
148}