1use std::{
2 cell::{Cell, RefCell},
3 collections::{HashMap, hash_map::Entry},
4 rc::Rc,
5 sync::Mutex,
6 time::Duration,
7};
8
9use easyfix_messages::{
10 fields::{FixString, SessionStatus},
11 messages::{FixtMessage, Message},
12};
13use futures_util::{Stream, pin_mut};
14use tokio::{
15 self,
16 io::{AsyncRead, AsyncWrite, AsyncWriteExt},
17 net::TcpStream,
18 sync::{mpsc, oneshot},
19};
20use tokio_stream::StreamExt;
21use tracing::{Instrument, Span, debug, error, info, info_span, warn};
22
23use crate::{
24 DisconnectReason, Error, NO_INBOUND_TIMEOUT_PADDING, Sender, SessionError,
25 TEST_REQUEST_THRESHOLD,
26 acceptor::{ActiveSessionsMap, SessionsMap},
27 application::{Emitter, FixEventInternal},
28 messages_storage::MessagesStorage,
29 session::Session,
30 session_id::SessionId,
31 session_state::State,
32 settings::{SessionSettings, Settings},
33};
34
35mod input_stream;
36pub use input_stream::{InputEvent, InputStream, input_stream};
37
38mod output_stream;
39use output_stream::{OutputEvent, output_stream};
40
41pub mod time;
42use time::{timeout, timeout_at, timeout_stream};
43
44static SENDERS: Mutex<Option<HashMap<SessionId, Sender>>> = Mutex::new(None);
45
46pub fn register_sender(session_id: SessionId, sender: Sender) {
47 if let Entry::Vacant(entry) = SENDERS
48 .lock()
49 .unwrap()
50 .get_or_insert_with(HashMap::new)
51 .entry(session_id)
52 {
53 entry.insert(sender);
54 }
55}
56
57pub fn unregister_sender(session_id: &SessionId) {
58 if SENDERS
59 .lock()
60 .unwrap()
61 .get_or_insert_with(HashMap::new)
62 .remove(session_id)
63 .is_none()
64 {
65 }
67}
68
69pub fn sender(session_id: &SessionId) -> Option<Sender> {
70 SENDERS
71 .lock()
72 .unwrap()
73 .get_or_insert_with(HashMap::new)
74 .get(session_id)
75 .cloned()
76}
77
78pub fn send(session_id: &SessionId, msg: Box<Message>) -> Result<(), Box<Message>> {
80 if let Some(sender) = sender(session_id) {
81 sender.send(msg).map_err(|msg| msg.body)
82 } else {
83 Err(msg)
84 }
85}
86
87pub fn send_raw(msg: Box<FixtMessage>) -> Result<(), Box<FixtMessage>> {
88 if let Some(sender) = sender(&SessionId::from_input_msg(&msg)) {
89 sender.send_raw(msg)
90 } else {
91 Err(msg)
92 }
93}
94
95async fn first_msg(
96 stream: &mut (impl Stream<Item = InputEvent> + Unpin),
97 logon_timeout: Duration,
98) -> Result<Box<FixtMessage>, Error> {
99 match timeout(logon_timeout, stream.next()).await {
100 Ok(Some(InputEvent::Message(msg))) => Ok(msg),
101 Ok(Some(InputEvent::IoError(error))) => Err(error.into()),
102 Ok(Some(InputEvent::DeserializeError(error))) => {
103 error!("failed to deserialize first message: {error}");
104 Err(Error::SessionError(SessionError::LogonNeverReceived))
105 }
106 _ => Err(Error::SessionError(SessionError::LogonNeverReceived)),
107 }
108}
109
110#[derive(Debug)]
111struct Connection<S> {
112 session: Rc<Session<S>>,
113}
114
115pub(crate) async fn acceptor_connection<S>(
116 reader: impl AsyncRead + Unpin,
117 writer: impl AsyncWrite + Unpin,
118 settings: Settings,
119 sessions: Rc<RefCell<SessionsMap<S>>>,
120 active_sessions: Rc<RefCell<ActiveSessionsMap<S>>>,
121 emitter: Emitter,
122 enabled: Rc<Cell<bool>>,
123) where
124 S: MessagesStorage,
125{
126 let stream = input_stream(reader);
127 let logon_timeout =
128 settings.auto_disconnect_after_no_logon_received + NO_INBOUND_TIMEOUT_PADDING;
129 pin_mut!(stream);
130 let msg = match first_msg(&mut stream, logon_timeout).await {
131 Ok(msg) => msg,
132 Err(err) => {
133 error!(%err, "failed to establish new session");
134 return;
135 }
136 };
137
138 let session_id = SessionId::from_input_msg(&msg);
139 debug!(first_msg = ?msg);
140
141 if !enabled.get() {
143 warn!("Acceptor is disabled, drop connection");
144 return;
145 }
146
147 let (sender, receiver) = mpsc::unbounded_channel();
148 let sender = Sender::new(sender);
149
150 let Some((session_settings, session_state)) = sessions.borrow().get_session(&session_id) else {
151 error!(%session_id, "failed to establish new session: unknown session id");
152 return;
153 };
154 if !session_state.borrow_mut().disconnected()
155 || active_sessions.borrow().contains_key(&session_id)
156 {
157 error!(%session_id, "Session already active");
158 return;
159 }
160 session_state.borrow_mut().set_disconnected(false);
161 register_sender(session_id.clone(), sender.clone());
162
163 let (disconnect_tx, disconnect_rx) = oneshot::channel();
164
165 let session = Rc::new(Session::new(
166 settings,
167 session_settings,
168 session_state,
169 sender,
170 emitter.clone(),
171 disconnect_tx,
172 ));
173
174 active_sessions
175 .borrow_mut()
176 .insert(session_id.clone(), session.clone());
177
178 let session_span = info_span!(
179 parent: None,
180 "session",
181 id = %session_id
182 );
183 session_span.follows_from(Span::current());
184
185 let input_loop_span = info_span!(parent: &session_span, "in");
186 let output_loop_span = info_span!(parent: &session_span, "out");
187
188 let force_disconnection_with_reason = session
189 .on_message_in(msg)
190 .instrument(input_loop_span.clone())
191 .await;
192
193 emitter
195 .send(FixEventInternal::Created(session_id.clone()))
196 .await;
197
198 let input_timeout_duration = session.heartbeat_interval().mul_f32(TEST_REQUEST_THRESHOLD);
199 let input_stream = timeout_stream(input_timeout_duration, stream)
200 .map(|res| res.unwrap_or(InputEvent::Timeout));
201 pin_mut!(input_stream);
202
203 let output_stream = output_stream(session.clone(), session.heartbeat_interval(), receiver);
204 pin_mut!(output_stream);
205
206 let connection = Connection::new(session);
207 let (input_closed_tx, input_closed_rx) = oneshot::channel();
208
209 tokio::join!(
210 connection
211 .input_loop(
212 input_stream,
213 input_closed_tx,
214 force_disconnection_with_reason,
215 disconnect_rx,
216 )
217 .instrument(input_loop_span),
218 connection
219 .output_loop(writer, output_stream, input_closed_rx)
220 .instrument(output_loop_span),
221 );
222 session_span.in_scope(|| {
223 info!("connection closed");
224 });
225 unregister_sender(&session_id);
226 active_sessions.borrow_mut().remove(&session_id);
227}
228
229pub(crate) async fn initiator_connection<S>(
230 tcp_stream: TcpStream,
231 settings: Settings,
232 session_settings: SessionSettings,
233 state: Rc<RefCell<State<S>>>,
234 active_sessions: Rc<RefCell<ActiveSessionsMap<S>>>,
235 emitter: Emitter,
236) where
237 S: MessagesStorage,
238{
239 let (source, sink) = tcp_stream.into_split();
240 state.borrow_mut().set_disconnected(false);
241 let session_id = session_settings.session_id.clone();
242
243 let (sender, receiver) = mpsc::unbounded_channel();
244 let sender = Sender::new(sender);
245
246 let (disconnect_tx, disconnect_rx) = oneshot::channel();
247
248 register_sender(session_id.clone(), sender.clone());
249 let session = Rc::new(Session::new(
250 settings,
251 session_settings,
252 state,
253 sender,
254 emitter.clone(),
255 disconnect_tx,
256 ));
257 active_sessions
258 .borrow_mut()
259 .insert(session_id.clone(), session.clone());
260
261 let session_span = info_span!(
262 "session",
263 id = %session_id
264 );
265
266 let input_loop_span = info_span!(parent: &session_span, "in");
267 let output_loop_span = info_span!(parent: &session_span, "out");
268
269 emitter
271 .send(FixEventInternal::Created(session_id.clone()))
272 .await;
273
274 let input_timeout_duration = session.heartbeat_interval().mul_f32(TEST_REQUEST_THRESHOLD);
275 let input_stream = timeout_stream(input_timeout_duration, input_stream(source))
276 .map(|res| res.unwrap_or(InputEvent::Timeout));
277 pin_mut!(input_stream);
278
279 let output_stream = output_stream(session.clone(), session.heartbeat_interval(), receiver);
280 pin_mut!(output_stream);
281
282 session.send_logon_request(&mut session.state().borrow_mut());
285
286 let connection = Connection::new(session);
287 let (input_closed_tx, input_closed_rx) = oneshot::channel();
288
289 tokio::join!(
290 connection
291 .input_loop(input_stream, input_closed_tx, None, disconnect_rx)
292 .instrument(input_loop_span),
293 connection
294 .output_loop(sink, output_stream, input_closed_rx)
295 .instrument(output_loop_span),
296 );
297 info!("connection closed");
298 unregister_sender(&session_id);
299 active_sessions.borrow_mut().remove(&session_id);
300}
301
302impl<S: MessagesStorage> Connection<S> {
303 fn new(session: Rc<Session<S>>) -> Connection<S> {
304 Connection { session }
305 }
306
307 async fn input_loop(
308 &self,
309 mut input_stream: impl Stream<Item = InputEvent> + Unpin,
310 input_closed_tx: oneshot::Sender<()>,
311 force_disconnection_with_reason: Option<DisconnectReason>,
312 mut disconnect_rx: oneshot::Receiver<()>,
313 ) {
314 if let Some(disconnect_reason) = force_disconnection_with_reason {
315 self.session
316 .disconnect(&mut self.session.state().borrow_mut(), disconnect_reason);
317
318 input_closed_tx
322 .send(())
323 .expect("Failed to notify about closed inpuot");
324
325 return;
326 }
327
328 let mut disconnect_reason = DisconnectReason::Disconnected;
329 let mut logout_deadline = None;
330
331 let mut next_item = async || {
332 if logout_deadline.is_none() {
333 logout_deadline = self.session.logout_deadline();
334 }
335 if let Some(logout_deadline) = logout_deadline {
336 timeout_at(logout_deadline, input_stream.next())
337 .await
338 .unwrap_or(Some(InputEvent::LogoutTimeout))
339 } else {
340 input_stream.next().await
341 }
342 };
343
344 loop {
345 let event = tokio::select! {
346 event = next_item() => {
348 if let Some(event) = event {
349 event
350 } else {
351 break
352 }
353 }
354
355 _ = &mut disconnect_rx => {
357 info!("Disconnect signaled, exiting input loop");
358 disconnect_reason = DisconnectReason::ApplicationForcedDisconnect;
359 break;
360 }
361 };
362
363 if self.session.state().borrow().disconnected() {
365 info!("session disconnected, exit input processing");
366 input_closed_tx
370 .send(())
371 .expect("Failed to notify about closed input");
372 return;
373 }
374
375 match event {
376 InputEvent::Message(msg) => {
377 if let Some(reason) = self.session.on_message_in(msg).await {
378 info!(?reason, "disconnect, exit input processing");
379 disconnect_reason = reason;
380 break;
381 }
382 }
383 InputEvent::DeserializeError(error) => {
384 if let Some(reason) = self.session.on_deserialize_error(error).await {
385 info!(?reason, "disconnect, exit input processing");
386 disconnect_reason = reason;
387 break;
388 }
389 }
390 InputEvent::IoError(error) => {
391 error!(%error, "Input error");
392 disconnect_reason = DisconnectReason::IoError;
393 break;
394 }
395 InputEvent::Timeout => {
396 if self.session.on_in_timeout().await {
397 self.session.send_logout(
398 &mut self.session.state().borrow_mut(),
399 Some(SessionStatus::SessionLogoutComplete),
400 Some(FixString::from_ascii_lossy(
401 b"Grace period is over".to_vec(),
402 )),
403 );
404 break;
405 }
406 }
407 InputEvent::LogoutTimeout => {
408 info!("Logout timeout");
409 disconnect_reason = DisconnectReason::LogoutTimeout;
410 break;
411 }
412 }
413 }
414 self.session
415 .disconnect(&mut self.session.state().borrow_mut(), disconnect_reason);
416
417 input_closed_tx
421 .send(())
422 .expect("Failed to notify about closed inpout");
423 }
424
425 async fn output_loop(
426 &self,
427 mut sink: impl AsyncWrite + Unpin,
428 mut output_stream: impl Stream<Item = OutputEvent> + Unpin,
429 input_closed_rx: oneshot::Receiver<()>,
430 ) {
431 let mut sink_closed = false;
432 let mut disconnect_reason = DisconnectReason::Disconnected;
433 while let Some(event) = output_stream.next().await {
434 match event {
435 OutputEvent::Message(msg) => {
436 if sink_closed {
437 info!("Client disconnected, message will be stored for further resend");
442 } else if let Err(error) = sink.write_all(&msg).await {
443 sink_closed = true;
444 error!(%error, "Output write error");
445 }
457 }
458 OutputEvent::Timeout => self.session.on_out_timeout().await,
459 OutputEvent::Disconnect(reason) => {
460 info!("Client disconnected");
464 if !sink_closed && let Err(error) = sink.flush().await {
465 error!(%error, "final flush failed");
466 }
467 disconnect_reason = reason;
468 }
469 }
470 }
471 self.session.emit_logout(disconnect_reason).await;
475
476 let _ = input_closed_rx.await;
480 if let Err(error) = sink.shutdown().await {
481 error!(%error, "connection shutdown failed")
482 }
483 info!("disconnect, exit output processing");
484 }
485}