use super::{PendingConnection, StdbConnection, StdbConnectionConfig};
use crate::{
alias::{ReadStdbConnectErrorMessage, ReadStdbConnectedMessage, ReadStdbDisconnectedMessage},
set::StdbSet,
};
use bevy_app::{App, Plugin, PreUpdate};
use bevy_ecs::prelude::{
Commands, IntoScheduleConfigs, Res, ResMut, Resource, not, resource_exists,
};
use bevy_tasks::IoTaskPool;
use bevy_time::{Time, Timer, TimerMode};
use spacetimedb_sdk::{
__codegen::{DbConnection, SpacetimeModule},
DbContext,
};
use std::{marker::PhantomData, ops::Deref, time::Duration};
#[derive(Clone, Debug)]
pub struct StdbReconnectOptions {
pub initial_delay: Duration,
pub max_attempts: u32,
pub backoff_factor: f32,
pub max_delay: Duration,
}
impl Default for StdbReconnectOptions {
fn default() -> Self {
Self {
initial_delay: Duration::from_secs(1),
max_attempts: 0,
backoff_factor: 1.5,
max_delay: Duration::from_secs(15),
}
}
}
#[derive(Resource, Clone)]
struct ReconnectConfig(pub StdbReconnectOptions);
impl Deref for ReconnectConfig {
type Target = StdbReconnectOptions;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Resource, Default)]
struct ReconnectBackoff {
attempts: u32,
current_delay: Duration,
timer: Option<Timer>,
}
pub(crate) struct ReconnectPlugin<C, M>
where
C: DbConnection<Module = M> + DbContext + Send + Sync,
M: SpacetimeModule<DbConnection = C>,
{
reconnect_options: StdbReconnectOptions,
_marker: PhantomData<(C, M)>,
}
impl<C, M> ReconnectPlugin<C, M>
where
C: DbConnection<Module = M> + DbContext + Send + Sync,
M: SpacetimeModule<DbConnection = C>,
{
pub(crate) fn new(reconnect_options: StdbReconnectOptions) -> Self {
Self {
reconnect_options,
_marker: PhantomData,
}
}
}
impl<
C: DbConnection<Module = M> + DbContext + Send + Sync + 'static,
M: SpacetimeModule<DbConnection = C> + 'static,
> Plugin for ReconnectPlugin<C, M>
{
fn build(&self, app: &mut App) {
app.insert_resource(ReconnectConfig(self.reconnect_options.clone()));
app.init_resource::<ReconnectBackoff>();
app.add_systems(
PreUpdate,
(on_connect, arm_reconnect_timer).in_set(StdbSet::Connection),
);
app.add_systems(
PreUpdate,
tick_reconnect_timer::<C, M>
.run_if(not(resource_exists::<StdbConnection<C>>))
.in_set(StdbSet::Connection),
);
}
}
fn on_connect(
mut msgs: ReadStdbConnectedMessage,
mut backoff: ResMut<ReconnectBackoff>,
config: Res<ReconnectConfig>,
) {
if msgs.read().next().is_some() {
backoff.attempts = 0;
backoff.current_delay = config.initial_delay;
backoff.timer = None;
}
}
fn arm_reconnect_timer(
mut disconnect_msgs: ReadStdbDisconnectedMessage,
mut error_msgs: ReadStdbConnectErrorMessage,
mut backoff: ResMut<ReconnectBackoff>,
config: Res<ReconnectConfig>,
) {
let unexpected_disconnect = disconnect_msgs.read().any(|msg| msg.err.is_some());
let connect_error = error_msgs.read().next().is_some();
if !(unexpected_disconnect || connect_error) {
return;
}
if backoff.current_delay.is_zero() {
backoff.current_delay = config.initial_delay;
}
backoff.timer = Some(Timer::new(backoff.current_delay, TimerMode::Once));
}
fn tick_reconnect_timer<C, M>(
time: Res<Time>,
mut backoff: ResMut<ReconnectBackoff>,
config: Res<ReconnectConfig>,
conn_config: Res<StdbConnectionConfig<C, M>>,
pending: Option<Res<PendingConnection<C>>>,
mut commands: Commands,
) where
C: DbConnection<Module = M> + DbContext + Send + Sync + 'static,
M: SpacetimeModule<DbConnection = C> + 'static,
{
if backoff.timer.is_none() || pending.is_some() {
return;
}
let Some(timer) = backoff.timer.as_mut() else {
return;
};
timer.tick(time.delta());
if !timer.just_finished() {
return;
}
backoff.timer = None;
backoff.attempts += 1;
if config.max_attempts > 0 && backoff.attempts > config.max_attempts {
return;
}
let next_delay = backoff
.current_delay
.mul_f32(config.backoff_factor.max(1.0));
backoff.current_delay = next_delay.min(config.max_delay);
let conn_config = conn_config.clone();
let task = IoTaskPool::get().spawn(async move { conn_config.build_connection().await });
commands.insert_resource(PendingConnection::<C>(task));
}