use std::future::Future;
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use axum::Extension;
use axum::Json;
use axum::Router;
use axum::extract::{Path, Query, Request, State};
use axum::http::{HeaderMap, Method, StatusCode};
use axum::middleware::{self, Next};
use axum::response::Response;
use axum::routing::{delete, get, post};
use k2db::options::{FindOptions, ProjectionMode};
use k2db::results::VersionInfo as DbVersionInfo;
use k2db_api_contract::{
AggregateRequest, ApiKeyIdentity, CountRequest, CreateIndexesRequest, CreateResult,
HealthOk, MessageResponse, PatchCollectionRequest, ReadyNotOk, ReadyOk, RestoreRequest,
SearchRequest, UpdateResult, VersionInfo, VersionedUpdateRequest, VersionedUpdateResult,
};
use mongodb::bson::{self, Bson, Document, doc};
use mongodb::options::IndexOptions;
use serde::Deserialize;
use serde::Serialize;
use serde_json::{Value, json};
use crate::api_error::ApiError;
use crate::auth::{AuthProvider, require_permission};
use crate::bootstrap::BootstrapConfig;
use crate::cli::{Cli, Command, ConfigCommand, KeysCommand};
use crate::config::AppConfig;
use crate::control_plane::{
RecoverConfigAction, create_key as create_control_plane_key, get_config as get_control_plane_config,
init as init_control_plane, list_keys as list_control_plane_keys, recover as recover_control_plane,
revoke_key as revoke_control_plane_key, set_config as set_control_plane_config,
};
use crate::manager::DbManager;
use crate::telemetry::Telemetry;
#[cfg(test)]
use k2db_api_contract::ProblemDetailsPayload;
#[cfg(test)]
use crate::config::{ApiKeyConfig, ApiKeySection, HostEntry, K2DbSection, ServerConfig};
type ReadinessFuture = Pin<Box<dyn Future<Output = Vec<String>> + Send>>;
type ReadinessCheck = Arc<dyn Fn() -> ReadinessFuture + Send + Sync>;
#[derive(Clone)]
struct AppState {
config: Arc<AppConfig>,
auth: AuthProvider,
manager: DbManager,
readiness: ReadinessCheck,
}
#[derive(Debug, Deserialize)]
struct VersionsQuery {
skip: Option<u64>,
limit: Option<u64>,
}
pub async fn run(cli: Cli) -> Result<(), Box<dyn std::error::Error>> {
match cli.command {
Command::Serve(args) => serve(args).await?,
Command::Init(args) => {
let outcome = init_control_plane(&args).await?;
println!(
"k2db-api initialized {} using {}",
outcome.bootstrap.system_db_name,
outcome.bootstrap.redacted_mongo_uri()
);
println!("active server_config: {}", outcome.server_config_id);
if let Some(key) = outcome.issued_key {
println!("runtime key label: {}", key.name);
println!("runtime key id: {}", key.key_id);
println!("runtime key database: {}", key.database);
println!("runtime key permissions: {}", key.permissions.join(","));
println!("runtime key printable: {}", key.printable);
}
}
Command::Recover(args) => {
let outcome = recover_control_plane(&args).await?;
println!(
"k2db-api recovered {} using {}",
outcome.bootstrap.system_db_name,
outcome.bootstrap.redacted_mongo_uri()
);
println!(
"server_config action: {}",
match outcome.server_config_action {
RecoverConfigAction::Preserved => "preserved",
RecoverConfigAction::Replaced => "replaced",
}
);
println!("active server_config: {}", outcome.server_config_id);
if let Some(key) = outcome.issued_key {
println!("runtime key label: {}", key.name);
println!("runtime key id: {}", key.key_id);
println!("runtime key database: {}", key.database);
println!("runtime key permissions: {}", key.permissions.join(","));
println!("runtime key printable: {}", key.printable);
}
}
Command::Keys(args) => match args.command {
KeysCommand::Create(create) => {
let key = create_control_plane_key(&args.bootstrap, &create).await?;
println!("runtime key label: {}", key.name);
println!("runtime key id: {}", key.key_id);
println!("runtime key database: {}", key.database);
println!("runtime key permissions: {}", key.permissions.join(","));
println!("runtime key printable: {}", key.printable);
}
KeysCommand::Revoke(revoke) => {
let outcome = revoke_control_plane_key(&args.bootstrap, &revoke).await?;
println!("revoked key: {}", outcome.key_id);
println!("was active: {}", outcome.was_active);
println!("was already revoked: {}", outcome.was_revoked);
}
KeysCommand::List => {
for key in list_control_plane_keys(&args.bootstrap).await? {
println!(
"{}\t{}\t{}\t{}\t{}\t{}\t{}",
key.key_id,
key.name.unwrap_or_default(),
key.database,
key.active,
key.expires_at.map(|value| value.to_string()).unwrap_or_default(),
key.revoked_at.map(|value| value.to_string()).unwrap_or_default(),
key.permissions.join(",")
);
}
}
},
Command::Config(args) => match args.command {
ConfigCommand::Get => {
let config = get_control_plane_config(&args.bootstrap).await?;
println!("active server_config: {}", config.id);
println!("listen host: {}", config.listen_host);
println!("listen port: {}", config.listen_port);
println!(
"ownership mode: {}",
match config.ownership_mode {
k2db::OwnershipMode::Lax => "lax",
k2db::OwnershipMode::Strict => "strict",
}
);
println!(
"slow query ms: {}",
config
.slow_query_ms
.map(|value| value.to_string())
.unwrap_or_default()
);
}
ConfigCommand::Set(set) => {
let outcome = set_control_plane_config(&args.bootstrap, &set).await?;
if let Some(previous) = outcome.previous_config_id {
println!("previous server_config: {}", previous);
}
println!("active server_config: {}", outcome.active_config.id);
println!("listen host: {}", outcome.active_config.listen_host);
println!("listen port: {}", outcome.active_config.listen_port);
println!(
"ownership mode: {}",
match outcome.active_config.ownership_mode {
k2db::OwnershipMode::Lax => "lax",
k2db::OwnershipMode::Strict => "strict",
}
);
println!(
"slow query ms: {}",
outcome
.active_config
.slow_query_ms
.map(|value| value.to_string())
.unwrap_or_default()
);
}
},
}
Ok(())
}
async fn serve(args: crate::cli::BootstrapArgs) -> Result<(), Box<dyn std::error::Error>> {
let state = build_state(&args).await?;
state.manager.warm_databases().await.map_err(boxed_error)?;
ensure_ready(&state).await.map_err(boxed_error)?;
let host = state.config.server.host.clone();
let port = state.config.server.port;
let address: SocketAddr = format!("{host}:{port}").parse()?;
let telemetry = Telemetry::from_env("k2db-api-rust", Some(address.to_string()), "k2db-api*");
let listener = tokio::net::TcpListener::bind(address).await?;
println!("k2db-api listening on {address}");
telemetry.emit(
"k2db-api:info",
"k2db-api:service_up",
json!({
"listen": address.to_string(),
"tenant_databases": state.config.tenant_databases(),
"system_db_name": args.system_db_name,
}),
);
axum::serve(listener, create_router(state))
.with_graceful_shutdown(shutdown_signal())
.await?;
Ok(())
}
async fn build_state(
args: &crate::cli::BootstrapArgs,
) -> Result<AppState, Box<dyn std::error::Error>> {
let bootstrap = BootstrapConfig::resolve_runtime(args)?;
let config = Arc::new(AppConfig::load(&bootstrap).await?);
let manager = DbManager::new(config.clone());
let auth = AuthProvider::new(config.apikey.keys.clone()).map_err(|error| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("invalid API key configuration: {error:?}"),
)
})?;
let readiness = readiness_probe(config.clone(), manager.clone());
Ok(AppState {
config,
auth,
manager,
readiness,
})
}
fn readiness_probe(config: Arc<AppConfig>, manager: DbManager) -> ReadinessCheck {
Arc::new(move || {
let config = config.clone();
let manager = manager.clone();
Box::pin(async move {
let mut down = Vec::new();
for database in config.tenant_databases() {
if !manager.is_db_healthy(&database).await {
down.push(database);
}
}
down
})
})
}
fn create_router(state: AppState) -> Router {
let api_state = state.clone();
let api_router = Router::new()
.route("/admin/{collection}", delete(admin_delete_collection))
.route("/admin/{collection}/{id}", delete(admin_delete_by_id))
.route("/admin/{collection}/indexes", post(admin_create_indexes))
.route(
"/admin/{collection}/history-indexes",
post(admin_create_history_indexes),
)
.route("/{collection}/search", post(search_collection))
.route("/{collection}/aggregate", post(aggregate_collection))
.route("/{collection}/count", post(count_collection))
.route("/{collection}/restore", post(restore_collection))
.route("/{collection}/{id}/versions", get(get_versions).patch(patch_versions))
.route(
"/{collection}/{id}/versions/{version}/revert",
post(revert_version),
)
.route("/{collection}/{id}", get(get_by_id).patch(patch_by_id).delete(delete_by_id))
.route("/{collection}", post(create_document).patch(patch_collection))
.route_layer(middleware::from_fn_with_state(
api_state,
authenticate_request,
));
Router::new()
.route("/health", get(health))
.route("/ready", get(ready))
.nest("/v1", api_router)
.fallback(not_found)
.with_state(state)
}
async fn authenticate_request(
State(state): State<AppState>,
mut request: Request,
next: Next,
) -> Result<Response, ApiError> {
let header = request
.headers()
.get(axum::http::header::AUTHORIZATION)
.and_then(|value| value.to_str().ok());
let identity = state.auth.verify(header).await?;
request.extensions_mut().insert(identity);
Ok(next.run(request).await)
}
async fn health() -> Json<HealthOk> {
Json(HealthOk {
status: "ok".to_owned(),
})
}
async fn ready(State(state): State<AppState>) -> (StatusCode, Json<serde_json::Value>) {
let down = (state.readiness)().await;
if down.is_empty() {
return (
StatusCode::OK,
Json(serde_json::to_value(ReadyOk {
status: "ready".to_owned(),
})
.expect("serialize ready response")),
);
}
(
StatusCode::SERVICE_UNAVAILABLE,
Json(serde_json::to_value(ReadyNotOk {
status: "not-ready".to_owned(),
databases: down,
})
.expect("serialize not-ready response")),
)
}
async fn create_document(
State(state): State<AppState>,
Extension(identity): Extension<ApiKeyIdentity>,
Path(collection): Path<String>,
headers: HeaderMap,
Json(body): Json<Value>,
) -> Result<(StatusCode, Json<CreateResult>), ApiError> {
require_permission(&identity, "collections.write")?;
let scope = require_scope_header(&headers)?;
let document = value_to_document(body, "Document data is required", "t-create-document-001")?;
let db = state.manager.get_db(&identity.database).await?;
let result = db
.create(&collection, &scope, document)
.await
.map_err(ApiError::from_k2db)?;
Ok((StatusCode::CREATED, Json(CreateResult { id: result.id })))
}
async fn get_by_id(
State(state): State<AppState>,
Extension(identity): Extension<ApiKeyIdentity>,
Path((collection, id)): Path<(String, String)>,
headers: HeaderMap,
) -> Result<Json<Value>, ApiError> {
require_permission(&identity, "collections.read")?;
let scoped = state
.manager
.scoped_db(&identity.database, &require_scope_header(&headers)?)
.await?;
let document = scoped.get(&collection, &id).await.map_err(ApiError::from_k2db)?;
Ok(Json(to_json_value(document)?))
}
async fn patch_by_id(
State(state): State<AppState>,
Extension(identity): Extension<ApiKeyIdentity>,
Path((collection, id)): Path<(String, String)>,
headers: HeaderMap,
Json(body): Json<Value>,
) -> Result<Json<UpdateResult>, ApiError> {
require_permission(&identity, "collections.write")?;
let updates = value_to_document(body, "Update data is required", "t-patchbyid-updates-003")?;
let scoped = state
.manager
.scoped_db(&identity.database, &require_scope_header(&headers)?)
.await?;
let result = scoped
.update(&collection, &id, updates, false)
.await
.map_err(ApiError::from_k2db)?;
if result.updated == 0 {
return Err(ApiError::new(
StatusCode::NOT_FOUND,
"not_found",
"Document not found",
Some("t-patchbyid-notfound-005".to_owned()),
));
}
Ok(Json(UpdateResult { updated: result.updated }))
}
async fn patch_collection(
State(state): State<AppState>,
Extension(identity): Extension<ApiKeyIdentity>,
Path(collection): Path<String>,
headers: HeaderMap,
Json(body): Json<PatchCollectionRequest>,
) -> Result<Json<UpdateResult>, ApiError> {
require_permission(&identity, "collections.write")?;
let criteria = value_to_document(body.criteria, "Update criteria is required", "t-patchcollection-criteria-001")?;
let values = value_to_document(body.values, "Update values are required", "t-patchcollection-values-002")?;
let scoped = state
.manager
.scoped_db(&identity.database, &require_scope_header(&headers)?)
.await?;
let result = scoped
.update_all(&collection, criteria, values)
.await
.map_err(ApiError::from_k2db)?;
Ok(Json(UpdateResult { updated: result.updated }))
}
async fn delete_by_id(
State(state): State<AppState>,
Extension(identity): Extension<ApiKeyIdentity>,
Path((collection, id)): Path<(String, String)>,
headers: HeaderMap,
) -> Result<StatusCode, ApiError> {
require_permission(&identity, "collections.write")?;
let scoped = state
.manager
.scoped_db(&identity.database, &require_scope_header(&headers)?)
.await?;
let result = scoped.delete(&collection, &id).await.map_err(ApiError::from_k2db)?;
if result.deleted == 0 {
return Err(ApiError::new(
StatusCode::NOT_FOUND,
"not_found",
"Document not found",
Some("t-deletebyid-notfound-004".to_owned()),
));
}
Ok(StatusCode::NO_CONTENT)
}
async fn search_collection(
State(state): State<AppState>,
Extension(identity): Extension<ApiKeyIdentity>,
Path(collection): Path<String>,
headers: HeaderMap,
Json(body): Json<SearchRequest>,
) -> Result<Json<Value>, ApiError> {
require_permission(&identity, "collections.read")?;
let scoped = state
.manager
.scoped_db(&identity.database, &require_scope_header(&headers)?)
.await?;
let filter = value_to_document(body.filter, "Filter must be an object", "t-search-filter-001")?;
let options = find_options_from_request(body.params, body.skip, body.limit)?;
let documents = scoped.find(&collection, filter, options).await.map_err(ApiError::from_k2db)?;
Ok(Json(to_json_value(documents)?))
}
async fn aggregate_collection(
State(state): State<AppState>,
Extension(identity): Extension<ApiKeyIdentity>,
Path(collection): Path<String>,
headers: HeaderMap,
Json(body): Json<AggregateRequest>,
) -> Result<Json<Value>, ApiError> {
require_permission(&identity, "collections.read")?;
let scoped = state
.manager
.scoped_db(&identity.database, &require_scope_header(&headers)?)
.await?;
let pipeline = body
.criteria
.into_iter()
.map(|stage| value_to_document(stage, "Criteria is required", "t-aggregate-criteria-001"))
.collect::<Result<Vec<_>, _>>()?;
let documents = scoped
.aggregate(&collection, pipeline, body.skip.unwrap_or(0), body.limit.unwrap_or(100))
.await
.map_err(ApiError::from_k2db)?;
Ok(Json(to_json_value(documents)?))
}
async fn count_collection(
State(state): State<AppState>,
Extension(identity): Extension<ApiKeyIdentity>,
Path(collection): Path<String>,
headers: HeaderMap,
Json(body): Json<CountRequest>,
) -> Result<Json<k2db_api_contract::CountResult>, ApiError> {
require_permission(&identity, "collections.read")?;
let scoped = state
.manager
.scoped_db(&identity.database, &require_scope_header(&headers)?)
.await?;
let criteria = value_to_document(
body.criteria.unwrap_or_else(|| serde_json::json!({})),
"Criteria must be an object",
"t-count-criteria-001",
)?;
let result = scoped.count(&collection, criteria).await.map_err(ApiError::from_k2db)?;
Ok(Json(k2db_api_contract::CountResult { count: result.count }))
}
async fn restore_collection(
State(state): State<AppState>,
Extension(identity): Extension<ApiKeyIdentity>,
Path(collection): Path<String>,
headers: HeaderMap,
Json(body): Json<RestoreRequest>,
) -> Result<Json<k2db_api_contract::RestoreResult>, ApiError> {
require_permission(&identity, "collections.write")?;
let scoped = state
.manager
.scoped_db(&identity.database, &require_scope_header(&headers)?)
.await?;
let criteria = value_to_document(body.criteria, "Restore criteria is required", "t-restore-criteria-001")?;
let result = scoped.restore(&collection, criteria).await.map_err(ApiError::from_k2db)?;
Ok(Json(k2db_api_contract::RestoreResult {
status: result.status,
modified: result.modified,
}))
}
async fn get_versions(
State(state): State<AppState>,
Extension(identity): Extension<ApiKeyIdentity>,
Path((collection, id)): Path<(String, String)>,
headers: HeaderMap,
Query(query): Query<VersionsQuery>,
) -> Result<Json<Vec<VersionInfo>>, ApiError> {
require_permission(&identity, "collections.read")?;
let scoped = state
.manager
.scoped_db(&identity.database, &require_scope_header(&headers)?)
.await?;
let versions = scoped
.list_versions(&collection, &id, query.skip.unwrap_or(0), query.limit.unwrap_or(100))
.await
.map_err(ApiError::from_k2db)?;
Ok(Json(versions.into_iter().map(map_version_info).collect()))
}
async fn patch_versions(
State(state): State<AppState>,
Extension(identity): Extension<ApiKeyIdentity>,
Path((collection, id)): Path<(String, String)>,
headers: HeaderMap,
Json(body): Json<VersionedUpdateRequest>,
) -> Result<Json<Vec<VersionedUpdateResult>>, ApiError> {
require_permission(&identity, "collections.write")?;
let scoped = state
.manager
.scoped_db(&identity.database, &require_scope_header(&headers)?)
.await?;
let data = value_to_document(body.data, "Update data is required", "t-patchversions-data-003")?;
let result = scoped
.update_versioned(
&collection,
&id,
data,
body.replace.unwrap_or(false),
body.max_versions,
)
.await
.map_err(ApiError::from_k2db)?;
Ok(Json(
result
.into_iter()
.map(|item| VersionedUpdateResult {
updated: item.updated,
version_saved: item.version_saved,
})
.collect(),
))
}
async fn revert_version(
State(state): State<AppState>,
Extension(identity): Extension<ApiKeyIdentity>,
Path((collection, id, version)): Path<(String, String, u64)>,
headers: HeaderMap,
) -> Result<Json<UpdateResult>, ApiError> {
require_permission(&identity, "collections.write")?;
let scoped = state
.manager
.scoped_db(&identity.database, &require_scope_header(&headers)?)
.await?;
let result = scoped
.revert_to_version(&collection, &id, version)
.await
.map_err(ApiError::from_k2db)?;
Ok(Json(UpdateResult { updated: result.updated }))
}
async fn admin_delete_collection(
State(state): State<AppState>,
Extension(identity): Extension<ApiKeyIdentity>,
Path(collection): Path<String>,
headers: HeaderMap,
) -> Result<StatusCode, ApiError> {
require_permission(&identity, "admin.delete")?;
let scoped = state
.manager
.scoped_db(&identity.database, &require_scope_header(&headers)?)
.await?;
scoped
.drop_collection(&collection)
.await
.map_err(ApiError::from_k2db)?;
Ok(StatusCode::NO_CONTENT)
}
async fn admin_delete_by_id(
State(state): State<AppState>,
Extension(identity): Extension<ApiKeyIdentity>,
Path((collection, id)): Path<(String, String)>,
headers: HeaderMap,
) -> Result<StatusCode, ApiError> {
require_permission(&identity, "admin.delete")?;
let scoped = state
.manager
.scoped_db(&identity.database, &require_scope_header(&headers)?)
.await?;
scoped.purge(&collection, &id).await.map_err(ApiError::from_k2db)?;
Ok(StatusCode::NO_CONTENT)
}
async fn admin_create_indexes(
State(state): State<AppState>,
Extension(identity): Extension<ApiKeyIdentity>,
Path(collection): Path<String>,
headers: HeaderMap,
Json(body): Json<CreateIndexesRequest>,
) -> Result<Json<MessageResponse>, ApiError> {
require_permission(&identity, "admin.indexes")?;
let _scope = require_scope_header(&headers)?;
let db = state.manager.get_db(&identity.database).await?;
let index_spec = value_to_document(
body.index_spec,
"Index specification is required",
"t-createindexes-indexspec-001",
)?;
let options = body
.options
.map(|value| value_to_document(value, "Index options must be an object", "t-createindexes-options-001"))
.transpose()?
.map(|value| bson::from_bson::<IndexOptions>(Bson::Document(value)))
.transpose()
.map_err(|_| ApiError::bad_request("Invalid index options", "t-createindexes-options-001"))?;
db.create_index(&collection, index_spec, options)
.await
.map_err(ApiError::from_k2db)?;
Ok(Json(MessageResponse {
message: "Index created successfully".to_owned(),
}))
}
async fn admin_create_history_indexes(
State(state): State<AppState>,
Extension(identity): Extension<ApiKeyIdentity>,
Path(collection): Path<String>,
headers: HeaderMap,
) -> Result<Json<MessageResponse>, ApiError> {
require_permission(&identity, "admin.indexes")?;
let scoped = state
.manager
.scoped_db(&identity.database, &require_scope_header(&headers)?)
.await?;
scoped
.ensure_history_indexes(&collection)
.await
.map_err(ApiError::from_k2db)?;
Ok(Json(MessageResponse {
message: "History indexes ensured successfully".to_owned(),
}))
}
async fn not_found(method: Method, uri: axum::http::Uri) -> ApiError {
ApiError::new(
StatusCode::NOT_FOUND,
"not_found",
format!("Route {method}:{uri} not found"),
Some("t-fastify-route-notfound-001".to_owned()),
)
}
fn require_scope_header(headers: &HeaderMap) -> Result<String, ApiError> {
headers
.get("x-scope")
.and_then(|value| value.to_str().ok())
.map(str::trim)
.filter(|value| !value.is_empty())
.map(ToOwned::to_owned)
.ok_or_else(|| ApiError::bad_request("Scope header is required", "t-scope-required-001"))
}
fn value_to_document(value: Value, message: &str, trace: &'static str) -> Result<Document, ApiError> {
match value {
Value::Object(_) => bson::to_document(&value)
.map_err(|_| ApiError::bad_request(message.to_owned(), trace)),
_ => Err(ApiError::bad_request(message.to_owned(), trace)),
}
}
fn find_options_from_request(
params: Option<Value>,
skip: Option<u64>,
limit: Option<u64>,
) -> Result<FindOptions, ApiError> {
let mut options = FindOptions {
skip: skip.unwrap_or(0),
limit: limit.unwrap_or(100),
..FindOptions::default()
};
let Some(params) = params else {
return Ok(options);
};
let Value::Object(map) = params else {
return Err(ApiError::bad_request("Search params must be an object", "t-search-params-001"));
};
let include_deleted = map
.get("includeDeleted")
.and_then(Value::as_bool)
.unwrap_or(false);
let deleted_only = map.get("deleted").and_then(Value::as_bool).unwrap_or(false) && !include_deleted;
options.include_deleted = include_deleted;
options.deleted_only = deleted_only;
if let Some(order) = map.get("order") {
options.sort = Some(order_to_document(order)?);
}
if let Some(filter) = map.get("filter") {
options.projection = match filter {
Value::String(value) if value == "all" => ProjectionMode::All,
Value::Array(items) => ProjectionMode::Include(string_array(items, "t-search-filter-002")?),
_ => return Err(ApiError::bad_request("Search params.filter is invalid", "t-search-filter-002")),
};
} else if let Some(exclude) = map.get("exclude") {
options.projection = match exclude {
Value::Array(items) => ProjectionMode::Exclude(string_array(items, "t-search-exclude-001")?),
_ => return Err(ApiError::bad_request("Search params.exclude is invalid", "t-search-exclude-001")),
};
}
Ok(options)
}
fn string_array(items: &[Value], trace: &'static str) -> Result<Vec<String>, ApiError> {
items
.iter()
.map(|item| {
item.as_str()
.map(|value| value.trim().to_owned())
.filter(|value| !value.is_empty())
.ok_or_else(|| ApiError::bad_request("Expected a string array", trace))
})
.collect()
}
fn order_to_document(value: &Value) -> Result<Document, ApiError> {
let Value::Object(map) = value else {
return Err(ApiError::bad_request("Search params.order is invalid", "t-search-order-001"));
};
let mut order = doc! {};
for (key, value) in map {
let direction = match value {
Value::String(text) if text == "asc" => 1,
_ => -1,
};
order.insert(key, direction);
}
Ok(order)
}
fn map_version_info(value: DbVersionInfo) -> VersionInfo {
VersionInfo {
uuid: value.uuid,
version: value.version,
at: value.at,
}
}
fn to_json_value<T: Serialize>(value: T) -> Result<Value, ApiError> {
serde_json::to_value(value)
.map_err(|_| ApiError::internal("Failed to serialize response", "t-response-serialize-001"))
}
fn boxed_error(error: ApiError) -> Box<dyn std::error::Error> {
Box::new(std::io::Error::other(format!("{:?}", error)))
}
async fn ensure_ready(state: &AppState) -> Result<(), ApiError> {
let down = (state.readiness)().await;
if down.is_empty() {
Ok(())
} else {
Err(ApiError::service_unavailable(
format!("Databases not ready: {}", down.join(", ")),
"t-ready-not-ready-001",
))
}
}
async fn shutdown_signal() {
let ctrl_c = async {
let _ = tokio::signal::ctrl_c().await;
};
#[cfg(unix)]
let terminate = async {
use tokio::signal::unix::{SignalKind, signal};
match signal(SignalKind::terminate()) {
Ok(mut stream) => {
let _ = stream.recv().await;
}
Err(_) => std::future::pending::<()>().await,
}
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {}
_ = terminate => {}
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::*;
use axum::body::{Body, to_bytes};
use axum::http::Request;
use base64::Engine;
use scrypt::{Params, scrypt};
use tower::ServiceExt;
fn test_state(readiness: ReadinessCheck) -> AppState {
let config = Arc::new(AppConfig {
server: ServerConfig {
host: "0.0.0.0".to_owned(),
port: 3000,
},
k2db: K2DbSection {
hosts: vec![HostEntry {
host: "localhost".to_owned(),
port: Some(27017),
}],
user: None,
password: None,
auth_source: None,
ownership_mode: None,
replica_set: None,
slow_query_ms: None,
},
apikey: ApiKeySection {
keys: HashMap::from([(
"worker".to_owned(),
ApiKeyConfig {
key_id: "demo".to_owned(),
secret_hash: test_secret_hash("secret-value"),
database: "project_a".to_owned(),
permissions: vec!["collections.read".to_owned()],
active: true,
expires_at: None,
},
)]),
},
});
let manager = DbManager::new(config.clone());
let auth = AuthProvider::new(config.apikey.keys.clone()).expect("auth provider");
AppState {
config,
auth,
manager,
readiness,
}
}
fn test_secret_hash(secret: &str) -> String {
let salt = b"0123456789abcdef";
let params = Params::new(14, 8, 1, 32).expect("params");
let mut out = vec![0_u8; 32];
scrypt(secret.as_bytes(), salt, ¶ms, &mut out).expect("scrypt");
format!(
"scrypt$16384$8$1${}${}",
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(salt),
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(out)
)
}
#[tokio::test]
async fn health_route_returns_ok_payload() {
let state = test_state(Arc::new(|| Box::pin(async { Vec::new() })));
let response = create_router(state)
.oneshot(Request::builder().uri("/health").body(Body::empty()).expect("request"))
.await
.expect("response");
assert_eq!(response.status(), StatusCode::OK);
let body = to_bytes(response.into_body(), usize::MAX).await.expect("body");
let value: Value = serde_json::from_slice(&body).expect("json");
assert_eq!(value, serde_json::json!({ "status": "ok" }));
}
#[tokio::test]
async fn ready_route_returns_not_ready_with_database_list() {
let state = test_state(Arc::new(|| Box::pin(async { vec!["project_a".to_owned()] })));
let response = create_router(state)
.oneshot(Request::builder().uri("/ready").body(Body::empty()).expect("request"))
.await
.expect("response");
assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
let body = to_bytes(response.into_body(), usize::MAX).await.expect("body");
let value: Value = serde_json::from_slice(&body).expect("json");
assert_eq!(value, serde_json::json!({ "status": "not-ready", "databases": ["project_a"] }));
}
#[tokio::test]
async fn authenticated_routes_require_authorization_header() {
let state = test_state(Arc::new(|| Box::pin(async { Vec::new() })));
let response = create_router(state)
.oneshot(
Request::builder()
.uri("/v1/widgets/ABC123")
.body(Body::empty())
.expect("request"),
)
.await
.expect("response");
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
let body = to_bytes(response.into_body(), usize::MAX).await.expect("body");
let value: ProblemDetailsPayload = serde_json::from_slice(&body).expect("json");
assert_eq!(value.title.as_deref(), Some("unauthorized"));
}
}