ratel_rust/
network.rs

1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use futures_util::{SinkExt, StreamExt};
4use serde::Serialize;
5use tokio::io::{AsyncReadExt, AsyncWriteExt};
6use tokio::net::TcpStream;
7use tokio_tungstenite::{WebSocketStream, tungstenite::protocol::Message};
8
9use crate::model::Packet;
10
11#[derive(Debug, Clone, Copy)]
12pub enum NetType {
13    Tcp,
14    WebSocket,
15}
16
17#[async_trait]
18pub trait Connection: Send {
19    async fn send(&mut self, data: &[u8]) -> Result<()>;
20    async fn receive(&mut self) -> Result<Option<Vec<u8>>>;
21}
22
23pub struct TcpConnection {
24    stream: TcpStream,
25}
26
27impl TcpConnection {
28    pub async fn connect(addr: &str) -> Result<Self> {
29        let stream = TcpStream::connect(addr)
30            .await
31            .context("Failed to connect to TCP server")?;
32        Ok(Self { stream })
33    }
34}
35
36#[async_trait]
37impl Connection for TcpConnection {
38    async fn send(&mut self, data: &[u8]) -> Result<()> {
39        self.stream.write_all(data).await?;
40        Ok(())
41    }
42
43    async fn receive(&mut self) -> Result<Option<Vec<u8>>> {
44        // Read 4-byte length prefix (big-endian)
45        let mut len_bytes = [0u8; 4];
46        self.stream.read_exact(&mut len_bytes).await?;
47
48        let length = u32::from_be_bytes(len_bytes) as usize;
49
50        // Read message body
51        let mut body = vec![0u8; length];
52        self.stream.read_exact(&mut body).await?;
53
54        Ok(Some(body))
55    }
56}
57
58pub struct WebSocketConnection {
59    stream: WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>,
60}
61
62impl WebSocketConnection {
63    pub async fn connect(addr: &str) -> Result<Self> {
64        let url = format!("ws://{}/ws", addr);
65        let (stream, _) = tokio_tungstenite::connect_async(&url)
66            .await
67            .context("Failed to connect to WebSocket server")?;
68        Ok(Self { stream })
69    }
70}
71
72#[async_trait]
73impl Connection for WebSocketConnection {
74    async fn send(&mut self, data: &[u8]) -> Result<()> {
75        self.stream.send(Message::Binary(data.to_vec())).await?;
76        Ok(())
77    }
78
79    async fn receive(&mut self) -> Result<Option<Vec<u8>>> {
80        match self.stream.next().await {
81            Some(Ok(msg)) => match msg {
82                Message::Binary(data) => {
83                    if data.len() < 4 {
84                        return Ok(None);
85                    }
86                    // Extract length and return body
87                    let length = u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize;
88                    if data.len() >= 4 + length {
89                        Ok(Some(data[4..4 + length].to_vec()))
90                    } else {
91                        Ok(Some(data[4..].to_vec()))
92                    }
93                }
94                Message::Text(text) => Ok(Some(text.into_bytes())),
95                Message::Close(_) => Ok(None),
96                _ => Ok(None),
97            },
98            Some(Err(e)) => Err(e.into()),
99            None => Ok(None),
100        }
101    }
102}
103
104pub fn serialize_packet<T: Serialize>(obj: &T) -> Result<Vec<u8>> {
105    let packet = Packet::new(serde_json::to_vec(obj)?);
106    Ok(serde_json::to_vec(&packet)?)
107}