1use core::panic;
17use socket2::{Domain, SockAddr, Socket, Type};
18use std::{
19 collections::HashMap,
20 net::{SocketAddr, TcpListener},
21 os::fd::{AsFd, FromRawFd},
22 sync::Arc,
23};
24use tokio::sync::Mutex;
25
26use bytes::Bytes;
27use derive_builder::Builder;
28use futures::{SinkExt, StreamExt};
29use local_ip_address::{list_afinet_netifas, local_ip};
30use serde::{Deserialize, Serialize};
31use tokio::{
32 io::AsyncWriteExt,
33 sync::{mpsc, oneshot},
34 time,
35};
36use tokio_util::codec::{FramedRead, FramedWrite};
37
38use super::{
39 CallHomeHandshake, ControlMessage, PendingConnections, RegisteredStream, StreamOptions,
40 StreamReceiver, StreamSender, TcpStreamConnectionInfo, TwoPartCodec,
41};
42use crate::engine::AsyncEngineContext;
43use crate::pipeline::{
44 network::{
45 codec::{TwoPartMessage, TwoPartMessageType},
46 tcp::StreamType,
47 ResponseService, ResponseStreamPrologue,
48 },
49 PipelineError,
50};
51use crate::{error, ErrorContext, Result};
52
53#[allow(dead_code)]
54type ResponseType = TwoPartMessage;
55
56#[derive(Debug, Serialize, Deserialize, Clone, Builder, Default)]
57pub struct ServerOptions {
58 #[builder(default = "0")]
59 pub port: u16,
60
61 #[builder(default)]
62 pub interface: Option<String>,
63}
64
65impl ServerOptions {
66 pub fn builder() -> ServerOptionsBuilder {
67 ServerOptionsBuilder::default()
68 }
69}
70
71pub struct TcpStreamServer {
75 local_ip: String,
76 local_port: u16,
77 state: Arc<Mutex<State>>,
78}
79
80#[allow(dead_code)]
87struct RequestedSendConnection {
88 context: Arc<dyn AsyncEngineContext>,
89 connection: oneshot::Sender<Result<StreamSender, String>>,
90}
91
92struct RequestedRecvConnection {
93 context: Arc<dyn AsyncEngineContext>,
94 connection: oneshot::Sender<Result<StreamReceiver, String>>,
95}
96
97#[derive(Default)]
114struct State {
115 tx_subjects: HashMap<String, RequestedSendConnection>,
116 rx_subjects: HashMap<String, RequestedRecvConnection>,
117 handle: Option<tokio::task::JoinHandle<Result<()>>>,
118}
119
120impl TcpStreamServer {
121 pub fn options_builder() -> ServerOptionsBuilder {
122 ServerOptionsBuilder::default()
123 }
124
125 pub async fn new(options: ServerOptions) -> Result<Arc<Self>, PipelineError> {
126 let local_ip = match options.interface {
127 Some(interface) => {
128 let interfaces: HashMap<String, std::net::IpAddr> =
129 list_afinet_netifas()?.into_iter().collect();
130
131 interfaces
132 .get(&interface)
133 .ok_or(PipelineError::Generic(format!(
134 "Interface not found: {}",
135 interface
136 )))?
137 .to_string()
138 }
139 None => local_ip().unwrap().to_string(),
140 };
141
142 let state = Arc::new(Mutex::new(State::default()));
143
144 let local_port = Self::start(local_ip.clone(), options.port, state.clone())
145 .await
146 .map_err(|e| {
147 PipelineError::Generic(format!("Failed to start TcpStreamServer: {}", e))
148 })?;
149
150 tracing::info!("tcp transport service on {}:{}", local_ip, local_port);
151
152 Ok(Arc::new(Self {
153 local_ip,
154 local_port,
155 state,
156 }))
157 }
158
159 #[allow(clippy::await_holding_lock)]
160 async fn start(local_ip: String, local_port: u16, state: Arc<Mutex<State>>) -> Result<u16> {
161 let addr = format!("{}:{}", local_ip, local_port);
162 let state_clone = state.clone();
163 let mut guard = state.lock().await;
164 if guard.handle.is_some() {
165 panic!("TcpStreamServer already started");
166 }
167 let (ready_tx, ready_rx) = tokio::sync::oneshot::channel::<Result<u16>>();
168 let handle = tokio::spawn(tcp_listener(addr, state_clone, ready_tx));
169 guard.handle = Some(handle);
170 drop(guard);
171 let local_port = ready_rx.await??;
172 Ok(local_port)
173 }
174}
175
176#[async_trait::async_trait]
178impl ResponseService for TcpStreamServer {
179 async fn register(&self, options: StreamOptions) -> PendingConnections {
200 let address = format!("{}:{}", self.local_ip, self.local_port);
203 tracing::debug!("Registering new TcpStream on {}", address);
204
205 let send_stream = if options.enable_request_stream {
206 let sender_subject = uuid::Uuid::new_v4().to_string();
207
208 let (pending_sender_tx, pending_sender_rx) = oneshot::channel();
209
210 let connection_info = RequestedSendConnection {
211 context: options.context.clone(),
212 connection: pending_sender_tx,
213 };
214
215 let mut state = self.state.lock().await;
216 state
217 .tx_subjects
218 .insert(sender_subject.clone(), connection_info);
219
220 let registered_stream = RegisteredStream {
221 connection_info: TcpStreamConnectionInfo {
222 address: address.clone(),
223 subject: sender_subject.clone(),
224 context: options.context.id().to_string(),
225 stream_type: StreamType::Request,
226 }
227 .into(),
228 stream_provider: pending_sender_rx,
229 };
230
231 Some(registered_stream)
232 } else {
233 None
234 };
235
236 let recv_stream = if options.enable_response_stream {
237 let (pending_recver_tx, pending_recver_rx) = oneshot::channel();
238 let receiver_subject = uuid::Uuid::new_v4().to_string();
239
240 let connection_info = RequestedRecvConnection {
241 context: options.context.clone(),
242 connection: pending_recver_tx,
243 };
244
245 let mut state = self.state.lock().await;
246 state
247 .rx_subjects
248 .insert(receiver_subject.clone(), connection_info);
249
250 let registered_stream = RegisteredStream {
251 connection_info: TcpStreamConnectionInfo {
252 address: address.clone(),
253 subject: receiver_subject.clone(),
254 context: options.context.id().to_string(),
255 stream_type: StreamType::Response,
256 }
257 .into(),
258 stream_provider: pending_recver_rx,
259 };
260
261 Some(registered_stream)
262 } else {
263 None
264 };
265
266 PendingConnections {
267 send_stream,
268 recv_stream,
269 }
270 }
271}
272
273async fn tcp_listener(
280 addr: String,
281 state: Arc<Mutex<State>>,
282 read_tx: tokio::sync::oneshot::Sender<Result<u16>>,
283) -> Result<()> {
284 let listener = tokio::net::TcpListener::bind(&addr)
285 .await
286 .map_err(|e| anyhow::anyhow!("Failed to start TcpListender on {}: {}", addr, e));
287
288 let listener = match listener {
289 Ok(listener) => {
290 let addr = listener
291 .local_addr()
292 .map_err(|e| anyhow::anyhow!("Failed get SocketAddr: {:?}", e))
293 .unwrap();
294
295 read_tx
296 .send(Ok(addr.port()))
297 .expect("Failed to send ready signal");
298
299 listener
300 }
301 Err(e) => {
302 read_tx.send(Err(e)).expect("Failed to send ready signal");
303 return Err(anyhow::anyhow!("Failed to start TcpListender on {}", addr));
304 }
305 };
306
307 loop {
308 let (stream, _addr) = match listener.accept().await {
314 Ok((stream, _addr)) => (stream, _addr),
315 Err(e) => {
316 tracing::warn!("failed to accept tcp connection: {}", e);
318 eprintln!("failed to accept tcp connection: {}", e);
319 continue;
320 }
321 };
322
323 match stream.set_nodelay(true) {
324 Ok(_) => (),
325 Err(e) => {
326 tracing::warn!("failed to set tcp stream to nodelay: {}", e);
327 }
328 }
329
330 match stream.set_linger(Some(std::time::Duration::from_secs(0))) {
331 Ok(_) => (),
332 Err(e) => {
333 tracing::warn!("failed to set tcp stream to linger: {}", e);
334 }
335 }
336
337 tokio::spawn(handle_connection(stream, state.clone()));
338 }
339
340 async fn handle_connection(stream: tokio::net::TcpStream, state: Arc<Mutex<State>>) {
343 let result = process_stream(stream, state).await;
344 match result {
345 Ok(_) => tracing::trace!("successfully processed tcp connection"),
346 Err(e) => {
347 tracing::warn!("failed to handle tcp connection: {}", e);
348 #[cfg(debug_assertions)]
349 eprintln!("failed to handle tcp connection: {}", e);
350 }
351 }
352 }
353
354 async fn process_stream(stream: tokio::net::TcpStream, state: Arc<Mutex<State>>) -> Result<()> {
357 let (read_half, write_half) = tokio::io::split(stream);
359
360 let mut framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
362 let framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
363
364 let first_message = framed_reader
367 .next()
368 .await
369 .ok_or(error!("Connection closed without a ControlMessage"))??;
370
371 let handshake: CallHomeHandshake = match first_message.header() {
374 Some(header) => serde_json::from_slice(header).map_err(|e| {
375 error!(
376 "Failed to deserialize the first message as a valid `CallHomeHandshake`: {e}",
377 )
378 })?,
379 None => {
380 return Err(error!("Expected ControlMessage, got DataMessage"));
381 }
382 };
383
384 match handshake.stream_type {
386 StreamType::Request => process_request_stream().await,
387 StreamType::Response => {
388 process_response_stream(handshake.subject, state, framed_reader, framed_writer)
389 .await
390 }
391 }
392 }
393
394 async fn process_request_stream() -> Result<()> {
395 Ok(())
396 }
397
398 async fn process_response_stream(
399 subject: String,
400 state: Arc<Mutex<State>>,
401 mut reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
402 writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
403 ) -> Result<()> {
404 let response_stream = state
405 .lock().await
406 .rx_subjects
407 .remove(&subject)
408 .ok_or(error!("Subject not found: {}; upstream publisher specified a subject unknown to the downsteam subscriber", subject))?;
409
410 let RequestedRecvConnection {
412 context,
413 connection,
414 } = response_stream;
415
416 let prologue = reader
419 .next()
420 .await
421 .ok_or(error!("Connection closed without a ControlMessge"))??;
422
423 let prologue = match prologue.into_message_type() {
425 TwoPartMessageType::HeaderOnly(header) => {
426 let prologue: ResponseStreamPrologue = serde_json::from_slice(&header)
427 .map_err(|e| error!("Failed to deserialize ControlMessage: {}", e))?;
428 prologue
429 }
430 _ => {
431 panic!("Expected HeaderOnly ControlMessage; internally logic error")
432 }
433 };
434
435 if let Some(error) = &prologue.error {
442 let _ = connection.send(Err(error.clone()));
443 return Err(error!("Received error prologue: {}", error));
444 }
445
446 let (response_tx, response_rx) = mpsc::channel(64);
448
449 if connection
450 .send(Ok(crate::pipeline::network::StreamReceiver {
451 rx: response_rx,
452 }))
453 .is_err()
454 {
455 return Err(error!("The requester of the stream has been dropped before the connection was established"));
456 }
457
458 let (control_tx, control_rx) = mpsc::channel::<ControlMessage>(1);
459
460 let send_task = tokio::spawn(network_send_handler(writer, control_rx));
464
465 let recv_task = tokio::spawn(network_receive_handler(
467 reader,
468 response_tx,
469 control_tx,
470 context.clone(),
471 ));
472
473 let (monitor_result, forward_result) = tokio::join!(send_task, recv_task);
475
476 monitor_result?;
477 forward_result?;
478
479 Ok(())
480 }
481
482 async fn network_receive_handler(
483 mut framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
484 response_tx: mpsc::Sender<Bytes>,
485 control_tx: mpsc::Sender<ControlMessage>,
486 context: Arc<dyn AsyncEngineContext>,
487 ) {
488 let mut can_stop = true;
490 loop {
491 tokio::select! {
492 biased;
493
494 _ = response_tx.closed() => {
495 tracing::trace!("response channel closed before the client finished writing data");
496 control_tx.send(ControlMessage::Kill).await.expect("the control channel should not be closed");
497 break;
498 }
499
500 _ = context.killed() => {
501 tracing::trace!("context kill signal received; shutting down");
502 control_tx.send(ControlMessage::Kill).await.expect("the control channel should not be closed");
503 break;
504 }
505
506 _ = context.stopped(), if can_stop => {
507 tracing::trace!("context stop signal received; shutting down");
508 can_stop = false;
509 control_tx.send(ControlMessage::Stop).await.expect("the control channel should not be closed");
510 }
511
512 msg = framed_reader.next() => {
513 match msg {
514 Some(Ok(msg)) => {
515 let (header, data) = msg.into_parts();
516
517 if !header.is_empty() {
519 match process_control_message(header) {
520 Ok(ControlAction::Continue) => {}
521 Ok(ControlAction::Shutdown) => {
522 assert!(data.is_empty(), "received sentinel message with data; this should never happen");
523 tracing::trace!("received sentinel message; shutting down");
524 break;
525 }
526 Err(e) => {
527 panic!("{:?}", e);
529 }
530 }
531 }
532
533 if !data.is_empty() {
534 if let Err(err) = response_tx.send(data).await {
535 tracing::debug!("forwarding body/data message to response channel failed: {}", err);
536 control_tx.send(ControlMessage::Kill).await.expect("the control channel should not be closed");
537 break;
538 };
539 }
540 }
541 Some(Err(_)) => {
542 panic!("invalid message issued over socket; this should never happen");
544 }
545 None => {
546 tracing::trace!("tcp stream was closed by client");
552 break;
553 }
554 }
555 }
556
557 }
558 }
559 }
560
561 async fn network_send_handler(
562 socket_tx: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
563 control_rx: mpsc::Receiver<ControlMessage>,
564 ) {
565 let mut socket_tx = socket_tx;
566 let mut control_rx = control_rx;
567
568 while let Some(control_msg) = control_rx.recv().await {
569 assert_ne!(
570 control_msg,
571 ControlMessage::Sentinel,
572 "received sentinel message; this should never happen"
573 );
574 let bytes =
575 serde_json::to_vec(&control_msg).expect("failed to serialize control message");
576 let message = TwoPartMessage::from_header(bytes.into());
577 match socket_tx.send(message).await {
578 Ok(_) => tracing::debug!("issued control message {control_msg:?} to sender"),
579 Err(_) => {
580 tracing::debug!("failed to send control message {control_msg:?} to sender")
581 }
582 }
583 }
584
585 let mut inner = socket_tx.into_inner();
586 if let Err(e) = inner.flush().await {
587 tracing::debug!("failed to flush socket: {}", e);
588 }
589 if let Err(e) = inner.shutdown().await {
590 tracing::debug!("failed to shutdown socket: {}", e);
591 }
592 }
593}
594
595enum ControlAction {
596 Continue,
597 Shutdown,
598}
599
600fn process_control_message(message: Bytes) -> Result<ControlAction> {
601 match serde_json::from_slice::<ControlMessage>(&message)? {
602 ControlMessage::Sentinel => {
603 tracing::trace!("sentinel received; shutting down");
606 Ok(ControlAction::Shutdown)
607 }
608 ControlMessage::Kill | ControlMessage::Stop => {
609 anyhow::bail!(
611 "fatal error - unexpected control message received - this should never happen"
612 );
613 }
614 }
615}