ftth_rsipstack/transport/
stream.rs1use crate::rsip;
2use crate::{
3 transport::{
4 connection::{TransportSender, KEEPALIVE_REQUEST, KEEPALIVE_RESPONSE},
5 SipAddr, SipConnection, TransportEvent,
6 },
7 Result,
8};
9use bytes::{Buf, BytesMut};
10use rsip::SipMessage;
11use tokio::{
12 io::{AsyncRead, AsyncWrite, AsyncWriteExt},
13 sync::Mutex,
14};
15use tokio_util::codec::{Decoder, Encoder};
16use tracing::{debug, info, warn};
17
18pub(super) const MAX_SIP_MESSAGE_SIZE: usize = 65535;
19
20pub struct SipCodec {}
21
22impl SipCodec {
23 pub fn new() -> Self {
24 Self {}
25 }
26}
27
28impl Default for SipCodec {
29 fn default() -> Self {
30 Self::new()
31 }
32}
33
34#[derive(Debug, Clone)]
35pub enum SipCodecType {
36 Message(SipMessage),
37 KeepaliveRequest,
38 KeepaliveResponse,
39}
40
41impl std::fmt::Display for SipCodecType {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 match self {
44 SipCodecType::Message(msg) => write!(f, "{}", msg),
45 SipCodecType::KeepaliveRequest => write!(f, "Keepalive Request"),
46 SipCodecType::KeepaliveResponse => write!(f, "Keepalive Response"),
47 }
48 }
49}
50
51impl Decoder for SipCodec {
52 type Item = SipCodecType;
53 type Error = crate::Error;
54
55 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
56 if src.len() >= 4 && &src[0..4] == KEEPALIVE_REQUEST {
57 src.advance(4);
58 return Ok(Some(SipCodecType::KeepaliveRequest));
59 }
60
61 if src.len() >= 2 && &src[0..2] == KEEPALIVE_RESPONSE {
62 src.advance(2);
63 return Ok(Some(SipCodecType::KeepaliveResponse));
64 }
65
66 if let Some(end_pos) = src
67 .windows(KEEPALIVE_REQUEST.len())
68 .position(|window| window == KEEPALIVE_REQUEST)
69 {
70 let msg_end = end_pos + KEEPALIVE_REQUEST.len();
71 let msg_data = &src[..msg_end];
72 match SipMessage::try_from(msg_data) {
73 Ok(msg) => {
74 src.advance(msg_end);
75 Ok(Some(SipCodecType::Message(msg)))
76 }
77 Err(e) => {
78 src.advance(msg_end);
79 Err(crate::Error::Error(format!(
80 "Failed to parse SIP message: {}",
81 e
82 )))
83 }
84 }
85 } else {
86 if src.len() > MAX_SIP_MESSAGE_SIZE {
87 return Err(crate::Error::Error("SIP message too large".to_string()));
88 }
89 Ok(None)
90 }
91 }
92}
93
94impl Encoder<SipMessage> for SipCodec {
95 type Error = crate::Error;
96
97 fn encode(&mut self, item: SipMessage, dst: &mut BytesMut) -> Result<()> {
98 let data = item.to_string();
99 dst.extend_from_slice(data.as_bytes());
100 Ok(())
101 }
102}
103
104pub struct StreamConnectionInner<R, W>
105where
106 R: AsyncRead + Unpin + Send,
107 W: AsyncWrite + Unpin + Send,
108{
109 pub local_addr: SipAddr,
110 pub remote_addr: SipAddr,
111 pub read_half: Mutex<Option<R>>,
112 pub write_half: Mutex<W>,
113}
114
115impl<R, W> StreamConnectionInner<R, W>
116where
117 R: AsyncRead + Unpin + Send,
118 W: AsyncWrite + Unpin + Send,
119{
120 pub fn new(local_addr: SipAddr, remote_addr: SipAddr, read_half: R, write_half: W) -> Self {
121 Self {
122 local_addr,
123 remote_addr,
124 read_half: Mutex::new(Some(read_half)),
125 write_half: Mutex::new(write_half),
126 }
127 }
128
129 pub async fn send_message(&self, msg: SipMessage) -> Result<()> {
130 send_to_stream(&self.write_half, msg).await
131 }
132
133 pub async fn send_raw(&self, data: &[u8]) -> Result<()> {
134 send_raw_to_stream(&self.write_half, data).await
135 }
136
137 pub async fn serve_loop(
138 &self,
139 sender: TransportSender,
140 connection: SipConnection,
141 ) -> Result<()> {
142 let mut read_half = match self.read_half.lock().await.take() {
143 Some(read_half) => read_half,
144 None => {
145 warn!("Connection closed");
146 return Ok(());
147 }
148 };
149
150 let remote_addr = self.remote_addr.clone();
151
152 let mut codec = SipCodec::new();
153 let mut buffer = BytesMut::with_capacity(MAX_SIP_MESSAGE_SIZE);
154 let mut read_buf = [0u8; MAX_SIP_MESSAGE_SIZE];
155
156 loop {
157 use tokio::io::AsyncReadExt;
158 match read_half.read(&mut read_buf).await {
159 Ok(0) => {
160 info!("Connection closed: {}", self.local_addr);
161 break;
162 }
163 Ok(n) => {
164 buffer.extend_from_slice(&read_buf[0..n]);
165
166 loop {
167 match codec.decode(&mut buffer) {
168 Ok(Some(msg)) => match msg {
169 SipCodecType::Message(sip_msg) => {
170 debug!("Received message from {}: {}", remote_addr, sip_msg);
171 let remote_socket_addr = remote_addr.get_socketaddr()?;
172 let sip_msg = SipConnection::update_msg_received(
173 sip_msg,
174 remote_socket_addr,
175 remote_addr.r#type.unwrap_or_default(),
176 )?;
177
178 if let Err(e) = sender.send(TransportEvent::Incoming(
179 sip_msg,
180 connection.clone(),
181 remote_addr.clone(),
182 )) {
183 warn!("Error sending incoming message: {:?}", e);
184 return Err(e.into());
185 }
186 }
187 SipCodecType::KeepaliveRequest => {
188 self.send_raw(KEEPALIVE_RESPONSE).await?;
189 }
190 SipCodecType::KeepaliveResponse => {}
191 },
192 Ok(None) => {
193 break;
195 }
196 Err(e) => {
197 warn!("Error decoding message from {}: {:?}", remote_addr, e);
198 }
200 }
201 }
202 }
203 Err(e) => {
204 warn!("Error reading from stream: {}", e);
205 break;
206 }
207 }
208 }
209 Ok(())
210 }
211
212 pub async fn close(&self) -> Result<()> {
213 let mut write_half = self.write_half.lock().await;
214 write_half
215 .shutdown()
216 .await
217 .map_err(|e| crate::Error::Error(format!("Failed to shutdown write half: {}", e)))?;
218 Ok(())
219 }
220}
221
222#[async_trait::async_trait]
223pub trait StreamConnection: Send + Sync + 'static {
224 fn get_addr(&self) -> &SipAddr;
225 async fn send_message(&self, msg: SipMessage) -> Result<()>;
226 async fn send_raw(&self, data: &[u8]) -> Result<()>;
227 async fn serve_loop(&self, sender: TransportSender) -> Result<()>;
228 async fn close(&self) -> Result<()>;
229}
230
231pub async fn send_to_stream<W>(write_half: &Mutex<W>, msg: SipMessage) -> Result<()>
232where
233 W: AsyncWrite + Unpin + Send,
234{
235 send_raw_to_stream(write_half, msg.to_string().as_bytes()).await
236}
237
238pub async fn send_raw_to_stream<W>(write_half: &Mutex<W>, data: &[u8]) -> Result<()>
239where
240 W: AsyncWrite + Unpin + Send,
241{
242 let mut lock = write_half.lock().await;
243 lock.write_all(data).await?;
244 lock.flush().await?;
245 Ok(())
246}