#![allow(dead_code)]
use std::collections::HashMap;
use std::sync::{Arc, OnceLock};
use async_trait::async_trait;
use futures::StreamExt;
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContext, Data, ResponseStream};
use dynamo_runtime::pipeline::{
Error, ManyOut, PipelineError, PipelineIO, SegmentSource, SingleIn,
context::{Context, StreamContext},
};
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub enum LatencyModel {
NoDelay,
ConstantDelayInNanos(u64),
NormalDistributionInNanos(u64, u64),
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct MockNetworkOptions {
request_latency: LatencyModel,
response_latency: LatencyModel,
}
impl Default for MockNetworkOptions {
fn default() -> Self {
Self {
request_latency: LatencyModel::NoDelay,
response_latency: LatencyModel::NoDelay,
}
}
}
#[derive(Debug, Clone)]
struct ControlPlaneRequest {
id: String,
request: Vec<u8>,
resp_tx: mpsc::Sender<DataPlaneMessage>,
}
enum MockNetworkControlEvents {
ControlPlaneRequest(ControlPlaneRequest),
Cancel(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
enum MockNetworkDataPlaneHeaders {
Handshake(Handshake),
Error(String),
Sentinel,
HeartBeat,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
enum Status {
Ok,
Error(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Handshake {
request_id: String,
worker_id: Option<String>,
status: Status,
}
struct DataPlaneMessage {
pub headers: Option<MockNetworkDataPlaneHeaders>,
pub body: Vec<u8>,
}
pub struct MockNetworkTransport<T: PipelineIO, U: PipelineIO> {
req: std::marker::PhantomData<T>,
resp: std::marker::PhantomData<U>,
}
impl<Req: PipelineIO, Resp: PipelineIO> MockNetworkTransport<Req, Resp> {
pub fn new_egress_ingress(
options: MockNetworkOptions,
) -> (
Arc<MockNetworkEgress<Req, Resp>>,
MockNetworkIngress<Req, Resp>,
) {
let (ctrl_tx, ctrl_rx) = mpsc::channel::<MockNetworkControlEvents>(8);
let egress = Arc::new(MockNetworkEgress::<Req, Resp>::new(
options.clone(),
ctrl_tx.clone(),
));
let ingress = MockNetworkIngress::<Req, Resp>::new(options.clone(), ctrl_rx);
(egress, ingress)
}
}
#[allow(dead_code)]
pub struct MockNetworkEgress<Req: PipelineIO, Resp: PipelineIO> {
options: MockNetworkOptions,
ctrl_tx: mpsc::Sender<MockNetworkControlEvents>,
req: std::marker::PhantomData<Req>,
resp: std::marker::PhantomData<Resp>,
}
impl<Req: PipelineIO, Resp: PipelineIO> MockNetworkEgress<Req, Resp> {
fn new(options: MockNetworkOptions, ctrl_tx: mpsc::Sender<MockNetworkControlEvents>) -> Self {
Self {
options,
ctrl_tx,
req: std::marker::PhantomData,
resp: std::marker::PhantomData,
}
}
}
#[async_trait]
impl<T: Data, U: Data> AsyncEngine<SingleIn<T>, ManyOut<U>, Error>
for MockNetworkEgress<SingleIn<T>, ManyOut<U>>
where
T: Data + Serialize,
U: for<'de> Deserialize<'de> + Data + Send + Sync + 'static,
Self: Send + Sync,
{
async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
let ctrl_tx = self.ctrl_tx.clone();
let id = request.id().to_string();
let request = request.try_map(|req| serde_json::to_vec(&req))?;
let (data, context) = request.transfer(());
let context = Arc::new(StreamContext::from(context));
let (data_tx, data_rx) = mpsc::channel::<DataPlaneMessage>(16);
let mut byte_stream = tokio_stream::wrappers::ReceiverStream::new(data_rx);
let (finished_tx, finished_rx) = tokio::sync::oneshot::channel::<()>();
let stream_monitor = ResponseMonitor {
ctx: context.clone(),
finish_rx: finished_rx,
};
let request = ControlPlaneRequest {
id,
request: data,
resp_tx: data_tx,
};
ctrl_tx
.send(MockNetworkControlEvents::ControlPlaneRequest(request))
.await
.map_err(|e| PipelineError::ControlPlaneRequestError(e.to_string()))?;
match byte_stream.next().await {
Some(DataPlaneMessage { headers, body }) => {
if !body.is_empty() {
return Err(PipelineError::ControlPlaneRequestError(
"Expected an empty body for the handshake message".to_string(),
)
.into());
}
match headers {
Some(header) => match header {
MockNetworkDataPlaneHeaders::Handshake(handshake) => {
match handshake.status {
Status::Ok => {}
Status::Error(e) => {
return Err(PipelineError::ControlPlaneRequestError(format!(
"remote segment was unable to process request: {}",
e
))
.into());
}
}
}
_ => {
return Err(PipelineError::ControlPlaneRequestError(format!(
"Expected a handshake message; got: {:?}",
header
))
.into());
}
},
_ => {
return Err(PipelineError::ControlPlaneRequestError(
"Failed to receive properly formatted handshake on data plane"
.to_string(),
)
.into());
}
}
}
None => {
return Err(PipelineError::ControlPlaneRequestError(
"Failed data plane connection closed before receiving handshake".to_string(),
)
.into());
}
}
let decoded = byte_stream
.scan(Some(stream_monitor), move |_stream_monitor, item| {
if let Some(headers) = &item.headers {
match headers {
MockNetworkDataPlaneHeaders::HeartBeat => {
}
MockNetworkDataPlaneHeaders::Sentinel => {
return futures::future::ready(None);
}
_ => {}
}
}
futures::future::ready(Some(item))
})
.map(move |item| {
serde_json::from_slice::<U>(&item.body).expect("failed to deserialize response")
});
let cancellation_monitor = CancellationMonitor {
ctx: context.clone(),
ctrl_tx,
finish_tx: finished_tx,
};
tokio::spawn(cancellation_monitor.execute());
Ok(ResponseStream::new(Box::pin(decoded), context))
}
}
#[allow(dead_code)]
pub struct MockNetworkIngress<Req: PipelineIO, Resp: PipelineIO> {
options: MockNetworkOptions,
ctrl_rx: mpsc::Receiver<MockNetworkControlEvents>,
segment: OnceLock<Arc<SegmentSource<Req, Resp>>>,
}
impl<Req: PipelineIO, Resp: PipelineIO> MockNetworkIngress<Req, Resp> {
fn new(options: MockNetworkOptions, ctrl_rx: mpsc::Receiver<MockNetworkControlEvents>) -> Self {
Self {
options,
ctrl_rx,
segment: OnceLock::new(),
}
}
pub fn segment(&self, segment: Arc<SegmentSource<Req, Resp>>) -> Result<(), PipelineError> {
self.segment
.set(segment)
.map_err(|_| PipelineError::EdgeAlreadySet)
}
}
impl<T: Data, U: Data> MockNetworkIngress<SingleIn<T>, ManyOut<U>>
where
T: Data + for<'de> Deserialize<'de>,
U: Data + Serialize,
{
pub async fn execute(self) -> Result<(), PipelineError> {
let mut state = HashMap::<String, Arc<dyn AsyncEngineContext>>::new();
let worker_id = uuid::Uuid::new_v4().to_string();
let mut ctrl_rx = self.ctrl_rx;
let segment = self.segment.get().expect("segment not set").clone();
while let Some(event) = ctrl_rx.recv().await {
match event {
MockNetworkControlEvents::ControlPlaneRequest(req) => {
let id = req.id.clone();
tracing::debug!("[ingress] received request [id: {}]", id);
let request = serde_json::from_slice::<T>(&req.request)
.expect("failed to deserialize request");
let request = Context::<T>::with_id(request, req.id.clone());
let response = segment.generate(request).await;
let handshake = match &response {
Ok(_) => Handshake {
request_id: req.id,
worker_id: Some(worker_id.clone()),
status: Status::Ok,
},
Err(e) => Handshake {
request_id: req.id,
worker_id: Some(worker_id.clone()),
status: Status::Error(e.to_string()),
},
};
tracing::debug!("[ingress] sending handshake [id: {}]: {:?}", id, handshake);
let handshake = DataPlaneMessage {
headers: Some(MockNetworkDataPlaneHeaders::Handshake(handshake)),
body: vec![],
};
req.resp_tx
.send(handshake)
.await
.expect("failed to send handshake");
tracing::trace!("[ingress] handshake sent [id: {}]", id);
if let Ok(response) = response {
tracing::debug!("[ingress] processing response stream [id: {}]", id);
tokio::spawn(async move {
let mut response = response;
while let Some(resp) = response.next().await {
tracing::trace!("[ingress] received response [id: {}]", id);
let resp_bytes = serde_json::to_vec(&resp)
.expect("failed to serialize response");
let msg = DataPlaneMessage {
headers: None,
body: resp_bytes,
};
req.resp_tx
.send(msg)
.await
.expect("failed to send response");
tracing::trace!("[ingress] sent response [id: {}]", id);
}
tracing::debug!("response stream completed [id: {}]", id);
});
}
}
MockNetworkControlEvents::Cancel(id) => {
if let Some(tx) = state.remove(&id) {
tx.stop_generating();
}
}
}
}
Ok(())
}
}
struct CancellationMonitor {
ctx: Arc<StreamContext>,
ctrl_tx: tokio::sync::mpsc::Sender<MockNetworkControlEvents>,
finish_tx: tokio::sync::oneshot::Sender<()>,
}
impl CancellationMonitor {
async fn execute(self) {
let ctx = self.ctx;
let ctrl_tx = self.ctrl_tx;
let mut finish_tx = self.finish_tx;
tokio::select! {
_ = ctx.stopped() => {
let _ = ctrl_tx.send(MockNetworkControlEvents::Cancel(ctx.id().to_string())).await;
}
_ = finish_tx.closed() => {
}
}
}
}
#[allow(dead_code)]
struct ResponseMonitor {
ctx: Arc<StreamContext>,
finish_rx: tokio::sync::oneshot::Receiver<()>,
}