pub mod config;
pub mod proto;
pub mod token;
pub use config::GrpcConfig;
pub use token::GrpcToken;
use super::error::{TransportError, TransportResult};
use super::traits::{TransportBase, TransportReceiver, TransportSender};
use super::types::{Message, PayloadFormat, SendResult};
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use tokio::sync::{mpsc, oneshot};
use tonic::{Request, Response, Status};
pub struct GrpcTransport {
client: Option<proto::dfe_transport_client::DfeTransportClient<tonic::transport::Channel>>,
receiver: Option<tokio::sync::Mutex<mpsc::Receiver<Message<GrpcToken>>>>,
shutdown_tx: Option<oneshot::Sender<()>>,
_server_handle: Option<tokio::task::JoinHandle<Result<(), tonic::transport::Error>>>,
closed: AtomicBool,
healthy: Arc<AtomicBool>,
recv_timeout_ms: u64,
#[cfg(feature = "metrics")]
inflight: AtomicU64,
filter_engine: super::filter::TransportFilterEngine,
filtered_dlq_buffer: parking_lot::Mutex<Vec<super::filter::FilteredDlqEntry>>,
}
impl GrpcTransport {
pub async fn new(config: &GrpcConfig) -> TransportResult<Self> {
let mut client = None;
let mut receiver = None;
let mut shutdown_tx = None;
let mut server_handle = None;
let sequence = Arc::new(AtomicU64::new(0));
if let Some(endpoint) = &config.endpoint {
let channel = tonic::transport::Channel::from_shared(endpoint.clone())
.map_err(|e| TransportError::Config(format!("invalid endpoint: {e}")))?
.connect_lazy();
let mut c = proto::dfe_transport_client::DfeTransportClient::new(channel)
.max_decoding_message_size(config.max_message_size)
.max_encoding_message_size(config.max_message_size);
if config.compression {
c = c
.send_compressed(tonic::codec::CompressionEncoding::Gzip)
.accept_compressed(tonic::codec::CompressionEncoding::Gzip);
}
client = Some(c);
}
if let Some(listen) = &config.listen {
let addr: std::net::SocketAddr = listen
.parse()
.map_err(|e| TransportError::Config(format!("invalid listen address: {e}")))?;
let (tx, rx) = mpsc::channel(config.recv_buffer_size);
let (sd_tx, sd_rx) = oneshot::channel();
let dfe_svc = DfeTransportServiceImpl {
sender: tx.clone(),
sequence: sequence.clone(),
};
let dfe_server = proto::dfe_transport_server::DfeTransportServer::new(dfe_svc)
.max_decoding_message_size(config.max_message_size)
.max_encoding_message_size(config.max_message_size)
.accept_compressed(tonic::codec::CompressionEncoding::Gzip)
.send_compressed(tonic::codec::CompressionEncoding::Gzip);
let mut builder = tonic::transport::Server::builder();
#[cfg(feature = "transport-grpc-vector-compat")]
let router = if config.vector_compat {
let vector_svc =
super::vector_compat::source::VectorCompatService::new(tx, sequence.clone());
let vector_server =
super::vector_compat::proto::vector::vector_server::VectorServer::new(
vector_svc,
)
.max_decoding_message_size(config.max_message_size)
.accept_compressed(tonic::codec::CompressionEncoding::Gzip)
.send_compressed(tonic::codec::CompressionEncoding::Gzip);
builder.add_service(dfe_server).add_service(vector_server)
} else {
builder.add_service(dfe_server)
};
#[cfg(not(feature = "transport-grpc-vector-compat"))]
let router = builder.add_service(dfe_server);
let handle = tokio::spawn(async move {
router
.serve_with_shutdown(addr, async {
sd_rx.await.ok();
})
.await
});
receiver = Some(tokio::sync::Mutex::new(rx));
shutdown_tx = Some(sd_tx);
server_handle = Some(handle);
}
let healthy = Arc::new(AtomicBool::new(true));
let filter_engine = super::filter::TransportFilterEngine::new(
&config.filters_in,
&config.filters_out,
&crate::transport::filter::TransportFilterTierConfig::default(),
)?;
#[cfg(feature = "health")]
{
let h = Arc::clone(&healthy);
crate::health::HealthRegistry::register("transport:grpc", move || {
if h.load(Ordering::Relaxed) {
crate::health::HealthStatus::Healthy
} else {
crate::health::HealthStatus::Unhealthy
}
});
}
Ok(Self {
client,
receiver,
shutdown_tx,
_server_handle: server_handle,
closed: AtomicBool::new(false),
healthy,
recv_timeout_ms: config.recv_timeout_ms,
#[cfg(feature = "metrics")]
inflight: AtomicU64::new(0),
filter_engine,
filtered_dlq_buffer: parking_lot::Mutex::new(Vec::new()),
})
}
}
impl TransportBase for GrpcTransport {
async fn close(&self) -> TransportResult<()> {
self.closed.store(true, Ordering::Relaxed);
self.healthy.store(false, Ordering::Relaxed);
Ok(())
}
fn is_healthy(&self) -> bool {
let healthy = self.healthy.load(Ordering::Relaxed);
#[cfg(feature = "metrics")]
metrics::gauge!("dfe_transport_healthy", "transport" => "grpc").set(if healthy {
1.0
} else {
0.0
});
healthy
}
fn name(&self) -> &'static str {
"grpc"
}
}
impl TransportSender for GrpcTransport {
async fn send(&self, key: &str, payload: &[u8]) -> SendResult {
if self.closed.load(Ordering::Relaxed) {
return SendResult::Fatal(TransportError::Closed);
}
if self.filter_engine.has_outbound_filters() {
match self.filter_engine.apply_outbound(payload) {
super::filter::FilterDisposition::Pass => {}
super::filter::FilterDisposition::Drop => return SendResult::Ok,
super::filter::FilterDisposition::Dlq => return SendResult::FilteredDlq,
}
}
let Some(client) = &self.client else {
return SendResult::Fatal(TransportError::Config(
"no endpoint configured for sending".into(),
));
};
let mut metadata = HashMap::new();
if !key.is_empty() {
metadata.insert("topic".to_string(), key.to_string());
}
#[cfg(feature = "otel")]
if let Some(tp) = super::propagation::current_traceparent() {
metadata.insert(super::propagation::TRACEPARENT_HEADER.to_string(), tp);
}
let request = proto::PushRequest {
payload: payload.to_vec(),
format: proto::Format::Auto.into(),
metadata,
};
#[cfg(feature = "metrics")]
let start = std::time::Instant::now();
#[cfg(feature = "metrics")]
self.inflight.fetch_add(1, Ordering::Relaxed);
let result = match client.clone().push(request).await {
Ok(_) => {
#[cfg(feature = "metrics")]
metrics::counter!("dfe_transport_sent_total", "transport" => "grpc").increment(1);
SendResult::Ok
}
Err(status) => match status.code() {
tonic::Code::Unavailable | tonic::Code::ResourceExhausted => {
#[cfg(feature = "metrics")]
metrics::counter!(
"dfe_transport_backpressured_total",
"transport" => "grpc"
)
.increment(1);
SendResult::Backpressured
}
_ => {
#[cfg(feature = "metrics")]
metrics::counter!(
"dfe_transport_send_errors_total",
"transport" => "grpc"
)
.increment(1);
SendResult::Fatal(TransportError::Send(status.message().to_string()))
}
},
};
#[cfg(feature = "metrics")]
{
self.inflight.fetch_sub(1, Ordering::Relaxed);
metrics::gauge!("dfe_transport_inflight", "transport" => "grpc")
.set(self.inflight.load(Ordering::Relaxed) as f64);
metrics::histogram!(
"dfe_transport_send_duration_seconds",
"transport" => "grpc"
)
.record(start.elapsed().as_secs_f64());
}
result
}
}
impl TransportReceiver for GrpcTransport {
type Token = GrpcToken;
async fn recv(&self, max: usize) -> TransportResult<Vec<Message<Self::Token>>> {
if self.closed.load(Ordering::Relaxed) {
return Err(TransportError::Closed);
}
let Some(receiver) = &self.receiver else {
return Err(TransportError::Config(
"no listen address configured for receiving".into(),
));
};
let mut rx = receiver.lock().await;
let mut messages = Vec::with_capacity(max.min(100));
for _ in 0..max {
let result = if self.recv_timeout_ms == 0 {
match rx.try_recv() {
Ok(msg) => Some(msg),
Err(mpsc::error::TryRecvError::Empty) => break,
Err(mpsc::error::TryRecvError::Disconnected) => {
return Err(TransportError::Closed);
}
}
} else if messages.is_empty() {
match tokio::time::timeout(
std::time::Duration::from_millis(self.recv_timeout_ms),
rx.recv(),
)
.await
{
Ok(Some(msg)) => Some(msg),
Ok(None) => return Err(TransportError::Closed),
Err(_) => break, }
} else {
match rx.try_recv() {
Ok(msg) => Some(msg),
Err(_) => break,
}
};
if let Some(msg) = result {
messages.push(msg);
}
}
if self.filter_engine.has_inbound_filters() {
let mut staged_dlq: Vec<super::filter::FilteredDlqEntry> = Vec::new();
messages.retain(|msg| match self.filter_engine.apply_inbound(&msg.payload) {
super::filter::FilterDisposition::Pass => true,
super::filter::FilterDisposition::Drop => false,
super::filter::FilterDisposition::Dlq => {
staged_dlq.push(super::filter::FilteredDlqEntry {
payload: msg.payload.clone(),
key: msg.key.clone(),
reason: "transport filter".to_string(),
});
false
}
});
if !staged_dlq.is_empty() {
self.filtered_dlq_buffer.lock().extend(staged_dlq);
}
}
Ok(messages)
}
fn take_filtered_dlq_entries(&self) -> Vec<super::filter::FilteredDlqEntry> {
std::mem::take(&mut *self.filtered_dlq_buffer.lock())
}
async fn commit(&self, _tokens: &[Self::Token]) -> TransportResult<()> {
Ok(())
}
}
impl Drop for GrpcTransport {
fn drop(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
}
}
struct DfeTransportServiceImpl {
sender: mpsc::Sender<Message<GrpcToken>>,
sequence: Arc<AtomicU64>,
}
#[tonic::async_trait]
impl proto::dfe_transport_server::DfeTransport for DfeTransportServiceImpl {
async fn push(
&self,
request: Request<proto::PushRequest>,
) -> Result<Response<proto::PushResponse>, Status> {
let req = request.into_inner();
let seq = self.sequence.fetch_add(1, Ordering::Relaxed);
#[cfg(feature = "otel")]
if let Some(tp) = req.metadata.get(super::propagation::TRACEPARENT_HEADER)
&& super::propagation::is_valid_traceparent(tp)
{
tracing::Span::current().record("traceparent", tp.as_str());
}
let format = PayloadFormat::detect(&req.payload);
let key = req.metadata.get("topic").map(|s| Arc::from(s.as_str()));
let msg = Message {
key,
payload: req.payload,
token: GrpcToken::new(seq),
timestamp_ms: None,
format,
};
match self.sender.try_send(msg) {
Ok(()) => {
#[cfg(feature = "metrics")]
{
metrics::counter!("dfe_transport_sent_total", "transport" => "grpc")
.increment(1);
metrics::gauge!("dfe_transport_queue_size", "transport" => "grpc").set(
self.sender
.max_capacity()
.saturating_sub(self.sender.capacity()) as f64,
);
}
Ok(Response::new(proto::PushResponse { accepted: 1 }))
}
Err(mpsc::error::TrySendError::Full(_)) => {
#[cfg(feature = "metrics")]
metrics::counter!(
"dfe_transport_backpressured_total",
"transport" => "grpc"
)
.increment(1);
Err(Status::resource_exhausted("receiver buffer full"))
}
Err(mpsc::error::TrySendError::Closed(_)) => {
#[cfg(feature = "metrics")]
metrics::counter!(
"dfe_transport_refused_total",
"transport" => "grpc"
)
.increment(1);
Err(Status::unavailable("receiver closed"))
}
}
}
async fn health_check(
&self,
_request: Request<proto::HealthCheckRequest>,
) -> Result<Response<proto::HealthCheckResponse>, Status> {
Ok(Response::new(proto::HealthCheckResponse {
status: proto::ServingStatus::Serving.into(),
}))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn grpc_token_display() {
let token = GrpcToken::new(42);
assert_eq!(format!("{token}"), "grpc:42");
let token = GrpcToken::with_source(7, Arc::from("peer-1"));
assert_eq!(format!("{token}"), "grpc:peer-1:7");
}
#[test]
fn grpc_config_defaults() {
let config = GrpcConfig::default();
assert!(config.listen.is_none());
assert!(config.endpoint.is_none());
assert_eq!(config.recv_buffer_size, 10_000);
assert_eq!(config.recv_timeout_ms, 100);
assert_eq!(config.max_message_size, 16 * 1024 * 1024);
assert!(!config.compression);
}
#[test]
fn grpc_config_server() {
let config = GrpcConfig::server("0.0.0.0:6000");
assert_eq!(config.listen.as_deref(), Some("0.0.0.0:6000"));
assert!(config.endpoint.is_none());
}
#[test]
fn grpc_config_client() {
let config = GrpcConfig::client("http://loader:6000");
assert!(config.listen.is_none());
assert_eq!(config.endpoint.as_deref(), Some("http://loader:6000"));
}
#[test]
fn grpc_config_with_compression() {
let config = GrpcConfig::server("0.0.0.0:6000").with_compression();
assert!(config.compression);
}
#[tokio::test]
async fn grpc_transport_client_only() {
let config = GrpcConfig::client("http://localhost:16000");
let transport = GrpcTransport::new(&config).await.unwrap();
assert!(transport.client.is_some());
assert!(transport.receiver.is_none());
assert!(transport.is_healthy());
assert_eq!(transport.name(), "grpc");
let result = transport.recv(10).await;
assert!(result.is_err());
transport.commit(&[]).await.unwrap();
}
#[tokio::test]
async fn grpc_transport_server_only() {
let config = GrpcConfig::server("127.0.0.1:16001");
let transport = GrpcTransport::new(&config).await.unwrap();
assert!(transport.client.is_none());
assert!(transport.receiver.is_some());
assert!(transport.is_healthy());
let result = transport.send("test", b"payload").await;
assert!(result.is_fatal());
transport.close().await.unwrap();
assert!(!transport.is_healthy());
}
}