ftth_rsipstack/transport/
stream.rs

1use crate::{
2    transport::{
3        connection::{TransportSender, KEEPALIVE_REQUEST, KEEPALIVE_RESPONSE},
4        SipAddr, SipConnection, TransportEvent,
5    },
6    Result,
7};
8use crate::rsip;
9use bytes::{Buf, BytesMut};
10use rsip::SipMessage;
11use tokio::{
12    io::{AsyncRead, AsyncWrite, AsyncWriteExt},
13    sync::Mutex,
14};
15use tokio_util::codec::{Decoder, Encoder};
16use tracing::{debug, info, warn};
17
18pub(super) const MAX_SIP_MESSAGE_SIZE: usize = 65535;
19
20pub struct SipCodec {}
21
22impl SipCodec {
23    pub fn new() -> Self {
24        Self {}
25    }
26}
27
28impl Default for SipCodec {
29    fn default() -> Self {
30        Self::new()
31    }
32}
33
34#[derive(Debug, Clone)]
35pub enum SipCodecType {
36    Message(SipMessage),
37    KeepaliveRequest,
38    KeepaliveResponse,
39}
40
41impl std::fmt::Display for SipCodecType {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        match self {
44            SipCodecType::Message(msg) => write!(f, "{}", msg),
45            SipCodecType::KeepaliveRequest => write!(f, "Keepalive Request"),
46            SipCodecType::KeepaliveResponse => write!(f, "Keepalive Response"),
47        }
48    }
49}
50
51impl Decoder for SipCodec {
52    type Item = SipCodecType;
53    type Error = crate::Error;
54
55    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
56        if src.len() >= 4 && &src[0..4] == KEEPALIVE_REQUEST {
57            src.advance(4);
58            return Ok(Some(SipCodecType::KeepaliveRequest));
59        }
60
61        if src.len() >= 2 && &src[0..2] == KEEPALIVE_RESPONSE {
62            src.advance(2);
63            return Ok(Some(SipCodecType::KeepaliveResponse));
64        }
65
66        if let Some(end_pos) = src
67            .windows(KEEPALIVE_REQUEST.len())
68            .position(|window| window == KEEPALIVE_REQUEST)
69        {
70            let msg_end = end_pos + KEEPALIVE_REQUEST.len();
71            let msg_data = &src[..msg_end];
72            match SipMessage::try_from(msg_data) {
73                Ok(msg) => {
74                    src.advance(msg_end);
75                    Ok(Some(SipCodecType::Message(msg)))
76                }
77                Err(e) => {
78                    src.advance(msg_end);
79                    Err(crate::Error::Error(format!(
80                        "Failed to parse SIP message: {}",
81                        e
82                    )))
83                }
84            }
85        } else {
86            if src.len() > MAX_SIP_MESSAGE_SIZE {
87                return Err(crate::Error::Error("SIP message too large".to_string()));
88            }
89            Ok(None)
90        }
91    }
92}
93
94impl Encoder<SipMessage> for SipCodec {
95    type Error = crate::Error;
96
97    fn encode(&mut self, item: SipMessage, dst: &mut BytesMut) -> Result<()> {
98        let data = item.to_string();
99        dst.extend_from_slice(data.as_bytes());
100        Ok(())
101    }
102}
103
104pub struct StreamConnectionInner<R, W>
105where
106    R: AsyncRead + Unpin + Send,
107    W: AsyncWrite + Unpin + Send,
108{
109    pub local_addr: SipAddr,
110    pub remote_addr: SipAddr,
111    pub read_half: Mutex<Option<R>>,
112    pub write_half: Mutex<W>,
113}
114
115impl<R, W> StreamConnectionInner<R, W>
116where
117    R: AsyncRead + Unpin + Send,
118    W: AsyncWrite + Unpin + Send,
119{
120    pub fn new(local_addr: SipAddr, remote_addr: SipAddr, read_half: R, write_half: W) -> Self {
121        Self {
122            local_addr,
123            remote_addr,
124            read_half: Mutex::new(Some(read_half)),
125            write_half: Mutex::new(write_half),
126        }
127    }
128
129    pub async fn send_message(&self, msg: SipMessage) -> Result<()> {
130        send_to_stream(&self.write_half, msg).await
131    }
132
133    pub async fn send_raw(&self, data: &[u8]) -> Result<()> {
134        send_raw_to_stream(&self.write_half, data).await
135    }
136
137    pub async fn serve_loop(
138        &self,
139        sender: TransportSender,
140        connection: SipConnection,
141    ) -> Result<()> {
142        let mut read_half = match self.read_half.lock().await.take() {
143            Some(read_half) => read_half,
144            None => {
145                warn!("Connection closed");
146                return Ok(());
147            }
148        };
149
150        let remote_addr = self.remote_addr.clone();
151
152        let mut codec = SipCodec::new();
153        let mut buffer = BytesMut::with_capacity(MAX_SIP_MESSAGE_SIZE);
154        let mut read_buf = [0u8; MAX_SIP_MESSAGE_SIZE];
155
156        loop {
157            use tokio::io::AsyncReadExt;
158            match read_half.read(&mut read_buf).await {
159                Ok(0) => {
160                    info!("Connection closed: {}", self.local_addr);
161                    break;
162                }
163                Ok(n) => {
164                    buffer.extend_from_slice(&read_buf[0..n]);
165
166                    loop {
167                        match codec.decode(&mut buffer) {
168                            Ok(Some(msg)) => match msg {
169                                SipCodecType::Message(sip_msg) => {
170                                    debug!("Received message from {}: {}", remote_addr, sip_msg);
171                                    let remote_socket_addr = remote_addr.get_socketaddr()?;
172                                    let sip_msg = SipConnection::update_msg_received(
173                                        sip_msg,
174                                        remote_socket_addr,
175                                        remote_addr.r#type.unwrap_or_default(),
176                                    )?;
177
178                                    if let Err(e) = sender.send(TransportEvent::Incoming(
179                                        sip_msg,
180                                        connection.clone(),
181                                        remote_addr.clone(),
182                                    )) {
183                                        warn!("Error sending incoming message: {:?}", e);
184                                        return Err(e.into());
185                                    }
186                                }
187                                SipCodecType::KeepaliveRequest => {
188                                    self.send_raw(KEEPALIVE_RESPONSE).await?;
189                                }
190                                SipCodecType::KeepaliveResponse => {}
191                            },
192                            Ok(None) => {
193                                // Need more data
194                                break;
195                            }
196                            Err(e) => {
197                                warn!("Error decoding message from {}: {:?}", remote_addr, e);
198                                // Continue processing despite decode errors
199                            }
200                        }
201                    }
202                }
203                Err(e) => {
204                    warn!("Error reading from stream: {}", e);
205                    break;
206                }
207            }
208        }
209        Ok(())
210    }
211
212    pub async fn close(&self) -> Result<()> {
213        let mut write_half = self.write_half.lock().await;
214        write_half
215            .shutdown()
216            .await
217            .map_err(|e| crate::Error::Error(format!("Failed to shutdown write half: {}", e)))?;
218        Ok(())
219    }
220}
221
222#[async_trait::async_trait]
223pub trait StreamConnection: Send + Sync + 'static {
224    fn get_addr(&self) -> &SipAddr;
225    async fn send_message(&self, msg: SipMessage) -> Result<()>;
226    async fn send_raw(&self, data: &[u8]) -> Result<()>;
227    async fn serve_loop(&self, sender: TransportSender) -> Result<()>;
228    async fn close(&self) -> Result<()>;
229}
230
231pub async fn send_to_stream<W>(write_half: &Mutex<W>, msg: SipMessage) -> Result<()>
232where
233    W: AsyncWrite + Unpin + Send,
234{
235    send_raw_to_stream(write_half, msg.to_string().as_bytes()).await
236}
237
238pub async fn send_raw_to_stream<W>(write_half: &Mutex<W>, data: &[u8]) -> Result<()>
239where
240    W: AsyncWrite + Unpin + Send,
241{
242    let mut lock = write_half.lock().await;
243    lock.write_all(data).await?;
244    lock.flush().await?;
245    Ok(())
246}