1use std::{
2 cell::RefCell,
3 collections::{hash_map::Entry, HashMap},
4 rc::Rc,
5 sync::Mutex,
6};
7
8use easyfix_messages::{
9 fields::{FixString, SessionStatus},
10 messages::{FixtMessage, Message},
11};
12use futures_util::{pin_mut, Stream};
13use tokio::{
14 self,
15 io::{AsyncRead, AsyncWrite, AsyncWriteExt},
16 net::TcpStream,
17 sync::mpsc,
18 time::Duration,
19};
20use tokio_stream::StreamExt;
21use tracing::{debug, error, info, info_span, Instrument};
22
23use crate::{
24 acceptor::{ActiveSessionsMap, SessionsMap},
25 application::{Emitter, FixEventInternal},
26 messages_storage::MessagesStorage,
27 session::Session,
28 session_id::SessionId,
29 session_state::State,
30 settings::{SessionSettings, Settings},
31 DisconnectReason, Error, Sender, SessionError, NO_INBOUND_TIMEOUT_PADDING,
32 TEST_REQUEST_THRESHOLD,
33};
34
35mod input_stream;
36pub use input_stream::{input_stream, InputEvent, InputStream};
37
38mod output_stream;
39use output_stream::{output_stream, OutputEvent};
40
41pub mod time;
42use time::{timeout, timeout_stream};
43
44static SENDERS: Mutex<Option<HashMap<SessionId, Sender>>> = Mutex::new(None);
45
46pub fn register_sender(session_id: SessionId, sender: Sender) {
47 if let Entry::Vacant(entry) = SENDERS
48 .lock()
49 .unwrap()
50 .get_or_insert_with(HashMap::new)
51 .entry(session_id)
52 {
53 entry.insert(sender);
54 }
55}
56
57pub fn unregister_sender(session_id: &SessionId) {
58 if SENDERS
59 .lock()
60 .unwrap()
61 .get_or_insert_with(HashMap::new)
62 .remove(session_id)
63 .is_none()
64 {
65 }
67}
68
69pub fn sender(session_id: &SessionId) -> Option<Sender> {
70 SENDERS
71 .lock()
72 .unwrap()
73 .get_or_insert_with(HashMap::new)
74 .get(session_id)
75 .cloned()
76}
77
78pub fn send(session_id: &SessionId, msg: Box<Message>) -> Result<(), Box<Message>> {
80 if let Some(sender) = sender(session_id) {
81 sender.send(msg).map_err(|msg| msg.body)
82 } else {
83 Err(msg)
84 }
85}
86
87pub fn send_raw(msg: Box<FixtMessage>) -> Result<(), Box<FixtMessage>> {
88 if let Some(sender) = sender(&SessionId::from_input_msg(&msg)) {
89 sender.send_raw(msg)
90 } else {
91 Err(msg)
92 }
93}
94
95async fn first_msg(
96 stream: &mut (impl Stream<Item = InputEvent> + Unpin),
97 logon_timeout: Duration,
98) -> Result<Box<FixtMessage>, Error> {
99 match timeout(logon_timeout, stream.next()).await {
100 Ok(Some(InputEvent::Message(msg))) => Ok(msg),
101 Ok(Some(InputEvent::IoError(error))) => Err(error.into()),
102 Ok(Some(InputEvent::DeserializeError(error))) => {
103 error!("failed to deserialize first message: {error}");
104 Err(Error::SessionError(SessionError::LogonNeverReceived))
105 }
106 _ => Err(Error::SessionError(SessionError::LogonNeverReceived)),
107 }
108}
109
110#[derive(Debug)]
111struct Connection<S> {
112 session: Rc<Session<S>>,
113}
114
115pub(crate) async fn acceptor_connection<S>(
116 reader: impl AsyncRead + Unpin,
117 writer: impl AsyncWrite + Unpin,
118 settings: Settings,
119 sessions: Rc<RefCell<SessionsMap<S>>>,
120 active_sessions: Rc<RefCell<ActiveSessionsMap<S>>>,
121 emitter: Emitter,
122) where
123 S: MessagesStorage,
124{
125 let stream = input_stream(reader);
126 let logon_timeout =
127 settings.auto_disconnect_after_no_logon_received + NO_INBOUND_TIMEOUT_PADDING;
128 pin_mut!(stream);
129 let msg = match first_msg(&mut stream, logon_timeout).await {
130 Ok(msg) => msg,
131 Err(err) => {
132 error!("failed to establish new session: {err}");
133 return;
134 }
135 };
136 let session_id = SessionId::from_input_msg(&msg);
137 debug!("first_msg: {msg:?}");
138
139 let (sender, receiver) = mpsc::unbounded_channel();
140 let sender = Sender::new(sender);
141
142 let Some((session_settings, session_state)) = sessions.borrow().get_session(&session_id) else {
143 error!("failed to establish new session: unknown session id {session_id}");
144 return;
145 };
146 session_state.borrow_mut().set_disconnected(false);
147 register_sender(session_id.clone(), sender.clone());
148 let session = Rc::new(Session::new(
149 settings,
150 session_settings,
151 session_state,
152 sender,
153 emitter.clone(),
154 ));
155 active_sessions
156 .borrow_mut()
157 .insert(session_id.clone(), session.clone());
158
159 let session_span = info_span!(
160 "session",
161 id = %session_id
162 );
163
164 let input_loop_span = info_span!(parent: &session_span, "in");
165 let output_loop_span = info_span!(parent: &session_span, "out");
166
167 let force_disconnection_with_reason = session
168 .on_message_in(msg)
169 .instrument(input_loop_span.clone())
170 .await;
171
172 emitter
174 .send(FixEventInternal::Created(session_id.clone()))
175 .await;
176
177 let input_timeout_duration = session.heartbeat_interval().mul_f32(TEST_REQUEST_THRESHOLD);
178 let input_stream = timeout_stream(input_timeout_duration, stream)
179 .map(|res| res.unwrap_or(InputEvent::Timeout));
180 pin_mut!(input_stream);
181
182 let output_stream = output_stream(session.clone(), session.heartbeat_interval(), receiver);
183 pin_mut!(output_stream);
184
185 let connection = Connection::new(session);
186 let (input_closed_tx, input_closed_rx) = tokio::sync::oneshot::channel();
187
188 tokio::join!(
189 connection
190 .input_loop(
191 input_stream,
192 input_closed_tx,
193 force_disconnection_with_reason
194 )
195 .instrument(input_loop_span.clone()),
196 connection
197 .output_loop(writer, output_stream, input_closed_rx)
198 .instrument(output_loop_span),
199 );
200 session_span.in_scope(|| {
201 info!("connection closed");
202 });
203 unregister_sender(&session_id);
204 active_sessions.borrow_mut().remove(&session_id);
205}
206
207pub(crate) async fn initiator_connection<S>(
208 tcp_stream: TcpStream,
209 settings: Settings,
210 session_settings: SessionSettings,
211 state: Rc<RefCell<State<S>>>,
212 active_sessions: Rc<RefCell<ActiveSessionsMap<S>>>,
213 emitter: Emitter,
214) where
215 S: MessagesStorage,
216{
217 let (source, sink) = tcp_stream.into_split();
218 state.borrow_mut().set_disconnected(false);
219 let session_id = session_settings.session_id.clone();
220
221 let (sender, receiver) = mpsc::unbounded_channel();
222 let sender = Sender::new(sender);
223
224 register_sender(session_id.clone(), sender.clone());
225 let session = Rc::new(Session::new(
226 settings,
227 session_settings,
228 state,
229 sender,
230 emitter.clone(),
231 ));
232 active_sessions
233 .borrow_mut()
234 .insert(session_id.clone(), session.clone());
235
236 let session_span = info_span!(
237 "session",
238 id = %session_id
239 );
240
241 let input_loop_span = info_span!(parent: &session_span, "in");
242 let output_loop_span = info_span!(parent: &session_span, "out");
243
244 emitter
246 .send(FixEventInternal::Created(session_id.clone()))
247 .await;
248
249 let input_timeout_duration = session.heartbeat_interval().mul_f32(TEST_REQUEST_THRESHOLD);
250 let input_stream = timeout_stream(input_timeout_duration, input_stream(source))
251 .map(|res| res.unwrap_or(InputEvent::Timeout));
252 pin_mut!(input_stream);
253
254 let output_stream = output_stream(session.clone(), session.heartbeat_interval(), receiver);
255 pin_mut!(output_stream);
256
257 session.send_logon_request(&mut session.state().borrow_mut());
260
261 let connection = Connection::new(session);
262 let (input_closed_tx, input_closed_rx) = tokio::sync::oneshot::channel();
263
264 tokio::join!(
265 connection
266 .input_loop(input_stream, input_closed_tx, None)
267 .instrument(input_loop_span),
268 connection
269 .output_loop(sink, output_stream, input_closed_rx)
270 .instrument(output_loop_span),
271 );
272 info!("connection closed");
273 unregister_sender(&session_id);
274 active_sessions.borrow_mut().remove(&session_id);
275}
276
277impl<S: MessagesStorage> Connection<S> {
278 fn new(session: Rc<Session<S>>) -> Connection<S> {
279 Connection { session }
280 }
281
282 async fn input_loop(
283 &self,
284 mut input_stream: impl Stream<Item = InputEvent> + Unpin,
285 input_closed_tx: tokio::sync::oneshot::Sender<()>,
286 force_disconnection_with_reason: Option<DisconnectReason>,
287 ) {
288 if let Some(disconnect_reason) = force_disconnection_with_reason {
289 self.session
290 .disconnect(&mut self.session.state().borrow_mut(), disconnect_reason);
291
292 input_closed_tx
296 .send(())
297 .expect("Failed to notify about closed inpuot");
298
299 return;
300 }
301
302 let mut disconnect_reason = DisconnectReason::Disconnected;
303
304 while let Some(event) = input_stream.next().await {
305 if self.session.state().borrow().disconnected() {
307 info!("session disconnected, exit input processing");
308 input_closed_tx
312 .send(())
313 .expect("Failed to notify about closed inpout");
314 return;
315 }
316 match event {
317 InputEvent::Message(msg) => {
318 if let Some(dr) = self.session.on_message_in(msg).await {
319 info!("disconnect ({dr:?}), exit input processing");
320 disconnect_reason = dr;
321 break;
322 }
323 }
324 InputEvent::DeserializeError(error) => {
325 if let Some(dr) = self.session.on_deserialize_error(error).await {
326 info!("disconnect ({dr:?}), exit input processing");
327 disconnect_reason = dr;
328 break;
329 }
330 }
331 InputEvent::IoError(error) => {
332 error!("Input error: {error:?}");
333 disconnect_reason = DisconnectReason::IoError;
334 break;
335 }
336 InputEvent::Timeout => {
337 if self.session.on_in_timeout().await {
338 self.session.send_logout(
339 &mut self.session.state().borrow_mut(),
340 Some(SessionStatus::SessionLogoutComplete),
341 Some(FixString::from_ascii_lossy(
342 b"Grace period is over".to_vec(),
343 )),
344 );
345 break;
346 }
347 }
348 }
349 }
350 self.session
351 .disconnect(&mut self.session.state().borrow_mut(), disconnect_reason);
352
353 input_closed_tx
357 .send(())
358 .expect("Failed to notify about closed inpout");
359 }
360
361 async fn output_loop(
362 &self,
363 mut sink: impl AsyncWrite + Unpin,
364 mut output_stream: impl Stream<Item = OutputEvent> + Unpin,
365 input_closed_rx: tokio::sync::oneshot::Receiver<()>,
366 ) {
367 let mut sink_closed = false;
368 let mut disconnect_reason = DisconnectReason::Disconnected;
369 while let Some(event) = output_stream.next().await {
370 match event {
371 OutputEvent::Message(msg) => {
372 if sink_closed {
373 info!("Client disconnected, message will be stored for further resend");
378 } else if let Err(error) = sink.write_all(&msg).await {
379 sink_closed = true;
380 error!("Output write error: {error:?}");
381 }
393 }
394 OutputEvent::Timeout => self.session.on_out_timeout().await,
395 OutputEvent::Disconnect(reason) => {
396 info!("Client disconnected");
400 if !sink_closed {
401 if let Err(e) = sink.flush().await {
402 error!("final flush failed: {e}");
403 }
404 }
405 disconnect_reason = reason;
406 }
407 }
408 }
409 self.session.emit_logout(disconnect_reason).await;
413
414 let _ = input_closed_rx.await;
418 info!("disconnect, exit output processing");
419 }
420}