//! HTTP handlers for model deployment endpoints.
use sqlx_pool_router::PoolProvider;
use crate::api::models::deployments::{ModelFacets, ModelListResponse, TrafficRoutingAction, TrafficRoutingRule};
use crate::db::models::deployments::{
MODEL_CATALOG_METADATA_MAX_BYTES, MODEL_CATALOG_METADATA_MAX_EXTRA_KEYS, ModelCatalogMetadata, TrafficRuleAction,
};
use crate::db::models::tariffs::TariffCreateDBRequest;
use crate::{
AppState,
api::models::{
deployments::{
ComponentEndpointSummary, ComponentModelSummary, DeployedModelCreate, DeployedModelResponse, DeployedModelUpdate,
GetModelQuery, ListModelsQuery, ModelComponentResponse, enrichment::DeployedModelEnricher,
},
users::CurrentUser,
},
auth::permissions::{RequiresPermission, can_read_all_resources, has_permission, operation, resource},
db::{
handlers::{Deployments, InferenceEndpoints, Repository, Tariffs, deployments::DeploymentFilter},
models::{
api_keys::ApiKeyPurpose,
deployments::{DeploymentComponentDBResponse, DeploymentCreateDBRequest, DeploymentUpdateDBRequest, ModelStatus, ModelType},
},
},
errors::{Error, Result},
types::{DeploymentId, Resource},
};
use axum::{
extract::{Path, Query, State},
response::Json,
};
use sqlx::Acquire;
/// Validate that model catalog metadata is within size and key count limits.
fn validate_metadata(metadata: &ModelCatalogMetadata) -> Result<()> {
let size = serde_json::to_vec(metadata).map(|v| v.len()).unwrap_or(0);
if size > MODEL_CATALOG_METADATA_MAX_BYTES {
return Err(Error::BadRequest {
message: format!(
"metadata exceeds maximum size ({} bytes, limit is {} bytes)",
size, MODEL_CATALOG_METADATA_MAX_BYTES
),
});
}
if let Some(serde_json::Value::Object(map)) = &metadata.extra
&& map.len() > MODEL_CATALOG_METADATA_MAX_EXTRA_KEYS
{
return Err(Error::BadRequest {
message: format!(
"metadata.extra has too many keys ({}, limit is {})",
map.len(),
MODEL_CATALOG_METADATA_MAX_EXTRA_KEYS
),
});
}
Ok(())
}
/// Resolve API traffic routing rules to DB-layer actions (alias strings → UUIDs).
/// Validates no self-redirects, no empty targets, and that redirect targets exist.
async fn resolve_traffic_rules(
rules: &[TrafficRoutingRule],
model_alias: &str,
repo: &mut Deployments<'_>,
) -> Result<Vec<(ApiKeyPurpose, TrafficRuleAction)>> {
let mut resolved = Vec::with_capacity(rules.len());
for rule in rules {
let action = match &rule.action {
TrafficRoutingAction::Deny => TrafficRuleAction::Deny,
TrafficRoutingAction::Redirect { target } => {
if target.is_empty() {
return Err(Error::BadRequest {
message: "Redirect target must not be empty".to_string(),
});
}
if target == model_alias {
return Err(Error::BadRequest {
message: format!("Traffic routing rule cannot redirect model '{}' to itself", model_alias),
});
}
let target_id = repo.resolve_alias_to_id(target).await?.ok_or_else(|| Error::BadRequest {
message: format!("Redirect target model '{}' does not exist", target),
})?;
TrafficRuleAction::Redirect(target_id)
}
};
resolved.push((rule.api_key_purpose.clone(), action));
}
Ok(resolved)
}
/// Convert a DB component response to an API component response
fn db_component_to_response(c: DeploymentComponentDBResponse) -> ModelComponentResponse {
ModelComponentResponse {
weight: c.weight,
enabled: c.enabled,
sort_order: c.sort_order,
created_at: c.created_at,
model: ComponentModelSummary {
id: c.deployed_model_id,
alias: c.model_alias,
model_name: c.model_name,
description: c.model_description,
model_type: c.model_type.and_then(|s| match s.as_str() {
"CHAT" => Some(ModelType::Chat),
"EMBEDDINGS" => Some(ModelType::Embeddings),
"RERANKER" => Some(ModelType::Reranker),
_ => None,
}),
endpoint: c.endpoint_id.map(|id| ComponentEndpointSummary {
id,
name: c.endpoint_name.unwrap_or_default(),
}),
trusted: c.model_trusted,
open_responses_adapter: c.model_open_responses_adapter,
},
}
}
#[utoipa::path(
get,
path = "/models",
tag = "models",
summary = "List deployed models",
description = "List all deployed models, optionally filtered by endpoint",
params(
("endpoint" = Option<i32>, Query, description = "Filter by inference endpoint ID"),
("accessible" = Option<bool>, Query, description = "Filter to only models the current user can access (defaults to false for admins, true for users)"),
("include" = Option<String>, Query, description = "Include additional data (comma-separated: 'groups', 'metrics', 'status', 'pricing', 'endpoints', 'facets'). Only platform managers can include groups. Status shows probe monitoring information. Pricing shows simple customer rates for regular users, full pricing structure including current active tariffs for users with Pricing::ReadAll permission. Endpoints includes full inference endpoint details. Facets returns distinct providers, capabilities, and model types for filter dropdowns."),
("provider" = Option<String>, Query, description = "Filter by provider name (case-insensitive exact match against metadata.provider)"),
("model_type" = Option<String>, Query, description = "Filter by model type (CHAT, EMBEDDINGS, RERANKER)"),
("capability" = Option<String>, Query, description = "Filter by capability (returns models that have this capability)"),
("sort" = Option<String>, Query, description = "Sort field: created_at (default), alias, intelligence_index, released_at, context_window, provider, price_from"),
("sort_direction" = Option<String>, Query, description = "Sort direction: asc or desc (default depends on sort field)"),
("deleted" = Option<bool>, Query, description = "Show deleted models when true (admin only), non-deleted models when false, and all models when not specified"),
("inactive" = Option<bool>, Query, description = "Show inactive models when true (admin only)"),
("limit" = Option<i64>, Query, description = "Maximum number of items to return (default: 10, max: 100)"),
("skip" = Option<i64>, Query, description = "Number of items to skip (default: 0)"),
("search" = Option<String>, Query, description = "Search query to filter models by alias, model_name, or endpoint name (case-insensitive substring match)"),
("is_composite" = Option<bool>, Query, description = "Filter by composite/virtual model status (true = virtual models only, false = hosted models only)"),
),
responses(
(status = 200, description = "Paginated list of deployed models", body = ModelListResponse),
(status = 401, description = "Unauthorized"),
(status = 404, description = "Inference endpoint not found"),
(status = 500, description = "Internal server error"),
),
security(
("BearerAuth" = []),
("CookieAuth" = []),
("X-Doubleword-User" = [])
)
)]
#[tracing::instrument(skip_all)]
pub async fn list_deployed_models<P: PoolProvider>(
State(state): State<AppState<P>>,
Query(query): Query<ListModelsQuery>,
// Lots of conditional logic here, so no logic in extractor
current_user: CurrentUser,
) -> Result<Json<ModelListResponse>> {
let has_system_access = has_permission(¤t_user, resource::Models.into(), operation::SystemAccess.into());
let can_read_all_models = can_read_all_resources(¤t_user, Resource::Models);
let can_read_groups = can_read_all_resources(¤t_user, Resource::Groups);
let can_read_users = can_read_all_resources(¤t_user, Resource::Users);
let can_read_pricing = can_read_all_resources(¤t_user, Resource::Pricing);
let can_read_rate_limits = can_read_all_resources(¤t_user, Resource::ModelRateLimits);
let can_read_metrics = can_read_all_resources(¤t_user, Resource::Analytics);
// Use read replica for this read-only operation
let mut conn = state.db.read().acquire().await.map_err(|e| Error::Database(e.into()))?;
// Validate endpoint exists if specified
if let Some(endpoint_id) = query.endpoint {
let mut endpoints_repo = InferenceEndpoints::new(&mut conn);
if endpoints_repo.get_by_id(endpoint_id).await?.is_none() {
return Err(Error::NotFound {
resource: "endpoint".to_string(),
id: endpoint_id.to_string(),
});
}
}
// Get deployments with the filter
let mut repo = Deployments::new(&mut conn);
// Build the filter with pagination parameters
let (skip, limit) = query.pagination.params();
let mut filter = DeploymentFilter::new(skip, limit);
if let Some(endpoint_id) = query.endpoint {
filter = filter.with_endpoint(endpoint_id);
};
// Parse comma-separated group IDs if specified
if let Some(ref group_str) = query.group {
let group_ids: std::result::Result<Vec<_>, _> = group_str
.split(',')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.map(|s| s.parse::<crate::types::GroupId>())
.collect();
match group_ids {
Ok(ids) if !ids.is_empty() => {
filter = filter.with_groups(ids);
}
Ok(_) => {
// Empty list after filtering, ignore
}
Err(_) => {
return Err(Error::BadRequest {
message: "Invalid group ID format. Expected comma-separated UUIDs.".to_string(),
});
}
}
};
// Handle deleted models - admins can supply query parameter
if has_system_access {
if query.deleted.unwrap_or(false) {
// Admins can see deleted models if requested (default from the repo is all models,
// inc. deleted), so no filter to add here.
} else {
// Admins see non-deleted models by default
filter = filter.with_deleted(false);
}
} else {
// users can only see non-deleted models
filter = filter.with_deleted(false);
};
// Handle inactive models - admins can supply query parameter
if has_system_access {
if query.inactive.unwrap_or(false) {
// Admins can see inactive models if requested
filter = filter.with_statuses(vec![ModelStatus::Inactive]);
} else {
// Admins see active models by default
filter = filter.with_statuses(vec![ModelStatus::Active]);
}
} else {
// users can only see active models
filter = filter.with_statuses(vec![ModelStatus::Active]);
};
// Apply accessibility filtering based if user doesn't have PlatformManager role
// When an organization is active, filter by the org's group memberships instead
if !can_read_all_models || query.accessible.unwrap_or(false) {
let target_user_id = current_user.active_organization.unwrap_or(current_user.id);
filter = filter.with_accessible_to(target_user_id);
}
// Apply search filter if specified
if let Some(search) = query.search.as_ref()
&& !search.trim().is_empty()
{
filter = filter.with_search(search.trim().to_string());
}
// Apply is_composite filter if specified
if let Some(is_composite) = query.is_composite {
filter = filter.with_composite(is_composite);
}
// Apply provider filter if specified
if let Some(ref provider) = query.provider
&& !provider.trim().is_empty()
{
filter = filter.with_provider(provider.trim().to_string());
}
// Apply model_type filter if specified
if let Some(model_type) = query.model_type {
filter = filter.with_model_type(model_type);
}
// Apply capability filter if specified
if let Some(ref capability) = query.capability
&& !capability.trim().is_empty()
{
filter = filter.with_capability(capability.trim().to_string());
}
// Apply sort if specified
if let Some(sort_field) = query.sort {
filter = filter.with_sort(sort_field, query.sort_direction);
}
// Parse include parameter
let all_includes: Vec<&str> = query
.include
.as_deref()
.unwrap_or("")
.split(',')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.collect();
// Filter includes based on permissions
let mut includes: Vec<&str> = Vec::new();
for &include in &all_includes {
match include {
"groups" => {
// Only users with Groups::ReadAll can include groups
if can_read_groups {
includes.push(include);
}
}
"metrics" => {
// Only users with Analytics::ReadAll can include metrics
if can_read_metrics {
includes.push(include);
}
}
"endpoints" => {
// Model endpoints is priviliged information for admins
if can_read_all_models {
includes.push(include);
}
}
_ => {
// Other includes (like pricing, status) are allowed for all users
includes.push(include);
}
}
}
// Get total count before applying pagination
let total_count = repo.count(&filter).await?;
let filtered_models = repo.list(&filter).await?;
// Convert to API responses and add provider_pricing based on permissions and includes
let mut models: Vec<DeployedModelResponse> = filtered_models
.into_iter()
.map(|model| {
let provider_pricing = if can_read_pricing { model.provider_pricing.clone() } else { None };
DeployedModelResponse::from(model).with_provider_pricing(provider_pricing)
})
.collect();
// Fetch and attach traffic rules in bulk
{
let model_ids: Vec<DeploymentId> = models.iter().map(|m| m.id).collect();
let mut traffic_rules_map = repo.get_traffic_rules_bulk(&model_ids).await?;
models = models
.into_iter()
.map(|model| {
let rules = traffic_rules_map.remove(&model.id).unwrap_or_default();
model.with_traffic_rules(rules)
})
.collect();
}
// Configure enrichment based on includes and permissions
let include_groups = includes.contains(&"groups");
let include_metrics = includes.contains(&"metrics");
let include_status = includes.contains(&"status");
let include_pricing = includes.contains(&"pricing");
let include_endpoints = includes.contains(&"endpoints");
let include_components = includes.contains(&"components");
// Use ModelEnricher to add requested data
let enricher = DeployedModelEnricher {
db: state.db.read(),
include_groups,
include_metrics,
include_status,
include_pricing,
include_endpoints,
include_components,
can_read_pricing,
can_read_rate_limits,
can_read_users,
can_read_composite_info: can_read_all_models,
};
let response = enricher.enrich_many(models).await?;
// Fetch facets if requested (reuse existing repo/connection to avoid
// acquiring a second read connection which could self-deadlock under pool
// saturation). The filter ensures facets respect the same access-control
// as the model list.
let include_facets = includes.contains(&"facets");
let facets = if include_facets {
let (providers, capabilities, model_types) = repo.facets(&filter).await?;
Some(ModelFacets {
providers,
capabilities,
model_types,
})
} else {
None
};
Ok(Json(ModelListResponse {
data: response,
total_count,
skip,
limit,
facets,
}))
}
#[utoipa::path(
post,
path = "/models",
tag = "models",
summary = "Create a new deployed model",
description = "Create a new deployed model. Admin only.",
request_body = DeployedModelCreate,
responses(
(status = 201, description = "Model created successfully", body = DeployedModelResponse),
(status = 400, description = "Bad request - invalid model data or duplicate alias/model name"),
(status = 401, description = "Unauthorized"),
(status = 403, description = "Forbidden - admin access required"),
(status = 404, description = "Inference endpoint not found"),
(status = 500, description = "Internal server error"),
),
security(
("BearerAuth" = []),
("CookieAuth" = []),
("X-Doubleword-User" = [])
)
)]
#[tracing::instrument(skip_all)]
pub async fn create_deployed_model<P: PoolProvider>(
State(state): State<AppState<P>>,
current_user: RequiresPermission<resource::Models, operation::CreateAll>,
Json(create): Json<DeployedModelCreate>,
) -> Result<Json<DeployedModelResponse>> {
// Extract common fields and variant-specific data
let (model_name, alias, hosted_on, tariffs, throughput) = match &create {
DeployedModelCreate::Standard(s) => (
s.model_name.trim(),
s.alias.as_deref().unwrap_or(s.model_name.trim()).trim(),
Some(s.hosted_on),
s.tariffs.clone(),
s.throughput,
),
DeployedModelCreate::Composite(c) => (
c.model_name.trim(),
c.alias.as_deref().unwrap_or(c.model_name.trim()).trim(),
None,
c.tariffs.clone(),
c.throughput,
),
};
if model_name.is_empty() || alias.is_empty() {
return Err(Error::BadRequest {
message: "Model name and alias must not be empty or whitespace".to_string(),
});
}
// Validate throughput is positive if provided
if let Some(t) = throughput
&& t <= 0.0
{
return Err(Error::BadRequest {
message: format!("throughput must be positive (> 0), got {}", t),
});
}
// Validate allowed batch completion windows against global config
let batch_windows = match &create {
DeployedModelCreate::Standard(s) => &s.allowed_batch_completion_windows,
DeployedModelCreate::Composite(c) => &c.allowed_batch_completion_windows,
};
if let Some(windows) = batch_windows {
let config = state.current_config();
let allowed = &config.batches.allowed_completion_windows;
for window in windows {
if !allowed.contains(window) {
return Err(Error::BadRequest {
message: format!(
"Invalid batch completion window '{}'. Configured windows: {}",
window,
allowed.join(", ")
),
});
}
}
}
// Validate metadata size
let metadata = match &create {
DeployedModelCreate::Standard(s) => &s.metadata,
DeployedModelCreate::Composite(c) => &c.metadata,
};
if let Some(m) = metadata {
validate_metadata(m)?;
}
let mut tx = state.db.write().begin().await.map_err(|e| Error::Database(e.into()))?;
// Validate endpoint exists (only for standard models)
if let Some(endpoint_id) = hosted_on {
let mut endpoints_repo = InferenceEndpoints::new(tx.acquire().await.map_err(|e| Error::Database(e.into()))?);
if endpoints_repo.get_by_id(endpoint_id).await?.is_none() {
return Err(Error::NotFound {
resource: "endpoint".to_string(),
id: endpoint_id.to_string(),
});
}
}
// Resolve traffic routing rules (alias → UUID) before creating
let traffic_rules_input = match &create {
DeployedModelCreate::Standard(s) => &s.traffic_routing_rules,
DeployedModelCreate::Composite(c) => &c.traffic_routing_rules,
};
let resolved_rules = if let Some(rules) = traffic_rules_input {
let mut repo = Deployments::new(tx.acquire().await.map_err(|e| Error::Database(e.into()))?);
Some(resolve_traffic_rules(rules, alias, &mut repo).await?)
} else {
None
};
// Create the deployment - let database constraints handle uniqueness
let mut repo = Deployments::new(tx.acquire().await.map_err(|e| Error::Database(e.into()))?);
let db_request = DeploymentCreateDBRequest::from_api_create(current_user.id, create);
let model = repo.create(&db_request).await?;
// Set traffic routing rules if provided
if let Some(rules) = &resolved_rules {
let mut repo = Deployments::new(tx.acquire().await.map_err(|e| Error::Database(e.into()))?);
repo.set_traffic_rules(model.id, rules).await?;
}
// Create tariffs if provided
if let Some(tariff_defs) = tariffs {
let mut tariffs_repo = Tariffs::new(tx.acquire().await.map_err(|e| Error::Database(e.into()))?);
for tariff_def in tariff_defs {
let tariff_request = TariffCreateDBRequest {
deployed_model_id: model.id,
name: tariff_def.name,
input_price_per_token: tariff_def.input_price_per_token,
output_price_per_token: tariff_def.output_price_per_token,
api_key_purpose: tariff_def.api_key_purpose,
completion_window: tariff_def.completion_window,
valid_from: None, // Use NOW()
};
tariffs_repo.create(&tariff_request).await?;
}
}
tx.commit().await.map_err(|e| Error::Database(e.into()))?;
// Fetch and attach traffic rules for the response
let mut response = DeployedModelResponse::from(model);
if resolved_rules.is_some() {
let mut conn = state.db.write().acquire().await.map_err(|e| Error::Database(e.into()))?;
let mut repo = Deployments::new(&mut conn);
let rules = repo.get_traffic_rules(response.id).await?;
response = response.with_traffic_rules(rules);
}
Ok(Json(response))
}
#[utoipa::path(
patch,
path = "/models/{id}",
tag = "models",
summary = "Update deployed model",
description = "Update a deployed model",
params(
("id" = uuid::Uuid, Path, description = "Deployment ID to update"),
),
responses(
(status = 200, description = "Deployed model updated successfully", body = DeployedModelResponse),
(status = 400, description = "Bad request - invalid model data"),
(status = 401, description = "Unauthorized"),
(status = 404, description = "Inference endpoint or deployment not found"),
(status = 500, description = "Internal server error"),
),
security(
("BearerAuth" = []),
("CookieAuth" = []),
("X-Doubleword-User" = [])
)
)]
#[tracing::instrument(skip_all)]
pub async fn update_deployed_model<P: PoolProvider>(
State(state): State<AppState<P>>,
Path(deployment_id): Path<DeploymentId>,
current_user: RequiresPermission<resource::Models, operation::UpdateAll>,
Json(update): Json<DeployedModelUpdate>,
) -> Result<Json<DeployedModelResponse>> {
let has_system_access = has_permission(¤t_user, resource::Models.into(), operation::SystemAccess.into());
if let Some(Some(t)) = update.throughput
&& t <= 0.0
{
return Err(Error::BadRequest {
message: format!("throughput must be positive (> 0), got {}", t),
});
}
// Validate allowed batch completion windows against global config
if let Some(Some(windows)) = &update.allowed_batch_completion_windows {
let config = state.current_config();
let allowed = &config.batches.allowed_completion_windows;
for window in windows {
if !allowed.contains(window) {
return Err(Error::BadRequest {
message: format!(
"Invalid batch completion window '{}'. Configured windows: {}",
window,
allowed.join(", ")
),
});
}
}
}
// Validate metadata size
if let Some(m) = &update.metadata {
validate_metadata(m)?;
}
let mut tx = state.db.write().begin().await.map_err(|e| Error::Database(e.into()))?;
// Verify deployment exists and check access based on permissions
let model_alias = {
let mut repo = Deployments::new(tx.acquire().await.map_err(|e| Error::Database(e.into()))?);
match repo.get_by_id(deployment_id).await {
Ok(Some(model)) => {
if model.deleted && !has_system_access {
return Err(Error::NotFound {
resource: "Deployment".to_string(),
id: deployment_id.to_string(),
});
}
model.alias.clone()
}
Ok(None) => {
return Err(Error::NotFound {
resource: "Deployment".to_string(),
id: deployment_id.to_string(),
});
}
Err(e) => return Err(e.into()),
}
};
// Resolve traffic routing rules if provided
let resolved_rules = match &update.traffic_routing_rules {
Some(Some(rules)) => {
let effective_alias = update.alias.as_deref().unwrap_or(&model_alias);
let mut repo = Deployments::new(tx.acquire().await.map_err(|e| Error::Database(e.into()))?);
Some(Some(resolve_traffic_rules(rules, effective_alias, &mut repo).await?))
}
Some(None) => Some(None), // Clear all rules
None => None, // No change
};
let tariffs = update.tariffs.clone();
let db_request = DeploymentUpdateDBRequest::from(update);
let mut repo = Deployments::new(tx.acquire().await.map_err(|e| Error::Database(e.into()))?);
let model = repo.update(deployment_id, &db_request).await?;
// Apply traffic rule changes
match &resolved_rules {
Some(Some(rules)) => {
let mut repo = Deployments::new(tx.acquire().await.map_err(|e| Error::Database(e.into()))?);
repo.set_traffic_rules(deployment_id, rules).await?;
}
Some(None) => {
// Clear all rules (pass empty slice)
let mut repo = Deployments::new(tx.acquire().await.map_err(|e| Error::Database(e.into()))?);
repo.set_traffic_rules(deployment_id, &[]).await?;
}
None => {} // No change
}
// Handle tariff replacement if provided
if let Some(tariff_defs) = tariffs {
let tariff_conn = tx.acquire().await.map_err(|e| Error::Database(e.into()))?;
let mut tariffs_repo = Tariffs::new(tariff_conn);
// Fetch current tariffs to compare
let current_tariffs = tariffs_repo.list_current_by_model(deployment_id).await?;
// Helper function to check if a tariff matches the definition
let tariff_matches = |existing: &crate::db::models::tariffs::ModelTariff,
def: &crate::api::models::deployments::TariffDefinition| {
existing.name == def.name
&& existing.input_price_per_token == def.input_price_per_token
&& existing.output_price_per_token == def.output_price_per_token
&& existing.api_key_purpose == def.api_key_purpose
&& existing.completion_window == def.completion_window
};
// Collect IDs of tariffs to close (those not in the new set or have changed)
let tariffs_to_close: Vec<uuid::Uuid> = current_tariffs
.iter()
.filter(|existing| !tariff_defs.iter().any(|def| tariff_matches(existing, def)))
.map(|t| t.id)
.collect();
// Batch close tariffs in a single query
if !tariffs_to_close.is_empty() {
tariffs_repo.close_tariffs_batch(&tariffs_to_close).await?;
}
// Create new or changed tariffs (skip those that already exist unchanged)
for tariff_def in tariff_defs {
// Skip if this tariff already exists with the same values
if current_tariffs.iter().any(|existing| tariff_matches(existing, &tariff_def)) {
continue;
}
let tariff_request = TariffCreateDBRequest {
deployed_model_id: deployment_id,
name: tariff_def.name,
input_price_per_token: tariff_def.input_price_per_token,
output_price_per_token: tariff_def.output_price_per_token,
api_key_purpose: tariff_def.api_key_purpose,
completion_window: tariff_def.completion_window,
valid_from: None, // Use NOW()
};
tariffs_repo.create(&tariff_request).await?;
}
}
tx.commit().await.map_err(|e| Error::Database(e.into()))?;
// Fetch and attach traffic rules for the response
let mut response = DeployedModelResponse::from(model);
let mut conn = state.db.write().acquire().await.map_err(|e| Error::Database(e.into()))?;
let mut repo = Deployments::new(&mut conn);
let rules = repo.get_traffic_rules(deployment_id).await?;
response = response.with_traffic_rules(rules);
Ok(Json(response))
}
#[utoipa::path(
get,
path = "/models/{id}",
tag = "models",
summary = "Get deployed model",
description = "Get a specific deployed model",
params(
("id" = uuid::Uuid, Path, description = "Deployment ID to retrieve"),
GetModelQuery
),
responses(
(status = 200, description = "Deployed model information", body = DeployedModelResponse),
(status = 401, description = "Unauthorized"),
(status = 404, description = "Inference endpoint or deployment not found"),
(status = 500, description = "Internal server error"),
),
security(
("BearerAuth" = []),
("CookieAuth" = []),
("X-Doubleword-User" = [])
)
)]
#[tracing::instrument(skip_all)]
pub async fn get_deployed_model<P: PoolProvider>(
State(state): State<AppState<P>>,
Path(deployment_id): Path<DeploymentId>,
Query(query): Query<GetModelQuery>,
current_user: CurrentUser,
) -> Result<Json<DeployedModelResponse>> {
let has_system_access = has_permission(¤t_user, resource::Models.into(), operation::SystemAccess.into());
let can_read_all_models = can_read_all_resources(¤t_user, Resource::Models);
let can_read_groups = can_read_all_resources(¤t_user, Resource::Groups);
let can_read_users = can_read_all_resources(¤t_user, Resource::Users);
let can_read_rate_limits = can_read_all_resources(¤t_user, Resource::ModelRateLimits);
let can_read_pricing = can_read_all_resources(¤t_user, Resource::Pricing);
let can_read_metrics = can_read_all_resources(¤t_user, Resource::Analytics);
let mut pool_conn = state.db.read().acquire().await.map_err(|e| Error::Database(e.into()))?;
let mut repo = Deployments::new(&mut pool_conn);
let model = match repo.get_by_id(deployment_id).await {
Ok(Some(model)) => model,
Ok(None) => {
return Err(Error::NotFound {
resource: "Deployment".to_string(),
id: deployment_id.to_string(),
});
}
Err(e) => return Err(e.into()),
};
// Check visibility rules based on model state and user permissions
match (model.deleted, &model.status) {
// Deleted models: only show to admins who explicitly request them
(true, _) => {
if !has_system_access || !query.deleted.unwrap_or(false) {
return Err(Error::NotFound {
resource: "Deployment".to_string(),
id: deployment_id.to_string(),
});
}
}
// Inactive models: only show to admins who explicitly request them
(false, ModelStatus::Inactive) => {
if !has_system_access || !query.inactive.unwrap_or(false) {
return Err(Error::NotFound {
resource: "Deployment".to_string(),
id: deployment_id.to_string(),
});
}
}
// Active models (or other statuses): always visible if not deleted
(false, _) => {
// Model is visible, continue
}
}
// Check group-based access control for non-admin users
if !can_read_all_models {
let has_access = repo.check_user_access(&model.alias, current_user.id).await?;
if has_access.is_none() {
return Err(Error::NotFound {
resource: "Deployment".to_string(),
id: deployment_id.to_string(),
});
}
}
// Parse include parameters and filter based on permissions (same logic as list)
let include_params = query.include.as_deref().unwrap_or("");
let all_includes: Vec<&str> = include_params.split(',').map(|s| s.trim()).filter(|s| !s.is_empty()).collect();
// Filter includes based on permissions
let mut include_groups = false;
let mut include_metrics = false;
let mut include_status = false;
let mut include_pricing = false;
let mut include_endpoints = false;
let mut include_components = false;
for &include in &all_includes {
match include {
"groups"
// Only users with Groups::ReadAll can include groups
if can_read_groups => {
include_groups = true;
}
"metrics"
// Only users with Analytics::ReadAll can include metrics
if can_read_metrics => {
include_metrics = true;
}
"endpoints"
// Model endpoints is priviliged information for admins
if can_read_all_models => {
include_endpoints = true;
}
"status" => {
// Status is allowed for all users
include_status = true;
}
"pricing" => {
// Pricing is allowed for all users (enricher handles ReadAll permission)
include_pricing = true;
}
"components" => {
// Components is allowed for all users (enricher handles composite info permission)
include_components = true;
}
_ => {
// Unknown includes are ignored
}
}
}
// Build base response with provider_pricing based on permissions and includes
let provider_pricing = if include_pricing && can_read_pricing {
model.provider_pricing.clone()
} else {
None
};
let mut response = DeployedModelResponse::from(model).with_provider_pricing(provider_pricing);
// Fetch and attach traffic rules
{
let traffic_rules = repo.get_traffic_rules(deployment_id).await?;
response = response.with_traffic_rules(traffic_rules);
}
// Use ModelEnricher to add related data
let enricher = DeployedModelEnricher {
db: state.db.read(),
include_groups,
include_metrics,
include_status,
include_pricing,
include_endpoints,
include_components,
can_read_pricing,
can_read_rate_limits,
can_read_users,
can_read_composite_info: can_read_all_models,
};
response = enricher.enrich_one(response).await?;
Ok(Json(response))
}
#[utoipa::path(
delete,
path = "/models/{id}",
tag = "models",
summary = "Delete deployed model",
description = "Delete a deployed model",
params(
("id" = uuid::Uuid, Path, description = "Deployment ID to delete"),
),
responses(
(status = 200, description = "Deployed model deleted successfully"),
(status = 401, description = "Unauthorized"),
(status = 404, description = "Inference endpoint or deployment not found"),
(status = 500, description = "Internal server error"),
),
security(
("BearerAuth" = []),
("CookieAuth" = []),
("X-Doubleword-User" = [])
)
)]
#[tracing::instrument(skip_all)]
pub async fn delete_deployed_model<P: PoolProvider>(
State(state): State<AppState<P>>,
Path(deployment_id): Path<DeploymentId>,
_: RequiresPermission<resource::Models, operation::DeleteAll>,
) -> Result<Json<String>> {
let mut pool_conn = state.db.write().acquire().await.map_err(|e| Error::Database(e.into()))?;
let mut repo = Deployments::new(&mut pool_conn);
// Hide model by setting deleted flag
let update_request = DeploymentUpdateDBRequest::visibility_update(true);
repo.update(deployment_id, &update_request).await?;
Ok(Json(deployment_id.to_string()))
}
// ===== Composite Model Component Handlers =====
use crate::api::models::deployments::{ModelComponentCreate, ModelComponentUpdate};
use crate::db::models::deployments::DeploymentComponentCreateDBRequest;
#[utoipa::path(
get,
path = "/models/{id}/components",
tag = "models",
summary = "Get composite model components",
description = "Get the list of underlying models that make up a composite model",
params(
("id" = String, Path, description = "The composite model ID", format = "uuid"),
),
responses(
(status = 200, description = "List of components", body = Vec<ModelComponentResponse>),
(status = 400, description = "Model is not a composite model"),
(status = 401, description = "Unauthorized"),
(status = 404, description = "Composite model not found"),
(status = 500, description = "Internal server error"),
),
security(
("BearerAuth" = []),
("CookieAuth" = []),
("X-Doubleword-User" = [])
)
)]
#[tracing::instrument(skip_all)]
pub async fn get_model_components<P: PoolProvider>(
State(state): State<AppState<P>>,
Path(id): Path<DeploymentId>,
_: RequiresPermission<resource::CompositeModels, operation::ReadAll>,
) -> Result<Json<Vec<ModelComponentResponse>>> {
// Verify the model exists and is composite
{
let mut conn = state.db.read().acquire().await.map_err(|e| Error::Database(e.into()))?;
let mut repo = Deployments::new(&mut conn);
let deployment = repo.get_by_id(id).await?.ok_or_else(|| Error::NotFound {
resource: "model".to_string(),
id: id.to_string(),
})?;
if !deployment.is_composite {
return Err(Error::BadRequest {
message: "Model is not a composite model".to_string(),
});
}
}
// Get components
let mut conn = state.db.read().acquire().await.map_err(|e| Error::Database(e.into()))?;
let mut repo = Deployments::new(&mut conn);
let components = repo.get_components(id).await?;
let response: Vec<ModelComponentResponse> = components.into_iter().map(db_component_to_response).collect();
Ok(Json(response))
}
#[utoipa::path(
post,
path = "/models/{id}/components/{component_id}",
tag = "models",
summary = "Add component to composite model",
description = "Add an underlying model as a component of a composite model",
params(
("id" = String, Path, description = "The composite model ID", format = "uuid"),
("component_id" = String, Path, description = "The deployed model ID to add as a component", format = "uuid"),
),
request_body = ModelComponentCreate,
responses(
(status = 200, description = "Component added", body = ModelComponentResponse),
(status = 400, description = "Model is not a composite model or component is not valid"),
(status = 401, description = "Unauthorized"),
(status = 404, description = "Composite model or component model not found"),
(status = 500, description = "Internal server error"),
),
security(
("BearerAuth" = []),
("CookieAuth" = []),
("X-Doubleword-User" = [])
)
)]
#[tracing::instrument(skip_all)]
pub async fn add_model_component<P: PoolProvider>(
State(state): State<AppState<P>>,
Path((id, component_id)): Path<(DeploymentId, DeploymentId)>,
_: RequiresPermission<resource::CompositeModels, operation::UpdateAll>,
Json(body): Json<ModelComponentCreate>,
) -> Result<Json<ModelComponentResponse>> {
// Validate weight
if !(1..=100).contains(&body.weight) {
return Err(Error::BadRequest {
message: "Weight must be between 1 and 100".to_string(),
});
}
let mut tx = state.db.write().begin().await.map_err(|e| Error::Database(e.into()))?;
// Verify both models exist and constraints are met
{
let mut repo = Deployments::new(tx.acquire().await.map_err(|e| Error::Database(e.into()))?);
// Check composite model exists and is composite
let composite = repo.get_by_id(id).await?.ok_or_else(|| Error::NotFound {
resource: "composite model".to_string(),
id: id.to_string(),
})?;
if !composite.is_composite {
return Err(Error::BadRequest {
message: "Model is not a composite model".to_string(),
});
}
// Check component model exists and is NOT composite
let component = repo.get_by_id(component_id).await?.ok_or_else(|| Error::NotFound {
resource: "component model".to_string(),
id: component_id.to_string(),
})?;
if component.is_composite {
return Err(Error::BadRequest {
message: "Cannot add a composite model as a component".to_string(),
});
}
}
// Add the component
let mut repo = Deployments::new(tx.acquire().await.map_err(|e| Error::Database(e.into()))?);
let request = DeploymentComponentCreateDBRequest {
composite_model_id: id,
deployed_model_id: component_id,
weight: body.weight,
enabled: body.enabled,
sort_order: body.sort_order,
};
let component = repo.add_component(&request).await?;
tx.commit().await.map_err(|e| Error::Database(e.into()))?;
Ok(Json(db_component_to_response(component)))
}
#[utoipa::path(
patch,
path = "/models/{id}/components/{component_id}",
tag = "models",
summary = "Update component in composite model",
description = "Update the weight or enabled status of a component",
params(
("id" = String, Path, description = "The composite model ID", format = "uuid"),
("component_id" = String, Path, description = "The deployed model ID of the component", format = "uuid"),
),
request_body = ModelComponentUpdate,
responses(
(status = 200, description = "Component updated", body = ModelComponentResponse),
(status = 400, description = "Invalid weight"),
(status = 401, description = "Unauthorized"),
(status = 404, description = "Composite model or component not found"),
(status = 500, description = "Internal server error"),
),
security(
("BearerAuth" = []),
("CookieAuth" = []),
("X-Doubleword-User" = [])
)
)]
#[tracing::instrument(skip_all)]
pub async fn update_model_component<P: PoolProvider>(
State(state): State<AppState<P>>,
Path((id, component_id)): Path<(DeploymentId, DeploymentId)>,
_: RequiresPermission<resource::CompositeModels, operation::UpdateAll>,
Json(body): Json<ModelComponentUpdate>,
) -> Result<Json<ModelComponentResponse>> {
// Validate weight if provided
if let Some(weight) = body.weight
&& !(1..=100).contains(&weight)
{
return Err(Error::BadRequest {
message: "Weight must be between 1 and 100".to_string(),
});
}
let mut pool_conn = state.db.write().acquire().await.map_err(|e| Error::Database(e.into()))?;
let mut repo = Deployments::new(&mut pool_conn);
let component = repo
.update_component(id, component_id, body.weight, body.enabled, body.sort_order)
.await?
.ok_or_else(|| Error::NotFound {
resource: "component".to_string(),
id: format!("{}/{}", id, component_id),
})?;
Ok(Json(db_component_to_response(component)))
}
#[utoipa::path(
delete,
path = "/models/{id}/components/{component_id}",
tag = "models",
summary = "Remove component from composite model",
description = "Remove an underlying model from a composite model",
params(
("id" = String, Path, description = "The composite model ID", format = "uuid"),
("component_id" = String, Path, description = "The deployed model ID of the component to remove", format = "uuid"),
),
responses(
(status = 200, description = "Component removed"),
(status = 401, description = "Unauthorized"),
(status = 404, description = "Composite model or component not found"),
(status = 500, description = "Internal server error"),
),
security(
("BearerAuth" = []),
("CookieAuth" = []),
("X-Doubleword-User" = [])
)
)]
#[tracing::instrument(skip_all)]
pub async fn remove_model_component<P: PoolProvider>(
State(state): State<AppState<P>>,
Path((id, component_id)): Path<(DeploymentId, DeploymentId)>,
_: RequiresPermission<resource::CompositeModels, operation::UpdateAll>,
) -> Result<Json<String>> {
let mut pool_conn = state.db.write().acquire().await.map_err(|e| Error::Database(e.into()))?;
let mut repo = Deployments::new(&mut pool_conn);
let removed = repo.remove_component(id, component_id).await?;
if !removed {
return Err(Error::NotFound {
resource: "component".to_string(),
id: format!("{}/{}", id, component_id),
});
}
Ok(Json("Component removed".to_string()))
}
#[cfg(test)]
mod tests {
use crate::{
api::{
handlers::deployments::DeployedModelResponse,
models::{pagination::PaginatedResponse, users::Role},
},
db::{
handlers::{Groups, Repository},
models::groups::GroupCreateDBRequest,
},
test::utils::*,
types::DeploymentId,
};
use serde_json::json;
use sqlx::PgPool;
/// Helper function to find a model by ID in a paginated response
fn get_model_by_id(id: DeploymentId, response: &PaginatedResponse<DeployedModelResponse>) -> Option<&DeployedModelResponse> {
response.data.iter().find(|model| model.id == id)
}
#[sqlx::test]
#[test_log::test]
async fn test_list_deployments(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let user = create_test_user(&pool, Role::StandardUser).await;
let test_endpoint_id = get_test_endpoint_id(&pool).await;
let response = app
.get(&format!("/admin/api/v1/models?endpoint={test_endpoint_id}"))
.add_header(&add_auth_headers(&user)[0].0, &add_auth_headers(&user)[0].1)
.add_header(&add_auth_headers(&user)[1].0, &add_auth_headers(&user)[1].1)
.await;
response.assert_status_ok();
let response_body: PaginatedResponse<DeployedModelResponse> = response.json();
// Should be empty initially, but test that it returns proper structure
assert!(response_body.data.is_empty() || !response_body.data.is_empty());
}
#[sqlx::test]
#[test_log::test]
async fn test_deployments_with_nonexistent_endpoint(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let user = create_test_user(&pool, Role::StandardUser).await;
let non_existent_id = uuid::Uuid::new_v4();
let response = app
.get(&format!("/admin/api/v1/models?endpoint={non_existent_id}"))
.add_header(&add_auth_headers(&user)[0].0, &add_auth_headers(&user)[0].1)
.add_header(&add_auth_headers(&user)[1].0, &add_auth_headers(&user)[1].1)
.await;
response.assert_status_not_found();
}
#[sqlx::test]
#[test_log::test]
async fn test_model_operations(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let user = create_test_admin_user(&pool, Role::PlatformManager).await;
// Create a deployment on the test endpoint
let created = create_test_deployment(&pool, user.id, "test-model", "test-alias").await;
let deployment_id = created.id;
// Get the deployment
let response = app
.get(&format!("/admin/api/v1/models/{deployment_id}"))
.add_header(&add_auth_headers(&user)[0].0, &add_auth_headers(&user)[0].1)
.add_header(&add_auth_headers(&user)[1].0, &add_auth_headers(&user)[1].1)
.await;
response.assert_status_ok();
let model: DeployedModelResponse = response.json();
assert_eq!(model.id, deployment_id);
assert_eq!(model.model_name, "test-model");
assert_eq!(model.alias, "test-alias");
// Update the deployment
let update = json!({
"alias": "new-alias"
});
let response = app
.patch(&format!("/admin/api/v1/models/{deployment_id}"))
.add_header(&add_auth_headers(&user)[0].0, &add_auth_headers(&user)[0].1)
.add_header(&add_auth_headers(&user)[1].0, &add_auth_headers(&user)[1].1)
.json(&update)
.await;
response.assert_status_ok();
let updated_model: DeployedModelResponse = response.json();
assert_eq!(updated_model.alias, "new-alias");
// List models with endpoint filter
let test_endpoint_id = get_test_endpoint_id(&pool).await;
let response = app
.get(&format!("/admin/api/v1/models?endpoint={test_endpoint_id}"))
.add_header(&add_auth_headers(&user)[0].0, &add_auth_headers(&user)[0].1)
.add_header(&add_auth_headers(&user)[1].0, &add_auth_headers(&user)[1].1)
.await;
response.assert_status_ok();
let response_body: PaginatedResponse<DeployedModelResponse> = response.json();
assert!(response_body.data.iter().any(|it| it.id == deployment_id));
// Delete the deployment
let response = app
.delete(&format!("/admin/api/v1/models/{deployment_id}"))
.add_header(&add_auth_headers(&user)[0].0, &add_auth_headers(&user)[0].1)
.add_header(&add_auth_headers(&user)[1].0, &add_auth_headers(&user)[1].1)
.await;
response.assert_status_ok();
// Verify it's deleted - should return 404 without deleted=true parameter
let response = app
.get(&format!("/admin/api/v1/models/{deployment_id}"))
.add_header(&add_auth_headers(&user)[0].0, &add_auth_headers(&user)[0].1)
.add_header(&add_auth_headers(&user)[1].0, &add_auth_headers(&user)[1].1)
.await;
response.assert_status_not_found(); // Returns 404 without deleted=true
// But admin should be able to see it with deleted=true
let response = app
.get(&format!("/admin/api/v1/models/{deployment_id}?deleted=true"))
.add_header(&add_auth_headers(&user)[0].0, &add_auth_headers(&user)[0].1)
.add_header(&add_auth_headers(&user)[1].0, &add_auth_headers(&user)[1].1)
.await;
response.assert_status_ok(); // Admin can see deleted model with deleted=true
}
#[sqlx::test]
#[test_log::test]
async fn test_list_deployments_with_groups_include(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let admin_user = create_test_admin_user(&pool, Role::PlatformManager).await;
// Create a deployment
let deployment = create_test_deployment(&pool, admin_user.id, "test-model", "test-alias").await;
assert!(deployment.last_sync.is_none());
// Create a group and add the deployment to it
let mut pool_conn = pool.acquire().await.unwrap();
let mut group_repo = Groups::new(&mut pool_conn);
let group_create = GroupCreateDBRequest {
name: "Test Group".to_string(),
description: Some("Test group for deployment".to_string()),
created_by: admin_user.id,
};
let group = group_repo.create(&group_create).await.expect("Failed to create test group");
group_repo
.add_deployment_to_group(deployment.id, group.id, admin_user.id)
.await
.expect("Failed to add deployment to group");
// Test without include parameter - should not include groups
let response = app
.get("/admin/api/v1/models")
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
let models: PaginatedResponse<DeployedModelResponse> = response.json();
assert!(models.data.iter().any(|it| it.id == deployment.id && it.groups.is_none()));
// Test with include=groups - should include groups
let response = app
.get("/admin/api/v1/models?include=groups")
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
let models: PaginatedResponse<DeployedModelResponse> = response.json();
assert!(
models
.data
.iter()
.any(|it| { it.id == deployment.id && it.groups.as_deref().is_some_and(|gs| gs.len() == 1 && gs[0].id == group.id) })
);
// Test with include=groups and endpoint filter
let test_endpoint_id = get_test_endpoint_id(&pool).await;
let response = app
.get(&format!("/admin/api/v1/models?endpoint={test_endpoint_id}&include=groups"))
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
let models: PaginatedResponse<DeployedModelResponse> = response.json();
assert!(
models
.data
.iter()
.any(|it| { it.id == deployment.id && it.groups.as_deref().is_some_and(|gs| gs.iter().any(|g| g.id == group.id)) })
);
}
#[sqlx::test]
#[test_log::test]
async fn test_role_based_visibility_for_deleted_models(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let admin_user = create_test_admin_user(&pool, Role::PlatformManager).await;
let regular_user = create_test_user(&pool, Role::StandardUser).await;
// Create a deployment
let deployment = create_test_deployment(&pool, admin_user.id, "test-model", "test-alias").await;
let deployment_id = deployment.id;
let everyone_group_id = uuid::Uuid::nil();
add_deployment_to_group(&pool, deployment_id, everyone_group_id, admin_user.id).await;
// Both users should initially see the model
let response = app
.get(&format!("/admin/api/v1/models/{deployment_id}"))
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
let response = app
.get(&format!("/admin/api/v1/models/{deployment_id}"))
.add_header(&add_auth_headers(®ular_user)[0].0, &add_auth_headers(®ular_user)[0].1)
.add_header(&add_auth_headers(®ular_user)[1].0, &add_auth_headers(®ular_user)[1].1)
.await;
response.assert_status_ok();
// Admin hides the model (soft delete)
let response = app
.delete(&format!("/admin/api/v1/models/{deployment_id}"))
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
// Admin should still be able to see the deleted model with deleted=true
let response = app
.get(&format!("/admin/api/v1/models/{deployment_id}?deleted=true"))
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
let model: DeployedModelResponse = response.json();
assert_eq!(model.id, deployment_id);
// Regular user should NOT see the deleted model (404)
let response = app
.get(&format!("/admin/api/v1/models/{deployment_id}"))
.add_header(&add_auth_headers(®ular_user)[0].0, &add_auth_headers(®ular_user)[0].1)
.add_header(&add_auth_headers(®ular_user)[1].0, &add_auth_headers(®ular_user)[1].1)
.await;
response.assert_status_not_found();
// Verify the API behavior is consistent with soft deletion
}
#[sqlx::test]
#[test_log::test]
async fn test_role_based_list_filtering(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let admin_user = create_test_admin_user(&pool, Role::PlatformManager).await;
let regular_user = create_test_user(&pool, Role::StandardUser).await;
// Create multiple deployments
let deployment1 = create_test_deployment(&pool, admin_user.id, "active-model", "active-alias").await;
let deployment2 = create_test_deployment(&pool, admin_user.id, "hidden-model", "hidden-alias").await;
// Create a group and add regular user to it so they can see deployment1
let mut pool_conn = pool.acquire().await.unwrap();
let mut group_repo = Groups::new(&mut pool_conn);
let group_create = GroupCreateDBRequest {
name: "List Filter Test Group".to_string(),
description: Some("Test group for list filtering".to_string()),
created_by: admin_user.id,
};
let group = group_repo.create(&group_create).await.unwrap();
group_repo.add_user_to_group(regular_user.id, group.id).await.unwrap();
// Add deployment1 to the group (regular user should see this)
group_repo
.add_deployment_to_group(deployment1.id, group.id, admin_user.id)
.await
.unwrap();
// Don't add deployment2 to any group (regular user shouldn't see it)
// Hide the second deployment
let response = app
.delete(&format!("/admin/api/v1/models/{}", deployment2.id))
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
// Admin should see both models in list when requesting deleted=true (include deleted)
let response = app
.get("/admin/api/v1/models?deleted=true")
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
let admin_all_models: PaginatedResponse<DeployedModelResponse> = response.json();
assert!(admin_all_models.data.iter().any(|it| it.id == deployment1.id));
assert!(admin_all_models.data.iter().any(|it| it.id == deployment2.id));
// Admin should see only non-deleted models by default
let response = app
.get("/admin/api/v1/models")
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
let admin_models: PaginatedResponse<DeployedModelResponse> = response.json();
assert!(admin_models.data.iter().any(|it| it.id == deployment1.id));
assert!(!admin_models.data.iter().any(|it| it.id == deployment2.id));
// Regular user should only see the active model
let response = app
.get("/admin/api/v1/models")
.add_header(&add_auth_headers(®ular_user)[0].0, &add_auth_headers(®ular_user)[0].1)
.add_header(&add_auth_headers(®ular_user)[1].0, &add_auth_headers(®ular_user)[1].1)
.await;
response.assert_status_ok();
let user_models: PaginatedResponse<DeployedModelResponse> = response.json();
assert!(user_models.data.iter().any(|it| it.id == deployment1.id));
assert!(!user_models.data.iter().any(|it| it.id == deployment2.id));
}
#[sqlx::test]
#[test_log::test]
async fn test_role_based_update_access_for_deleted_models(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let admin_user = create_test_admin_user(&pool, Role::PlatformManager).await;
let regular_user = create_test_user(&pool, Role::StandardUser).await;
// Create and hide a deployment
let deployment = create_test_deployment(&pool, admin_user.id, "update-test-model", "update-test-alias").await;
let deployment_id = deployment.id;
// Hide the model
let response = app
.delete(&format!("/admin/api/v1/models/{deployment_id}"))
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
// Admin should be able to update the deleted model
let update = json!({
"alias": "admin-updated-alias"
});
let response = app
.patch(&format!("/admin/api/v1/models/{deployment_id}"))
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.json(&update)
.await;
response.assert_status_ok();
let updated_model: DeployedModelResponse = response.json();
assert_eq!(updated_model.alias, "admin-updated-alias");
// Regular user should NOT be able to update the deleted model (404)
let update = json!({
"alias": "user-attempted-update"
});
let response = app
.patch(&format!("/admin/api/v1/models/{deployment_id}"))
.add_header(&add_auth_headers(®ular_user)[0].0, &add_auth_headers(®ular_user)[0].1)
.add_header(&add_auth_headers(®ular_user)[1].0, &add_auth_headers(®ular_user)[1].1)
.json(&update)
.await;
response.assert_status_forbidden();
}
#[sqlx::test]
#[test_log::test]
async fn test_soft_delete_preserves_model_accessibility_for_admin(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let admin_user = create_test_admin_user(&pool, Role::PlatformManager).await;
let regular_user = create_test_user(&pool, Role::StandardUser).await;
// Create a deployment via API
let deployment = create_test_deployment(&pool, admin_user.id, "preserve-test-model", "preserve-test-alias").await;
let deployment_id = deployment.id;
// Add to Everyone group so regular users can access it
let everyone_group_id = uuid::Uuid::nil();
add_deployment_to_group(&pool, deployment_id, everyone_group_id, admin_user.id).await;
// Verify both users can initially access the model
let response = app
.get(&format!("/admin/api/v1/models/{deployment_id}"))
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
let response = app
.get(&format!("/admin/api/v1/models/{deployment_id}"))
.add_header(&add_auth_headers(®ular_user)[0].0, &add_auth_headers(®ular_user)[0].1)
.add_header(&add_auth_headers(®ular_user)[1].0, &add_auth_headers(®ular_user)[1].1)
.await;
response.assert_status_ok();
// Admin soft deletes the model
let response = app
.delete(&format!("/admin/api/v1/models/{deployment_id}"))
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
// Admin can still access the model after soft deletion with deleted=true
let response = app
.get(&format!("/admin/api/v1/models/{deployment_id}?deleted=true"))
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
let model: DeployedModelResponse = response.json();
assert_eq!(model.model_name, "preserve-test-model");
assert_eq!(model.alias, "preserve-test-alias");
// Regular user can no longer access the model
let response = app
.get(&format!("/admin/api/v1/models/{deployment_id}"))
.add_header(&add_auth_headers(®ular_user)[0].0, &add_auth_headers(®ular_user)[0].1)
.add_header(&add_auth_headers(®ular_user)[1].0, &add_auth_headers(®ular_user)[1].1)
.await;
response.assert_status_not_found();
// Admin can still update the soft-deleted model
let update = json!({
"alias": "updated-after-deletion"
});
let response = app
.patch(&format!("/admin/api/v1/models/{deployment_id}"))
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.json(&update)
.await;
response.assert_status_ok();
let updated_model: DeployedModelResponse = response.json();
assert_eq!(updated_model.alias, "updated-after-deletion");
}
#[sqlx::test]
#[test_log::test]
async fn test_create_deployed_model(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let admin_user = create_test_admin_user(&pool, Role::PlatformManager).await;
let test_endpoint_id = get_test_endpoint_id(&pool).await;
// Create a model via API
let create_request = json!({
"type": "standard",
"model_name": "test-new-model",
"alias": "Test New Model",
"hosted_on": test_endpoint_id.to_string(),
"description": "A test model created via API",
"model_type": "CHAT",
"capabilities": ["text-generation", "streaming"]
});
let response = app
.post("/admin/api/v1/models")
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.json(&create_request)
.await;
response.assert_status_ok();
let created_model: DeployedModelResponse = response.json();
assert_eq!(created_model.model_name, "test-new-model");
assert_eq!(created_model.alias, "Test New Model");
assert_eq!(created_model.hosted_on, Some(test_endpoint_id));
assert_eq!(created_model.description, Some("A test model created via API".to_string()));
assert_eq!(created_model.created_by, Some(admin_user.id));
}
#[sqlx::test]
#[test_log::test]
async fn test_create_deployed_model_with_defaults(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let admin_user = create_test_admin_user(&pool, Role::PlatformManager).await;
let test_endpoint_id = get_test_endpoint_id(&pool).await;
// Create a model with minimal data (alias should default to model_name)
let create_request = json!({
"type": "standard",
"model_name": "simple-model",
"hosted_on": test_endpoint_id.to_string()
});
let response = app
.post("/admin/api/v1/models")
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.json(&create_request)
.await;
response.assert_status_ok();
let created_model: DeployedModelResponse = response.json();
assert_eq!(created_model.model_name, "simple-model");
assert_eq!(created_model.alias, "simple-model"); // Should default to model_name
assert_eq!(created_model.hosted_on, Some(test_endpoint_id));
assert_eq!(created_model.description, None);
assert_eq!(created_model.created_by, Some(admin_user.id));
}
#[sqlx::test]
#[test_log::test]
async fn test_create_deployed_model_non_admin_forbidden(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let regular_user = create_test_user(&pool, Role::StandardUser).await;
let test_endpoint_id = get_test_endpoint_id(&pool).await;
let create_request = json!({
"type": "standard",
"model_name": "forbidden-model",
"hosted_on": test_endpoint_id.to_string()
});
let response = app
.post("/admin/api/v1/models")
.add_header(&add_auth_headers(®ular_user)[0].0, &add_auth_headers(®ular_user)[0].1)
.add_header(&add_auth_headers(®ular_user)[1].0, &add_auth_headers(®ular_user)[1].1)
.json(&create_request)
.await;
response.assert_status_forbidden();
}
#[sqlx::test]
#[test_log::test]
async fn test_create_deployed_model_nonexistent_endpoint(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let admin_user = create_test_admin_user(&pool, Role::PlatformManager).await;
let create_request = json!({
"type": "standard",
"model_name": "test-model",
"hosted_on": "99999999-9999-9999-9999-999999999999" // Non-existent endpoint
});
let response = app
.post("/admin/api/v1/models")
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.json(&create_request)
.await;
response.assert_status_not_found();
}
#[sqlx::test]
#[test_log::test]
async fn test_include_groups_admin_only(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let admin_user = create_test_admin_user(&pool, Role::PlatformManager).await;
let regular_user = create_test_user(&pool, Role::StandardUser).await;
// Create a deployment
let deployment = create_test_deployment(&pool, admin_user.id, "groups-test-model", "groups-test-alias").await;
// Create a group and add the deployment to it
let mut pool_conn = pool.acquire().await.unwrap();
let mut groups_repo = Groups::new(&mut pool_conn);
let group_create = GroupCreateDBRequest {
name: "Test Group".to_string(),
description: Some("Test group for include test".to_string()),
created_by: admin_user.id,
};
let group = groups_repo.create(&group_create).await.expect("Failed to create group");
groups_repo
.add_deployment_to_group(deployment.id, group.id, admin_user.id)
.await
.expect("Failed to add deployment to group");
// Add regular user to the group so they can see the deployment
groups_repo
.add_user_to_group(regular_user.id, group.id)
.await
.expect("Failed to add regular user to group");
// Admin should be able to include groups and see them
let response = app
.get("/admin/api/v1/models?include=groups")
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
let models: PaginatedResponse<DeployedModelResponse> = response.json();
// Find our test deployment by ID and verify it has groups included
let test_model = get_model_by_id(deployment.id, &models).unwrap_or_else(|| {
panic!(
"Test model not found. Available models: {:?}",
models.data.iter().map(|m| &m.id).collect::<Vec<_>>()
)
});
assert!(test_model.groups.is_some(), "Admin should see groups included");
let groups = test_model.groups.as_ref().unwrap();
assert_eq!(groups.len(), 1, "Should have exactly one group");
assert_eq!(groups[0].name, "Test Group");
// Regular user should NOT be able to include groups (groups should be None)
let response = app
.get("/admin/api/v1/models?include=groups")
.add_header(&add_auth_headers(®ular_user)[0].0, &add_auth_headers(®ular_user)[0].1)
.add_header(&add_auth_headers(®ular_user)[1].0, &add_auth_headers(®ular_user)[1].1)
.await;
response.assert_status_ok();
let models: PaginatedResponse<DeployedModelResponse> = response.json();
// Find our test deployment by ID and verify groups are NOT included
let test_model = get_model_by_id(deployment.id, &models).unwrap_or_else(|| {
panic!(
"Test model not found. Available models: {:?}",
models.data.iter().map(|m| &m.id).collect::<Vec<_>>()
)
});
assert!(test_model.groups.is_none(), "Regular user should NOT see groups included");
}
#[sqlx::test]
#[test_log::test]
async fn test_accessible_parameter_filtering(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let admin_user = create_test_admin_user(&pool, Role::PlatformManager).await;
let regular_user = create_test_user(&pool, Role::StandardUser).await;
// Create deployments
let deployment1 = create_test_deployment(&pool, admin_user.id, "test-model-1", "test-alias-1").await;
let deployment2 = create_test_deployment(&pool, admin_user.id, "test-model-2", "test-alias-2").await;
// Create a group and add regular user to it
let mut pool_conn = pool.acquire().await.unwrap();
let mut group_repo = Groups::new(&mut pool_conn);
let group_create = GroupCreateDBRequest {
name: "Access Test Group".to_string(),
description: Some("Test group for accessible filtering".to_string()),
created_by: admin_user.id,
};
let group = group_repo.create(&group_create).await.unwrap();
group_repo.add_user_to_group(regular_user.id, group.id).await.unwrap();
// Add only deployment1 to the group (regular user should only access this one)
group_repo
.add_deployment_to_group(deployment1.id, group.id, admin_user.id)
.await
.unwrap();
// Don't add deployment2 to any group
// Test 1: Regular user without accessible=true should still get filtered (default behavior)
let response = app
.get("/admin/api/v1/models")
.add_header(&add_auth_headers(®ular_user)[0].0, &add_auth_headers(®ular_user)[0].1)
.add_header(&add_auth_headers(®ular_user)[1].0, &add_auth_headers(®ular_user)[1].1)
.await;
response.assert_status_ok();
let user_models: PaginatedResponse<DeployedModelResponse> = response.json();
assert_eq!(user_models.data.len(), 1, "Regular user should only see 1 accessible model");
assert!(get_model_by_id(deployment1.id, &user_models).is_some());
assert!(get_model_by_id(deployment2.id, &user_models).is_none());
// Test 2: Regular user with accessible=true should get same result (explicit filtering)
let response = app
.get("/admin/api/v1/models?accessible=true")
.add_header(&add_auth_headers(®ular_user)[0].0, &add_auth_headers(®ular_user)[0].1)
.add_header(&add_auth_headers(®ular_user)[1].0, &add_auth_headers(®ular_user)[1].1)
.await;
response.assert_status_ok();
let user_models_explicit: PaginatedResponse<DeployedModelResponse> = response.json();
assert_eq!(user_models_explicit.data.len(), 1);
assert!(get_model_by_id(deployment1.id, &user_models_explicit).is_some());
// Test 3: Admin user without accessible parameter should see all models (default)
let response = app
.get("/admin/api/v1/models")
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
let admin_models: PaginatedResponse<DeployedModelResponse> = response.json();
assert_eq!(admin_models.data.len(), 2, "Admin should see all models by default");
assert!(get_model_by_id(deployment1.id, &admin_models).is_some());
assert!(get_model_by_id(deployment2.id, &admin_models).is_some());
// Test 4: Admin user with accessible=false should see all models (explicit no filtering)
let response = app
.get("/admin/api/v1/models?accessible=false")
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
let admin_models_explicit: PaginatedResponse<DeployedModelResponse> = response.json();
assert_eq!(admin_models_explicit.data.len(), 2);
// Test 5: Admin user with accessible=true should get filtered results (only their accessible models)
// First add admin to a group and that group to deployment1
group_repo.add_user_to_group(admin_user.id, group.id).await.unwrap();
let response = app
.get("/admin/api/v1/models?accessible=true")
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
let admin_accessible: PaginatedResponse<DeployedModelResponse> = response.json();
assert_eq!(
admin_accessible.data.len(),
1,
"Admin with accessible=true should only see their accessible models"
);
assert!(get_model_by_id(deployment1.id, &admin_accessible).is_some());
}
#[sqlx::test]
#[test_log::test]
async fn test_include_metrics_parameter(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let admin_user = create_test_admin_user(&pool, Role::PlatformManager).await;
let regular_user = create_test_user(&pool, Role::StandardUser).await;
// Create a deployment
let deployment = create_test_deployment(&pool, admin_user.id, "metrics-test-model", "metrics-test-alias").await;
// Test without include parameter - should not include metrics
let response = app
.get("/admin/api/v1/models")
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
let models: PaginatedResponse<DeployedModelResponse> = response.json();
let test_model = get_model_by_id(deployment.id, &models).unwrap();
assert!(test_model.metrics.is_none(), "Should not include metrics by default");
// Test with include=metrics - should include metrics
let response = app
.get("/admin/api/v1/models?include=metrics")
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
let models: PaginatedResponse<DeployedModelResponse> = response.json();
let test_model = get_model_by_id(deployment.id, &models).unwrap();
assert!(test_model.metrics.is_some(), "Admin should see metrics when requested");
let metrics = test_model.metrics.as_ref().unwrap();
assert_eq!(metrics.total_requests, 0); // No requests yet, so should be 0
// Test that regular users CANNOT include metrics (no Analytics::ReadAll permission)
let response = app
.get("/admin/api/v1/models?include=metrics")
.add_header(&add_auth_headers(®ular_user)[0].0, &add_auth_headers(®ular_user)[0].1)
.add_header(&add_auth_headers(®ular_user)[1].0, &add_auth_headers(®ular_user)[1].1)
.await;
response.assert_status_ok();
let models: PaginatedResponse<DeployedModelResponse> = response.json();
if let Some(test_model) = get_model_by_id(deployment.id, &models) {
assert!(
test_model.metrics.is_none(),
"Regular user should NOT see metrics (no Analytics::ReadAll permission)"
);
}
// Test with include=groups,metrics - should include both for admin
let response = app
.get("/admin/api/v1/models?include=groups,metrics")
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
let models: PaginatedResponse<DeployedModelResponse> = response.json();
let test_model = get_model_by_id(deployment.id, &models).unwrap();
assert!(test_model.groups.is_some(), "Admin should see groups when requested");
assert!(test_model.metrics.is_some(), "Admin should see metrics when requested");
// Test that regular users cannot include groups or metrics (no permissions)
let response = app
.get("/admin/api/v1/models?include=groups,metrics")
.add_header(&add_auth_headers(®ular_user)[0].0, &add_auth_headers(®ular_user)[0].1)
.add_header(&add_auth_headers(®ular_user)[1].0, &add_auth_headers(®ular_user)[1].1)
.await;
response.assert_status_ok();
let models: PaginatedResponse<DeployedModelResponse> = response.json();
if let Some(test_model) = get_model_by_id(deployment.id, &models) {
assert!(test_model.groups.is_none(), "Regular user should NOT see groups");
assert!(
test_model.metrics.is_none(),
"Regular user should NOT see metrics (no Analytics::ReadAll permission)"
);
}
}
#[sqlx::test]
#[test_log::test]
async fn test_platform_manager_sees_all_models_by_default(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let platform_manager = create_test_admin_user(&pool, Role::PlatformManager).await;
let standard_user = create_test_user(&pool, Role::StandardUser).await;
// Create multiple deployments
let deployment1 = create_test_deployment(&pool, platform_manager.id, "pm-model-1", "pm-alias-1").await;
let deployment2 = create_test_deployment(&pool, platform_manager.id, "pm-model-2", "pm-alias-2").await;
// Create a group and add only standard user to it
let mut pool_conn = pool.acquire().await.unwrap();
let mut group_repo = Groups::new(&mut pool_conn);
let group_create = GroupCreateDBRequest {
name: "Standard User Group".to_string(),
description: Some("Group for standard user only".to_string()),
created_by: platform_manager.id,
};
let group = group_repo.create(&group_create).await.unwrap();
group_repo.add_user_to_group(standard_user.id, group.id).await.unwrap();
// Add only deployment1 to the group (standard user should only see this)
group_repo
.add_deployment_to_group(deployment1.id, group.id, platform_manager.id)
.await
.unwrap();
// Don't add deployment2 to any group - platform manager should still see it
// Platform manager should see ALL models (both deployment1 and deployment2)
let response = app
.get("/admin/api/v1/models")
.add_header(&add_auth_headers(&platform_manager)[0].0, &add_auth_headers(&platform_manager)[0].1)
.add_header(&add_auth_headers(&platform_manager)[1].0, &add_auth_headers(&platform_manager)[1].1)
.await;
response.assert_status_ok();
let pm_models: PaginatedResponse<DeployedModelResponse> = response.json();
assert!(pm_models.data.iter().any(|m| m.id == deployment1.id), "PM should see deployment1");
assert!(
pm_models.data.iter().any(|m| m.id == deployment2.id),
"PM should see deployment2 even without group access"
);
// Standard user should only see models they have access to (deployment1 only)
let response = app
.get("/admin/api/v1/models")
.add_header(&add_auth_headers(&standard_user)[0].0, &add_auth_headers(&standard_user)[0].1)
.add_header(&add_auth_headers(&standard_user)[1].0, &add_auth_headers(&standard_user)[1].1)
.await;
response.assert_status_ok();
let user_models: PaginatedResponse<DeployedModelResponse> = response.json();
let user_accessible_count = user_models
.data
.iter()
.filter(|m| m.id == deployment1.id || m.id == deployment2.id)
.count();
assert_eq!(user_accessible_count, 1, "Standard user should only see 1 accessible model");
assert!(
user_models.data.iter().any(|m| m.id == deployment1.id),
"User should see deployment1"
);
assert!(
!user_models.data.iter().any(|m| m.id == deployment2.id),
"User should NOT see deployment2"
);
}
#[sqlx::test]
#[test_log::test]
async fn test_platform_manager_can_request_accessible_filtering(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let platform_manager = create_test_admin_user(&pool, Role::PlatformManager).await;
// Create deployments
let deployment1 = create_test_deployment(&pool, platform_manager.id, "pm-access-1", "pm-access-alias-1").await;
let deployment2 = create_test_deployment(&pool, platform_manager.id, "pm-access-2", "pm-access-alias-2").await;
// Create a group and add platform manager to it
let mut pool_conn = pool.acquire().await.unwrap();
let mut group_repo = Groups::new(&mut pool_conn);
let group_create = GroupCreateDBRequest {
name: "PM Access Group".to_string(),
description: Some("Group for platform manager accessibility test".to_string()),
created_by: platform_manager.id,
};
let group = group_repo.create(&group_create).await.unwrap();
group_repo.add_user_to_group(platform_manager.id, group.id).await.unwrap();
// Add only deployment1 to the group
group_repo
.add_deployment_to_group(deployment1.id, group.id, platform_manager.id)
.await
.unwrap();
// Platform manager with accessible=false should see ALL models (default behavior)
let response = app
.get("/admin/api/v1/models?accessible=false")
.add_header(&add_auth_headers(&platform_manager)[0].0, &add_auth_headers(&platform_manager)[0].1)
.add_header(&add_auth_headers(&platform_manager)[1].0, &add_auth_headers(&platform_manager)[1].1)
.await;
response.assert_status_ok();
let all_models: PaginatedResponse<DeployedModelResponse> = response.json();
assert!(all_models.data.iter().any(|m| m.id == deployment1.id));
assert!(all_models.data.iter().any(|m| m.id == deployment2.id));
// Platform manager with accessible=true should see only accessible models
let response = app
.get("/admin/api/v1/models?accessible=true")
.add_header(&add_auth_headers(&platform_manager)[0].0, &add_auth_headers(&platform_manager)[0].1)
.add_header(&add_auth_headers(&platform_manager)[1].0, &add_auth_headers(&platform_manager)[1].1)
.await;
response.assert_status_ok();
let accessible_models: PaginatedResponse<DeployedModelResponse> = response.json();
let accessible_count = accessible_models
.data
.iter()
.filter(|m| m.id == deployment1.id || m.id == deployment2.id)
.count();
assert_eq!(accessible_count, 1, "PM with accessible=true should see only 1 accessible model");
assert!(
accessible_models.data.iter().any(|m| m.id == deployment1.id),
"Should see accessible deployment"
);
assert!(
!accessible_models.data.iter().any(|m| m.id == deployment2.id),
"Should NOT see non-accessible deployment"
);
}
#[sqlx::test]
#[test_log::test]
async fn test_request_viewer_role_gets_filtered(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let platform_manager = create_test_admin_user(&pool, Role::PlatformManager).await;
let request_viewer = create_test_user(&pool, Role::RequestViewer).await;
// Create deployments
let deployment1 = create_test_deployment(&pool, platform_manager.id, "rv-model-1", "rv-alias-1").await;
let deployment2 = create_test_deployment(&pool, platform_manager.id, "rv-model-2", "rv-alias-2").await;
// Create a group and add request viewer to it
let mut pool_conn = pool.acquire().await.unwrap();
let mut group_repo = Groups::new(&mut pool_conn);
let group_create = GroupCreateDBRequest {
name: "Request Viewer Group".to_string(),
description: Some("Group for request viewer test".to_string()),
created_by: platform_manager.id,
};
let group = group_repo.create(&group_create).await.unwrap();
group_repo.add_user_to_group(request_viewer.id, group.id).await.unwrap();
// Add only deployment1 to the group
group_repo
.add_deployment_to_group(deployment1.id, group.id, platform_manager.id)
.await
.unwrap();
// Request viewer should only see models they have access to (like standard user)
let response = app
.get("/admin/api/v1/models")
.add_header(&add_auth_headers(&request_viewer)[0].0, &add_auth_headers(&request_viewer)[0].1)
.add_header(&add_auth_headers(&request_viewer)[1].0, &add_auth_headers(&request_viewer)[1].1)
.await;
response.assert_status_ok();
let rv_models: PaginatedResponse<DeployedModelResponse> = response.json();
let rv_accessible_count = rv_models
.data
.iter()
.filter(|m| m.id == deployment1.id || m.id == deployment2.id)
.count();
assert_eq!(rv_accessible_count, 1, "RequestViewer should only see 1 accessible model");
assert!(
rv_models.data.iter().any(|m| m.id == deployment1.id),
"Should see accessible deployment"
);
assert!(
!rv_models.data.iter().any(|m| m.id == deployment2.id),
"Should NOT see non-accessible deployment"
);
// Compare with platform manager who should see both
let response = app
.get("/admin/api/v1/models")
.add_header(&add_auth_headers(&platform_manager)[0].0, &add_auth_headers(&platform_manager)[0].1)
.add_header(&add_auth_headers(&platform_manager)[1].0, &add_auth_headers(&platform_manager)[1].1)
.await;
response.assert_status_ok();
let pm_models: PaginatedResponse<DeployedModelResponse> = response.json();
assert!(pm_models.data.iter().any(|m| m.id == deployment1.id));
assert!(pm_models.data.iter().any(|m| m.id == deployment2.id));
}
#[sqlx::test]
#[test_log::test]
async fn test_platform_manager_can_see_newly_created_models(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let platform_manager = create_test_admin_user(&pool, Role::PlatformManager).await;
let test_endpoint_id = get_test_endpoint_id(&pool).await;
// Create a model via API
let create_request = json!({
"type": "standard",
"model_name": "pm-new-model",
"alias": "Platform Manager New Model",
"hosted_on": test_endpoint_id.to_string(),
"description": "A model created by platform manager"
});
let response = app
.post("/admin/api/v1/models")
.add_header(&add_auth_headers(&platform_manager)[0].0, &add_auth_headers(&platform_manager)[0].1)
.add_header(&add_auth_headers(&platform_manager)[1].0, &add_auth_headers(&platform_manager)[1].1)
.json(&create_request)
.await;
response.assert_status_ok();
let created_model: DeployedModelResponse = response.json();
let deployment_id = created_model.id;
// Platform manager should immediately see the newly created model in list
let response = app
.get("/admin/api/v1/models")
.add_header(&add_auth_headers(&platform_manager)[0].0, &add_auth_headers(&platform_manager)[0].1)
.add_header(&add_auth_headers(&platform_manager)[1].0, &add_auth_headers(&platform_manager)[1].1)
.await;
response.assert_status_ok();
let models: PaginatedResponse<DeployedModelResponse> = response.json();
assert!(
models.data.iter().any(|m| m.id == deployment_id),
"Platform manager should see newly created model immediately"
);
// Verify the model details
let found_model = models.data.iter().find(|m| m.id == deployment_id).unwrap();
assert_eq!(found_model.model_name, "pm-new-model");
assert_eq!(found_model.alias, "Platform Manager New Model");
}
#[sqlx::test]
#[test_log::test]
async fn test_standard_user_cannot_see_ungrouped_models(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let platform_manager = create_test_admin_user(&pool, Role::PlatformManager).await;
let standard_user = create_test_user(&pool, Role::StandardUser).await;
// Platform manager creates a model
let deployment = create_test_deployment(&pool, platform_manager.id, "ungrouped-model", "ungrouped-alias").await;
// Don't add the model to any groups
// Platform manager should see the ungrouped model
let response = app
.get("/admin/api/v1/models")
.add_header(&add_auth_headers(&platform_manager)[0].0, &add_auth_headers(&platform_manager)[0].1)
.add_header(&add_auth_headers(&platform_manager)[1].0, &add_auth_headers(&platform_manager)[1].1)
.await;
response.assert_status_ok();
let pm_models: PaginatedResponse<DeployedModelResponse> = response.json();
assert!(
pm_models.data.iter().any(|m| m.id == deployment.id),
"Platform manager should see ungrouped model"
);
// Standard user should NOT see the ungrouped model
let response = app
.get("/admin/api/v1/models")
.add_header(&add_auth_headers(&standard_user)[0].0, &add_auth_headers(&standard_user)[0].1)
.add_header(&add_auth_headers(&standard_user)[1].0, &add_auth_headers(&standard_user)[1].1)
.await;
response.assert_status_ok();
let user_models: PaginatedResponse<DeployedModelResponse> = response.json();
assert!(
!user_models.data.iter().any(|m| m.id == deployment.id),
"Standard user should NOT see ungrouped model"
);
}
#[sqlx::test]
#[test_log::test]
async fn test_request_viewer_cannot_modify_models(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let platform_manager = create_test_admin_user(&pool, Role::PlatformManager).await;
let request_viewer = create_test_user(&pool, Role::RequestViewer).await;
let test_endpoint_id = get_test_endpoint_id(&pool).await;
// Create a deployment as platform manager
let deployment = create_test_deployment(&pool, platform_manager.id, "rv-test-model", "rv-test-alias").await;
// RequestViewer should NOT be able to create models
let create_request = json!({
"model_name": "rv-forbidden-create",
"hosted_on": test_endpoint_id.to_string()
});
let response = app
.post("/admin/api/v1/models")
.add_header(&add_auth_headers(&request_viewer)[0].0, &add_auth_headers(&request_viewer)[0].1)
.add_header(&add_auth_headers(&request_viewer)[1].0, &add_auth_headers(&request_viewer)[1].1)
.json(&create_request)
.await;
response.assert_status_forbidden();
// RequestViewer should NOT be able to update models
let update = json!({"alias": "rv-forbidden-update"});
let response = app
.patch(&format!("/admin/api/v1/models/{}", deployment.id))
.add_header(&add_auth_headers(&request_viewer)[0].0, &add_auth_headers(&request_viewer)[0].1)
.add_header(&add_auth_headers(&request_viewer)[1].0, &add_auth_headers(&request_viewer)[1].1)
.json(&update)
.await;
response.assert_status_forbidden();
// RequestViewer should NOT be able to delete models
let response = app
.delete(&format!("/admin/api/v1/models/{}", deployment.id))
.add_header(&add_auth_headers(&request_viewer)[0].0, &add_auth_headers(&request_viewer)[0].1)
.add_header(&add_auth_headers(&request_viewer)[1].0, &add_auth_headers(&request_viewer)[1].1)
.await;
response.assert_status_forbidden();
}
#[sqlx::test]
#[test_log::test]
async fn test_standard_user_cannot_modify_models(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let platform_manager = create_test_admin_user(&pool, Role::PlatformManager).await;
let standard_user = create_test_user(&pool, Role::StandardUser).await;
let test_endpoint_id = get_test_endpoint_id(&pool).await;
// Create a deployment as platform manager
let deployment = create_test_deployment(&pool, platform_manager.id, "su-test-model", "su-test-alias").await;
// StandardUser should NOT be able to create models
let create_request = json!({
"model_name": "su-forbidden-create",
"hosted_on": test_endpoint_id.to_string()
});
let response = app
.post("/admin/api/v1/models")
.add_header(&add_auth_headers(&standard_user)[0].0, &add_auth_headers(&standard_user)[0].1)
.add_header(&add_auth_headers(&standard_user)[1].0, &add_auth_headers(&standard_user)[1].1)
.json(&create_request)
.await;
response.assert_status_forbidden();
// StandardUser should NOT be able to update models
let update = json!({"alias": "su-forbidden-update"});
let response = app
.patch(&format!("/admin/api/v1/models/{}", deployment.id))
.add_header(&add_auth_headers(&standard_user)[0].0, &add_auth_headers(&standard_user)[0].1)
.add_header(&add_auth_headers(&standard_user)[1].0, &add_auth_headers(&standard_user)[1].1)
.json(&update)
.await;
response.assert_status_forbidden();
// StandardUser should NOT be able to delete models
let response = app
.delete(&format!("/admin/api/v1/models/{}", deployment.id))
.add_header(&add_auth_headers(&standard_user)[0].0, &add_auth_headers(&standard_user)[0].1)
.add_header(&add_auth_headers(&standard_user)[1].0, &add_auth_headers(&standard_user)[1].1)
.await;
response.assert_status_forbidden();
}
#[sqlx::test]
#[test_log::test]
async fn test_multi_role_user_cannot_modify_models_without_platform_manager(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let platform_manager = create_test_admin_user(&pool, Role::PlatformManager).await;
// Create user with StandardUser + RequestViewer (but not PlatformManager)
let multi_role_user = create_test_user_with_roles(&pool, vec![Role::StandardUser, Role::RequestViewer]).await;
let test_endpoint_id = get_test_endpoint_id(&pool).await;
// Multi-role user should still NOT be able to create models (needs PlatformManager role)
let create_request = json!({
"model_name": "multi-forbidden-create",
"hosted_on": test_endpoint_id.to_string()
});
let response = app
.post("/admin/api/v1/models")
.add_header(&add_auth_headers(&multi_role_user)[0].0, &add_auth_headers(&multi_role_user)[0].1)
.add_header(&add_auth_headers(&multi_role_user)[1].0, &add_auth_headers(&multi_role_user)[1].1)
.json(&create_request)
.await;
response.assert_status_forbidden();
// Create a deployment to test update/delete
let deployment = create_test_deployment(&pool, platform_manager.id, "multi-test-model", "multi-test-alias").await;
// Multi-role user should NOT be able to update models
let update = json!({"alias": "multi-forbidden-update"});
let response = app
.patch(&format!("/admin/api/v1/models/{}", deployment.id))
.add_header(&add_auth_headers(&multi_role_user)[0].0, &add_auth_headers(&multi_role_user)[0].1)
.add_header(&add_auth_headers(&multi_role_user)[1].0, &add_auth_headers(&multi_role_user)[1].1)
.json(&update)
.await;
response.assert_status_forbidden();
// Multi-role user should NOT be able to delete models
let response = app
.delete(&format!("/admin/api/v1/models/{}", deployment.id))
.add_header(&add_auth_headers(&multi_role_user)[0].0, &add_auth_headers(&multi_role_user)[0].1)
.add_header(&add_auth_headers(&multi_role_user)[1].0, &add_auth_headers(&multi_role_user)[1].1)
.await;
response.assert_status_forbidden();
}
#[sqlx::test]
#[test_log::test]
async fn test_platform_manager_plus_standard_user_can_modify_models(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
// Create user with both PlatformManager and StandardUser roles
let platform_user = create_test_user_with_roles(&pool, vec![Role::PlatformManager, Role::StandardUser]).await;
let test_endpoint_id = get_test_endpoint_id(&pool).await;
// Should be able to create models (PlatformManager permission)
let create_request = json!({
"type": "standard",
"model_name": "pm-create-test",
"hosted_on": test_endpoint_id.to_string(),
"alias": "Platform Manager Created"
});
let response = app
.post("/admin/api/v1/models")
.add_header(&add_auth_headers(&platform_user)[0].0, &add_auth_headers(&platform_user)[0].1)
.add_header(&add_auth_headers(&platform_user)[1].0, &add_auth_headers(&platform_user)[1].1)
.json(&create_request)
.await;
response.assert_status_ok();
let created_model: DeployedModelResponse = response.json();
// Should be able to update models
let update = json!({"alias": "PM Updated Alias"});
let response = app
.patch(&format!("/admin/api/v1/models/{}", created_model.id))
.add_header(&add_auth_headers(&platform_user)[0].0, &add_auth_headers(&platform_user)[0].1)
.add_header(&add_auth_headers(&platform_user)[1].0, &add_auth_headers(&platform_user)[1].1)
.json(&update)
.await;
response.assert_status_ok();
let updated_model: DeployedModelResponse = response.json();
assert_eq!(updated_model.alias, "PM Updated Alias");
// Should be able to delete models
let response = app
.delete(&format!("/admin/api/v1/models/{}", created_model.id))
.add_header(&add_auth_headers(&platform_user)[0].0, &add_auth_headers(&platform_user)[0].1)
.add_header(&add_auth_headers(&platform_user)[1].0, &add_auth_headers(&platform_user)[1].1)
.await;
response.assert_status_ok();
}
#[sqlx::test]
#[test_log::test]
async fn test_accessibility_filtering_permissions(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let platform_manager = create_test_admin_user(&pool, Role::PlatformManager).await;
let standard_user = create_test_user(&pool, Role::StandardUser).await;
let request_viewer = create_test_user(&pool, Role::RequestViewer).await;
// Create deployments
let accessible_deployment = create_test_deployment(&pool, platform_manager.id, "accessible-model", "accessible-alias").await;
let inaccessible_deployment = create_test_deployment(&pool, platform_manager.id, "inaccessible-model", "inaccessible-alias").await;
// Create a group and add standard_user to it
let mut pool_conn = pool.acquire().await.unwrap();
let mut group_repo = Groups::new(&mut pool_conn);
let group_create = GroupCreateDBRequest {
name: "Access Test Group".to_string(),
description: Some("Group for accessibility testing".to_string()),
created_by: platform_manager.id,
};
let group = group_repo.create(&group_create).await.unwrap();
group_repo.add_user_to_group(standard_user.id, group.id).await.unwrap();
group_repo.add_user_to_group(request_viewer.id, group.id).await.unwrap();
// Add only accessible_deployment to the group
group_repo
.add_deployment_to_group(accessible_deployment.id, group.id, platform_manager.id)
.await
.unwrap();
// StandardUser should only see accessible models (default behavior)
let response = app
.get("/admin/api/v1/models")
.add_header(&add_auth_headers(&standard_user)[0].0, &add_auth_headers(&standard_user)[0].1)
.add_header(&add_auth_headers(&standard_user)[1].0, &add_auth_headers(&standard_user)[1].1)
.await;
response.assert_status_ok();
let standard_models: PaginatedResponse<DeployedModelResponse> = response.json();
assert!(standard_models.data.iter().any(|m| m.id == accessible_deployment.id));
assert!(!standard_models.data.iter().any(|m| m.id == inaccessible_deployment.id));
// RequestViewer should have same accessibility filtering as StandardUser
let response = app
.get("/admin/api/v1/models")
.add_header(&add_auth_headers(&request_viewer)[0].0, &add_auth_headers(&request_viewer)[0].1)
.add_header(&add_auth_headers(&request_viewer)[1].0, &add_auth_headers(&request_viewer)[1].1)
.await;
response.assert_status_ok();
let rv_models: PaginatedResponse<DeployedModelResponse> = response.json();
assert!(rv_models.data.iter().any(|m| m.id == accessible_deployment.id));
assert!(!rv_models.data.iter().any(|m| m.id == inaccessible_deployment.id));
// PlatformManager should see all models by default
let response = app
.get("/admin/api/v1/models")
.add_header(&add_auth_headers(&platform_manager)[0].0, &add_auth_headers(&platform_manager)[0].1)
.add_header(&add_auth_headers(&platform_manager)[1].0, &add_auth_headers(&platform_manager)[1].1)
.await;
response.assert_status_ok();
let pm_models: PaginatedResponse<DeployedModelResponse> = response.json();
assert!(pm_models.data.iter().any(|m| m.id == accessible_deployment.id));
assert!(pm_models.data.iter().any(|m| m.id == inaccessible_deployment.id));
}
#[sqlx::test]
#[test_log::test]
async fn test_groups_include_permission_enforcement(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let platform_manager = create_test_admin_user(&pool, Role::PlatformManager).await;
let standard_user = create_test_user(&pool, Role::StandardUser).await;
let request_viewer = create_test_user(&pool, Role::RequestViewer).await;
// Create a deployment and add it to a group
let deployment = create_test_deployment(&pool, platform_manager.id, "groups-perm-model", "groups-perm-alias").await;
let mut pool_conn = pool.acquire().await.unwrap();
let mut group_repo = Groups::new(&mut pool_conn);
let group_create = GroupCreateDBRequest {
name: "Groups Permission Test".to_string(),
description: Some("Test group for groups include permission".to_string()),
created_by: platform_manager.id,
};
let group = group_repo.create(&group_create).await.unwrap();
// Add all users to the group so they can see the deployment
group_repo.add_user_to_group(standard_user.id, group.id).await.unwrap();
group_repo.add_user_to_group(request_viewer.id, group.id).await.unwrap();
group_repo
.add_deployment_to_group(deployment.id, group.id, platform_manager.id)
.await
.unwrap();
// PlatformManager should be able to include groups
let response = app
.get("/admin/api/v1/models?include=groups")
.add_header(&add_auth_headers(&platform_manager)[0].0, &add_auth_headers(&platform_manager)[0].1)
.add_header(&add_auth_headers(&platform_manager)[1].0, &add_auth_headers(&platform_manager)[1].1)
.await;
response.assert_status_ok();
let pm_models: PaginatedResponse<DeployedModelResponse> = response.json();
let pm_model = pm_models.data.iter().find(|m| m.id == deployment.id).unwrap();
assert!(pm_model.groups.is_some(), "PlatformManager should see groups when included");
// StandardUser should NOT be able to include groups (groups should be None)
let response = app
.get("/admin/api/v1/models?include=groups")
.add_header(&add_auth_headers(&standard_user)[0].0, &add_auth_headers(&standard_user)[0].1)
.add_header(&add_auth_headers(&standard_user)[1].0, &add_auth_headers(&standard_user)[1].1)
.await;
response.assert_status_ok();
let models: PaginatedResponse<DeployedModelResponse> = response.json();
let test_model = get_model_by_id(deployment.id, &models).unwrap();
assert!(test_model.groups.is_none(), "Regular user should NOT see groups included");
// RequestViewer should NOT be able to include groups
let response = app
.get("/admin/api/v1/models?include=groups")
.add_header(&add_auth_headers(&request_viewer)[0].0, &add_auth_headers(&request_viewer)[0].1)
.add_header(&add_auth_headers(&request_viewer)[1].0, &add_auth_headers(&request_viewer)[1].1)
.await;
response.assert_status_ok();
let rv_models: PaginatedResponse<DeployedModelResponse> = response.json();
let rv_model = rv_models.data.iter().find(|m| m.id == deployment.id).unwrap();
assert!(rv_model.groups.is_none(), "RequestViewer should NOT see groups even when requested");
}
#[sqlx::test]
#[test_log::test]
async fn test_rate_limits_permission_gating(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let platform_manager = create_test_admin_user(&pool, Role::PlatformManager).await;
let standard_user = create_test_user(&pool, Role::StandardUser).await;
let request_viewer = create_test_user(&pool, Role::RequestViewer).await;
// Create a deployment with rate limits
let deployment = create_test_deployment(&pool, platform_manager.id, "rate-limit-test", "rate-limit-alias").await;
// Set rate limits on the deployment
let update = json!({
"requests_per_second": 100.0,
"burst_size": 200
});
let response = app
.patch(&format!("/admin/api/v1/models/{}", deployment.id))
.add_header(&add_auth_headers(&platform_manager)[0].0, &add_auth_headers(&platform_manager)[0].1)
.add_header(&add_auth_headers(&platform_manager)[1].0, &add_auth_headers(&platform_manager)[1].1)
.json(&update)
.await;
response.assert_status_ok();
// Create a group and add users to it so they can see the deployment
let mut pool_conn = pool.acquire().await.unwrap();
let mut group_repo = Groups::new(&mut pool_conn);
let group_create = GroupCreateDBRequest {
name: "Rate Limit Test Group".to_string(),
description: Some("Test group for rate limit permissions".to_string()),
created_by: platform_manager.id,
};
let group = group_repo.create(&group_create).await.unwrap();
group_repo.add_user_to_group(standard_user.id, group.id).await.unwrap();
group_repo.add_user_to_group(request_viewer.id, group.id).await.unwrap();
group_repo
.add_deployment_to_group(deployment.id, group.id, platform_manager.id)
.await
.unwrap();
// PlatformManager should see rate limits (has ModelRateLimits::ReadAll)
let response = app
.get(&format!("/admin/api/v1/models/{}", deployment.id))
.add_header(&add_auth_headers(&platform_manager)[0].0, &add_auth_headers(&platform_manager)[0].1)
.add_header(&add_auth_headers(&platform_manager)[1].0, &add_auth_headers(&platform_manager)[1].1)
.await;
response.assert_status_ok();
let pm_model: DeployedModelResponse = response.json();
assert_eq!(pm_model.requests_per_second, Some(100.0), "PlatformManager should see rate limits");
assert_eq!(pm_model.burst_size, Some(200), "PlatformManager should see burst size");
// StandardUser should NOT see rate limits (masked)
let response = app
.get(&format!("/admin/api/v1/models/{}", deployment.id))
.add_header(&add_auth_headers(&standard_user)[0].0, &add_auth_headers(&standard_user)[0].1)
.add_header(&add_auth_headers(&standard_user)[1].0, &add_auth_headers(&standard_user)[1].1)
.await;
response.assert_status_ok();
let user_model: DeployedModelResponse = response.json();
assert_eq!(user_model.requests_per_second, None, "StandardUser should NOT see rate limits");
assert_eq!(user_model.burst_size, None, "StandardUser should NOT see burst size");
// RequestViewer should NOT see rate limits (masked)
let response = app
.get(&format!("/admin/api/v1/models/{}", deployment.id))
.add_header(&add_auth_headers(&request_viewer)[0].0, &add_auth_headers(&request_viewer)[0].1)
.add_header(&add_auth_headers(&request_viewer)[1].0, &add_auth_headers(&request_viewer)[1].1)
.await;
response.assert_status_ok();
let rv_model: DeployedModelResponse = response.json();
assert_eq!(rv_model.requests_per_second, None, "RequestViewer should NOT see rate limits");
assert_eq!(rv_model.burst_size, None, "RequestViewer should NOT see burst size");
}
#[sqlx::test]
#[test_log::test]
async fn test_metrics_permission_gating(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let platform_manager = create_test_admin_user(&pool, Role::PlatformManager).await;
let standard_user = create_test_user(&pool, Role::StandardUser).await;
let request_viewer = create_test_user(&pool, Role::RequestViewer).await;
// Create a deployment
let deployment = create_test_deployment(&pool, platform_manager.id, "metrics-perm-test", "metrics-perm-alias").await;
// Create a group and add users to it so they can see the deployment
let mut pool_conn = pool.acquire().await.unwrap();
let mut group_repo = Groups::new(&mut pool_conn);
let group_create = GroupCreateDBRequest {
name: "Metrics Permission Test Group".to_string(),
description: Some("Test group for metrics permissions".to_string()),
created_by: platform_manager.id,
};
let group = group_repo.create(&group_create).await.unwrap();
group_repo.add_user_to_group(standard_user.id, group.id).await.unwrap();
group_repo.add_user_to_group(request_viewer.id, group.id).await.unwrap();
group_repo
.add_deployment_to_group(deployment.id, group.id, platform_manager.id)
.await
.unwrap();
// PlatformManager should be able to include metrics (has Analytics::ReadAll)
let response = app
.get("/admin/api/v1/models?include=metrics")
.add_header(&add_auth_headers(&platform_manager)[0].0, &add_auth_headers(&platform_manager)[0].1)
.add_header(&add_auth_headers(&platform_manager)[1].0, &add_auth_headers(&platform_manager)[1].1)
.await;
response.assert_status_ok();
let models: PaginatedResponse<DeployedModelResponse> = response.json();
let pm_model = get_model_by_id(deployment.id, &models).unwrap();
assert!(pm_model.metrics.is_some(), "PlatformManager should see metrics when requested");
// StandardUser should NOT be able to include metrics (no Analytics::ReadAll)
let response = app
.get("/admin/api/v1/models?include=metrics")
.add_header(&add_auth_headers(&standard_user)[0].0, &add_auth_headers(&standard_user)[0].1)
.add_header(&add_auth_headers(&standard_user)[1].0, &add_auth_headers(&standard_user)[1].1)
.await;
response.assert_status_ok();
let models: PaginatedResponse<DeployedModelResponse> = response.json();
let user_model = get_model_by_id(deployment.id, &models).unwrap();
assert!(
user_model.metrics.is_none(),
"StandardUser should NOT see metrics even when requested"
);
// RequestViewer should be able to include metrics (has Analytics::ReadAll)
let response = app
.get("/admin/api/v1/models?include=metrics")
.add_header(&add_auth_headers(&request_viewer)[0].0, &add_auth_headers(&request_viewer)[0].1)
.add_header(&add_auth_headers(&request_viewer)[1].0, &add_auth_headers(&request_viewer)[1].1)
.await;
response.assert_status_ok();
let models: PaginatedResponse<DeployedModelResponse> = response.json();
let rv_model = get_model_by_id(deployment.id, &models).unwrap();
assert!(rv_model.metrics.is_some(), "RequestViewer should see metrics when requested");
}
#[sqlx::test]
#[test_log::test]
async fn test_models_pagination(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let admin_user = create_test_admin_user(&pool, Role::PlatformManager).await;
// Create 5 test deployments
let deployment1 = create_test_deployment(&pool, admin_user.id, "model-1", "alias-1").await;
let deployment2 = create_test_deployment(&pool, admin_user.id, "model-2", "alias-2").await;
let deployment3 = create_test_deployment(&pool, admin_user.id, "model-3", "alias-3").await;
let deployment4 = create_test_deployment(&pool, admin_user.id, "model-4", "alias-4").await;
let deployment5 = create_test_deployment(&pool, admin_user.id, "model-5", "alias-5").await;
// Test 1: Get first page with limit=2
let response = app
.get("/admin/api/v1/models?limit=2&skip=0")
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
let page1: PaginatedResponse<DeployedModelResponse> = response.json();
assert_eq!(page1.data.len(), 2, "First page should have 2 models");
// Test 2: Get second page with limit=2, skip=2
let response = app
.get("/admin/api/v1/models?limit=2&skip=2")
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
let page2: PaginatedResponse<DeployedModelResponse> = response.json();
assert_eq!(page2.data.len(), 2, "Second page should have 2 models");
// Test 3: Get third page with limit=2, skip=4 (should have 1 model)
let response = app
.get("/admin/api/v1/models?limit=2&skip=4")
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
let page3: PaginatedResponse<DeployedModelResponse> = response.json();
assert_eq!(page3.data.len(), 1, "Third page should have 1 model");
// Test 4: Verify no duplicate models across pages
let all_page_ids: Vec<DeploymentId> = page1
.data
.iter()
.chain(page2.data.iter())
.chain(page3.data.iter())
.map(|m| m.id)
.collect();
let unique_ids: std::collections::HashSet<DeploymentId> = all_page_ids.iter().copied().collect();
assert_eq!(all_page_ids.len(), unique_ids.len(), "Pages should not have duplicate models");
assert_eq!(unique_ids.len(), 5, "Should have all 5 models across pages");
// Test 5: Verify all created models are present
assert!(
unique_ids.contains(&deployment1.id)
&& unique_ids.contains(&deployment2.id)
&& unique_ids.contains(&deployment3.id)
&& unique_ids.contains(&deployment4.id)
&& unique_ids.contains(&deployment5.id),
"All created deployments should be present in paginated results"
);
// Test 6: Default limit (should get all 5 models)
let response = app
.get("/admin/api/v1/models")
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
let all_models: PaginatedResponse<DeployedModelResponse> = response.json();
assert_eq!(all_models.data.len(), 5, "Without pagination params, should get all models");
// Test 7: Offset (skip) beyond available models (should return empty)
let response = app
.get("/admin/api/v1/models?limit=10&skip=100")
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
let empty_page: PaginatedResponse<DeployedModelResponse> = response.json();
assert_eq!(empty_page.data.len(), 0, "Offset beyond available models should return empty array");
}
#[sqlx::test]
#[test_log::test]
async fn test_list_models_with_group_filter(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let admin_user = create_test_admin_user(&pool, Role::PlatformManager).await;
// Create three deployments
let deployment1 = create_test_deployment(&pool, admin_user.id, "model-1", "alias-1").await;
let deployment2 = create_test_deployment(&pool, admin_user.id, "model-2", "alias-2").await;
let deployment3 = create_test_deployment(&pool, admin_user.id, "model-3", "alias-3").await;
// Create two groups
let mut group_conn = pool.acquire().await.unwrap();
let mut group_repo = Groups::new(&mut group_conn);
let group1_create = GroupCreateDBRequest {
name: "Production".to_string(),
description: Some("Production group".to_string()),
created_by: admin_user.id,
};
let group1 = group_repo.create(&group1_create).await.unwrap();
let group2_create = GroupCreateDBRequest {
name: "Staging".to_string(),
description: Some("Staging group".to_string()),
created_by: admin_user.id,
};
let group2 = group_repo.create(&group2_create).await.unwrap();
// Add deployment1 to group1
group_repo
.add_deployment_to_group(deployment1.id, group1.id, admin_user.id)
.await
.unwrap();
// Add deployment2 to group2
group_repo
.add_deployment_to_group(deployment2.id, group2.id, admin_user.id)
.await
.unwrap();
// deployment3 has no groups
// Test 1: Filter by single group
let response = app
.get(&format!("/admin/api/v1/models?group={}", group1.id))
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
let single_group: PaginatedResponse<DeployedModelResponse> = response.json();
assert_eq!(single_group.data.len(), 1, "Should only return models from group1");
assert!(get_model_by_id(deployment1.id, &single_group).is_some());
assert!(get_model_by_id(deployment2.id, &single_group).is_none());
assert!(get_model_by_id(deployment3.id, &single_group).is_none());
// Test 2: Filter by multiple groups (comma-separated)
let response = app
.get(&format!("/admin/api/v1/models?group={},{}", group1.id, group2.id))
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
let multi_group: PaginatedResponse<DeployedModelResponse> = response.json();
assert_eq!(multi_group.data.len(), 2, "Should return models from both groups");
assert!(get_model_by_id(deployment1.id, &multi_group).is_some());
assert!(get_model_by_id(deployment2.id, &multi_group).is_some());
assert!(get_model_by_id(deployment3.id, &multi_group).is_none());
// Test 3: Invalid group ID format should return 400
let response = app
.get("/admin/api/v1/models?group=invalid-uuid")
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status(axum::http::StatusCode::BAD_REQUEST);
// Test 4: Mix of valid and invalid UUIDs should return 400
let response = app
.get(&format!("/admin/api/v1/models?group={},invalid-uuid", group1.id))
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status(axum::http::StatusCode::BAD_REQUEST);
// Test 5: Empty group parameter (just commas) should work without filtering
let response = app
.get("/admin/api/v1/models?group=,,,")
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
let all_models: PaginatedResponse<DeployedModelResponse> = response.json();
assert!(all_models.data.len() >= 3, "Empty group list should return all models");
}
// ===== Traffic Routing Rules Tests =====
#[sqlx::test]
#[test_log::test]
async fn test_create_model_with_traffic_rules(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let admin_user = create_test_admin_user(&pool, Role::PlatformManager).await;
let test_endpoint_id = get_test_endpoint_id(&pool).await;
// Create a redirect target first
create_test_deployment(&pool, admin_user.id, "traffic-target", "traffic-target-alias").await;
// Create model with traffic routing rules
let response = app
.post("/admin/api/v1/models")
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.json(&json!({
"type": "standard",
"model_name": "traffic-source",
"alias": "traffic-source-alias",
"hosted_on": test_endpoint_id,
"traffic_routing_rules": [
{ "api_key_purpose": "batch", "action": { "type": "deny" } },
{ "api_key_purpose": "realtime", "action": { "type": "redirect", "target": "traffic-target-alias" } }
]
}))
.await;
response.assert_status_ok();
let model: DeployedModelResponse = response.json();
assert_eq!(model.alias, "traffic-source-alias");
let rules = model.traffic_routing_rules.expect("expected traffic_routing_rules in response");
assert_eq!(rules.len(), 2);
// Verify deny rule
let batch_rule = rules
.iter()
.find(|r| r.api_key_purpose == crate::db::models::api_keys::ApiKeyPurpose::Batch)
.unwrap();
assert!(matches!(
batch_rule.action,
crate::api::models::deployments::TrafficRoutingAction::Deny
));
// Verify redirect rule
let realtime_rule = rules
.iter()
.find(|r| r.api_key_purpose == crate::db::models::api_keys::ApiKeyPurpose::Realtime)
.unwrap();
match &realtime_rule.action {
crate::api::models::deployments::TrafficRoutingAction::Redirect { target } => {
assert_eq!(target, "traffic-target-alias");
}
_ => panic!("expected redirect action"),
}
}
#[sqlx::test]
#[test_log::test]
async fn test_get_model_includes_traffic_rules(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let admin_user = create_test_admin_user(&pool, Role::PlatformManager).await;
let test_endpoint_id = get_test_endpoint_id(&pool).await;
// Create target and source with rules via API
create_test_deployment(&pool, admin_user.id, "get-tr-target", "get-tr-target-alias").await;
let create_resp = app
.post("/admin/api/v1/models")
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.json(&json!({
"type": "standard",
"model_name": "get-tr-source",
"alias": "get-tr-source-alias",
"hosted_on": test_endpoint_id,
"traffic_routing_rules": [
{ "api_key_purpose": "batch", "action": { "type": "deny" } }
]
}))
.await;
create_resp.assert_status_ok();
let created: DeployedModelResponse = create_resp.json();
// GET the model
let response = app
.get(&format!("/admin/api/v1/models/{}", created.id))
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
let model: DeployedModelResponse = response.json();
let rules = model.traffic_routing_rules.expect("expected traffic_routing_rules");
assert_eq!(rules.len(), 1);
assert!(matches!(
rules[0].action,
crate::api::models::deployments::TrafficRoutingAction::Deny
));
}
#[sqlx::test]
#[test_log::test]
async fn test_list_models_includes_traffic_rules(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let admin_user = create_test_admin_user(&pool, Role::PlatformManager).await;
let test_endpoint_id = get_test_endpoint_id(&pool).await;
// Create source model with rules via API
let create_resp = app
.post("/admin/api/v1/models")
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.json(&json!({
"type": "standard",
"model_name": "list-tr-source",
"alias": "list-tr-source-alias",
"hosted_on": test_endpoint_id,
"traffic_routing_rules": [
{ "api_key_purpose": "playground", "action": { "type": "deny" } }
]
}))
.await;
create_resp.assert_status_ok();
let created: DeployedModelResponse = create_resp.json();
// List models
let response = app
.get("/admin/api/v1/models?limit=100")
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
let list: PaginatedResponse<DeployedModelResponse> = response.json();
let model = get_model_by_id(created.id, &list).expect("model should be in list");
let rules = model.traffic_routing_rules.as_ref().expect("traffic rules should be present");
assert_eq!(rules.len(), 1);
assert_eq!(rules[0].api_key_purpose, crate::db::models::api_keys::ApiKeyPurpose::Playground);
}
#[sqlx::test]
#[test_log::test]
async fn test_update_model_set_traffic_rules(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let admin_user = create_test_admin_user(&pool, Role::PlatformManager).await;
// Create model without rules
let model = create_test_deployment(&pool, admin_user.id, "update-tr-model", "update-tr-alias").await;
// PATCH with traffic rules
let response = app
.patch(&format!("/admin/api/v1/models/{}", model.id))
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.json(&json!({
"traffic_routing_rules": [
{ "api_key_purpose": "batch", "action": { "type": "deny" } }
]
}))
.await;
response.assert_status_ok();
let updated: DeployedModelResponse = response.json();
let rules = updated.traffic_routing_rules.expect("traffic rules should be set");
assert_eq!(rules.len(), 1);
// Verify via GET
let response = app
.get(&format!("/admin/api/v1/models/{}", model.id))
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
let fetched: DeployedModelResponse = response.json();
assert_eq!(fetched.traffic_routing_rules.unwrap().len(), 1);
}
#[sqlx::test]
#[test_log::test]
async fn test_update_model_clear_traffic_rules(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let admin_user = create_test_admin_user(&pool, Role::PlatformManager).await;
let test_endpoint_id = get_test_endpoint_id(&pool).await;
// Create model with rules
let create_resp = app
.post("/admin/api/v1/models")
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.json(&json!({
"type": "standard",
"model_name": "clear-tr-model",
"alias": "clear-tr-alias",
"hosted_on": test_endpoint_id,
"traffic_routing_rules": [
{ "api_key_purpose": "batch", "action": { "type": "deny" } }
]
}))
.await;
create_resp.assert_status_ok();
let created: DeployedModelResponse = create_resp.json();
assert!(created.traffic_routing_rules.is_some());
// PATCH with traffic_routing_rules: null → clear rules
let response = app
.patch(&format!("/admin/api/v1/models/{}", created.id))
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.json(&json!({
"traffic_routing_rules": null
}))
.await;
response.assert_status_ok();
let updated: DeployedModelResponse = response.json();
assert!(updated.traffic_routing_rules.is_none());
// Verify via GET
let response = app
.get(&format!("/admin/api/v1/models/{}", created.id))
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
let fetched: DeployedModelResponse = response.json();
assert!(fetched.traffic_routing_rules.is_none());
}
#[sqlx::test]
#[test_log::test]
async fn test_update_model_no_change_to_traffic_rules(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let admin_user = create_test_admin_user(&pool, Role::PlatformManager).await;
let test_endpoint_id = get_test_endpoint_id(&pool).await;
// Create model with rules
let create_resp = app
.post("/admin/api/v1/models")
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.json(&json!({
"type": "standard",
"model_name": "nochange-tr-model",
"alias": "nochange-tr-alias",
"hosted_on": test_endpoint_id,
"traffic_routing_rules": [
{ "api_key_purpose": "batch", "action": { "type": "deny" } }
]
}))
.await;
create_resp.assert_status_ok();
let created: DeployedModelResponse = create_resp.json();
// PATCH another field without mentioning traffic_routing_rules
let response = app
.patch(&format!("/admin/api/v1/models/{}", created.id))
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.json(&json!({
"description": "updated description"
}))
.await;
response.assert_status_ok();
let updated: DeployedModelResponse = response.json();
// Rules should still be present
let rules = updated.traffic_routing_rules.expect("traffic rules should be unchanged");
assert_eq!(rules.len(), 1);
}
#[sqlx::test]
#[test_log::test]
async fn test_create_model_self_redirect_rejected(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let admin_user = create_test_admin_user(&pool, Role::PlatformManager).await;
let test_endpoint_id = get_test_endpoint_id(&pool).await;
let response = app
.post("/admin/api/v1/models")
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.json(&json!({
"type": "standard",
"model_name": "self-redirect",
"alias": "self-redirect-alias",
"hosted_on": test_endpoint_id,
"traffic_routing_rules": [
{ "api_key_purpose": "realtime", "action": { "type": "redirect", "target": "self-redirect-alias" } }
]
}))
.await;
response.assert_status(axum::http::StatusCode::BAD_REQUEST);
}
#[sqlx::test]
#[test_log::test]
async fn test_create_model_nonexistent_redirect_target(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let admin_user = create_test_admin_user(&pool, Role::PlatformManager).await;
let test_endpoint_id = get_test_endpoint_id(&pool).await;
let response = app
.post("/admin/api/v1/models")
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.json(&json!({
"type": "standard",
"model_name": "bad-redirect",
"alias": "bad-redirect-alias",
"hosted_on": test_endpoint_id,
"traffic_routing_rules": [
{ "api_key_purpose": "realtime", "action": { "type": "redirect", "target": "nonexistent-model" } }
]
}))
.await;
response.assert_status(axum::http::StatusCode::BAD_REQUEST);
}
#[sqlx::test]
#[test_log::test]
async fn test_create_model_empty_redirect_target(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let admin_user = create_test_admin_user(&pool, Role::PlatformManager).await;
let test_endpoint_id = get_test_endpoint_id(&pool).await;
let response = app
.post("/admin/api/v1/models")
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.json(&json!({
"type": "standard",
"model_name": "empty-redirect",
"alias": "empty-redirect-alias",
"hosted_on": test_endpoint_id,
"traffic_routing_rules": [
{ "api_key_purpose": "realtime", "action": { "type": "redirect", "target": "" } }
]
}))
.await;
response.assert_status(axum::http::StatusCode::BAD_REQUEST);
}
#[sqlx::test]
#[test_log::test]
async fn test_delete_redirect_target_cascades(pool: PgPool) {
let (app, _bg_services) = create_test_app(pool.clone(), false).await;
let admin_user = create_test_admin_user(&pool, Role::PlatformManager).await;
let test_endpoint_id = get_test_endpoint_id(&pool).await;
// Create target model (model B)
let target = create_test_deployment(&pool, admin_user.id, "cascade-target", "cascade-target-alias").await;
// Create source model (model A) with redirect to model B
let create_resp = app
.post("/admin/api/v1/models")
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.json(&json!({
"type": "standard",
"model_name": "cascade-source",
"alias": "cascade-source-alias",
"hosted_on": test_endpoint_id,
"traffic_routing_rules": [
{ "api_key_purpose": "realtime", "action": { "type": "redirect", "target": "cascade-target-alias" } }
]
}))
.await;
create_resp.assert_status_ok();
let source: DeployedModelResponse = create_resp.json();
assert!(source.traffic_routing_rules.is_some());
// Hard-delete model B to trigger CASCADE (API uses soft delete which won't trigger CASCADE)
sqlx::query!("DELETE FROM deployed_models WHERE id = $1", target.id)
.execute(&pool)
.await
.unwrap();
// GET model A → traffic rules should be gone (CASCADE removed the redirect rule)
let response = app
.get(&format!("/admin/api/v1/models/{}", source.id))
.add_header(&add_auth_headers(&admin_user)[0].0, &add_auth_headers(&admin_user)[0].1)
.add_header(&add_auth_headers(&admin_user)[1].0, &add_auth_headers(&admin_user)[1].1)
.await;
response.assert_status_ok();
let fetched: DeployedModelResponse = response.json();
assert!(
fetched.traffic_routing_rules.is_none(),
"traffic rules should be cleared after cascade delete of redirect target"
);
}
}