1use std::{
2 cell::{Cell, RefCell},
3 collections::HashMap,
4 future::Future,
5 io,
6 net::SocketAddr,
7 pin::Pin,
8 rc::Rc,
9 task::{Context, Poll},
10};
11
12use easyfix_messages::fields::{FixString, SeqNum, SessionStatus};
13use futures::{self, Stream};
14use pin_project::pin_project;
15use tokio::{
16 io::{AsyncRead, AsyncWrite},
17 net::TcpListener,
18 task::JoinHandle,
19};
20use tracing::{Instrument, error, info, info_span, instrument, warn};
21
22use crate::{
23 DisconnectReason, Settings,
24 application::{AsEvent, Emitter, EventStream, events_channel},
25 io::acceptor_connection,
26 messages_storage::MessagesStorage,
27 session::Session,
28 session_id::SessionId,
29 session_state::State as SessionState,
30 settings::SessionSettings,
31};
32
33#[derive(Debug, thiserror::Error)]
34pub enum AcceptorError {
35 #[error("Unknown session")]
36 UnknownSession,
37 #[error("Session active")]
38 SessionActive,
39}
40
41#[allow(async_fn_in_trait)]
42pub trait Connection {
43 async fn accept(
44 &mut self,
45 ) -> Result<
46 (
47 impl AsyncRead + Unpin + 'static,
48 impl AsyncWrite + Unpin + 'static,
49 SocketAddr,
50 ),
51 io::Error,
52 >;
53}
54
55pub struct TcpConnection {
56 listener: TcpListener,
57}
58
59impl TcpConnection {
60 pub async fn new(socket_addr: impl Into<SocketAddr>) -> Result<TcpConnection, io::Error> {
61 let socket_addr = socket_addr.into();
62 let listener = TcpListener::bind(&socket_addr).await?;
63 Ok(TcpConnection { listener })
64 }
65}
66
67impl Connection for TcpConnection {
68 async fn accept(
69 &mut self,
70 ) -> Result<
71 (
72 impl AsyncRead + Unpin + 'static,
73 impl AsyncWrite + Unpin + 'static,
74 SocketAddr,
75 ),
76 io::Error,
77 > {
78 let (tcp_stream, peer_addr) = self.listener.accept().await?;
79 tcp_stream.set_nodelay(true)?;
80 let (reader, writer) = tcp_stream.into_split();
81 Ok((reader, writer, peer_addr))
82 }
83}
84
85type SessionMapInternal<S> = HashMap<SessionId, (SessionSettings, Rc<RefCell<SessionState<S>>>)>;
86
87pub struct SessionsMap<S> {
88 map: SessionMapInternal<S>,
89 message_storage_builder: Box<dyn Fn(&SessionId) -> S>,
90}
91
92impl<S: MessagesStorage> SessionsMap<S> {
93 fn new(message_storage_builder: Box<dyn Fn(&SessionId) -> S>) -> SessionsMap<S> {
94 SessionsMap {
95 map: HashMap::new(),
96 message_storage_builder,
97 }
98 }
99
100 pub fn register_session(&mut self, session_id: SessionId, session_settings: SessionSettings) {
101 let storage = (self.message_storage_builder)(&session_id);
102 self.map.insert(
103 session_id.clone(),
104 (
105 session_settings,
106 Rc::new(RefCell::new(SessionState::new(storage))),
107 ),
108 );
109 }
110
111 pub(crate) fn get_session(
112 &self,
113 session_id: &SessionId,
114 ) -> Option<(SessionSettings, Rc<RefCell<SessionState<S>>>)> {
115 self.map.get(session_id).cloned()
116 }
117
118 fn contains(&self, session_id: &SessionId) -> bool {
119 self.map.contains_key(session_id)
120 }
121}
122
123pub struct SessionTask<S> {
124 settings: Settings,
125 sessions: Rc<RefCell<SessionsMap<S>>>,
126 active_sessions: Rc<RefCell<ActiveSessionsMap<S>>>,
127 emitter: Emitter,
128 enabled: Rc<Cell<bool>>,
129}
130
131impl<S> Clone for SessionTask<S> {
132 fn clone(&self) -> Self {
133 Self {
134 settings: self.settings.clone(),
135 sessions: self.sessions.clone(),
136 active_sessions: self.active_sessions.clone(),
137 emitter: self.emitter.clone(),
138 enabled: self.enabled.clone(),
139 }
140 }
141}
142
143impl<S: MessagesStorage + 'static> SessionTask<S> {
144 fn new(
145 settings: Settings,
146 sessions: Rc<RefCell<SessionsMap<S>>>,
147 active_sessions: Rc<RefCell<ActiveSessionsMap<S>>>,
148 emitter: Emitter,
149 enabled: Rc<Cell<bool>>,
150 ) -> SessionTask<S> {
151 SessionTask {
152 settings,
153 sessions,
154 active_sessions,
155 emitter,
156 enabled,
157 }
158 }
159
160 pub async fn run(
161 self,
162 peer_addr: SocketAddr,
163 reader: impl AsyncRead + Unpin + 'static,
164 writer: impl AsyncWrite + Unpin + 'static,
165 ) {
166 let span = info_span!("connection", %peer_addr);
167
168 span.in_scope(|| {
169 info!("New connection");
170 });
171
172 if self.enabled.get() {
173 acceptor_connection(
174 reader,
175 writer,
176 self.settings,
177 self.sessions,
178 self.active_sessions,
179 self.emitter,
180 self.enabled,
181 )
182 .instrument(span.clone())
183 .await;
184 } else {
185 span.in_scope(|| warn!("Acceptor is disabled"))
186 }
187
188 span.in_scope(|| {
189 info!("Connection closed");
190 });
191 }
192}
193
194pub(crate) type ActiveSessionsMap<S> = HashMap<SessionId, Rc<Session<S>>>;
195
196#[pin_project]
197pub struct Acceptor<S> {
198 sessions: Rc<RefCell<SessionsMap<S>>>,
199 active_sessions: Rc<RefCell<ActiveSessionsMap<S>>>,
200 session_task: SessionTask<S>,
201 #[pin]
202 event_stream: EventStream,
203 enabled: Rc<Cell<bool>>,
204}
205
206impl<S: MessagesStorage + 'static> Acceptor<S> {
207 pub fn new(
208 settings: Settings,
209 message_storage_builder: Box<dyn Fn(&SessionId) -> S>,
210 ) -> Acceptor<S> {
211 let (emitter, event_stream) = events_channel();
212 let sessions = Rc::new(RefCell::new(SessionsMap::new(message_storage_builder)));
213 let active_sessions = Rc::new(RefCell::new(HashMap::new()));
214 let enabled = Rc::new(Cell::new(true));
215 let session_task = SessionTask::new(
216 settings,
217 sessions.clone(),
218 active_sessions.clone(),
219 emitter,
220 enabled.clone(),
221 );
222
223 Acceptor {
224 sessions,
225 active_sessions,
226 session_task,
227 event_stream,
228 enabled,
229 }
230 }
231
232 pub fn enable(&self) {
233 info!("acceptor enabled");
234 self.enabled.set(true);
235 }
236
237 pub fn disable(&self) {
238 info!("acceptor disabled");
239 self.enabled.set(false);
240 for (_, session) in self.active_sessions.borrow_mut().drain() {
241 session.disconnect(
242 &mut session.state().borrow_mut(),
243 DisconnectReason::ApplicationForcedDisconnect,
244 );
245 }
246 }
247
248 pub fn disable_with_logout(
249 &self,
250 session_status: Option<SessionStatus>,
251 reason: Option<FixString>,
252 ) {
253 info!("acceptor disabled with logout");
254 self.enabled.set(false);
255 for (_, session) in self.active_sessions.borrow_mut().drain() {
256 let mut state = session.state().borrow_mut();
257 session.send_logout(&mut state, session_status, reason.clone());
258 session.disconnect(&mut state, DisconnectReason::ApplicationForcedDisconnect);
259 }
260 }
261
262 pub fn register_session(&mut self, session_id: SessionId, session_settings: SessionSettings) {
263 self.sessions
264 .borrow_mut()
265 .register_session(session_id, session_settings);
266 }
267
268 pub fn sessions_map(&self) -> Rc<RefCell<SessionsMap<S>>> {
269 self.sessions.clone()
270 }
271
272 pub fn start(&self, connection: impl Connection + 'static) -> JoinHandle<()> {
273 tokio::task::spawn_local(Self::server_task(connection, self.session_task.clone()))
274 }
275
276 pub fn is_session_active(&self, session_id: &SessionId) -> Result<bool, AcceptorError> {
277 if self.active_sessions.borrow().contains_key(session_id) {
278 Ok(true)
279 } else if self.sessions.borrow().contains(session_id) {
280 Ok(false)
281 } else {
282 Err(AcceptorError::UnknownSession)
283 }
284 }
285
286 pub fn logout(
287 &self,
288 session_id: &SessionId,
289 session_status: Option<SessionStatus>,
290 reason: Option<FixString>,
291 ) -> Result<(), AcceptorError> {
292 if let Some(session) = self.active_sessions.borrow().get(session_id) {
293 session.send_logout(&mut session.state().borrow_mut(), session_status, reason);
294 Ok(())
295 } else if self.sessions.borrow().contains(session_id) {
296 Ok(())
298 } else {
299 Err(AcceptorError::UnknownSession)
300 }
301 }
302
303 pub fn disconnect(&self, session_id: &SessionId) -> Result<(), AcceptorError> {
304 if let Some(session) = self.active_sessions.borrow_mut().remove(session_id) {
305 session.disconnect(
306 &mut session.state().borrow_mut(),
307 DisconnectReason::ApplicationForcedDisconnect,
308 );
309 Ok(())
310 } else if self.sessions.borrow().contains(session_id) {
311 Ok(())
313 } else {
314 Err(AcceptorError::UnknownSession)
315 }
316 }
317
318 pub fn disconnect_with_logout(
319 &self,
320 session_id: &SessionId,
321 session_status: Option<SessionStatus>,
322 reason: Option<FixString>,
323 ) -> Result<(), AcceptorError> {
324 if let Some(session) = self.active_sessions.borrow().get(session_id) {
325 session.send_logout(&mut session.state().borrow_mut(), session_status, reason);
326 session.disconnect(
327 &mut session.state().borrow_mut(),
328 DisconnectReason::ApplicationForcedDisconnect,
329 );
330 Ok(())
331 } else if self.sessions.borrow().contains(session_id) {
332 Ok(())
334 } else {
335 Err(AcceptorError::UnknownSession)
336 }
337 }
338
339 #[instrument(skip_all, fields(session_id=%session_id) ret)]
348 pub fn reset(&self, session_id: &SessionId) -> Result<(), AcceptorError> {
349 if self.active_sessions.borrow().contains_key(session_id) {
350 Err(AcceptorError::SessionActive)
351 } else if let Some((_, session_state)) = self.sessions.borrow().get_session(session_id) {
352 session_state.borrow_mut().reset();
353 Ok(())
354 } else {
355 Err(AcceptorError::UnknownSession)
356 }
357 }
358
359 #[instrument(skip_all, fields(session_id=%session_id) ret)]
361 pub fn force_reset(&self, session_id: &SessionId) -> Result<(), AcceptorError> {
362 if let Some(session) = self.active_sessions.borrow().get(session_id) {
363 session.state().borrow_mut().reset();
364 Ok(())
365 } else if let Some((_, session_state)) = self.sessions.borrow().get_session(session_id) {
366 session_state.borrow_mut().reset();
367 Ok(())
368 } else {
369 Err(AcceptorError::UnknownSession)
370 }
371 }
372
373 #[instrument(skip_all, fields(session_id=%session_id) ret)]
375 pub fn next_sender_msg_seq_num(&self, session_id: &SessionId) -> Result<SeqNum, AcceptorError> {
376 if let Some(session) = self.active_sessions.borrow().get(session_id) {
377 Ok(session.state().borrow().next_sender_msg_seq_num())
378 } else if let Some((_, session_state)) = self.sessions.borrow().get_session(session_id) {
379 Ok(session_state.borrow().next_sender_msg_seq_num())
380 } else {
381 Err(AcceptorError::UnknownSession)
382 }
383 }
384
385 #[instrument(skip_all, fields(session_id=%session_id, seq_num) ret)]
387 pub fn set_next_sender_msg_seq_num(
388 &self,
389 session_id: &SessionId,
390 seq_num: SeqNum,
391 ) -> Result<(), AcceptorError> {
392 if let Some(session) = self.active_sessions.borrow().get(session_id) {
393 session
394 .state()
395 .borrow_mut()
396 .set_next_sender_msg_seq_num(seq_num);
397 Ok(())
398 } else if let Some((_, session_state)) = self.sessions.borrow().get_session(session_id) {
399 session_state
400 .borrow_mut()
401 .set_next_sender_msg_seq_num(seq_num);
402 Ok(())
403 } else {
404 Err(AcceptorError::UnknownSession)
405 }
406 }
407
408 async fn server_task(mut connection: impl Connection, session_task: SessionTask<S>) {
409 info!("Acceptor started");
410 loop {
411 match connection.accept().await {
412 Ok((reader, writer, peer_addr)) => {
413 tokio::task::spawn_local(session_task.clone().run(peer_addr, reader, writer));
414 }
415 Err(err) => error!("server task failed to accept incoming connection: {err}"),
416 }
417 }
418 }
419
420 pub fn session_task(&self) -> SessionTask<S> {
421 self.session_task.clone()
422 }
423
424 pub fn run_session_task(
425 &self,
426 peer_addr: SocketAddr,
427 reader: impl AsyncRead + Unpin + 'static,
428 writer: impl AsyncWrite + Unpin + 'static,
429 ) -> impl Future<Output = ()> {
430 self.session_task.clone().run(peer_addr, reader, writer)
431 }
432}
433
434impl<S: MessagesStorage> Stream for Acceptor<S> {
435 type Item = impl AsEvent;
436
437 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
438 Pin::new(&mut self.event_stream).poll_next(cx)
439 }
440
441 fn size_hint(&self) -> (usize, Option<usize>) {
442 self.event_stream.size_hint()
443 }
444}