mod config;
pub use config::WithdrawalWorkerConfig;
use rand::seq::SliceRandom;
use std::sync::Arc;
use std::time::Duration;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, warn};
use zeroize::Zeroize;
use crate::errors::AppError;
use crate::repositories::{
DepositRepository, DepositSessionEntity, DepositStatus, WithdrawalHistoryEntry,
WithdrawalHistoryRepository,
};
use crate::services::{
decrypt_base64_payload, AdminNotification, NoteEncryptionService, NotificationService,
NotificationSeverity, PrivacySidecarClient, SettingsService,
};
mod defaults {
pub const POLL_INTERVAL_SECS: u64 = 3600;
pub const BATCH_SIZE: u32 = 10;
pub const TIMEOUT_SECS: u64 = 120;
pub const MAX_RETRIES: u32 = 3;
pub const WITHDRAWAL_PERCENTAGE: u8 = 100;
pub const PARTIAL_COUNT: u8 = 0;
pub const PARTIAL_MIN_LAMPORTS: u64 = 500_000_000;
pub const MIN_LAMPORTS: u64 = 1_000_000_000;
}
fn clamp_partial_withdrawal_count(value: u32) -> u8 {
value.min(u8::MAX as u32) as u8
}
#[derive(Clone)]
pub struct WithdrawalWorker {
deposit_repo: Arc<dyn DepositRepository>,
withdrawal_history_repo: Arc<dyn WithdrawalHistoryRepository>,
sidecar: Arc<PrivacySidecarClient>,
note_encryption: Arc<NoteEncryptionService>,
notification_service: Arc<dyn NotificationService>,
settings_service: Arc<SettingsService>,
config: WithdrawalWorkerConfig,
}
impl WithdrawalWorker {
fn compute_spent_lamports(amount_lamports: i64, fee_lamports: i64) -> Result<i64, AppError> {
if amount_lamports < 0 {
return Err(AppError::Internal(anyhow::anyhow!(
"Negative withdrawal amount from sidecar"
)));
}
if fee_lamports < 0 {
return Err(AppError::Internal(anyhow::anyhow!(
"Negative withdrawal fee from sidecar"
)));
}
let spent = (amount_lamports as i128)
.checked_add(fee_lamports as i128)
.ok_or_else(|| AppError::Internal(anyhow::anyhow!("Withdrawal amount overflow")))?;
i64::try_from(spent)
.map_err(|_| AppError::Internal(anyhow::anyhow!("Withdrawal amount overflow")))
}
pub fn new(
deposit_repo: Arc<dyn DepositRepository>,
withdrawal_history_repo: Arc<dyn WithdrawalHistoryRepository>,
sidecar: Arc<PrivacySidecarClient>,
note_encryption: Arc<NoteEncryptionService>,
notification_service: Arc<dyn NotificationService>,
settings_service: Arc<SettingsService>,
config: WithdrawalWorkerConfig,
) -> Self {
Self {
deposit_repo,
withdrawal_history_repo,
sidecar,
note_encryption,
notification_service,
settings_service,
config,
}
}
async fn get_poll_interval(&self) -> u64 {
self.settings_service
.get_u64("withdrawal_poll_interval_secs")
.await
.ok()
.flatten()
.unwrap_or(defaults::POLL_INTERVAL_SECS)
}
async fn get_batch_size(&self) -> u32 {
self.settings_service
.get_u32("withdrawal_batch_size")
.await
.ok()
.flatten()
.unwrap_or(defaults::BATCH_SIZE)
}
async fn get_timeout(&self) -> u64 {
self.settings_service
.get_u64("withdrawal_timeout_secs")
.await
.ok()
.flatten()
.unwrap_or(defaults::TIMEOUT_SECS)
}
async fn get_max_retries(&self) -> u32 {
self.settings_service
.get_u32("withdrawal_max_retries")
.await
.ok()
.flatten()
.unwrap_or(defaults::MAX_RETRIES)
}
async fn get_withdrawal_percentage(&self) -> u8 {
self.settings_service
.get_u32("withdrawal_percentage")
.await
.ok()
.flatten()
.map(|v| v.clamp(1, 100) as u8)
.unwrap_or(defaults::WITHDRAWAL_PERCENTAGE)
}
async fn get_partial_count(&self) -> u8 {
self.settings_service
.get_u32("partial_withdrawal_count")
.await
.ok()
.flatten()
.map(clamp_partial_withdrawal_count)
.unwrap_or(defaults::PARTIAL_COUNT)
}
async fn get_partial_min_lamports(&self) -> u64 {
self.settings_service
.get_u64("partial_withdrawal_min_lamports")
.await
.ok()
.flatten()
.unwrap_or(defaults::PARTIAL_MIN_LAMPORTS)
}
async fn get_min_lamports(&self) -> u64 {
self.settings_service
.get_u64("withdrawal_min_lamports")
.await
.ok()
.flatten()
.unwrap_or(defaults::MIN_LAMPORTS)
}
pub fn start(self, cancel_token: CancellationToken) -> JoinHandle<()> {
tokio::spawn(async move {
let poll_interval = self.get_poll_interval().await;
info!(
poll_interval = poll_interval,
"Withdrawal worker started (settings from DB)"
);
let mut interval = tokio::time::interval(Duration::from_secs(poll_interval));
let mut current_poll_interval = poll_interval;
loop {
tokio::select! {
_ = cancel_token.cancelled() => {
info!("Withdrawal worker shutting down gracefully");
break;
}
_ = interval.tick() => {
let new_poll_interval = self.get_poll_interval().await;
if new_poll_interval != current_poll_interval {
info!(
old_interval = current_poll_interval,
new_interval = new_poll_interval,
"Poll interval changed, updating timer"
);
interval = tokio::time::interval(Duration::from_secs(new_poll_interval));
current_poll_interval = new_poll_interval;
}
if let Err(e) = self.process_batch().await {
error!(error = %e, "Failed to process withdrawal batch");
}
}
}
}
info!("Withdrawal worker stopped");
})
}
async fn process_batch(&self) -> Result<(), AppError> {
let batch_size = self.get_batch_size().await;
let withdrawal_percentage = self.get_withdrawal_percentage().await;
let partial_count = self.get_partial_count().await as usize;
let partial_min_lamports = self.get_partial_min_lamports().await as i64;
let min_lamports = self.get_min_lamports().await as i64;
let now = chrono::Utc::now();
let mut sessions = self
.deposit_repo
.claim_ready_for_withdrawal(now, batch_size)
.await?;
if sessions.is_empty() {
return Ok(());
}
let mut skipped_sessions = Vec::new();
sessions.retain(|s| {
if s.remaining_lamports() >= min_lamports {
true
} else {
skipped_sessions.push(s.id);
false
}
});
for id in &skipped_sessions {
if let Err(e) = self
.deposit_repo
.update_status(*id, DepositStatus::Completed, None)
.await
{
warn!(session_id = %id, error = %e, "Failed to reset skipped session status");
}
}
if !skipped_sessions.is_empty() {
debug!(
skipped = skipped_sessions.len(),
min_lamports = min_lamports,
"Skipped and reset withdrawals below minimum amount"
);
}
if sessions.is_empty() {
return Ok(());
}
let pct = withdrawal_percentage.clamp(1, 100) as usize;
let total_ready = sessions.len();
let target_count = if pct == 100 {
sessions.len()
} else {
let target = (sessions.len() * pct).div_ceil(100);
target.clamp(1, sessions.len())
};
let partial_percentages: Vec<u8> = {
use rand::Rng;
let mut rng = rand::thread_rng();
(0..partial_count).map(|_| rng.gen_range(30..=70)).collect()
};
{
let mut rng = rand::thread_rng();
sessions.shuffle(&mut rng);
}
for session in sessions.iter().skip(target_count) {
if let Err(e) = self
.deposit_repo
.update_status(session.id, DepositStatus::Completed, None)
.await
{
warn!(session_id = %session.id, error = %e, "Failed to reset excess session status");
}
}
sessions.truncate(target_count);
let mut partial_assigned = 0usize;
let withdrawal_plans: Vec<_> = sessions
.iter()
.map(|session| {
let remaining = session.remaining_lamports();
let should_partial = partial_assigned < partial_count
&& remaining >= partial_min_lamports
&& partial_count > 0;
if should_partial {
let pct = partial_percentages[partial_assigned];
partial_assigned += 1;
let partial_amount = (remaining * pct as i64 / 100)
.max(min_lamports)
.min(remaining);
(session, partial_amount, true)
} else {
(session, remaining, false)
}
})
.collect();
info!(
ready_count = total_ready,
processing = withdrawal_plans.len(),
partial_count = partial_assigned,
percentage = pct,
"Processing pending withdrawals"
);
for (session, amount, is_partial) in withdrawal_plans {
if let Err(e) = self
.process_withdrawal(session, amount as u64, is_partial)
.await
{
warn!(
session_id = %session.id,
user_id = %session.user_id,
error = %e,
"Failed to process withdrawal"
);
}
}
Ok(())
}
async fn process_withdrawal(
&self,
session: &DepositSessionEntity,
amount_lamports: u64,
is_partial: bool,
) -> Result<(), AppError> {
let session_id = session.id;
let user_id = session.user_id;
let timeout_secs = self.get_timeout().await;
let max_retries = self.get_max_retries().await;
debug!(
session_id = %session_id,
user_id = %user_id,
amount_lamports = amount_lamports,
is_partial = is_partial,
"Processing withdrawal"
);
let encrypted_data = session.stored_share_b.as_ref().ok_or_else(|| {
AppError::Internal(anyhow::anyhow!(
"Session {} missing encrypted private key",
session_id
))
})?;
let private_key_bytes = decrypt_base64_payload(
self.note_encryption.as_ref(),
encrypted_data,
"Failed to decode encrypted private key",
"Invalid encrypted private key format",
)?;
let mut private_key = String::from_utf8(private_key_bytes).map_err(|e| {
let mut bytes = e.into_bytes();
bytes.zeroize();
AppError::Internal(anyhow::anyhow!("Invalid private key encoding"))
})?;
let company_currency = self.config.company_currency.to_uppercase();
if company_currency != "SOL" {
warn!(
company_currency = %self.config.company_currency,
"Company currency swap on withdrawal not supported; withdrawing SOL"
);
}
let target_currency: Option<&str> = None;
let withdrawal_result = tokio::time::timeout(
Duration::from_secs(timeout_secs),
self.sidecar
.withdraw(&private_key, amount_lamports, target_currency),
)
.await
.map_err(|_| {
AppError::Internal(anyhow::anyhow!(
"Withdrawal timed out after {}s",
timeout_secs
))
})?;
private_key.zeroize();
let withdrawal_response = match withdrawal_result {
Ok(response) => response,
Err(e) => {
let error_msg = e.to_string();
let attempts = session.processing_attempts.saturating_add(1);
let record_ok = match self
.deposit_repo
.record_processing_attempt(session_id, Some(&error_msg))
.await
{
Ok(_) => true,
Err(update_err) => {
warn!(
session_id = %session_id,
error = %update_err,
"Failed to record withdrawal attempt"
);
false
}
};
let status = if !record_ok || attempts >= max_retries as i32 {
DepositStatus::Failed
} else {
DepositStatus::PendingRetry
};
if let Err(update_err) = self
.deposit_repo
.update_status(session_id, status, Some(error_msg))
.await
{
warn!(
session_id = %session_id,
error = %update_err,
"Failed to update withdrawal status"
);
}
return Err(e);
}
};
let withdrawn_amount = Self::compute_spent_lamports(
withdrawal_response.amount_lamports,
withdrawal_response.fee_lamports,
)?;
let deposit_amount = session.deposit_amount_lamports.unwrap_or(0);
let previous_withdrawn = session.withdrawn_amount_lamports;
let cumulative = previous_withdrawn + withdrawn_amount;
let remaining = (deposit_amount - cumulative).max(0);
let fully_withdrawn = self
.deposit_repo
.record_partial_withdrawal(
session_id,
withdrawn_amount,
&withdrawal_response.tx_signature,
)
.await?;
let withdrawal_pct = if is_partial {
let remaining_before = deposit_amount - previous_withdrawn;
let pct = if remaining_before > 0 {
(withdrawn_amount * 100 / remaining_before) as i16
} else {
100
};
Some(pct)
} else {
Some(100)
};
let history_entry = WithdrawalHistoryEntry::new(
session_id,
user_id,
withdrawn_amount,
withdrawal_response.tx_signature.clone(),
cumulative,
remaining,
fully_withdrawn,
withdrawal_pct,
);
if let Err(e) = self.withdrawal_history_repo.create(history_entry).await {
warn!(
session_id = %session_id,
error = %e,
"Failed to record withdrawal history entry"
);
}
info!(
session_id = %session_id,
user_id = %user_id,
tx_signature = %withdrawal_response.tx_signature,
requested_lamports = %amount_lamports,
actual_lamports = %withdrawal_response.amount_lamports,
fee_lamports = %withdrawal_response.fee_lamports,
spent_lamports = %withdrawn_amount,
intentional_partial = %is_partial,
fully_withdrawn = %fully_withdrawn,
"Withdrawal completed successfully"
);
if withdrawal_response.is_partial && !is_partial {
let alert = AdminNotification::new(
NotificationSeverity::Warn,
"Unexpected Partial Withdrawal",
&format!(
"Session {} for user {} had unexpected partial withdrawal (insufficient balance): \
requested {} lamports, withdrew {} lamports",
session_id, user_id, amount_lamports, withdrawal_response.amount_lamports
),
);
let _ = self.notification_service.notify(alert).await;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::clamp_partial_withdrawal_count;
use super::*;
use crate::repositories::{
InMemoryDepositRepository, InMemorySystemSettingsRepository,
InMemoryWithdrawalHistoryRepository,
};
use crate::services::LogNotificationService;
use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
use std::sync::Arc;
use uuid::Uuid;
fn build_worker(
repo: Arc<InMemoryDepositRepository>,
note_encryption: Arc<NoteEncryptionService>,
settings_service: Arc<SettingsService>,
) -> WithdrawalWorker {
let sidecar = PrivacySidecarClient::new(crate::services::SidecarClientConfig {
base_url: "http://127.0.0.1:1".to_string(),
timeout_ms: 10,
api_key: "test".to_string(),
})
.unwrap();
WithdrawalWorker::new(
repo,
Arc::new(InMemoryWithdrawalHistoryRepository::new()),
Arc::new(sidecar),
note_encryption,
Arc::new(LogNotificationService::new()),
settings_service,
WithdrawalWorkerConfig::default(),
)
}
fn build_settings_service() -> Arc<SettingsService> {
let repo = Arc::new(InMemorySystemSettingsRepository::with_defaults());
Arc::new(SettingsService::new(repo))
}
fn build_session(note_encryption: &NoteEncryptionService) -> DepositSessionEntity {
let encrypted = note_encryption.encrypt(b"test-private-key").unwrap();
let mut combined = encrypted.nonce;
combined.extend(encrypted.ciphertext);
let stored_share_b = BASE64.encode(combined);
DepositSessionEntity::new_privacy_deposit(
Uuid::new_v4(),
Uuid::new_v4(),
"Wallet".to_string(),
1_000_000_000,
"tx_sig".to_string(),
stored_share_b,
chrono::Utc::now() - chrono::Duration::hours(1),
)
}
#[tokio::test]
async fn test_settings_defaults() {
let settings = build_settings_service();
assert_eq!(
settings
.get_u64("withdrawal_poll_interval_secs")
.await
.unwrap(),
Some(3600)
);
assert_eq!(
settings.get_u32("withdrawal_batch_size").await.unwrap(),
Some(10)
);
assert_eq!(
settings.get_u64("withdrawal_timeout_secs").await.unwrap(),
Some(120)
);
assert_eq!(
settings.get_u32("withdrawal_max_retries").await.unwrap(),
Some(3)
);
}
#[test]
fn test_compute_spent_lamports_includes_fee() {
let spent = WithdrawalWorker::compute_spent_lamports(100, 7).unwrap();
assert_eq!(spent, 107);
}
#[test]
fn test_compute_spent_lamports_rejects_negative_fee() {
assert!(WithdrawalWorker::compute_spent_lamports(100, -1).is_err());
}
#[tokio::test]
async fn test_withdrawal_failure_resets_to_completed_for_retry() {
let repo = Arc::new(InMemoryDepositRepository::new());
let note_encryption = Arc::new(NoteEncryptionService::new(&[7u8; 32], "v1").unwrap());
let settings = build_settings_service();
let worker = build_worker(repo.clone(), note_encryption.clone(), settings);
let session = build_session(note_encryption.as_ref());
let session_id = session.id;
repo.create(session).await.unwrap();
let session = repo.find_by_id(session_id).await.unwrap().unwrap();
let amount = session.remaining_lamports() as u64;
let result = worker.process_withdrawal(&session, amount, false).await;
assert!(result.is_err());
let updated = repo.find_by_id(session_id).await.unwrap().unwrap();
assert_eq!(updated.status, DepositStatus::PendingRetry);
assert_eq!(updated.processing_attempts, 1);
assert!(updated.last_processing_error.is_some());
}
#[tokio::test]
async fn test_withdrawal_failure_marks_failed_after_max_retries() {
use crate::repositories::{SystemSetting, SystemSettingsRepository};
let repo = Arc::new(InMemoryDepositRepository::new());
let note_encryption = Arc::new(NoteEncryptionService::new(&[9u8; 32], "v1").unwrap());
let settings_repo = Arc::new(InMemorySystemSettingsRepository::with_defaults());
settings_repo
.upsert(SystemSetting {
key: "withdrawal_max_retries".to_string(),
value: "1".to_string(),
category: "withdrawal".to_string(),
description: None,
is_secret: false,
encryption_version: None,
updated_at: chrono::Utc::now(),
updated_by: None,
})
.await
.unwrap();
let settings = Arc::new(SettingsService::new(settings_repo));
let worker = build_worker(repo.clone(), note_encryption.clone(), settings);
let session = build_session(note_encryption.as_ref());
let session_id = session.id;
repo.create(session).await.unwrap();
let session = repo.find_by_id(session_id).await.unwrap().unwrap();
let amount = session.remaining_lamports() as u64;
let result = worker.process_withdrawal(&session, amount, false).await;
assert!(result.is_err());
let updated = repo.find_by_id(session_id).await.unwrap().unwrap();
assert_eq!(updated.status, DepositStatus::Failed);
assert_eq!(updated.processing_attempts, 1);
assert!(updated.last_processing_error.is_some());
}
#[test]
fn test_clamp_partial_withdrawal_count_bounds() {
assert_eq!(clamp_partial_withdrawal_count(0), 0);
assert_eq!(clamp_partial_withdrawal_count(12), 12);
assert_eq!(clamp_partial_withdrawal_count(255), 255);
}
#[test]
fn test_clamp_partial_withdrawal_count_caps_overflow_values() {
assert_eq!(clamp_partial_withdrawal_count(256), 255);
assert_eq!(clamp_partial_withdrawal_count(10_000), 255);
assert_eq!(clamp_partial_withdrawal_count(u32::MAX), 255);
}
}