foctet_mux/
session.rs

1use crate::stream::{LogicalStream, StreamEvent, StreamState};
2use anyhow::Result;
3use foctet_core::{
4    codec::FrameCodec, connection::SessionId, frame::{Frame, FrameFlags}, stream::StreamId
5};
6use futures::{SinkExt, StreamExt};
7use nohash_hasher::IntMap;
8use std::{marker::PhantomData, net::SocketAddr};
9use tokio::{
10    io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf},
11    sync::{
12        mpsc::{self, Receiver, Sender},
13        oneshot,
14    },
15    time::Interval,
16};
17use tokio_util::{
18    codec::{FramedRead, FramedWrite},
19    sync::CancellationToken,
20    task::AbortOnDropHandle,
21};
22
23/// Session side, client or server
24#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
25pub enum SessionSide {
26    /// The session is a client
27    Client,
28    /// The session is a server (typical low level stream is an accepted TcpStream)
29    Server,
30}
31
32impl SessionSide {
33    /// If this is a client type (inbound connection)
34    pub fn is_client(self) -> bool {
35        self == SessionSide::Client
36    }
37
38    /// If this is a server type (outbound connection)
39    pub fn is_server(self) -> bool {
40        self == SessionSide::Server
41    }
42}
43
44#[derive(Debug)]
45pub enum Command {
46    OpenStream(oneshot::Sender<Result<LogicalStream>>),
47    Shutdown(oneshot::Sender<()>),
48}
49
50pub struct SessionActor<T> {
51    /// Framed low level raw stream writer
52    framed_writer: FramedWrite<WriteHalf<T>, FrameCodec>,
53    /// Framed low level raw stream reader
54    framed_reader: FramedRead<ReadHalf<T>, FrameCodec>,
55    /// Session ID
56    session_id: SessionId,
57    /// next_stream_id is the next stream we should
58    /// send. This depends if we are a client/server.
59    next_stream_id: StreamId,
60    /// remote_closed indicates the remote side does
61    /// not want further connections. Must be first for alignment.
62    remote_closed: bool,
63    /// local_closed indicates that we should stop
64    /// accepting further connections. Must be first for alignment.
65    local_closed: bool,
66    /// pending_streams maps a stream id to a sender of logical-stream.
67    /// waiting for connection_response
68    pending_streams: IntMap<StreamId, oneshot::Sender<Result<LogicalStream>>>,
69    /// streams maps a stream id to a sender of stream,
70    streams: IntMap<StreamId, Sender<Frame>>,
71    /// For receive events from sub streams (for clone to new stream)
72    event_sender: Sender<StreamEvent>,
73    /// For receive events from sub streams
74    event_receiver: Receiver<StreamEvent>,
75    /// Receive control command from session
76    control_receiver: Receiver<Command>,
77    /// Send the new incoming logical-stream to the session
78    stream_sender: Sender<LogicalStream>,
79    /// Keepalive interval
80    keepalive: Option<Interval>,
81    /// Cancel token
82    cancel: CancellationToken,
83}
84
85impl<T> SessionActor<T>
86where
87    T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
88{
89    pub async fn run(mut self) {
90        loop {
91            tokio::select! {
92                _ = self.cancel.cancelled() => {
93                    tracing::info!("SessionActor loop cancelled, closing loop");
94                    break;
95                }
96                Some(frame_result) = self.framed_reader.next() => {
97                    match frame_result {
98                        Ok(frame) => {
99                            if let Err(e) = self.handle_incoming_frame(frame).await {
100                                tracing::error!("Error handling incoming frame: {:?}", e);
101                                break;
102                            }
103                        }
104                        Err(e) => {
105                            tracing::error!("Framed reader error: {:?}", e);
106                            self.remote_closed = true;
107                            break;
108                        }
109                    }
110                }
111                Some(cmd) = self.control_receiver.recv() => {
112                    if let Err(e) = self.handle_control_command(cmd).await {
113                        tracing::error!("Error handling control command: {:?}", e);
114                        break;
115                    }
116                }
117                Some(event) = self.event_receiver.recv() => {
118                    if let Err(e) = self.handle_stream_event(event).await {
119                        tracing::error!("Error handling stream event: {:?}", e);
120                        break;
121                    }
122                }
123            }
124        }
125
126        // Shutdown session
127        self.shutdown().await;
128    }
129
130    async fn handle_incoming_frame(&mut self, frame: Frame) -> Result<(), anyhow::Error> {
131        let stream_id = frame.header.stream_id;
132
133        if let Some(sender) = self.streams.get(&stream_id) {
134            // Send the frame to the logical stream
135            if let Err(e) = sender.send(frame).await {
136                tracing::error!("Failed to send frame to stream {}: {:?}", stream_id.0, e);
137                self.streams.remove(&stream_id);
138            }
139        } else {
140            if frame.header.flags.is_open_request() {
141                // Create and send new LogicalStream
142                let (stream_sender, stream_receiver) = tokio::sync::mpsc::channel(32);
143
144                let logical_stream = LogicalStream::new(
145                    self.session_id, 
146                    stream_id,
147                    StreamState::Established, 
148                    self.event_sender.clone(), 
149                    stream_receiver
150                );
151
152                // Regist new stream to map
153                self.streams.insert(stream_id, stream_sender);
154
155                // Send new LogicalStream to session
156                if let Err(e) = self.stream_sender.send(logical_stream).await {
157                    tracing::error!("Failed to send new stream to session: {:?}", e);
158                    self.streams.remove(&stream_id);
159                }
160
161                // Send open response to remote
162                let open_response_frame = Frame::builder()
163                    .with_stream_id(stream_id)
164                    .with_flags(FrameFlags::open_response())
165                    .build();
166                
167                if let Err(e) = self.framed_writer.send(open_response_frame).await {
168                    tracing::error!("Failed to send open response: {:?}", e);
169                }
170
171                tracing::debug!("New stream accepted: {}", stream_id.0);
172            } else if frame.header.flags.is_open_response() {
173                // Open response received. Send to pending stream.
174                if let Some(sender) = self.pending_streams.remove(&stream_id) {
175                    let (stream_sender, stream_receiver) = tokio::sync::mpsc::channel(32);
176                    // Regist new stream to map
177                    self.streams.insert(stream_id, stream_sender);
178                    // Send new LogicalStream to waiting channel
179                    if let Err(e) = sender.send(Ok(LogicalStream::new(
180                        self.session_id, 
181                        stream_id,
182                        StreamState::Established, 
183                        self.event_sender.clone(), 
184                        stream_receiver
185                    ))) {
186                        tracing::error!("Failed to send new LogicalStream: {:?}", e);
187                    }
188                } else {
189                    tracing::error!("Received open response for unknown stream {}", stream_id.0);
190                }
191            } else if frame.header.flags.is_open_reset() {
192                // Open reset received: stream was rejected by remote
193                if let Some(sender) = self.pending_streams.remove(&stream_id) {
194                    let _ = sender.send(Err(anyhow::anyhow!(
195                        "Stream {} rejected by remote", stream_id.0
196                    )));
197                    tracing::debug!("Stream {} was rejected by remote", stream_id.0);
198                } else {
199                    tracing::warn!("Received open_reset for unknown pending stream {}", stream_id.0);
200                }
201            } else {
202                // Unknown stream and NOT a open request
203                tracing::error!("Received frame for unknown stream {} without open request", stream_id.0);
204                // TODO: Should send RESET?
205            }
206        }
207
208        Ok(())
209    }
210
211    async fn handle_control_command(&mut self, cmd: Command) -> Result<(), anyhow::Error> {
212        match cmd {
213            Command::OpenStream(reply_tx) => {
214                // Get new stream ID
215                let stream_id = self.next_stream_id.fetch_add(1);
216                let (resp_tx, resp_rx) = oneshot::channel();
217
218                // Regist new stream responder to map
219                self.pending_streams.insert(stream_id, resp_tx);
220
221                // Send open request to remote
222                let open_frame = Frame::builder()
223                .with_stream_id(stream_id)
224                .with_flags(FrameFlags::open_request())
225                .build();
226
227                self.framed_writer.send(open_frame).await?;
228
229                // Wait for open response
230                tokio::spawn(async move {
231                    match resp_rx.await {
232                        Ok(Ok(stream)) => {
233                            let _ = reply_tx.send(Ok(stream));
234                        }
235                        Ok(Err(e)) => {
236                            let _ = reply_tx.send(Err(e));
237                        }
238                        Err(_) => {
239                            let _ = reply_tx.send(Err(anyhow::anyhow!("No response received")));
240                        }
241                    }
242                });
243
244                tracing::debug!("New stream opened: {}", stream_id.0);
245            }
246            Command::Shutdown(reply_tx) => {
247                // Set local closed
248                self.local_closed = true;
249                let _ = reply_tx.send(());
250            }
251        }
252
253        Ok(())
254    }
255
256    async fn handle_stream_event(&mut self, event: StreamEvent) -> Result<(), anyhow::Error> {
257        match event {
258            StreamEvent::Frame(frame) => {
259                // Send the frame via RAW stream
260                self.framed_writer.send(frame).await.map_err(|e| {
261                    anyhow::anyhow!("Failed to send frame to writer: {:?}", e)
262                })?;
263            }
264            StreamEvent::Closed(stream_id) => {
265                // logical-stream closed. Remove from map.
266                self.streams.remove(&stream_id);
267                tracing::debug!("Stream {} closed and removed", stream_id);
268            }
269            StreamEvent::Error => {
270                // Error from logical-stream.
271                // Currently omitted. Only log.
272                tracing::warn!("Stream event error received");
273                // TODO: handle error
274            }
275        }
276    
277        Ok(())
278    }
279
280    async fn shutdown(&mut self) {
281        tracing::info!("Session {} shutting down", self.session_id);
282
283        // Close all stream
284        self.streams.clear();
285        tracing::debug!("All logical streams closed");
286
287        // Close framed_writer
288        if let Err(e) = self.framed_writer.close().await {
289            tracing::warn!("Error while closing framed writer: {:?}", e);
290        } else {
291            tracing::info!("Framed writer closed successfully");
292        }
293    }
294
295    pub async fn keepalive_tick(&mut self) {
296        if let Some(_keepalive) = &mut self.keepalive {
297            // TODO: implement keepalive
298        }
299    }
300
301}
302
303/// The session
304pub struct Session<T> {
305    _marker: PhantomData<T>,
306    /// Session ID
307    session_id: SessionId,
308    /// Client or Server
309    side: SessionSide,
310    /// SessionActor handle
311    handle: AbortOnDropHandle<()>,
312    /// Send control command to SessionActor
313    control_sender: Sender<Command>,
314    /// Receive new incoming logical-stream from SessionActor
315    stream_receiver: Receiver<LogicalStream>,
316    /// Cancel token
317    cancel: CancellationToken,
318    /// Local socket address
319    local_addr: Option<SocketAddr>,
320    /// Remote socket address
321    remote_addr: Option<SocketAddr>,
322}
323
324impl<T> Session<T> 
325where
326    T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
327{
328    pub async fn spawn(
329        stream: T,
330        side: SessionSide,
331        session_id: SessionId,
332        local_addr: Option<SocketAddr>,
333        remote_addr: Option<SocketAddr>,
334    ) -> Self {    
335        let (read_half, write_half) = tokio::io::split(stream);
336        let framed_reader = FramedRead::new(read_half, FrameCodec::new());
337        let framed_writer = FramedWrite::new(write_half, FrameCodec::new());
338    
339        let (event_sender, event_receiver) = mpsc::channel(32);
340        let (control_sender, control_receiver) = mpsc::channel(8);
341        let (stream_sender, stream_receiver) = mpsc::channel(32);
342    
343        let next_stream_id = match side {
344            SessionSide::Client => StreamId(1),
345            SessionSide::Server => StreamId(2),
346        };
347
348        let cancel = CancellationToken::new();
349        let cancel_clone = cancel.clone();
350    
351        let actor = SessionActor {
352            framed_reader,
353            framed_writer,
354            session_id,
355            next_stream_id,
356            remote_closed: false,
357            local_closed: false,
358            pending_streams: IntMap::default(),
359            streams: IntMap::default(),
360            event_sender: event_sender.clone(),
361            event_receiver,
362            control_receiver,
363            stream_sender: stream_sender.clone(),
364            keepalive: None,
365            cancel: cancel_clone,
366        };
367    
368        let handle = tokio::spawn(async move {
369            actor.run().await;
370        });
371    
372        let handle = AbortOnDropHandle::new(handle);
373    
374        Session {
375            _marker: PhantomData,
376            session_id,
377            side,
378            handle,
379            control_sender,
380            stream_receiver,
381            cancel,
382            local_addr,
383            remote_addr,
384        }
385    }
386    pub async fn new_client(raw_stream: T, session_id: SessionId) -> Self {
387        Self::spawn(raw_stream, SessionSide::Client, session_id, None, None).await
388    }
389    pub async fn new_server(raw_stream: T, session_id: SessionId) -> Self {
390        Self::spawn(raw_stream, SessionSide::Server, session_id, None, None).await
391    }
392    pub async fn open_stream(&self) -> Result<LogicalStream, anyhow::Error> {
393        let (reply_tx, reply_rx) = tokio::sync::oneshot::channel();
394
395        self.control_sender.send(Command::OpenStream(reply_tx)).await.map_err(|e| {
396            anyhow::anyhow!("Failed to send OpenStream command: {:?}", e)
397        })?;
398
399        match reply_rx.await {
400            Ok(Ok(stream)) => Ok(stream),
401            Ok(Err(e)) => Err(anyhow::anyhow!("Stream open failed: {:?}", e)),
402            Err(e) => Err(anyhow::anyhow!("Stream open response failed: {:?}", e)),
403        }
404    }
405
406    pub async fn accept_stream(&mut self) -> Result<LogicalStream, anyhow::Error> {
407        match self.stream_receiver.recv().await {
408            Some(stream) => Ok(stream),
409            None => Err(anyhow::anyhow!("Session closed")),
410        }
411    }
412
413    pub async fn shutdown(&self) -> Result<(), anyhow::Error> {
414        let (reply_tx, reply_rx) = tokio::sync::oneshot::channel();
415
416        self.control_sender.send(Command::Shutdown(reply_tx)).await.map_err(|e| {
417            anyhow::anyhow!("Failed to send Shutdown command: {:?}", e)
418        })?;
419
420        let _ = reply_rx.await;
421        self.cancel.cancel();
422        Ok(())
423    }
424
425    pub fn session_id(&self) -> SessionId {
426        self.session_id
427    }
428    pub fn side(&self) -> SessionSide {
429        self.side
430    }
431
432    pub fn is_active(&self) -> bool {
433        !self.handle.is_finished()
434    }
435
436    pub fn set_local_addr(&mut self, addr: SocketAddr) {
437        self.local_addr = Some(addr);
438    }
439
440    pub fn set_remote_addr(&mut self, addr: SocketAddr) {
441        self.remote_addr = Some(addr);
442    }
443
444    pub fn local_addr(&self) -> Option<SocketAddr> {
445        self.local_addr
446    }
447
448    pub fn remote_addr(&self) -> Option<SocketAddr> {
449        self.remote_addr
450    }
451
452}