Skip to main content

hyperi_rustlib/transport/grpc/
mod.rs

1// Project:   hyperi-rustlib
2// File:      src/transport/grpc/mod.rs
3// Purpose:   gRPC transport backend
4// Language:  Rust
5//
6// License:   BUSL-1.1
7// Copyright: (c) 2026 HYPERI PTY LIMITED
8
9//! # gRPC Transport
10//!
11//! DFE native gRPC transport using tonic. Supports client mode (sending),
12//! server mode (receiving), or both.
13//!
14//! ## DFE Native Protocol
15//!
16//! Lightweight bulk bytes transfer via `dfe.transport.v1.DfeTransport/Push`.
17//! Payload is opaque bytes (JSON, MsgPack, or Arrow IPC) with a format hint.
18//!
19//! ## Vector Wire Protocol Compatibility (optional)
20//!
21//! When the `transport-grpc-vector-compat` feature is enabled and
22//! `GrpcConfig::vector_compat` is true, the server also accepts
23//! `vector.Vector/PushEvents` RPCs from legacy Vector sinks.
24//!
25//! ## Example
26//!
27//! ```rust,ignore
28//! use hyperi_rustlib::transport::{GrpcTransport, GrpcConfig, TransportReceiver};
29//!
30//! // Server mode (receive from remote senders)
31//! let config = GrpcConfig::server("0.0.0.0:6000");
32//! let transport = GrpcTransport::new(&config).await?;
33//!
34//! let records = transport.recv(100).await?.records;
35//! // commit is a no-op for gRPC (no persistence)
36//! transport.commit(&[]).await?;
37//! ```
38
39pub mod batch;
40pub mod config;
41pub mod proto;
42pub mod token;
43
44pub use config::GrpcConfig;
45pub use token::GrpcToken;
46
47use super::error::{TransportError, TransportResult};
48use super::traits::{RecvBatch, TransportBase, TransportReceiver, TransportSender};
49use super::types::{Message, PayloadFormat, SendResult};
50use super::work_batch::{Record, WorkBatch};
51use std::collections::HashMap;
52use std::sync::Arc;
53use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
54use tokio::sync::{mpsc, oneshot};
55use tonic::{Request, Response, Status};
56
57/// gRPC transport for DFE inter-service communication.
58///
59/// Implements both `TransportSender` and `TransportReceiver`, so it also
60/// satisfies the unified `Transport` trait via blanket impl.
61pub struct GrpcTransport {
62    /// Client for sending (None if server-only mode).
63    client: Option<proto::dfe_transport_client::DfeTransportClient<tonic::transport::Channel>>,
64
65    /// Receiver channel (None if client-only mode).
66    receiver: Option<tokio::sync::Mutex<mpsc::Receiver<Message<GrpcToken>>>>,
67
68    /// Shutdown signal for the server task. Behind a `Mutex<Option<..>>` so
69    /// `close(&self)` (not just `Drop`) can take and fire it.
70    shutdown_tx: parking_lot::Mutex<Option<oneshot::Sender<()>>>,
71
72    /// Server background task handle (kept alive, aborted on drop).
73    _server_handle: Option<tokio::task::JoinHandle<Result<(), tonic::transport::Error>>>,
74
75    /// Whether the transport is closed.
76    closed: AtomicBool,
77
78    /// Shared healthy flag -- read by health registry closure, written by close().
79    healthy: Arc<AtomicBool>,
80
81    /// Receive timeout (milliseconds).
82    recv_timeout_ms: u64,
83
84    /// Per-RPC send deadline (milliseconds, 0 = none).
85    send_timeout_ms: u64,
86
87    /// In-flight send count (for metrics).
88    #[cfg(feature = "metrics")]
89    inflight: AtomicU64,
90
91    /// Transport-level message filter engine.
92    filter_engine: super::filter::TransportFilterEngine,
93}
94
95/// Build a tonic `ClientTlsConfig` from the unified TLS fields on `GrpcConfig`.
96///
97/// tonic owns its TLS stack (like librdkafka), so this maps the unified
98/// `TlsTrust` vocabulary onto `ClientTlsConfig`: private-CA PEM (else OS native
99/// roots), optional SNI domain override, and optional mTLS client identity.
100fn build_grpc_client_tls(
101    config: &GrpcConfig,
102) -> TransportResult<tonic::transport::ClientTlsConfig> {
103    use tonic::transport::{Certificate, ClientTlsConfig, Identity};
104
105    let mut tls = ClientTlsConfig::new();
106
107    if let Some(ref ca) = config.tls_ca_path {
108        let pem = std::fs::read(ca)
109            .map_err(|e| TransportError::Config(format!("gRPC TLS: cannot read ca {ca}: {e}")))?;
110        tls = tls.ca_certificate(Certificate::from_pem(pem));
111    } else {
112        // No private CA -> trust the OS native roots.
113        tls = tls.with_native_roots();
114    }
115
116    if let Some(ref domain) = config.tls_domain {
117        tls = tls.domain_name(domain.clone());
118    }
119
120    // mTLS identity -- both cert and key, or neither.
121    match (&config.tls_client_cert_path, &config.tls_client_key_path) {
122        (Some(cert), Some(key)) => {
123            let cert_pem = std::fs::read(cert).map_err(|e| {
124                TransportError::Config(format!("gRPC TLS: cannot read client cert {cert}: {e}"))
125            })?;
126            let key_pem = std::fs::read(key).map_err(|e| {
127                TransportError::Config(format!("gRPC TLS: cannot read client key {key}: {e}"))
128            })?;
129            tls = tls.identity(Identity::from_pem(cert_pem, key_pem));
130        }
131        (None, None) => {}
132        _ => {
133            return Err(TransportError::Config(
134                "gRPC TLS: mTLS requires BOTH tls_client_cert_path and tls_client_key_path"
135                    .to_string(),
136            ));
137        }
138    }
139
140    Ok(tls)
141}
142
143impl GrpcTransport {
144    /// Create a new gRPC transport.
145    ///
146    /// # Configuration
147    ///
148    /// - Set `config.listen` to start a gRPC server (receive mode).
149    /// - Set `config.endpoint` to connect to a remote server (send mode).
150    /// - Set both for bidirectional communication.
151    ///
152    /// # Errors
153    ///
154    /// Returns error if the listen address is invalid or the server fails to start.
155    pub async fn new(config: &GrpcConfig) -> TransportResult<Self> {
156        Self::new_inner(
157            config,
158            #[cfg(feature = "governor")]
159            None,
160        )
161        .await
162    }
163
164    /// Create a gRPC transport bound to a pressure governor (G3, `governor`
165    /// feature).
166    ///
167    /// Identical to [`new`](Self::new) except the receive server consults
168    /// `pressure` BEFORE enqueuing each inbound Push / batch record: while
169    /// [`UnifiedPressure::should_hold`](crate::governor::UnifiedPressure::should_hold)
170    /// holds, the RPC is rejected with `Status::unavailable` (the gRPC analogue
171    /// of HTTP 503, matching the existing channel-full backpressure mapping).
172    /// Passing `None` is exactly equivalent to [`new`](Self::new).
173    ///
174    /// # Errors
175    ///
176    /// Same as [`new`](Self::new).
177    #[cfg(feature = "governor")]
178    pub async fn with_pressure(
179        config: &GrpcConfig,
180        pressure: Option<Arc<crate::governor::UnifiedPressure>>,
181    ) -> TransportResult<Self> {
182        Self::new_inner(config, pressure).await
183    }
184
185    async fn new_inner(
186        config: &GrpcConfig,
187        #[cfg(feature = "governor")] pressure: Option<Arc<crate::governor::UnifiedPressure>>,
188    ) -> TransportResult<Self> {
189        let mut client = None;
190        let mut receiver = None;
191        let mut shutdown_tx = None;
192        let mut server_handle = None;
193        let sequence = Arc::new(AtomicU64::new(0));
194
195        // Set up client (lazy connection -- doesn't fail until first RPC)
196        if let Some(endpoint) = &config.endpoint {
197            let mut ep = tonic::transport::Channel::from_shared(endpoint.clone())
198                .map_err(|e| TransportError::Config(format!("invalid endpoint: {e}")))?;
199
200            // Client TLS. tonic owns its TLS stack, so we map the unified
201            // vocabulary onto ClientTlsConfig (private CA, mTLS identity, SNI).
202            if config.tls_enabled {
203                ep = ep
204                    .tls_config(build_grpc_client_tls(config)?)
205                    .map_err(|e| TransportError::Config(format!("gRPC TLS config: {e}")))?;
206            }
207
208            let channel = ep.connect_lazy();
209
210            let mut c = proto::dfe_transport_client::DfeTransportClient::new(channel)
211                .max_decoding_message_size(config.max_message_size)
212                .max_encoding_message_size(config.max_message_size);
213
214            if config.compression {
215                c = c
216                    .send_compressed(tonic::codec::CompressionEncoding::Gzip)
217                    .accept_compressed(tonic::codec::CompressionEncoding::Gzip);
218            }
219
220            client = Some(c);
221        }
222
223        // Set up server
224        if let Some(listen) = &config.listen {
225            let addr: std::net::SocketAddr = listen
226                .parse()
227                .map_err(|e| TransportError::Config(format!("invalid listen address: {e}")))?;
228
229            let (tx, rx) = mpsc::channel(config.recv_buffer_size);
230            let (sd_tx, sd_rx) = oneshot::channel();
231
232            // DFE native service
233            let dfe_svc = DfeTransportServiceImpl {
234                sender: tx.clone(),
235                sequence: sequence.clone(),
236                #[cfg(feature = "governor")]
237                pressure: pressure.clone(),
238            };
239
240            let dfe_server = proto::dfe_transport_server::DfeTransportServer::new(dfe_svc)
241                .max_decoding_message_size(config.max_message_size)
242                .max_encoding_message_size(config.max_message_size)
243                .accept_compressed(tonic::codec::CompressionEncoding::Gzip)
244                .send_compressed(tonic::codec::CompressionEncoding::Gzip);
245
246            // Build server with optional Vector compat
247            let mut builder = tonic::transport::Server::builder();
248
249            #[cfg(feature = "transport-grpc-vector-compat")]
250            let router = if config.vector_compat {
251                let vector_svc =
252                    super::vector_compat::source::VectorCompatService::new(tx, sequence.clone());
253                let vector_server =
254                    super::vector_compat::proto::vector::vector_server::VectorServer::new(
255                        vector_svc,
256                    )
257                    .max_decoding_message_size(config.max_message_size)
258                    .accept_compressed(tonic::codec::CompressionEncoding::Gzip)
259                    .send_compressed(tonic::codec::CompressionEncoding::Gzip);
260
261                builder.add_service(dfe_server).add_service(vector_server)
262            } else {
263                builder.add_service(dfe_server)
264            };
265
266            #[cfg(not(feature = "transport-grpc-vector-compat"))]
267            let router = builder.add_service(dfe_server);
268
269            // Bind the listener synchronously BEFORE spawning the serve task.
270            // Once `TcpListener::bind` returns the OS socket is listening and
271            // queues incoming connections, so `new()` returning is a true
272            // readiness signal -- callers (and their tests) can connect
273            // immediately with no polling. `serve_with_shutdown(addr, ..)`
274            // bound inside the spawned task, which made `new()` return before
275            // the socket existed and forced every consumer to poll the port.
276            let listener = tokio::net::TcpListener::bind(addr)
277                .await
278                .map_err(|e| TransportError::Config(format!("failed to bind {addr}: {e}")))?;
279            let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
280
281            let handle = tokio::spawn(async move {
282                router
283                    .serve_with_incoming_shutdown(incoming, async {
284                        sd_rx.await.ok();
285                    })
286                    .await
287            });
288
289            receiver = Some(tokio::sync::Mutex::new(rx));
290            shutdown_tx = Some(sd_tx);
291            server_handle = Some(handle);
292        } else {
293            // No receive server -> nothing to attach the governor to. Consume
294            // it so the param stays uniform with no unused-variable warning.
295            #[cfg(feature = "governor")]
296            let _ = pressure;
297        }
298
299        let healthy = Arc::new(AtomicBool::new(true));
300
301        let filter_engine = super::filter::TransportFilterEngine::new(
302            &config.filters_in,
303            &config.filters_out,
304            &crate::transport::filter::TransportFilterTierConfig::from_cascade(),
305        )?;
306
307        #[cfg(feature = "health")]
308        {
309            let h = Arc::clone(&healthy);
310            crate::health::HealthRegistry::register("transport:grpc", move || {
311                if h.load(Ordering::Relaxed) {
312                    crate::health::HealthStatus::Healthy
313                } else {
314                    crate::health::HealthStatus::Unhealthy
315                }
316            });
317        }
318
319        Ok(Self {
320            client,
321            receiver,
322            shutdown_tx: parking_lot::Mutex::new(shutdown_tx),
323            _server_handle: server_handle,
324            closed: AtomicBool::new(false),
325            healthy,
326            recv_timeout_ms: config.recv_timeout_ms,
327            send_timeout_ms: config.send_timeout_ms,
328            #[cfg(feature = "metrics")]
329            inflight: AtomicU64::new(0),
330            filter_engine,
331        })
332    }
333}
334
335impl TransportSender for GrpcTransport {
336    async fn send(&self, key: &str, payload: bytes::Bytes) -> SendResult {
337        if self.closed.load(Ordering::Relaxed) {
338            return SendResult::Fatal(TransportError::Closed);
339        }
340
341        // Outbound filter check
342        if self.filter_engine.has_outbound_filters() {
343            match self.filter_engine.apply_outbound(&payload) {
344                super::filter::FilterDisposition::Pass => {}
345                super::filter::FilterDisposition::Drop => return SendResult::Ok,
346                super::filter::FilterDisposition::Dlq => return SendResult::FilteredDlq,
347            }
348        }
349
350        let Some(client) = &self.client else {
351            return SendResult::Fatal(TransportError::Config(
352                "no endpoint configured for sending".into(),
353            ));
354        };
355
356        let mut metadata = HashMap::new();
357        if !key.is_empty() {
358            metadata.insert("topic".to_string(), key.to_string());
359        }
360
361        // Inject W3C traceparent into gRPC metadata for distributed tracing
362        #[cfg(feature = "transport-trace")]
363        if let Some(tp) = super::propagation::current_traceparent() {
364            metadata.insert(super::propagation::TRACEPARENT_HEADER.to_string(), tp);
365        }
366
367        let mut request = tonic::Request::new(proto::PushRequest {
368            // `payload` is already `bytes::Bytes` and the proto field is now
369            // `Bytes` too (`.bytes(".")` in build.rs) -- move the handle, no copy.
370            payload,
371            format: proto::Format::Auto.into(),
372            metadata,
373        });
374
375        // Bound the RPC so a hung/black-holing server cannot wedge the sender
376        // task forever. Sent as the grpc-timeout header; the server aborts and
377        // the client surfaces Code::DeadlineExceeded when it elapses.
378        if self.send_timeout_ms > 0 {
379            request.set_timeout(std::time::Duration::from_millis(self.send_timeout_ms));
380        }
381
382        #[cfg(feature = "metrics")]
383        let start = std::time::Instant::now();
384
385        #[cfg(feature = "metrics")]
386        self.inflight.fetch_add(1, Ordering::Relaxed);
387
388        // tonic clients are cheaply cloneable (shared channel)
389        let result = match client.clone().push(request).await {
390            Ok(_) => {
391                #[cfg(feature = "metrics")]
392                metrics::counter!("dfe_transport_sent_total", "transport" => "grpc").increment(1);
393                SendResult::Ok
394            }
395            Err(status) => match status.code() {
396                // DeadlineExceeded = our send_timeout_ms fired (slow/hung server).
397                // Transient -- treat as backpressure so the caller retries rather
398                // than dropping the message.
399                tonic::Code::Unavailable
400                | tonic::Code::ResourceExhausted
401                | tonic::Code::DeadlineExceeded => {
402                    #[cfg(feature = "metrics")]
403                    metrics::counter!(
404                        "dfe_transport_backpressured_total",
405                        "transport" => "grpc"
406                    )
407                    .increment(1);
408                    SendResult::Backpressured
409                }
410                _ => {
411                    #[cfg(feature = "metrics")]
412                    metrics::counter!(
413                        "dfe_transport_send_errors_total",
414                        "transport" => "grpc"
415                    )
416                    .increment(1);
417                    SendResult::Fatal(TransportError::Send(status.message().to_string()))
418                }
419            },
420        };
421
422        #[cfg(feature = "metrics")]
423        {
424            self.inflight.fetch_sub(1, Ordering::Relaxed);
425            metrics::gauge!("dfe_transport_inflight", "transport" => "grpc")
426                .set(self.inflight.load(Ordering::Relaxed) as f64);
427            metrics::histogram!(
428                "dfe_transport_send_duration_seconds",
429                "transport" => "grpc"
430            )
431            .record(start.elapsed().as_secs_f64());
432        }
433
434        result
435    }
436
437    /// Send a whole batch of records in ONE `RouteBatch` RPC (Task 0.6).
438    ///
439    /// The native batch override of [`TransportSender::send_batch`]: serde-less
440    /// rustlib<->rustlib transfer. The records map to a proto
441    /// [`Batch`](proto::Batch) via [`batch::records_to_proto`] -- payloads travel
442    /// as OPAQUE `bytes` and the JSON / MsgPack codec is NEVER invoked in
443    /// transit. The whole batch goes in a single call (batch-at-a-time, NOT
444    /// record-by-record streaming), so unlike the trait's per-record default
445    /// there is no partial-send window: the block is accepted or not as a unit.
446    ///
447    /// Commit tokens and inline-DLQ entries are NOT sent -- they are the
448    /// SENDER's local concern. Pass the records (e.g. `&workbatch.records`); the
449    /// caller fires its commit tokens locally after this returns `Ok`.
450    ///
451    /// ## Atomic (all-or-nothing) acceptance
452    ///
453    /// The server handler reserves receiver-channel capacity for the WHOLE
454    /// batch (one `try_reserve_many`) BEFORE enqueuing any record, so the block
455    /// is accepted or rejected as a unit -- there is genuinely no partial-send
456    /// window. A `Backpressured` result means ZERO records were admitted, so the
457    /// caller safely retries the whole block (at-least-once) with no risk of the
458    /// receiver having kept a prefix on the prior attempt (no duplicate prefix).
459    ///
460    /// # Errors / result
461    ///
462    /// Returns a [`SendResult`]. `Backpressured` maps the same transient gRPC
463    /// codes as [`send`](TransportSender::send) so the caller retries the whole
464    /// block rather than dropping it (at-least-once).
465    async fn send_batch(&self, records: &[Record]) -> SendResult {
466        if self.closed.load(Ordering::Relaxed) {
467            return SendResult::Fatal(TransportError::Closed);
468        }
469
470        let Some(client) = &self.client else {
471            return SendResult::Fatal(TransportError::Config(
472                "no endpoint configured for sending".into(),
473            ));
474        };
475
476        // Map records -> proto Batch. Payloads are MOVED (Bytes handle), opaque.
477        let proto_batch = batch::records_to_proto(records.to_vec());
478
479        let mut request = tonic::Request::new(proto_batch);
480
481        // Inject W3C traceparent into gRPC metadata for distributed tracing.
482        #[cfg(feature = "transport-trace")]
483        if let Some(tp) = super::propagation::current_traceparent()
484            && let Ok(val) = tp.parse()
485        {
486            request
487                .metadata_mut()
488                .insert(super::propagation::TRACEPARENT_HEADER, val);
489        }
490
491        if self.send_timeout_ms > 0 {
492            request.set_timeout(std::time::Duration::from_millis(self.send_timeout_ms));
493        }
494
495        #[cfg(feature = "metrics")]
496        let start = std::time::Instant::now();
497        #[cfg(feature = "metrics")]
498        self.inflight.fetch_add(1, Ordering::Relaxed);
499
500        let result = match client.clone().route_batch(request).await {
501            Ok(_) => {
502                #[cfg(feature = "metrics")]
503                metrics::counter!(
504                    "dfe_transport_sent_total",
505                    "transport" => "grpc",
506                    "path" => "batch"
507                )
508                .increment(records.len() as u64);
509                SendResult::Ok
510            }
511            Err(status) => match status.code() {
512                tonic::Code::Unavailable
513                | tonic::Code::ResourceExhausted
514                | tonic::Code::DeadlineExceeded => {
515                    #[cfg(feature = "metrics")]
516                    metrics::counter!(
517                        "dfe_transport_backpressured_total",
518                        "transport" => "grpc"
519                    )
520                    .increment(1);
521                    SendResult::Backpressured
522                }
523                _ => {
524                    #[cfg(feature = "metrics")]
525                    metrics::counter!(
526                        "dfe_transport_send_errors_total",
527                        "transport" => "grpc"
528                    )
529                    .increment(1);
530                    SendResult::Fatal(TransportError::Send(status.message().to_string()))
531                }
532            },
533        };
534
535        #[cfg(feature = "metrics")]
536        {
537            self.inflight.fetch_sub(1, Ordering::Relaxed);
538            metrics::histogram!(
539                "dfe_transport_send_duration_seconds",
540                "transport" => "grpc"
541            )
542            .record(start.elapsed().as_secs_f64());
543        }
544
545        result
546    }
547}
548
549impl TransportBase for GrpcTransport {
550    async fn close(&self) -> TransportResult<()> {
551        self.closed.store(true, Ordering::Relaxed);
552        self.healthy.store(false, Ordering::Relaxed);
553
554        // Actually stop the server: fire the shutdown oneshot so
555        // serve_with_incoming_shutdown completes and the listener is freed.
556        // Idempotent -- a second close() (or Drop) finds None.
557        if let Some(tx) = self.shutdown_tx.lock().take() {
558            let _ = tx.send(());
559        }
560        Ok(())
561    }
562
563    fn is_healthy(&self) -> bool {
564        let healthy = self.healthy.load(Ordering::Relaxed);
565        #[cfg(feature = "metrics")]
566        metrics::gauge!("dfe_transport_healthy", "transport" => "grpc").set(if healthy {
567            1.0
568        } else {
569            0.0
570        });
571        healthy
572    }
573
574    fn name(&self) -> &'static str {
575        "grpc"
576    }
577}
578
579impl TransportReceiver for GrpcTransport {
580    type Token = GrpcToken;
581
582    async fn recv(&self, max: usize) -> TransportResult<WorkBatch<Self::Token>> {
583        if self.closed.load(Ordering::Relaxed) {
584            return Err(TransportError::Closed);
585        }
586
587        let Some(receiver) = &self.receiver else {
588            return Err(TransportError::Config(
589                "no listen address configured for receiving".into(),
590            ));
591        };
592
593        let mut rx = receiver.lock().await;
594        let mut messages = Vec::with_capacity(max.min(100));
595
596        for _ in 0..max {
597            let result = if self.recv_timeout_ms == 0 {
598                // Non-blocking
599                match rx.try_recv() {
600                    Ok(msg) => Some(msg),
601                    Err(mpsc::error::TryRecvError::Empty) => break,
602                    Err(mpsc::error::TryRecvError::Disconnected) => {
603                        return Err(TransportError::Closed);
604                    }
605                }
606            } else if messages.is_empty() {
607                // First message: wait with timeout
608                match tokio::time::timeout(
609                    std::time::Duration::from_millis(self.recv_timeout_ms),
610                    rx.recv(),
611                )
612                .await
613                {
614                    Ok(Some(msg)) => Some(msg),
615                    Ok(None) => return Err(TransportError::Closed),
616                    Err(_) => break, // Timeout
617                }
618            } else {
619                // Subsequent: non-blocking drain
620                match rx.try_recv() {
621                    Ok(msg) => Some(msg),
622                    Err(_) => break,
623                }
624            };
625
626            if let Some(msg) = result {
627                messages.push(msg);
628            }
629        }
630
631        // Apply inbound filters via the shared partition helper; DLQ entries
632        // are returned in the RecvBatch for the caller to route onward.
633        let batch =
634            self.filter_engine
635                .partition_batch(messages, |m| m.payload.as_ref(), |m| m.key.clone());
636        let messages = batch.messages;
637        let dlq_entries = batch.dlq_entries;
638
639        Ok(RecvBatch {
640            messages,
641            dlq_entries,
642        }
643        .into())
644    }
645
646    async fn commit(&self, _tokens: &[Self::Token]) -> TransportResult<()> {
647        // gRPC has no broker-side persistence -- commit is a no-op.
648        // Acknowledgement is implicit in the Push RPC response.
649        Ok(())
650    }
651}
652
653impl Drop for GrpcTransport {
654    fn drop(&mut self) {
655        // Fire the shutdown signal if close() didn't already (idempotent).
656        if let Some(tx) = self.shutdown_tx.lock().take() {
657            let _ = tx.send(());
658        }
659        // Server handle will be dropped, which aborts the task
660    }
661}
662
663// --- DFE Transport gRPC service implementation ---
664
665/// Internal service implementation that receives Push RPCs
666/// and forwards messages into the transport's mpsc channel.
667struct DfeTransportServiceImpl {
668    sender: mpsc::Sender<Message<GrpcToken>>,
669    sequence: Arc<AtomicU64>,
670    /// Optional pressure governor (G3, `governor` feature). `None` by default
671    /// -> the handlers never consult it and behaviour is byte-identical. When
672    /// `Some`, an inbound Push / batch record is rejected with
673    /// `Status::unavailable` while [`UnifiedPressure::should_hold`] holds --
674    /// pressure-driven shedding ON TOP of the existing channel-full rejection.
675    #[cfg(feature = "governor")]
676    pressure: Option<Arc<crate::governor::UnifiedPressure>>,
677}
678
679#[tonic::async_trait]
680impl proto::dfe_transport_server::DfeTransport for DfeTransportServiceImpl {
681    async fn push(
682        &self,
683        request: Request<proto::PushRequest>,
684    ) -> Result<Response<proto::PushResponse>, Status> {
685        // G3 pressure-driven shedding (governor feature, opt-in). BEFORE doing
686        // any work, if a governor is wired and it says hold, reject with
687        // `unavailable` -- the gRPC analogue of HTTP 503, mirroring the
688        // channel-full rejection below. Default `None` -> skipped, unchanged.
689        #[cfg(feature = "governor")]
690        if let Some(pressure) = &self.pressure
691            && pressure.should_hold()
692        {
693            #[cfg(feature = "metrics")]
694            metrics::counter!(
695                "dfe_transport_backpressured_total",
696                "transport" => "grpc",
697                "reason" => "pressure"
698            )
699            .increment(1);
700            return Err(Status::unavailable("under pressure -- inbound held"));
701        }
702
703        let req = request.into_inner();
704        let seq = self.sequence.fetch_add(1, Ordering::Relaxed);
705
706        // Extract W3C traceparent from incoming gRPC metadata for distributed tracing
707        #[cfg(feature = "transport-trace")]
708        if let Some(tp) = req.metadata.get(super::propagation::TRACEPARENT_HEADER)
709            && super::propagation::is_valid_traceparent(tp)
710        {
711            tracing::Span::current().record("traceparent", tp.as_str());
712        }
713
714        let format = PayloadFormat::detect(&req.payload);
715        let key = req.metadata.get("topic").map(|s| Arc::from(s.as_str()));
716
717        // `req.payload` is already prost `Bytes` (`.bytes(".")` in build.rs) --
718        // the decode was zero-copy, so this is a move, not a copy.
719        let msg = Message {
720            key,
721            payload: req.payload,
722            token: GrpcToken::new(seq),
723            timestamp_ms: None,
724            format,
725        };
726
727        match self.sender.try_send(msg) {
728            Ok(()) => {
729                #[cfg(feature = "metrics")]
730                {
731                    metrics::counter!("dfe_transport_sent_total", "transport" => "grpc")
732                        .increment(1);
733                    metrics::gauge!("dfe_transport_queue_size", "transport" => "grpc").set(
734                        self.sender
735                            .max_capacity()
736                            .saturating_sub(self.sender.capacity()) as f64,
737                    );
738                }
739                Ok(Response::new(proto::PushResponse { accepted: 1 }))
740            }
741            Err(mpsc::error::TrySendError::Full(_)) => {
742                #[cfg(feature = "metrics")]
743                metrics::counter!(
744                    "dfe_transport_backpressured_total",
745                    "transport" => "grpc"
746                )
747                .increment(1);
748                Err(Status::resource_exhausted("receiver buffer full"))
749            }
750            Err(mpsc::error::TrySendError::Closed(_)) => {
751                #[cfg(feature = "metrics")]
752                metrics::counter!(
753                    "dfe_transport_refused_total",
754                    "transport" => "grpc"
755                )
756                .increment(1);
757                Err(Status::unavailable("receiver closed"))
758            }
759        }
760    }
761
762    async fn route_batch(
763        &self,
764        request: Request<proto::Batch>,
765    ) -> Result<Response<proto::BatchAck>, Status> {
766        // G3 pressure-driven shedding (governor feature, opt-in): reject the
767        // whole batch with `unavailable` while pressure holds. Default `None`
768        // -> skipped, byte-identical.
769        #[cfg(feature = "governor")]
770        if let Some(pressure) = &self.pressure
771            && pressure.should_hold()
772        {
773            #[cfg(feature = "metrics")]
774            metrics::counter!(
775                "dfe_transport_backpressured_total",
776                "transport" => "grpc",
777                "reason" => "pressure"
778            )
779            .increment(1);
780            return Err(Status::unavailable("under pressure -- inbound held"));
781        }
782
783        // Extract W3C traceparent from incoming gRPC metadata for distributed
784        // tracing, BEFORE consuming the request body.
785        #[cfg(feature = "transport-trace")]
786        if let Some(tp) = request
787            .metadata()
788            .get(super::propagation::TRACEPARENT_HEADER)
789            .and_then(|v| v.to_str().ok())
790            && super::propagation::is_valid_traceparent(tp)
791        {
792            tracing::Span::current().record("traceparent", tp);
793        }
794
795        let proto_batch = request.into_inner();
796
797        // Decode the proto Batch back into rustlib Records (payloads are
798        // zero-copy `Bytes`; the codec is NOT invoked here). Each record fans
799        // into the SAME mpsc channel the single-message Push path uses, so the
800        // existing recv() path delivers them unchanged.
801        let records = batch::proto_batch_to_records(proto_batch);
802        let accepted = records.len() as u64;
803
804        // ATOMICITY (Phase 4): reserve channel capacity for the WHOLE batch up
805        // front via `try_reserve_many`, BEFORE assigning any sequence number or
806        // enqueuing ANY record. If the channel cannot fit the whole block we
807        // reject all-or-nothing -- no record is admitted, so a retry re-sends
808        // the full block with no partial-acceptance / duplicate window. This is
809        // the contract `send_batch`'s doc claims ("the block is accepted or not
810        // as a unit"). The previous per-record `try_send` loop could enqueue
811        // some records then fail mid-batch, stranding a prefix in the channel.
812        //
813        // An empty batch reserves zero permits (a harmless no-op) and the loop
814        // below does not run, matching the prior empty-batch behaviour.
815        let permits = match self.sender.try_reserve_many(records.len()) {
816            Ok(permits) => permits,
817            Err(mpsc::error::TrySendError::Full(())) => {
818                #[cfg(feature = "metrics")]
819                metrics::counter!(
820                    "dfe_transport_backpressured_total",
821                    "transport" => "grpc"
822                )
823                .increment(1);
824                return Err(Status::resource_exhausted("receiver buffer full"));
825            }
826            Err(mpsc::error::TrySendError::Closed(())) => {
827                #[cfg(feature = "metrics")]
828                metrics::counter!(
829                    "dfe_transport_refused_total",
830                    "transport" => "grpc"
831                )
832                .increment(1);
833                return Err(Status::unavailable("receiver closed"));
834            }
835        };
836
837        // Capacity is now held for every record -- enqueuing is infallible. Pair
838        // each reserved permit with a record and send.
839        for (permit, record) in permits.zip(records) {
840            let seq = self.sequence.fetch_add(1, Ordering::Relaxed);
841            let format = record.metadata.format;
842            // A record carrying Auto means the sender did not pin a format
843            // (e.g. it framed but did not classify). Detect from the bytes so
844            // the receiver still gets a concrete hint -- this inspects the lead
845            // byte only, it does NOT parse/decode the payload.
846            let format = if format == PayloadFormat::Auto {
847                PayloadFormat::detect(&record.payload)
848            } else {
849                format
850            };
851
852            permit.send(Message {
853                key: record.key,
854                payload: record.payload,
855                token: GrpcToken::new(seq),
856                timestamp_ms: record.metadata.timestamp_ms,
857                format,
858            });
859        }
860
861        #[cfg(feature = "metrics")]
862        metrics::counter!(
863            "dfe_transport_sent_total",
864            "transport" => "grpc",
865            "path" => "batch"
866        )
867        .increment(accepted);
868
869        Ok(Response::new(proto::BatchAck { accepted }))
870    }
871
872    async fn health_check(
873        &self,
874        _request: Request<proto::HealthCheckRequest>,
875    ) -> Result<Response<proto::HealthCheckResponse>, Status> {
876        Ok(Response::new(proto::HealthCheckResponse {
877            status: proto::ServingStatus::Serving.into(),
878        }))
879    }
880}
881
882#[cfg(test)]
883mod tests {
884    use super::*;
885
886    #[test]
887    fn grpc_token_display() {
888        let token = GrpcToken::new(42);
889        assert_eq!(format!("{token}"), "grpc:42");
890
891        let token = GrpcToken::with_source(7, Arc::from("peer-1"));
892        assert_eq!(format!("{token}"), "grpc:peer-1:7");
893    }
894
895    #[test]
896    fn grpc_config_defaults() {
897        let config = GrpcConfig::default();
898        assert!(config.listen.is_none());
899        assert!(config.endpoint.is_none());
900        assert_eq!(config.recv_buffer_size, 10_000);
901        assert_eq!(config.recv_timeout_ms, 100);
902        assert_eq!(config.send_timeout_ms, 30_000);
903        assert_eq!(config.max_message_size, 16 * 1024 * 1024);
904        assert!(!config.compression);
905        assert!(!config.tls_enabled);
906        assert!(config.tls_ca_path.is_none());
907    }
908
909    #[test]
910    fn grpc_client_tls_builds_with_private_ca_and_rejects_half_mtls() {
911        use std::io::Write;
912        let cert = rcgen::generate_simple_self_signed(vec!["grpc.test".to_string()]).unwrap();
913        let mut ca = tempfile::NamedTempFile::new().unwrap();
914        ca.write_all(cert.cert.pem().as_bytes()).unwrap();
915        ca.flush().unwrap();
916
917        // Private CA + SNI -> builds.
918        let cfg = GrpcConfig {
919            endpoint: Some("https://peer:6000".to_string()),
920            tls_enabled: true,
921            tls_ca_path: Some(ca.path().to_string_lossy().into_owned()),
922            tls_domain: Some("grpc.test".to_string()),
923            ..Default::default()
924        };
925        assert!(build_grpc_client_tls(&cfg).is_ok());
926
927        // Half-configured mTLS (cert without key) -> error.
928        let cfg = GrpcConfig {
929            tls_enabled: true,
930            tls_client_cert_path: Some(ca.path().to_string_lossy().into_owned()),
931            tls_client_key_path: None,
932            ..Default::default()
933        };
934        assert!(build_grpc_client_tls(&cfg).is_err());
935    }
936
937    #[test]
938    fn grpc_config_server() {
939        let config = GrpcConfig::server("0.0.0.0:6000");
940        assert_eq!(config.listen.as_deref(), Some("0.0.0.0:6000"));
941        assert!(config.endpoint.is_none());
942    }
943
944    #[test]
945    fn grpc_config_client() {
946        let config = GrpcConfig::client("http://loader:6000");
947        assert!(config.listen.is_none());
948        assert_eq!(config.endpoint.as_deref(), Some("http://loader:6000"));
949    }
950
951    #[test]
952    fn grpc_config_with_compression() {
953        let config = GrpcConfig::server("0.0.0.0:6000").with_compression();
954        assert!(config.compression);
955    }
956
957    #[tokio::test]
958    async fn grpc_transport_client_only() {
959        // Client-only transport (lazy connection, no server)
960        let config = GrpcConfig::client("http://localhost:16000");
961        let transport = GrpcTransport::new(&config).await.unwrap();
962
963        assert!(transport.client.is_some());
964        assert!(transport.receiver.is_none());
965        assert!(transport.is_healthy());
966        assert_eq!(transport.name(), "grpc");
967
968        // recv should error (no server)
969        let result = transport.recv(10).await;
970        assert!(result.is_err());
971
972        // commit is always ok
973        transport.commit(&[]).await.unwrap();
974    }
975
976    /// G3: with a pressure governor pinned HIGH, the gRPC Push handler rejects
977    /// with `Status::unavailable` (the gRPC analogue of 503). The default `new`
978    /// (no governor) accepts as before.
979    #[cfg(feature = "governor")]
980    #[tokio::test]
981    async fn grpc_pressure_high_rejects_unavailable() {
982        use crate::governor::{Hysteresis, MemoryPressureSource, PressureSource, UnifiedPressure};
983        use crate::memory::{MemoryGuard, MemoryGuardConfig};
984
985        let guard = Arc::new(MemoryGuard::new(MemoryGuardConfig {
986            limit_bytes: 1000,
987            pressure_threshold: 0.80,
988            ..Default::default()
989        }));
990        guard.add_bytes(950); // 95%
991        let pressure = Arc::new(UnifiedPressure::new(
992            vec![Arc::new(MemoryPressureSource::new(Arc::clone(&guard))) as Arc<dyn PressureSource>],
993            Hysteresis::new(0.80, 0.65).expect("valid band"),
994        ));
995        assert!(pressure.should_hold(), "pinned-high governor must hold");
996
997        // Server bound to the governor.
998        let server_cfg = GrpcConfig::server("127.0.0.1:16077");
999        let server = GrpcTransport::with_pressure(&server_cfg, Some(Arc::clone(&pressure)))
1000            .await
1001            .unwrap();
1002        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1003
1004        // Client pushes -> rejected as backpressure (maps to Backpressured).
1005        let client_cfg = GrpcConfig::client("http://127.0.0.1:16077");
1006        let client = GrpcTransport::new(&client_cfg).await.unwrap();
1007        let result = client
1008            .send("events", bytes::Bytes::from_static(b"{\"x\":1}"))
1009            .await;
1010        assert!(
1011            matches!(result, SendResult::Backpressured),
1012            "push under pressure must surface as backpressure, got {result:?}"
1013        );
1014
1015        client.close().await.unwrap();
1016        server.close().await.unwrap();
1017    }
1018
1019    #[tokio::test]
1020    async fn grpc_transport_server_only() {
1021        // Server-only transport (no client for sending)
1022        // Note: port 0 may not work with tonic parse, use a specific port
1023        let config = GrpcConfig::server("127.0.0.1:16001");
1024        let transport = GrpcTransport::new(&config).await.unwrap();
1025
1026        assert!(transport.client.is_none());
1027        assert!(transport.receiver.is_some());
1028        assert!(transport.is_healthy());
1029
1030        // send should error (no client)
1031        let result = transport
1032            .send("test", bytes::Bytes::from_static(b"payload"))
1033            .await;
1034        assert!(result.is_fatal());
1035
1036        // Close
1037        transport.close().await.unwrap();
1038        assert!(!transport.is_healthy());
1039    }
1040}