easyfix_session/
acceptor.rs

1use std::{
2    cell::RefCell,
3    collections::HashMap,
4    future::Future,
5    io,
6    net::SocketAddr,
7    pin::Pin,
8    rc::Rc,
9    task::{Context, Poll},
10};
11
12use easyfix_messages::fields::{FixString, SeqNum, SessionStatus};
13use futures::{self, Stream};
14use pin_project::pin_project;
15use tokio::{
16    io::{AsyncRead, AsyncWrite},
17    net::TcpListener,
18    task::JoinHandle,
19};
20use tracing::{Instrument, error, info, info_span, instrument, warn};
21
22use crate::{
23    DisconnectReason, Settings,
24    application::{AsEvent, Emitter, EventStream, events_channel},
25    io::acceptor_connection,
26    messages_storage::MessagesStorage,
27    session::Session,
28    session_id::SessionId,
29    session_state::State as SessionState,
30    settings::SessionSettings,
31};
32
33#[allow(async_fn_in_trait)]
34pub trait Connection {
35    async fn accept(
36        &mut self,
37    ) -> Result<
38        (
39            impl AsyncRead + Unpin + 'static,
40            impl AsyncWrite + Unpin + 'static,
41            SocketAddr,
42        ),
43        io::Error,
44    >;
45}
46
47pub struct TcpConnection {
48    listener: TcpListener,
49}
50
51impl TcpConnection {
52    pub async fn new(socket_addr: impl Into<SocketAddr>) -> Result<TcpConnection, io::Error> {
53        let socket_addr = socket_addr.into();
54        let listener = TcpListener::bind(&socket_addr).await?;
55        Ok(TcpConnection { listener })
56    }
57}
58
59impl Connection for TcpConnection {
60    async fn accept(
61        &mut self,
62    ) -> Result<
63        (
64            impl AsyncRead + Unpin + 'static,
65            impl AsyncWrite + Unpin + 'static,
66            SocketAddr,
67        ),
68        io::Error,
69    > {
70        let (tcp_stream, peer_addr) = self.listener.accept().await?;
71        tcp_stream.set_nodelay(true)?;
72        let (reader, writer) = tcp_stream.into_split();
73        Ok((reader, writer, peer_addr))
74    }
75}
76
77type SessionMapInternal<S> = HashMap<SessionId, (SessionSettings, Rc<RefCell<SessionState<S>>>)>;
78
79pub struct SessionsMap<S> {
80    map: SessionMapInternal<S>,
81    message_storage_builder: Box<dyn Fn(&SessionId) -> S>,
82}
83
84impl<S: MessagesStorage> SessionsMap<S> {
85    fn new(message_storage_builder: Box<dyn Fn(&SessionId) -> S>) -> SessionsMap<S> {
86        SessionsMap {
87            map: HashMap::new(),
88            message_storage_builder,
89        }
90    }
91
92    #[rustfmt::skip]
93    pub fn register_session(&mut self, session_id: SessionId, session_settings: SessionSettings) {
94        self.map.insert(
95            session_id.clone(),
96            (
97                session_settings,
98                Rc::new(RefCell::new(SessionState::new(
99                    (self.message_storage_builder)(&session_id),
100                ))),
101            ),
102        );
103    }
104
105    pub(crate) fn get_session(
106        &self,
107        session_id: &SessionId,
108    ) -> Option<(SessionSettings, Rc<RefCell<SessionState<S>>>)> {
109        self.map.get(session_id).cloned()
110    }
111}
112
113pub struct SessionTask<S> {
114    settings: Settings,
115    sessions: Rc<RefCell<SessionsMap<S>>>,
116    active_sessions: Rc<RefCell<ActiveSessionsMap<S>>>,
117    emitter: Emitter,
118}
119
120impl<S> Clone for SessionTask<S> {
121    fn clone(&self) -> Self {
122        Self {
123            settings: self.settings.clone(),
124            sessions: self.sessions.clone(),
125            active_sessions: self.active_sessions.clone(),
126            emitter: self.emitter.clone(),
127        }
128    }
129}
130
131impl<S: MessagesStorage + 'static> SessionTask<S> {
132    fn new(
133        settings: Settings,
134        sessions: Rc<RefCell<SessionsMap<S>>>,
135        active_sessions: Rc<RefCell<ActiveSessionsMap<S>>>,
136        emitter: Emitter,
137    ) -> SessionTask<S> {
138        SessionTask {
139            settings,
140            sessions,
141            active_sessions,
142            emitter,
143        }
144    }
145
146    pub async fn run(
147        self,
148        peer_addr: SocketAddr,
149        reader: impl AsyncRead + Unpin + 'static,
150        writer: impl AsyncWrite + Unpin + 'static,
151    ) {
152        let span = info_span!("connection", %peer_addr);
153
154        span.in_scope(|| {
155            info!("New connection");
156        });
157
158        acceptor_connection(
159            reader,
160            writer,
161            self.settings,
162            self.sessions,
163            self.active_sessions,
164            self.emitter,
165        )
166        .instrument(span.clone())
167        .await;
168
169        span.in_scope(|| {
170            info!("Connection closed");
171        });
172    }
173}
174
175pub(crate) type ActiveSessionsMap<S> = HashMap<SessionId, Rc<Session<S>>>;
176
177#[pin_project]
178pub struct Acceptor<S> {
179    sessions: Rc<RefCell<SessionsMap<S>>>,
180    active_sessions: Rc<RefCell<ActiveSessionsMap<S>>>,
181    session_task: SessionTask<S>,
182    #[pin]
183    event_stream: EventStream,
184}
185
186impl<S: MessagesStorage + 'static> Acceptor<S> {
187    pub fn new(
188        settings: Settings,
189        message_storage_builder: Box<dyn Fn(&SessionId) -> S>,
190    ) -> Acceptor<S> {
191        let (emitter, event_stream) = events_channel();
192        let sessions = Rc::new(RefCell::new(SessionsMap::new(message_storage_builder)));
193        let active_sessions = Rc::new(RefCell::new(HashMap::new()));
194        let session_task_builder =
195            SessionTask::new(settings, sessions.clone(), active_sessions.clone(), emitter);
196
197        Acceptor {
198            sessions,
199            active_sessions,
200            session_task: session_task_builder,
201            event_stream,
202        }
203    }
204
205    pub fn register_session(&mut self, session_id: SessionId, session_settings: SessionSettings) {
206        self.sessions
207            .borrow_mut()
208            .register_session(session_id, session_settings);
209    }
210
211    pub fn sessions_map(&self) -> Rc<RefCell<SessionsMap<S>>> {
212        self.sessions.clone()
213    }
214
215    pub fn start(&self, connection: impl Connection + 'static) -> JoinHandle<()> {
216        tokio::task::spawn_local(Self::server_task(connection, self.session_task.clone()))
217    }
218
219    pub fn logout(
220        &self,
221        session_id: &SessionId,
222        session_status: Option<SessionStatus>,
223        reason: Option<FixString>,
224    ) {
225        let active_sessions = self.active_sessions.borrow();
226        let Some(session) = active_sessions.get(session_id) else {
227            warn!("logout: session {session_id} not found");
228            return;
229        };
230
231        session.send_logout(&mut session.state().borrow_mut(), session_status, reason);
232    }
233
234    pub fn disconnect(&self, session_id: &SessionId) {
235        let mut active_sessions = self.active_sessions.borrow_mut();
236        let Some(session) = active_sessions.remove(session_id) else {
237            warn!("logout: session {session_id} not found");
238            return;
239        };
240
241        session.disconnect(
242            &mut session.state().borrow_mut(),
243            DisconnectReason::ApplicationForcedDisconnect,
244        );
245    }
246
247    /// Force reset of the session
248    ///
249    /// Functionally equivalent to `reset_on_logon/logout/disconnect` settings,
250    /// but triggered manually.
251    ///
252    /// You may call this after [Self::disconnect] if you want to manually reset the connection
253    pub fn reset(&self, session_id: &SessionId) {
254        let active_sessions = self.active_sessions.borrow();
255        let Some(session) = active_sessions.get(session_id) else {
256            warn!("reset: session {session_id} not found");
257            return;
258        };
259
260        session.reset(&mut session.state().borrow_mut());
261    }
262
263    /// Sender seq_num getter
264    #[instrument(skip(self))]
265    pub fn next_sender_msg_seq_num(&self, session_id: &SessionId) -> SeqNum {
266        let active_sessions = self.active_sessions.borrow();
267        let Some(session) = active_sessions.get(session_id) else {
268            warn!("session not found");
269            return 0;
270        };
271
272        let state = session.state().borrow_mut();
273        state.next_sender_msg_seq_num()
274    }
275
276    /// Override sender's next seq_num
277    #[instrument(skip(self))]
278    pub fn set_next_sender_msg_seq_num(&self, session_id: &SessionId, seq_num: SeqNum) {
279        let active_sessions = self.active_sessions.borrow();
280        let Some(session) = active_sessions.get(session_id) else {
281            warn!("session not found");
282            return;
283        };
284
285        session
286            .state()
287            .borrow_mut()
288            .set_next_sender_msg_seq_num(seq_num);
289    }
290
291    async fn server_task(mut connection: impl Connection, session_task: SessionTask<S>) {
292        info!("Acceptor started");
293        loop {
294            match connection.accept().await {
295                Ok((reader, writer, peer_addr)) => {
296                    tokio::task::spawn_local(session_task.clone().run(peer_addr, reader, writer));
297                }
298                Err(err) => error!("server task failed to accept incoming connection: {err}"),
299            }
300        }
301    }
302
303    pub fn session_task(&self) -> SessionTask<S> {
304        self.session_task.clone()
305    }
306
307    pub fn run_session_task(
308        &self,
309        peer_addr: SocketAddr,
310        reader: impl AsyncRead + Unpin + 'static,
311        writer: impl AsyncWrite + Unpin + 'static,
312    ) -> impl Future<Output = ()> {
313        self.session_task.clone().run(peer_addr, reader, writer)
314    }
315
316    pub fn drop_all_connections(&self) {}
317}
318
319impl<S: MessagesStorage> Stream for Acceptor<S> {
320    type Item = impl AsEvent;
321
322    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
323        Pin::new(&mut self.event_stream).poll_next(cx)
324    }
325
326    fn size_hint(&self) -> (usize, Option<usize>) {
327        self.event_stream.size_hint()
328    }
329}