easyfix_session/
io.rs

1use std::{
2    cell::{Cell, RefCell},
3    collections::{HashMap, hash_map::Entry},
4    rc::Rc,
5    sync::Mutex,
6    time::Duration,
7};
8
9use easyfix_messages::{
10    fields::{FixString, SessionStatus},
11    messages::{FixtMessage, Message},
12};
13use futures_util::{Stream, pin_mut};
14use tokio::{
15    self,
16    io::{AsyncRead, AsyncWrite, AsyncWriteExt},
17    net::TcpStream,
18    sync::{mpsc, oneshot},
19};
20use tokio_stream::StreamExt;
21use tracing::{Instrument, Span, debug, error, info, info_span, warn};
22
23use crate::{
24    DisconnectReason, Error, NO_INBOUND_TIMEOUT_PADDING, Sender, SessionError,
25    TEST_REQUEST_THRESHOLD,
26    acceptor::{ActiveSessionsMap, SessionsMap},
27    application::{Emitter, FixEventInternal},
28    messages_storage::MessagesStorage,
29    session::Session,
30    session_id::SessionId,
31    session_state::State,
32    settings::{SessionSettings, Settings},
33};
34
35mod input_stream;
36pub use input_stream::{InputEvent, InputStream, input_stream};
37
38mod output_stream;
39use output_stream::{OutputEvent, output_stream};
40
41pub mod time;
42use time::{timeout, timeout_at, timeout_stream};
43
44static SENDERS: Mutex<Option<HashMap<SessionId, Sender>>> = Mutex::new(None);
45
46pub fn register_sender(session_id: SessionId, sender: Sender) {
47    if let Entry::Vacant(entry) = SENDERS
48        .lock()
49        .unwrap()
50        .get_or_insert_with(HashMap::new)
51        .entry(session_id)
52    {
53        entry.insert(sender);
54    }
55}
56
57pub fn unregister_sender(session_id: &SessionId) {
58    if SENDERS
59        .lock()
60        .unwrap()
61        .get_or_insert_with(HashMap::new)
62        .remove(session_id)
63        .is_none()
64    {
65        // TODO: ERROR?
66    }
67}
68
69pub fn sender(session_id: &SessionId) -> Option<Sender> {
70    SENDERS
71        .lock()
72        .unwrap()
73        .get_or_insert_with(HashMap::new)
74        .get(session_id)
75        .cloned()
76}
77
78// TODO: Remove?
79pub fn send(session_id: &SessionId, msg: Box<Message>) -> Result<(), Box<Message>> {
80    if let Some(sender) = sender(session_id) {
81        sender.send(msg).map_err(|msg| msg.body)
82    } else {
83        Err(msg)
84    }
85}
86
87pub fn send_raw(msg: Box<FixtMessage>) -> Result<(), Box<FixtMessage>> {
88    if let Some(sender) = sender(&SessionId::from_input_msg(&msg)) {
89        sender.send_raw(msg)
90    } else {
91        Err(msg)
92    }
93}
94
95async fn first_msg(
96    stream: &mut (impl Stream<Item = InputEvent> + Unpin),
97    logon_timeout: Duration,
98) -> Result<Box<FixtMessage>, Error> {
99    match timeout(logon_timeout, stream.next()).await {
100        Ok(Some(InputEvent::Message(msg))) => Ok(msg),
101        Ok(Some(InputEvent::IoError(error))) => Err(error.into()),
102        Ok(Some(InputEvent::DeserializeError(error))) => {
103            error!("failed to deserialize first message: {error}");
104            Err(Error::SessionError(SessionError::LogonNeverReceived))
105        }
106        _ => Err(Error::SessionError(SessionError::LogonNeverReceived)),
107    }
108}
109
110#[derive(Debug)]
111struct Connection<S> {
112    session: Rc<Session<S>>,
113}
114
115pub(crate) async fn acceptor_connection<S>(
116    reader: impl AsyncRead + Unpin,
117    writer: impl AsyncWrite + Unpin,
118    settings: Settings,
119    sessions: Rc<RefCell<SessionsMap<S>>>,
120    active_sessions: Rc<RefCell<ActiveSessionsMap<S>>>,
121    emitter: Emitter,
122    enabled: Rc<Cell<bool>>,
123) where
124    S: MessagesStorage,
125{
126    let stream = input_stream(reader);
127    let logon_timeout =
128        settings.auto_disconnect_after_no_logon_received + NO_INBOUND_TIMEOUT_PADDING;
129    pin_mut!(stream);
130    let msg = match first_msg(&mut stream, logon_timeout).await {
131        Ok(msg) => msg,
132        Err(err) => {
133            error!(%err, "failed to establish new session");
134            return;
135        }
136    };
137
138    let session_id = SessionId::from_input_msg(&msg);
139    debug!(first_msg = ?msg);
140
141    // XXX: there should be no await point between active_sessions.insert below
142    if !enabled.get() {
143        warn!("Acceptor is disabled, drop connection");
144        return;
145    }
146
147    let (sender, receiver) = mpsc::unbounded_channel();
148    let sender = Sender::new(sender);
149
150    let Some((session_settings, session_state)) = sessions.borrow().get_session(&session_id) else {
151        error!(%session_id, "failed to establish new session: unknown session id");
152        return;
153    };
154    if !session_state.borrow_mut().disconnected()
155        || active_sessions.borrow().contains_key(&session_id)
156    {
157        error!(%session_id, "Session already active");
158        return;
159    }
160    session_state.borrow_mut().set_disconnected(false);
161    register_sender(session_id.clone(), sender.clone());
162
163    let (disconnect_tx, disconnect_rx) = oneshot::channel();
164
165    let session = Rc::new(Session::new(
166        settings,
167        session_settings,
168        session_state,
169        sender,
170        emitter.clone(),
171        disconnect_tx,
172    ));
173
174    active_sessions
175        .borrow_mut()
176        .insert(session_id.clone(), session.clone());
177
178    let session_span = info_span!(
179        parent: None,
180        "session",
181        id = %session_id
182    );
183    session_span.follows_from(Span::current());
184
185    let input_loop_span = info_span!(parent: &session_span, "in");
186    let output_loop_span = info_span!(parent: &session_span, "out");
187
188    let force_disconnection_with_reason = session
189        .on_message_in(msg)
190        .instrument(input_loop_span.clone())
191        .await;
192
193    // TODO: Not here!, send this event when SessionState is created!
194    emitter
195        .send(FixEventInternal::Created(session_id.clone()))
196        .await;
197
198    let input_timeout_duration = session.heartbeat_interval().mul_f32(TEST_REQUEST_THRESHOLD);
199    let input_stream = timeout_stream(input_timeout_duration, stream)
200        .map(|res| res.unwrap_or(InputEvent::Timeout));
201    pin_mut!(input_stream);
202
203    let output_stream = output_stream(session.clone(), session.heartbeat_interval(), receiver);
204    pin_mut!(output_stream);
205
206    let connection = Connection::new(session);
207    let (input_closed_tx, input_closed_rx) = oneshot::channel();
208
209    tokio::join!(
210        connection
211            .input_loop(
212                input_stream,
213                input_closed_tx,
214                force_disconnection_with_reason,
215                disconnect_rx,
216            )
217            .instrument(input_loop_span),
218        connection
219            .output_loop(writer, output_stream, input_closed_rx)
220            .instrument(output_loop_span),
221    );
222    session_span.in_scope(|| {
223        info!("connection closed");
224    });
225    unregister_sender(&session_id);
226    active_sessions.borrow_mut().remove(&session_id);
227}
228
229pub(crate) async fn initiator_connection<S>(
230    tcp_stream: TcpStream,
231    settings: Settings,
232    session_settings: SessionSettings,
233    state: Rc<RefCell<State<S>>>,
234    active_sessions: Rc<RefCell<ActiveSessionsMap<S>>>,
235    emitter: Emitter,
236) where
237    S: MessagesStorage,
238{
239    let (source, sink) = tcp_stream.into_split();
240    state.borrow_mut().set_disconnected(false);
241    let session_id = session_settings.session_id.clone();
242
243    let (sender, receiver) = mpsc::unbounded_channel();
244    let sender = Sender::new(sender);
245
246    let (disconnect_tx, disconnect_rx) = oneshot::channel();
247
248    register_sender(session_id.clone(), sender.clone());
249    let session = Rc::new(Session::new(
250        settings,
251        session_settings,
252        state,
253        sender,
254        emitter.clone(),
255        disconnect_tx,
256    ));
257    active_sessions
258        .borrow_mut()
259        .insert(session_id.clone(), session.clone());
260
261    let session_span = info_span!(
262        "session",
263        id = %session_id
264    );
265
266    let input_loop_span = info_span!(parent: &session_span, "in");
267    let output_loop_span = info_span!(parent: &session_span, "out");
268
269    // TODO: Not here!, send this event when SessionState is created!
270    emitter
271        .send(FixEventInternal::Created(session_id.clone()))
272        .await;
273
274    let input_timeout_duration = session.heartbeat_interval().mul_f32(TEST_REQUEST_THRESHOLD);
275    let input_stream = timeout_stream(input_timeout_duration, input_stream(source))
276        .map(|res| res.unwrap_or(InputEvent::Timeout));
277    pin_mut!(input_stream);
278
279    let output_stream = output_stream(session.clone(), session.heartbeat_interval(), receiver);
280    pin_mut!(output_stream);
281
282    // TODO: It's not so simple, add check if session time is within range,
283    //       if not schedule timer to send logon at proper time
284    session.send_logon_request(&mut session.state().borrow_mut());
285
286    let connection = Connection::new(session);
287    let (input_closed_tx, input_closed_rx) = oneshot::channel();
288
289    tokio::join!(
290        connection
291            .input_loop(input_stream, input_closed_tx, None, disconnect_rx)
292            .instrument(input_loop_span),
293        connection
294            .output_loop(sink, output_stream, input_closed_rx)
295            .instrument(output_loop_span),
296    );
297    info!("connection closed");
298    unregister_sender(&session_id);
299    active_sessions.borrow_mut().remove(&session_id);
300}
301
302impl<S: MessagesStorage> Connection<S> {
303    fn new(session: Rc<Session<S>>) -> Connection<S> {
304        Connection { session }
305    }
306
307    async fn input_loop(
308        &self,
309        mut input_stream: impl Stream<Item = InputEvent> + Unpin,
310        input_closed_tx: oneshot::Sender<()>,
311        force_disconnection_with_reason: Option<DisconnectReason>,
312        mut disconnect_rx: oneshot::Receiver<()>,
313    ) {
314        if let Some(disconnect_reason) = force_disconnection_with_reason {
315            self.session
316                .disconnect(&mut self.session.state().borrow_mut(), disconnect_reason);
317
318            // Notify output loop that all input is processed so output queue can
319            // be safely closed.
320            // See `fn send()` and `fn send_raw()` from session.rs.
321            input_closed_tx
322                .send(())
323                .expect("Failed to notify about closed inpuot");
324
325            return;
326        }
327
328        let mut disconnect_reason = DisconnectReason::Disconnected;
329        let mut logout_deadline = None;
330
331        let mut next_item = async || {
332            if logout_deadline.is_none() {
333                logout_deadline = self.session.logout_deadline();
334            }
335            if let Some(logout_deadline) = logout_deadline {
336                timeout_at(logout_deadline, input_stream.next())
337                    .await
338                    .unwrap_or(Some(InputEvent::LogoutTimeout))
339            } else {
340                input_stream.next().await
341            }
342        };
343
344        loop {
345            let event = tokio::select! {
346                // Wait for network input
347                event = next_item() => {
348                    if let Some(event) = event {
349                        event
350                    } else {
351                        break
352                    }
353                }
354
355                // Wait for disconnect signal from Session::disconnect()
356                _ = &mut disconnect_rx => {
357                    info!("Disconnect signaled, exiting input loop");
358                    disconnect_reason = DisconnectReason::ApplicationForcedDisconnect;
359                    break;
360                }
361            };
362
363            // Don't accept new messages if session is disconnected.
364            if self.session.state().borrow().disconnected() {
365                info!("session disconnected, exit input processing");
366                // Notify output loop that all input is processed so output queue can
367                // be safely closed.
368                // See `fn send()` and `fn send_raw()` from session.rs.
369                input_closed_tx
370                    .send(())
371                    .expect("Failed to notify about closed input");
372                return;
373            }
374
375            match event {
376                InputEvent::Message(msg) => {
377                    if let Some(reason) = self.session.on_message_in(msg).await {
378                        info!(?reason, "disconnect, exit input processing");
379                        disconnect_reason = reason;
380                        break;
381                    }
382                }
383                InputEvent::DeserializeError(error) => {
384                    if let Some(reason) = self.session.on_deserialize_error(error).await {
385                        info!(?reason, "disconnect, exit input processing");
386                        disconnect_reason = reason;
387                        break;
388                    }
389                }
390                InputEvent::IoError(error) => {
391                    error!(%error, "Input error");
392                    disconnect_reason = DisconnectReason::IoError;
393                    break;
394                }
395                InputEvent::Timeout => {
396                    if self.session.on_in_timeout().await {
397                        self.session.send_logout(
398                            &mut self.session.state().borrow_mut(),
399                            Some(SessionStatus::SessionLogoutComplete),
400                            Some(FixString::from_ascii_lossy(
401                                b"Grace period is over".to_vec(),
402                            )),
403                        );
404                        break;
405                    }
406                }
407                InputEvent::LogoutTimeout => {
408                    info!("Logout timeout");
409                    disconnect_reason = DisconnectReason::LogoutTimeout;
410                    break;
411                }
412            }
413        }
414        self.session
415            .disconnect(&mut self.session.state().borrow_mut(), disconnect_reason);
416
417        // Notify output loop that all input is processed so output queue can
418        // be safely closed.
419        // See `fn send()` and `fn send_raw()` from session.rs.
420        input_closed_tx
421            .send(())
422            .expect("Failed to notify about closed inpout");
423    }
424
425    async fn output_loop(
426        &self,
427        mut sink: impl AsyncWrite + Unpin,
428        mut output_stream: impl Stream<Item = OutputEvent> + Unpin,
429        input_closed_rx: oneshot::Receiver<()>,
430    ) {
431        let mut sink_closed = false;
432        let mut disconnect_reason = DisconnectReason::Disconnected;
433        while let Some(event) = output_stream.next().await {
434            match event {
435                OutputEvent::Message(msg) => {
436                    if sink_closed {
437                        // Sink is closed - ignore message, but do not break
438                        // the loop. Output stream has to process all enqueued
439                        // messages to made them available
440                        // for ResendRequest<2>.
441                        info!("Client disconnected, message will be stored for further resend");
442                    } else if let Err(error) = sink.write_all(&msg).await {
443                        sink_closed = true;
444                        error!(%error, "Output write error");
445                        // XXX: Don't disconnect now. If IO error happened
446                        //      here, it will aslo happen in input loop
447                        //      and input loop will trigger disconnection.
448                        //      Disonnection from here would lead to message
449                        //      loss when output queue would be closed
450                        //      and input handler would try to send something.
451                        //
452                        // self.session.disconnect(
453                        //     &mut self.session.state().borrow_mut(),
454                        //     DisconnectReason::IoError,
455                        // );
456                    }
457                }
458                OutputEvent::Timeout => self.session.on_out_timeout().await,
459                OutputEvent::Disconnect(reason) => {
460                    // Internal channel is closed in output stream
461                    // inplementation, at this point no new messages
462                    // can be send.
463                    info!("Client disconnected");
464                    if !sink_closed && let Err(error) = sink.flush().await {
465                        error!(%error, "final flush failed");
466                    }
467                    disconnect_reason = reason;
468                }
469            }
470        }
471        // XXX: Emit logout here instead of Session::disconnect, so `Logout`
472        //      event will be delivered after Logout message instead of
473        //      randomly before or after.
474        self.session.emit_logout(disconnect_reason).await;
475
476        // Don't wait for any specific value it's just notification that
477        // input_loop finished, so no more messages can be added to output
478        // queue.
479        let _ = input_closed_rx.await;
480        if let Err(error) = sink.shutdown().await {
481            error!(%error, "connection shutdown failed")
482        }
483        info!("disconnect, exit output processing");
484    }
485}