ftth_rsipstack/transport/
stream.rs

1use crate::{
2    transport::{
3        connection::{TransportSender, KEEPALIVE_REQUEST, KEEPALIVE_RESPONSE},
4        SipAddr, SipConnection, TransportEvent,
5    },
6    Result,
7};
8use bytes::{Buf, BytesMut};
9use rsip::SipMessage;
10use tokio::{
11    io::{AsyncRead, AsyncWrite, AsyncWriteExt},
12    sync::Mutex,
13};
14use tokio_util::codec::{Decoder, Encoder};
15use tracing::{debug, info, warn};
16
17pub(super) const MAX_SIP_MESSAGE_SIZE: usize = 65535;
18
19pub struct SipCodec {}
20
21impl SipCodec {
22    pub fn new() -> Self {
23        Self {}
24    }
25}
26
27impl Default for SipCodec {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33#[derive(Debug, Clone)]
34pub enum SipCodecType {
35    Message(SipMessage),
36    KeepaliveRequest,
37    KeepaliveResponse,
38}
39
40impl std::fmt::Display for SipCodecType {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        match self {
43            SipCodecType::Message(msg) => write!(f, "{}", msg),
44            SipCodecType::KeepaliveRequest => write!(f, "Keepalive Request"),
45            SipCodecType::KeepaliveResponse => write!(f, "Keepalive Response"),
46        }
47    }
48}
49
50impl Decoder for SipCodec {
51    type Item = SipCodecType;
52    type Error = crate::Error;
53
54    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
55        if src.len() >= 4 && &src[0..4] == KEEPALIVE_REQUEST {
56            src.advance(4);
57            return Ok(Some(SipCodecType::KeepaliveRequest));
58        }
59
60        if src.len() >= 2 && &src[0..2] == KEEPALIVE_RESPONSE {
61            src.advance(2);
62            return Ok(Some(SipCodecType::KeepaliveResponse));
63        }
64
65        if let Some(end_pos) = src
66            .windows(KEEPALIVE_REQUEST.len())
67            .position(|window| window == KEEPALIVE_REQUEST)
68        {
69            let msg_end = end_pos + KEEPALIVE_REQUEST.len();
70            let msg_data = &src[..msg_end];
71            match SipMessage::try_from(msg_data) {
72                Ok(msg) => {
73                    src.advance(msg_end);
74                    Ok(Some(SipCodecType::Message(msg)))
75                }
76                Err(e) => {
77                    src.advance(msg_end);
78                    Err(crate::Error::Error(format!(
79                        "Failed to parse SIP message: {}",
80                        e
81                    )))
82                }
83            }
84        } else {
85            if src.len() > MAX_SIP_MESSAGE_SIZE {
86                return Err(crate::Error::Error("SIP message too large".to_string()));
87            }
88            Ok(None)
89        }
90    }
91}
92
93impl Encoder<SipMessage> for SipCodec {
94    type Error = crate::Error;
95
96    fn encode(&mut self, item: SipMessage, dst: &mut BytesMut) -> Result<()> {
97        let data = item.to_string();
98        dst.extend_from_slice(data.as_bytes());
99        Ok(())
100    }
101}
102
103pub struct StreamConnectionInner<R, W>
104where
105    R: AsyncRead + Unpin + Send,
106    W: AsyncWrite + Unpin + Send,
107{
108    pub local_addr: SipAddr,
109    pub remote_addr: SipAddr,
110    pub read_half: Mutex<Option<R>>,
111    pub write_half: Mutex<W>,
112}
113
114impl<R, W> StreamConnectionInner<R, W>
115where
116    R: AsyncRead + Unpin + Send,
117    W: AsyncWrite + Unpin + Send,
118{
119    pub fn new(local_addr: SipAddr, remote_addr: SipAddr, read_half: R, write_half: W) -> Self {
120        Self {
121            local_addr,
122            remote_addr,
123            read_half: Mutex::new(Some(read_half)),
124            write_half: Mutex::new(write_half),
125        }
126    }
127
128    pub async fn send_message(&self, msg: SipMessage) -> Result<()> {
129        send_to_stream(&self.write_half, msg).await
130    }
131
132    pub async fn send_raw(&self, data: &[u8]) -> Result<()> {
133        send_raw_to_stream(&self.write_half, data).await
134    }
135
136    pub async fn serve_loop(
137        &self,
138        sender: TransportSender,
139        connection: SipConnection,
140    ) -> Result<()> {
141        let mut read_half = match self.read_half.lock().await.take() {
142            Some(read_half) => read_half,
143            None => {
144                warn!("Connection closed");
145                return Ok(());
146            }
147        };
148
149        let remote_addr = self.remote_addr.clone();
150
151        let mut codec = SipCodec::new();
152        let mut buffer = BytesMut::with_capacity(MAX_SIP_MESSAGE_SIZE);
153        let mut read_buf = [0u8; MAX_SIP_MESSAGE_SIZE];
154
155        loop {
156            use tokio::io::AsyncReadExt;
157            match read_half.read(&mut read_buf).await {
158                Ok(0) => {
159                    info!("Connection closed: {}", self.local_addr);
160                    break;
161                }
162                Ok(n) => {
163                    buffer.extend_from_slice(&read_buf[0..n]);
164
165                    loop {
166                        match codec.decode(&mut buffer) {
167                            Ok(Some(msg)) => match msg {
168                                SipCodecType::Message(sip_msg) => {
169                                    debug!("Received message from {}: {}", remote_addr, sip_msg);
170                                    let remote_socket_addr = remote_addr.get_socketaddr()?;
171                                    let sip_msg = SipConnection::update_msg_received(
172                                        sip_msg,
173                                        remote_socket_addr,
174                                        remote_addr.r#type.unwrap_or_default(),
175                                    )?;
176
177                                    if let Err(e) = sender.send(TransportEvent::Incoming(
178                                        sip_msg,
179                                        connection.clone(),
180                                        remote_addr.clone(),
181                                    )) {
182                                        warn!("Error sending incoming message: {:?}", e);
183                                        return Err(e.into());
184                                    }
185                                }
186                                SipCodecType::KeepaliveRequest => {
187                                    self.send_raw(KEEPALIVE_RESPONSE).await?;
188                                }
189                                SipCodecType::KeepaliveResponse => {}
190                            },
191                            Ok(None) => {
192                                // Need more data
193                                break;
194                            }
195                            Err(e) => {
196                                warn!("Error decoding message from {}: {:?}", remote_addr, e);
197                                // Continue processing despite decode errors
198                            }
199                        }
200                    }
201                }
202                Err(e) => {
203                    warn!("Error reading from stream: {}", e);
204                    break;
205                }
206            }
207        }
208        Ok(())
209    }
210
211    pub async fn close(&self) -> Result<()> {
212        let mut write_half = self.write_half.lock().await;
213        write_half
214            .shutdown()
215            .await
216            .map_err(|e| crate::Error::Error(format!("Failed to shutdown write half: {}", e)))?;
217        Ok(())
218    }
219}
220
221#[async_trait::async_trait]
222pub trait StreamConnection: Send + Sync + 'static {
223    fn get_addr(&self) -> &SipAddr;
224    async fn send_message(&self, msg: SipMessage) -> Result<()>;
225    async fn send_raw(&self, data: &[u8]) -> Result<()>;
226    async fn serve_loop(&self, sender: TransportSender) -> Result<()>;
227    async fn close(&self) -> Result<()>;
228}
229
230pub async fn send_to_stream<W>(write_half: &Mutex<W>, msg: SipMessage) -> Result<()>
231where
232    W: AsyncWrite + Unpin + Send,
233{
234    send_raw_to_stream(write_half, msg.to_string().as_bytes()).await
235}
236
237pub async fn send_raw_to_stream<W>(write_half: &Mutex<W>, data: &[u8]) -> Result<()>
238where
239    W: AsyncWrite + Unpin + Send,
240{
241    let mut lock = write_half.lock().await;
242    lock.write_all(data).await?;
243    lock.flush().await?;
244    Ok(())
245}