use std::net::IpAddr;
use chrono::{DateTime, Utc};
use ipnetwork::IpNetwork;
use serde::Serialize;
use sqlx::postgres::{PgPool, PgRow};
use sqlx::{Postgres, Row, Transaction};
#[derive(Debug, Clone, Serialize)]
pub struct IpRule {
pub id: String,
pub addr: String,
pub label: Option<String>,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize)]
pub struct ApiKeyIpPolicy {
pub api_key_id: i64,
pub virgin_mode: bool,
pub virgin_resolved: bool,
pub virgin_resolved_at: Option<DateTime<Utc>>,
pub virgin_until_n_requests: i32,
pub virgin_request_count: i32,
pub max_whitelist_ips: i32,
pub whitelist: Vec<IpRule>,
pub blacklist: Vec<IpRule>,
}
#[derive(Debug, Clone, Serialize)]
pub struct SeenIpRecord {
pub id: String,
pub addr: String,
pub first_seen_at: DateTime<Utc>,
pub last_seen_at: DateTime<Utc>,
pub hit_count: i64,
pub locked_in: bool,
}
#[derive(Debug, Clone)]
pub struct SeenIpUpdate {
pub distinct_ip_count: i64,
pub hit_count_for_ip: i64,
pub was_new: bool,
}
fn map_ip_rule(row: &PgRow, id_column: &str) -> Result<IpRule, sqlx::Error> {
let id: uuid::Uuid = row.try_get(id_column)?;
let addr: IpNetwork = row.try_get("addr")?;
Ok(IpRule {
id: id.to_string(),
addr: addr.to_string(),
label: row.try_get("label")?,
created_at: row.try_get("created_at")?,
})
}
pub async fn load_api_key_ip_policy(
pool: &PgPool,
api_key_id: i64,
) -> Result<Option<ApiKeyIpPolicy>, sqlx::Error> {
let key_row: Option<PgRow> = sqlx::query(
r#"
SELECT
id,
virgin_mode,
virgin_resolved,
virgin_resolved_at,
virgin_until_n_requests,
virgin_request_count,
max_whitelist_ips
FROM api_keys
WHERE id = $1
"#,
)
.bind(api_key_id)
.fetch_optional(pool)
.await?;
let Some(key_row) = key_row else {
return Ok(None);
};
let whitelist_rows: Vec<PgRow> = sqlx::query(
r#"
SELECT api_key_ip_whitelist_id AS id, addr, label, created_at
FROM api_key_ip_whitelist
WHERE api_key_id = $1
ORDER BY created_at
"#,
)
.bind(api_key_id)
.fetch_all(pool)
.await?;
let blacklist_rows: Vec<PgRow> = sqlx::query(
r#"
SELECT api_key_ip_blacklist_id AS id, addr, label, created_at
FROM api_key_ip_blacklist
WHERE api_key_id = $1
ORDER BY created_at
"#,
)
.bind(api_key_id)
.fetch_all(pool)
.await?;
let whitelist: Vec<IpRule> = whitelist_rows
.iter()
.map(|row| map_ip_rule(row, "id"))
.collect::<Result<_, _>>()?;
let blacklist: Vec<IpRule> = blacklist_rows
.iter()
.map(|row| map_ip_rule(row, "id"))
.collect::<Result<_, _>>()?;
Ok(Some(ApiKeyIpPolicy {
api_key_id: key_row.try_get("id")?,
virgin_mode: key_row.try_get("virgin_mode")?,
virgin_resolved: key_row.try_get("virgin_resolved")?,
virgin_resolved_at: key_row.try_get("virgin_resolved_at")?,
virgin_until_n_requests: key_row.try_get("virgin_until_n_requests")?,
virgin_request_count: key_row.try_get("virgin_request_count")?,
max_whitelist_ips: key_row.try_get("max_whitelist_ips")?,
whitelist,
blacklist,
}))
}
pub async fn load_global_ip_rules(
pool: &PgPool,
client_name: Option<&str>,
) -> Result<(Vec<IpRule>, Vec<IpRule>), sqlx::Error> {
let whitelist_rows: Vec<PgRow> = sqlx::query(
r#"
SELECT api_key_ip_global_whitelist_id AS id, addr, label, created_at
FROM api_key_ip_global_whitelist
WHERE client_name IS NULL OR client_name = $1
ORDER BY created_at
"#,
)
.bind(client_name)
.fetch_all(pool)
.await?;
let blacklist_rows: Vec<PgRow> = sqlx::query(
r#"
SELECT api_key_ip_global_blacklist_id AS id, addr, label, created_at
FROM api_key_ip_global_blacklist
WHERE client_name IS NULL OR client_name = $1
ORDER BY created_at
"#,
)
.bind(client_name)
.fetch_all(pool)
.await?;
let whitelist: Vec<IpRule> = whitelist_rows
.iter()
.map(|row| map_ip_rule(row, "id"))
.collect::<Result<_, _>>()?;
let blacklist: Vec<IpRule> = blacklist_rows
.iter()
.map(|row| map_ip_rule(row, "id"))
.collect::<Result<_, _>>()?;
Ok((whitelist, blacklist))
}
pub async fn record_seen_ip(
pool: &PgPool,
api_key_id: i64,
ip: IpAddr,
) -> Result<SeenIpUpdate, sqlx::Error> {
let addr: IpNetwork = IpNetwork::from(ip);
let mut tx: Transaction<'_, Postgres> = pool.begin().await?;
let row: PgRow = sqlx::query(
r#"
INSERT INTO api_key_ip_seen (api_key_id, addr, ipv4_address, first_seen_at, last_seen_at, hit_count)
VALUES ($1, $2, $3, now(), now(), 1)
ON CONFLICT (api_key_id, addr)
DO UPDATE SET
ipv4_address = COALESCE(api_key_ip_seen.ipv4_address, EXCLUDED.ipv4_address),
last_seen_at = now(),
hit_count = api_key_ip_seen.hit_count + 1
RETURNING hit_count, (xmax = 0) AS was_new
"#,
)
.bind(api_key_id)
.bind(addr)
.bind(ip.to_string())
.fetch_one(&mut *tx)
.await?;
let hit_count_for_ip: i64 = row.try_get("hit_count")?;
let was_new: bool = row.try_get("was_new")?;
let distinct_ip_count: i64 = sqlx::query_scalar(
r#"
SELECT COUNT(*)::bigint
FROM api_key_ip_seen
WHERE api_key_id = $1
"#,
)
.bind(api_key_id)
.fetch_one(&mut *tx)
.await?;
tx.commit().await?;
Ok(SeenIpUpdate {
distinct_ip_count,
hit_count_for_ip,
was_new,
})
}
pub async fn increment_virgin_request_count(
pool: &PgPool,
api_key_id: i64,
) -> Result<i32, sqlx::Error> {
let row: PgRow = sqlx::query(
r#"
UPDATE api_keys
SET
virgin_request_count = virgin_request_count + 1,
updated_at = now()
WHERE id = $1
RETURNING virgin_request_count
"#,
)
.bind(api_key_id)
.fetch_one(pool)
.await?;
row.try_get("virgin_request_count")
}
pub async fn promote_seen_ips_to_whitelist(
pool: &PgPool,
api_key_id: i64,
max_take: Option<i32>,
) -> Result<i64, sqlx::Error> {
let limit: Option<i32> = max_take.filter(|value| *value > 0);
let row: PgRow = sqlx::query(
r#"
WITH candidates AS (
SELECT addr
FROM api_key_ip_seen
WHERE api_key_id = $1
ORDER BY first_seen_at
LIMIT COALESCE($2, 2147483647)
),
inserted AS (
INSERT INTO api_key_ip_whitelist (api_key_id, addr, label)
SELECT $1, addr::cidr, 'virgin_learned'
FROM candidates
ON CONFLICT (api_key_id, addr) DO NOTHING
RETURNING 1
)
SELECT COUNT(*)::bigint AS promoted FROM inserted
"#,
)
.bind(api_key_id)
.bind(limit)
.fetch_one(pool)
.await?;
row.try_get("promoted")
}
pub async fn mark_virgin_resolved(pool: &PgPool, api_key_id: i64) -> Result<(), sqlx::Error> {
let mut tx: Transaction<'_, Postgres> = pool.begin().await?;
sqlx::query(
r#"
UPDATE api_keys
SET
virgin_resolved = true,
virgin_resolved_at = COALESCE(virgin_resolved_at, now()),
updated_at = now()
WHERE id = $1
"#,
)
.bind(api_key_id)
.execute(&mut *tx)
.await?;
sqlx::query(
r#"
UPDATE api_key_ip_seen
SET locked_in = true
WHERE api_key_id = $1
"#,
)
.bind(api_key_id)
.execute(&mut *tx)
.await?;
tx.commit().await
}
pub async fn reset_virgin_state(
pool: &PgPool,
api_key_id: i64,
clear_seen: bool,
) -> Result<(), sqlx::Error> {
let mut tx: Transaction<'_, Postgres> = pool.begin().await?;
sqlx::query(
r#"
UPDATE api_keys
SET
virgin_resolved = false,
virgin_resolved_at = NULL,
virgin_request_count = 0,
updated_at = now()
WHERE id = $1
"#,
)
.bind(api_key_id)
.execute(&mut *tx)
.await?;
if clear_seen {
sqlx::query(
r#"
DELETE FROM api_key_ip_seen
WHERE api_key_id = $1
"#,
)
.bind(api_key_id)
.execute(&mut *tx)
.await?;
} else {
sqlx::query(
r#"
UPDATE api_key_ip_seen
SET locked_in = false
WHERE api_key_id = $1
"#,
)
.bind(api_key_id)
.execute(&mut *tx)
.await?;
}
tx.commit().await
}
pub async fn list_seen_ips(
pool: &PgPool,
api_key_id: i64,
limit: i32,
offset: i32,
) -> Result<Vec<SeenIpRecord>, sqlx::Error> {
let limit: i32 = limit.clamp(1, 1000);
let offset: i32 = offset.max(0);
let rows: Vec<PgRow> = sqlx::query(
r#"
SELECT api_key_ip_seen_id AS id, addr, first_seen_at, last_seen_at, hit_count, locked_in
FROM api_key_ip_seen
WHERE api_key_id = $1
ORDER BY last_seen_at DESC
LIMIT $2 OFFSET $3
"#,
)
.bind(api_key_id)
.bind(limit)
.bind(offset)
.fetch_all(pool)
.await?;
rows.into_iter()
.map(|row| {
let id: uuid::Uuid = row.try_get("id")?;
let addr: IpNetwork = row.try_get("addr")?;
Ok(SeenIpRecord {
id: id.to_string(),
addr: addr.to_string(),
first_seen_at: row.try_get("first_seen_at")?,
last_seen_at: row.try_get("last_seen_at")?,
hit_count: row.try_get("hit_count")?,
locked_in: row.try_get("locked_in")?,
})
})
.collect()
}
pub async fn upsert_whitelist_entries(
pool: &PgPool,
api_key_id: i64,
addrs: &[IpNetwork],
label: Option<&str>,
) -> Result<i64, sqlx::Error> {
insert_ip_entries(pool, "api_key_ip_whitelist", api_key_id, addrs, label).await
}
pub async fn upsert_blacklist_entries(
pool: &PgPool,
api_key_id: i64,
addrs: &[IpNetwork],
label: Option<&str>,
) -> Result<i64, sqlx::Error> {
insert_ip_entries(pool, "api_key_ip_blacklist", api_key_id, addrs, label).await
}
pub async fn delete_whitelist_entries(
pool: &PgPool,
api_key_id: i64,
addrs: &[IpNetwork],
) -> Result<i64, sqlx::Error> {
delete_ip_entries(pool, "api_key_ip_whitelist", api_key_id, addrs).await
}
pub async fn delete_blacklist_entries(
pool: &PgPool,
api_key_id: i64,
addrs: &[IpNetwork],
) -> Result<i64, sqlx::Error> {
delete_ip_entries(pool, "api_key_ip_blacklist", api_key_id, addrs).await
}
async fn insert_ip_entries(
pool: &PgPool,
table: &str,
api_key_id: i64,
addrs: &[IpNetwork],
label: Option<&str>,
) -> Result<i64, sqlx::Error> {
if addrs.is_empty() {
return Ok(0);
}
let sql: String = format!(
r#"
INSERT INTO {table} (api_key_id, addr, label)
SELECT $1, addr, $3
FROM UNNEST($2::cidr[]) AS t(addr)
ON CONFLICT (api_key_id, addr) DO NOTHING
"#,
);
let result = sqlx::query(&sql)
.bind(api_key_id)
.bind(addrs)
.bind(label)
.execute(pool)
.await?;
Ok(result.rows_affected() as i64)
}
async fn delete_ip_entries(
pool: &PgPool,
table: &str,
api_key_id: i64,
addrs: &[IpNetwork],
) -> Result<i64, sqlx::Error> {
if addrs.is_empty() {
return Ok(0);
}
let sql: String = format!(
r#"
DELETE FROM {table}
WHERE api_key_id = $1
AND addr = ANY($2::cidr[])
"#,
);
let result = sqlx::query(&sql)
.bind(api_key_id)
.bind(addrs)
.execute(pool)
.await?;
Ok(result.rows_affected() as i64)
}
#[derive(Debug, Clone, Serialize)]
pub struct GlobalIpRule {
pub id: String,
pub addr: String,
pub client_name: Option<String>,
pub label: Option<String>,
pub created_at: DateTime<Utc>,
}
pub async fn list_global_whitelist(pool: &PgPool) -> Result<Vec<GlobalIpRule>, sqlx::Error> {
list_global_rules(pool, "api_key_ip_global_whitelist").await
}
pub async fn list_global_blacklist(pool: &PgPool) -> Result<Vec<GlobalIpRule>, sqlx::Error> {
list_global_rules(pool, "api_key_ip_global_blacklist").await
}
async fn list_global_rules(pool: &PgPool, table: &str) -> Result<Vec<GlobalIpRule>, sqlx::Error> {
let id_column = format!("{}_id", table);
let sql: String = format!(
r#"
SELECT {id_column} AS id, addr, client_name, label, created_at
FROM {table}
ORDER BY created_at
"#,
);
let rows: Vec<PgRow> = sqlx::query(&sql).fetch_all(pool).await?;
rows.into_iter()
.map(|row| {
let id: uuid::Uuid = row.try_get("id")?;
let addr: IpNetwork = row.try_get("addr")?;
Ok(GlobalIpRule {
id: id.to_string(),
addr: addr.to_string(),
client_name: row.try_get("client_name")?,
label: row.try_get("label")?,
created_at: row.try_get("created_at")?,
})
})
.collect()
}
pub async fn insert_global_whitelist_entry(
pool: &PgPool,
addr: IpNetwork,
client_name: Option<&str>,
label: Option<&str>,
) -> Result<GlobalIpRule, sqlx::Error> {
insert_global_entry(
pool,
"api_key_ip_global_whitelist",
addr,
client_name,
label,
)
.await
}
pub async fn insert_global_blacklist_entry(
pool: &PgPool,
addr: IpNetwork,
client_name: Option<&str>,
label: Option<&str>,
) -> Result<GlobalIpRule, sqlx::Error> {
insert_global_entry(
pool,
"api_key_ip_global_blacklist",
addr,
client_name,
label,
)
.await
}
async fn insert_global_entry(
pool: &PgPool,
table: &str,
addr: IpNetwork,
client_name: Option<&str>,
label: Option<&str>,
) -> Result<GlobalIpRule, sqlx::Error> {
let id_column = format!("{}_id", table);
let sql: String = format!(
r#"
INSERT INTO {table} (addr, client_name, label)
VALUES ($1, $2, $3)
ON CONFLICT (addr, client_name)
DO UPDATE SET label = EXCLUDED.label
RETURNING {id_column} AS id, addr, client_name, label, created_at
"#,
);
let row: PgRow = sqlx::query(&sql)
.bind(addr)
.bind(client_name)
.bind(label)
.fetch_one(pool)
.await?;
let id: uuid::Uuid = row.try_get("id")?;
let stored_addr: IpNetwork = row.try_get("addr")?;
Ok(GlobalIpRule {
id: id.to_string(),
addr: stored_addr.to_string(),
client_name: row.try_get("client_name")?,
label: row.try_get("label")?,
created_at: row.try_get("created_at")?,
})
}
pub async fn delete_global_whitelist_entry(
pool: &PgPool,
addr: IpNetwork,
client_name: Option<&str>,
) -> Result<bool, sqlx::Error> {
delete_global_entry(pool, "api_key_ip_global_whitelist", addr, client_name).await
}
pub async fn delete_global_blacklist_entry(
pool: &PgPool,
addr: IpNetwork,
client_name: Option<&str>,
) -> Result<bool, sqlx::Error> {
delete_global_entry(pool, "api_key_ip_global_blacklist", addr, client_name).await
}
async fn delete_global_entry(
pool: &PgPool,
table: &str,
addr: IpNetwork,
client_name: Option<&str>,
) -> Result<bool, sqlx::Error> {
let sql: String = format!(
r#"
DELETE FROM {table}
WHERE addr = $1
AND ((client_name IS NULL AND $2::text IS NULL) OR client_name = $2)
"#,
);
let result = sqlx::query(&sql)
.bind(addr)
.bind(client_name)
.execute(pool)
.await?;
Ok(result.rows_affected() > 0)
}
pub fn ip_matches_any(rules: &[IpRule], ip: IpAddr) -> bool {
rules
.iter()
.any(|rule| match rule.addr.parse::<IpNetwork>() {
Ok(network) => network.contains(ip),
Err(_) => false,
})
}