easyfix_session/
io.rs

1use std::{
2    cell::RefCell,
3    collections::{HashMap, hash_map::Entry},
4    rc::Rc,
5    sync::Mutex,
6    time::Duration,
7};
8
9use easyfix_messages::{
10    fields::{FixString, SessionStatus},
11    messages::{FixtMessage, Message},
12};
13use futures_util::{Stream, pin_mut};
14use tokio::{
15    self,
16    io::{AsyncRead, AsyncWrite, AsyncWriteExt},
17    net::TcpStream,
18    sync::mpsc,
19};
20use tokio_stream::StreamExt;
21use tracing::{Instrument, Span, debug, error, info, info_span};
22
23use crate::{
24    DisconnectReason, Error, NO_INBOUND_TIMEOUT_PADDING, Sender, SessionError,
25    TEST_REQUEST_THRESHOLD,
26    acceptor::{ActiveSessionsMap, SessionsMap},
27    application::{Emitter, FixEventInternal},
28    messages_storage::MessagesStorage,
29    session::Session,
30    session_id::SessionId,
31    session_state::State,
32    settings::{SessionSettings, Settings},
33};
34
35mod input_stream;
36pub use input_stream::{InputEvent, InputStream, input_stream};
37
38mod output_stream;
39use output_stream::{OutputEvent, output_stream};
40
41pub mod time;
42use time::{timeout, timeout_at, timeout_stream};
43
44static SENDERS: Mutex<Option<HashMap<SessionId, Sender>>> = Mutex::new(None);
45
46pub fn register_sender(session_id: SessionId, sender: Sender) {
47    if let Entry::Vacant(entry) = SENDERS
48        .lock()
49        .unwrap()
50        .get_or_insert_with(HashMap::new)
51        .entry(session_id)
52    {
53        entry.insert(sender);
54    }
55}
56
57pub fn unregister_sender(session_id: &SessionId) {
58    if SENDERS
59        .lock()
60        .unwrap()
61        .get_or_insert_with(HashMap::new)
62        .remove(session_id)
63        .is_none()
64    {
65        // TODO: ERROR?
66    }
67}
68
69pub fn sender(session_id: &SessionId) -> Option<Sender> {
70    SENDERS
71        .lock()
72        .unwrap()
73        .get_or_insert_with(HashMap::new)
74        .get(session_id)
75        .cloned()
76}
77
78// TODO: Remove?
79pub fn send(session_id: &SessionId, msg: Box<Message>) -> Result<(), Box<Message>> {
80    if let Some(sender) = sender(session_id) {
81        sender.send(msg).map_err(|msg| msg.body)
82    } else {
83        Err(msg)
84    }
85}
86
87pub fn send_raw(msg: Box<FixtMessage>) -> Result<(), Box<FixtMessage>> {
88    if let Some(sender) = sender(&SessionId::from_input_msg(&msg)) {
89        sender.send_raw(msg)
90    } else {
91        Err(msg)
92    }
93}
94
95async fn first_msg(
96    stream: &mut (impl Stream<Item = InputEvent> + Unpin),
97    logon_timeout: Duration,
98) -> Result<Box<FixtMessage>, Error> {
99    match timeout(logon_timeout, stream.next()).await {
100        Ok(Some(InputEvent::Message(msg))) => Ok(msg),
101        Ok(Some(InputEvent::IoError(error))) => Err(error.into()),
102        Ok(Some(InputEvent::DeserializeError(error))) => {
103            error!("failed to deserialize first message: {error}");
104            Err(Error::SessionError(SessionError::LogonNeverReceived))
105        }
106        _ => Err(Error::SessionError(SessionError::LogonNeverReceived)),
107    }
108}
109
110#[derive(Debug)]
111struct Connection<S> {
112    session: Rc<Session<S>>,
113}
114
115pub(crate) async fn acceptor_connection<S>(
116    reader: impl AsyncRead + Unpin,
117    writer: impl AsyncWrite + Unpin,
118    settings: Settings,
119    sessions: Rc<RefCell<SessionsMap<S>>>,
120    active_sessions: Rc<RefCell<ActiveSessionsMap<S>>>,
121    emitter: Emitter,
122) where
123    S: MessagesStorage,
124{
125    let stream = input_stream(reader);
126    let logon_timeout =
127        settings.auto_disconnect_after_no_logon_received + NO_INBOUND_TIMEOUT_PADDING;
128    pin_mut!(stream);
129    let msg = match first_msg(&mut stream, logon_timeout).await {
130        Ok(msg) => msg,
131        Err(err) => {
132            error!("failed to establish new session: {err}");
133            return;
134        }
135    };
136    let session_id = SessionId::from_input_msg(&msg);
137    debug!("first_msg: {msg:?}");
138
139    let (sender, receiver) = mpsc::unbounded_channel();
140    let sender = Sender::new(sender);
141
142    let Some((session_settings, session_state)) = sessions.borrow().get_session(&session_id) else {
143        error!("failed to establish new session: unknown session id {session_id}");
144        return;
145    };
146    if !session_state.borrow_mut().disconnected()
147        || active_sessions.borrow().contains_key(&session_id)
148    {
149        error!(%session_id, "Session already active");
150        return;
151    }
152    session_state.borrow_mut().set_disconnected(false);
153    register_sender(session_id.clone(), sender.clone());
154    let session = Rc::new(Session::new(
155        settings,
156        session_settings,
157        session_state,
158        sender,
159        emitter.clone(),
160    ));
161    active_sessions
162        .borrow_mut()
163        .insert(session_id.clone(), session.clone());
164
165    let session_span = info_span!(
166        parent: None,
167        "session",
168        id = %session_id
169    );
170    session_span.follows_from(Span::current());
171
172    let input_loop_span = info_span!(parent: &session_span, "in");
173    let output_loop_span = info_span!(parent: &session_span, "out");
174
175    let force_disconnection_with_reason = session
176        .on_message_in(msg)
177        .instrument(input_loop_span.clone())
178        .await;
179
180    // TODO: Not here!, send this event when SessionState is created!
181    emitter
182        .send(FixEventInternal::Created(session_id.clone()))
183        .await;
184
185    let input_timeout_duration = session.heartbeat_interval().mul_f32(TEST_REQUEST_THRESHOLD);
186    let input_stream = timeout_stream(input_timeout_duration, stream)
187        .map(|res| res.unwrap_or(InputEvent::Timeout));
188    pin_mut!(input_stream);
189
190    let output_stream = output_stream(session.clone(), session.heartbeat_interval(), receiver);
191    pin_mut!(output_stream);
192
193    let connection = Connection::new(session);
194    let (input_closed_tx, input_closed_rx) = tokio::sync::oneshot::channel();
195
196    tokio::join!(
197        connection
198            .input_loop(
199                input_stream,
200                input_closed_tx,
201                force_disconnection_with_reason
202            )
203            .instrument(input_loop_span),
204        connection
205            .output_loop(writer, output_stream, input_closed_rx)
206            .instrument(output_loop_span),
207    );
208    session_span.in_scope(|| {
209        info!("connection closed");
210    });
211    unregister_sender(&session_id);
212    active_sessions.borrow_mut().remove(&session_id);
213}
214
215pub(crate) async fn initiator_connection<S>(
216    tcp_stream: TcpStream,
217    settings: Settings,
218    session_settings: SessionSettings,
219    state: Rc<RefCell<State<S>>>,
220    active_sessions: Rc<RefCell<ActiveSessionsMap<S>>>,
221    emitter: Emitter,
222) where
223    S: MessagesStorage,
224{
225    let (source, sink) = tcp_stream.into_split();
226    state.borrow_mut().set_disconnected(false);
227    let session_id = session_settings.session_id.clone();
228
229    let (sender, receiver) = mpsc::unbounded_channel();
230    let sender = Sender::new(sender);
231
232    register_sender(session_id.clone(), sender.clone());
233    let session = Rc::new(Session::new(
234        settings,
235        session_settings,
236        state,
237        sender,
238        emitter.clone(),
239    ));
240    active_sessions
241        .borrow_mut()
242        .insert(session_id.clone(), session.clone());
243
244    let session_span = info_span!(
245        "session",
246        id = %session_id
247    );
248
249    let input_loop_span = info_span!(parent: &session_span, "in");
250    let output_loop_span = info_span!(parent: &session_span, "out");
251
252    // TODO: Not here!, send this event when SessionState is created!
253    emitter
254        .send(FixEventInternal::Created(session_id.clone()))
255        .await;
256
257    let input_timeout_duration = session.heartbeat_interval().mul_f32(TEST_REQUEST_THRESHOLD);
258    let input_stream = timeout_stream(input_timeout_duration, input_stream(source))
259        .map(|res| res.unwrap_or(InputEvent::Timeout));
260    pin_mut!(input_stream);
261
262    let output_stream = output_stream(session.clone(), session.heartbeat_interval(), receiver);
263    pin_mut!(output_stream);
264
265    // TODO: It's not so simple, add check if session time is within range,
266    //       if not schedule timer to send logon at proper time
267    session.send_logon_request(&mut session.state().borrow_mut());
268
269    let connection = Connection::new(session);
270    let (input_closed_tx, input_closed_rx) = tokio::sync::oneshot::channel();
271
272    tokio::join!(
273        connection
274            .input_loop(input_stream, input_closed_tx, None)
275            .instrument(input_loop_span),
276        connection
277            .output_loop(sink, output_stream, input_closed_rx)
278            .instrument(output_loop_span),
279    );
280    info!("connection closed");
281    unregister_sender(&session_id);
282    active_sessions.borrow_mut().remove(&session_id);
283}
284
285impl<S: MessagesStorage> Connection<S> {
286    fn new(session: Rc<Session<S>>) -> Connection<S> {
287        Connection { session }
288    }
289
290    async fn input_loop(
291        &self,
292        mut input_stream: impl Stream<Item = InputEvent> + Unpin,
293        input_closed_tx: tokio::sync::oneshot::Sender<()>,
294        force_disconnection_with_reason: Option<DisconnectReason>,
295    ) {
296        if let Some(disconnect_reason) = force_disconnection_with_reason {
297            self.session
298                .disconnect(&mut self.session.state().borrow_mut(), disconnect_reason);
299
300            // Notify output loop that all input is processed so output queue can
301            // be safely closed.
302            // See `fn send()` and `fn send_raw()` from session.rs.
303            input_closed_tx
304                .send(())
305                .expect("Failed to notify about closed inpuot");
306
307            return;
308        }
309
310        let mut disconnect_reason = DisconnectReason::Disconnected;
311        let mut logout_deadline = None;
312
313        let mut next_item = async || {
314            if logout_deadline.is_none() {
315                logout_deadline = self.session.logout_deadline();
316            }
317            if let Some(logout_deadline) = logout_deadline {
318                timeout_at(logout_deadline, input_stream.next())
319                    .await
320                    .unwrap_or(Some(InputEvent::LogoutTimeout))
321            } else {
322                input_stream.next().await
323            }
324        };
325
326        while let Some(event) = next_item().await {
327            // Don't accept new messages if session is disconnected.
328            if self.session.state().borrow().disconnected() {
329                info!("session disconnected, exit input processing");
330                // Notify output loop that all input is processed so output queue can
331                // be safely closed.
332                // See `fn send()` and `fn send_raw()` from session.rs.
333                input_closed_tx
334                    .send(())
335                    .expect("Failed to notify about closed inpout");
336                return;
337            }
338            match event {
339                InputEvent::Message(msg) => {
340                    if let Some(reason) = self.session.on_message_in(msg).await {
341                        info!(?reason, "disconnect, exit input processing");
342                        disconnect_reason = reason;
343                        break;
344                    }
345                }
346                InputEvent::DeserializeError(error) => {
347                    if let Some(reason) = self.session.on_deserialize_error(error).await {
348                        info!(?reason, "disconnect, exit input processing");
349                        disconnect_reason = reason;
350                        break;
351                    }
352                }
353                InputEvent::IoError(error) => {
354                    error!(%error, "Input error");
355                    disconnect_reason = DisconnectReason::IoError;
356                    break;
357                }
358                InputEvent::Timeout => {
359                    if self.session.on_in_timeout().await {
360                        self.session.send_logout(
361                            &mut self.session.state().borrow_mut(),
362                            Some(SessionStatus::SessionLogoutComplete),
363                            Some(FixString::from_ascii_lossy(
364                                b"Grace period is over".to_vec(),
365                            )),
366                        );
367                        break;
368                    }
369                }
370                InputEvent::LogoutTimeout => {
371                    info!("Logout timeout");
372                    disconnect_reason = DisconnectReason::LogoutTimeout;
373                    break;
374                }
375            }
376        }
377        self.session
378            .disconnect(&mut self.session.state().borrow_mut(), disconnect_reason);
379
380        // Notify output loop that all input is processed so output queue can
381        // be safely closed.
382        // See `fn send()` and `fn send_raw()` from session.rs.
383        input_closed_tx
384            .send(())
385            .expect("Failed to notify about closed inpout");
386    }
387
388    async fn output_loop(
389        &self,
390        mut sink: impl AsyncWrite + Unpin,
391        mut output_stream: impl Stream<Item = OutputEvent> + Unpin,
392        input_closed_rx: tokio::sync::oneshot::Receiver<()>,
393    ) {
394        let mut sink_closed = false;
395        let mut disconnect_reason = DisconnectReason::Disconnected;
396        while let Some(event) = output_stream.next().await {
397            match event {
398                OutputEvent::Message(msg) => {
399                    if sink_closed {
400                        // Sink is closed - ignore message, but do not break
401                        // the loop. Output stream has to process all enqueued
402                        // messages to made them available
403                        // for ResendRequest<2>.
404                        info!("Client disconnected, message will be stored for further resend");
405                    } else if let Err(error) = sink.write_all(&msg).await {
406                        sink_closed = true;
407                        error!(%error, "Output write error");
408                        // XXX: Don't disconnect now. If IO error happened
409                        //      here, it will aslo happen in input loop
410                        //      and input loop will trigger disconnection.
411                        //      Disonnection from here would lead to message
412                        //      loss when output queue would be closed
413                        //      and input handler would try to send something.
414                        //
415                        // self.session.disconnect(
416                        //     &mut self.session.state().borrow_mut(),
417                        //     DisconnectReason::IoError,
418                        // );
419                    }
420                }
421                OutputEvent::Timeout => self.session.on_out_timeout().await,
422                OutputEvent::Disconnect(reason) => {
423                    // Internal channel is closed in output stream
424                    // inplementation, at this point no new messages
425                    // can be send.
426                    info!("Client disconnected");
427                    if !sink_closed && let Err(error) = sink.flush().await {
428                        error!(%error, "final flush failed");
429                    }
430                    disconnect_reason = reason;
431                }
432            }
433        }
434        // XXX: Emit logout here instead of Session::disconnect, so `Logout`
435        //      event will be delivered after Logout message instead of
436        //      randomly before or after.
437        self.session.emit_logout(disconnect_reason).await;
438
439        // Don't wait for any specific value it's just notification that
440        // input_loop finished, so no more messages can be added to output
441        // queue.
442        let _ = input_closed_rx.await;
443        if let Err(error) = sink.shutdown().await {
444            error!(%error, "connection shutdown failed")
445        }
446        info!("disconnect, exit output processing");
447    }
448}