easyfix_session/
io.rs

1use std::{
2    cell::RefCell,
3    collections::{hash_map::Entry, HashMap},
4    rc::Rc,
5    sync::Mutex,
6};
7
8use easyfix_messages::{
9    fields::{FixString, SessionStatus},
10    messages::{FixtMessage, Message},
11};
12use futures_util::{pin_mut, Stream};
13use tokio::{
14    self,
15    io::{AsyncRead, AsyncWrite, AsyncWriteExt},
16    net::TcpStream,
17    sync::mpsc,
18    time::Duration,
19};
20use tokio_stream::StreamExt;
21use tracing::{debug, error, info, info_span, Instrument};
22
23use crate::{
24    acceptor::{ActiveSessionsMap, SessionsMap},
25    application::{Emitter, FixEventInternal},
26    messages_storage::MessagesStorage,
27    session::Session,
28    session_id::SessionId,
29    session_state::State,
30    settings::{SessionSettings, Settings},
31    DisconnectReason, Error, Sender, SessionError, NO_INBOUND_TIMEOUT_PADDING,
32    TEST_REQUEST_THRESHOLD,
33};
34
35mod input_stream;
36pub use input_stream::{input_stream, InputEvent, InputStream};
37
38mod output_stream;
39use output_stream::{output_stream, OutputEvent};
40
41pub mod time;
42use time::{timeout, timeout_stream};
43
44static SENDERS: Mutex<Option<HashMap<SessionId, Sender>>> = Mutex::new(None);
45
46pub fn register_sender(session_id: SessionId, sender: Sender) {
47    if let Entry::Vacant(entry) = SENDERS
48        .lock()
49        .unwrap()
50        .get_or_insert_with(HashMap::new)
51        .entry(session_id)
52    {
53        entry.insert(sender);
54    }
55}
56
57pub fn unregister_sender(session_id: &SessionId) {
58    if SENDERS
59        .lock()
60        .unwrap()
61        .get_or_insert_with(HashMap::new)
62        .remove(session_id)
63        .is_none()
64    {
65        // TODO: ERROR?
66    }
67}
68
69pub fn sender(session_id: &SessionId) -> Option<Sender> {
70    SENDERS
71        .lock()
72        .unwrap()
73        .get_or_insert_with(HashMap::new)
74        .get(session_id)
75        .cloned()
76}
77
78// TODO: Remove?
79pub fn send(session_id: &SessionId, msg: Box<Message>) -> Result<(), Box<Message>> {
80    if let Some(sender) = sender(session_id) {
81        sender.send(msg).map_err(|msg| msg.body)
82    } else {
83        Err(msg)
84    }
85}
86
87pub fn send_raw(msg: Box<FixtMessage>) -> Result<(), Box<FixtMessage>> {
88    if let Some(sender) = sender(&SessionId::from_input_msg(&msg)) {
89        sender.send_raw(msg)
90    } else {
91        Err(msg)
92    }
93}
94
95async fn first_msg(
96    stream: &mut (impl Stream<Item = InputEvent> + Unpin),
97    logon_timeout: Duration,
98) -> Result<Box<FixtMessage>, Error> {
99    match timeout(logon_timeout, stream.next()).await {
100        Ok(Some(InputEvent::Message(msg))) => Ok(msg),
101        Ok(Some(InputEvent::IoError(error))) => Err(error.into()),
102        Ok(Some(InputEvent::DeserializeError(error))) => {
103            error!("failed to deserialize first message: {error}");
104            Err(Error::SessionError(SessionError::LogonNeverReceived))
105        }
106        _ => Err(Error::SessionError(SessionError::LogonNeverReceived)),
107    }
108}
109
110#[derive(Debug)]
111struct Connection<S> {
112    session: Rc<Session<S>>,
113}
114
115pub(crate) async fn acceptor_connection<S>(
116    reader: impl AsyncRead + Unpin,
117    writer: impl AsyncWrite + Unpin,
118    settings: Settings,
119    sessions: Rc<RefCell<SessionsMap<S>>>,
120    active_sessions: Rc<RefCell<ActiveSessionsMap<S>>>,
121    emitter: Emitter,
122) where
123    S: MessagesStorage,
124{
125    let stream = input_stream(reader);
126    let logon_timeout =
127        settings.auto_disconnect_after_no_logon_received + NO_INBOUND_TIMEOUT_PADDING;
128    pin_mut!(stream);
129    let msg = match first_msg(&mut stream, logon_timeout).await {
130        Ok(msg) => msg,
131        Err(err) => {
132            error!("failed to establish new session: {err}");
133            return;
134        }
135    };
136    let session_id = SessionId::from_input_msg(&msg);
137    debug!("first_msg: {msg:?}");
138
139    let (sender, receiver) = mpsc::unbounded_channel();
140    let sender = Sender::new(sender);
141
142    let Some((session_settings, session_state)) = sessions.borrow().get_session(&session_id) else {
143        error!("failed to establish new session: unknown session id {session_id}");
144        return;
145    };
146    session_state.borrow_mut().set_disconnected(false);
147    register_sender(session_id.clone(), sender.clone());
148    let session = Rc::new(Session::new(
149        settings,
150        session_settings,
151        session_state,
152        sender,
153        emitter.clone(),
154    ));
155    active_sessions
156        .borrow_mut()
157        .insert(session_id.clone(), session.clone());
158
159    let session_span = info_span!(
160        "session",
161        id = %session_id
162    );
163
164    let input_loop_span = info_span!(parent: &session_span, "in");
165    let output_loop_span = info_span!(parent: &session_span, "out");
166
167    let force_disconnection_with_reason = session
168        .on_message_in(msg)
169        .instrument(input_loop_span.clone())
170        .await;
171
172    // TODO: Not here!, send this event when SessionState is created!
173    emitter
174        .send(FixEventInternal::Created(session_id.clone()))
175        .await;
176
177    let input_timeout_duration = session.heartbeat_interval().mul_f32(TEST_REQUEST_THRESHOLD);
178    let input_stream = timeout_stream(input_timeout_duration, stream)
179        .map(|res| res.unwrap_or(InputEvent::Timeout));
180    pin_mut!(input_stream);
181
182    let output_stream = output_stream(session.clone(), session.heartbeat_interval(), receiver);
183    pin_mut!(output_stream);
184
185    let connection = Connection::new(session);
186    let (input_closed_tx, input_closed_rx) = tokio::sync::oneshot::channel();
187
188    tokio::join!(
189        connection
190            .input_loop(
191                input_stream,
192                input_closed_tx,
193                force_disconnection_with_reason
194            )
195            .instrument(input_loop_span.clone()),
196        connection
197            .output_loop(writer, output_stream, input_closed_rx)
198            .instrument(output_loop_span),
199    );
200    session_span.in_scope(|| {
201        info!("connection closed");
202    });
203    unregister_sender(&session_id);
204    active_sessions.borrow_mut().remove(&session_id);
205}
206
207pub(crate) async fn initiator_connection<S>(
208    tcp_stream: TcpStream,
209    settings: Settings,
210    session_settings: SessionSettings,
211    state: Rc<RefCell<State<S>>>,
212    active_sessions: Rc<RefCell<ActiveSessionsMap<S>>>,
213    emitter: Emitter,
214) where
215    S: MessagesStorage,
216{
217    let (source, sink) = tcp_stream.into_split();
218    state.borrow_mut().set_disconnected(false);
219    let session_id = session_settings.session_id.clone();
220
221    let (sender, receiver) = mpsc::unbounded_channel();
222    let sender = Sender::new(sender);
223
224    register_sender(session_id.clone(), sender.clone());
225    let session = Rc::new(Session::new(
226        settings,
227        session_settings,
228        state,
229        sender,
230        emitter.clone(),
231    ));
232    active_sessions
233        .borrow_mut()
234        .insert(session_id.clone(), session.clone());
235
236    let session_span = info_span!(
237        "session",
238        id = %session_id
239    );
240
241    let input_loop_span = info_span!(parent: &session_span, "in");
242    let output_loop_span = info_span!(parent: &session_span, "out");
243
244    // TODO: Not here!, send this event when SessionState is created!
245    emitter
246        .send(FixEventInternal::Created(session_id.clone()))
247        .await;
248
249    let input_timeout_duration = session.heartbeat_interval().mul_f32(TEST_REQUEST_THRESHOLD);
250    let input_stream = timeout_stream(input_timeout_duration, input_stream(source))
251        .map(|res| res.unwrap_or(InputEvent::Timeout));
252    pin_mut!(input_stream);
253
254    let output_stream = output_stream(session.clone(), session.heartbeat_interval(), receiver);
255    pin_mut!(output_stream);
256
257    // TODO: It's not so simple, add check if session time is within range,
258    //       if not schedule timer to send logon at proper time
259    session.send_logon_request(&mut session.state().borrow_mut());
260
261    let connection = Connection::new(session);
262    let (input_closed_tx, input_closed_rx) = tokio::sync::oneshot::channel();
263
264    tokio::join!(
265        connection
266            .input_loop(input_stream, input_closed_tx, None)
267            .instrument(input_loop_span),
268        connection
269            .output_loop(sink, output_stream, input_closed_rx)
270            .instrument(output_loop_span),
271    );
272    info!("connection closed");
273    unregister_sender(&session_id);
274    active_sessions.borrow_mut().remove(&session_id);
275}
276
277impl<S: MessagesStorage> Connection<S> {
278    fn new(session: Rc<Session<S>>) -> Connection<S> {
279        Connection { session }
280    }
281
282    async fn input_loop(
283        &self,
284        mut input_stream: impl Stream<Item = InputEvent> + Unpin,
285        input_closed_tx: tokio::sync::oneshot::Sender<()>,
286        force_disconnection_with_reason: Option<DisconnectReason>,
287    ) {
288        if let Some(disconnect_reason) = force_disconnection_with_reason {
289            self.session
290                .disconnect(&mut self.session.state().borrow_mut(), disconnect_reason);
291
292            // Notify output loop that all input is processed so output queue can
293            // be safely closed.
294            // See `fn send()` and `fn send_raw()` from session.rs.
295            input_closed_tx
296                .send(())
297                .expect("Failed to notify about closed inpuot");
298
299            return;
300        }
301
302        let mut disconnect_reason = DisconnectReason::Disconnected;
303
304        while let Some(event) = input_stream.next().await {
305            // Don't accept new messages if session is disconnected.
306            if self.session.state().borrow().disconnected() {
307                info!("session disconnected, exit input processing");
308                // Notify output loop that all input is processed so output queue can
309                // be safely closed.
310                // See `fn send()` and `fn send_raw()` from session.rs.
311                input_closed_tx
312                    .send(())
313                    .expect("Failed to notify about closed inpout");
314                return;
315            }
316            match event {
317                InputEvent::Message(msg) => {
318                    if let Some(dr) = self.session.on_message_in(msg).await {
319                        info!("disconnect ({dr:?}), exit input processing");
320                        disconnect_reason = dr;
321                        break;
322                    }
323                }
324                InputEvent::DeserializeError(error) => {
325                    if let Some(dr) = self.session.on_deserialize_error(error).await {
326                        info!("disconnect ({dr:?}), exit input processing");
327                        disconnect_reason = dr;
328                        break;
329                    }
330                }
331                InputEvent::IoError(error) => {
332                    error!("Input error: {error:?}");
333                    disconnect_reason = DisconnectReason::IoError;
334                    break;
335                }
336                InputEvent::Timeout => {
337                    if self.session.on_in_timeout().await {
338                        self.session.send_logout(
339                            &mut self.session.state().borrow_mut(),
340                            Some(SessionStatus::SessionLogoutComplete),
341                            Some(FixString::from_ascii_lossy(
342                                b"Grace period is over".to_vec(),
343                            )),
344                        );
345                        break;
346                    }
347                }
348            }
349        }
350        self.session
351            .disconnect(&mut self.session.state().borrow_mut(), disconnect_reason);
352
353        // Notify output loop that all input is processed so output queue can
354        // be safely closed.
355        // See `fn send()` and `fn send_raw()` from session.rs.
356        input_closed_tx
357            .send(())
358            .expect("Failed to notify about closed inpout");
359    }
360
361    async fn output_loop(
362        &self,
363        mut sink: impl AsyncWrite + Unpin,
364        mut output_stream: impl Stream<Item = OutputEvent> + Unpin,
365        input_closed_rx: tokio::sync::oneshot::Receiver<()>,
366    ) {
367        let mut sink_closed = false;
368        let mut disconnect_reason = DisconnectReason::Disconnected;
369        while let Some(event) = output_stream.next().await {
370            match event {
371                OutputEvent::Message(msg) => {
372                    if sink_closed {
373                        // Sink is closed - ignore message, but do not break
374                        // the loop. Output stream has to process all enqueued
375                        // messages to made them available
376                        // for ResendRequest<2>.
377                        info!("Client disconnected, message will be stored for further resend");
378                    } else if let Err(error) = sink.write_all(&msg).await {
379                        sink_closed = true;
380                        error!("Output write error: {error:?}");
381                        // XXX: Don't disconnect now. If IO error happened
382                        //      here, it will aslo happen in input loop
383                        //      and input loop will trigger disconnection.
384                        //      Disonnection from here would lead to message
385                        //      loss when output queue would be closed
386                        //      and input handler would try to send something.
387                        //
388                        // self.session.disconnect(
389                        //     &mut self.session.state().borrow_mut(),
390                        //     DisconnectReason::IoError,
391                        // );
392                    }
393                }
394                OutputEvent::Timeout => self.session.on_out_timeout().await,
395                OutputEvent::Disconnect(reason) => {
396                    // Internal channel is closed in output stream
397                    // inplementation, at this point no new messages
398                    // can be send.
399                    info!("Client disconnected");
400                    if !sink_closed {
401                        if let Err(e) = sink.flush().await {
402                            error!("final flush failed: {e}");
403                        }
404                    }
405                    disconnect_reason = reason;
406                }
407            }
408        }
409        // XXX: Emit logout here instead of Session::disconnect, so `Logout`
410        //      event will be delivered after Logout message instead of
411        //      randomly before or after.
412        self.session.emit_logout(disconnect_reason).await;
413
414        // Don't wait for any specific value it's just notification that
415        // input_loop finished, so no more messages can be added to output
416        // queue.
417        let _ = input_closed_rx.await;
418        info!("disconnect, exit output processing");
419    }
420}