easyfix_session/
acceptor.rs

1use std::{
2    cell::{Cell, RefCell},
3    collections::HashMap,
4    future::Future,
5    io,
6    net::SocketAddr,
7    pin::Pin,
8    rc::Rc,
9    task::{Context, Poll},
10};
11
12use easyfix_messages::fields::{FixString, SeqNum, SessionStatus};
13use futures::{self, Stream};
14use pin_project::pin_project;
15use tokio::{
16    io::{AsyncRead, AsyncWrite},
17    net::TcpListener,
18    task::JoinHandle,
19};
20use tracing::{Instrument, error, info, info_span, instrument, warn};
21
22use crate::{
23    DisconnectReason, Settings,
24    application::{AsEvent, Emitter, EventStream, events_channel},
25    io::acceptor_connection,
26    messages_storage::MessagesStorage,
27    session::Session,
28    session_id::SessionId,
29    session_state::State as SessionState,
30    settings::SessionSettings,
31};
32
33#[derive(Debug, thiserror::Error)]
34pub enum AcceptorError {
35    #[error("Unknown session")]
36    UnknownSession,
37    #[error("Session active")]
38    SessionActive,
39}
40
41#[allow(async_fn_in_trait)]
42pub trait Connection {
43    async fn accept(
44        &mut self,
45    ) -> Result<
46        (
47            impl AsyncRead + Unpin + 'static,
48            impl AsyncWrite + Unpin + 'static,
49            SocketAddr,
50        ),
51        io::Error,
52    >;
53}
54
55pub struct TcpConnection {
56    listener: TcpListener,
57}
58
59impl TcpConnection {
60    pub async fn new(socket_addr: impl Into<SocketAddr>) -> Result<TcpConnection, io::Error> {
61        let socket_addr = socket_addr.into();
62        let listener = TcpListener::bind(&socket_addr).await?;
63        Ok(TcpConnection { listener })
64    }
65}
66
67impl Connection for TcpConnection {
68    async fn accept(
69        &mut self,
70    ) -> Result<
71        (
72            impl AsyncRead + Unpin + 'static,
73            impl AsyncWrite + Unpin + 'static,
74            SocketAddr,
75        ),
76        io::Error,
77    > {
78        let (tcp_stream, peer_addr) = self.listener.accept().await?;
79        tcp_stream.set_nodelay(true)?;
80        let (reader, writer) = tcp_stream.into_split();
81        Ok((reader, writer, peer_addr))
82    }
83}
84
85type SessionMapInternal<S> = HashMap<SessionId, (SessionSettings, Rc<RefCell<SessionState<S>>>)>;
86
87pub struct SessionsMap<S> {
88    map: SessionMapInternal<S>,
89    message_storage_builder: Box<dyn Fn(&SessionId) -> S>,
90}
91
92impl<S: MessagesStorage> SessionsMap<S> {
93    fn new(message_storage_builder: Box<dyn Fn(&SessionId) -> S>) -> SessionsMap<S> {
94        SessionsMap {
95            map: HashMap::new(),
96            message_storage_builder,
97        }
98    }
99
100    pub fn register_session(&mut self, session_id: SessionId, session_settings: SessionSettings) {
101        let storage = (self.message_storage_builder)(&session_id);
102        self.map.insert(
103            session_id.clone(),
104            (
105                session_settings,
106                Rc::new(RefCell::new(SessionState::new(storage))),
107            ),
108        );
109    }
110
111    pub(crate) fn get_session(
112        &self,
113        session_id: &SessionId,
114    ) -> Option<(SessionSettings, Rc<RefCell<SessionState<S>>>)> {
115        self.map.get(session_id).cloned()
116    }
117
118    fn contains(&self, session_id: &SessionId) -> bool {
119        self.map.contains_key(session_id)
120    }
121}
122
123pub struct SessionTask<S> {
124    settings: Settings,
125    sessions: Rc<RefCell<SessionsMap<S>>>,
126    active_sessions: Rc<RefCell<ActiveSessionsMap<S>>>,
127    emitter: Emitter,
128    enabled: Rc<Cell<bool>>,
129}
130
131impl<S> Clone for SessionTask<S> {
132    fn clone(&self) -> Self {
133        Self {
134            settings: self.settings.clone(),
135            sessions: self.sessions.clone(),
136            active_sessions: self.active_sessions.clone(),
137            emitter: self.emitter.clone(),
138            enabled: self.enabled.clone(),
139        }
140    }
141}
142
143impl<S: MessagesStorage + 'static> SessionTask<S> {
144    fn new(
145        settings: Settings,
146        sessions: Rc<RefCell<SessionsMap<S>>>,
147        active_sessions: Rc<RefCell<ActiveSessionsMap<S>>>,
148        emitter: Emitter,
149        enabled: Rc<Cell<bool>>,
150    ) -> SessionTask<S> {
151        SessionTask {
152            settings,
153            sessions,
154            active_sessions,
155            emitter,
156            enabled,
157        }
158    }
159
160    pub async fn run(
161        self,
162        peer_addr: SocketAddr,
163        reader: impl AsyncRead + Unpin + 'static,
164        writer: impl AsyncWrite + Unpin + 'static,
165    ) {
166        let span = info_span!("connection", %peer_addr);
167
168        span.in_scope(|| {
169            info!("New connection");
170        });
171
172        if self.enabled.get() {
173            acceptor_connection(
174                reader,
175                writer,
176                self.settings,
177                self.sessions,
178                self.active_sessions,
179                self.emitter,
180                self.enabled,
181            )
182            .instrument(span.clone())
183            .await;
184        } else {
185            span.in_scope(|| warn!("Acceptor is disabled"))
186        }
187
188        span.in_scope(|| {
189            info!("Connection closed");
190        });
191    }
192}
193
194pub(crate) type ActiveSessionsMap<S> = HashMap<SessionId, Rc<Session<S>>>;
195
196#[pin_project]
197pub struct Acceptor<S> {
198    sessions: Rc<RefCell<SessionsMap<S>>>,
199    active_sessions: Rc<RefCell<ActiveSessionsMap<S>>>,
200    session_task: SessionTask<S>,
201    #[pin]
202    event_stream: EventStream,
203    enabled: Rc<Cell<bool>>,
204}
205
206impl<S: MessagesStorage + 'static> Acceptor<S> {
207    pub fn new(
208        settings: Settings,
209        message_storage_builder: Box<dyn Fn(&SessionId) -> S>,
210    ) -> Acceptor<S> {
211        let (emitter, event_stream) = events_channel();
212        let sessions = Rc::new(RefCell::new(SessionsMap::new(message_storage_builder)));
213        let active_sessions = Rc::new(RefCell::new(HashMap::new()));
214        let enabled = Rc::new(Cell::new(true));
215        let session_task = SessionTask::new(
216            settings,
217            sessions.clone(),
218            active_sessions.clone(),
219            emitter,
220            enabled.clone(),
221        );
222
223        Acceptor {
224            sessions,
225            active_sessions,
226            session_task,
227            event_stream,
228            enabled,
229        }
230    }
231
232    pub fn enable(&self) {
233        info!("acceptor enabled");
234        self.enabled.set(true);
235    }
236
237    pub fn disable(&self) {
238        info!("acceptor disabled");
239        self.enabled.set(false);
240        for (_, session) in self.active_sessions.borrow_mut().drain() {
241            session.disconnect(
242                &mut session.state().borrow_mut(),
243                DisconnectReason::ApplicationForcedDisconnect,
244            );
245        }
246    }
247
248    pub fn disable_with_logout(
249        &self,
250        session_status: Option<SessionStatus>,
251        reason: Option<FixString>,
252    ) {
253        info!("acceptor disabled with logout");
254        self.enabled.set(false);
255        for (_, session) in self.active_sessions.borrow_mut().drain() {
256            let mut state = session.state().borrow_mut();
257            session.send_logout(&mut state, session_status, reason.clone());
258            session.disconnect(&mut state, DisconnectReason::ApplicationForcedDisconnect);
259        }
260    }
261
262    pub fn register_session(&mut self, session_id: SessionId, session_settings: SessionSettings) {
263        self.sessions
264            .borrow_mut()
265            .register_session(session_id, session_settings);
266    }
267
268    pub fn sessions_map(&self) -> Rc<RefCell<SessionsMap<S>>> {
269        self.sessions.clone()
270    }
271
272    pub fn start(&self, connection: impl Connection + 'static) -> JoinHandle<()> {
273        tokio::task::spawn_local(Self::server_task(connection, self.session_task.clone()))
274    }
275
276    pub fn is_session_active(&self, session_id: &SessionId) -> Result<bool, AcceptorError> {
277        if self.active_sessions.borrow().contains_key(session_id) {
278            Ok(true)
279        } else if self.sessions.borrow().contains(session_id) {
280            Ok(false)
281        } else {
282            Err(AcceptorError::UnknownSession)
283        }
284    }
285
286    pub fn logout(
287        &self,
288        session_id: &SessionId,
289        session_status: Option<SessionStatus>,
290        reason: Option<FixString>,
291    ) -> Result<(), AcceptorError> {
292        if let Some(session) = self.active_sessions.borrow().get(session_id) {
293            session.send_logout(&mut session.state().borrow_mut(), session_status, reason);
294            Ok(())
295        } else if self.sessions.borrow().contains(session_id) {
296            // Already logged out
297            Ok(())
298        } else {
299            Err(AcceptorError::UnknownSession)
300        }
301    }
302
303    pub fn disconnect(&self, session_id: &SessionId) -> Result<(), AcceptorError> {
304        if let Some(session) = self.active_sessions.borrow_mut().remove(session_id) {
305            session.disconnect(
306                &mut session.state().borrow_mut(),
307                DisconnectReason::ApplicationForcedDisconnect,
308            );
309            Ok(())
310        } else if self.sessions.borrow().contains(session_id) {
311            // Already disconnected
312            Ok(())
313        } else {
314            Err(AcceptorError::UnknownSession)
315        }
316    }
317
318    pub fn disconnect_with_logout(
319        &self,
320        session_id: &SessionId,
321        session_status: Option<SessionStatus>,
322        reason: Option<FixString>,
323    ) -> Result<(), AcceptorError> {
324        if let Some(session) = self.active_sessions.borrow().get(session_id) {
325            session.send_logout(&mut session.state().borrow_mut(), session_status, reason);
326            session.disconnect(
327                &mut session.state().borrow_mut(),
328                DisconnectReason::ApplicationForcedDisconnect,
329            );
330            Ok(())
331        } else if self.sessions.borrow().contains(session_id) {
332            // Already logged out
333            Ok(())
334        } else {
335            Err(AcceptorError::UnknownSession)
336        }
337    }
338
339    /// Force reset of the session
340    ///
341    /// Functionally equivalent to `reset_on_logon/logout/disconnect` settings,
342    /// but triggered manually.
343    ///
344    /// Returns [`AcceptorError::SessionActive`] if the session is still active.
345    /// In that case, call [Self::disconnect] or [Self::logout] first and wait
346    /// for the session to fully terminate before retrying.
347    #[instrument(skip_all, fields(session_id=%session_id) ret)]
348    pub fn reset(&self, session_id: &SessionId) -> Result<(), AcceptorError> {
349        if self.active_sessions.borrow().contains_key(session_id) {
350            Err(AcceptorError::SessionActive)
351        } else if let Some((_, session_state)) = self.sessions.borrow().get_session(session_id) {
352            session_state.borrow_mut().reset();
353            Ok(())
354        } else {
355            Err(AcceptorError::UnknownSession)
356        }
357    }
358
359    // TODO: temporary solution, remove when diconnect will be synchronized
360    #[instrument(skip_all, fields(session_id=%session_id) ret)]
361    pub fn force_reset(&self, session_id: &SessionId) -> Result<(), AcceptorError> {
362        if let Some(session) = self.active_sessions.borrow().get(session_id) {
363            session.state().borrow_mut().reset();
364            Ok(())
365        } else if let Some((_, session_state)) = self.sessions.borrow().get_session(session_id) {
366            session_state.borrow_mut().reset();
367            Ok(())
368        } else {
369            Err(AcceptorError::UnknownSession)
370        }
371    }
372
373    /// Sender seq_num getter
374    #[instrument(skip_all, fields(session_id=%session_id) ret)]
375    pub fn next_sender_msg_seq_num(&self, session_id: &SessionId) -> Result<SeqNum, AcceptorError> {
376        if let Some(session) = self.active_sessions.borrow().get(session_id) {
377            Ok(session.state().borrow().next_sender_msg_seq_num())
378        } else if let Some((_, session_state)) = self.sessions.borrow().get_session(session_id) {
379            Ok(session_state.borrow().next_sender_msg_seq_num())
380        } else {
381            Err(AcceptorError::UnknownSession)
382        }
383    }
384
385    /// Override sender's next seq_num
386    #[instrument(skip_all, fields(session_id=%session_id, seq_num) ret)]
387    pub fn set_next_sender_msg_seq_num(
388        &self,
389        session_id: &SessionId,
390        seq_num: SeqNum,
391    ) -> Result<(), AcceptorError> {
392        if let Some(session) = self.active_sessions.borrow().get(session_id) {
393            session
394                .state()
395                .borrow_mut()
396                .set_next_sender_msg_seq_num(seq_num);
397            Ok(())
398        } else if let Some((_, session_state)) = self.sessions.borrow().get_session(session_id) {
399            session_state
400                .borrow_mut()
401                .set_next_sender_msg_seq_num(seq_num);
402            Ok(())
403        } else {
404            Err(AcceptorError::UnknownSession)
405        }
406    }
407
408    async fn server_task(mut connection: impl Connection, session_task: SessionTask<S>) {
409        info!("Acceptor started");
410        loop {
411            match connection.accept().await {
412                Ok((reader, writer, peer_addr)) => {
413                    tokio::task::spawn_local(session_task.clone().run(peer_addr, reader, writer));
414                }
415                Err(err) => error!("server task failed to accept incoming connection: {err}"),
416            }
417        }
418    }
419
420    pub fn session_task(&self) -> SessionTask<S> {
421        self.session_task.clone()
422    }
423
424    pub fn run_session_task(
425        &self,
426        peer_addr: SocketAddr,
427        reader: impl AsyncRead + Unpin + 'static,
428        writer: impl AsyncWrite + Unpin + 'static,
429    ) -> impl Future<Output = ()> {
430        self.session_task.clone().run(peer_addr, reader, writer)
431    }
432}
433
434impl<S: MessagesStorage> Stream for Acceptor<S> {
435    type Item = impl AsEvent;
436
437    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
438        Pin::new(&mut self.event_stream).poll_next(cx)
439    }
440
441    fn size_hint(&self) -> (usize, Option<usize>) {
442        self.event_stream.size_hint()
443    }
444}