1use chrono::{DateTime, Utc};
2use proto::SourceTransformResponse;
3use std::collections::HashMap;
4
5use std::sync::Arc;
6
7use tokio::sync::{mpsc, oneshot};
8use tokio::task::JoinHandle;
9use tokio_stream::wrappers::ReceiverStream;
10use tokio_util::sync::CancellationToken;
11use tonic::{Request, Response, Status, Streaming, async_trait};
12use tracing::{error, info};
13
14use crate::error::{Error, ErrorKind};
15use crate::proto::metadata as metadata_pb;
16use crate::proto::source_transformer as proto;
17use crate::shared;
18
19use shared::{
20 ContainerType, DROP, build_panic_status, get_panic_info, prost_timestamp_from_utc,
21 utc_from_timestamp,
22};
23
24pub const SOCK_ADDR: &str = "/var/run/numaflow/sourcetransform.sock";
26
27pub const SERVER_INFO_FILE: &str = "/var/run/numaflow/sourcetransformer-server-info";
29
30const CHANNEL_SIZE: usize = 1000;
32
33#[derive(Debug, Clone, Default)]
37pub struct SystemMetadata {
38 data: HashMap<String, HashMap<String, Vec<u8>>>,
39}
40
41impl SystemMetadata {
42 pub fn new() -> Self {
45 Self::default()
46 }
47
48 pub fn groups(&self) -> Vec<String> {
61 self.data.keys().cloned().collect()
62 }
63
64 pub fn keys(&self, group: &str) -> Vec<String> {
77 self.data
78 .get(group)
79 .map(|kv| kv.keys().cloned().collect())
80 .unwrap_or_default()
81 }
82
83 pub fn value(&self, group: &str, key: &str) -> Vec<u8> {
96 self.data
97 .get(group)
98 .and_then(|kv| kv.get(key))
99 .cloned()
100 .unwrap_or_default()
101 }
102}
103
104#[derive(Debug, Clone, Default)]
106pub struct UserMetadata {
107 data: HashMap<String, HashMap<String, Vec<u8>>>,
108}
109
110impl UserMetadata {
111 pub fn new() -> Self {
113 Self::default()
114 }
115
116 pub fn groups(&self) -> Vec<String> {
129 self.data.keys().cloned().collect()
130 }
131
132 pub fn keys(&self, group: &str) -> Vec<String> {
145 self.data
146 .get(group)
147 .map(|kv| kv.keys().cloned().collect())
148 .unwrap_or_default()
149 }
150
151 pub fn value(&self, group: &str, key: &str) -> Vec<u8> {
164 self.data
165 .get(group)
166 .and_then(|kv| kv.get(key))
167 .cloned()
168 .unwrap_or_default()
169 }
170
171 pub fn create_group(&mut self, group: String) {
184 self.data.entry(group).or_default();
185 }
186
187 pub fn add_kv(&mut self, group: String, key: String, value: Vec<u8>) {
199 self.data.entry(group).or_default().insert(key, value);
200 }
201
202 pub fn remove_key(&mut self, group: &str, key: &str) {
215 if let Some(kv) = self.data.get_mut(group) {
216 kv.remove(key);
217 }
218 }
219
220 pub fn remove_group(&mut self, group: &str) {
233 self.data.remove(group);
234 }
235}
236
237struct SourceTransformerService<T> {
238 handler: Arc<T>,
239 shutdown_tx: mpsc::Sender<()>,
240 cancellation_token: CancellationToken,
241}
242
243#[async_trait]
245pub trait SourceTransformer {
246 async fn transform(&self, input: SourceTransformRequest) -> Vec<Message>;
280}
281
282#[derive(Debug)]
284pub struct Message {
285 pub keys: Option<Vec<String>>,
288 pub value: Vec<u8>,
290 pub event_time: DateTime<Utc>,
293 pub tags: Option<Vec<String>>,
295 pub user_metadata: Option<UserMetadata>,
297}
298
299impl Message {
301 pub fn new(value: Vec<u8>, event_time: DateTime<Utc>) -> Self {
319 Self {
320 value,
321 event_time,
322 keys: None,
323 tags: None,
324 user_metadata: None,
325 }
326 }
327 pub fn message_to_drop(event_time: DateTime<Utc>) -> Message {
343 Message {
344 keys: None,
345 value: vec![],
346 event_time,
347 tags: Some(vec![DROP.to_string()]),
348 user_metadata: None,
349 }
350 }
351
352 pub fn with_keys(mut self, keys: Vec<String>) -> Self {
367 self.keys = Some(keys);
368 self
369 }
370 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
385 self.tags = Some(tags);
386 self
387 }
388
389 pub fn with_user_metadata(mut self, user_metadata: UserMetadata) -> Self {
391 self.user_metadata = Some(user_metadata);
392 self
393 }
394}
395
396pub struct SourceTransformRequest {
398 pub keys: Vec<String>,
400 pub value: Vec<u8>,
402 pub watermark: DateTime<Utc>,
405 pub eventtime: DateTime<Utc>,
407 pub headers: HashMap<String, String>,
409 pub user_metadata: UserMetadata,
411 pub system_metadata: SystemMetadata,
413}
414
415fn to_proto(user_metadata: Option<&UserMetadata>) -> metadata_pb::Metadata {
419 let mut user = HashMap::new();
420
421 if let Some(umd) = user_metadata {
422 for group in umd.groups() {
423 let mut kv = HashMap::new();
424 for key in umd.keys(&group) {
425 kv.insert(key.clone(), umd.value(&group, &key));
426 }
427 user.insert(group, metadata_pb::KeyValueGroup { key_value: kv });
428 }
429 }
430
431 metadata_pb::Metadata {
432 previous_vertex: String::new(),
433 sys_metadata: HashMap::new(),
434 user_metadata: user,
435 }
436}
437
438impl From<Message> for proto::source_transform_response::Result {
439 fn from(value: Message) -> Self {
440 proto::source_transform_response::Result {
441 keys: value.keys.unwrap_or_default(),
442 value: value.value,
443 event_time: prost_timestamp_from_utc(value.event_time),
444 tags: value.tags.unwrap_or_default(),
445 metadata: Some(to_proto(value.user_metadata.as_ref())),
446 }
447 }
448}
449
450fn user_metadata_from_proto(proto: Option<&metadata_pb::Metadata>) -> UserMetadata {
452 let proto = match proto {
453 Some(p) => p,
454 None => return UserMetadata::new(),
455 };
456
457 let mut user_map = HashMap::new();
458 for (group, kv_group) in &proto.user_metadata {
459 user_map.insert(group.clone(), kv_group.key_value.clone());
460 }
461
462 UserMetadata { data: user_map }
463}
464
465fn system_metadata_from_proto(proto: Option<&metadata_pb::Metadata>) -> SystemMetadata {
467 let proto = match proto {
468 Some(p) => p,
469 None => return SystemMetadata::new(),
470 };
471
472 let mut sys_map = HashMap::new();
473 for (group, kv_group) in &proto.sys_metadata {
474 sys_map.insert(group.clone(), kv_group.key_value.clone());
475 }
476
477 SystemMetadata { data: sys_map }
478}
479
480impl From<proto::source_transform_request::Request> for SourceTransformRequest {
481 fn from(request: proto::source_transform_request::Request) -> Self {
482 let user_metadata = user_metadata_from_proto(request.metadata.as_ref());
483 let system_metadata = system_metadata_from_proto(request.metadata.as_ref());
484
485 SourceTransformRequest {
486 keys: request.keys,
487 value: request.value,
488 watermark: utc_from_timestamp(request.watermark),
489 eventtime: utc_from_timestamp(request.event_time),
490 headers: request.headers,
491 user_metadata,
492 system_metadata,
493 }
494 }
495}
496
497#[async_trait]
498impl<T> proto::source_transform_server::SourceTransform for SourceTransformerService<T>
499where
500 T: SourceTransformer + Send + Sync + 'static,
501{
502 type SourceTransformFnStream = ReceiverStream<Result<SourceTransformResponse, Status>>;
503
504 async fn source_transform_fn(
505 &self,
506 request: Request<Streaming<proto::SourceTransformRequest>>,
507 ) -> Result<Response<Self::SourceTransformFnStream>, Status> {
508 let mut stream = request.into_inner();
509 let handler = Arc::clone(&self.handler);
510
511 let (stream_response_tx, stream_response_rx) =
512 mpsc::channel::<Result<SourceTransformResponse, Status>>(CHANNEL_SIZE);
513
514 perform_handshake(&mut stream, &stream_response_tx).await?;
516
517 let (error_tx, error_rx) = mpsc::channel::<Error>(1);
518
519 let handle: JoinHandle<()> = tokio::spawn(handle_stream_requests(
522 handler.clone(),
523 stream,
524 stream_response_tx.clone(),
525 error_tx,
526 self.cancellation_token.child_token(),
527 ));
528
529 tokio::spawn(manage_grpc_stream(
530 handle,
531 stream_response_tx,
532 error_rx,
533 self.shutdown_tx.clone(),
534 ));
535
536 Ok(Response::new(ReceiverStream::new(stream_response_rx)))
537 }
538
539 async fn is_ready(&self, _: Request<()>) -> Result<Response<proto::ReadyResponse>, Status> {
540 Ok(Response::new(proto::ReadyResponse { ready: true }))
541 }
542}
543
544async fn perform_handshake(
545 stream: &mut Streaming<proto::SourceTransformRequest>,
546 stream_response_tx: &mpsc::Sender<Result<SourceTransformResponse, Status>>,
547) -> Result<(), Status> {
548 let handshake_request = stream
549 .message()
550 .await
551 .map_err(|e| Status::internal(format!("Handshake failed: {}", e)))?
552 .ok_or_else(|| Status::internal("Stream closed before handshake"))?;
553
554 if let Some(handshake) = handshake_request.handshake {
555 stream_response_tx
556 .send(Ok(SourceTransformResponse {
557 results: vec![],
558 id: "".to_string(),
559 handshake: Some(handshake),
560 }))
561 .await
562 .map_err(|e| Status::internal(format!("Failed to send handshake response: {}", e)))?;
563 Ok(())
564 } else {
565 Err(Status::invalid_argument("Handshake not present"))
566 }
567}
568
569async fn manage_grpc_stream(
571 request_handler: JoinHandle<()>,
572 stream_response_tx: mpsc::Sender<Result<SourceTransformResponse, Status>>,
573 mut error_rx: mpsc::Receiver<Error>,
574 server_shutdown_tx: mpsc::Sender<()>,
575) {
576 let err = match error_rx.recv().await {
577 Some(err) => err,
578 None => match request_handler.await {
579 Ok(_) => return,
580 Err(e) => Error::SourceTransformerError(ErrorKind::InternalError(format!(
581 "Source transformer request handler aborted: {e:?}"
582 ))),
583 },
584 };
585
586 error!("Shutting down gRPC channel: {err:?}");
587 stream_response_tx
588 .send(Err(err.into_status()))
589 .await
590 .expect("Sending error message to gRPC response channel");
591 server_shutdown_tx
592 .send(())
593 .await
594 .expect("Writing to shutdown channel");
595}
596
597async fn handle_stream_requests<T>(
600 handler: Arc<T>,
601 mut stream: Streaming<proto::SourceTransformRequest>,
602 stream_response_tx: mpsc::Sender<Result<SourceTransformResponse, Status>>,
603 error_tx: mpsc::Sender<Error>,
604 token: CancellationToken,
605) where
606 T: SourceTransformer + Send + Sync + 'static,
607{
608 let mut stream_open = true;
609 while stream_open {
610 stream_open = tokio::select! {
611 transform_request = stream.message() => handle_request(
612 handler.clone(),
613 transform_request,
614 stream_response_tx.clone(),
615 error_tx.clone(),
616 ).await,
617 _ = token.cancelled() => {
618 info!("Cancellation token is cancelled, shutting down");
619 break;
620 }
621 }
622 }
623}
624
625async fn handle_request<T>(
628 handler: Arc<T>,
629 transform_request: Result<Option<proto::SourceTransformRequest>, Status>,
630 stream_response_tx: mpsc::Sender<Result<SourceTransformResponse, Status>>,
631 error_tx: mpsc::Sender<Error>,
632) -> bool
633where
634 T: SourceTransformer + Send + Sync + 'static,
635{
636 let transform_request = match transform_request {
637 Ok(None) => return false,
638 Ok(Some(val)) => val,
639 Err(val) => {
640 error!("Received gRPC error from sender: {val:?}");
641 return false;
642 }
643 };
644 tokio::spawn(run_transform(
645 handler,
646 transform_request,
647 stream_response_tx,
648 error_tx,
649 ));
650 true
651}
652
653async fn run_transform<T>(
655 handler: Arc<T>,
656 transform_request: proto::SourceTransformRequest,
657 stream_response_tx: mpsc::Sender<Result<SourceTransformResponse, Status>>,
658 error_tx: mpsc::Sender<Error>,
659) where
660 T: SourceTransformer + Send + Sync + 'static,
661{
662 let request = transform_request.request.expect("request can not be none");
663 let message_id = request.id.clone();
664
665 let udf_transform_task = tokio::spawn({
667 let handler = handler.clone();
668 async move { handler.transform(request.into()).await }
669 });
670
671 let messages = match udf_transform_task.await {
672 Ok(messages) => messages,
673 Err(e) => {
674 error!("Failed to run transform function: {e:?}");
675
676 if let Some(panic_info) = get_panic_info() {
678 let status = build_panic_status(&panic_info);
680 let _ = error_tx.send(Error::GrpcStatus(status)).await;
681 } else {
682 let _ = error_tx
684 .send(Error::SourceTransformerError(ErrorKind::InternalError(
685 format!("Transform task execution failed: {e:?}"),
686 )))
687 .await;
688 }
689 return;
690 }
691 };
692
693 let send_response_result = stream_response_tx
694 .send(Ok(SourceTransformResponse {
695 results: messages.into_iter().map(|msg| msg.into()).collect(),
696 id: message_id,
697 handshake: None,
698 }))
699 .await;
700
701 let Err(e) = send_response_result else {
702 return;
703 };
704
705 let _ = error_tx
706 .send(Error::SourceTransformerError(ErrorKind::InternalError(
707 format!("sending source transform response over gRPC stream: {e:?}"),
708 )))
709 .await;
710}
711
712#[derive(Debug)]
714pub struct Server<T> {
715 inner: shared::Server<T>,
716}
717
718impl<T> shared::ServerExtras<T> for Server<T> {
719 fn transform_inner<F>(self, f: F) -> Self
720 where
721 F: FnOnce(shared::Server<T>) -> shared::Server<T>,
722 {
723 Self {
724 inner: f(self.inner),
725 }
726 }
727
728 fn inner_ref(&self) -> &shared::Server<T> {
729 &self.inner
730 }
731}
732
733impl<T> Server<T> {
734 pub fn new(sourcetransformer_svc: T) -> Self {
735 Self {
736 inner: shared::Server::new(
737 sourcetransformer_svc,
738 ContainerType::SourceTransformer,
739 SOCK_ADDR,
740 SERVER_INFO_FILE,
741 ),
742 }
743 }
744
745 pub async fn start_with_shutdown(
747 self,
748 shutdown_rx: oneshot::Receiver<()>,
749 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
750 where
751 T: SourceTransformer + Send + Sync + 'static,
752 {
753 self.inner
754 .start_with_shutdown(
755 shutdown_rx,
756 |handler, max_message_size, shutdown_tx, cln_token| {
757 let sourcetrf_svc = SourceTransformerService {
758 handler: Arc::new(handler),
759 shutdown_tx,
760 cancellation_token: cln_token,
761 };
762
763 let sourcetrf_svc =
764 proto::source_transform_server::SourceTransformServer::new(sourcetrf_svc)
765 .max_encoding_message_size(max_message_size)
766 .max_decoding_message_size(max_message_size);
767
768 tonic::transport::Server::builder().add_service(sourcetrf_svc)
769 },
770 )
771 .await
772 }
773
774 pub async fn start(self) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
776 where
777 T: SourceTransformer + Send + Sync + 'static,
778 {
779 self.inner
780 .start(|handler, max_message_size, shutdown_tx, cln_token| {
781 let sourcetrf_svc = SourceTransformerService {
782 handler: Arc::new(handler),
783 shutdown_tx,
784 cancellation_token: cln_token,
785 };
786
787 let sourcetrf_svc =
788 proto::source_transform_server::SourceTransformServer::new(sourcetrf_svc)
789 .max_encoding_message_size(max_message_size)
790 .max_decoding_message_size(max_message_size);
791
792 tonic::transport::Server::builder().add_service(sourcetrf_svc)
793 })
794 .await
795 }
796}
797
798#[cfg(test)]
799mod tests {
800 use crate::shared::ServerExtras;
801 use chrono::Utc;
802 use std::{error::Error, time::Duration};
803 use tempfile::TempDir;
804 use tokio::net::UnixStream;
805 use tokio::sync::{mpsc, oneshot};
806 use tokio_stream::wrappers::ReceiverStream;
807 use tonic::transport::Uri;
808 use tower::service_fn;
809
810 use crate::proto::source_transformer::{
811 self as proto, source_transform_client::SourceTransformClient,
812 };
813 use crate::sourcetransform::{self};
814
815 #[tokio::test]
816 async fn source_transformer_server() -> Result<(), Box<dyn Error>> {
817 struct NowCat;
818 #[tonic::async_trait]
819 impl sourcetransform::SourceTransformer for NowCat {
820 async fn transform(
821 &self,
822 input: sourcetransform::SourceTransformRequest,
823 ) -> Vec<sourcetransform::Message> {
824 vec![sourcetransform::Message {
825 keys: Some(input.keys),
826 value: input.value,
827 tags: Some(vec![]),
828 event_time: Utc::now(),
829 user_metadata: None,
830 }]
831 }
832 }
833
834 let tmp_dir = TempDir::new()?;
835 let sock_file = tmp_dir.path().join("sourcetransform.sock");
836 let server_info_file = tmp_dir.path().join("sourcetransformer-server-info");
837
838 let server = sourcetransform::Server::new(NowCat)
839 .with_server_info_file(&server_info_file)
840 .with_socket_file(&sock_file)
841 .with_max_message_size(10240);
842
843 assert_eq!(server.max_message_size(), 10240);
844 assert_eq!(server.server_info_file(), server_info_file);
845 assert_eq!(server.socket_file(), sock_file);
846
847 let (shutdown_tx, shutdown_rx) = oneshot::channel();
848 let task = tokio::spawn(async move { server.start_with_shutdown(shutdown_rx).await });
849
850 tokio::time::sleep(Duration::from_millis(50)).await;
851
852 let channel = tonic::transport::Endpoint::try_from("http://[::]:50051")?
854 .connect_with_connector(service_fn(move |_: Uri| {
855 let sock_file = sock_file.clone();
857 async move {
858 Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new(
859 UnixStream::connect(sock_file).await?,
860 ))
861 }
862 }))
863 .await?;
864
865 let mut client = SourceTransformClient::new(channel);
866
867 let (tx, rx) = mpsc::channel(2);
868
869 let handshake_request = proto::SourceTransformRequest {
870 request: None,
871 handshake: Some(proto::Handshake { sot: true }),
872 };
873 tx.send(handshake_request).await.unwrap();
874
875 let mut stream = tokio::time::timeout(
876 Duration::from_secs(2),
877 client.source_transform_fn(ReceiverStream::new(rx)),
878 )
879 .await
880 .map_err(|_| "timeout while getting stream for source_transform_fn")??
881 .into_inner();
882
883 let handshake_resp = stream.message().await?.unwrap();
884 assert!(
885 handshake_resp.results.is_empty(),
886 "The handshake response should not contain any messages"
887 );
888 assert!(
889 handshake_resp.id.is_empty(),
890 "The message id of the handshake response should be empty"
891 );
892 assert!(
893 handshake_resp.handshake.is_some(),
894 "Not a valid response for handshake request"
895 );
896
897 let request = proto::SourceTransformRequest {
898 request: Some(proto::source_transform_request::Request {
899 id: "1".to_string(),
900 keys: vec!["first".into(), "second".into()],
901 value: "hello".into(),
902 watermark: Some(prost_types::Timestamp::default()),
903 event_time: Some(prost_types::Timestamp::default()),
904 headers: Default::default(),
905 metadata: None,
906 }),
907 handshake: None,
908 };
909
910 tx.send(request).await.unwrap();
911
912 let resp = stream.message().await?.unwrap();
913 assert_eq!(resp.results.len(), 1, "Expected single message from server");
914 let msg = &resp.results[0];
915 assert_eq!(msg.keys.first(), Some(&"first".to_owned()));
916 assert_eq!(msg.value, "hello".as_bytes());
917
918 drop(tx);
919
920 shutdown_tx
921 .send(())
922 .expect("Sending shutdown signal to gRPC server");
923 tokio::time::sleep(Duration::from_millis(50)).await;
924 assert!(task.is_finished(), "gRPC server is still running");
925 Ok(())
926 }
927
928 #[cfg(feature = "test-panic")]
929 #[tokio::test]
930 async fn source_transformer_panic() -> Result<(), Box<dyn Error>> {
931 struct PanicTransformer;
932 #[tonic::async_trait]
933 impl sourcetransform::SourceTransformer for PanicTransformer {
934 async fn transform(
935 &self,
936 _: sourcetransform::SourceTransformRequest,
937 ) -> Vec<sourcetransform::Message> {
938 panic!("Panic in transformer");
939 }
940 }
941
942 let tmp_dir = TempDir::new()?;
943 let sock_file = tmp_dir.path().join("sourcetransform.sock");
944 let server_info_file = tmp_dir.path().join("sourcetransformer-server-info");
945
946 let server = sourcetransform::Server::new(PanicTransformer)
947 .with_server_info_file(&server_info_file)
948 .with_socket_file(&sock_file)
949 .with_max_message_size(10240);
950
951 assert_eq!(server.max_message_size(), 10240);
952 assert_eq!(server.server_info_file(), server_info_file);
953 assert_eq!(server.socket_file(), sock_file);
954
955 let (_shutdown_tx, shutdown_rx) = oneshot::channel();
956 let task = tokio::spawn(async move { server.start_with_shutdown(shutdown_rx).await });
957
958 tokio::time::sleep(Duration::from_millis(50)).await;
959
960 let channel = tonic::transport::Endpoint::try_from("http://[::]:50051")?
962 .connect_with_connector(service_fn(move |_: Uri| {
963 let sock_file = sock_file.clone();
965 async move {
966 Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new(
967 UnixStream::connect(sock_file).await?,
968 ))
969 }
970 }))
971 .await?;
972
973 let mut client = SourceTransformClient::new(channel);
974
975 let (tx, rx) = mpsc::channel(2);
976 let handshake_request = proto::SourceTransformRequest {
977 request: None,
978 handshake: Some(proto::Handshake { sot: true }),
979 };
980 tx.send(handshake_request).await.unwrap();
981
982 let mut stream = tokio::time::timeout(
983 Duration::from_secs(2),
984 client.source_transform_fn(ReceiverStream::new(rx)),
985 )
986 .await
987 .map_err(|_| "timeout while getting stream for source_transform_fn")??
988 .into_inner();
989
990 let handshake_resp = stream.message().await?.unwrap();
991 assert!(
992 handshake_resp.handshake.is_some(),
993 "Not a valid response for handshake request"
994 );
995
996 let request = proto::SourceTransformRequest {
997 request: Some(proto::source_transform_request::Request {
998 id: "2".to_string(),
999 keys: vec!["first".into(), "second".into()],
1000 value: "hello".into(),
1001 watermark: Some(prost_types::Timestamp::default()),
1002 event_time: Some(prost_types::Timestamp::default()),
1003 headers: Default::default(),
1004 metadata: None,
1005 }),
1006 handshake: None,
1007 };
1008 tx.send(request).await.unwrap();
1009 drop(tx);
1010
1011 for _ in 0..10 {
1013 tokio::time::sleep(Duration::from_millis(10)).await;
1014 if task.is_finished() {
1015 break;
1016 }
1017 }
1018 assert!(task.is_finished(), "gRPC server is still running");
1019 Ok(())
1020 }
1021}