pub mod codec;
pub mod egress;
pub mod ingress;
pub mod manager;
pub mod tcp;
use crate::SystemHealth;
use std::sync::{Arc, OnceLock};
use anyhow::Result;
use async_trait::async_trait;
use bytes::Bytes;
use codec::{TwoPartCodec, TwoPartMessage, TwoPartMessageType};
use derive_builder::Builder;
use futures::StreamExt;
use super::{AsyncEngine, AsyncEngineContext, AsyncEngineContextProvider, ResponseStream};
use serde::{Deserialize, Serialize};
use super::{
AsyncTransportEngine, Context, Data, Error, ManyOut, PipelineError, PipelineIO, SegmentSource,
ServiceBackend, ServiceEngine, SingleIn, Source, context,
};
use crate::metrics::MetricsHierarchy;
use ingress::push_handler::WorkHandlerMetrics;
use prometheus::{CounterVec, Histogram, IntCounter, IntCounterVec, IntGauge};
pub(crate) const DEFAULT_TCP_MAX_MESSAGE_SIZE: usize = 32 * 1024 * 1024;
static TCP_MAX_MESSAGE_SIZE: OnceLock<usize> = OnceLock::new();
pub(crate) fn get_tcp_max_message_size() -> usize {
*TCP_MAX_MESSAGE_SIZE.get_or_init(|| {
std::env::var("DYN_TCP_MAX_MESSAGE_SIZE")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(DEFAULT_TCP_MAX_MESSAGE_SIZE)
})
}
pub trait Codable: PipelineIO + Serialize + for<'de> Deserialize<'de> {}
impl<T: PipelineIO + Serialize + for<'de> Deserialize<'de>> Codable for T {}
#[async_trait]
pub trait WorkQueueConsumer {
async fn dequeue(&self) -> Result<Bytes, String>;
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum StreamType {
Request,
Response,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ControlMessage {
Stop,
Kill,
Sentinel,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct ResponseStreamPrologue {
error: Option<String>,
}
pub type StreamProvider<T> = tokio::sync::oneshot::Receiver<Result<T, String>>;
struct Cleanup(Option<Box<dyn FnOnce() + Send + 'static>>);
impl Drop for Cleanup {
fn drop(&mut self) {
if let Some(f) = self.0.take() {
f();
}
}
}
pub struct RegisteredStream<T> {
pub connection_info: ConnectionInfo,
pub stream_provider: StreamProvider<T>,
cleanup: Cleanup,
}
impl<T> std::fmt::Debug for RegisteredStream<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RegisteredStream")
.field("connection_info", &self.connection_info)
.finish_non_exhaustive()
}
}
impl<T> RegisteredStream<T> {
pub(crate) fn new(connection_info: ConnectionInfo, stream_provider: StreamProvider<T>) -> Self {
Self {
connection_info,
stream_provider,
cleanup: Cleanup(None),
}
}
pub(crate) fn with_cleanup<F>(mut self, cleanup: F) -> Self
where
F: FnOnce() + Send + 'static,
{
self.cleanup.0 = Some(Box::new(cleanup));
self
}
pub fn into_parts(self) -> (ConnectionInfo, StreamProvider<T>) {
let Self {
connection_info,
stream_provider,
mut cleanup,
} = self;
cleanup.0.take();
(connection_info, stream_provider)
}
}
pub struct PendingConnections {
pub send_stream: Option<RegisteredStream<StreamSender>>,
pub recv_stream: Option<RegisteredStream<StreamReceiver>>,
}
impl PendingConnections {
pub fn into_parts(
self,
) -> (
Option<RegisteredStream<StreamSender>>,
Option<RegisteredStream<StreamReceiver>>,
) {
(self.send_stream, self.recv_stream)
}
}
#[async_trait::async_trait]
pub trait ResponseService {
async fn register(&self, options: StreamOptions) -> PendingConnections;
}
#[cfg(test)]
mod registered_stream_tests {
use super::*;
use std::sync::atomic::{AtomicBool, Ordering};
fn dummy_conn_info() -> ConnectionInfo {
ConnectionInfo {
transport: "test".to_string(),
info: "{}".to_string(),
}
}
#[test]
fn drop_runs_cleanup() {
let flag = Arc::new(AtomicBool::new(false));
let flag_clone = flag.clone();
let (_tx, rx) = tokio::sync::oneshot::channel::<Result<(), String>>();
let stream = RegisteredStream::new(dummy_conn_info(), rx).with_cleanup(move || {
flag_clone.store(true, Ordering::SeqCst);
});
drop(stream);
assert!(
flag.load(Ordering::SeqCst),
"cleanup must fire when RegisteredStream is dropped"
);
}
#[test]
fn into_parts_disarms_cleanup() {
let flag = Arc::new(AtomicBool::new(false));
let flag_clone = flag.clone();
let (_tx, rx) = tokio::sync::oneshot::channel::<Result<(), String>>();
let stream = RegisteredStream::new(dummy_conn_info(), rx).with_cleanup(move || {
flag_clone.store(true, Ordering::SeqCst);
});
let (conn, provider) = stream.into_parts();
drop(conn);
drop(provider);
assert!(
!flag.load(Ordering::SeqCst),
"into_parts() must disarm the cleanup closure"
);
}
#[test]
fn drop_without_cleanup_is_a_noop() {
let (_tx, rx) = tokio::sync::oneshot::channel::<Result<(), String>>();
let stream: RegisteredStream<()> = RegisteredStream::new(dummy_conn_info(), rx);
drop(stream); }
}
pub struct StreamSender {
tx: tokio::sync::mpsc::Sender<TwoPartMessage>,
prologue: Option<ResponseStreamPrologue>,
}
impl StreamSender {
pub async fn send(&self, data: Bytes) -> Result<()> {
Ok(self.tx.send(TwoPartMessage::from_data(data)).await?)
}
pub async fn send_control(&self, control: ControlMessage) -> Result<()> {
let bytes = serde_json::to_vec(&control)?;
Ok(self
.tx
.send(TwoPartMessage::from_header(bytes.into()))
.await?)
}
#[allow(clippy::needless_update)]
pub async fn send_prologue(&mut self, error: Option<String>) -> Result<(), String> {
if let Some(_prologue) = self.prologue.take() {
let prologue = ResponseStreamPrologue { error };
let header_bytes: Bytes = match serde_json::to_vec(&prologue) {
Ok(b) => b.into(),
Err(err) => {
tracing::error!(%err, "send_prologue: ResponseStreamPrologue did not serialize to a JSON array");
return Err("Invalid prologue".to_string());
}
};
self.tx
.send(TwoPartMessage::from_header(header_bytes))
.await
.map_err(|e| e.to_string())?;
} else {
panic!("Prologue already sent; or not set; logic error");
}
Ok(())
}
}
pub struct StreamReceiver {
rx: tokio::sync::mpsc::Receiver<Bytes>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConnectionInfo {
pub transport: String,
pub info: String,
}
#[derive(Clone, Builder)]
pub struct StreamOptions {
pub context: Arc<dyn AsyncEngineContext>,
pub enable_request_stream: bool,
pub enable_response_stream: bool,
#[builder(default = "8")]
pub send_buffer_count: usize,
#[builder(default = "8")]
pub recv_buffer_count: usize,
}
impl StreamOptions {
pub fn builder() -> StreamOptionsBuilder {
StreamOptionsBuilder::default()
}
}
pub struct Egress<Req: PipelineIO, Resp: PipelineIO> {
transport_engine: Arc<dyn AsyncTransportEngine<Req, Resp>>,
}
#[async_trait]
impl<T: Data, U: Data> AsyncEngine<SingleIn<T>, ManyOut<U>, Error>
for Egress<SingleIn<T>, ManyOut<U>>
where
T: Data + Serialize,
U: for<'de> Deserialize<'de> + Data,
{
async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
self.transport_engine.generate(request).await
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
enum RequestType {
SingleIn,
ManyIn,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
enum ResponseType {
SingleOut,
ManyOut,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct RequestControlMessage {
id: String,
request_type: RequestType,
response_type: ResponseType,
connection_info: ConnectionInfo,
#[serde(default, skip_serializing_if = "Option::is_none")]
frontend_send_ts_ns: Option<u64>,
}
pub struct Ingress<Req: PipelineIO, Resp: PipelineIO> {
segment: OnceLock<Arc<SegmentSource<Req, Resp>>>,
metrics: OnceLock<Arc<WorkHandlerMetrics>>,
endpoint_health_check_notifier: OnceLock<Arc<tokio::sync::Notify>>,
}
impl<Req: PipelineIO + Sync, Resp: PipelineIO> Ingress<Req, Resp> {
pub fn new() -> Arc<Self> {
Arc::new(Self {
segment: OnceLock::new(),
metrics: OnceLock::new(),
endpoint_health_check_notifier: OnceLock::new(),
})
}
pub fn attach(&self, segment: Arc<SegmentSource<Req, Resp>>) -> Result<()> {
self.segment
.set(segment)
.map_err(|_| anyhow::anyhow!("Segment already set"))
}
pub fn add_metrics(
&self,
endpoint: &crate::component::Endpoint,
metrics_labels: Option<&[(&str, &str)]>,
) -> Result<()> {
let metrics = WorkHandlerMetrics::from_endpoint(endpoint, metrics_labels)
.map_err(|e| anyhow::anyhow!("Failed to create work handler metrics: {}", e))?;
crate::metrics::work_handler_perf::ensure_work_handler_perf_metrics_registered(
endpoint.get_metrics_registry(),
);
crate::metrics::work_handler_pool::ensure_work_handler_pool_metrics_registered(
endpoint.get_metrics_registry(),
);
self.metrics
.set(Arc::new(metrics))
.map_err(|_| anyhow::anyhow!("Metrics already set"))
}
pub fn link(segment: Arc<SegmentSource<Req, Resp>>) -> Result<Arc<Self>> {
let ingress = Ingress::new();
ingress.attach(segment)?;
Ok(ingress)
}
pub fn for_pipeline(segment: Arc<SegmentSource<Req, Resp>>) -> Result<Arc<Self>> {
let ingress = Ingress::new();
ingress.attach(segment)?;
Ok(ingress)
}
pub fn for_engine(engine: ServiceEngine<Req, Resp>) -> Result<Arc<Self>> {
let frontend = SegmentSource::<Req, Resp>::new();
let backend = ServiceBackend::from_engine(engine);
let pipeline = frontend.link(backend)?.link(frontend)?;
let ingress = Ingress::new();
ingress.attach(pipeline)?;
Ok(ingress)
}
fn metrics(&self) -> Option<&Arc<WorkHandlerMetrics>> {
self.metrics.get()
}
}
#[async_trait]
pub trait PushWorkHandler: Send + Sync {
async fn handle_payload(
&self,
payload: Bytes,
request_id: Option<String>,
) -> Result<(), PipelineError>;
fn add_metrics(
&self,
endpoint: &crate::component::Endpoint,
metrics_labels: Option<&[(&str, &str)]>,
) -> Result<()>;
fn set_endpoint_health_check_notifier(
&self,
_notifier: Arc<tokio::sync::Notify>,
) -> Result<()> {
Ok(())
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct NetworkStreamWrapper<U> {
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<U>,
pub complete_final: bool,
}