1use std::{
2 cell::RefCell,
3 collections::{HashMap, hash_map::Entry},
4 rc::Rc,
5 sync::Mutex,
6 time::Duration,
7};
8
9use easyfix_messages::{
10 fields::{FixString, SessionStatus},
11 messages::{FixtMessage, Message},
12};
13use futures_util::{Stream, pin_mut};
14use tokio::{
15 self,
16 io::{AsyncRead, AsyncWrite, AsyncWriteExt},
17 net::TcpStream,
18 sync::mpsc,
19};
20use tokio_stream::StreamExt;
21use tracing::{Instrument, Span, debug, error, info, info_span};
22
23use crate::{
24 DisconnectReason, Error, NO_INBOUND_TIMEOUT_PADDING, Sender, SessionError,
25 TEST_REQUEST_THRESHOLD,
26 acceptor::{ActiveSessionsMap, SessionsMap},
27 application::{Emitter, FixEventInternal},
28 messages_storage::MessagesStorage,
29 session::Session,
30 session_id::SessionId,
31 session_state::State,
32 settings::{SessionSettings, Settings},
33};
34
35mod input_stream;
36pub use input_stream::{InputEvent, InputStream, input_stream};
37
38mod output_stream;
39use output_stream::{OutputEvent, output_stream};
40
41pub mod time;
42use time::{timeout, timeout_at, timeout_stream};
43
44static SENDERS: Mutex<Option<HashMap<SessionId, Sender>>> = Mutex::new(None);
45
46pub fn register_sender(session_id: SessionId, sender: Sender) {
47 if let Entry::Vacant(entry) = SENDERS
48 .lock()
49 .unwrap()
50 .get_or_insert_with(HashMap::new)
51 .entry(session_id)
52 {
53 entry.insert(sender);
54 }
55}
56
57pub fn unregister_sender(session_id: &SessionId) {
58 if SENDERS
59 .lock()
60 .unwrap()
61 .get_or_insert_with(HashMap::new)
62 .remove(session_id)
63 .is_none()
64 {
65 }
67}
68
69pub fn sender(session_id: &SessionId) -> Option<Sender> {
70 SENDERS
71 .lock()
72 .unwrap()
73 .get_or_insert_with(HashMap::new)
74 .get(session_id)
75 .cloned()
76}
77
78pub fn send(session_id: &SessionId, msg: Box<Message>) -> Result<(), Box<Message>> {
80 if let Some(sender) = sender(session_id) {
81 sender.send(msg).map_err(|msg| msg.body)
82 } else {
83 Err(msg)
84 }
85}
86
87pub fn send_raw(msg: Box<FixtMessage>) -> Result<(), Box<FixtMessage>> {
88 if let Some(sender) = sender(&SessionId::from_input_msg(&msg)) {
89 sender.send_raw(msg)
90 } else {
91 Err(msg)
92 }
93}
94
95async fn first_msg(
96 stream: &mut (impl Stream<Item = InputEvent> + Unpin),
97 logon_timeout: Duration,
98) -> Result<Box<FixtMessage>, Error> {
99 match timeout(logon_timeout, stream.next()).await {
100 Ok(Some(InputEvent::Message(msg))) => Ok(msg),
101 Ok(Some(InputEvent::IoError(error))) => Err(error.into()),
102 Ok(Some(InputEvent::DeserializeError(error))) => {
103 error!("failed to deserialize first message: {error}");
104 Err(Error::SessionError(SessionError::LogonNeverReceived))
105 }
106 _ => Err(Error::SessionError(SessionError::LogonNeverReceived)),
107 }
108}
109
110#[derive(Debug)]
111struct Connection<S> {
112 session: Rc<Session<S>>,
113}
114
115pub(crate) async fn acceptor_connection<S>(
116 reader: impl AsyncRead + Unpin,
117 writer: impl AsyncWrite + Unpin,
118 settings: Settings,
119 sessions: Rc<RefCell<SessionsMap<S>>>,
120 active_sessions: Rc<RefCell<ActiveSessionsMap<S>>>,
121 emitter: Emitter,
122) where
123 S: MessagesStorage,
124{
125 let stream = input_stream(reader);
126 let logon_timeout =
127 settings.auto_disconnect_after_no_logon_received + NO_INBOUND_TIMEOUT_PADDING;
128 pin_mut!(stream);
129 let msg = match first_msg(&mut stream, logon_timeout).await {
130 Ok(msg) => msg,
131 Err(err) => {
132 error!("failed to establish new session: {err}");
133 return;
134 }
135 };
136 let session_id = SessionId::from_input_msg(&msg);
137 debug!("first_msg: {msg:?}");
138
139 let (sender, receiver) = mpsc::unbounded_channel();
140 let sender = Sender::new(sender);
141
142 let Some((session_settings, session_state)) = sessions.borrow().get_session(&session_id) else {
143 error!("failed to establish new session: unknown session id {session_id}");
144 return;
145 };
146 if !session_state.borrow_mut().disconnected()
147 || active_sessions.borrow().contains_key(&session_id)
148 {
149 error!(%session_id, "Session already active");
150 return;
151 }
152 session_state.borrow_mut().set_disconnected(false);
153 register_sender(session_id.clone(), sender.clone());
154 let session = Rc::new(Session::new(
155 settings,
156 session_settings,
157 session_state,
158 sender,
159 emitter.clone(),
160 ));
161 active_sessions
162 .borrow_mut()
163 .insert(session_id.clone(), session.clone());
164
165 let session_span = info_span!(
166 parent: None,
167 "session",
168 id = %session_id
169 );
170 session_span.follows_from(Span::current());
171
172 let input_loop_span = info_span!(parent: &session_span, "in");
173 let output_loop_span = info_span!(parent: &session_span, "out");
174
175 let force_disconnection_with_reason = session
176 .on_message_in(msg)
177 .instrument(input_loop_span.clone())
178 .await;
179
180 emitter
182 .send(FixEventInternal::Created(session_id.clone()))
183 .await;
184
185 let input_timeout_duration = session.heartbeat_interval().mul_f32(TEST_REQUEST_THRESHOLD);
186 let input_stream = timeout_stream(input_timeout_duration, stream)
187 .map(|res| res.unwrap_or(InputEvent::Timeout));
188 pin_mut!(input_stream);
189
190 let output_stream = output_stream(session.clone(), session.heartbeat_interval(), receiver);
191 pin_mut!(output_stream);
192
193 let connection = Connection::new(session);
194 let (input_closed_tx, input_closed_rx) = tokio::sync::oneshot::channel();
195
196 tokio::join!(
197 connection
198 .input_loop(
199 input_stream,
200 input_closed_tx,
201 force_disconnection_with_reason
202 )
203 .instrument(input_loop_span),
204 connection
205 .output_loop(writer, output_stream, input_closed_rx)
206 .instrument(output_loop_span),
207 );
208 session_span.in_scope(|| {
209 info!("connection closed");
210 });
211 unregister_sender(&session_id);
212 active_sessions.borrow_mut().remove(&session_id);
213}
214
215pub(crate) async fn initiator_connection<S>(
216 tcp_stream: TcpStream,
217 settings: Settings,
218 session_settings: SessionSettings,
219 state: Rc<RefCell<State<S>>>,
220 active_sessions: Rc<RefCell<ActiveSessionsMap<S>>>,
221 emitter: Emitter,
222) where
223 S: MessagesStorage,
224{
225 let (source, sink) = tcp_stream.into_split();
226 state.borrow_mut().set_disconnected(false);
227 let session_id = session_settings.session_id.clone();
228
229 let (sender, receiver) = mpsc::unbounded_channel();
230 let sender = Sender::new(sender);
231
232 register_sender(session_id.clone(), sender.clone());
233 let session = Rc::new(Session::new(
234 settings,
235 session_settings,
236 state,
237 sender,
238 emitter.clone(),
239 ));
240 active_sessions
241 .borrow_mut()
242 .insert(session_id.clone(), session.clone());
243
244 let session_span = info_span!(
245 "session",
246 id = %session_id
247 );
248
249 let input_loop_span = info_span!(parent: &session_span, "in");
250 let output_loop_span = info_span!(parent: &session_span, "out");
251
252 emitter
254 .send(FixEventInternal::Created(session_id.clone()))
255 .await;
256
257 let input_timeout_duration = session.heartbeat_interval().mul_f32(TEST_REQUEST_THRESHOLD);
258 let input_stream = timeout_stream(input_timeout_duration, input_stream(source))
259 .map(|res| res.unwrap_or(InputEvent::Timeout));
260 pin_mut!(input_stream);
261
262 let output_stream = output_stream(session.clone(), session.heartbeat_interval(), receiver);
263 pin_mut!(output_stream);
264
265 session.send_logon_request(&mut session.state().borrow_mut());
268
269 let connection = Connection::new(session);
270 let (input_closed_tx, input_closed_rx) = tokio::sync::oneshot::channel();
271
272 tokio::join!(
273 connection
274 .input_loop(input_stream, input_closed_tx, None)
275 .instrument(input_loop_span),
276 connection
277 .output_loop(sink, output_stream, input_closed_rx)
278 .instrument(output_loop_span),
279 );
280 info!("connection closed");
281 unregister_sender(&session_id);
282 active_sessions.borrow_mut().remove(&session_id);
283}
284
285impl<S: MessagesStorage> Connection<S> {
286 fn new(session: Rc<Session<S>>) -> Connection<S> {
287 Connection { session }
288 }
289
290 async fn input_loop(
291 &self,
292 mut input_stream: impl Stream<Item = InputEvent> + Unpin,
293 input_closed_tx: tokio::sync::oneshot::Sender<()>,
294 force_disconnection_with_reason: Option<DisconnectReason>,
295 ) {
296 if let Some(disconnect_reason) = force_disconnection_with_reason {
297 self.session
298 .disconnect(&mut self.session.state().borrow_mut(), disconnect_reason);
299
300 input_closed_tx
304 .send(())
305 .expect("Failed to notify about closed inpuot");
306
307 return;
308 }
309
310 let mut disconnect_reason = DisconnectReason::Disconnected;
311 let mut logout_deadline = None;
312
313 let mut next_item = async || {
314 if logout_deadline.is_none() {
315 logout_deadline = self.session.logout_deadline();
316 }
317 if let Some(logout_deadline) = logout_deadline {
318 timeout_at(logout_deadline, input_stream.next())
319 .await
320 .unwrap_or(Some(InputEvent::LogoutTimeout))
321 } else {
322 input_stream.next().await
323 }
324 };
325
326 while let Some(event) = next_item().await {
327 if self.session.state().borrow().disconnected() {
329 info!("session disconnected, exit input processing");
330 input_closed_tx
334 .send(())
335 .expect("Failed to notify about closed inpout");
336 return;
337 }
338 match event {
339 InputEvent::Message(msg) => {
340 if let Some(reason) = self.session.on_message_in(msg).await {
341 info!(?reason, "disconnect, exit input processing");
342 disconnect_reason = reason;
343 break;
344 }
345 }
346 InputEvent::DeserializeError(error) => {
347 if let Some(reason) = self.session.on_deserialize_error(error).await {
348 info!(?reason, "disconnect, exit input processing");
349 disconnect_reason = reason;
350 break;
351 }
352 }
353 InputEvent::IoError(error) => {
354 error!(%error, "Input error");
355 disconnect_reason = DisconnectReason::IoError;
356 break;
357 }
358 InputEvent::Timeout => {
359 if self.session.on_in_timeout().await {
360 self.session.send_logout(
361 &mut self.session.state().borrow_mut(),
362 Some(SessionStatus::SessionLogoutComplete),
363 Some(FixString::from_ascii_lossy(
364 b"Grace period is over".to_vec(),
365 )),
366 );
367 break;
368 }
369 }
370 InputEvent::LogoutTimeout => {
371 info!("Logout timeout");
372 disconnect_reason = DisconnectReason::LogoutTimeout;
373 break;
374 }
375 }
376 }
377 self.session
378 .disconnect(&mut self.session.state().borrow_mut(), disconnect_reason);
379
380 input_closed_tx
384 .send(())
385 .expect("Failed to notify about closed inpout");
386 }
387
388 async fn output_loop(
389 &self,
390 mut sink: impl AsyncWrite + Unpin,
391 mut output_stream: impl Stream<Item = OutputEvent> + Unpin,
392 input_closed_rx: tokio::sync::oneshot::Receiver<()>,
393 ) {
394 let mut sink_closed = false;
395 let mut disconnect_reason = DisconnectReason::Disconnected;
396 while let Some(event) = output_stream.next().await {
397 match event {
398 OutputEvent::Message(msg) => {
399 if sink_closed {
400 info!("Client disconnected, message will be stored for further resend");
405 } else if let Err(error) = sink.write_all(&msg).await {
406 sink_closed = true;
407 error!(%error, "Output write error");
408 }
420 }
421 OutputEvent::Timeout => self.session.on_out_timeout().await,
422 OutputEvent::Disconnect(reason) => {
423 info!("Client disconnected");
427 if !sink_closed && let Err(error) = sink.flush().await {
428 error!(%error, "final flush failed");
429 }
430 disconnect_reason = reason;
431 }
432 }
433 }
434 self.session.emit_logout(disconnect_reason).await;
438
439 let _ = input_closed_rx.await;
443 if let Err(error) = sink.shutdown().await {
444 error!(%error, "connection shutdown failed")
445 }
446 info!("disconnect, exit output processing");
447 }
448}