use std::sync::Arc;
use async_trait::async_trait;
use dashmap::DashMap;
use opentelemetry::trace::{SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId, TraceState};
use tokio::sync::Mutex;
use tonic::transport::Channel;
use tonic::{Request, Response, Status};
use tracing::{self, instrument};
use tracing_opentelemetry::OpenTelemetrySpanExt;
use crate::envelope::{Envelope, EnvelopeRequest};
use crate::error::ClusterError;
use crate::message::ReplyReceiver;
use crate::reply::{fallback_reply_id, ExitResult, Reply, ReplyWithExit, EXIT_SEQUENCE};
use crate::runner_health::RunnerHealth;
use crate::runners::Runners;
use crate::sharding::Sharding;
use crate::types::RunnerAddress;
pub mod proto {
tonic::include_proto!("cluster");
}
use proto::runner_service_client::RunnerServiceClient;
use proto::runner_service_server::{RunnerService, RunnerServiceServer};
pub struct GrpcRunners {
channels: DashMap<String, RunnerServiceClient<Channel>>,
connect_locks: DashMap<String, Arc<Mutex<()>>>,
connect_timeout: std::time::Duration,
}
impl GrpcRunners {
pub fn new() -> Self {
Self {
channels: DashMap::new(),
connect_locks: DashMap::new(),
connect_timeout: std::time::Duration::from_secs(5),
}
}
pub fn with_connect_timeout(connect_timeout: std::time::Duration) -> Self {
Self {
channels: DashMap::new(),
connect_locks: DashMap::new(),
connect_timeout,
}
}
#[instrument(level = "debug", skip(self), fields(runner_address = %address))]
async fn client_for(
&self,
address: &RunnerAddress,
) -> Result<RunnerServiceClient<Channel>, ClusterError> {
let key = address.to_string();
if let Some(client) = self.channels.get(&key) {
return Ok(client.clone());
}
let lock = self
.connect_locks
.entry(key.clone())
.or_insert_with(|| Arc::new(Mutex::new(())))
.clone();
let _guard = lock.lock().await;
if let Some(client) = self.channels.get(&key) {
return Ok(client.clone());
}
let endpoint = format!("http://{}:{}", address.host, address.port);
let channel = Channel::from_shared(endpoint)
.map_err(|e| ClusterError::RunnerUnavailable {
address: address.clone(),
source: Some(Box::new(e)),
})?
.connect_timeout(self.connect_timeout)
.connect()
.await
.map_err(|e| ClusterError::RunnerUnavailable {
address: address.clone(),
source: Some(Box::new(e)),
})?;
let client = RunnerServiceClient::new(channel);
self.channels.insert(key, client.clone());
Ok(client)
}
fn remove_connection(&self, address: &RunnerAddress) {
let key = address.to_string();
self.channels.remove(&key);
self.connect_locks.remove(&key);
}
pub fn cleanup_orphaned_locks(&self) {
self.connect_locks
.retain(|key, _| self.channels.contains_key(key));
}
}
impl Default for GrpcRunners {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Runners for GrpcRunners {
#[instrument(level = "debug", skip(self), fields(runner_address = %address))]
async fn ping(&self, address: &RunnerAddress) -> Result<(), ClusterError> {
let mut client = self.client_for(address).await?;
client
.ping(Request::new(proto::PingRequest {}))
.await
.map_err(|e| ClusterError::RunnerUnavailable {
address: address.clone(),
source: Some(Box::new(e)),
})?;
Ok(())
}
#[instrument(name = "deliver", skip(self, envelope), fields(runner_address = %address))]
async fn send(
&self,
address: &RunnerAddress,
mut envelope: EnvelopeRequest,
) -> Result<ReplyReceiver, ClusterError> {
let mut client = self.client_for(address).await?;
let request_id = envelope.request_id;
{
let context = tracing::Span::current().context();
let span_ref = context.span();
let sc = span_ref.span_context();
if sc.is_valid() {
envelope.trace_id = Some(sc.trace_id().to_string());
envelope.span_id = Some(sc.span_id().to_string());
envelope.sampled = Some(sc.trace_flags().is_sampled());
}
}
let envelope_bytes = rmp_serde::to_vec(&Envelope::Request(envelope)).map_err(|e| {
ClusterError::MalformedMessage {
reason: "failed to serialize envelope".into(),
source: Some(Box::new(e)),
}
})?;
let response = client
.send(Request::new(proto::SendRequest {
envelope: envelope_bytes,
}))
.await
.map_err(|e| ClusterError::RunnerUnavailable {
address: address.clone(),
source: Some(Box::new(e)),
})?;
let mut stream = response.into_inner();
let (tx, rx) = tokio::sync::mpsc::channel::<Reply>(16);
tokio::spawn(async move {
loop {
tokio::select! {
biased;
_ = tx.closed() => {
break;
}
result = stream.message() => {
match result {
Ok(Some(msg)) => {
match rmp_serde::from_slice::<Reply>(&msg.reply) {
Ok(reply) => {
if tx.send(reply).await.is_err() {
break; }
}
Err(e) => {
tracing::warn!(error = %e, "failed to deserialize reply from gRPC stream");
let failure = Reply::WithExit(ReplyWithExit {
request_id,
id: fallback_reply_id(request_id, EXIT_SEQUENCE),
exit: ExitResult::Failure(format!(
"gRPC reply deserialization failed: {e}"
)),
});
let _ = tx.send(failure).await;
break;
}
}
}
Ok(None) => break, Err(e) => {
tracing::warn!(error = %e, "gRPC reply stream error");
let failure = Reply::WithExit(ReplyWithExit {
request_id,
id: fallback_reply_id(request_id, EXIT_SEQUENCE),
exit: ExitResult::Failure(format!(
"gRPC reply stream error: {e}"
)),
});
let _ = tx.send(failure).await;
break;
}
}
}
}
}
});
Ok(rx)
}
#[instrument(skip(self, envelope), fields(runner_address = %address))]
async fn notify(
&self,
address: &RunnerAddress,
envelope: Envelope,
) -> Result<(), ClusterError> {
let mut client = self.client_for(address).await?;
let envelope_bytes =
rmp_serde::to_vec(&envelope).map_err(|e| ClusterError::MalformedMessage {
reason: "failed to serialize envelope".into(),
source: Some(Box::new(e)),
})?;
client
.notify(Request::new(proto::NotifyRequest {
envelope: envelope_bytes,
}))
.await
.map_err(|e| ClusterError::RunnerUnavailable {
address: address.clone(),
source: Some(Box::new(e)),
})?;
Ok(())
}
async fn on_runner_unavailable(&self, address: &RunnerAddress) -> Result<(), ClusterError> {
self.remove_connection(address);
tracing::info!(%address, "removed cached gRPC connection for unavailable runner");
Ok(())
}
}
pub struct GrpcRunnerServer {
sharding: Arc<dyn Sharding>,
}
impl GrpcRunnerServer {
pub fn new(sharding: Arc<dyn Sharding>) -> Self {
Self { sharding }
}
pub fn into_service(self) -> RunnerServiceServer<Self> {
RunnerServiceServer::new(self)
}
}
#[tonic::async_trait]
impl RunnerService for GrpcRunnerServer {
#[instrument(level = "debug", skip_all)]
async fn ping(
&self,
_request: Request<proto::PingRequest>,
) -> Result<Response<proto::PingResponse>, Status> {
Ok(Response::new(proto::PingResponse {}))
}
type SendStream = tokio_stream::wrappers::ReceiverStream<Result<proto::ReplyMessage, Status>>;
#[instrument(name = "receive", skip_all)]
async fn send(
&self,
request: Request<proto::SendRequest>,
) -> Result<Response<Self::SendStream>, Status> {
let envelope_bytes = request.into_inner().envelope;
let envelope: EnvelopeRequest = match rmp_serde::from_slice::<Envelope>(&envelope_bytes) {
Ok(Envelope::Request(envelope)) => envelope,
Ok(_) => {
return Err(Status::invalid_argument(
"expected Request envelope for send",
))
}
Err(_) => rmp_serde::from_slice(&envelope_bytes)
.map_err(|e| Status::invalid_argument(format!("malformed envelope: {e}")))?,
};
if let (Some(trace_id_str), Some(span_id_str)) = (&envelope.trace_id, &envelope.span_id) {
if let (Ok(tid), Ok(sid)) = (
TraceId::from_hex(trace_id_str),
SpanId::from_hex(span_id_str),
) {
let flags = if envelope.sampled.unwrap_or(false) {
TraceFlags::SAMPLED
} else {
TraceFlags::default()
};
let remote_context = SpanContext::new(tid, sid, flags, true, TraceState::default());
let otel_context =
opentelemetry::Context::current().with_remote_span_context(remote_context);
tracing::Span::current().set_parent(otel_context);
}
}
let mut reply_rx = self
.sharding
.send(envelope)
.await
.map_err(|e| Status::internal(e.to_string()))?;
let (tx, rx) = tokio::sync::mpsc::channel(16);
tokio::spawn(async move {
while let Some(reply) = reply_rx.recv().await {
match rmp_serde::to_vec(&reply) {
Ok(bytes) => {
let msg = proto::ReplyMessage { reply: bytes };
if tx.send(Ok(msg)).await.is_err() {
break;
}
}
Err(e) => {
let _ = tx
.send(Err(Status::internal(format!(
"failed to serialize reply: {e}"
))))
.await;
break;
}
}
}
});
Ok(Response::new(tokio_stream::wrappers::ReceiverStream::new(
rx,
)))
}
#[instrument(skip_all)]
async fn notify(
&self,
request: Request<proto::NotifyRequest>,
) -> Result<Response<proto::NotifyResponse>, Status> {
let envelope_bytes = request.into_inner().envelope;
let envelope = match rmp_serde::from_slice::<Envelope>(&envelope_bytes) {
Ok(envelope) => envelope,
Err(_) => {
let envelope: EnvelopeRequest = rmp_serde::from_slice(&envelope_bytes)
.map_err(|e| Status::invalid_argument(format!("malformed envelope: {e}")))?;
Envelope::Request(envelope)
}
};
match envelope {
Envelope::Request(envelope) => {
self.sharding
.notify(envelope)
.await
.map_err(|e| Status::internal(e.to_string()))?;
}
Envelope::AckChunk(ack) => {
self.sharding
.ack_chunk(ack)
.await
.map_err(|e| Status::internal(e.to_string()))?;
}
Envelope::Interrupt(interrupt) => {
self.sharding
.interrupt(interrupt)
.await
.map_err(|e| Status::internal(e.to_string()))?;
}
}
Ok(Response::new(proto::NotifyResponse {}))
}
}
pub struct GrpcRunnerHealth {
runners: Arc<GrpcRunners>,
}
impl GrpcRunnerHealth {
pub fn new(runners: Arc<GrpcRunners>) -> Self {
Self { runners }
}
}
#[async_trait]
impl RunnerHealth for GrpcRunnerHealth {
async fn is_alive(&self, address: &RunnerAddress) -> Result<bool, ClusterError> {
match self.runners.ping(address).await {
Ok(()) => Ok(true),
Err(ClusterError::RunnerUnavailable { .. }) => Ok(false),
Err(e) => Err(e),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::ShardingConfig;
use crate::entity::{Entity, EntityContext, EntityHandler};
use crate::metrics::ClusterMetrics;
use crate::reply::{ExitResult, Reply};
use crate::sharding_impl::ShardingImpl;
use crate::snowflake::Snowflake;
use crate::storage::noop_runners::NoopRunners;
use crate::types::{EntityAddress, EntityId, EntityType};
use std::collections::HashMap;
use std::net::SocketAddr;
struct EchoEntity;
#[async_trait]
impl Entity for EchoEntity {
fn entity_type(&self) -> EntityType {
EntityType::new("Echo")
}
async fn spawn(&self, _ctx: EntityContext) -> Result<Box<dyn EntityHandler>, ClusterError> {
Ok(Box::new(EchoHandler))
}
}
struct EchoHandler;
#[async_trait]
impl EntityHandler for EchoHandler {
async fn handle_request(
&self,
_tag: &str,
payload: &[u8],
_headers: &HashMap<String, String>,
) -> Result<Vec<u8>, ClusterError> {
Ok(payload.to_vec())
}
}
async fn start_test_server(sharding: Arc<dyn Sharding>) -> SocketAddr {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server = GrpcRunnerServer::new(sharding);
let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
tokio::spawn(async move {
tonic::transport::Server::builder()
.add_service(server.into_service())
.serve_with_incoming(incoming)
.await
.unwrap();
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
addr
}
fn make_sharding(addr: SocketAddr) -> Arc<ShardingImpl> {
let config = Arc::new(ShardingConfig {
runner_address: RunnerAddress::new(addr.ip().to_string(), addr.port()),
shard_groups: vec!["default".to_string()],
shards_per_group: 10,
..Default::default()
});
let runners: Arc<dyn Runners> = Arc::new(NoopRunners);
let metrics = Arc::new(ClusterMetrics::unregistered());
ShardingImpl::new(config, runners, None, None, None, metrics).unwrap()
}
#[tokio::test]
async fn grpc_ping_roundtrip() {
let sharding = make_sharding("127.0.0.1:0".parse().unwrap());
let addr = start_test_server(sharding).await;
let runners = GrpcRunners::new();
let address = RunnerAddress::new(addr.ip().to_string(), addr.port());
runners.ping(&address).await.unwrap();
}
#[tokio::test]
async fn grpc_send_and_receive_reply() {
let sharding = make_sharding("127.0.0.1:0".parse().unwrap());
sharding.acquire_all_shards().await;
sharding
.register_entity(Arc::new(EchoEntity))
.await
.unwrap();
let addr = start_test_server(sharding.clone()).await;
let runners = GrpcRunners::new();
let address = RunnerAddress::new(addr.ip().to_string(), addr.port());
let eid = EntityId::new("e-1");
let shard = sharding.get_shard_id(&EntityType::new("Echo"), &eid);
let payload = rmp_serde::to_vec(&"hello").unwrap();
let envelope = EnvelopeRequest {
request_id: sharding.snowflake().next_async().await.unwrap(),
address: EntityAddress {
shard_id: shard,
entity_type: EntityType::new("Echo"),
entity_id: eid,
},
tag: "echo".into(),
payload: payload.clone(),
headers: HashMap::new(),
span_id: None,
trace_id: None,
sampled: None,
persisted: false,
uninterruptible: Default::default(),
deliver_at: None,
};
let mut rx = runners.send(&address, envelope).await.unwrap();
let reply = rx.recv().await.unwrap();
match reply {
Reply::WithExit(r) => match r.exit {
ExitResult::Success(bytes) => assert_eq!(bytes, payload),
ExitResult::Failure(msg) => panic!("unexpected failure: {msg}"),
},
Reply::Chunk(_) => panic!("unexpected chunk"),
}
}
#[tokio::test]
async fn grpc_notify_fire_and_forget() {
let sharding = make_sharding("127.0.0.1:0".parse().unwrap());
sharding.acquire_all_shards().await;
sharding
.register_entity(Arc::new(EchoEntity))
.await
.unwrap();
let addr = start_test_server(sharding.clone()).await;
let runners = GrpcRunners::new();
let address = RunnerAddress::new(addr.ip().to_string(), addr.port());
let eid = EntityId::new("e-2");
let shard = sharding.get_shard_id(&EntityType::new("Echo"), &eid);
let envelope = EnvelopeRequest {
request_id: sharding.snowflake().next_async().await.unwrap(),
address: EntityAddress {
shard_id: shard,
entity_type: EntityType::new("Echo"),
entity_id: eid,
},
tag: "ping".into(),
payload: vec![],
headers: HashMap::new(),
span_id: None,
trace_id: None,
sampled: None,
persisted: false,
uninterruptible: Default::default(),
deliver_at: None,
};
runners
.notify(&address, Envelope::Request(envelope))
.await
.unwrap();
}
#[tokio::test]
async fn grpc_ping_unavailable_runner() {
let runners = GrpcRunners::new();
let address = RunnerAddress::new("127.0.0.1", 1);
let result = runners.ping(&address).await;
assert!(result.is_err());
}
#[tokio::test]
async fn grpc_health_check_alive() {
let sharding = make_sharding("127.0.0.1:0".parse().unwrap());
let addr = start_test_server(sharding).await;
let runners = Arc::new(GrpcRunners::new());
let health = GrpcRunnerHealth::new(Arc::clone(&runners));
let address = RunnerAddress::new(addr.ip().to_string(), addr.port());
assert!(health.is_alive(&address).await.unwrap());
}
#[tokio::test]
async fn grpc_health_check_dead() {
let runners = Arc::new(GrpcRunners::new());
let health = GrpcRunnerHealth::new(Arc::clone(&runners));
let address = RunnerAddress::new("127.0.0.1", 1);
assert!(!health.is_alive(&address).await.unwrap());
}
#[tokio::test]
async fn grpc_connection_pooling() {
let sharding = make_sharding("127.0.0.1:0".parse().unwrap());
let addr = start_test_server(sharding).await;
let runners = GrpcRunners::new();
let address = RunnerAddress::new(addr.ip().to_string(), addr.port());
runners.ping(&address).await.unwrap();
assert_eq!(runners.channels.len(), 1);
runners.ping(&address).await.unwrap();
assert_eq!(runners.channels.len(), 1);
}
#[test]
fn trace_context_survives_envelope_serde_roundtrip() {
let envelope = EnvelopeRequest {
request_id: Snowflake(999),
address: EntityAddress {
shard_id: crate::types::ShardId::new("default", 0),
entity_type: EntityType::new("Test"),
entity_id: EntityId::new("t-1"),
},
tag: "test".into(),
payload: vec![],
headers: HashMap::new(),
span_id: Some("e457b5a2e4d86bd1".into()),
trace_id: Some("0af7651916cd43dd8448eb211c80319c".into()),
sampled: Some(true),
persisted: false,
uninterruptible: Default::default(),
deliver_at: None,
};
let bytes = rmp_serde::to_vec(&envelope).unwrap();
let decoded: EnvelopeRequest = rmp_serde::from_slice(&bytes).unwrap();
assert_eq!(
decoded.trace_id.as_deref(),
Some("0af7651916cd43dd8448eb211c80319c")
);
assert_eq!(decoded.span_id.as_deref(), Some("e457b5a2e4d86bd1"));
assert_eq!(decoded.sampled, Some(true));
}
#[test]
fn trace_context_none_survives_envelope_serde_roundtrip() {
let envelope = EnvelopeRequest {
request_id: Snowflake(999),
address: EntityAddress {
shard_id: crate::types::ShardId::new("default", 0),
entity_type: EntityType::new("Test"),
entity_id: EntityId::new("t-1"),
},
tag: "test".into(),
payload: vec![],
headers: HashMap::new(),
span_id: None,
trace_id: None,
sampled: None,
persisted: false,
uninterruptible: Default::default(),
deliver_at: None,
};
let bytes = rmp_serde::to_vec(&envelope).unwrap();
let decoded: EnvelopeRequest = rmp_serde::from_slice(&bytes).unwrap();
assert_eq!(decoded.trace_id, None);
assert_eq!(decoded.span_id, None);
assert_eq!(decoded.sampled, None);
}
#[tokio::test]
async fn grpc_connect_timeout_is_applied() {
let runners = GrpcRunners::with_connect_timeout(std::time::Duration::from_millis(100));
let address = RunnerAddress::new("192.0.2.1", 9999);
let start = std::time::Instant::now();
let result = runners.ping(&address).await;
let elapsed = start.elapsed();
assert!(result.is_err());
assert!(
elapsed < std::time::Duration::from_secs(10),
"connect should have timed out quickly, took {:?}",
elapsed
);
}
#[tokio::test]
async fn grpc_on_runner_unavailable_removes_connection() {
let sharding = make_sharding("127.0.0.1:0".parse().unwrap());
let addr = start_test_server(sharding).await;
let runners = GrpcRunners::new();
let address = RunnerAddress::new(addr.ip().to_string(), addr.port());
runners.ping(&address).await.unwrap();
assert_eq!(runners.channels.len(), 1);
runners.on_runner_unavailable(&address).await.unwrap();
assert_eq!(runners.channels.len(), 0);
}
}