1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use futures_util::{SinkExt, StreamExt};
4use serde::Serialize;
5use tokio::io::{AsyncReadExt, AsyncWriteExt};
6use tokio::net::TcpStream;
7use tokio_tungstenite::{WebSocketStream, tungstenite::protocol::Message};
8
9use crate::model::Packet;
10
11#[derive(Debug, Clone, Copy)]
12pub enum NetType {
13 Tcp,
14 WebSocket,
15}
16
17#[async_trait]
18pub trait Connection: Send {
19 async fn send(&mut self, data: &[u8]) -> Result<()>;
20 async fn receive(&mut self) -> Result<Option<Vec<u8>>>;
21}
22
23pub struct TcpConnection {
24 stream: TcpStream,
25}
26
27impl TcpConnection {
28 pub async fn connect(addr: &str) -> Result<Self> {
29 let stream = TcpStream::connect(addr)
30 .await
31 .context("Failed to connect to TCP server")?;
32 Ok(Self { stream })
33 }
34}
35
36#[async_trait]
37impl Connection for TcpConnection {
38 async fn send(&mut self, data: &[u8]) -> Result<()> {
39 self.stream.write_all(data).await?;
40 Ok(())
41 }
42
43 async fn receive(&mut self) -> Result<Option<Vec<u8>>> {
44 let mut len_bytes = [0u8; 4];
46 self.stream.read_exact(&mut len_bytes).await?;
47
48 let length = u32::from_be_bytes(len_bytes) as usize;
49
50 let mut body = vec![0u8; length];
52 self.stream.read_exact(&mut body).await?;
53
54 Ok(Some(body))
55 }
56}
57
58pub struct WebSocketConnection {
59 stream: WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>,
60}
61
62impl WebSocketConnection {
63 pub async fn connect(addr: &str) -> Result<Self> {
64 let url = format!("ws://{}/ws", addr);
65 let (stream, _) = tokio_tungstenite::connect_async(&url)
66 .await
67 .context("Failed to connect to WebSocket server")?;
68 Ok(Self { stream })
69 }
70}
71
72#[async_trait]
73impl Connection for WebSocketConnection {
74 async fn send(&mut self, data: &[u8]) -> Result<()> {
75 self.stream.send(Message::Binary(data.to_vec())).await?;
76 Ok(())
77 }
78
79 async fn receive(&mut self) -> Result<Option<Vec<u8>>> {
80 match self.stream.next().await {
81 Some(Ok(msg)) => match msg {
82 Message::Binary(data) => {
83 if data.len() < 4 {
84 return Ok(None);
85 }
86 let length = u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize;
88 if data.len() >= 4 + length {
89 Ok(Some(data[4..4 + length].to_vec()))
90 } else {
91 Ok(Some(data[4..].to_vec()))
92 }
93 }
94 Message::Text(text) => Ok(Some(text.into_bytes())),
95 Message::Close(_) => Ok(None),
96 _ => Ok(None),
97 },
98 Some(Err(e)) => Err(e.into()),
99 None => Ok(None),
100 }
101 }
102}
103
104pub fn serialize_packet<T: Serialize>(obj: &T) -> Result<Vec<u8>> {
105 let packet = Packet::new(serde_json::to_vec(obj)?);
106 Ok(serde_json::to_vec(&packet)?)
107}