1use crate::stream::{LogicalStream, StreamEvent, StreamState};
2use anyhow::Result;
3use foctet_core::{
4 codec::FrameCodec, connection::SessionId, frame::{Frame, FrameFlags}, stream::StreamId
5};
6use futures::{SinkExt, StreamExt};
7use nohash_hasher::IntMap;
8use std::{marker::PhantomData, net::SocketAddr};
9use tokio::{
10 io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf},
11 sync::{
12 mpsc::{self, Receiver, Sender},
13 oneshot,
14 },
15 time::Interval,
16};
17use tokio_util::{
18 codec::{FramedRead, FramedWrite},
19 sync::CancellationToken,
20 task::AbortOnDropHandle,
21};
22
23#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
25pub enum SessionSide {
26 Client,
28 Server,
30}
31
32impl SessionSide {
33 pub fn is_client(self) -> bool {
35 self == SessionSide::Client
36 }
37
38 pub fn is_server(self) -> bool {
40 self == SessionSide::Server
41 }
42}
43
44#[derive(Debug)]
45pub enum Command {
46 OpenStream(oneshot::Sender<Result<LogicalStream>>),
47 Shutdown(oneshot::Sender<()>),
48}
49
50pub struct SessionActor<T> {
51 framed_writer: FramedWrite<WriteHalf<T>, FrameCodec>,
53 framed_reader: FramedRead<ReadHalf<T>, FrameCodec>,
55 session_id: SessionId,
57 next_stream_id: StreamId,
60 remote_closed: bool,
63 local_closed: bool,
66 pending_streams: IntMap<StreamId, oneshot::Sender<Result<LogicalStream>>>,
69 streams: IntMap<StreamId, Sender<Frame>>,
71 event_sender: Sender<StreamEvent>,
73 event_receiver: Receiver<StreamEvent>,
75 control_receiver: Receiver<Command>,
77 stream_sender: Sender<LogicalStream>,
79 keepalive: Option<Interval>,
81 cancel: CancellationToken,
83}
84
85impl<T> SessionActor<T>
86where
87 T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
88{
89 pub async fn run(mut self) {
90 loop {
91 tokio::select! {
92 _ = self.cancel.cancelled() => {
93 tracing::info!("SessionActor loop cancelled, closing loop");
94 break;
95 }
96 Some(frame_result) = self.framed_reader.next() => {
97 match frame_result {
98 Ok(frame) => {
99 if let Err(e) = self.handle_incoming_frame(frame).await {
100 tracing::error!("Error handling incoming frame: {:?}", e);
101 break;
102 }
103 }
104 Err(e) => {
105 tracing::error!("Framed reader error: {:?}", e);
106 self.remote_closed = true;
107 break;
108 }
109 }
110 }
111 Some(cmd) = self.control_receiver.recv() => {
112 if let Err(e) = self.handle_control_command(cmd).await {
113 tracing::error!("Error handling control command: {:?}", e);
114 break;
115 }
116 }
117 Some(event) = self.event_receiver.recv() => {
118 if let Err(e) = self.handle_stream_event(event).await {
119 tracing::error!("Error handling stream event: {:?}", e);
120 break;
121 }
122 }
123 }
124 }
125
126 self.shutdown().await;
128 }
129
130 async fn handle_incoming_frame(&mut self, frame: Frame) -> Result<(), anyhow::Error> {
131 let stream_id = frame.header.stream_id;
132
133 if let Some(sender) = self.streams.get(&stream_id) {
134 if let Err(e) = sender.send(frame).await {
136 tracing::error!("Failed to send frame to stream {}: {:?}", stream_id.0, e);
137 self.streams.remove(&stream_id);
138 }
139 } else {
140 if frame.header.flags.is_open_request() {
141 let (stream_sender, stream_receiver) = tokio::sync::mpsc::channel(32);
143
144 let logical_stream = LogicalStream::new(
145 self.session_id,
146 stream_id,
147 StreamState::Established,
148 self.event_sender.clone(),
149 stream_receiver
150 );
151
152 self.streams.insert(stream_id, stream_sender);
154
155 if let Err(e) = self.stream_sender.send(logical_stream).await {
157 tracing::error!("Failed to send new stream to session: {:?}", e);
158 self.streams.remove(&stream_id);
159 }
160
161 let open_response_frame = Frame::builder()
163 .with_stream_id(stream_id)
164 .with_flags(FrameFlags::open_response())
165 .build();
166
167 if let Err(e) = self.framed_writer.send(open_response_frame).await {
168 tracing::error!("Failed to send open response: {:?}", e);
169 }
170
171 tracing::debug!("New stream accepted: {}", stream_id.0);
172 } else if frame.header.flags.is_open_response() {
173 if let Some(sender) = self.pending_streams.remove(&stream_id) {
175 let (stream_sender, stream_receiver) = tokio::sync::mpsc::channel(32);
176 self.streams.insert(stream_id, stream_sender);
178 if let Err(e) = sender.send(Ok(LogicalStream::new(
180 self.session_id,
181 stream_id,
182 StreamState::Established,
183 self.event_sender.clone(),
184 stream_receiver
185 ))) {
186 tracing::error!("Failed to send new LogicalStream: {:?}", e);
187 }
188 } else {
189 tracing::error!("Received open response for unknown stream {}", stream_id.0);
190 }
191 } else if frame.header.flags.is_open_reset() {
192 if let Some(sender) = self.pending_streams.remove(&stream_id) {
194 let _ = sender.send(Err(anyhow::anyhow!(
195 "Stream {} rejected by remote", stream_id.0
196 )));
197 tracing::debug!("Stream {} was rejected by remote", stream_id.0);
198 } else {
199 tracing::warn!("Received open_reset for unknown pending stream {}", stream_id.0);
200 }
201 } else {
202 tracing::error!("Received frame for unknown stream {} without open request", stream_id.0);
204 }
206 }
207
208 Ok(())
209 }
210
211 async fn handle_control_command(&mut self, cmd: Command) -> Result<(), anyhow::Error> {
212 match cmd {
213 Command::OpenStream(reply_tx) => {
214 let stream_id = self.next_stream_id.fetch_add(1);
216 let (resp_tx, resp_rx) = oneshot::channel();
217
218 self.pending_streams.insert(stream_id, resp_tx);
220
221 let open_frame = Frame::builder()
223 .with_stream_id(stream_id)
224 .with_flags(FrameFlags::open_request())
225 .build();
226
227 self.framed_writer.send(open_frame).await?;
228
229 tokio::spawn(async move {
231 match resp_rx.await {
232 Ok(Ok(stream)) => {
233 let _ = reply_tx.send(Ok(stream));
234 }
235 Ok(Err(e)) => {
236 let _ = reply_tx.send(Err(e));
237 }
238 Err(_) => {
239 let _ = reply_tx.send(Err(anyhow::anyhow!("No response received")));
240 }
241 }
242 });
243
244 tracing::debug!("New stream opened: {}", stream_id.0);
245 }
246 Command::Shutdown(reply_tx) => {
247 self.local_closed = true;
249 let _ = reply_tx.send(());
250 }
251 }
252
253 Ok(())
254 }
255
256 async fn handle_stream_event(&mut self, event: StreamEvent) -> Result<(), anyhow::Error> {
257 match event {
258 StreamEvent::Frame(frame) => {
259 self.framed_writer.send(frame).await.map_err(|e| {
261 anyhow::anyhow!("Failed to send frame to writer: {:?}", e)
262 })?;
263 }
264 StreamEvent::Closed(stream_id) => {
265 self.streams.remove(&stream_id);
267 tracing::debug!("Stream {} closed and removed", stream_id);
268 }
269 StreamEvent::Error => {
270 tracing::warn!("Stream event error received");
273 }
275 }
276
277 Ok(())
278 }
279
280 async fn shutdown(&mut self) {
281 tracing::info!("Session {} shutting down", self.session_id);
282
283 self.streams.clear();
285 tracing::debug!("All logical streams closed");
286
287 if let Err(e) = self.framed_writer.close().await {
289 tracing::warn!("Error while closing framed writer: {:?}", e);
290 } else {
291 tracing::info!("Framed writer closed successfully");
292 }
293 }
294
295 pub async fn keepalive_tick(&mut self) {
296 if let Some(_keepalive) = &mut self.keepalive {
297 }
299 }
300
301}
302
303pub struct Session<T> {
305 _marker: PhantomData<T>,
306 session_id: SessionId,
308 side: SessionSide,
310 handle: AbortOnDropHandle<()>,
312 control_sender: Sender<Command>,
314 stream_receiver: Receiver<LogicalStream>,
316 cancel: CancellationToken,
318 local_addr: Option<SocketAddr>,
320 remote_addr: Option<SocketAddr>,
322}
323
324impl<T> Session<T>
325where
326 T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
327{
328 pub async fn spawn(
329 stream: T,
330 side: SessionSide,
331 session_id: SessionId,
332 local_addr: Option<SocketAddr>,
333 remote_addr: Option<SocketAddr>,
334 ) -> Self {
335 let (read_half, write_half) = tokio::io::split(stream);
336 let framed_reader = FramedRead::new(read_half, FrameCodec::new());
337 let framed_writer = FramedWrite::new(write_half, FrameCodec::new());
338
339 let (event_sender, event_receiver) = mpsc::channel(32);
340 let (control_sender, control_receiver) = mpsc::channel(8);
341 let (stream_sender, stream_receiver) = mpsc::channel(32);
342
343 let next_stream_id = match side {
344 SessionSide::Client => StreamId(1),
345 SessionSide::Server => StreamId(2),
346 };
347
348 let cancel = CancellationToken::new();
349 let cancel_clone = cancel.clone();
350
351 let actor = SessionActor {
352 framed_reader,
353 framed_writer,
354 session_id,
355 next_stream_id,
356 remote_closed: false,
357 local_closed: false,
358 pending_streams: IntMap::default(),
359 streams: IntMap::default(),
360 event_sender: event_sender.clone(),
361 event_receiver,
362 control_receiver,
363 stream_sender: stream_sender.clone(),
364 keepalive: None,
365 cancel: cancel_clone,
366 };
367
368 let handle = tokio::spawn(async move {
369 actor.run().await;
370 });
371
372 let handle = AbortOnDropHandle::new(handle);
373
374 Session {
375 _marker: PhantomData,
376 session_id,
377 side,
378 handle,
379 control_sender,
380 stream_receiver,
381 cancel,
382 local_addr,
383 remote_addr,
384 }
385 }
386 pub async fn new_client(raw_stream: T, session_id: SessionId) -> Self {
387 Self::spawn(raw_stream, SessionSide::Client, session_id, None, None).await
388 }
389 pub async fn new_server(raw_stream: T, session_id: SessionId) -> Self {
390 Self::spawn(raw_stream, SessionSide::Server, session_id, None, None).await
391 }
392 pub async fn open_stream(&self) -> Result<LogicalStream, anyhow::Error> {
393 let (reply_tx, reply_rx) = tokio::sync::oneshot::channel();
394
395 self.control_sender.send(Command::OpenStream(reply_tx)).await.map_err(|e| {
396 anyhow::anyhow!("Failed to send OpenStream command: {:?}", e)
397 })?;
398
399 match reply_rx.await {
400 Ok(Ok(stream)) => Ok(stream),
401 Ok(Err(e)) => Err(anyhow::anyhow!("Stream open failed: {:?}", e)),
402 Err(e) => Err(anyhow::anyhow!("Stream open response failed: {:?}", e)),
403 }
404 }
405
406 pub async fn accept_stream(&mut self) -> Result<LogicalStream, anyhow::Error> {
407 match self.stream_receiver.recv().await {
408 Some(stream) => Ok(stream),
409 None => Err(anyhow::anyhow!("Session closed")),
410 }
411 }
412
413 pub async fn shutdown(&self) -> Result<(), anyhow::Error> {
414 let (reply_tx, reply_rx) = tokio::sync::oneshot::channel();
415
416 self.control_sender.send(Command::Shutdown(reply_tx)).await.map_err(|e| {
417 anyhow::anyhow!("Failed to send Shutdown command: {:?}", e)
418 })?;
419
420 let _ = reply_rx.await;
421 self.cancel.cancel();
422 Ok(())
423 }
424
425 pub fn session_id(&self) -> SessionId {
426 self.session_id
427 }
428 pub fn side(&self) -> SessionSide {
429 self.side
430 }
431
432 pub fn is_active(&self) -> bool {
433 !self.handle.is_finished()
434 }
435
436 pub fn set_local_addr(&mut self, addr: SocketAddr) {
437 self.local_addr = Some(addr);
438 }
439
440 pub fn set_remote_addr(&mut self, addr: SocketAddr) {
441 self.remote_addr = Some(addr);
442 }
443
444 pub fn local_addr(&self) -> Option<SocketAddr> {
445 self.local_addr
446 }
447
448 pub fn remote_addr(&self) -> Option<SocketAddr> {
449 self.remote_addr
450 }
451
452}