use crate::config::{
BatchConfig, DaemonConfig, DaemonEnabled, FileLimitsConfig, FilesConfig, LeaderElectionConfig, LimitsConfig, NativeAuthConfig,
OnwardsSyncConfig, PasswordConfig, PoolSettings, ProbeSchedulerConfig, ProxyHeaderAuthConfig, SecurityConfig,
};
use crate::db::handlers::inference_endpoints::{InferenceEndpointFilter, InferenceEndpoints};
use crate::db::handlers::repository::Repository;
use crate::db::models::api_keys::ApiKeyPurpose;
use crate::errors::Error;
use crate::types::{GroupId, Operation, Permission, Resource, UserId};
use crate::{
api::models::{
api_keys::ApiKeyCreate,
users::{CurrentUser, Role, UserResponse},
},
db::{
handlers::{Deployments, Groups, Users, api_keys::ApiKeys},
models::{
api_keys::{ApiKeyCreateDBRequest, ApiKeyDBResponse},
deployments::{DeploymentCreateDBRequest, DeploymentDBResponse, ModelType},
groups::{GroupCreateDBRequest, GroupDBResponse},
users::UserCreateDBRequest,
},
},
};
use axum_test::TestServer;
use sqlx::{PgConnection, PgPool};
use sqlx_pool_router::TestDbPools;
use uuid::Uuid;
pub async fn create_test_app_state_with_config(pool: PgPool, config: crate::config::Config) -> crate::AppState<TestDbPools> {
let test_pools = TestDbPools::new(pool.clone()).await.expect("Failed to create TestDbPools");
let fusillade_pools = TestDbPools::new(pool.clone())
.await
.expect("Failed to create fusillade TestDbPools");
let request_manager = std::sync::Arc::new(fusillade::PostgresRequestManager::new(fusillade_pools, Default::default()));
let limiters = crate::limits::Limiters::new(&config.limits);
let shared_config = crate::SharedConfig::new(config);
underway::run_migrations(&pool).await.expect("Failed to run underway migrations");
let task_state = crate::tasks::TaskState {
request_manager: request_manager.clone(),
dwctl_pool: pool.clone(),
config: shared_config.clone(),
encryption_key: None,
ingest_file_job: std::sync::Arc::new(std::sync::OnceLock::new()),
activate_batch_job: std::sync::Arc::new(std::sync::OnceLock::new()),
create_batch_job: std::sync::Arc::new(std::sync::OnceLock::new()),
};
let task_runner = std::sync::Arc::new(
crate::tasks::TaskRunner::new(pool, task_state)
.await
.expect("Failed to create task runner"),
);
crate::AppState::builder()
.db(test_pools)
.config(shared_config)
.request_manager(request_manager)
.task_runner(task_runner)
.limiters(limiters)
.build()
}
pub async fn create_test_app_state_with_fusillade(pool: PgPool, config: crate::config::Config) -> crate::AppState<TestDbPools> {
use sqlx::Executor;
use sqlx::postgres::PgConnectOptions;
pool.execute("CREATE SCHEMA IF NOT EXISTS fusillade")
.await
.expect("Failed to create fusillade schema");
let base_opts: PgConnectOptions = pool.connect_options().as_ref().clone();
let fusillade_pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(4)
.min_connections(0)
.connect_with(base_opts.options([("search_path", "fusillade")]))
.await
.expect("Failed to create fusillade pool");
fusillade::migrator()
.run(&fusillade_pool)
.await
.expect("Failed to run fusillade migrations");
let test_pools = TestDbPools::new(pool.clone()).await.expect("Failed to create TestDbPools");
let fusillade_test_pools = TestDbPools::new(fusillade_pool)
.await
.expect("Failed to create fusillade TestDbPools");
let request_manager = std::sync::Arc::new(fusillade::PostgresRequestManager::new(fusillade_test_pools, Default::default()));
let limiters = crate::limits::Limiters::new(&config.limits);
let shared_config = crate::SharedConfig::new(config);
underway::run_migrations(&pool).await.expect("Failed to run underway migrations");
let task_state = crate::tasks::TaskState {
request_manager: request_manager.clone(),
dwctl_pool: pool.clone(),
config: shared_config.clone(),
encryption_key: None,
ingest_file_job: std::sync::Arc::new(std::sync::OnceLock::new()),
activate_batch_job: std::sync::Arc::new(std::sync::OnceLock::new()),
create_batch_job: std::sync::Arc::new(std::sync::OnceLock::new()),
};
let task_runner = std::sync::Arc::new(
crate::tasks::TaskRunner::new(pool, task_state)
.await
.expect("Failed to create task runner"),
);
crate::AppState::builder()
.db(test_pools)
.config(shared_config)
.request_manager(request_manager)
.task_runner(task_runner)
.limiters(limiters)
.build()
}
pub async fn create_test_app(pool: PgPool, _enable_sync: bool) -> (TestServer, crate::BackgroundServices) {
let config = create_test_config();
let app = crate::Application::new_with_pool(config, Some(pool), None)
.await
.expect("Failed to create application");
app.into_test_server()
}
pub async fn create_test_app_with_config(
pool: PgPool,
config: crate::config::Config,
_enable_sync: bool,
) -> (TestServer, crate::BackgroundServices) {
let app = crate::Application::new_with_pool(config, Some(pool), None)
.await
.expect("Failed to create application");
app.into_test_server()
}
pub fn create_test_config() -> crate::config::Config {
let temp_dir = std::env::temp_dir().join(format!("dwctl-test-emails-{}", std::process::id()));
crate::config::Config {
database_url: None,
database_replica_url: None,
database: crate::config::DatabaseConfig::External {
url: "Something".to_string(), replica_url: None,
pool: PoolSettings {
max_connections: 4,
min_connections: 1,
..Default::default()
},
replica_pool: None,
fusillade: crate::config::ComponentDb::Schema {
name: "fusillade".to_string(),
pool: PoolSettings {
max_connections: 4,
min_connections: 0,
..Default::default()
},
replica_pool: None,
},
outlet: crate::config::ComponentDb::Schema {
name: "outlet".to_string(),
pool: PoolSettings {
max_connections: 4,
min_connections: 0,
..Default::default()
},
replica_pool: None,
},
underway_pool: crate::config::default_underway_pool(),
},
slow_statement_threshold_ms: 1000,
host: "127.0.0.1".to_string(),
port: 0,
dashboard_url: "http://localhost:3001".to_string(),
admin_email: "admin@test.com".to_string(),
admin_password: None,
secret_key: Some("test-secret-key-for-testing-only".to_string()),
model_sources: vec![crate::config::ModelSource {
name: "test".to_string(),
url: "http://localhost:8081".parse().unwrap(),
api_key: None,
sync_interval: std::time::Duration::from_secs(60),
default_models: None,
}],
metadata: crate::config::Metadata::default(),
payment: None,
auth: crate::config::AuthConfig {
native: NativeAuthConfig {
enabled: false,
password: PasswordConfig {
min_length: 8,
max_length: 64,
argon2_memory_kib: 128, argon2_iterations: 1, argon2_parallelism: 1, },
..Default::default()
},
proxy_header: ProxyHeaderAuthConfig {
enabled: true,
..Default::default()
},
security: SecurityConfig::default(),
default_user_roles: vec![crate::api::models::users::Role::StandardUser],
},
enable_metrics: false,
enable_request_logging: false,
enable_analytics: true,
analytics: crate::config::AnalyticsConfig::default(),
enable_otel_export: false,
credits: crate::config::CreditsConfig::default(),
batches: BatchConfig {
enabled: true,
files: FilesConfig::default(),
..Default::default()
},
background_services: crate::config::BackgroundServicesConfig {
onwards_sync: OnwardsSyncConfig {
enabled: false,
fallback_interval_milliseconds: 10000,
},
probe_scheduler: ProbeSchedulerConfig { enabled: false },
batch_daemon: DaemonConfig {
enabled: DaemonEnabled::Never,
..Default::default()
},
leader_election: LeaderElectionConfig { enabled: false },
notifications: crate::config::NotificationsConfig {
webhooks: crate::config::WebhookConfig {
enabled: false,
..Default::default()
},
..Default::default()
},
sync_workers: crate::config::SyncWorkersConfig {
enabled: false,
..Default::default()
},
..Default::default()
},
email: crate::config::EmailConfig {
transport: crate::config::EmailTransportConfig::File {
path: temp_dir.to_string_lossy().to_string(),
},
..Default::default()
},
sample_files: crate::sample_files::SampleFilesConfig::default(),
limits: LimitsConfig {
files: FileLimitsConfig {
max_file_size: 1000 * 1024 * 1024, ..Default::default()
},
..Default::default()
},
onwards: crate::config::OnwardsConfig::default(),
onboarding_url: None,
support_email: "support@test.com".to_string(),
connections: Default::default(),
}
}
pub async fn create_test_user(pool: &PgPool, role: Role) -> UserResponse {
let mut conn = pool.acquire().await.expect("Failed to acquire connection");
let mut users_repo = Users::new(&mut conn);
let user_id = Uuid::new_v4();
let username = format!("testuser_{}", user_id.simple());
let email = format!("{username}@example.com");
let roles = vec![role];
let user_create = UserCreateDBRequest {
username: username.clone(),
email,
display_name: Some("Test User".to_string()),
avatar_url: None,
is_admin: false,
roles,
auth_source: "test".to_string(),
password_hash: None,
external_user_id: Some(username.clone()),
};
let user = users_repo.create(&user_create).await.expect("Failed to create test user");
UserResponse::from(user)
}
pub async fn create_test_admin_user(pool: &PgPool, role: Role) -> UserResponse {
let mut conn = pool.acquire().await.expect("Failed to acquire connection");
let mut users_repo = Users::new(&mut conn);
let user_id = Uuid::new_v4();
let username = format!("testadmin_{}", user_id.simple());
let email = format!("{username}@example.com");
let roles = vec![role];
let user_create = UserCreateDBRequest {
username: username.clone(),
email,
display_name: Some("Test Admin User".to_string()),
avatar_url: None,
is_admin: true,
roles,
auth_source: "test".to_string(),
password_hash: None,
external_user_id: Some(username.clone()),
};
let user = users_repo.create(&user_create).await.expect("Failed to create test admin user");
UserResponse::from(user)
}
pub async fn create_test_user_with_roles(pool: &PgPool, roles: Vec<Role>) -> UserResponse {
let mut conn = pool.acquire().await.expect("Failed to acquire connection");
let mut users_repo = Users::new(&mut conn);
let user_id = Uuid::new_v4();
let username = format!("testuser_{}", user_id.simple());
let email = format!("{username}@example.com");
let user_create = UserCreateDBRequest {
username: username.clone(),
email,
display_name: Some("Test Multi-Role User".to_string()),
avatar_url: None,
is_admin: false,
roles,
auth_source: "test".to_string(),
password_hash: None,
external_user_id: Some(username.clone()),
};
let user = users_repo
.create(&user_create)
.await
.expect("Failed to create test user with roles");
UserResponse::from(user)
}
pub fn add_auth_headers(user: &UserResponse) -> Vec<(String, String)> {
let config = ProxyHeaderAuthConfig::default();
let external_user_id = user.external_user_id.as_ref().unwrap_or(&user.username);
vec![
(config.header_name, external_user_id.clone()),
(config.email_header_name, user.email.clone()),
]
}
pub async fn create_test_group(pool: &PgPool) -> GroupDBResponse {
let mut conn = pool.acquire().await.expect("Failed to acquire connection");
let system_user = get_system_user(&mut conn).await;
let mut group_repo = Groups::new(&mut conn);
let group_create = GroupCreateDBRequest {
name: format!("test_group_{}", Uuid::new_v4().simple()),
description: Some("Test group".to_string()),
created_by: system_user.id,
};
group_repo.create(&group_create).await.expect("Failed to create test group")
}
pub async fn get_system_user(pool: &mut PgConnection) -> UserResponse {
let user_id = Uuid::nil();
let user = sqlx::query!(
r#"
SELECT id, username, email, display_name, avatar_url, is_admin, created_at, updated_at, auth_source
FROM users
WHERE users.id = $1
"#,
user_id
)
.fetch_one(&mut *pool)
.await
.expect("Failed to get system user");
let roles = sqlx::query!("SELECT role as \"role: Role\" FROM user_roles WHERE user_id = $1", user.id)
.fetch_all(&mut *pool)
.await
.expect("Failed to get system user roles");
let roles: Vec<Role> = roles.into_iter().map(|r| r.role).collect();
UserResponse {
id: user.id,
username: user.username,
email: user.email,
display_name: user.display_name,
avatar_url: user.avatar_url,
is_admin: user.is_admin,
roles,
created_at: user.created_at,
updated_at: user.updated_at,
last_login: None,
auth_source: user.auth_source,
external_user_id: None,
groups: None, credit_balance: None,
has_payment_provider_id: false,
batch_notifications_enabled: false,
low_balance_threshold: None,
auto_topup_amount: None,
auto_topup_threshold: None,
has_auto_topup_payment_method: false,
auto_topup_monthly_limit: None,
user_type: "individual".to_string(),
organizations: None,
active_organization_id: None,
onboarding_redirect_url: None,
}
}
pub async fn add_user_to_group(pool: &PgPool, user_id: UserId, group_id: GroupId) {
let mut conn = pool.acquire().await.expect("Failed to acquire connection");
let mut group_repo = Groups::new(&mut conn);
group_repo
.add_user_to_group(user_id, group_id)
.await
.expect("Failed to add user to group");
}
pub async fn create_test_api_key_for_user(pool: &PgPool, user_id: UserId) -> ApiKeyDBResponse {
let mut conn = pool.acquire().await.expect("Failed to acquire connection");
let mut api_key_repo = ApiKeys::new(&mut conn);
let request = ApiKeyCreateDBRequest::new(
user_id,
user_id,
ApiKeyCreate {
name: format!("Test API Key {}", Uuid::new_v4().simple()),
description: Some("Test description".to_string()),
purpose: ApiKeyPurpose::Realtime,
requests_per_second: None,
burst_size: None,
member_id: None,
},
);
api_key_repo.create(&request).await.expect("Failed to create test API key")
}
pub async fn create_test_deployment(pool: &PgPool, created_by: UserId, model_name: &str, alias: &str) -> DeploymentDBResponse {
let mut tx = pool.begin().await.expect("Failed to begin transaction");
let mut endpoints_repo = InferenceEndpoints::new(&mut tx);
let filter = InferenceEndpointFilter::new(0, 100);
let endpoints = endpoints_repo.list(&filter).await.expect("Failed to list endpoints");
let test_endpoint_id = endpoints
.into_iter()
.find(|e| e.name == "test")
.expect("Test endpoint should exist")
.id;
let mut deployment_repo = Deployments::new(&mut tx);
let request = DeploymentCreateDBRequest::builder()
.created_by(created_by)
.model_name(model_name.to_string())
.alias(alias.to_string())
.hosted_on(test_endpoint_id)
.build();
let response = deployment_repo.create(&request).await.expect("Failed to create test deployment");
tx.commit().await.expect("Failed to commit transaction");
response
}
pub async fn create_test_deployment_with_model_type(
pool: &PgPool,
created_by: UserId,
model_name: &str,
alias: &str,
model_type: ModelType,
) -> DeploymentDBResponse {
let mut tx = pool.begin().await.expect("Failed to begin transaction");
let mut endpoints_repo = InferenceEndpoints::new(&mut tx);
let filter = InferenceEndpointFilter::new(0, 100);
let endpoints = endpoints_repo.list(&filter).await.expect("Failed to list endpoints");
let test_endpoint_id = endpoints
.into_iter()
.find(|e| e.name == "test")
.expect("Test endpoint should exist")
.id;
let mut deployment_repo = Deployments::new(&mut tx);
let request = DeploymentCreateDBRequest::builder()
.created_by(created_by)
.model_name(model_name.to_string())
.alias(alias.to_string())
.model_type(model_type)
.hosted_on(test_endpoint_id)
.build();
let response = deployment_repo.create(&request).await.expect("Failed to create test deployment");
tx.commit().await.expect("Failed to commit transaction");
response
}
pub async fn add_deployment_to_group(pool: &PgPool, deployment_id: crate::types::DeploymentId, group_id: GroupId, granted_by: UserId) {
let mut conn = pool.acquire().await.expect("Failed to acquire connection");
let mut group_repo = Groups::new(&mut conn);
group_repo
.add_deployment_to_group(deployment_id, group_id, granted_by)
.await
.expect("Failed to add deployment to group");
}
pub async fn get_test_endpoint_id(pool: &PgPool) -> uuid::Uuid {
let mut conn = pool.acquire().await.expect("Failed to acquire connection");
let mut endpoints_repo = InferenceEndpoints::new(&mut conn);
let filter = crate::db::handlers::inference_endpoints::InferenceEndpointFilter::new(0, 100);
let endpoints = endpoints_repo.list(&filter).await.expect("Failed to list endpoints");
endpoints.iter().find(|e| e.name == "test").expect("Test endpoint should exist").id
}
pub fn require_admin(user: CurrentUser) -> Result<CurrentUser, Error> {
if user.is_admin {
Ok(user)
} else {
Err(Error::InsufficientPermissions {
required: Permission::Allow(Resource::Users, Operation::ReadAll),
action: Operation::ReadAll,
resource: "admin resource".to_string(),
})
}
}
pub async fn create_test_endpoint(pool: &PgPool, name: &str, created_by: UserId) -> uuid::Uuid {
let endpoint_id = uuid::Uuid::new_v4();
sqlx::query!(
r#"
INSERT INTO inference_endpoints (id, name, url, api_key, created_by)
VALUES ($1, $2, 'http://localhost:8080', NULL, $3)
"#,
endpoint_id,
name,
created_by
)
.execute(pool)
.await
.expect("Failed to create test endpoint");
endpoint_id
}
pub async fn create_test_org(pool: &PgPool, created_by: UserId) -> UserResponse {
use crate::db::handlers::Organizations;
use crate::db::models::organizations::OrganizationCreateDBRequest;
let org_name = format!("testorg_{}", Uuid::new_v4().simple());
let mut conn = pool.acquire().await.expect("Failed to acquire connection");
let mut orgs = Organizations::new(&mut conn);
let org = orgs
.create(
&OrganizationCreateDBRequest {
name: org_name.clone(),
email: format!("{org_name}@example.com"),
display_name: Some("Test Organization".to_string()),
avatar_url: None,
created_by,
},
&[
crate::api::models::users::Role::StandardUser,
crate::api::models::users::Role::BatchAPIUser,
],
)
.await
.expect("Failed to create test organization");
UserResponse {
id: org.id,
username: org.username,
email: org.email,
display_name: org.display_name,
avatar_url: org.avatar_url,
is_admin: org.is_admin,
roles: vec![],
created_at: org.created_at,
updated_at: org.updated_at,
last_login: None,
auth_source: org.auth_source,
external_user_id: None,
groups: None,
credit_balance: None,
has_payment_provider_id: false,
batch_notifications_enabled: false,
low_balance_threshold: None,
auto_topup_amount: None,
auto_topup_threshold: None,
has_auto_topup_payment_method: false,
auto_topup_monthly_limit: None,
user_type: org.user_type,
organizations: None,
active_organization_id: None,
onboarding_redirect_url: None,
}
}
pub async fn add_org_member(pool: &PgPool, org_id: UserId, user_id: UserId, role: &str) {
use crate::db::handlers::Organizations;
let mut conn = pool.acquire().await.expect("Failed to acquire connection");
let mut orgs = Organizations::new(&mut conn);
orgs.add_member(org_id, user_id, role).await.expect("Failed to add org member");
}
pub async fn create_test_model(pool: &PgPool, model_name: &str, alias: &str, endpoint_id: uuid::Uuid, created_by: UserId) -> uuid::Uuid {
let deployment_id = uuid::Uuid::new_v4();
sqlx::query!(
r#"
INSERT INTO deployed_models (id, model_name, alias, hosted_on, created_by, deleted)
VALUES ($1, $2, $3, $4, $5, false)
"#,
deployment_id,
model_name,
alias,
endpoint_id,
created_by
)
.execute(pool)
.await
.expect("Failed to create test model");
deployment_id
}