1use std::{
2 cell::RefCell,
3 collections::HashMap,
4 future::Future,
5 io,
6 net::SocketAddr,
7 pin::Pin,
8 rc::Rc,
9 task::{Context, Poll},
10};
11
12use easyfix_messages::fields::{FixString, SeqNum, SessionStatus};
13use futures::{self, Stream};
14use pin_project::pin_project;
15use tokio::{
16 io::{AsyncRead, AsyncWrite},
17 net::TcpListener,
18 task::JoinHandle,
19};
20use tracing::{Instrument, error, info, info_span, instrument, warn};
21
22use crate::{
23 DisconnectReason, Settings,
24 application::{AsEvent, Emitter, EventStream, events_channel},
25 io::acceptor_connection,
26 messages_storage::MessagesStorage,
27 session::Session,
28 session_id::SessionId,
29 session_state::State as SessionState,
30 settings::SessionSettings,
31};
32
33#[allow(async_fn_in_trait)]
34pub trait Connection {
35 async fn accept(
36 &mut self,
37 ) -> Result<
38 (
39 impl AsyncRead + Unpin + 'static,
40 impl AsyncWrite + Unpin + 'static,
41 SocketAddr,
42 ),
43 io::Error,
44 >;
45}
46
47pub struct TcpConnection {
48 listener: TcpListener,
49}
50
51impl TcpConnection {
52 pub async fn new(socket_addr: impl Into<SocketAddr>) -> Result<TcpConnection, io::Error> {
53 let socket_addr = socket_addr.into();
54 let listener = TcpListener::bind(&socket_addr).await?;
55 Ok(TcpConnection { listener })
56 }
57}
58
59impl Connection for TcpConnection {
60 async fn accept(
61 &mut self,
62 ) -> Result<
63 (
64 impl AsyncRead + Unpin + 'static,
65 impl AsyncWrite + Unpin + 'static,
66 SocketAddr,
67 ),
68 io::Error,
69 > {
70 let (tcp_stream, peer_addr) = self.listener.accept().await?;
71 tcp_stream.set_nodelay(true)?;
72 let (reader, writer) = tcp_stream.into_split();
73 Ok((reader, writer, peer_addr))
74 }
75}
76
77type SessionMapInternal<S> = HashMap<SessionId, (SessionSettings, Rc<RefCell<SessionState<S>>>)>;
78
79pub struct SessionsMap<S> {
80 map: SessionMapInternal<S>,
81 message_storage_builder: Box<dyn Fn(&SessionId) -> S>,
82}
83
84impl<S: MessagesStorage> SessionsMap<S> {
85 fn new(message_storage_builder: Box<dyn Fn(&SessionId) -> S>) -> SessionsMap<S> {
86 SessionsMap {
87 map: HashMap::new(),
88 message_storage_builder,
89 }
90 }
91
92 #[rustfmt::skip]
93 pub fn register_session(&mut self, session_id: SessionId, session_settings: SessionSettings) {
94 self.map.insert(
95 session_id.clone(),
96 (
97 session_settings,
98 Rc::new(RefCell::new(SessionState::new(
99 (self.message_storage_builder)(&session_id),
100 ))),
101 ),
102 );
103 }
104
105 pub(crate) fn get_session(
106 &self,
107 session_id: &SessionId,
108 ) -> Option<(SessionSettings, Rc<RefCell<SessionState<S>>>)> {
109 self.map.get(session_id).cloned()
110 }
111}
112
113pub struct SessionTask<S> {
114 settings: Settings,
115 sessions: Rc<RefCell<SessionsMap<S>>>,
116 active_sessions: Rc<RefCell<ActiveSessionsMap<S>>>,
117 emitter: Emitter,
118}
119
120impl<S> Clone for SessionTask<S> {
121 fn clone(&self) -> Self {
122 Self {
123 settings: self.settings.clone(),
124 sessions: self.sessions.clone(),
125 active_sessions: self.active_sessions.clone(),
126 emitter: self.emitter.clone(),
127 }
128 }
129}
130
131impl<S: MessagesStorage + 'static> SessionTask<S> {
132 fn new(
133 settings: Settings,
134 sessions: Rc<RefCell<SessionsMap<S>>>,
135 active_sessions: Rc<RefCell<ActiveSessionsMap<S>>>,
136 emitter: Emitter,
137 ) -> SessionTask<S> {
138 SessionTask {
139 settings,
140 sessions,
141 active_sessions,
142 emitter,
143 }
144 }
145
146 pub async fn run(
147 self,
148 peer_addr: SocketAddr,
149 reader: impl AsyncRead + Unpin + 'static,
150 writer: impl AsyncWrite + Unpin + 'static,
151 ) {
152 let span = info_span!("connection", %peer_addr);
153
154 span.in_scope(|| {
155 info!("New connection");
156 });
157
158 acceptor_connection(
159 reader,
160 writer,
161 self.settings,
162 self.sessions,
163 self.active_sessions,
164 self.emitter,
165 )
166 .instrument(span.clone())
167 .await;
168
169 span.in_scope(|| {
170 info!("Connection closed");
171 });
172 }
173}
174
175pub(crate) type ActiveSessionsMap<S> = HashMap<SessionId, Rc<Session<S>>>;
176
177#[pin_project]
178pub struct Acceptor<S> {
179 sessions: Rc<RefCell<SessionsMap<S>>>,
180 active_sessions: Rc<RefCell<ActiveSessionsMap<S>>>,
181 session_task: SessionTask<S>,
182 #[pin]
183 event_stream: EventStream,
184}
185
186impl<S: MessagesStorage + 'static> Acceptor<S> {
187 pub fn new(
188 settings: Settings,
189 message_storage_builder: Box<dyn Fn(&SessionId) -> S>,
190 ) -> Acceptor<S> {
191 let (emitter, event_stream) = events_channel();
192 let sessions = Rc::new(RefCell::new(SessionsMap::new(message_storage_builder)));
193 let active_sessions = Rc::new(RefCell::new(HashMap::new()));
194 let session_task_builder =
195 SessionTask::new(settings, sessions.clone(), active_sessions.clone(), emitter);
196
197 Acceptor {
198 sessions,
199 active_sessions,
200 session_task: session_task_builder,
201 event_stream,
202 }
203 }
204
205 pub fn register_session(&mut self, session_id: SessionId, session_settings: SessionSettings) {
206 self.sessions
207 .borrow_mut()
208 .register_session(session_id, session_settings);
209 }
210
211 pub fn sessions_map(&self) -> Rc<RefCell<SessionsMap<S>>> {
212 self.sessions.clone()
213 }
214
215 pub fn start(&self, connection: impl Connection + 'static) -> JoinHandle<()> {
216 tokio::task::spawn_local(Self::server_task(connection, self.session_task.clone()))
217 }
218
219 pub fn logout(
220 &self,
221 session_id: &SessionId,
222 session_status: Option<SessionStatus>,
223 reason: Option<FixString>,
224 ) {
225 let active_sessions = self.active_sessions.borrow();
226 let Some(session) = active_sessions.get(session_id) else {
227 warn!("logout: session {session_id} not found");
228 return;
229 };
230
231 session.send_logout(&mut session.state().borrow_mut(), session_status, reason);
232 }
233
234 pub fn disconnect(&self, session_id: &SessionId) {
235 let mut active_sessions = self.active_sessions.borrow_mut();
236 let Some(session) = active_sessions.remove(session_id) else {
237 warn!("logout: session {session_id} not found");
238 return;
239 };
240
241 session.disconnect(
242 &mut session.state().borrow_mut(),
243 DisconnectReason::ApplicationForcedDisconnect,
244 );
245 }
246
247 pub fn reset(&self, session_id: &SessionId) {
254 let active_sessions = self.active_sessions.borrow();
255 let Some(session) = active_sessions.get(session_id) else {
256 warn!("reset: session {session_id} not found");
257 return;
258 };
259
260 session.reset(&mut session.state().borrow_mut());
261 }
262
263 #[instrument(skip(self))]
265 pub fn next_sender_msg_seq_num(&self, session_id: &SessionId) -> SeqNum {
266 let active_sessions = self.active_sessions.borrow();
267 let Some(session) = active_sessions.get(session_id) else {
268 warn!("session not found");
269 return 0;
270 };
271
272 let state = session.state().borrow_mut();
273 state.next_sender_msg_seq_num()
274 }
275
276 #[instrument(skip(self))]
278 pub fn set_next_sender_msg_seq_num(&self, session_id: &SessionId, seq_num: SeqNum) {
279 let active_sessions = self.active_sessions.borrow();
280 let Some(session) = active_sessions.get(session_id) else {
281 warn!("session not found");
282 return;
283 };
284
285 session
286 .state()
287 .borrow_mut()
288 .set_next_sender_msg_seq_num(seq_num);
289 }
290
291 async fn server_task(mut connection: impl Connection, session_task: SessionTask<S>) {
292 info!("Acceptor started");
293 loop {
294 match connection.accept().await {
295 Ok((reader, writer, peer_addr)) => {
296 tokio::task::spawn_local(session_task.clone().run(peer_addr, reader, writer));
297 }
298 Err(err) => error!("server task failed to accept incoming connection: {err}"),
299 }
300 }
301 }
302
303 pub fn session_task(&self) -> SessionTask<S> {
304 self.session_task.clone()
305 }
306
307 pub fn run_session_task(
308 &self,
309 peer_addr: SocketAddr,
310 reader: impl AsyncRead + Unpin + 'static,
311 writer: impl AsyncWrite + Unpin + 'static,
312 ) -> impl Future<Output = ()> {
313 self.session_task.clone().run(peer_addr, reader, writer)
314 }
315
316 pub fn drop_all_connections(&self) {}
317}
318
319impl<S: MessagesStorage> Stream for Acceptor<S> {
320 type Item = impl AsEvent;
321
322 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
323 Pin::new(&mut self.event_stream).poll_next(cx)
324 }
325
326 fn size_hint(&self) -> (usize, Option<usize>) {
327 self.event_stream.size_hint()
328 }
329}