numaflow/
sourcetransform.rs

1use chrono::{DateTime, Utc};
2use proto::SourceTransformResponse;
3use std::collections::HashMap;
4
5use std::sync::Arc;
6
7use tokio::sync::{mpsc, oneshot};
8use tokio::task::JoinHandle;
9use tokio_stream::wrappers::ReceiverStream;
10use tokio_util::sync::CancellationToken;
11use tonic::{Request, Response, Status, Streaming, async_trait};
12use tracing::{error, info};
13
14use crate::error::{Error, ErrorKind};
15use crate::proto::metadata as metadata_pb;
16use crate::proto::source_transformer as proto;
17use crate::shared;
18
19use shared::{
20    ContainerType, DROP, build_panic_status, get_panic_info, prost_timestamp_from_utc,
21    utc_from_timestamp,
22};
23
24/// Default socket address for source transformer service
25pub const SOCK_ADDR: &str = "/var/run/numaflow/sourcetransform.sock";
26
27/// Default server info file for source transformer service
28pub const SERVER_INFO_FILE: &str = "/var/run/numaflow/sourcetransformer-server-info";
29
30/// Default channel size for source transformer service
31const CHANNEL_SIZE: usize = 1000;
32
33/// SystemMetadata is mapping of group name to key-value pairs
34/// SystemMetadata wraps system-generated metadata groups per message.
35/// It is read-only to UDFs
36#[derive(Debug, Clone, Default)]
37pub struct SystemMetadata {
38    data: HashMap<String, HashMap<String, Vec<u8>>>,
39}
40
41impl SystemMetadata {
42    /// Create a new SystemMetadata instance
43    /// This is for internal and testing purposes only.
44    pub fn new() -> Self {
45        Self::default()
46    }
47
48    /// groups returns the groups of the system metadata.
49    /// If there are no groups, it returns an empty vector.
50    ///
51    /// # Example
52    ///
53    /// ```no_run
54    /// # use numaflow::sourcetransform::SourceTransformRequest;
55    /// # let input: SourceTransformRequest = unimplemented!();
56    /// let smd = input.system_metadata;
57    /// let groups = smd.groups();
58    /// println!("System metadata groups: {:?}", groups);
59    /// ```
60    pub fn groups(&self) -> Vec<String> {
61        self.data.keys().cloned().collect()
62    }
63
64    /// keys returns the keys of the system metadata for the given group.
65    /// If there are no keys or the group is not present, it returns an empty vector.
66    ///
67    /// # Example
68    ///
69    /// ```no_run
70    /// # use numaflow::sourcetransform::SourceTransformRequest;
71    /// # let input: SourceTransformRequest = unimplemented!();
72    /// let smd = input.system_metadata;
73    /// let keys = smd.keys("system-group");
74    /// println!("Keys in system-group: {:?}", keys);
75    /// ```
76    pub fn keys(&self, group: &str) -> Vec<String> {
77        self.data
78            .get(group)
79            .map(|kv| kv.keys().cloned().collect())
80            .unwrap_or_default()
81    }
82
83    /// value returns the value of the system metadata for the given group and key.
84    /// If there is no value or the group or key is not present, it returns an empty vector.
85    ///
86    /// # Example
87    ///
88    /// ```no_run
89    /// # use numaflow::sourcetransform::SourceTransformRequest;
90    /// # let input: SourceTransformRequest = unimplemented!();
91    /// let smd = input.system_metadata;
92    /// let value = smd.value("system-group", "system-key");
93    /// println!("Value: {:?}", value);
94    /// ```
95    pub fn value(&self, group: &str, key: &str) -> Vec<u8> {
96        self.data
97            .get(group)
98            .and_then(|kv| kv.get(key))
99            .cloned()
100            .unwrap_or_default()
101    }
102}
103
104/// UserMetadata wraps user-defined metadata groups per message.
105#[derive(Debug, Clone, Default)]
106pub struct UserMetadata {
107    data: HashMap<String, HashMap<String, Vec<u8>>>,
108}
109
110impl UserMetadata {
111    /// Create a new UserMetadata instance
112    pub fn new() -> Self {
113        Self::default()
114    }
115
116    /// groups returns the groups of the user metadata.
117    /// If there are no groups, it returns an empty vector.
118    ///
119    /// # Example
120    ///
121    /// ```no_run
122    /// # use numaflow::sourcetransform::SourceTransformRequest;
123    /// # let input: SourceTransformRequest = unimplemented!();
124    /// let umd = input.user_metadata;
125    /// let groups = umd.groups();
126    /// println!("User metadata groups: {:?}", groups);
127    /// ```
128    pub fn groups(&self) -> Vec<String> {
129        self.data.keys().cloned().collect()
130    }
131
132    /// keys returns the keys of the user metadata for the given group.
133    /// If there are no keys or the group is not present, it returns an empty vector.
134    ///
135    /// # Example
136    ///
137    /// ```no_run
138    /// # use numaflow::sourcetransform::SourceTransformRequest;
139    /// # let input: SourceTransformRequest = unimplemented!();
140    /// let umd = input.user_metadata;
141    /// let keys = umd.keys("my-group");
142    /// println!("Keys in my-group: {:?}", keys);
143    /// ```
144    pub fn keys(&self, group: &str) -> Vec<String> {
145        self.data
146            .get(group)
147            .map(|kv| kv.keys().cloned().collect())
148            .unwrap_or_default()
149    }
150
151    /// value returns the value of the user metadata for the given group and key.
152    /// If there is no value or the group or key is not present, it returns an empty vector.
153    ///
154    /// # Example
155    ///
156    /// ```no_run
157    /// # use numaflow::sourcetransform::SourceTransformRequest;
158    /// # let input: SourceTransformRequest = unimplemented!();
159    /// let umd = input.user_metadata;
160    /// let value = umd.value("my-group", "my-key");
161    /// println!("Value: {:?}", value);
162    /// ```
163    pub fn value(&self, group: &str, key: &str) -> Vec<u8> {
164        self.data
165            .get(group)
166            .and_then(|kv| kv.get(key))
167            .cloned()
168            .unwrap_or_default()
169    }
170
171    /// create_group creates a new group in the user metadata.
172    /// If the group is not present, it's a no-op.
173    ///
174    /// # Example
175    ///
176    /// ```no_run
177    /// use numaflow::sourcetransform::UserMetadata;
178    /// use std::collections::HashMap;
179    /// let mut umd = UserMetadata::new();
180    /// umd.create_group("group1".to_string());
181    /// println!("{:?}", umd);
182    /// ```
183    pub fn create_group(&mut self, group: String) {
184        self.data.entry(group).or_default();
185    }
186
187    /// add_kv adds a key-value pair to the user metadata.
188    /// If the group is not present, it creates a new group.
189    ///
190    /// # Example
191    ///
192    /// ```no_run
193    /// use numaflow::sourcetransform::UserMetadata;
194    /// let mut umd = UserMetadata::new();
195    /// umd.add_kv("group1".to_string(), "key1".to_string(), "value1".as_bytes().to_vec());
196    /// println!("{:?}", umd);
197    /// ```
198    pub fn add_kv(&mut self, group: String, key: String, value: Vec<u8>) {
199        self.data.entry(group).or_default().insert(key, value);
200    }
201
202    /// remove_key removes a key from a group in the user metadata.
203    /// If the key or group is not present, it's a no-op.
204    ///
205    /// # Example
206    ///
207    /// ```no_run
208    /// use numaflow::sourcetransform::UserMetadata;
209    /// let mut umd = UserMetadata::new();
210    /// umd.add_kv("group1".to_string(), "key1".to_string(), "value1".as_bytes().to_vec());
211    /// umd.remove_key("group1", "key1");
212    /// println!("{:?}", umd);
213    /// ```
214    pub fn remove_key(&mut self, group: &str, key: &str) {
215        if let Some(kv) = self.data.get_mut(group) {
216            kv.remove(key);
217        }
218    }
219
220    /// remove_group removes a group from the user metadata.
221    /// If the group is not present, it's a no-op.
222    ///
223    /// # Example
224    ///
225    /// ```no_run
226    /// use numaflow::sourcetransform::UserMetadata;
227    /// let mut umd = UserMetadata::new();
228    /// umd.create_group("group1".to_string());
229    /// umd.remove_group("group1");
230    /// println!("{:?}", umd);
231    /// ```
232    pub fn remove_group(&mut self, group: &str) {
233        self.data.remove(group);
234    }
235}
236
237struct SourceTransformerService<T> {
238    handler: Arc<T>,
239    shutdown_tx: mpsc::Sender<()>,
240    cancellation_token: CancellationToken,
241}
242
243/// SourceTransformer trait for implementing SourceTransform handler.
244#[async_trait]
245pub trait SourceTransformer {
246    /// transform takes in an input element and can produce 0, 1, or more results. The input is a [`SourceTransformRequest`]
247    /// and the output is a [`Vec`] of [`Message`]. In a `transform` each element is processed independently
248    /// and there is no state associated with the elements. Source transformer can be used for transforming
249    /// and assigning event time to input messages. More about source transformer can be read
250    /// [here](https://numaflow.numaproj.io/user-guide/sources/transformer/overview/)
251    ///
252    /// #Example
253    ///
254    /// ```no_run
255    /// use numaflow::sourcetransform;
256    /// use std::error::Error;
257    ///
258    /// // A simple source transformer which assigns event time to the current time in utc.
259    ///
260    /// #[tokio::main]
261    /// async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
262    ///     sourcetransform::Server::new(NowCat).start().await
263    /// }
264    ///
265    /// struct NowCat;
266    ///
267    /// #[tonic::async_trait]
268    /// impl sourcetransform::SourceTransformer for NowCat {
269    ///     async fn transform(
270    ///         &self,
271    ///         input: sourcetransform::SourceTransformRequest,
272    ///     ) -> Vec<sourcetransform::Message> {
273    ///         use numaflow::sourcetransform::Message;
274    ///         let message=Message::new(input.value, chrono::offset::Utc::now()).with_keys(input.keys).with_tags(vec![]);
275    ///        vec![message]
276    ///     }
277    /// }
278    /// ```
279    async fn transform(&self, input: SourceTransformRequest) -> Vec<Message>;
280}
281
282/// Message is the response struct from the [`SourceTransformer::transform`] .
283#[derive(Debug)]
284pub struct Message {
285    /// Keys are a collection of strings which will be passed on to the next vertex as is. It can
286    /// be an empty collection.
287    pub keys: Option<Vec<String>>,
288    /// Value is the value passed to the next vertex.
289    pub value: Vec<u8>,
290    /// Time for the given event. This will be used for tracking watermarks. If cannot be derived, set it to the incoming
291    /// event_time from the [`SourceTransformRequest`].
292    pub event_time: DateTime<Utc>,
293    /// Tags are used for [conditional forwarding](https://numaflow.numaproj.io/user-guide/reference/conditional-forwarding/).
294    pub tags: Option<Vec<String>>,
295    /// User metadata for the message.
296    pub user_metadata: Option<UserMetadata>,
297}
298
299/// Represents a message that can be modified and forwarded.
300impl Message {
301    /// Creates a new message with the specified value and event time.
302    ///
303    /// This constructor initializes the message with no keys, tags.
304    ///
305    /// # Arguments
306    ///
307    /// * `value` - A vector of bytes representing the message's payload.
308    /// * `event_time` - The `DateTime<Utc>` that specifies when the event occurred.
309    ///
310    /// # Examples
311    ///
312    /// ```
313    /// use numaflow::sourcetransform::Message;
314    /// use chrono::Utc;
315    /// let now = Utc::now();
316    /// let message = Message::new(vec![1, 2, 3, 4], now);
317    /// ```
318    pub fn new(value: Vec<u8>, event_time: DateTime<Utc>) -> Self {
319        Self {
320            value,
321            event_time,
322            keys: None,
323            tags: None,
324            user_metadata: None,
325        }
326    }
327    /// Marks the message to be dropped by creating a new `Message` with an empty value, a special "DROP" tag, and the specified event time.
328    ///
329    /// # Arguments
330    ///
331    /// * `event_time` - The `DateTime<Utc>` that specifies when the event occurred. Event time is required because, even though a message is dropped,
332    ///   it is still considered as being processed, hence the watermark should be updated accordingly using the provided event time.
333    ///
334    /// # Examples
335    ///
336    /// ```
337    /// use numaflow::sourcetransform::Message;
338    /// use chrono::Utc;
339    /// let now = Utc::now();
340    /// let dropped_message = Message::message_to_drop(now);
341    /// ```
342    pub fn message_to_drop(event_time: DateTime<Utc>) -> Message {
343        Message {
344            keys: None,
345            value: vec![],
346            event_time,
347            tags: Some(vec![DROP.to_string()]),
348            user_metadata: None,
349        }
350    }
351
352    /// Sets or replaces the keys associated with this message.
353    ///
354    /// # Arguments
355    ///
356    /// * `keys` - A vector of strings representing the keys.
357    ///
358    /// # Examples
359    ///
360    /// ```
361    /// use numaflow::sourcetransform::Message;
362    /// use chrono::Utc;
363    /// let now = Utc::now();
364    /// let message = Message::new(vec![1, 2, 3], now).with_keys(vec!["key1".to_string(), "key2".to_string()]);
365    /// ```
366    pub fn with_keys(mut self, keys: Vec<String>) -> Self {
367        self.keys = Some(keys);
368        self
369    }
370    /// Sets or replaces the tags associated with this message.
371    ///
372    /// # Arguments
373    ///
374    /// * `tags` - A vector of strings representing the tags.
375    ///
376    /// # Examples
377    ///
378    /// ```
379    /// use numaflow::sourcetransform::Message;
380    /// use chrono::Utc;
381    /// let now = Utc::now();
382    /// let message = Message::new(vec![1, 2, 3], now).with_tags(vec!["tag1".to_string(), "tag2".to_string()]);
383    /// ```
384    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
385        self.tags = Some(tags);
386        self
387    }
388
389    /// Sets the user metadata for the message.
390    pub fn with_user_metadata(mut self, user_metadata: UserMetadata) -> Self {
391        self.user_metadata = Some(user_metadata);
392        self
393    }
394}
395
396/// Incoming request to the Source Transformer.
397pub struct SourceTransformRequest {
398    /// keys are the keys in the (key, value) terminology of map/reduce paradigm.
399    pub keys: Vec<String>,
400    /// value is the value in (key, value) terminology of map/reduce paradigm.
401    pub value: Vec<u8>,
402    /// [watermark](https://numaflow.numaproj.io/core-concepts/watermarks/) represented by time is a guarantee that we will not see an element older than this
403    /// time.
404    pub watermark: DateTime<Utc>,
405    /// event_time is the time of the element as seen at source or aligned after a reduce operation.
406    pub eventtime: DateTime<Utc>,
407    /// Headers for the message.
408    pub headers: HashMap<String, String>,
409    /// User metadata for the message.
410    pub user_metadata: UserMetadata,
411    /// System metadata for the message.
412    pub system_metadata: SystemMetadata,
413}
414
415/// Converts Option<&UserMetadata> to proto Metadata.
416/// SDKs should always return non-nil metadata.
417/// If user metadata is None or empty, it returns a metadata with empty user_metadata map.
418fn to_proto(user_metadata: Option<&UserMetadata>) -> metadata_pb::Metadata {
419    let mut user = HashMap::new();
420
421    if let Some(umd) = user_metadata {
422        for group in umd.groups() {
423            let mut kv = HashMap::new();
424            for key in umd.keys(&group) {
425                kv.insert(key.clone(), umd.value(&group, &key));
426            }
427            user.insert(group, metadata_pb::KeyValueGroup { key_value: kv });
428        }
429    }
430
431    metadata_pb::Metadata {
432        previous_vertex: String::new(),
433        sys_metadata: HashMap::new(),
434        user_metadata: user,
435    }
436}
437
438impl From<Message> for proto::source_transform_response::Result {
439    fn from(value: Message) -> Self {
440        proto::source_transform_response::Result {
441            keys: value.keys.unwrap_or_default(),
442            value: value.value,
443            event_time: prost_timestamp_from_utc(value.event_time),
444            tags: value.tags.unwrap_or_default(),
445            metadata: Some(to_proto(value.user_metadata.as_ref())),
446        }
447    }
448}
449
450/// Get UserMetadata from proto Metadata
451fn user_metadata_from_proto(proto: Option<&metadata_pb::Metadata>) -> UserMetadata {
452    let proto = match proto {
453        Some(p) => p,
454        None => return UserMetadata::new(),
455    };
456
457    let mut user_map = HashMap::new();
458    for (group, kv_group) in &proto.user_metadata {
459        user_map.insert(group.clone(), kv_group.key_value.clone());
460    }
461
462    UserMetadata { data: user_map }
463}
464
465/// Get SystemMetadata from proto Metadata
466fn system_metadata_from_proto(proto: Option<&metadata_pb::Metadata>) -> SystemMetadata {
467    let proto = match proto {
468        Some(p) => p,
469        None => return SystemMetadata::new(),
470    };
471
472    let mut sys_map = HashMap::new();
473    for (group, kv_group) in &proto.sys_metadata {
474        sys_map.insert(group.clone(), kv_group.key_value.clone());
475    }
476
477    SystemMetadata { data: sys_map }
478}
479
480impl From<proto::source_transform_request::Request> for SourceTransformRequest {
481    fn from(request: proto::source_transform_request::Request) -> Self {
482        let user_metadata = user_metadata_from_proto(request.metadata.as_ref());
483        let system_metadata = system_metadata_from_proto(request.metadata.as_ref());
484
485        SourceTransformRequest {
486            keys: request.keys,
487            value: request.value,
488            watermark: utc_from_timestamp(request.watermark),
489            eventtime: utc_from_timestamp(request.event_time),
490            headers: request.headers,
491            user_metadata,
492            system_metadata,
493        }
494    }
495}
496
497#[async_trait]
498impl<T> proto::source_transform_server::SourceTransform for SourceTransformerService<T>
499where
500    T: SourceTransformer + Send + Sync + 'static,
501{
502    type SourceTransformFnStream = ReceiverStream<Result<SourceTransformResponse, Status>>;
503
504    async fn source_transform_fn(
505        &self,
506        request: Request<Streaming<proto::SourceTransformRequest>>,
507    ) -> Result<Response<Self::SourceTransformFnStream>, Status> {
508        let mut stream = request.into_inner();
509        let handler = Arc::clone(&self.handler);
510
511        let (stream_response_tx, stream_response_rx) =
512            mpsc::channel::<Result<SourceTransformResponse, Status>>(CHANNEL_SIZE);
513
514        // do the handshake first to let the client know that we are ready to receive transformation requests.
515        perform_handshake(&mut stream, &stream_response_tx).await?;
516
517        let (error_tx, error_rx) = mpsc::channel::<Error>(1);
518
519        // Spawn a task to continuously receive messages from the client over the gRPC stream.
520        // For each message received from the stream, a new task is spawned to call the transform function and send the response back to the client
521        let handle: JoinHandle<()> = tokio::spawn(handle_stream_requests(
522            handler.clone(),
523            stream,
524            stream_response_tx.clone(),
525            error_tx,
526            self.cancellation_token.child_token(),
527        ));
528
529        tokio::spawn(manage_grpc_stream(
530            handle,
531            stream_response_tx,
532            error_rx,
533            self.shutdown_tx.clone(),
534        ));
535
536        Ok(Response::new(ReceiverStream::new(stream_response_rx)))
537    }
538
539    async fn is_ready(&self, _: Request<()>) -> Result<Response<proto::ReadyResponse>, Status> {
540        Ok(Response::new(proto::ReadyResponse { ready: true }))
541    }
542}
543
544async fn perform_handshake(
545    stream: &mut Streaming<proto::SourceTransformRequest>,
546    stream_response_tx: &mpsc::Sender<Result<SourceTransformResponse, Status>>,
547) -> Result<(), Status> {
548    let handshake_request = stream
549        .message()
550        .await
551        .map_err(|e| Status::internal(format!("Handshake failed: {}", e)))?
552        .ok_or_else(|| Status::internal("Stream closed before handshake"))?;
553
554    if let Some(handshake) = handshake_request.handshake {
555        stream_response_tx
556            .send(Ok(SourceTransformResponse {
557                results: vec![],
558                id: "".to_string(),
559                handshake: Some(handshake),
560            }))
561            .await
562            .map_err(|e| Status::internal(format!("Failed to send handshake response: {}", e)))?;
563        Ok(())
564    } else {
565        Err(Status::invalid_argument("Handshake not present"))
566    }
567}
568
569// shutdown the gRPC server on first error
570async fn manage_grpc_stream(
571    request_handler: JoinHandle<()>,
572    stream_response_tx: mpsc::Sender<Result<SourceTransformResponse, Status>>,
573    mut error_rx: mpsc::Receiver<Error>,
574    server_shutdown_tx: mpsc::Sender<()>,
575) {
576    let err = match error_rx.recv().await {
577        Some(err) => err,
578        None => match request_handler.await {
579            Ok(_) => return,
580            Err(e) => Error::SourceTransformerError(ErrorKind::InternalError(format!(
581                "Source transformer request handler aborted: {e:?}"
582            ))),
583        },
584    };
585
586    error!("Shutting down gRPC channel: {err:?}");
587    stream_response_tx
588        .send(Err(err.into_status()))
589        .await
590        .expect("Sending error message to gRPC response channel");
591    server_shutdown_tx
592        .send(())
593        .await
594        .expect("Writing to shutdown channel");
595}
596
597// Receives messages from the stream. For each message received from the stream,
598// a new task is spawned to call the transform function and send the response back to the client
599async fn handle_stream_requests<T>(
600    handler: Arc<T>,
601    mut stream: Streaming<proto::SourceTransformRequest>,
602    stream_response_tx: mpsc::Sender<Result<SourceTransformResponse, Status>>,
603    error_tx: mpsc::Sender<Error>,
604    token: CancellationToken,
605) where
606    T: SourceTransformer + Send + Sync + 'static,
607{
608    let mut stream_open = true;
609    while stream_open {
610        stream_open = tokio::select! {
611            transform_request = stream.message() => handle_request(
612                handler.clone(),
613                transform_request,
614                stream_response_tx.clone(),
615                error_tx.clone(),
616            ).await,
617            _ = token.cancelled() => {
618                info!("Cancellation token is cancelled, shutting down");
619                break;
620            }
621        }
622    }
623}
624
625// The return boolean value indicates whether a task was created to handle the request.
626// If the return value is false, either client sent an error gRPC status or the stream was closed.
627async fn handle_request<T>(
628    handler: Arc<T>,
629    transform_request: Result<Option<proto::SourceTransformRequest>, Status>,
630    stream_response_tx: mpsc::Sender<Result<SourceTransformResponse, Status>>,
631    error_tx: mpsc::Sender<Error>,
632) -> bool
633where
634    T: SourceTransformer + Send + Sync + 'static,
635{
636    let transform_request = match transform_request {
637        Ok(None) => return false,
638        Ok(Some(val)) => val,
639        Err(val) => {
640            error!("Received gRPC error from sender: {val:?}");
641            return false;
642        }
643    };
644    tokio::spawn(run_transform(
645        handler,
646        transform_request,
647        stream_response_tx,
648        error_tx,
649    ));
650    true
651}
652
653// Calls the user implemented transform function on the request.
654async fn run_transform<T>(
655    handler: Arc<T>,
656    transform_request: proto::SourceTransformRequest,
657    stream_response_tx: mpsc::Sender<Result<SourceTransformResponse, Status>>,
658    error_tx: mpsc::Sender<Error>,
659) where
660    T: SourceTransformer + Send + Sync + 'static,
661{
662    let request = transform_request.request.expect("request can not be none");
663    let message_id = request.id.clone();
664
665    // A new task is spawned to catch the panic
666    let udf_transform_task = tokio::spawn({
667        let handler = handler.clone();
668        async move { handler.transform(request.into()).await }
669    });
670
671    let messages = match udf_transform_task.await {
672        Ok(messages) => messages,
673        Err(e) => {
674            error!("Failed to run transform function: {e:?}");
675
676            // Check if this is a panic or a regular error
677            if let Some(panic_info) = get_panic_info() {
678                // This is a panic - send detailed panic information
679                let status = build_panic_status(&panic_info);
680                let _ = error_tx.send(Error::GrpcStatus(status)).await;
681            } else {
682                // This is a non-panic error
683                let _ = error_tx
684                    .send(Error::SourceTransformerError(ErrorKind::InternalError(
685                        format!("Transform task execution failed: {e:?}"),
686                    )))
687                    .await;
688            }
689            return;
690        }
691    };
692
693    let send_response_result = stream_response_tx
694        .send(Ok(SourceTransformResponse {
695            results: messages.into_iter().map(|msg| msg.into()).collect(),
696            id: message_id,
697            handshake: None,
698        }))
699        .await;
700
701    let Err(e) = send_response_result else {
702        return;
703    };
704
705    let _ = error_tx
706        .send(Error::SourceTransformerError(ErrorKind::InternalError(
707            format!("sending source transform response over gRPC stream: {e:?}"),
708        )))
709        .await;
710}
711
712/// gRPC server to start a sourcetransform service
713#[derive(Debug)]
714pub struct Server<T> {
715    inner: shared::Server<T>,
716}
717
718impl<T> shared::ServerExtras<T> for Server<T> {
719    fn transform_inner<F>(self, f: F) -> Self
720    where
721        F: FnOnce(shared::Server<T>) -> shared::Server<T>,
722    {
723        Self {
724            inner: f(self.inner),
725        }
726    }
727
728    fn inner_ref(&self) -> &shared::Server<T> {
729        &self.inner
730    }
731}
732
733impl<T> Server<T> {
734    pub fn new(sourcetransformer_svc: T) -> Self {
735        Self {
736            inner: shared::Server::new(
737                sourcetransformer_svc,
738                ContainerType::SourceTransformer,
739                SOCK_ADDR,
740                SERVER_INFO_FILE,
741            ),
742        }
743    }
744
745    /// Starts the gRPC server. When message is received on the `shutdown` channel, graceful shutdown of the gRPC server will be initiated.
746    pub async fn start_with_shutdown(
747        self,
748        shutdown_rx: oneshot::Receiver<()>,
749    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
750    where
751        T: SourceTransformer + Send + Sync + 'static,
752    {
753        self.inner
754            .start_with_shutdown(
755                shutdown_rx,
756                |handler, max_message_size, shutdown_tx, cln_token| {
757                    let sourcetrf_svc = SourceTransformerService {
758                        handler: Arc::new(handler),
759                        shutdown_tx,
760                        cancellation_token: cln_token,
761                    };
762
763                    let sourcetrf_svc =
764                        proto::source_transform_server::SourceTransformServer::new(sourcetrf_svc)
765                            .max_encoding_message_size(max_message_size)
766                            .max_decoding_message_size(max_message_size);
767
768                    tonic::transport::Server::builder().add_service(sourcetrf_svc)
769                },
770            )
771            .await
772    }
773
774    /// Starts the gRPC server. Automatically registers singal handlers for SIGINT and SIGTERM and initiates graceful shutdown of gRPC server when either one of the singal arrives.
775    pub async fn start(self) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
776    where
777        T: SourceTransformer + Send + Sync + 'static,
778    {
779        self.inner
780            .start(|handler, max_message_size, shutdown_tx, cln_token| {
781                let sourcetrf_svc = SourceTransformerService {
782                    handler: Arc::new(handler),
783                    shutdown_tx,
784                    cancellation_token: cln_token,
785                };
786
787                let sourcetrf_svc =
788                    proto::source_transform_server::SourceTransformServer::new(sourcetrf_svc)
789                        .max_encoding_message_size(max_message_size)
790                        .max_decoding_message_size(max_message_size);
791
792                tonic::transport::Server::builder().add_service(sourcetrf_svc)
793            })
794            .await
795    }
796}
797
798#[cfg(test)]
799mod tests {
800    use crate::shared::ServerExtras;
801    use chrono::Utc;
802    use std::{error::Error, time::Duration};
803    use tempfile::TempDir;
804    use tokio::net::UnixStream;
805    use tokio::sync::{mpsc, oneshot};
806    use tokio_stream::wrappers::ReceiverStream;
807    use tonic::transport::Uri;
808    use tower::service_fn;
809
810    use crate::proto::source_transformer::{
811        self as proto, source_transform_client::SourceTransformClient,
812    };
813    use crate::sourcetransform::{self};
814
815    #[tokio::test]
816    async fn source_transformer_server() -> Result<(), Box<dyn Error>> {
817        struct NowCat;
818        #[tonic::async_trait]
819        impl sourcetransform::SourceTransformer for NowCat {
820            async fn transform(
821                &self,
822                input: sourcetransform::SourceTransformRequest,
823            ) -> Vec<sourcetransform::Message> {
824                vec![sourcetransform::Message {
825                    keys: Some(input.keys),
826                    value: input.value,
827                    tags: Some(vec![]),
828                    event_time: Utc::now(),
829                    user_metadata: None,
830                }]
831            }
832        }
833
834        let tmp_dir = TempDir::new()?;
835        let sock_file = tmp_dir.path().join("sourcetransform.sock");
836        let server_info_file = tmp_dir.path().join("sourcetransformer-server-info");
837
838        let server = sourcetransform::Server::new(NowCat)
839            .with_server_info_file(&server_info_file)
840            .with_socket_file(&sock_file)
841            .with_max_message_size(10240);
842
843        assert_eq!(server.max_message_size(), 10240);
844        assert_eq!(server.server_info_file(), server_info_file);
845        assert_eq!(server.socket_file(), sock_file);
846
847        let (shutdown_tx, shutdown_rx) = oneshot::channel();
848        let task = tokio::spawn(async move { server.start_with_shutdown(shutdown_rx).await });
849
850        tokio::time::sleep(Duration::from_millis(50)).await;
851
852        // https://github.com/hyperium/tonic/blob/master/examples/src/uds/client.rs
853        let channel = tonic::transport::Endpoint::try_from("http://[::]:50051")?
854            .connect_with_connector(service_fn(move |_: Uri| {
855                // https://rust-lang.github.io/async-book/03_async_await/01_chapter.html#async-lifetimes
856                let sock_file = sock_file.clone();
857                async move {
858                    Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new(
859                        UnixStream::connect(sock_file).await?,
860                    ))
861                }
862            }))
863            .await?;
864
865        let mut client = SourceTransformClient::new(channel);
866
867        let (tx, rx) = mpsc::channel(2);
868
869        let handshake_request = proto::SourceTransformRequest {
870            request: None,
871            handshake: Some(proto::Handshake { sot: true }),
872        };
873        tx.send(handshake_request).await.unwrap();
874
875        let mut stream = tokio::time::timeout(
876            Duration::from_secs(2),
877            client.source_transform_fn(ReceiverStream::new(rx)),
878        )
879        .await
880        .map_err(|_| "timeout while getting stream for source_transform_fn")??
881        .into_inner();
882
883        let handshake_resp = stream.message().await?.unwrap();
884        assert!(
885            handshake_resp.results.is_empty(),
886            "The handshake response should not contain any messages"
887        );
888        assert!(
889            handshake_resp.id.is_empty(),
890            "The message id of the handshake response should be empty"
891        );
892        assert!(
893            handshake_resp.handshake.is_some(),
894            "Not a valid response for handshake request"
895        );
896
897        let request = proto::SourceTransformRequest {
898            request: Some(proto::source_transform_request::Request {
899                id: "1".to_string(),
900                keys: vec!["first".into(), "second".into()],
901                value: "hello".into(),
902                watermark: Some(prost_types::Timestamp::default()),
903                event_time: Some(prost_types::Timestamp::default()),
904                headers: Default::default(),
905                metadata: None,
906            }),
907            handshake: None,
908        };
909
910        tx.send(request).await.unwrap();
911
912        let resp = stream.message().await?.unwrap();
913        assert_eq!(resp.results.len(), 1, "Expected single message from server");
914        let msg = &resp.results[0];
915        assert_eq!(msg.keys.first(), Some(&"first".to_owned()));
916        assert_eq!(msg.value, "hello".as_bytes());
917
918        drop(tx);
919
920        shutdown_tx
921            .send(())
922            .expect("Sending shutdown signal to gRPC server");
923        tokio::time::sleep(Duration::from_millis(50)).await;
924        assert!(task.is_finished(), "gRPC server is still running");
925        Ok(())
926    }
927
928    #[cfg(feature = "test-panic")]
929    #[tokio::test]
930    async fn source_transformer_panic() -> Result<(), Box<dyn Error>> {
931        struct PanicTransformer;
932        #[tonic::async_trait]
933        impl sourcetransform::SourceTransformer for PanicTransformer {
934            async fn transform(
935                &self,
936                _: sourcetransform::SourceTransformRequest,
937            ) -> Vec<sourcetransform::Message> {
938                panic!("Panic in transformer");
939            }
940        }
941
942        let tmp_dir = TempDir::new()?;
943        let sock_file = tmp_dir.path().join("sourcetransform.sock");
944        let server_info_file = tmp_dir.path().join("sourcetransformer-server-info");
945
946        let server = sourcetransform::Server::new(PanicTransformer)
947            .with_server_info_file(&server_info_file)
948            .with_socket_file(&sock_file)
949            .with_max_message_size(10240);
950
951        assert_eq!(server.max_message_size(), 10240);
952        assert_eq!(server.server_info_file(), server_info_file);
953        assert_eq!(server.socket_file(), sock_file);
954
955        let (_shutdown_tx, shutdown_rx) = oneshot::channel();
956        let task = tokio::spawn(async move { server.start_with_shutdown(shutdown_rx).await });
957
958        tokio::time::sleep(Duration::from_millis(50)).await;
959
960        // https://github.com/hyperium/tonic/blob/master/examples/src/uds/client.rs
961        let channel = tonic::transport::Endpoint::try_from("http://[::]:50051")?
962            .connect_with_connector(service_fn(move |_: Uri| {
963                // https://rust-lang.github.io/async-book/03_async_await/01_chapter.html#async-lifetimes
964                let sock_file = sock_file.clone();
965                async move {
966                    Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new(
967                        UnixStream::connect(sock_file).await?,
968                    ))
969                }
970            }))
971            .await?;
972
973        let mut client = SourceTransformClient::new(channel);
974
975        let (tx, rx) = mpsc::channel(2);
976        let handshake_request = proto::SourceTransformRequest {
977            request: None,
978            handshake: Some(proto::Handshake { sot: true }),
979        };
980        tx.send(handshake_request).await.unwrap();
981
982        let mut stream = tokio::time::timeout(
983            Duration::from_secs(2),
984            client.source_transform_fn(ReceiverStream::new(rx)),
985        )
986        .await
987        .map_err(|_| "timeout while getting stream for source_transform_fn")??
988        .into_inner();
989
990        let handshake_resp = stream.message().await?.unwrap();
991        assert!(
992            handshake_resp.handshake.is_some(),
993            "Not a valid response for handshake request"
994        );
995
996        let request = proto::SourceTransformRequest {
997            request: Some(proto::source_transform_request::Request {
998                id: "2".to_string(),
999                keys: vec!["first".into(), "second".into()],
1000                value: "hello".into(),
1001                watermark: Some(prost_types::Timestamp::default()),
1002                event_time: Some(prost_types::Timestamp::default()),
1003                headers: Default::default(),
1004                metadata: None,
1005            }),
1006            handshake: None,
1007        };
1008        tx.send(request).await.unwrap();
1009        drop(tx);
1010
1011        // server should shut down gracefully because there was a panic in the handler.
1012        for _ in 0..10 {
1013            tokio::time::sleep(Duration::from_millis(10)).await;
1014            if task.is_finished() {
1015                break;
1016            }
1017        }
1018        assert!(task.is_finished(), "gRPC server is still running");
1019        Ok(())
1020    }
1021}