dynamo_runtime/pipeline/network/tcp/
client.rs1use std::sync::Arc;
17
18use futures::{SinkExt, StreamExt};
19use tokio::io::{AsyncReadExt, ReadHalf, WriteHalf};
20use tokio::{
21 io::AsyncWriteExt,
22 net::TcpStream,
23 time::{self, Duration, Instant},
24};
25use tokio_util::codec::{FramedRead, FramedWrite};
26
27use super::{CallHomeHandshake, ControlMessage, TcpStreamConnectionInfo};
28use crate::engine::AsyncEngineContext;
29use crate::pipeline::network::{
30 ConnectionInfo, ResponseStreamPrologue, StreamSender,
31 codec::{TwoPartCodec, TwoPartMessage},
32 tcp::StreamType,
33};
34use crate::{ErrorContext, Result, error}; #[allow(dead_code)]
37pub struct TcpClient {
38 worker_id: String,
39}
40
41impl Default for TcpClient {
42 fn default() -> Self {
43 TcpClient {
44 worker_id: uuid::Uuid::new_v4().to_string(),
45 }
46 }
47}
48
49impl TcpClient {
50 pub fn new(worker_id: String) -> Self {
51 TcpClient { worker_id }
52 }
53
54 async fn connect(address: &str) -> std::io::Result<TcpStream> {
55 let backoff = std::time::Duration::from_millis(200);
57 loop {
58 match TcpStream::connect(address).await {
59 Ok(socket) => {
60 socket.set_nodelay(true)?;
61 return Ok(socket);
62 }
63 Err(e) => {
64 if e.kind() == std::io::ErrorKind::AddrNotAvailable {
65 tracing::warn!("retry warning: failed to connect: {:?}", e);
66 tokio::time::sleep(backoff).await;
67 } else {
68 return Err(e);
69 }
70 }
71 }
72 }
73 }
74
75 pub async fn create_response_stream(
76 context: Arc<dyn AsyncEngineContext>,
77 info: ConnectionInfo,
78 ) -> Result<StreamSender> {
79 let info =
80 TcpStreamConnectionInfo::try_from(info).context("tcp-stream-connection-info-error")?;
81 tracing::trace!("Creating response stream for {:?}", info);
82
83 if info.stream_type != StreamType::Response {
84 return Err(error!(
85 "Invalid stream type; TcpClient requires the stream type to be `response`; however {:?} was passed",
86 info.stream_type
87 ));
88 }
89
90 if info.context != context.id() {
91 return Err(error!(
92 "Invalid context; TcpClient requires the context to be {:?}; however {:?} was passed",
93 context.id(),
94 info.context
95 ));
96 }
97
98 let stream = TcpClient::connect(&info.address).await?;
99 let (read_half, write_half) = tokio::io::split(stream);
100
101 let framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
102 let mut framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
103
104 let (alive_tx, alive_rx) = tokio::sync::oneshot::channel::<()>();
110
111 let reader_task = tokio::spawn(handle_reader(framed_reader, context.clone(), alive_tx));
112
113 let handshake = CallHomeHandshake {
115 subject: info.subject,
116 stream_type: StreamType::Response,
117 };
118
119 let handshake_bytes = match serde_json::to_vec(&handshake) {
120 Ok(hb) => hb,
121 Err(err) => {
122 return Err(error!(
123 "create_response_stream: Error converting CallHomeHandshake to JSON array: {err:#}"
124 ));
125 }
126 };
127 let msg = TwoPartMessage::from_header(handshake_bytes.into());
128
129 framed_writer
131 .send(msg)
132 .await
133 .map_err(|e| error!("failed to send handshake: {:?}", e))?;
134
135 let (bytes_tx, bytes_rx) = tokio::sync::mpsc::channel(64);
137
138 let writer_task = tokio::spawn(handle_writer(framed_writer, bytes_rx, alive_rx, context));
141
142 tokio::spawn(async move {
143 let (reader, writer) = tokio::join!(reader_task, writer_task);
145
146 match (reader, writer) {
147 (Ok(reader), Ok(writer)) => {
148 let reader = reader.into_inner();
149
150 let writer = match writer {
151 Ok(writer) => writer.into_inner(),
152 Err(e) => {
153 tracing::error!("failed to join writer task: {:?}", e);
154 return Err(e);
155 }
156 };
157
158 let mut stream = reader.unsplit(writer);
159
160 let mut buf = vec![0u8; 1024];
163 let deadline = Instant::now() + Duration::from_secs(10);
164 loop {
165 let n = time::timeout_at(deadline, stream.read(&mut buf))
166 .await
167 .inspect_err(|_| {
168 tracing::debug!("server did not close socket within the deadline");
169 })?
170 .inspect_err(|e| {
171 tracing::debug!("failed to read from stream: {:?}", e);
172 })?;
173 if n == 0 {
174 break;
176 }
177 }
178
179 Ok(())
180 }
181 _ => {
182 tracing::error!("failed to join reader and writer tasks");
183 anyhow::bail!("failed to join reader and writer tasks");
184 }
185 }
186 });
187
188 let prologue = Some(ResponseStreamPrologue { error: None });
191
192 let stream_sender = StreamSender {
194 tx: bytes_tx,
195 prologue,
196 };
197
198 Ok(stream_sender)
199 }
200}
201
202async fn handle_reader(
203 framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
204 context: Arc<dyn AsyncEngineContext>,
205 alive_tx: tokio::sync::oneshot::Sender<()>,
206) -> FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec> {
207 let mut framed_reader = framed_reader;
208 let mut alive_tx = alive_tx;
209 loop {
210 tokio::select! {
211 msg = framed_reader.next() => {
212 match msg {
213 Some(Ok(two_part_msg)) => {
214 match two_part_msg.optional_parts() {
215 (Some(bytes), None) => {
216 let msg = match serde_json::from_slice::<ControlMessage>(bytes) {
217 Ok(msg) => msg,
218 Err(_) => {
219 panic!("fatal error - invalid control message detected");
221 }
222 };
223
224 match msg {
225 ControlMessage::Stop => {
226 context.stop();
227 }
228 ControlMessage::Kill => {
229 context.kill();
230 }
231 ControlMessage::Sentinel => {
232 panic!("received a sentinel message; this should never happen");
234 }
235 }
236 }
237 _ => {
238 panic!("received a non-control message; this should never happen");
239 }
240 }
241 }
242 Some(Err(_)) => {
243 panic!("fatal error - failed to decode message from stream; invalid line protocol");
246 }
247 None => {
248 tracing::debug!("tcp stream closed by server");
249 break;
250 }
251 }
252 }
253 _ = alive_tx.closed() => {
254 break;
255 }
256 }
257 }
258 framed_reader
259}
260
261async fn handle_writer(
262 mut framed_writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
263 mut bytes_rx: tokio::sync::mpsc::Receiver<TwoPartMessage>,
264 alive_rx: tokio::sync::oneshot::Receiver<()>,
265 context: Arc<dyn AsyncEngineContext>,
266) -> Result<FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>> {
267 loop {
268 let msg = tokio::select! {
269 biased;
270
271 _ = context.killed() => {
272 tracing::trace!("context kill signal received; shutting down");
273 break;
274 }
275
276 msg = bytes_rx.recv() => {
277 match msg {
278 Some(msg) => msg,
279 None => {
280 tracing::trace!("response channel closed; shutting down");
281 break;
282 }
283 }
284 }
285 };
286
287 if let Err(e) = framed_writer.send(msg).await {
288 tracing::trace!(
289 "failed to send message to network; possible disconnect: {:?}",
290 e
291 );
292 break;
293 }
294 }
295
296 let message = serde_json::to_vec(&ControlMessage::Sentinel)?;
298 let msg = TwoPartMessage::from_header(message.into());
299 framed_writer.send(msg).await?;
300
301 drop(alive_rx);
302 Ok(framed_writer)
303}