use argon2::{
password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, SaltString},
Argon2, PasswordVerifier,
};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use hyperinfer_core::{
ApiKey, ConfigStore, CreateDeploymentRequest, Database, DbError, Deployment, ModelAlias,
PolicyUpdate, Quota, RoutingConfig, Team, UpdateRoutingConfigRequest, UsageLog, User,
};
use serde::Serialize;
use sqlx::PgPool;
pub fn hash_password(password: &str) -> Result<String, DbError> {
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
let password_hash = argon2
.hash_password(password.as_bytes(), &salt)
.map_err(|e| DbError::Sqlx(sqlx::Error::Protocol(e.to_string())))?
.to_string();
Ok(password_hash)
}
pub fn verify_password(password: &str, hash: &str) -> bool {
let parsed_hash = PasswordHash::new(hash).ok();
if let Some(ph) = parsed_hash {
Argon2::default()
.verify_password(password.as_bytes(), &ph)
.is_ok()
} else {
false
}
}
#[derive(Clone)]
pub struct SqlxDb {
pool: PgPool,
}
impl SqlxDb {
pub fn new(pool: PgPool) -> Self {
Self { pool }
}
}
#[async_trait]
impl Database for SqlxDb {
async fn get_team(&self, id: &str) -> Result<Option<Team>, DbError> {
let uuid = uuid::Uuid::parse_str(id).map_err(|_| DbError::InvalidUuid)?;
let result: Option<TeamRow> = sqlx::query_as(
"SELECT id, name, budget_cents, created_at, updated_at FROM teams WHERE id = $1",
)
.bind(uuid)
.fetch_optional(&self.pool)
.await?;
Ok(result.map(Team::from))
}
async fn create_team(&self, name: &str, budget_cents: i64) -> Result<Team, DbError> {
let result: TeamRow = match sqlx::query_as(
"INSERT INTO teams (name, budget_cents) VALUES ($1, $2) RETURNING id, name, budget_cents, created_at, updated_at"
)
.bind(name)
.bind(budget_cents)
.fetch_one(&self.pool)
.await
{
Ok(row) => row,
Err(e) => {
if e.as_database_error().map(|db| db.is_unique_violation()).unwrap_or(false) {
return Err(DbError::UniqueViolation("Team name already exists".to_string()));
}
return Err(DbError::Sqlx(e));
}
};
Ok(Team::from(result))
}
async fn get_user(&self, id: &str) -> Result<Option<User>, DbError> {
let uuid = uuid::Uuid::parse_str(id).map_err(|_| DbError::InvalidUuid)?;
let result: Option<UserRow> = sqlx::query_as(
"SELECT id, team_id, email, role, password_hash, created_at FROM users WHERE id = $1",
)
.bind(uuid)
.fetch_optional(&self.pool)
.await?;
Ok(result.map(User::from))
}
async fn get_user_by_email(&self, email: &str) -> Result<Option<User>, DbError> {
let result: Option<UserRow> = sqlx::query_as(
"SELECT id, team_id, email, role, password_hash, created_at FROM users WHERE email = $1",
)
.bind(email)
.fetch_optional(&self.pool)
.await?;
Ok(result.map(User::from))
}
async fn create_user(
&self,
team_id: &str,
email: &str,
role: &str,
password_hash: Option<String>,
) -> Result<User, DbError> {
let team_uuid = uuid::Uuid::parse_str(team_id).map_err(|_| DbError::InvalidUuid)?;
let result: UserRow = sqlx::query_as(
"INSERT INTO users (team_id, email, role, password_hash) VALUES ($1, $2, $3, $4) RETURNING id, team_id, email, role, password_hash, created_at"
)
.bind(team_uuid)
.bind(email)
.bind(role)
.bind(password_hash)
.fetch_one(&self.pool)
.await?;
Ok(User::from(result))
}
async fn get_api_key(&self, id: &str) -> Result<Option<ApiKey>, DbError> {
let uuid = uuid::Uuid::parse_str(id).map_err(|_| DbError::InvalidUuid)?;
let result: Option<ApiKeyRow> = sqlx::query_as(
"SELECT id, key_hash, user_id, team_id, name, is_active, created_at, expires_at FROM api_keys WHERE id = $1"
)
.bind(uuid)
.fetch_optional(&self.pool)
.await?;
Ok(result.map(ApiKey::from))
}
async fn get_api_key_by_hash(&self, key_hash: &str) -> Result<Option<ApiKey>, DbError> {
let result: Option<ApiKeyRow> = sqlx::query_as(
"SELECT id, key_hash, user_id, team_id, name, is_active, created_at, expires_at FROM api_keys WHERE key_hash = $1 AND is_active = true"
)
.bind(key_hash)
.fetch_optional(&self.pool)
.await?;
Ok(result.map(ApiKey::from))
}
async fn create_api_key(
&self,
key_hash: &str,
user_id: &str,
team_id: &str,
name: Option<String>,
) -> Result<ApiKey, DbError> {
let user_uuid = uuid::Uuid::parse_str(user_id).map_err(|_| DbError::InvalidUuid)?;
let team_uuid = uuid::Uuid::parse_str(team_id).map_err(|_| DbError::InvalidUuid)?;
let result: ApiKeyRow = sqlx::query_as(
"INSERT INTO api_keys (key_hash, user_id, team_id, name) VALUES ($1, $2, $3, $4) RETURNING id, key_hash, user_id, team_id, name, is_active, created_at, expires_at"
)
.bind(key_hash)
.bind(user_uuid)
.bind(team_uuid)
.bind(name.as_deref())
.fetch_one(&self.pool)
.await?;
Ok(ApiKey::from(result))
}
async fn deactivate_api_key(&self, id: &str) -> Result<ApiKey, DbError> {
let uuid = uuid::Uuid::parse_str(id).map_err(|_| DbError::InvalidUuid)?;
let result: ApiKeyRow = sqlx::query_as(
"UPDATE api_keys SET is_active = false WHERE id = $1 RETURNING id, key_hash, user_id, team_id, name, is_active, created_at, expires_at"
)
.bind(uuid)
.fetch_one(&self.pool)
.await?;
Ok(ApiKey::from(result))
}
async fn get_model_alias(&self, id: &str) -> Result<Option<ModelAlias>, DbError> {
let uuid = uuid::Uuid::parse_str(id).map_err(|_| DbError::InvalidUuid)?;
let result: Option<ModelAliasRow> = sqlx::query_as(
"SELECT id, team_id, alias, target_model, provider, created_at FROM model_aliases WHERE id = $1"
)
.bind(uuid)
.fetch_optional(&self.pool)
.await?;
Ok(result.map(ModelAlias::from))
}
async fn create_model_alias(
&self,
team_id: &str,
alias: &str,
target_model: &str,
provider: &str,
) -> Result<ModelAlias, DbError> {
let team_uuid = uuid::Uuid::parse_str(team_id).map_err(|_| DbError::InvalidUuid)?;
let result: ModelAliasRow = sqlx::query_as(
"INSERT INTO model_aliases (team_id, alias, target_model, provider) VALUES ($1, $2, $3, $4) RETURNING id, team_id, alias, target_model, provider, created_at"
)
.bind(team_uuid)
.bind(alias)
.bind(target_model)
.bind(provider)
.fetch_one(&self.pool)
.await?;
Ok(ModelAlias::from(result))
}
async fn get_quota(&self, team_id: &str) -> Result<Option<Quota>, DbError> {
let uuid = uuid::Uuid::parse_str(team_id).map_err(|_| DbError::InvalidUuid)?;
let result: Option<QuotaRow> = sqlx::query_as(
"SELECT id, team_id, rpm_limit, tpm_limit, updated_at FROM quotas WHERE team_id = $1",
)
.bind(uuid)
.fetch_optional(&self.pool)
.await?;
Ok(result.map(Quota::from))
}
async fn create_quota(
&self,
team_id: &str,
rpm_limit: i32,
tpm_limit: i32,
) -> Result<Quota, DbError> {
let team_uuid = uuid::Uuid::parse_str(team_id).map_err(|_| DbError::InvalidUuid)?;
let result: QuotaRow = sqlx::query_as(
"INSERT INTO quotas (team_id, rpm_limit, tpm_limit) VALUES ($1, $2, $3) RETURNING id, team_id, rpm_limit, tpm_limit, updated_at"
)
.bind(team_uuid)
.bind(rpm_limit)
.bind(tpm_limit)
.fetch_one(&self.pool)
.await?;
Ok(Quota::from(result))
}
async fn record_usage(
&self,
team_id: &str,
api_key_id: &str,
model: &str,
input_tokens: i32,
output_tokens: i32,
response_time_ms: i64,
) -> Result<UsageLog, DbError> {
let team_uuid = uuid::Uuid::parse_str(team_id).map_err(|_| DbError::InvalidUuid)?;
let api_key_uuid = uuid::Uuid::parse_str(api_key_id).map_err(|_| DbError::InvalidUuid)?;
let result: UsageLogRow = sqlx::query_as(
"INSERT INTO usage_logs (team_id, api_key_id, model, input_tokens, output_tokens, response_time_ms) VALUES ($1, $2, $3, $4, $5, $6) RETURNING id, team_id, api_key_id, model, input_tokens, output_tokens, response_time_ms, recorded_at"
)
.bind(team_uuid)
.bind(api_key_uuid)
.bind(model)
.bind(input_tokens)
.bind(output_tokens)
.bind(response_time_ms)
.fetch_one(&self.pool)
.await?;
Ok(UsageLog::from(result))
}
async fn count_users_by_role(&self, role: &str) -> Result<i64, DbError> {
let count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users WHERE role = $1")
.bind(role)
.fetch_one(&self.pool)
.await?;
Ok(count.0)
}
async fn update_password_hash(
&self,
user_id: &str,
password_hash: &str,
) -> Result<(), DbError> {
let user_uuid = uuid::Uuid::parse_str(user_id).map_err(|_| DbError::InvalidUuid)?;
sqlx::query("UPDATE users SET password_hash = $1 WHERE id = $2")
.bind(password_hash)
.bind(user_uuid)
.execute(&self.pool)
.await?;
Ok(())
}
async fn list_deployments(
&self,
model: &str,
is_active: Option<bool>,
) -> Result<Vec<Deployment>, DbError> {
let result: Vec<DeploymentRow> = match is_active {
Some(active) => sqlx::query_as(
"SELECT id, name, provider, model, api_key_ref, base_url, is_active, weight, priority, max_tpm, max_rpm, cost_per_1k_input_tokens, cost_per_1k_output_tokens, metadata, sort_order, created_at, updated_at FROM deployments WHERE model = $1 AND is_active = $2 ORDER BY sort_order ASC, created_at ASC"
)
.bind(model)
.bind(active)
.fetch_all(&self.pool)
.await?,
None => sqlx::query_as(
"SELECT id, name, provider, model, api_key_ref, base_url, is_active, weight, priority, max_tpm, max_rpm, cost_per_1k_input_tokens, cost_per_1k_output_tokens, metadata, sort_order, created_at, updated_at FROM deployments WHERE model = $1 ORDER BY sort_order ASC, created_at ASC"
)
.bind(model)
.fetch_all(&self.pool)
.await?,
};
Ok(result.into_iter().map(Deployment::from).collect())
}
async fn get_deployment(&self, id: &str) -> Result<Option<Deployment>, DbError> {
let uuid = uuid::Uuid::parse_str(id).map_err(|_| DbError::InvalidUuid)?;
let result: Option<DeploymentRow> = sqlx::query_as(
"SELECT id, name, provider, model, api_key_ref, base_url, is_active, weight, priority, max_tpm, max_rpm, cost_per_1k_input_tokens, cost_per_1k_output_tokens, metadata, sort_order, created_at, updated_at FROM deployments WHERE id = $1",
)
.bind(uuid)
.fetch_optional(&self.pool)
.await?;
Ok(result.map(Deployment::from))
}
async fn create_deployment(&self, req: CreateDeploymentRequest) -> Result<Deployment, DbError> {
let result: DeploymentRow = sqlx::query_as(
"INSERT INTO deployments (name, provider, model, api_key_ref, base_url, is_active, weight, priority, max_tpm, max_rpm, cost_per_1k_input_tokens, cost_per_1k_output_tokens, metadata, sort_order) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) RETURNING id, name, provider, model, api_key_ref, base_url, is_active, weight, priority, max_tpm, max_rpm, cost_per_1k_input_tokens, cost_per_1k_output_tokens, metadata, sort_order, created_at, updated_at"
)
.bind(&req.name)
.bind(&req.provider)
.bind(&req.model)
.bind(req.api_key_ref.as_deref().unwrap_or(""))
.bind(&req.base_url)
.bind(req.is_active)
.bind(req.weight)
.bind(req.priority)
.bind(req.max_tpm)
.bind(req.max_rpm)
.bind(req.cost_per_1k_input_tokens)
.bind(req.cost_per_1k_output_tokens)
.bind(req.metadata.unwrap_or(serde_json::json!({})))
.bind(req.sort_order.unwrap_or(0))
.fetch_one(&self.pool)
.await
.map_err(|e| {
if e.as_database_error()
.map(|db| db.is_unique_violation())
.unwrap_or(false)
{
DbError::UniqueViolation("Deployment with this name already exists".to_string())
} else {
DbError::Sqlx(e)
}
})?;
Ok(Deployment::from(result))
}
async fn update_deployment(
&self,
id: &str,
req: CreateDeploymentRequest,
) -> Result<Deployment, DbError> {
let uuid = uuid::Uuid::parse_str(id).map_err(|_| DbError::InvalidUuid)?;
let result: DeploymentRow = sqlx::query_as(
"UPDATE deployments SET name = $1, provider = $2, model = $3, api_key_ref = $4, base_url = $5, is_active = $6, weight = $7, priority = $8, max_tpm = $9, max_rpm = $10, cost_per_1k_input_tokens = $11, cost_per_1k_output_tokens = $12, metadata = $13, sort_order = $14 WHERE id = $15 RETURNING id, name, provider, model, api_key_ref, base_url, is_active, weight, priority, max_tpm, max_rpm, cost_per_1k_input_tokens, cost_per_1k_output_tokens, metadata, sort_order, created_at, updated_at"
)
.bind(&req.name)
.bind(&req.provider)
.bind(&req.model)
.bind(req.api_key_ref.as_deref().unwrap_or(""))
.bind(&req.base_url)
.bind(req.is_active)
.bind(req.weight)
.bind(req.priority)
.bind(req.max_tpm)
.bind(req.max_rpm)
.bind(req.cost_per_1k_input_tokens)
.bind(req.cost_per_1k_output_tokens)
.bind(req.metadata.unwrap_or(serde_json::json!({})))
.bind(req.sort_order.unwrap_or(0))
.bind(uuid)
.fetch_one(&self.pool)
.await
.map_err(|e| {
if e.as_database_error()
.map(|db| db.is_unique_violation())
.unwrap_or(false)
{
DbError::UniqueViolation("Deployment with this name already exists".to_string())
} else {
DbError::Sqlx(e)
}
})?;
Ok(Deployment::from(result))
}
async fn delete_deployment(&self, id: &str) -> Result<(), DbError> {
let uuid = uuid::Uuid::parse_str(id).map_err(|_| DbError::InvalidUuid)?;
let result = sqlx::query("DELETE FROM deployments WHERE id = $1")
.bind(uuid)
.execute(&self.pool)
.await?;
if result.rows_affected() == 0 {
return Err(DbError::NotFound);
}
Ok(())
}
async fn get_routing_config(&self) -> Result<Option<RoutingConfig>, DbError> {
let result: Option<RoutingConfigRow> = sqlx::query_as(
"SELECT id, strategy, strategy_params, fallback_config, routing_groups, updated_at FROM routing_config WHERE id = 1",
)
.fetch_optional(&self.pool)
.await?;
Ok(result.map(RoutingConfig::from))
}
async fn update_routing_config(
&self,
req: UpdateRoutingConfigRequest,
) -> Result<RoutingConfig, DbError> {
let existing = self
.get_routing_config()
.await?
.unwrap_or_else(|| RoutingConfig {
strategy: "weighted-shuffle".to_string(),
strategy_params: serde_json::json!({}),
fallback_config: serde_json::json!({}),
routing_groups: serde_json::json!([]),
updated_at: chrono::Utc::now(),
});
let strategy = req.strategy.unwrap_or(existing.strategy);
let strategy_params = req.strategy_params.unwrap_or(existing.strategy_params);
let fallback_config = req.fallback_config.unwrap_or(existing.fallback_config);
let routing_groups = req.routing_groups.unwrap_or(existing.routing_groups);
let result: RoutingConfigRow = sqlx::query_as(
"UPDATE routing_config SET strategy = $1, strategy_params = $2, fallback_config = $3, routing_groups = $4, updated_at = NOW() WHERE id = 1 RETURNING id, strategy, strategy_params, fallback_config, routing_groups, updated_at",
)
.bind(&strategy)
.bind(&strategy_params)
.bind(&fallback_config)
.bind(&routing_groups)
.fetch_one(&self.pool)
.await?;
Ok(RoutingConfig::from(result))
}
async fn ping(&self) -> Result<(), DbError> {
sqlx::query("SELECT 1").execute(&self.pool).await?;
Ok(())
}
}
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
struct TeamRow {
id: uuid::Uuid,
name: String,
budget_cents: i64,
created_at: DateTime<Utc>,
updated_at: DateTime<Utc>,
}
impl From<TeamRow> for Team {
fn from(row: TeamRow) -> Self {
Team {
id: row.id.to_string(),
name: row.name,
budget_cents: row.budget_cents,
created_at: row.created_at,
updated_at: row.updated_at,
}
}
}
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
struct UserRow {
id: uuid::Uuid,
team_id: uuid::Uuid,
email: String,
role: String,
password_hash: Option<String>,
created_at: DateTime<Utc>,
}
impl From<UserRow> for User {
fn from(row: UserRow) -> Self {
User {
id: row.id.to_string(),
team_id: row.team_id.to_string(),
email: row.email,
role: row.role,
password_hash: row.password_hash,
created_at: row.created_at,
}
}
}
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
struct ApiKeyRow {
id: uuid::Uuid,
key_hash: String,
user_id: uuid::Uuid,
team_id: uuid::Uuid,
name: Option<String>,
is_active: bool,
created_at: DateTime<Utc>,
expires_at: Option<DateTime<Utc>>,
}
impl From<ApiKeyRow> for ApiKey {
fn from(row: ApiKeyRow) -> Self {
ApiKey {
id: row.id.to_string(),
key_hash: row.key_hash,
user_id: row.user_id.to_string(),
team_id: row.team_id.to_string(),
name: row.name,
is_active: row.is_active,
created_at: row.created_at,
expires_at: row.expires_at,
}
}
}
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
struct ModelAliasRow {
id: uuid::Uuid,
team_id: uuid::Uuid,
alias: String,
target_model: String,
provider: String,
created_at: DateTime<Utc>,
}
impl From<ModelAliasRow> for ModelAlias {
fn from(row: ModelAliasRow) -> Self {
ModelAlias {
id: row.id.to_string(),
team_id: row.team_id.to_string(),
alias: row.alias,
target_model: row.target_model,
provider: row.provider,
created_at: row.created_at,
}
}
}
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
struct QuotaRow {
id: uuid::Uuid,
team_id: uuid::Uuid,
rpm_limit: i32,
tpm_limit: i32,
updated_at: DateTime<Utc>,
}
impl From<QuotaRow> for Quota {
fn from(row: QuotaRow) -> Self {
Quota {
id: row.id.to_string(),
team_id: row.team_id.to_string(),
rpm_limit: row.rpm_limit,
tpm_limit: row.tpm_limit,
updated_at: row.updated_at,
}
}
}
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
struct UsageLogRow {
id: uuid::Uuid,
team_id: uuid::Uuid,
api_key_id: uuid::Uuid,
model: String,
input_tokens: i32,
output_tokens: i32,
response_time_ms: i64,
recorded_at: DateTime<Utc>,
}
impl From<UsageLogRow> for UsageLog {
fn from(row: UsageLogRow) -> Self {
UsageLog {
id: row.id.to_string(),
team_id: row.team_id.to_string(),
api_key_id: row.api_key_id.to_string(),
model: row.model,
input_tokens: row.input_tokens,
output_tokens: row.output_tokens,
response_time_ms: row.response_time_ms,
recorded_at: row.recorded_at,
}
}
}
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
struct DeploymentRow {
id: uuid::Uuid,
name: String,
provider: String,
model: String,
api_key_ref: String,
base_url: String,
is_active: bool,
weight: i32,
priority: i32,
max_tpm: Option<i32>,
max_rpm: Option<i32>,
cost_per_1k_input_tokens: Option<f64>,
cost_per_1k_output_tokens: Option<f64>,
metadata: serde_json::Value,
sort_order: i32,
created_at: DateTime<Utc>,
updated_at: DateTime<Utc>,
}
impl From<DeploymentRow> for Deployment {
fn from(row: DeploymentRow) -> Self {
Deployment {
id: row.id.to_string(),
name: row.name,
provider: row.provider,
model: row.model,
api_key_ref: row.api_key_ref,
base_url: row.base_url,
is_active: row.is_active,
weight: row.weight as u32,
priority: row.priority as u32,
max_tpm: row.max_tpm.map(|v| v as u32),
max_rpm: row.max_rpm.map(|v| v as u32),
cost_per_1k_input_tokens: row.cost_per_1k_input_tokens,
cost_per_1k_output_tokens: row.cost_per_1k_output_tokens,
metadata: row.metadata,
sort_order: row.sort_order as u32,
created_at: row.created_at,
updated_at: row.updated_at,
}
}
}
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
struct RoutingConfigRow {
id: i32,
strategy: String,
strategy_params: serde_json::Value,
fallback_config: serde_json::Value,
routing_groups: serde_json::Value,
updated_at: DateTime<Utc>,
}
impl From<RoutingConfigRow> for RoutingConfig {
fn from(row: RoutingConfigRow) -> Self {
RoutingConfig {
strategy: row.strategy,
strategy_params: row.strategy_params,
fallback_config: row.fallback_config,
routing_groups: row.routing_groups,
updated_at: row.updated_at,
}
}
}
#[derive(Clone)]
pub struct RedisConfigStore {
manager: hyperinfer_core::redis::ConfigManager,
}
impl RedisConfigStore {
pub async fn new(redis_url: &str) -> Result<Self, hyperinfer_core::ConfigError> {
let manager = hyperinfer_core::redis::ConfigManager::new(redis_url).await?;
Ok(Self { manager })
}
pub async fn subscribe_to_config_updates(
&self,
config: std::sync::Arc<tokio::sync::RwLock<hyperinfer_core::Config>>,
) -> Result<tokio::task::JoinHandle<()>, hyperinfer_core::ConfigError> {
self.manager.subscribe_to_config_updates(config).await
}
pub async fn ping(&self) -> Result<(), hyperinfer_core::ConfigError> {
self.manager.ping().await
}
}
#[async_trait]
impl ConfigStore for RedisConfigStore {
async fn fetch_config(&self) -> Result<hyperinfer_core::Config, hyperinfer_core::ConfigError> {
self.manager.fetch_config().await
}
async fn publish_config_update(
&self,
config: &hyperinfer_core::Config,
) -> Result<(), hyperinfer_core::ConfigError> {
self.manager.publish_config_update(config).await
}
async fn publish_policy_update(
&self,
update: &PolicyUpdate,
) -> Result<(), hyperinfer_core::ConfigError> {
self.manager.publish_policy_update(update).await
}
async fn ping(&self) -> Result<(), hyperinfer_core::ConfigError> {
self.manager.ping().await
}
}