use async_trait::async_trait;
use chrono::{DateTime, Utc};
use sqlx::PgPool;
use uuid::Uuid;
use crate::errors::AppError;
use crate::repositories::{normalize_email, InviteEntity, InviteRepository, OrgRole};
fn map_invite_error(e: sqlx::Error) -> AppError {
match &e {
sqlx::Error::Database(db_err) => {
if db_err.code().map(|c| c == "23505").unwrap_or(false) {
if let Some(constraint) = db_err.constraint() {
if constraint.contains("org_email") {
return AppError::Validation(
"An invite already exists for this email in this organization".into(),
);
} else if constraint.contains("org_wallet") {
return AppError::Validation(
"An invite already exists for this wallet in this organization".into(),
);
} else if constraint.contains("token") {
return AppError::Validation(
"Invite token collision - please retry".into(),
);
}
}
return AppError::Validation("Duplicate invite".into());
}
AppError::Database(e.to_string())
}
sqlx::Error::RowNotFound => AppError::NotFound("Invite not found".into()),
_ => AppError::Database(e.to_string()),
}
}
pub struct PostgresInviteRepository {
pool: PgPool,
}
impl PostgresInviteRepository {
pub fn new(pool: PgPool) -> Self {
Self { pool }
}
}
#[derive(sqlx::FromRow)]
struct InviteRow {
id: Uuid,
org_id: Uuid,
email: Option<String>,
wallet_address: Option<String>,
role: String,
token_hash: String,
invited_by: Uuid,
created_at: DateTime<Utc>,
expires_at: DateTime<Utc>,
accepted_at: Option<DateTime<Utc>>,
}
impl TryFrom<InviteRow> for InviteEntity {
type Error = AppError;
fn try_from(row: InviteRow) -> Result<Self, Self::Error> {
let role = OrgRole::from_str(&row.role)
.ok_or_else(|| AppError::Database(format!("Invalid role: {}", row.role)))?;
Ok(Self {
id: row.id,
org_id: row.org_id,
email: row.email,
wallet_address: row.wallet_address,
role,
token_hash: row.token_hash,
invited_by: row.invited_by,
created_at: row.created_at,
expires_at: row.expires_at,
accepted_at: row.accepted_at,
})
}
}
#[async_trait]
impl InviteRepository for PostgresInviteRepository {
async fn find_by_id(&self, id: Uuid) -> Result<Option<InviteEntity>, AppError> {
let row: Option<InviteRow> = sqlx::query_as(
r#"
SELECT id, org_id, email, wallet_address, role, token_hash, invited_by, created_at, expires_at, accepted_at
FROM invites WHERE id = $1
"#,
)
.bind(id)
.fetch_optional(&self.pool)
.await
.map_err(|e| AppError::Database(e.to_string()))?;
match row {
Some(r) => Ok(Some(r.try_into()?)),
None => Ok(None),
}
}
async fn find_by_token_hash(&self, token_hash: &str) -> Result<Option<InviteEntity>, AppError> {
let row: Option<InviteRow> = sqlx::query_as(
r#"
SELECT id, org_id, email, wallet_address, role, token_hash, invited_by, created_at, expires_at, accepted_at
FROM invites WHERE token_hash = $1
"#,
)
.bind(token_hash)
.fetch_optional(&self.pool)
.await
.map_err(|e| AppError::Database(e.to_string()))?;
match row {
Some(r) => Ok(Some(r.try_into()?)),
None => Ok(None),
}
}
async fn find_by_org_and_email(
&self,
org_id: Uuid,
email: &str,
) -> Result<Option<InviteEntity>, AppError> {
let email_normalized = normalize_email(email);
let row: Option<InviteRow> = sqlx::query_as(
r#"
SELECT id, org_id, email, wallet_address, role, token_hash, invited_by, created_at, expires_at, accepted_at
FROM invites WHERE org_id = $1 AND email = $2
"#,
)
.bind(org_id)
.bind(&email_normalized)
.fetch_optional(&self.pool)
.await
.map_err(|e| AppError::Database(e.to_string()))?;
match row {
Some(r) => Ok(Some(r.try_into()?)),
None => Ok(None),
}
}
async fn find_by_org_and_wallet(
&self,
org_id: Uuid,
wallet_address: &str,
) -> Result<Option<InviteEntity>, AppError> {
let row: Option<InviteRow> = sqlx::query_as(
r#"
SELECT id, org_id, email, wallet_address, role, token_hash, invited_by, created_at, expires_at, accepted_at
FROM invites WHERE org_id = $1 AND wallet_address = $2
"#,
)
.bind(org_id)
.bind(wallet_address)
.fetch_optional(&self.pool)
.await
.map_err(|e| AppError::Database(e.to_string()))?;
match row {
Some(r) => Ok(Some(r.try_into()?)),
None => Ok(None),
}
}
async fn find_pending_by_org(&self, org_id: Uuid) -> Result<Vec<InviteEntity>, AppError> {
const PENDING_LIMIT: usize = 1000;
let rows: Vec<InviteRow> = sqlx::query_as(
r#"
SELECT id, org_id, email, wallet_address, role, token_hash, invited_by, created_at, expires_at, accepted_at
FROM invites
WHERE org_id = $1 AND accepted_at IS NULL AND expires_at > NOW()
ORDER BY created_at DESC
LIMIT 1000
"#,
)
.bind(org_id)
.fetch_all(&self.pool)
.await
.map_err(|e| AppError::Database(e.to_string()))?;
if rows.len() >= PENDING_LIMIT {
tracing::warn!(
org_id = %org_id,
count = rows.len(),
limit = PENDING_LIMIT,
"find_pending_by_org hit limit - results may be truncated, use paged variant"
);
}
rows.into_iter().map(TryInto::try_into).collect()
}
async fn find_pending_by_org_paged(
&self,
org_id: Uuid,
limit: u32,
offset: u32,
) -> Result<Vec<InviteEntity>, AppError> {
const MAX_PAGE_SIZE: u32 = 100;
const MAX_OFFSET: u32 = 1_000_000;
let capped_limit = limit.min(MAX_PAGE_SIZE);
let capped_offset = offset.min(MAX_OFFSET);
let rows: Vec<InviteRow> = sqlx::query_as(
r#"
SELECT id, org_id, email, wallet_address, role, token_hash, invited_by, created_at, expires_at, accepted_at
FROM invites
WHERE org_id = $1 AND accepted_at IS NULL AND expires_at > NOW()
ORDER BY created_at DESC
LIMIT $2 OFFSET $3
"#,
)
.bind(org_id)
.bind(capped_limit as i64)
.bind(capped_offset as i64)
.fetch_all(&self.pool)
.await
.map_err(|e| AppError::Database(e.to_string()))?;
rows.into_iter().map(TryInto::try_into).collect()
}
async fn find_pending_by_email(&self, email: &str) -> Result<Vec<InviteEntity>, AppError> {
const PENDING_LIMIT: usize = 1000;
let email_normalized = normalize_email(email);
let rows: Vec<InviteRow> = sqlx::query_as(
r#"
SELECT id, org_id, email, wallet_address, role, token_hash, invited_by, created_at, expires_at, accepted_at
FROM invites
WHERE email = $1 AND accepted_at IS NULL AND expires_at > NOW()
ORDER BY created_at DESC
LIMIT 1000
"#,
)
.bind(&email_normalized)
.fetch_all(&self.pool)
.await
.map_err(|e| AppError::Database(e.to_string()))?;
if rows.len() >= PENDING_LIMIT {
tracing::warn!(
email = %email,
count = rows.len(),
limit = PENDING_LIMIT,
"find_pending_by_email hit limit - results may be truncated"
);
}
rows.into_iter().map(TryInto::try_into).collect()
}
async fn find_pending_by_wallet(
&self,
wallet_address: &str,
) -> Result<Vec<InviteEntity>, AppError> {
const PENDING_LIMIT: usize = 1000;
let rows: Vec<InviteRow> = sqlx::query_as(
r#"
SELECT id, org_id, email, wallet_address, role, token_hash, invited_by, created_at, expires_at, accepted_at
FROM invites
WHERE wallet_address = $1 AND accepted_at IS NULL AND expires_at > NOW()
ORDER BY created_at DESC
LIMIT 1000
"#,
)
.bind(wallet_address)
.fetch_all(&self.pool)
.await
.map_err(|e| AppError::Database(e.to_string()))?;
if rows.len() >= PENDING_LIMIT {
tracing::warn!(
wallet_address = %wallet_address,
count = rows.len(),
limit = PENDING_LIMIT,
"find_pending_by_wallet hit limit - results may be truncated"
);
}
rows.into_iter().map(TryInto::try_into).collect()
}
async fn create(&self, invite: InviteEntity) -> Result<InviteEntity, AppError> {
let row: InviteRow = sqlx::query_as(
r#"
INSERT INTO invites (id, org_id, email, wallet_address, role, token_hash, invited_by, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id, org_id, email, wallet_address, role, token_hash, invited_by, created_at, expires_at, accepted_at
"#,
)
.bind(invite.id)
.bind(invite.org_id)
.bind(&invite.email)
.bind(&invite.wallet_address)
.bind(invite.role.as_str())
.bind(&invite.token_hash)
.bind(invite.invited_by)
.bind(invite.expires_at)
.fetch_one(&self.pool)
.await
.map_err(map_invite_error)?;
row.try_into()
}
async fn mark_accepted(&self, id: Uuid) -> Result<(), AppError> {
let result = sqlx::query(
"UPDATE invites SET accepted_at = NOW() WHERE id = $1 AND accepted_at IS NULL",
)
.bind(id)
.execute(&self.pool)
.await
.map_err(|e| AppError::Database(e.to_string()))?;
if result.rows_affected() == 0 {
return Err(AppError::NotFound(
"Invite not found or already accepted".into(),
));
}
Ok(())
}
async fn mark_accepted_if_valid(&self, id: Uuid) -> Result<Option<InviteEntity>, AppError> {
let row: Option<InviteRow> = sqlx::query_as(
r#"
UPDATE invites
SET accepted_at = NOW()
WHERE id = $1 AND accepted_at IS NULL AND expires_at > NOW()
RETURNING id, org_id, email, wallet_address, role, token_hash, invited_by, created_at, expires_at, accepted_at
"#,
)
.bind(id)
.fetch_optional(&self.pool)
.await
.map_err(|e| AppError::Database(e.to_string()))?;
match row {
Some(r) => Ok(Some(r.try_into()?)),
None => Ok(None),
}
}
async fn unmark_accepted(&self, id: Uuid) -> Result<(), AppError> {
sqlx::query("UPDATE invites SET accepted_at = NULL WHERE id = $1")
.bind(id)
.execute(&self.pool)
.await
.map_err(|e| AppError::Database(e.to_string()))?;
Ok(())
}
async fn delete(&self, id: Uuid) -> Result<(), AppError> {
let result = sqlx::query("DELETE FROM invites WHERE id = $1")
.bind(id)
.execute(&self.pool)
.await
.map_err(|e| AppError::Database(e.to_string()))?;
if result.rows_affected() == 0 {
return Err(AppError::NotFound("Invite not found".into()));
}
Ok(())
}
async fn delete_by_org(&self, org_id: Uuid) -> Result<u64, AppError> {
let result = sqlx::query("DELETE FROM invites WHERE org_id = $1")
.bind(org_id)
.execute(&self.pool)
.await
.map_err(|e| AppError::Database(e.to_string()))?;
Ok(result.rows_affected())
}
async fn delete_expired(&self) -> Result<u64, AppError> {
let result = sqlx::query("DELETE FROM invites WHERE expires_at < NOW()")
.execute(&self.pool)
.await
.map_err(|e| AppError::Database(e.to_string()))?;
Ok(result.rows_affected())
}
async fn count_pending_by_org(&self, org_id: Uuid) -> Result<u64, AppError> {
let count: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM invites WHERE org_id = $1 AND accepted_at IS NULL AND expires_at > NOW()",
)
.bind(org_id)
.fetch_one(&self.pool)
.await
.map_err(|e| AppError::Database(e.to_string()))?;
Ok(count.max(0) as u64)
}
}