1pub mod codec;
12pub mod egress;
13pub mod ingress;
14pub mod manager;
15pub mod tcp;
16
17use crate::SystemHealth;
18use std::sync::{Arc, OnceLock};
19
20use anyhow::Result;
21use async_trait::async_trait;
22use bytes::Bytes;
23use codec::{TwoPartCodec, TwoPartMessage, TwoPartMessageType};
24use derive_builder::Builder;
25use futures::StreamExt;
26use super::{AsyncEngine, AsyncEngineContext, AsyncEngineContextProvider, ResponseStream};
28use serde::{Deserialize, Serialize};
29
30use super::{
31 AsyncTransportEngine, Context, Data, Error, ManyOut, PipelineError, PipelineIO, SegmentSource,
32 ServiceBackend, ServiceEngine, SingleIn, Source, context,
33};
34use crate::metrics::MetricsHierarchy;
35use ingress::push_handler::WorkHandlerMetrics;
36use prometheus::{CounterVec, Histogram, IntCounter, IntCounterVec, IntGauge};
37
38pub(crate) const DEFAULT_TCP_MAX_MESSAGE_SIZE: usize = 32 * 1024 * 1024;
40
41static TCP_MAX_MESSAGE_SIZE: OnceLock<usize> = OnceLock::new();
42
43pub(crate) fn get_tcp_max_message_size() -> usize {
46 *TCP_MAX_MESSAGE_SIZE.get_or_init(|| {
47 std::env::var("DYN_TCP_MAX_MESSAGE_SIZE")
48 .ok()
49 .and_then(|s| s.parse::<usize>().ok())
50 .unwrap_or(DEFAULT_TCP_MAX_MESSAGE_SIZE)
51 })
52}
53
54pub trait Codable: PipelineIO + Serialize + for<'de> Deserialize<'de> {}
55impl<T: PipelineIO + Serialize + for<'de> Deserialize<'de>> Codable for T {}
56
57#[async_trait]
59pub trait WorkQueueConsumer {
60 async fn dequeue(&self) -> Result<Bytes, String>;
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
64#[serde(rename_all = "snake_case")]
65pub enum StreamType {
66 Request,
67 Response,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
71#[serde(rename_all = "snake_case")]
72pub enum ControlMessage {
73 Stop,
74 Kill,
75 Sentinel,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
85pub struct ResponseStreamPrologue {
86 error: Option<String>,
87}
88
89pub type StreamProvider<T> = tokio::sync::oneshot::Receiver<Result<T, String>>;
90
91struct Cleanup(Option<Box<dyn FnOnce() + Send + 'static>>);
94
95impl Drop for Cleanup {
96 fn drop(&mut self) {
97 if let Some(f) = self.0.take() {
98 f();
99 }
100 }
101}
102
103pub struct RegisteredStream<T> {
107 pub connection_info: ConnectionInfo,
108 pub stream_provider: StreamProvider<T>,
109 cleanup: Cleanup,
110}
111
112impl<T> std::fmt::Debug for RegisteredStream<T> {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 f.debug_struct("RegisteredStream")
115 .field("connection_info", &self.connection_info)
116 .finish_non_exhaustive()
117 }
118}
119
120impl<T> RegisteredStream<T> {
121 pub(crate) fn new(connection_info: ConnectionInfo, stream_provider: StreamProvider<T>) -> Self {
122 Self {
123 connection_info,
124 stream_provider,
125 cleanup: Cleanup(None),
126 }
127 }
128
129 pub(crate) fn with_cleanup<F>(mut self, cleanup: F) -> Self
130 where
131 F: FnOnce() + Send + 'static,
132 {
133 self.cleanup.0 = Some(Box::new(cleanup));
134 self
135 }
136
137 pub fn into_parts(self) -> (ConnectionInfo, StreamProvider<T>) {
140 let Self {
141 connection_info,
142 stream_provider,
143 mut cleanup,
144 } = self;
145 cleanup.0.take();
146 (connection_info, stream_provider)
147 }
148}
149
150pub struct PendingConnections {
153 pub send_stream: Option<RegisteredStream<StreamSender>>,
154 pub recv_stream: Option<RegisteredStream<StreamReceiver>>,
155}
156
157impl PendingConnections {
158 pub fn into_parts(
159 self,
160 ) -> (
161 Option<RegisteredStream<StreamSender>>,
162 Option<RegisteredStream<StreamReceiver>>,
163 ) {
164 (self.send_stream, self.recv_stream)
165 }
166}
167
168#[async_trait::async_trait]
171pub trait ResponseService {
172 async fn register(&self, options: StreamOptions) -> PendingConnections;
173}
174
175#[cfg(test)]
176mod registered_stream_tests {
177 use super::*;
178 use std::sync::atomic::{AtomicBool, Ordering};
179
180 fn dummy_conn_info() -> ConnectionInfo {
181 ConnectionInfo {
182 transport: "test".to_string(),
183 info: "{}".to_string(),
184 }
185 }
186
187 #[test]
189 fn drop_runs_cleanup() {
190 let flag = Arc::new(AtomicBool::new(false));
191 let flag_clone = flag.clone();
192
193 let (_tx, rx) = tokio::sync::oneshot::channel::<Result<(), String>>();
194 let stream = RegisteredStream::new(dummy_conn_info(), rx).with_cleanup(move || {
195 flag_clone.store(true, Ordering::SeqCst);
196 });
197
198 drop(stream);
199 assert!(
200 flag.load(Ordering::SeqCst),
201 "cleanup must fire when RegisteredStream is dropped"
202 );
203 }
204
205 #[test]
209 fn into_parts_disarms_cleanup() {
210 let flag = Arc::new(AtomicBool::new(false));
211 let flag_clone = flag.clone();
212
213 let (_tx, rx) = tokio::sync::oneshot::channel::<Result<(), String>>();
214 let stream = RegisteredStream::new(dummy_conn_info(), rx).with_cleanup(move || {
215 flag_clone.store(true, Ordering::SeqCst);
216 });
217
218 let (conn, provider) = stream.into_parts();
219 drop(conn);
220 drop(provider);
221
222 assert!(
223 !flag.load(Ordering::SeqCst),
224 "into_parts() must disarm the cleanup closure"
225 );
226 }
227
228 #[test]
230 fn drop_without_cleanup_is_a_noop() {
231 let (_tx, rx) = tokio::sync::oneshot::channel::<Result<(), String>>();
232 let stream: RegisteredStream<()> = RegisteredStream::new(dummy_conn_info(), rx);
233 drop(stream); }
235}
236
237pub struct StreamSender {
262 tx: tokio::sync::mpsc::Sender<TwoPartMessage>,
263 prologue: Option<ResponseStreamPrologue>,
264}
265
266impl StreamSender {
267 pub async fn send(&self, data: Bytes) -> Result<()> {
268 Ok(self.tx.send(TwoPartMessage::from_data(data)).await?)
269 }
270
271 pub async fn send_control(&self, control: ControlMessage) -> Result<()> {
272 let bytes = serde_json::to_vec(&control)?;
273 Ok(self
274 .tx
275 .send(TwoPartMessage::from_header(bytes.into()))
276 .await?)
277 }
278
279 #[allow(clippy::needless_update)]
280 pub async fn send_prologue(&mut self, error: Option<String>) -> Result<(), String> {
281 if let Some(_prologue) = self.prologue.take() {
285 let prologue = ResponseStreamPrologue { error };
287 let header_bytes: Bytes = match serde_json::to_vec(&prologue) {
288 Ok(b) => b.into(),
289 Err(err) => {
290 tracing::error!(%err, "send_prologue: ResponseStreamPrologue did not serialize to a JSON array");
291 return Err("Invalid prologue".to_string());
292 }
293 };
294 self.tx
295 .send(TwoPartMessage::from_header(header_bytes))
296 .await
297 .map_err(|e| e.to_string())?;
298 } else {
299 panic!("Prologue already sent; or not set; logic error");
300 }
301 Ok(())
302 }
303}
304
305pub struct StreamReceiver {
306 rx: tokio::sync::mpsc::Receiver<Bytes>,
307}
308
309#[derive(Debug, Clone, Serialize, Deserialize)]
319pub struct ConnectionInfo {
320 pub transport: String,
321 pub info: String,
322}
323
324#[derive(Clone, Builder)]
331pub struct StreamOptions {
332 pub context: Arc<dyn AsyncEngineContext>,
334
335 pub enable_request_stream: bool,
340
341 pub enable_response_stream: bool,
344
345 #[builder(default = "8")]
347 pub send_buffer_count: usize,
348
349 #[builder(default = "8")]
351 pub recv_buffer_count: usize,
352}
353
354impl StreamOptions {
355 pub fn builder() -> StreamOptionsBuilder {
356 StreamOptionsBuilder::default()
357 }
358}
359
360pub struct Egress<Req: PipelineIO, Resp: PipelineIO> {
361 transport_engine: Arc<dyn AsyncTransportEngine<Req, Resp>>,
362}
363
364#[async_trait]
365impl<T: Data, U: Data> AsyncEngine<SingleIn<T>, ManyOut<U>, Error>
366 for Egress<SingleIn<T>, ManyOut<U>>
367where
368 T: Data + Serialize,
369 U: for<'de> Deserialize<'de> + Data,
370{
371 async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
372 self.transport_engine.generate(request).await
373 }
374}
375
376#[derive(Debug, Clone, Serialize, Deserialize)]
377#[serde(rename_all = "snake_case")]
378enum RequestType {
379 SingleIn,
380 ManyIn,
381}
382
383#[derive(Debug, Clone, Serialize, Deserialize)]
384#[serde(rename_all = "snake_case")]
385enum ResponseType {
386 SingleOut,
387 ManyOut,
388}
389
390#[derive(Debug, Clone, Serialize, Deserialize)]
391struct RequestControlMessage {
392 id: String,
393 request_type: RequestType,
394 response_type: ResponseType,
395 connection_info: ConnectionInfo,
396 #[serde(default, skip_serializing_if = "Option::is_none")]
400 frontend_send_ts_ns: Option<u64>,
401}
402
403pub struct Ingress<Req: PipelineIO, Resp: PipelineIO> {
404 segment: OnceLock<Arc<SegmentSource<Req, Resp>>>,
405 metrics: OnceLock<Arc<WorkHandlerMetrics>>,
406 endpoint_health_check_notifier: OnceLock<Arc<tokio::sync::Notify>>,
408}
409
410impl<Req: PipelineIO + Sync, Resp: PipelineIO> Ingress<Req, Resp> {
411 pub fn new() -> Arc<Self> {
412 Arc::new(Self {
413 segment: OnceLock::new(),
414 metrics: OnceLock::new(),
415 endpoint_health_check_notifier: OnceLock::new(),
416 })
417 }
418
419 pub fn attach(&self, segment: Arc<SegmentSource<Req, Resp>>) -> Result<()> {
420 self.segment
421 .set(segment)
422 .map_err(|_| anyhow::anyhow!("Segment already set"))
423 }
424
425 pub fn add_metrics(
426 &self,
427 endpoint: &crate::component::Endpoint,
428 metrics_labels: Option<&[(&str, &str)]>,
429 ) -> Result<()> {
430 let metrics = WorkHandlerMetrics::from_endpoint(endpoint, metrics_labels)
431 .map_err(|e| anyhow::anyhow!("Failed to create work handler metrics: {}", e))?;
432
433 crate::metrics::work_handler_perf::ensure_work_handler_perf_metrics_registered(
435 endpoint.get_metrics_registry(),
436 );
437
438 crate::metrics::work_handler_pool::ensure_work_handler_pool_metrics_registered(
442 endpoint.get_metrics_registry(),
443 );
444
445 self.metrics
446 .set(Arc::new(metrics))
447 .map_err(|_| anyhow::anyhow!("Metrics already set"))
448 }
449
450 pub fn link(segment: Arc<SegmentSource<Req, Resp>>) -> Result<Arc<Self>> {
451 let ingress = Ingress::new();
452 ingress.attach(segment)?;
453 Ok(ingress)
454 }
455
456 pub fn for_pipeline(segment: Arc<SegmentSource<Req, Resp>>) -> Result<Arc<Self>> {
457 let ingress = Ingress::new();
458 ingress.attach(segment)?;
459 Ok(ingress)
460 }
461
462 pub fn for_engine(engine: ServiceEngine<Req, Resp>) -> Result<Arc<Self>> {
463 let frontend = SegmentSource::<Req, Resp>::new();
464 let backend = ServiceBackend::from_engine(engine);
465
466 let pipeline = frontend.link(backend)?.link(frontend)?;
468
469 let ingress = Ingress::new();
470 ingress.attach(pipeline)?;
471
472 Ok(ingress)
473 }
474
475 fn metrics(&self) -> Option<&Arc<WorkHandlerMetrics>> {
477 self.metrics.get()
478 }
479}
480
481#[async_trait]
482pub trait PushWorkHandler: Send + Sync {
483 async fn handle_payload(
484 &self,
485 payload: Bytes,
486 request_id: Option<String>,
487 ) -> Result<(), PipelineError>;
488
489 fn add_metrics(
491 &self,
492 endpoint: &crate::component::Endpoint,
493 metrics_labels: Option<&[(&str, &str)]>,
494 ) -> Result<()>;
495
496 fn set_endpoint_health_check_notifier(
498 &self,
499 _notifier: Arc<tokio::sync::Notify>,
500 ) -> Result<()> {
501 Ok(())
503 }
504}
505
506#[derive(Serialize, Deserialize, Debug)]
551pub struct NetworkStreamWrapper<U> {
552 #[serde(skip_serializing_if = "Option::is_none")]
553 pub data: Option<U>,
554 pub complete_final: bool,
555}