use axum::{extract::State, http::HeaderMap, Json};
use axum::extract::Query;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use crate::callback::AuthCallback;
use crate::errors::AppError;
use crate::services::EmailService;
use crate::utils::authenticate;
use crate::AppState;
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ReferralResponse {
pub referral_code: String,
pub referral_count: u64,
pub direct_payout_enabled: bool,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct RegenerateReferralResponse {
pub referral_code: String,
}
pub async fn get_referral<C: AuthCallback, E: EmailService>(
State(state): State<Arc<AppState<C, E>>>,
headers: HeaderMap,
) -> Result<Json<ReferralResponse>, AppError> {
let auth = authenticate(&state, &headers).await?;
let enabled = state
.settings_service
.get_bool("feature_referrals_enabled")
.await
.ok()
.flatten()
.unwrap_or(false);
if !enabled {
return Err(AppError::NotFound("Referrals not enabled".into()));
}
let user = state
.user_repo
.find_by_id(auth.user_id)
.await?
.ok_or(AppError::InvalidToken)?;
let count = state.user_repo.count_referrals(auth.user_id).await?;
let direct_payout_enabled = state
.settings_service
.get("referral_reward_type")
.await
.ok()
.flatten()
.map(|v| v == "direct_payout")
.unwrap_or(false);
Ok(Json(ReferralResponse {
referral_code: user.referral_code,
referral_count: count,
direct_payout_enabled,
}))
}
pub async fn regenerate_referral<C: AuthCallback, E: EmailService>(
State(state): State<Arc<AppState<C, E>>>,
headers: HeaderMap,
) -> Result<Json<RegenerateReferralResponse>, AppError> {
let auth = authenticate(&state, &headers).await?;
let enabled = state
.settings_service
.get_bool("feature_referrals_enabled")
.await
.ok()
.flatten()
.unwrap_or(false);
if !enabled {
return Err(AppError::NotFound("Referrals not enabled".into()));
}
let new_code = state
.user_repo
.regenerate_referral_code(auth.user_id)
.await?;
Ok(Json(RegenerateReferralResponse {
referral_code: new_code,
}))
}
#[derive(Debug, Deserialize)]
pub struct SetReferralCodeRequest {
pub code: String,
}
pub async fn set_referral_code<C: AuthCallback, E: EmailService>(
State(state): State<Arc<AppState<C, E>>>,
headers: HeaderMap,
Json(body): Json<SetReferralCodeRequest>,
) -> Result<Json<RegenerateReferralResponse>, AppError> {
let auth = authenticate(&state, &headers).await?;
let enabled = state
.settings_service
.get_bool("feature_referrals_enabled")
.await
.ok()
.flatten()
.unwrap_or(false);
if !enabled {
return Err(AppError::NotFound("Referrals not enabled".into()));
}
let code = body.code.to_uppercase();
validate_vanity_code(&code)?;
if let Some(existing) = state.user_repo.find_by_referral_code(&code).await? {
if existing.id != auth.user_id {
return Err(AppError::Validation("Referral code is already taken".into()));
}
return Ok(Json(RegenerateReferralResponse { referral_code: code }));
}
state.user_repo.set_referral_code(auth.user_id, &code).await?;
Ok(Json(RegenerateReferralResponse { referral_code: code }))
}
fn validate_vanity_code(code: &str) -> Result<(), AppError> {
let len = code.len();
if len < 4 || len > 16 {
return Err(AppError::Validation(
"Referral code must be 4–16 characters".into(),
));
}
if !code.chars().all(|c| c.is_ascii_uppercase() || c.is_ascii_digit()) {
return Err(AppError::Validation(
"Referral code must contain only uppercase letters and digits (A-Z, 0-9)".into(),
));
}
Ok(())
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct RewardsInfoResponse {
pub total_earned: i64,
pub pending_amount: i64,
pub pending_count: u64,
pub currency: String,
pub reward_type: String,
pub payout_wallet_address: Option<String>,
pub referral_count: u64,
}
pub async fn get_rewards_info<C: AuthCallback, E: EmailService>(
State(state): State<Arc<AppState<C, E>>>,
headers: HeaderMap,
) -> Result<Json<RewardsInfoResponse>, AppError> {
let auth = authenticate(&state, &headers).await?;
let enabled = state
.settings_service
.get_bool("feature_referrals_enabled")
.await
.ok()
.flatten()
.unwrap_or(false);
if !enabled {
return Err(AppError::NotFound("Referrals not enabled".into()));
}
let user = state
.user_repo
.find_by_id(auth.user_id)
.await?
.ok_or(AppError::InvalidToken)?;
let reward_type = state
.settings_service
.get("referral_reward_type")
.await
.ok()
.flatten()
.unwrap_or_else(|| "credits".to_string());
let currency = state
.settings_service
.get("referral_reward_currency")
.await
.ok()
.flatten()
.unwrap_or_else(|| "USD".to_string());
let referral_count = state.user_repo.count_referrals(auth.user_id).await?;
let (total_earned, pending_amount, pending_count) = if reward_type == "direct_payout" {
let total = state
.referral_payout_repo
.sum_for_referrer(auth.user_id)
.await?;
let pending_sum = state
.referral_payout_repo
.sum_by_status_for_referrer(auth.user_id, "pending")
.await
.unwrap_or(0);
let p_count = state
.referral_payout_repo
.count_by_referrer(auth.user_id, Some("pending"))
.await?;
(total, pending_sum, p_count)
} else {
let total = state
.credit_repo
.sum_adjustments_by_reference_type_prefix(auth.user_id, ¤cy, "referral_")
.await?;
(total, 0i64, 0u64)
};
Ok(Json(RewardsInfoResponse {
total_earned,
pending_amount,
pending_count,
currency,
reward_type,
payout_wallet_address: user.payout_wallet_address,
referral_count,
}))
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct RewardHistoryItem {
pub id: String,
pub trigger_type: String,
pub amount: i64,
pub currency: String,
pub status: String,
pub tx_signature: Option<String>,
pub created_at: String,
pub completed_at: Option<String>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct RewardsHistoryResponse {
pub items: Vec<RewardHistoryItem>,
pub total: u64,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RewardsHistoryQuery {
#[serde(default = "default_limit")]
pub limit: u32,
#[serde(default)]
pub offset: u32,
}
fn default_limit() -> u32 {
20
}
pub async fn get_rewards_history<C: AuthCallback, E: EmailService>(
State(state): State<Arc<AppState<C, E>>>,
headers: HeaderMap,
Query(query): Query<RewardsHistoryQuery>,
) -> Result<Json<RewardsHistoryResponse>, AppError> {
let auth = authenticate(&state, &headers).await?;
let enabled = state
.settings_service
.get_bool("feature_referrals_enabled")
.await
.ok()
.flatten()
.unwrap_or(false);
if !enabled {
return Err(AppError::NotFound("Referrals not enabled".into()));
}
let reward_type = state
.settings_service
.get("referral_reward_type")
.await
.ok()
.flatten()
.unwrap_or_else(|| "credits".to_string());
let limit = query.limit.min(100);
if reward_type == "direct_payout" {
let payouts = state
.referral_payout_repo
.list_by_referrer(auth.user_id, None, limit, query.offset)
.await?;
let total = state
.referral_payout_repo
.count_by_referrer(auth.user_id, None)
.await?;
let items = payouts
.into_iter()
.map(|p| RewardHistoryItem {
id: p.id.to_string(),
trigger_type: p.trigger_type,
amount: p.amount,
currency: p.currency,
status: p.status,
tx_signature: p.tx_signature,
created_at: p.created_at.to_rfc3339(),
completed_at: p.completed_at.map(|t| t.to_rfc3339()),
})
.collect();
return Ok(Json(RewardsHistoryResponse { items, total }));
}
let currency = state
.settings_service
.get("referral_reward_currency")
.await
.ok()
.flatten()
.unwrap_or_else(|| "USD".to_string());
let txs = state
.credit_repo
.list_by_reference_type_prefix(auth.user_id, ¤cy, "referral_", limit, query.offset)
.await?;
let total = state
.credit_repo
.count_by_reference_type_prefix(auth.user_id, ¤cy, "referral_")
.await?;
let items = txs
.into_iter()
.map(|t| {
let trigger_type = t
.reference_type
.as_deref()
.and_then(|rt| rt.strip_prefix("referral_"))
.unwrap_or("unknown")
.to_string();
RewardHistoryItem {
id: t.id.to_string(),
trigger_type,
amount: t.amount,
currency: t.currency,
status: "credited".to_string(),
tx_signature: None,
created_at: t.created_at.to_rfc3339(),
completed_at: Some(t.created_at.to_rfc3339()),
}
})
.collect();
Ok(Json(RewardsHistoryResponse { items, total }))
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SetPayoutWalletRequest {
pub wallet_address: Option<String>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct SetPayoutWalletResponse {
pub ok: bool,
pub wallet_address: Option<String>,
}
fn validate_solana_wallet(address: &str) -> Result<(), AppError> {
let decoded = bs58::decode(address)
.into_vec()
.map_err(|_| AppError::Validation("Invalid wallet address: not valid base58".into()))?;
if decoded.len() != 32 {
return Err(AppError::Validation(
"Invalid wallet address: must be a 32-byte Solana public key".into(),
));
}
Ok(())
}
pub async fn set_payout_wallet<C: AuthCallback, E: EmailService>(
State(state): State<Arc<AppState<C, E>>>,
headers: HeaderMap,
Json(body): Json<SetPayoutWalletRequest>,
) -> Result<Json<SetPayoutWalletResponse>, AppError> {
let auth = authenticate(&state, &headers).await?;
let enabled = state
.settings_service
.get_bool("feature_referrals_enabled")
.await
.ok()
.flatten()
.unwrap_or(false);
if !enabled {
return Err(AppError::NotFound("Referrals not enabled".into()));
}
if let Some(ref addr) = body.wallet_address {
validate_solana_wallet(addr)?;
}
state
.user_repo
.set_payout_wallet_address(auth.user_id, body.wallet_address.as_deref())
.await?;
Ok(Json(SetPayoutWalletResponse {
ok: true,
wallet_address: body.wallet_address,
}))
}