use chrono::{DateTime, Utc};
use std::collections::HashMap;
use std::env;
use std::sync::Arc;
use tokio::sync::{mpsc, oneshot};
use tokio::task::JoinHandle;
use tokio_stream::wrappers::ReceiverStream;
use tokio_util::sync::CancellationToken;
use tonic::{Request, Status, Streaming};
use crate::error::{Error, ErrorKind};
use tracing::{debug, error, info};
use crate::proto::metadata as metadata_pb;
use crate::proto::sink::{self as sink_pb, SinkResponse};
use crate::shared;
use shared::{ContainerType, ENV_CONTAINER_TYPE, build_panic_status, get_panic_info};
pub const SOCK_ADDR: &str = "/var/run/numaflow/sink.sock";
pub const SERVER_INFO_FILE: &str = "/var/run/numaflow/sinker-server-info";
pub const FB_SOCK_ADDR: &str = "/var/run/numaflow/fb-sink.sock";
pub const FB_SERVER_INFO_FILE: &str = "/var/run/numaflow/fb-sinker-server-info";
const FB_CONTAINER_TYPE: &str = "fb-udsink";
pub const ONS_SOCK_ADDR: &str = "/var/run/numaflow/ons-sink.sock";
pub const ONS_SERVER_INFO_FILE: &str = "/var/run/numaflow/ons-sinker-server-info";
const ONS_CONTAINER_TYPE: &str = "ons-udsink";
const CHANNEL_SIZE: usize = 1000;
struct SinkService<T: Sinker> {
handler: Arc<T>,
shutdown_tx: mpsc::Sender<()>,
cancellation_token: CancellationToken,
}
#[tonic::async_trait]
pub trait Sinker {
async fn sink(&self, input: mpsc::Receiver<SinkRequest>) -> Vec<Response>;
}
pub struct SinkRequest {
pub keys: Vec<String>,
pub value: Vec<u8>,
pub watermark: DateTime<Utc>,
pub event_time: DateTime<Utc>,
pub id: String,
pub headers: HashMap<String, String>,
pub user_metadata: UserMetadata,
pub system_metadata: SystemMetadata,
}
#[derive(Debug, Clone, Default)]
pub struct UserMetadata {
data: HashMap<String, HashMap<String, Vec<u8>>>,
}
impl UserMetadata {
pub fn new() -> Self {
Self {
data: Default::default(),
}
}
pub fn groups(&self) -> Vec<String> {
self.data.keys().cloned().collect()
}
pub fn keys(&self, group: &str) -> Vec<String> {
self.data
.get(group)
.unwrap_or(&HashMap::new())
.keys()
.cloned()
.collect()
}
pub fn value(&self, group: &str, key: &str) -> Vec<u8> {
self.data
.get(group)
.unwrap_or(&HashMap::new())
.get(key)
.unwrap_or(&Vec::new())
.clone()
}
}
#[derive(Debug, Clone, Default)]
pub struct SystemMetadata {
data: HashMap<String, HashMap<String, Vec<u8>>>,
}
impl SystemMetadata {
pub fn new() -> Self {
Self {
data: Default::default(),
}
}
pub fn groups(&self) -> Vec<String> {
self.data.keys().cloned().collect()
}
pub fn keys(&self, group: &str) -> Vec<String> {
self.data
.get(group)
.unwrap_or(&HashMap::new())
.keys()
.cloned()
.collect()
}
pub fn value(&self, group: &str, key: &str) -> Vec<u8> {
self.data
.get(group)
.unwrap_or(&HashMap::new())
.get(key)
.unwrap_or(&Vec::new())
.clone()
}
}
fn user_metadata_from_proto(proto: Option<&metadata_pb::Metadata>) -> UserMetadata {
let proto = match proto {
Some(p) => p,
None => return UserMetadata::new(),
};
let mut user_map = HashMap::new();
for (group, kv_group) in &proto.user_metadata {
user_map.insert(group.clone(), kv_group.key_value.clone());
}
UserMetadata { data: user_map }
}
fn system_metadata_from_proto(proto: Option<&metadata_pb::Metadata>) -> SystemMetadata {
let proto = match proto {
Some(p) => p,
None => return SystemMetadata::new(),
};
let mut sys_map = HashMap::new();
for (group, kv_group) in &proto.sys_metadata {
sys_map.insert(group.clone(), kv_group.key_value.clone());
}
SystemMetadata { data: sys_map }
}
impl From<sink_pb::sink_request::Request> for SinkRequest {
fn from(sr: sink_pb::sink_request::Request) -> Self {
let user_metadata = user_metadata_from_proto(sr.metadata.as_ref());
let system_metadata = system_metadata_from_proto(sr.metadata.as_ref());
Self {
keys: sr.keys,
value: sr.value,
watermark: shared::utc_from_timestamp(sr.watermark),
event_time: shared::utc_from_timestamp(sr.event_time),
id: sr.id,
headers: sr.headers,
user_metadata,
system_metadata,
}
}
}
pub enum ResponseType {
Success,
Failure,
FallBack,
Serve,
OnSuccess,
}
#[derive(Default)]
pub struct KeyValueGroup {
pub key_value: HashMap<String, Vec<u8>>,
}
impl From<KeyValueGroup> for metadata_pb::KeyValueGroup {
fn from(kv: KeyValueGroup) -> Self {
Self {
key_value: kv.key_value,
}
}
}
impl From<HashMap<String, Vec<u8>>> for KeyValueGroup {
fn from(hm: HashMap<String, Vec<u8>>) -> Self {
Self { key_value: hm }
}
}
impl From<HashMap<String, String>> for KeyValueGroup {
fn from(hm: HashMap<String, String>) -> Self {
Self {
key_value: hm.into_iter().map(|(k, v)| (k, v.into())).collect(),
}
}
}
#[derive(Default)]
pub struct Message {
pub keys: Option<Vec<String>>,
pub value: Vec<u8>,
pub user_metadata: Option<HashMap<String, KeyValueGroup>>,
}
impl Message {
pub fn new(value: Vec<u8>) -> Self {
Self {
value,
keys: None,
user_metadata: None,
}
}
pub fn with_keys(mut self, keys: Vec<String>) -> Self {
self.keys = Some(keys);
self
}
pub fn with_user_metadata(mut self, user_metadata: HashMap<String, KeyValueGroup>) -> Self {
self.user_metadata = Some(user_metadata);
self
}
pub fn build(self) -> Option<Self> {
Some(self)
}
}
impl From<Message> for sink_pb::sink_response::result::Message {
fn from(msg: Message) -> Self {
Self {
keys: msg.keys.map_or(vec![], |keys| keys),
value: msg.value,
metadata: msg
.user_metadata
.map(|user_metadata| metadata_pb::Metadata {
user_metadata: user_metadata
.into_iter()
.map(|(k, v)| (k, metadata_pb::KeyValueGroup::from(v)))
.collect(),
..Default::default()
}),
}
}
}
pub struct Response {
pub id: String,
pub response_type: ResponseType,
pub err: Option<String>,
pub serve_response: Option<Vec<u8>>,
pub on_success_msg: Option<Message>,
}
impl Response {
pub fn ok(id: String) -> Self {
Self {
id,
response_type: ResponseType::Success,
err: None,
serve_response: None,
on_success_msg: None,
}
}
pub fn failure(id: String, err: String) -> Self {
Self {
id,
response_type: ResponseType::Failure,
err: Some(err),
serve_response: None,
on_success_msg: None,
}
}
pub fn fallback(id: String) -> Self {
Self {
id,
response_type: ResponseType::FallBack,
err: None,
serve_response: None,
on_success_msg: None,
}
}
pub fn serve(id: String, payload: Vec<u8>) -> Self {
Self {
id,
response_type: ResponseType::Serve,
err: None,
serve_response: Some(payload),
on_success_msg: None,
}
}
pub fn on_success(id: String, payload: Option<Message>) -> Self {
Self {
id,
response_type: ResponseType::OnSuccess,
err: None,
serve_response: None,
on_success_msg: payload,
}
}
}
impl From<Response> for sink_pb::sink_response::Result {
fn from(r: Response) -> Self {
Self {
id: r.id,
status: match r.response_type {
ResponseType::Success => sink_pb::Status::Success as i32,
ResponseType::Failure => sink_pb::Status::Failure as i32,
ResponseType::FallBack => sink_pb::Status::Fallback as i32,
ResponseType::Serve => sink_pb::Status::Serve as i32,
ResponseType::OnSuccess => sink_pb::Status::OnSuccess as i32,
},
err_msg: r.err.unwrap_or_default(),
serve_response: r.serve_response,
on_success_msg: r.on_success_msg.map(|msg| msg.into()),
}
}
}
#[tonic::async_trait]
impl<T> sink_pb::sink_server::Sink for SinkService<T>
where
T: Sinker + Send + Sync + 'static,
{
type SinkFnStream = ReceiverStream<Result<SinkResponse, Status>>;
async fn sink_fn(
&self,
request: Request<Streaming<sink_pb::SinkRequest>>,
) -> Result<tonic::Response<Self::SinkFnStream>, Status> {
let mut sink_stream = request.into_inner();
let sink_handle = self.handler.clone();
let shutdown_tx = self.shutdown_tx.clone();
let cln_token = self.cancellation_token.clone();
let (resp_tx, resp_rx) = mpsc::channel::<Result<SinkResponse, Status>>(CHANNEL_SIZE);
self.perform_handshake(&mut sink_stream, &resp_tx).await?;
let grpc_resp_tx = resp_tx.clone();
let handle: JoinHandle<Result<(), Error>> = tokio::spawn(async move {
Self::process_sink_stream(sink_handle, sink_stream, grpc_resp_tx).await
});
tokio::spawn(Self::handle_sink_errors(
handle,
resp_tx,
shutdown_tx,
cln_token,
));
Ok(tonic::Response::new(ReceiverStream::new(resp_rx)))
}
async fn is_ready(
&self,
_: Request<()>,
) -> Result<tonic::Response<sink_pb::ReadyResponse>, Status> {
Ok(tonic::Response::new(sink_pb::ReadyResponse { ready: true }))
}
}
impl<T> SinkService<T>
where
T: Sinker + Send + Sync + 'static,
{
async fn process_sink_stream(
sink_handle: Arc<T>,
mut sink_stream: Streaming<sink_pb::SinkRequest>,
grpc_resp_tx: mpsc::Sender<Result<SinkResponse, Status>>,
) -> Result<(), Error> {
let mut global_stream_ended = false;
while !global_stream_ended {
global_stream_ended = Self::process_sink_batch(
sink_handle.clone(),
&mut sink_stream,
grpc_resp_tx.clone(),
)
.await?;
}
Ok(())
}
async fn process_sink_batch(
sink_handle: Arc<T>,
sink_stream: &mut Streaming<sink_pb::SinkRequest>,
grpc_resp_tx: mpsc::Sender<Result<SinkResponse, Status>>,
) -> Result<bool, Error> {
let (tx, rx) = mpsc::channel::<SinkRequest>(CHANNEL_SIZE);
let resp_tx = grpc_resp_tx.clone();
let sink_handle = sink_handle.clone();
let sinker_handle = tokio::spawn(async move {
let responses = sink_handle.sink(rx).await;
if resp_tx
.send(Ok(SinkResponse {
results: responses.into_iter().map(|r| r.into()).collect(),
handshake: None,
status: None,
}))
.await
.is_err()
{
return;
}
if resp_tx
.send(Ok(SinkResponse {
results: vec![],
handshake: None,
status: Some(sink_pb::TransmissionStatus { eot: true }),
}))
.await
.is_err()
{}
});
let mut global_stream_ended = false;
loop {
let message = match sink_stream.message().await {
Ok(Some(m)) => m,
Ok(None) => {
info!("global bidi stream ended");
global_stream_ended = true;
break; }
Err(e) => {
error!("Error reading message from stream: {}", e);
global_stream_ended = true;
return Ok(global_stream_ended);
}
};
if message.status.is_some_and(|status| status.eot) {
debug!("Batch Ended, received an EOT message");
break;
}
let request = message.request.ok_or_else(|| {
Error::SinkError(ErrorKind::InternalError(
"Invalid argument, request can't be None".to_string(),
))
})?;
tx.send(request.into()).await.map_err(|e| {
Error::SinkError(ErrorKind::InternalError(format!(
"Error sending message to sink handler: {}",
e
)))
})?;
}
drop(tx);
match sinker_handle.await {
Ok(_) => {
}
Err(e) => {
if let Some(panic_info) = get_panic_info() {
let status = build_panic_status(&panic_info);
return Err(Error::GrpcStatus(status));
} else {
return Err(Error::SinkError(ErrorKind::UserDefinedError(e.to_string())));
}
}
}
Ok(global_stream_ended)
}
async fn handle_sink_errors(
handle: JoinHandle<Result<(), Error>>,
resp_tx: mpsc::Sender<Result<SinkResponse, Status>>,
shutdown_tx: mpsc::Sender<()>,
cln_token: CancellationToken,
) {
tokio::select! {
resp = handle => {
match resp {
Ok(Ok(_)) => {},
Ok(Err(e)) => {
resp_tx.send(Err(e.into_status())).await
.inspect_err(|send_err| error!("Failed to send error to response channel (receiver likely dropped): {}", send_err))
.ok();
shutdown_tx.send(()).await
.inspect_err(|send_err| error!("Failed to send shutdown signal: {}", send_err))
.ok();
}
Err(e) => {
resp_tx
.send(Err(Status::internal(format!(
"Sink handler aborted: {}",
e
))))
.await
.inspect_err(|send_err| error!("Failed to send error to response channel (receiver likely dropped): {}", send_err))
.ok();
shutdown_tx.send(()).await
.inspect_err(|send_err| error!("Failed to send shutdown signal: {}", send_err))
.ok();
}
}
},
_ = cln_token.cancelled() => {
resp_tx
.send(Err(Status::cancelled("Sink handler cancelled")))
.await
.inspect_err(|send_err| error!("Token cancelled: Failed to send error to response channel: {}", send_err))
.ok();
}
}
}
async fn perform_handshake(
&self,
sink_stream: &mut Streaming<sink_pb::SinkRequest>,
resp_tx: &mpsc::Sender<Result<SinkResponse, Status>>,
) -> Result<(), Status> {
let handshake_request = sink_stream
.message()
.await
.map_err(|e| Status::internal(format!("handshake failed {}", e)))?
.ok_or_else(|| Status::internal("stream closed before handshake"))?;
if let Some(handshake) = handshake_request.handshake {
resp_tx
.send(Ok(SinkResponse {
results: vec![],
handshake: Some(handshake),
status: None,
}))
.await
.map_err(|e| {
Status::internal(format!("failed to send handshake response {}", e))
})?;
Ok(())
} else {
Err(Status::invalid_argument("Handshake not present"))
}
}
}
#[derive(Debug)]
pub struct Server<T> {
inner: shared::Server<T>,
}
impl<T> shared::ServerExtras<T> for Server<T> {
fn transform_inner<F>(self, f: F) -> Self
where
F: FnOnce(shared::Server<T>) -> shared::Server<T>,
{
Self {
inner: f(self.inner),
}
}
fn inner_ref(&self) -> &shared::Server<T> {
&self.inner
}
}
impl<T> Server<T> {
pub fn new(svc: T) -> Self {
let container_type = env::var(ENV_CONTAINER_TYPE).unwrap_or_default();
let (sock_addr, server_info_file) = if container_type == FB_CONTAINER_TYPE {
(FB_SOCK_ADDR, FB_SERVER_INFO_FILE)
} else if container_type == ONS_CONTAINER_TYPE {
(ONS_SOCK_ADDR, ONS_SERVER_INFO_FILE)
} else {
(SOCK_ADDR, SERVER_INFO_FILE)
};
Self {
inner: shared::Server::new_with_custom_paths(
svc,
ContainerType::Sink,
sock_addr,
server_info_file,
),
}
}
pub async fn start_with_shutdown(
self,
shutdown_rx: oneshot::Receiver<()>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
where
T: Sinker + Send + Sync + 'static,
{
self.inner
.start_with_shutdown(
shutdown_rx,
|handler, max_message_size, shutdown_tx, cln_token| {
let svc = SinkService {
handler: Arc::new(handler),
shutdown_tx,
cancellation_token: cln_token,
};
let svc = sink_pb::sink_server::SinkServer::new(svc)
.max_encoding_message_size(max_message_size)
.max_decoding_message_size(max_message_size);
tonic::transport::Server::builder().add_service(svc)
},
)
.await
}
pub async fn start(self) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
where
T: Sinker + Send + Sync + 'static,
{
self.inner
.start(|handler, max_message_size, shutdown_tx, cln_token| {
let svc = SinkService {
handler: Arc::new(handler),
shutdown_tx,
cancellation_token: cln_token,
};
let svc = sink_pb::sink_server::SinkServer::new(svc)
.max_encoding_message_size(max_message_size)
.max_decoding_message_size(max_message_size);
tonic::transport::Server::builder().add_service(svc)
})
.await
}
}
#[cfg(test)]
mod tests {
use crate::shared::ServerExtras;
use std::{error::Error, time::Duration};
use tempfile::TempDir;
use tokio::net::UnixStream;
use tokio::sync::oneshot;
use tonic::transport::Uri;
use tower::service_fn;
use crate::proto::sink::TransmissionStatus;
use crate::proto::sink::sink_client::SinkClient;
use crate::proto::sink::sink_request::Request;
use crate::proto::sink::{Handshake, SinkRequest};
use crate::sink;
#[tokio::test]
async fn sink_server() -> Result<(), Box<dyn Error>> {
struct Logger;
#[tonic::async_trait]
impl sink::Sinker for Logger {
async fn sink(
&self,
mut input: tokio::sync::mpsc::Receiver<sink::SinkRequest>,
) -> Vec<sink::Response> {
let mut responses: Vec<sink::Response> = Vec::new();
while let Some(datum) = input.recv().await {
let response = match std::str::from_utf8(&datum.value) {
Ok(_) => sink::Response::ok(datum.id),
Err(e) => sink::Response::failure(
datum.id,
format!("Invalid UTF-8 sequence: {}", e),
),
};
responses.push(response);
}
responses
}
}
let tmp_dir = TempDir::new()?;
let sock_file = tmp_dir.path().join("sink.sock");
let server_info_file = tmp_dir.path().join("sinker-server-info");
let server = sink::Server::new(Logger)
.with_server_info_file(&server_info_file)
.with_socket_file(&sock_file)
.with_max_message_size(10240);
assert_eq!(server.max_message_size(), 10240);
assert_eq!(server.server_info_file(), server_info_file);
assert_eq!(server.socket_file(), sock_file);
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let task = tokio::spawn(async move { server.start_with_shutdown(shutdown_rx).await });
tokio::time::sleep(Duration::from_millis(50)).await;
let channel = tonic::transport::Endpoint::try_from("http://[::]:50051")?
.connect_with_connector(service_fn(move |_: Uri| {
let sock_file = sock_file.clone();
async move {
Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new(
UnixStream::connect(sock_file).await?,
))
}
}))
.await?;
let mut client = SinkClient::new(channel);
let handshake_request = SinkRequest {
request: None,
status: None,
handshake: Some(Handshake { sot: true }),
};
let request = SinkRequest {
request: Some(Request {
keys: vec!["first".into(), "second".into()],
value: "hello".into(),
watermark: Some(prost_types::Timestamp::default()),
event_time: Some(prost_types::Timestamp::default()),
id: "1".to_string(),
headers: Default::default(),
metadata: None,
}),
status: None,
handshake: None,
};
let eot_request = SinkRequest {
request: None,
status: Some(TransmissionStatus { eot: true }),
handshake: None,
};
let request_two = SinkRequest {
request: Some(Request {
keys: vec!["first".into(), "second".into()],
value: "hello".into(),
watermark: Some(prost_types::Timestamp::default()),
event_time: Some(prost_types::Timestamp::default()),
id: "2".to_string(),
headers: Default::default(),
metadata: None,
}),
status: None,
handshake: None,
};
let resp = client
.sink_fn(tokio_stream::iter(vec![
handshake_request,
request,
eot_request.clone(),
request_two,
eot_request,
]))
.await?;
let mut resp_stream = resp.into_inner();
let resp = resp_stream.message().await.unwrap().unwrap();
assert!(resp.handshake.is_some());
let resp = resp_stream.message().await.unwrap().unwrap();
assert!(!resp.results.is_empty());
let msg = resp.results.first().unwrap();
assert_eq!(msg.err_msg, "");
assert_eq!(msg.id, "1");
let resp = resp_stream.message().await.unwrap().unwrap();
assert!(resp.results.is_empty());
assert!(resp.handshake.is_none());
let msg = &resp.status.unwrap();
assert!(msg.eot);
let resp = resp_stream.message().await.unwrap().unwrap();
assert!(!resp.results.is_empty());
assert!(resp.handshake.is_none());
let msg = resp.results.first().unwrap();
assert_eq!(msg.err_msg, "");
assert_eq!(msg.id, "2");
let resp = resp_stream.message().await.unwrap().unwrap();
assert!(resp.results.is_empty());
assert!(resp.handshake.is_none());
let msg = &resp.status.unwrap();
assert!(msg.eot);
shutdown_tx
.send(())
.expect("Sending shutdown signal to gRPC server");
tokio::time::sleep(Duration::from_millis(50)).await;
assert!(task.is_finished(), "gRPC server is still running");
Ok(())
}
#[cfg(feature = "test-panic")]
#[tokio::test]
async fn sink_panic() -> Result<(), Box<dyn Error>> {
struct PanicSink;
#[tonic::async_trait]
impl sink::Sinker for PanicSink {
async fn sink(
&self,
mut input: tokio::sync::mpsc::Receiver<sink::SinkRequest>,
) -> Vec<sink::Response> {
let mut responses: Vec<sink::Response> = Vec::new();
let mut count = 0;
while let Some(datum) = input.recv().await {
if count > 5 {
panic!("Should not cross 5");
}
count += 1;
responses.push(sink::Response::ok(datum.id));
}
responses
}
}
let tmp_dir = TempDir::new()?;
let sock_file = tmp_dir.path().join("sink.sock");
let server_info_file = tmp_dir.path().join("sinker-server-info");
let server = sink::Server::new(PanicSink)
.with_server_info_file(&server_info_file)
.with_socket_file(&sock_file)
.with_max_message_size(10240);
assert_eq!(server.max_message_size(), 10240);
assert_eq!(server.server_info_file(), server_info_file);
assert_eq!(server.socket_file(), sock_file);
let (_shutdown_tx, shutdown_rx) = oneshot::channel();
let task = tokio::spawn(async move { server.start_with_shutdown(shutdown_rx).await });
tokio::time::sleep(Duration::from_millis(50)).await;
let channel = tonic::transport::Endpoint::try_from("http://[::]:50051")?
.connect_with_connector(service_fn(move |_: Uri| {
let sock_file = sock_file.clone();
async move {
Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new(
UnixStream::connect(sock_file).await?,
))
}
}))
.await?;
let mut client = SinkClient::new(channel);
let handshake_request = SinkRequest {
request: None,
status: None,
handshake: Some(Handshake { sot: true }),
};
let mut requests = vec![handshake_request];
for i in 0..10 {
let request = SinkRequest {
request: Some(Request {
keys: vec!["first".into(), "second".into()],
value: format!("hello {}", i).into(),
watermark: Some(prost_types::Timestamp::default()),
event_time: Some(prost_types::Timestamp::default()),
id: i.to_string(),
headers: Default::default(),
metadata: None,
}),
status: None,
handshake: None,
};
requests.push(request);
}
requests.push(SinkRequest {
request: None,
status: Some(TransmissionStatus { eot: true }),
handshake: None,
});
let mut resp_stream = client
.sink_fn(tokio_stream::iter(requests))
.await
.unwrap()
.into_inner();
let resp = resp_stream.message().await.unwrap().unwrap();
assert!(resp.results.is_empty());
assert!(resp.handshake.is_some());
let err_resp = resp_stream.message().await;
assert!(err_resp.is_err());
if let Err(e) = err_resp {
assert_eq!(e.code(), tonic::Code::Internal);
assert!(e.message().contains("UDF_EXECUTION_ERROR"));
assert!(e.message().contains("Should not cross 5"));
}
for _ in 0..10 {
tokio::time::sleep(Duration::from_millis(10)).await;
if task.is_finished() {
break;
}
}
assert!(task.is_finished(), "gRPC server is still running");
Ok(())
}
}