use std::{
ops::ControlFlow,
sync::{Arc, Mutex},
time::{Duration, Instant},
};
use bestool_postgres::pool::PgConnection;
use rustyline::ExternalPrinter;
use tracing::{debug, error, warn};
use super::{ReplState, state::ReplContext, transaction::TransactionState};
use crate::ots;
pub const WRITE_MODE_IDLE_TIMEOUT: Duration = Duration::from_secs(10 * 60);
pub const WRITE_MODE_TIMEOUT_CHECK_INTERVAL: Duration = Duration::from_secs(30);
pub fn mark_write_mode_active(repl_state: &Arc<Mutex<ReplState>>) {
let mut state = repl_state.lock().unwrap();
if state.write_mode {
state.write_mode_active_at = Some(Instant::now());
}
}
fn should_time_out_write_mode(
in_write_mode: bool,
active_at: Option<Instant>,
now: Instant,
tx_state: TransactionState,
idle_threshold: Duration,
) -> bool {
if !in_write_mode {
return false;
}
let Some(active_at) = active_at else {
return false;
};
if now.duration_since(active_at) < idle_threshold {
return false;
}
matches!(tx_state, TransactionState::Idle | TransactionState::None)
}
pub async fn handle_write_mode_toggle(
ctx: &mut ReplContext<'_>,
ots: Option<String>,
) -> ControlFlow<()> {
let state = { ctx.repl_state.lock().unwrap().clone() };
if state.write_mode {
let tx_state = TransactionState::check(ctx.monitor_client, ctx.backend_pid).await;
if !matches!(tx_state, TransactionState::Idle | TransactionState::None) {
eprintln!(
"Cannot disable write mode while in a transaction. COMMIT or ROLLBACK first."
);
return ControlFlow::Continue(());
}
let mut new_state = state.clone();
new_state.write_mode = false;
new_state.ots = None;
new_state.write_mode_active_at = None;
match ctx
.client
.batch_execute("ROLLBACK; SET SESSION CHARACTERISTICS AS TRANSACTION READ ONLY")
.await
{
Ok(_) => {
debug!("Write mode disabled");
eprintln!("SESSION IS NOW READ ONLY");
*ctx.repl_state.lock().unwrap() = new_state;
}
Err(e) => {
error!("Failed to disable write mode: {e}");
}
}
} else {
let supplied_ots = ots.map(|s| s.trim().to_string()).filter(|s| !s.is_empty());
let new_ots = match supplied_ots {
Some(ots) => Ok(ots),
None => ots::prompt_for_ots(ctx.rl.history()),
};
match new_ots {
Ok(new_ots) => {
let mut new_state = state.clone();
new_state.write_mode = true;
new_state.ots = Some(new_ots.clone());
new_state.write_mode_active_at = Some(Instant::now());
match ctx
.client
.batch_execute(
"SET SESSION CHARACTERISTICS AS TRANSACTION READ WRITE; COMMIT; BEGIN",
)
.await
{
Ok(_) => {
debug!("Write mode enabled");
eprintln!("AUTOCOMMIT IS OFF -- REMEMBER TO `COMMIT;` YOUR WRITES");
*ctx.repl_state.lock().unwrap() = new_state;
}
Err(e) => {
error!("Failed to enable write mode: {e}");
}
}
}
Err(e) => {
error!("Failed to enable write mode: {e}");
}
}
}
ControlFlow::Continue(())
}
pub async fn watch_write_mode_idle_timeout<P: ExternalPrinter + Send + 'static>(
client: Arc<PgConnection>,
monitor_client: Arc<PgConnection>,
backend_pid: i32,
repl_state: Arc<Mutex<ReplState>>,
mut printer: P,
) {
loop {
tokio::time::sleep(WRITE_MODE_TIMEOUT_CHECK_INTERVAL).await;
let (in_write_mode, active_at) = {
let state = repl_state.lock().unwrap();
(state.write_mode, state.write_mode_active_at)
};
if !in_write_mode {
continue;
}
let Some(active_at) = active_at else { continue };
if active_at.elapsed() < WRITE_MODE_IDLE_TIMEOUT {
continue;
}
let tx_state = TransactionState::check(&monitor_client, backend_pid).await;
if !should_time_out_write_mode(
in_write_mode,
Some(active_at),
Instant::now(),
tx_state,
WRITE_MODE_IDLE_TIMEOUT,
) {
continue;
}
match client
.batch_execute("ROLLBACK; SET SESSION CHARACTERISTICS AS TRANSACTION READ ONLY")
.await
{
Ok(_) => {
{
let mut state = repl_state.lock().unwrap();
if !state.write_mode {
continue;
}
state.write_mode = false;
state.ots = None;
state.write_mode_active_at = None;
}
let minutes = WRITE_MODE_IDLE_TIMEOUT.as_secs() / 60;
let msg = format!(
"\nWrite mode idle for {minutes} minutes — session reverted to read-only."
);
if let Err(e) = printer.print(msg) {
warn!("failed to print write-mode timeout notice: {e}");
}
debug!("write mode timed out due to inactivity");
}
Err(e) => {
warn!("failed to revert write mode on timeout: {e}");
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn baseline() -> (Instant, Instant) {
let earlier = Instant::now();
let later = earlier + Duration::from_secs(3600);
(earlier, later)
}
#[test]
fn does_not_time_out_when_write_mode_off() {
let (earlier, later) = baseline();
assert!(!should_time_out_write_mode(
false,
Some(earlier),
later,
TransactionState::None,
Duration::from_secs(600),
));
}
#[test]
fn does_not_time_out_when_no_activity_recorded() {
let (_, later) = baseline();
assert!(!should_time_out_write_mode(
true,
None,
later,
TransactionState::None,
Duration::from_secs(600),
));
}
#[test]
fn does_not_time_out_before_threshold() {
let earlier = Instant::now();
let later = earlier + Duration::from_secs(60);
assert!(!should_time_out_write_mode(
true,
Some(earlier),
later,
TransactionState::None,
Duration::from_secs(600),
));
}
#[test]
fn does_not_time_out_with_uncommitted_writes() {
let (earlier, later) = baseline();
assert!(!should_time_out_write_mode(
true,
Some(earlier),
later,
TransactionState::Active,
Duration::from_secs(600),
));
assert!(!should_time_out_write_mode(
true,
Some(earlier),
later,
TransactionState::Error,
Duration::from_secs(600),
));
}
#[test]
fn times_out_when_idle_and_past_threshold() {
let earlier = Instant::now();
let later = earlier + Duration::from_secs(601);
assert!(should_time_out_write_mode(
true,
Some(earlier),
later,
TransactionState::Idle,
Duration::from_secs(600),
));
assert!(should_time_out_write_mode(
true,
Some(earlier),
later,
TransactionState::None,
Duration::from_secs(600),
));
}
#[test]
fn mark_write_mode_active_only_runs_when_in_write_mode() {
let state = Arc::new(Mutex::new(ReplState::new()));
mark_write_mode_active(&state);
assert!(state.lock().unwrap().write_mode_active_at.is_none());
state.lock().unwrap().write_mode = true;
mark_write_mode_active(&state);
assert!(state.lock().unwrap().write_mode_active_at.is_some());
}
}