1use std::{
2 cell::{Cell, RefCell},
3 collections::{HashMap, hash_map::Entry},
4 rc::Rc,
5 sync::Mutex,
6 time::Duration,
7};
8
9use easyfix_messages::{
10 fields::{FixString, SessionStatus},
11 messages::{FixtMessage, Message},
12};
13use futures_util::{Stream, pin_mut};
14use tokio::{
15 self,
16 io::{AsyncRead, AsyncWrite, AsyncWriteExt},
17 net::TcpStream,
18 sync::{mpsc, oneshot},
19};
20use tokio_stream::StreamExt;
21use tracing::{Instrument, Span, debug, error, info, info_span, warn};
22
23use crate::{
24 DisconnectReason, Error, NO_INBOUND_TIMEOUT_PADDING, Sender, SessionError,
25 TEST_REQUEST_THRESHOLD,
26 acceptor::{ActiveSessionsMap, SessionsMap},
27 application::{Emitter, FixEventInternal},
28 messages_storage::MessagesStorage,
29 session::Session,
30 session_id::SessionId,
31 session_state::State,
32 settings::{SessionSettings, Settings},
33};
34
35mod input_stream;
36pub use input_stream::{InputEvent, InputStream, input_stream};
37
38mod output_stream;
39use output_stream::{OutputEvent, output_stream};
40
41pub mod time;
42use time::{timeout, timeout_at, timeout_stream};
43
44static SENDERS: Mutex<Option<HashMap<SessionId, Sender>>> = Mutex::new(None);
45
46pub fn register_sender(session_id: SessionId, sender: Sender) {
47 if let Entry::Vacant(entry) = SENDERS
48 .lock()
49 .unwrap()
50 .get_or_insert_with(HashMap::new)
51 .entry(session_id)
52 {
53 entry.insert(sender);
54 }
55}
56
57pub fn unregister_sender(session_id: &SessionId) {
58 if SENDERS
59 .lock()
60 .unwrap()
61 .get_or_insert_with(HashMap::new)
62 .remove(session_id)
63 .is_none()
64 {
65 }
67}
68
69pub fn sender(session_id: &SessionId) -> Option<Sender> {
70 SENDERS
71 .lock()
72 .unwrap()
73 .get_or_insert_with(HashMap::new)
74 .get(session_id)
75 .cloned()
76}
77
78pub fn send(session_id: &SessionId, msg: Box<Message>) -> Result<(), Box<Message>> {
80 if let Some(sender) = sender(session_id) {
81 sender.send(msg).map_err(|msg| msg.body)
82 } else {
83 Err(msg)
84 }
85}
86
87pub fn send_raw(msg: Box<FixtMessage>) -> Result<(), Box<FixtMessage>> {
88 if let Some(sender) = sender(&SessionId::from_input_msg(&msg)) {
89 sender.send_raw(msg)
90 } else {
91 Err(msg)
92 }
93}
94
95async fn first_msg(
96 stream: &mut (impl Stream<Item = InputEvent> + Unpin),
97 logon_timeout: Duration,
98) -> Result<Box<FixtMessage>, Error> {
99 match timeout(logon_timeout, stream.next()).await {
100 Ok(Some(InputEvent::Message(msg))) => Ok(msg),
101 Ok(Some(InputEvent::IoError(error))) => Err(error.into()),
102 Ok(Some(InputEvent::DeserializeError(error))) => {
103 error!("failed to deserialize first message: {error}");
104 Err(Error::SessionError(SessionError::LogonNeverReceived))
105 }
106 _ => Err(Error::SessionError(SessionError::LogonNeverReceived)),
107 }
108}
109
110#[derive(Debug)]
111struct Connection<S> {
112 session: Rc<Session<S>>,
113}
114
115struct SessionCleanupGuard<S: MessagesStorage> {
121 session_id: SessionId,
122 state: Rc<RefCell<State<S>>>,
123 active_sessions: Rc<RefCell<ActiveSessionsMap<S>>>,
124 reset_on_disconnect: bool,
125}
126
127impl<S: MessagesStorage> Drop for SessionCleanupGuard<S> {
128 fn drop(&mut self) {
129 unregister_sender(&self.session_id);
130
131 match self.active_sessions.try_borrow_mut() {
133 Ok(mut active_sessions) => {
134 active_sessions.remove(&self.session_id);
135 }
136 Err(_) => error!(
137 session_id = %self.session_id,
138 "session cleanup failed: active sessions map already borrowed"
139 ),
140 }
141
142 match self.state.try_borrow_mut() {
143 Ok(mut state) => {
144 state.set_logon_received(false);
148 state.set_logon_sent(false);
149 if !state.disconnected() {
150 warn!(
151 session_id = %self.session_id,
152 "connection task finished without disconnecting, forcing disconnected state"
153 );
154 state.disconnect(self.reset_on_disconnect);
155 }
156 }
157 Err(_) => error!(
158 session_id = %self.session_id,
159 "session cleanup failed: session state already borrowed"
160 ),
161 }
162 }
163}
164
165pub(crate) async fn acceptor_connection<S>(
166 reader: impl AsyncRead + Unpin,
167 writer: impl AsyncWrite + Unpin,
168 settings: Settings,
169 sessions: Rc<RefCell<SessionsMap<S>>>,
170 active_sessions: Rc<RefCell<ActiveSessionsMap<S>>>,
171 emitter: Emitter,
172 enabled: Rc<Cell<bool>>,
173) where
174 S: MessagesStorage,
175{
176 let stream = input_stream(reader);
177 let logon_timeout =
178 settings.auto_disconnect_after_no_logon_received + NO_INBOUND_TIMEOUT_PADDING;
179 pin_mut!(stream);
180 let msg = match first_msg(&mut stream, logon_timeout).await {
181 Ok(msg) => msg,
182 Err(err) => {
183 error!(%err, "failed to establish new session");
184 return;
185 }
186 };
187
188 let session_id = SessionId::from_input_msg(&msg);
189 debug!(first_msg = ?msg);
190
191 if !enabled.get() {
193 warn!("Acceptor is disabled, drop connection");
194 return;
195 }
196
197 let (sender, receiver) = mpsc::unbounded_channel();
198 let sender = Sender::new(sender);
199
200 let Some((session_settings, session_state)) = sessions.borrow().get_session(&session_id) else {
201 error!(%session_id, "failed to establish new session: unknown session id");
202 return;
203 };
204 if !session_state.borrow_mut().disconnected()
205 || active_sessions.borrow().contains_key(&session_id)
206 {
207 error!(%session_id, "Session already active");
208 return;
209 }
210 session_state.borrow_mut().set_disconnected(false);
211 register_sender(session_id.clone(), sender.clone());
212
213 let _cleanup_guard = SessionCleanupGuard {
214 session_id: session_id.clone(),
215 state: session_state.clone(),
216 active_sessions: active_sessions.clone(),
217 reset_on_disconnect: session_settings.reset_on_disconnect,
218 };
219
220 let (disconnect_tx, disconnect_rx) = oneshot::channel();
221
222 let session = Rc::new(Session::new(
223 settings,
224 session_settings,
225 session_state,
226 sender,
227 emitter.clone(),
228 disconnect_tx,
229 ));
230
231 active_sessions
232 .borrow_mut()
233 .insert(session_id.clone(), session.clone());
234
235 let session_span = info_span!(
236 parent: None,
237 "session",
238 id = %session_id
239 );
240 session_span.follows_from(Span::current());
241
242 let input_loop_span = info_span!(parent: &session_span, "in");
243 let output_loop_span = info_span!(parent: &session_span, "out");
244
245 let force_disconnection_with_reason = session
246 .on_message_in(msg)
247 .instrument(input_loop_span.clone())
248 .await;
249
250 emitter
252 .send(FixEventInternal::Created(session_id.clone()))
253 .await;
254
255 let input_timeout_duration = session.heartbeat_interval().mul_f32(TEST_REQUEST_THRESHOLD);
256 let input_stream = timeout_stream(input_timeout_duration, stream)
257 .map(|res| res.unwrap_or(InputEvent::Timeout));
258 pin_mut!(input_stream);
259
260 let output_stream = output_stream(session.clone(), session.heartbeat_interval(), receiver);
261 pin_mut!(output_stream);
262
263 let connection = Connection::new(session);
264 let (input_closed_tx, input_closed_rx) = oneshot::channel();
265
266 tokio::join!(
267 connection
268 .input_loop(
269 input_stream,
270 input_closed_tx,
271 force_disconnection_with_reason,
272 disconnect_rx,
273 )
274 .instrument(input_loop_span),
275 connection
276 .output_loop(writer, output_stream, input_closed_rx)
277 .instrument(output_loop_span),
278 );
279 session_span.in_scope(|| {
280 info!("connection closed");
281 });
282}
283
284pub(crate) async fn initiator_connection<S>(
285 tcp_stream: TcpStream,
286 settings: Settings,
287 session_settings: SessionSettings,
288 state: Rc<RefCell<State<S>>>,
289 active_sessions: Rc<RefCell<ActiveSessionsMap<S>>>,
290 emitter: Emitter,
291) where
292 S: MessagesStorage,
293{
294 let (source, sink) = tcp_stream.into_split();
295 state.borrow_mut().set_disconnected(false);
296 let session_id = session_settings.session_id.clone();
297
298 let (sender, receiver) = mpsc::unbounded_channel();
299 let sender = Sender::new(sender);
300
301 let (disconnect_tx, disconnect_rx) = oneshot::channel();
302
303 register_sender(session_id.clone(), sender.clone());
304
305 let _cleanup_guard = SessionCleanupGuard {
306 session_id: session_id.clone(),
307 state: state.clone(),
308 active_sessions: active_sessions.clone(),
309 reset_on_disconnect: session_settings.reset_on_disconnect,
310 };
311
312 let session = Rc::new(Session::new(
313 settings,
314 session_settings,
315 state,
316 sender,
317 emitter.clone(),
318 disconnect_tx,
319 ));
320 active_sessions
321 .borrow_mut()
322 .insert(session_id.clone(), session.clone());
323
324 let session_span = info_span!(
325 "session",
326 id = %session_id
327 );
328
329 let input_loop_span = info_span!(parent: &session_span, "in");
330 let output_loop_span = info_span!(parent: &session_span, "out");
331
332 emitter
334 .send(FixEventInternal::Created(session_id.clone()))
335 .await;
336
337 let input_timeout_duration = session.heartbeat_interval().mul_f32(TEST_REQUEST_THRESHOLD);
338 let input_stream = timeout_stream(input_timeout_duration, input_stream(source))
339 .map(|res| res.unwrap_or(InputEvent::Timeout));
340 pin_mut!(input_stream);
341
342 let output_stream = output_stream(session.clone(), session.heartbeat_interval(), receiver);
343 pin_mut!(output_stream);
344
345 session.send_logon_request(&mut session.state().borrow_mut());
348
349 let connection = Connection::new(session);
350 let (input_closed_tx, input_closed_rx) = oneshot::channel();
351
352 tokio::join!(
353 connection
354 .input_loop(input_stream, input_closed_tx, None, disconnect_rx)
355 .instrument(input_loop_span),
356 connection
357 .output_loop(sink, output_stream, input_closed_rx)
358 .instrument(output_loop_span),
359 );
360 info!("connection closed");
361}
362
363impl<S: MessagesStorage> Connection<S> {
364 fn new(session: Rc<Session<S>>) -> Connection<S> {
365 Connection { session }
366 }
367
368 async fn input_loop(
369 &self,
370 mut input_stream: impl Stream<Item = InputEvent> + Unpin,
371 input_closed_tx: oneshot::Sender<()>,
372 force_disconnection_with_reason: Option<DisconnectReason>,
373 mut disconnect_rx: oneshot::Receiver<()>,
374 ) {
375 if let Some(disconnect_reason) = force_disconnection_with_reason {
376 self.session
377 .disconnect(&mut self.session.state().borrow_mut(), disconnect_reason);
378
379 input_closed_tx
383 .send(())
384 .expect("Failed to notify about closed inpuot");
385
386 return;
387 }
388
389 let mut disconnect_reason = DisconnectReason::Disconnected;
390 let mut logout_deadline = None;
391
392 let mut next_item = async || {
393 if logout_deadline.is_none() {
394 logout_deadline = self.session.logout_deadline();
395 }
396 if let Some(logout_deadline) = logout_deadline {
397 timeout_at(logout_deadline, input_stream.next())
398 .await
399 .unwrap_or(Some(InputEvent::LogoutTimeout))
400 } else {
401 input_stream.next().await
402 }
403 };
404
405 loop {
406 let event = tokio::select! {
407 event = next_item() => {
409 if let Some(event) = event {
410 event
412 } else {
413 break
414 }
415 }
416
417 _ = &mut disconnect_rx => {
419 info!("Disconnect signaled, exiting input loop");
420 disconnect_reason = DisconnectReason::ApplicationForcedDisconnect;
421 break;
422 }
423 };
424
425 if self.session.state().borrow().disconnected() {
427 info!("session disconnected, exit input processing");
428 input_closed_tx
432 .send(())
433 .expect("Failed to notify about closed input");
434 return;
435 }
436
437 match event {
438 InputEvent::Message(msg) => {
439 if let Some(reason) = self.session.on_message_in(msg).await {
440 info!(?reason, "disconnect, exit input processing");
441 disconnect_reason = reason;
442 break;
443 }
444 }
445 InputEvent::DeserializeError(error) => {
446 if let Some(reason) = self.session.on_deserialize_error(error).await {
447 info!(?reason, "disconnect, exit input processing");
448 disconnect_reason = reason;
449 break;
450 }
451 }
452 InputEvent::IoError(error) => {
453 error!(%error, "Input error");
454 disconnect_reason = DisconnectReason::IoError;
455 break;
456 }
457 InputEvent::Timeout => {
458 if self.session.on_in_timeout().await {
459 self.session.send_logout(
460 &mut self.session.state().borrow_mut(),
461 Some(SessionStatus::SessionLogoutComplete),
462 Some(FixString::from_ascii_lossy(
463 b"Grace period is over".to_vec(),
464 )),
465 );
466 break;
467 }
468 }
469 InputEvent::LogoutTimeout => {
470 info!("Logout timeout");
471 disconnect_reason = DisconnectReason::LogoutTimeout;
472 break;
473 }
474 }
475 }
476 self.session
477 .disconnect(&mut self.session.state().borrow_mut(), disconnect_reason);
478
479 input_closed_tx
483 .send(())
484 .expect("Failed to notify about closed inpout");
485 }
486
487 async fn output_loop(
488 &self,
489 mut sink: impl AsyncWrite + Unpin,
490 mut output_stream: impl Stream<Item = OutputEvent> + Unpin,
491 input_closed_rx: oneshot::Receiver<()>,
492 ) {
493 let mut sink_closed = false;
494 let mut disconnect_reason = DisconnectReason::Disconnected;
495 while let Some(event) = output_stream.next().await {
496 match event {
497 OutputEvent::Message(msg) => {
498 if sink_closed {
499 info!("Client disconnected, message will be stored for further resend");
504 } else if let Err(error) = sink.write_all(&msg).await {
505 sink_closed = true;
506 error!(%error, "Output write error");
507 }
519 }
520 OutputEvent::Timeout => self.session.on_out_timeout().await,
521 OutputEvent::Disconnect(reason) => {
522 info!("Client disconnected");
526 if !sink_closed && let Err(error) = sink.flush().await {
527 error!(%error, "final flush failed");
528 }
529 disconnect_reason = reason;
530 }
531 }
532 }
533 self.session.emit_logout(disconnect_reason).await;
537
538 let _ = input_closed_rx.await;
542 if let Err(error) = sink.shutdown().await {
543 error!(%error, "connection shutdown failed")
544 }
545 info!("disconnect, exit output processing");
546 }
547}