use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use futures::StreamExt;
use rand::{Rng, thread_rng};
use tokio::sync::{Mutex, RwLock, Semaphore, mpsc};
use tokio::time::{MissedTickBehavior, interval};
use crate::auth::OttProvider;
use crate::connector::{BaseConnector, ConnectorConfig};
use crate::error::{ConnectorError, Result};
use crate::logger::Logger;
use crate::multi::MultiTransportOptions;
use crate::multi::RegistrationKey;
use crate::multi::shared_channel::{SharedChannel, SharedStream};
use crate::types::ConnectorMetrics;
use crate::types::{ExecuteRequest as SdkExecuteRequest, ExecuteResponse, PayloadEncoding};
use crate::utils::{deserialize_payload, error_response, sanitize_identifier, serialize_payload};
use strike48_proto::proto::{
self, ConnectorCapabilities, CredentialsIssued, HeartbeatRequest, HeartbeatResponse,
InstanceMetadata, RegisterConnectorRequest, StreamMessage, stream_message,
};
pub(crate) struct RegistrationRunner {
pub key: RegistrationKey,
pub config: Arc<RwLock<ConnectorConfig>>,
pub connector: Arc<dyn BaseConnector>,
pub shared_channel: Arc<SharedChannel>,
pub shutdown: Arc<AtomicBool>,
pub metrics: Arc<Mutex<ConnectorMetrics>>,
pub opts: MultiTransportOptions,
pub request_semaphore: Arc<Semaphore>,
pub session_token: Arc<RwLock<Option<String>>>,
}
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30);
const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(45);
const SHUTDOWN_POLL: Duration = Duration::from_millis(100);
const RECONNECT_POLL: Duration = Duration::from_millis(50);
#[derive(Debug)]
enum StreamOutcome {
Shutdown,
ServerClosed,
StreamError(String),
HeartbeatTimeout,
OutboundClosed,
}
impl StreamOutcome {
fn summary(&self) -> &'static str {
match self {
StreamOutcome::Shutdown => "shutdown",
StreamOutcome::ServerClosed => "server-closed",
StreamOutcome::StreamError(_) => "stream-error",
StreamOutcome::HeartbeatTimeout => "heartbeat-timeout",
StreamOutcome::OutboundClosed => "outbound-closed",
}
}
}
impl std::fmt::Display for StreamOutcome {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
StreamOutcome::StreamError(s) => write!(f, "stream-error: {s}"),
other => write!(f, "{}", other.summary()),
}
}
}
fn compute_backoff(opts: &MultiTransportOptions, attempt: u32) -> Duration {
let base = opts.reconnect_delay_ms;
let max = opts.max_backoff_delay_ms;
let exp = (attempt.saturating_sub(1)).min(20);
let scaled = base.saturating_mul(1u64 << exp);
let jitter = if opts.reconnect_jitter_ms > 0 {
thread_rng().gen_range(0..=opts.reconnect_jitter_ms)
} else {
0
};
let with_jitter = scaled.saturating_add(jitter);
Duration::from_millis(with_jitter.min(max))
}
fn bump_heartbeat(last_heartbeat_response: &mut Instant) {
*last_heartbeat_response = Instant::now();
}
async fn sleep_with_shutdown(total: Duration, shutdown: &Arc<AtomicBool>) -> bool {
let deadline = Instant::now() + total;
loop {
if shutdown.load(Ordering::SeqCst) {
return false;
}
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
return true;
}
let step = remaining.min(RECONNECT_POLL);
tokio::time::sleep(step).await;
}
}
impl RegistrationRunner {
pub async fn run(self) -> Result<()> {
let logger = Logger::new("multi/registration");
if self.shutdown.load(Ordering::SeqCst) {
logger.debug(&format!(
"registration {} skipped: shutdown signalled before run",
self.key
));
return Ok(());
}
let mut attempt: u32 = 0;
let mut ever_registered = false;
loop {
if self.shutdown.load(Ordering::SeqCst) {
logger.debug(&format!(
"registration {}: shutdown signalled, exiting reconnect loop",
self.key
));
return Ok(());
}
let register = {
let cfg = self.config.read().await;
let token = self.session_token.read().await.clone().unwrap_or_default();
build_register_message_with_token(&cfg, self.connector.as_ref(), &token)
};
let stream = match self.shared_channel.open_stream(register, 64).await {
Ok(s) => s,
Err(e) => {
logger.warn(&format!(
"registration {}: open_stream failed: {e}",
self.key
));
if !self.opts.reconnect_enabled {
return Ok(());
}
attempt = attempt.saturating_add(1);
let backoff = compute_backoff(&self.opts, attempt);
{
let mut m = self.metrics.lock().await;
m.reconnection_attempts += 1;
m.current_backoff_ms = backoff.as_millis() as u64;
}
if !sleep_with_shutdown(backoff, &self.shutdown).await {
return Ok(());
}
continue;
}
};
let mut stream = stream;
let response_deadline =
Duration::from_millis(self.opts.connect_timeout_ms.max(1)).saturating_mul(3);
match wait_for_register_response(&mut stream.inbound, &self.shutdown, response_deadline)
.await
{
Ok(accepted) => {
let arn = &accepted.connector_arn;
logger.info(&format!("registration {} registered (arn={arn})", self.key));
if !accepted.session_token.is_empty() {
*self.session_token.write().await = Some(accepted.session_token.clone());
}
{
let mut m = self.metrics.lock().await;
m.last_connected_at_ms = Some(now_ms());
m.current_backoff_ms = 0;
if ever_registered {
m.successful_reconnects += 1;
}
}
ever_registered = true;
attempt = 0;
}
Err(e) => {
logger.warn(&format!("registration {}: register failed: {e}", self.key));
if !self.opts.reconnect_enabled {
return Ok(());
}
attempt = attempt.saturating_add(1);
let backoff = compute_backoff(&self.opts, attempt);
{
let mut m = self.metrics.lock().await;
m.reconnection_attempts += 1;
m.current_backoff_ms = backoff.as_millis() as u64;
}
if !sleep_with_shutdown(backoff, &self.shutdown).await {
return Ok(());
}
continue;
}
}
let outcome = self.drive_stream(stream, &logger).await;
{
let mut m = self.metrics.lock().await;
m.total_disconnects += 1;
m.last_disconnected_at_ms = Some(now_ms());
m.last_disconnect_reason = Some(outcome.summary().to_string());
}
if self.shutdown.load(Ordering::SeqCst) {
return Ok(());
}
if !self.opts.reconnect_enabled {
logger.debug(&format!(
"registration {}: reconnect disabled, exiting after {outcome}",
self.key
));
return Ok(());
}
attempt = attempt.saturating_add(1);
let backoff = compute_backoff(&self.opts, attempt);
logger.warn(&format!(
"registration {}: stream ended ({outcome}); reconnecting in {}ms (attempt {})",
self.key,
backoff.as_millis(),
attempt
));
{
let mut m = self.metrics.lock().await;
m.reconnection_attempts += 1;
m.current_backoff_ms = backoff.as_millis() as u64;
}
if !sleep_with_shutdown(backoff, &self.shutdown).await {
return Ok(());
}
}
}
async fn drive_stream(&self, stream: SharedStream, logger: &Logger) -> StreamOutcome {
let SharedStream {
tx, mut inbound, ..
} = stream;
let mut last_heartbeat_response = Instant::now();
let hb_interval_dur = self.opts.heartbeat_interval.unwrap_or(HEARTBEAT_INTERVAL);
let hb_timeout_dur = self.opts.heartbeat_timeout.unwrap_or(HEARTBEAT_TIMEOUT);
let mut hb_interval = interval(hb_interval_dur);
hb_interval.set_missed_tick_behavior(MissedTickBehavior::Skip);
hb_interval.tick().await;
loop {
if self.shutdown.load(Ordering::SeqCst) {
logger.debug(&format!(
"registration {}: shutting down on signal",
self.key
));
return StreamOutcome::Shutdown;
}
tokio::select! {
msg_opt = inbound.next() => {
match msg_opt {
Some(Ok(msg)) => {
if let Some(outcome) = self
.dispatch_inbound(msg, &tx, &mut last_heartbeat_response, logger)
.await
{
return outcome;
}
}
Some(Err(status)) => {
logger.warn(&format!(
"registration {}: stream error: {status}",
self.key
));
return StreamOutcome::StreamError(status.to_string());
}
None => {
logger.debug(&format!(
"registration {}: server closed stream",
self.key
));
return StreamOutcome::ServerClosed;
}
}
}
_ = hb_interval.tick() => {
if last_heartbeat_response.elapsed() > hb_timeout_dur {
logger.warn(&format!(
"registration {}: no heartbeat response for {}s, presumed dead",
self.key,
last_heartbeat_response.elapsed().as_secs()
));
return StreamOutcome::HeartbeatTimeout;
}
let hb = StreamMessage {
message: Some(stream_message::Message::HeartbeatRequest(HeartbeatRequest {
gateway_id: String::new(),
timestamp_ms: now_ms() as i64,
})),
};
if tx.send(hb).await.is_err() {
logger.debug(&format!(
"registration {}: outbound channel closed; exiting",
self.key
));
return StreamOutcome::OutboundClosed;
}
}
_ = tokio::time::sleep(SHUTDOWN_POLL) => {
}
}
}
}
async fn dispatch_inbound(
&self,
msg: StreamMessage,
tx: &mpsc::Sender<StreamMessage>,
last_heartbeat_response: &mut Instant,
logger: &Logger,
) -> Option<StreamOutcome> {
match msg.message {
Some(stream_message::Message::ExecuteRequest(req)) => {
let request = SdkExecuteRequest {
request_id: req.request_id.clone(),
payload: req.payload.clone(),
payload_encoding: PayloadEncoding::from(req.payload_encoding),
context: req.context.clone(),
capability_id: if req.capability_id.is_empty() {
None
} else {
Some(req.capability_id.clone())
},
};
let connector = self.connector.clone();
let metrics = self.metrics.clone();
let tx = tx.clone();
let key = self.key.clone();
let semaphore = self.request_semaphore.clone();
let logger = Logger::new("multi/registration/execute");
tokio::spawn(async move {
let permit = match semaphore.acquire_owned().await {
Ok(p) => p,
Err(_) => {
logger.debug(&format!(
"registration {key}: request semaphore closed, dropping execute"
));
return;
}
};
if let Err(e) =
handle_execute(connector, request, tx, metrics, &logger, &key).await
{
logger.error(
&format!("registration {key}: execute dispatch failed"),
&e.to_string(),
);
}
drop(permit);
});
None
}
Some(stream_message::Message::HeartbeatRequest(_)) => {
let resp = StreamMessage {
message: Some(stream_message::Message::HeartbeatResponse(
HeartbeatResponse {
gateway_id: String::new(),
timestamp_ms: now_ms() as i64,
should_reconnect: false,
},
)),
};
let _ = tx.send(resp).await;
None
}
Some(stream_message::Message::HeartbeatResponse(_)) => {
*last_heartbeat_response = Instant::now();
None
}
Some(stream_message::Message::RegisterResponse(resp)) => {
if resp.success {
logger.info(&format!(
"registration {}: in-stream re-register succeeded (arn={})",
self.key, resp.connector_arn
));
if !resp.session_token.is_empty() {
*self.session_token.write().await = Some(resp.session_token.clone());
}
let mut m = self.metrics.lock().await;
m.last_connected_at_ms = Some(now_ms());
} else {
logger.warn(&format!(
"registration {}: in-stream re-register failed: status='{}' error='{}'",
self.key, resp.status, resp.error
));
}
None
}
Some(stream_message::Message::CredentialsIssued(creds)) => {
self.handle_credentials_issued(creds, tx, last_heartbeat_response, logger)
.await
}
Some(stream_message::Message::ApprovalNotification(notif)) => {
let status = proto::RegistrationStatus::try_from(notif.status);
match status {
Ok(proto::RegistrationStatus::Approved) => {
logger.info(&format!(
"registration {}: approved by admin (CredentialsIssued imminent)",
self.key
));
None
}
Ok(proto::RegistrationStatus::Pending) => {
logger.info(&format!(
"registration {}: pending approval — {}",
self.key,
if notif.message.is_empty() {
"awaiting admin"
} else {
¬if.message
}
));
None
}
Ok(proto::RegistrationStatus::Rejected) => {
logger.warn(&format!(
"registration {}: REJECTED by admin — {}",
self.key,
if notif.message.is_empty() {
"no reason given"
} else {
¬if.message
}
));
Some(StreamOutcome::ServerClosed)
}
_ => {
logger.debug(&format!(
"registration {}: ApprovalNotification status={} message={}",
self.key, notif.status, notif.message
));
None
}
}
}
Some(other) => {
logger.debug(&format!(
"registration {}: ignoring inbound variant {:?}",
self.key,
std::mem::discriminant(&other)
));
None
}
None => {
logger.debug(&format!("registration {}: empty inbound message", self.key));
None
}
}
}
async fn handle_credentials_issued(
&self,
creds: CredentialsIssued,
tx: &mpsc::Sender<StreamMessage>,
last_heartbeat_response: &mut Instant,
logger: &Logger,
) -> Option<StreamOutcome> {
if creds.ott.is_empty() {
logger.error(
&format!("registration {}: CredentialsIssued without OTT", self.key),
"",
);
return None;
}
if creds.matrix_api_url.is_empty() {
logger.error(
&format!(
"registration {}: CredentialsIssued without matrix_api_url",
self.key
),
"",
);
return None;
}
let (instance_id, connector_type) = (
self.config.read().await.instance_id.clone(),
self.connector.connector_type().to_string(),
);
let mut provider =
OttProvider::new(Some(connector_type.clone()), Some(instance_id.clone()));
match provider
.register_public_key_with_ott_data(
&creds.ott,
&creds.matrix_api_url,
&creds.register_url,
&connector_type,
Some(&instance_id),
)
.await
{
Ok(response) => {
logger.debug(&format!(
"registration {}: registered public key with OTT (client_id={})",
self.key, response.client_id
));
}
Err(e) => {
logger.error(
&format!(
"registration {}: failed to complete OTT registration",
self.key
),
&e.to_string(),
);
return None;
}
}
let jwt_token = match provider.get_token().await {
Ok(t) => t,
Err(e) => {
logger.error(
&format!(
"registration {}: failed to fetch JWT after OTT exchange",
self.key
),
&e.to_string(),
);
return None;
}
};
self.config.write().await.auth_token = jwt_token.clone();
drop(provider);
let register = {
let cfg = self.config.read().await;
let token = self.session_token.read().await.clone().unwrap_or_default();
build_register_message_with_token(&cfg, self.connector.as_ref(), &token)
};
match tx.send(register).await {
Ok(()) => {
bump_heartbeat(last_heartbeat_response);
logger.info(&format!(
"registration {}: sent JWT re-registration on existing stream",
self.key
));
None
}
Err(e) => {
logger.warn(&format!(
"registration {}: in-stream re-register send failed: {e}; reconnecting",
self.key
));
Some(StreamOutcome::OutboundClosed)
}
}
}
}
async fn handle_execute(
connector: Arc<dyn BaseConnector>,
request: SdkExecuteRequest,
tx: mpsc::Sender<StreamMessage>,
metrics: Arc<Mutex<ConnectorMetrics>>,
logger: &Logger,
key: &RegistrationKey,
) -> Result<()> {
let start = Instant::now();
{
let mut m = metrics.lock().await;
m.requests_received += 1;
m.bytes_received += request.payload.len() as u64;
m.last_request_at_ms = chrono::Utc::now().timestamp_millis().max(0) as u64;
}
let response = match deserialize_payload::<serde_json::Value>(
&request.payload,
request.payload_encoding,
) {
Ok(req_data) => match connector
.execute_with_context(req_data, request.capability_id.as_deref(), &request.context)
.await
{
Ok(resp_data) => match serialize_payload(&resp_data, PayloadEncoding::Json) {
Ok(payload) => {
let duration_ms = start.elapsed().as_millis() as u64;
{
let mut m = metrics.lock().await;
m.requests_processed += 1;
m.bytes_sent += payload.len() as u64;
m.total_duration_ms += duration_ms;
}
ExecuteResponse {
request_id: request.request_id,
success: true,
payload,
payload_encoding: PayloadEncoding::Json,
error: String::new(),
duration_ms,
}
}
Err(e) => {
logger.error(
&format!("registration {key}: serialization failed"),
&e.to_string(),
);
{
let mut m = metrics.lock().await;
m.requests_failed += 1;
}
ExecuteResponse {
request_id: request.request_id,
success: false,
payload: error_response(&e.to_string()).unwrap_or_default(),
payload_encoding: PayloadEncoding::Json,
error: e.to_string(),
duration_ms: start.elapsed().as_millis() as u64,
}
}
},
Err(e) => {
logger.error(
&format!("registration {key}: execute failed"),
&e.to_string(),
);
{
let mut m = metrics.lock().await;
m.requests_failed += 1;
}
ExecuteResponse {
request_id: request.request_id,
success: false,
payload: error_response(&e.to_string()).unwrap_or_default(),
payload_encoding: PayloadEncoding::Json,
error: e.to_string(),
duration_ms: start.elapsed().as_millis() as u64,
}
}
},
Err(e) => {
logger.error(
&format!("registration {key}: deserialization failed"),
&e.to_string(),
);
{
let mut m = metrics.lock().await;
m.requests_failed += 1;
}
ExecuteResponse {
request_id: request.request_id,
success: false,
payload: error_response(&e.to_string()).unwrap_or_default(),
payload_encoding: PayloadEncoding::Json,
error: e.to_string(),
duration_ms: start.elapsed().as_millis() as u64,
}
}
};
let message = StreamMessage {
message: Some(stream_message::Message::ExecuteResponse(
proto::ExecuteResponse {
request_id: response.request_id,
success: response.success,
payload: response.payload,
payload_encoding: response.payload_encoding as i32,
error: response.error,
duration_ms: response.duration_ms as i64,
},
)),
};
tx.send(message).await.map_err(|e| {
ConnectorError::StreamError(format!("failed to send execute response: {e}"))
})?;
Ok(())
}
fn now_ms() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
fn build_register_message_with_token(
config: &ConnectorConfig,
connector: &dyn BaseConnector,
session_token: &str,
) -> StreamMessage {
let mut msg = build_register_message(config, connector);
if let Some(stream_message::Message::RegisterRequest(ref mut req)) = msg.message {
req.session_token = session_token.to_string();
}
msg
}
fn build_register_message(
config: &ConnectorConfig,
connector: &dyn BaseConnector,
) -> StreamMessage {
let capabilities = ConnectorCapabilities {
connector_type: connector.connector_type().to_string(),
version: connector.version().to_string(),
supported_encodings: connector
.supported_encodings()
.iter()
.map(|e| *e as i32)
.collect(),
behaviors: connector.behaviors().iter().map(|b| *b as i32).collect(),
metadata: crate::connector::build_registration_metadata(connector),
task_types: connector
.capabilities()
.iter()
.map(|tt| proto::TaskTypeSchema {
task_type_id: tt.task_type_id.clone(),
name: tt.name.clone(),
description: tt.description.clone(),
category: tt.category.clone(),
icon: tt.icon.clone(),
input_schema_json: tt.input_schema_json.clone(),
output_schema_json: tt.output_schema_json.clone(),
})
.collect(),
};
let sanitized_instance = sanitize_identifier(&config.instance_id);
let mut metadata = config.metadata.clone();
crate::sdk_metadata::merge_into(
&mut metadata,
&config.transport_type.to_string(),
config.use_tls,
);
let instance_metadata = Some(InstanceMetadata {
display_name: config
.display_name
.clone()
.unwrap_or_else(|| sanitized_instance.clone()),
tags: config.tags.clone(),
metadata,
});
let request = RegisterConnectorRequest {
tenant_id: sanitize_identifier(&config.tenant_id),
connector_type: sanitize_identifier(connector.connector_type()),
instance_id: sanitized_instance,
capabilities: Some(capabilities),
jwt_token: config.auth_token.clone(),
session_token: String::new(),
scope: 0,
instance_metadata,
};
StreamMessage {
message: Some(stream_message::Message::RegisterRequest(request)),
}
}
#[derive(Debug, Clone)]
pub(crate) struct RegisterAccepted {
pub connector_arn: String,
pub session_token: String,
}
async fn wait_for_register_response(
inbound: &mut tonic::Streaming<StreamMessage>,
shutdown: &Arc<AtomicBool>,
deadline: Duration,
) -> Result<RegisterAccepted> {
let started = Instant::now();
loop {
let remaining = deadline.saturating_sub(started.elapsed());
if remaining.is_zero() {
return Err(ConnectorError::Timeout(
"register response not received within deadline".into(),
));
}
tokio::select! {
biased;
_ = wait_for_shutdown(shutdown) => {
return Err(ConnectorError::StreamError(
"shutdown signalled while waiting for register response".into(),
));
}
_ = tokio::time::sleep(remaining) => {
return Err(ConnectorError::Timeout(
"register response not received within deadline".into(),
));
}
inbound_result = inbound.next() => match inbound_result {
None => {
return Err(ConnectorError::StreamError(
"stream closed before register response".into(),
));
}
Some(Err(status)) => {
return Err(ConnectorError::Grpc(Box::new(status)));
}
Some(Ok(msg)) => match msg.message {
Some(stream_message::Message::RegisterResponse(resp)) => {
if !resp.success {
return Err(ConnectorError::RegistrationError(format!(
"register failed: status='{}' error='{}'",
resp.status, resp.error
)));
}
return Ok(RegisterAccepted {
connector_arn: resp.connector_arn,
session_token: resp.session_token,
});
}
Some(_) => continue,
None => continue,
},
},
}
}
}
async fn wait_for_shutdown(shutdown: &Arc<AtomicBool>) {
while !shutdown.load(Ordering::SeqCst) {
tokio::time::sleep(RECONNECT_POLL).await;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::multi::shared_channel::SharedChannel;
use proto::ApprovalNotification;
use std::pin::Pin;
use tokio::sync::Semaphore;
struct NoopConn;
impl BaseConnector for NoopConn {
fn connector_type(&self) -> &str {
"test_conn"
}
fn version(&self) -> &str {
"0.0.0"
}
fn execute(
&self,
_: serde_json::Value,
_: Option<&str>,
) -> Pin<Box<dyn std::future::Future<Output = Result<serde_json::Value>> + Send + '_>>
{
Box::pin(async { Ok(serde_json::json!({})) })
}
}
fn make_runner_for_dispatch_tests() -> RegistrationRunner {
let key = RegistrationKey {
tenant_id: "t".into(),
connector_type: "test_conn".into(),
instance_id: "i".into(),
};
let opts = MultiTransportOptions::default();
RegistrationRunner {
key,
config: Arc::new(RwLock::new(crate::connector::ConnectorConfig::default())),
connector: Arc::new(NoopConn),
shared_channel: Arc::new(SharedChannel::new(opts.clone())),
shutdown: Arc::new(AtomicBool::new(false)),
metrics: Arc::new(Mutex::new(ConnectorMetrics::default())),
opts,
request_semaphore: Arc::new(Semaphore::new(8)),
session_token: Arc::new(RwLock::new(None)),
}
}
fn opts_for_backoff(base: u64, max: u64, jitter: u64) -> MultiTransportOptions {
MultiTransportOptions {
reconnect_delay_ms: base,
max_backoff_delay_ms: max,
reconnect_jitter_ms: jitter,
..MultiTransportOptions::default()
}
}
#[test]
fn backoff_never_exceeds_max_even_with_jitter() {
let opts = opts_for_backoff(500, 1_000, 10_000);
for attempt in 1u32..=64 {
for _ in 0..32 {
let d = compute_backoff(&opts, attempt);
assert!(
d.as_millis() as u64 <= opts.max_backoff_delay_ms,
"compute_backoff(attempt={attempt}) = {}ms exceeds cap {}ms",
d.as_millis(),
opts.max_backoff_delay_ms
);
}
}
}
#[test]
fn backoff_zero_jitter_is_capped_exactly() {
let opts = opts_for_backoff(500, 1_000, 0);
assert_eq!(compute_backoff(&opts, 1).as_millis(), 500);
assert_eq!(compute_backoff(&opts, 2).as_millis(), 1_000);
assert_eq!(compute_backoff(&opts, 8).as_millis(), 1_000);
}
#[test]
fn bump_heartbeat_moves_timestamp_forward() {
let original = Instant::now() - Duration::from_secs(120);
let mut ts = original;
assert!(original.elapsed() >= Duration::from_secs(120));
bump_heartbeat(&mut ts);
assert!(
ts > original,
"bump_heartbeat must advance the watchdog timestamp"
);
assert!(
ts.elapsed() < Duration::from_secs(5),
"bump_heartbeat must reset to ~now"
);
}
#[tokio::test]
async fn dispatch_heartbeat_request_replies_with_response() {
let runner = make_runner_for_dispatch_tests();
let (tx, mut rx) = mpsc::channel::<StreamMessage>(8);
let mut last_hb = Instant::now();
let logger = Logger::new("test");
let msg = StreamMessage {
message: Some(stream_message::Message::HeartbeatRequest(
HeartbeatRequest {
gateway_id: "gw1".into(),
timestamp_ms: 1234,
},
)),
};
let outcome = runner
.dispatch_inbound(msg, &tx, &mut last_hb, &logger)
.await;
assert!(
outcome.is_none(),
"HeartbeatRequest must NOT terminate the dispatch loop"
);
let resp = rx
.try_recv()
.expect("must reply to inbound HeartbeatRequest");
match resp.message {
Some(stream_message::Message::HeartbeatResponse(_)) => {}
other => panic!("expected HeartbeatResponse, got {other:?}"),
}
}
#[tokio::test]
async fn dispatch_heartbeat_response_advances_watchdog() {
let runner = make_runner_for_dispatch_tests();
let (tx, _rx) = mpsc::channel::<StreamMessage>(8);
let mut last_hb = Instant::now() - Duration::from_secs(60);
let before = last_hb;
let logger = Logger::new("test");
let msg = StreamMessage {
message: Some(stream_message::Message::HeartbeatResponse(
HeartbeatResponse {
gateway_id: String::new(),
timestamp_ms: 0,
should_reconnect: false,
},
)),
};
let outcome = runner
.dispatch_inbound(msg, &tx, &mut last_hb, &logger)
.await;
assert!(outcome.is_none());
assert!(
last_hb > before,
"HeartbeatResponse must reset the watchdog"
);
}
#[tokio::test]
async fn dispatch_approval_pending_does_not_close_stream() {
let runner = make_runner_for_dispatch_tests();
let (tx, _rx) = mpsc::channel::<StreamMessage>(8);
let mut last_hb = Instant::now();
let logger = Logger::new("test");
let msg = StreamMessage {
message: Some(stream_message::Message::ApprovalNotification(
ApprovalNotification {
status: proto::RegistrationStatus::Pending as i32,
message: "awaiting admin".into(),
..Default::default()
},
)),
};
let outcome = runner
.dispatch_inbound(msg, &tx, &mut last_hb, &logger)
.await;
assert!(
outcome.is_none(),
"Pending approval must not terminate the stream"
);
}
#[tokio::test]
async fn dispatch_approval_approved_does_not_close_stream() {
let runner = make_runner_for_dispatch_tests();
let (tx, _rx) = mpsc::channel::<StreamMessage>(8);
let mut last_hb = Instant::now();
let logger = Logger::new("test");
let msg = StreamMessage {
message: Some(stream_message::Message::ApprovalNotification(
ApprovalNotification {
status: proto::RegistrationStatus::Approved as i32,
..Default::default()
},
)),
};
let outcome = runner
.dispatch_inbound(msg, &tx, &mut last_hb, &logger)
.await;
assert!(
outcome.is_none(),
"Approved must not terminate the stream — CredentialsIssued follows"
);
}
#[tokio::test]
async fn dispatch_approval_rejected_returns_server_closed() {
let runner = make_runner_for_dispatch_tests();
let (tx, _rx) = mpsc::channel::<StreamMessage>(8);
let mut last_hb = Instant::now();
let logger = Logger::new("test");
let msg = StreamMessage {
message: Some(stream_message::Message::ApprovalNotification(
ApprovalNotification {
status: proto::RegistrationStatus::Rejected as i32,
message: "not allowed".into(),
..Default::default()
},
)),
};
let outcome = runner
.dispatch_inbound(msg, &tx, &mut last_hb, &logger)
.await;
assert!(
matches!(outcome, Some(StreamOutcome::ServerClosed)),
"Rejected approval must end the stream so we restart through pending"
);
}
#[tokio::test]
async fn dispatch_empty_or_unknown_message_is_no_op() {
let runner = make_runner_for_dispatch_tests();
let (tx, _rx) = mpsc::channel::<StreamMessage>(8);
let mut last_hb = Instant::now();
let logger = Logger::new("test");
let none_msg = StreamMessage { message: None };
assert!(
runner
.dispatch_inbound(none_msg, &tx, &mut last_hb, &logger)
.await
.is_none()
);
}
#[tokio::test]
async fn wait_for_register_response_observes_shutdown() {
let shutdown = Arc::new(AtomicBool::new(false));
let shutdown_for_signal = shutdown.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(50)).await;
shutdown_for_signal.store(true, Ordering::SeqCst);
});
wait_for_shutdown(&shutdown).await;
assert!(shutdown.load(Ordering::SeqCst));
}
#[test]
fn runner_uses_custom_heartbeat_interval_and_timeout_from_opts() {
let opts = MultiTransportOptions::builder()
.heartbeat_interval(Duration::from_secs(5))
.heartbeat_timeout(Duration::from_secs(15))
.build();
let runner = RegistrationRunner {
key: RegistrationKey {
tenant_id: "t".into(),
connector_type: "c".into(),
instance_id: "i".into(),
},
config: Arc::new(RwLock::new(crate::connector::ConnectorConfig::default())),
connector: Arc::new(NoopConn),
shared_channel: Arc::new(SharedChannel::new(opts.clone())),
shutdown: Arc::new(AtomicBool::new(false)),
metrics: Arc::new(Mutex::new(ConnectorMetrics::default())),
opts: opts.clone(),
request_semaphore: Arc::new(Semaphore::new(8)),
session_token: Arc::new(RwLock::new(None)),
};
assert_eq!(
runner.opts.heartbeat_interval,
Some(Duration::from_secs(5)),
"custom heartbeat_interval must be wired through to the runner"
);
assert_eq!(
runner.opts.heartbeat_timeout,
Some(Duration::from_secs(15)),
"custom heartbeat_timeout must be wired through to the runner"
);
assert_eq!(HEARTBEAT_INTERVAL, Duration::from_secs(30));
assert_eq!(HEARTBEAT_TIMEOUT, Duration::from_secs(45));
}
#[test]
fn backoff_huge_attempt_with_huge_jitter_still_capped() {
let opts = opts_for_backoff(u64::MAX / 4, 60_000, u64::MAX / 4);
for attempt in 1u32..=128 {
let d = compute_backoff(&opts, attempt);
assert!(d.as_millis() as u64 <= opts.max_backoff_delay_ms);
}
}
use std::collections::HashMap as StdHashMap;
use std::sync::Mutex as StdMutex;
struct CapturingConn {
seen: Arc<StdMutex<Option<StdHashMap<String, String>>>>,
}
impl BaseConnector for CapturingConn {
fn connector_type(&self) -> &str {
"capturing"
}
fn version(&self) -> &str {
"0.0.0"
}
fn execute(
&self,
_request: serde_json::Value,
_capability_id: Option<&str>,
) -> Pin<Box<dyn std::future::Future<Output = Result<serde_json::Value>> + Send + '_>>
{
Box::pin(async {
unreachable!("SDK must dispatch through execute_with_context, not bare execute")
})
}
fn execute_with_context<'a>(
&'a self,
_request: serde_json::Value,
_capability_id: Option<&'a str>,
context: &'a StdHashMap<String, String>,
) -> Pin<Box<dyn std::future::Future<Output = Result<serde_json::Value>> + Send + 'a>>
{
let captured = context.clone();
let slot = self.seen.clone();
Box::pin(async move {
*slot.lock().unwrap() = Some(captured);
Ok(serde_json::json!({ "ok": true }))
})
}
}
struct LegacyConn {
called: Arc<AtomicBool>,
}
impl BaseConnector for LegacyConn {
fn connector_type(&self) -> &str {
"legacy"
}
fn version(&self) -> &str {
"0.0.0"
}
fn execute(
&self,
_request: serde_json::Value,
_capability_id: Option<&str>,
) -> Pin<Box<dyn std::future::Future<Output = Result<serde_json::Value>> + Send + '_>>
{
let flag = self.called.clone();
Box::pin(async move {
flag.store(true, Ordering::SeqCst);
Ok(serde_json::json!({ "legacy": true }))
})
}
}
fn key_for_context_tests() -> RegistrationKey {
RegistrationKey {
tenant_id: "t".into(),
connector_type: "ctx".into(),
instance_id: "i".into(),
}
}
#[tokio::test]
async fn handle_execute_forwards_request_context_to_connector() {
let seen = Arc::new(StdMutex::new(None));
let connector: Arc<dyn BaseConnector> = Arc::new(CapturingConn { seen: seen.clone() });
let metrics = Arc::new(Mutex::new(ConnectorMetrics::default()));
let logger = Logger::new("test/ctx");
let key = key_for_context_tests();
let mut context = StdHashMap::new();
context.insert("tenant_id".to_string(), "tenant-acme".to_string());
context.insert("user_id".to_string(), "user-42".to_string());
context.insert("strike48.attrs.region".to_string(), "us-east-1".to_string());
let payload = serialize_payload(&serde_json::json!({}), PayloadEncoding::Json)
.expect("serialize empty payload");
let request = SdkExecuteRequest {
request_id: "req-ctx-1".into(),
payload,
payload_encoding: PayloadEncoding::Json,
context: context.clone(),
capability_id: None,
};
let (tx, mut rx) = mpsc::channel::<StreamMessage>(8);
handle_execute(connector, request, tx, metrics, &logger, &key)
.await
.expect("handle_execute should succeed");
let captured = seen
.lock()
.unwrap()
.clone()
.expect("execute_with_context must have been invoked");
assert_eq!(
captured, context,
"context map must round-trip from request → connector"
);
let resp = rx.try_recv().expect("an ExecuteResponse must be sent back");
match resp.message {
Some(stream_message::Message::ExecuteResponse(r)) => {
assert!(r.success, "successful execute must produce success=true");
assert_eq!(r.request_id, "req-ctx-1");
}
other => panic!("expected ExecuteResponse, got {other:?}"),
}
}
#[tokio::test]
async fn handle_execute_keeps_working_for_legacy_execute_only_connector() {
let called = Arc::new(AtomicBool::new(false));
let connector: Arc<dyn BaseConnector> = Arc::new(LegacyConn {
called: called.clone(),
});
let metrics = Arc::new(Mutex::new(ConnectorMetrics::default()));
let logger = Logger::new("test/ctx");
let key = key_for_context_tests();
let payload = serialize_payload(&serde_json::json!({}), PayloadEncoding::Json)
.expect("serialize empty payload");
let request = SdkExecuteRequest {
request_id: "req-legacy-1".into(),
payload,
payload_encoding: PayloadEncoding::Json,
context: {
let mut m = StdHashMap::new();
m.insert("tenant_id".to_string(), "ignored".to_string());
m
},
capability_id: None,
};
let (tx, mut rx) = mpsc::channel::<StreamMessage>(8);
handle_execute(connector, request, tx, metrics, &logger, &key)
.await
.expect("handle_execute should succeed");
assert!(
called.load(Ordering::SeqCst),
"default execute_with_context must delegate to execute"
);
let resp = rx.try_recv().expect("must produce an ExecuteResponse");
match resp.message {
Some(stream_message::Message::ExecuteResponse(r)) => {
assert!(r.success);
assert_eq!(r.request_id, "req-legacy-1");
}
other => panic!("expected ExecuteResponse, got {other:?}"),
}
}
}