1mod stream_manager;
28mod task;
29
30use std::marker::PhantomData;
31use std::sync::{
32 Arc,
33 atomic::{AtomicBool, Ordering},
34};
35
36use parking_lot::Once;
37use tokio::{
38 io::{self, AsyncRead, AsyncWrite},
39 sync::{broadcast, mpsc, oneshot},
40};
41
42use crate::{
43 Config, Stream,
44 alloc::{EVEN_START_STREAM_ID, ODD_START_STREAM_ID, StreamId, StreamIdAllocator},
45 consts::{CLIENT_MODE, SERVER_MODE, SessionMode},
46 error::Error,
47 msg::{self, Message},
48 session::stream_manager::StreamManager,
49};
50
51pub struct Session<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> {
57 config: Config,
58 stream_id_allocator: StreamIdAllocator,
59 stream_manager: Arc<StreamManager>,
60
61 stream_creation_rx: tokio::sync::Mutex<mpsc::UnboundedReceiver<StreamId>>,
62
63 shutdown_tx: broadcast::Sender<()>,
64 shutdown_once: Once,
65 is_shutdown: AtomicBool,
66
67 msg_tx: mpsc::Sender<Message>,
69 close_tx: mpsc::UnboundedSender<StreamId>,
70
71 _phantom: PhantomData<T>,
72}
73
74impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> Session<T> {
75 fn new(conn: T, config: Config, mode: SessionMode) -> Self {
76 let (conn_reader, conn_writer) = io::split(conn);
77 let (msg_tx, msg_rx) = mpsc::channel(config.conn_send_window_size);
78 let (close_tx, close_rx) = mpsc::unbounded_channel();
79 let (stream_creation_tx, stream_creation_rx) = mpsc::unbounded_channel();
80 let (shutdown_tx, shutdown_rx1) = broadcast::channel(1);
81 let shutdown_rx2 = shutdown_tx.subscribe();
82 let shutdown_rx3 = shutdown_tx.subscribe();
83
84 let session = Self {
85 config,
86 stream_id_allocator: StreamIdAllocator::new(match mode {
87 SERVER_MODE => ODD_START_STREAM_ID,
88 CLIENT_MODE => EVEN_START_STREAM_ID,
89 }),
90 stream_manager: Arc::new(StreamManager::new(stream_creation_tx)),
91 stream_creation_rx: tokio::sync::Mutex::new(stream_creation_rx),
92 shutdown_tx,
93 shutdown_once: Once::new(),
94 is_shutdown: AtomicBool::new(false),
95 msg_tx,
96 close_tx,
97 _phantom: PhantomData,
98 };
99
100 tokio::spawn(task::start_msg_collect_loop(
101 msg_rx,
102 conn_writer,
103 shutdown_rx1,
104 ));
105 tokio::spawn(task::start_frame_dispatch_loop(
106 conn_reader,
107 session.stream_manager.clone(),
108 shutdown_rx2,
109 ));
110 tokio::spawn(task::start_stream_close_listen(
111 close_rx,
112 session.stream_manager.clone(),
113 shutdown_rx3,
114 ));
115
116 session
117 }
118
119 pub fn server(conn: T, config: Config) -> Self {
121 Self::new(conn, config, SERVER_MODE)
122 }
123
124 pub fn client(conn: T, config: Config) -> Self {
126 Self::new(conn, config, CLIENT_MODE)
127 }
128
129 pub async fn open(&self) -> Result<Stream, Error> {
135 if self.is_shutdown.load(Ordering::SeqCst) {
136 return Err(Error::SessionClosed);
137 }
138 let shutdown_rx = self.shutdown_tx.subscribe();
139 let close_tx = self.close_tx.clone();
140 let msg_tx = self.msg_tx.clone();
141 let stream_id = self.stream_id_allocator.allocate();
142 let (frame_tx, frame_rx) = mpsc::channel(self.config.stream_recv_window_size);
143 let (remote_fin_tx, remote_fin_rx) = oneshot::channel();
144
145 let stream = Stream::new(
146 stream_id,
147 shutdown_rx,
148 msg_tx.clone(),
149 frame_rx,
150 close_tx,
151 remote_fin_rx,
152 );
153 let (remote_ack_tx, remote_ack_rx) = oneshot::channel();
154 self.stream_manager
155 .add_stream(stream_id, frame_tx, remote_fin_tx, Some(remote_ack_tx))?;
156 msg::send_syn(msg_tx, stream_id).await?;
157 remote_ack_rx
158 .await
159 .map_err(|_| Error::Internal("remote ack rx not found".to_string()))?;
160
161 Ok(stream)
162 }
163
164 pub async fn accept(&self) -> Result<Stream, Error> {
170 let stream_id = self
171 .stream_creation_rx
172 .lock()
173 .await
174 .recv()
175 .await
176 .ok_or(Error::SessionClosed)?;
177
178 let shutdown_rx = self.shutdown_tx.subscribe();
179 let close_tx = self.close_tx.clone();
180 let msg_tx = self.msg_tx.clone();
181 let (frame_tx, frame_rx) = mpsc::channel(self.config.stream_recv_window_size);
182 let (remote_fin_tx, remote_fin_rx) = oneshot::channel();
183
184 let stream = Stream::new(
185 stream_id,
186 shutdown_rx,
187 msg_tx,
188 frame_rx,
189 close_tx,
190 remote_fin_rx,
191 );
192 self.stream_manager
193 .add_stream(stream_id, frame_tx, remote_fin_tx, None)?;
194 msg::send_ack(self.msg_tx.clone(), stream_id).await?;
195
196 Ok(stream)
197 }
198
199 pub fn close(self) {
208 self.shutdown_once.call_once(|| {
209 self.is_shutdown.store(true, Ordering::SeqCst);
210 let _ = self.shutdown_tx.send(());
211 });
212 }
213}