use anyhow::Result;
use axum::{
Router,
extract::{Path, Query, State},
http::StatusCode,
middleware,
response::Json,
routing::get,
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tower_http::{
compression::CompressionLayer,
cors::{Any, CorsLayer},
trace::TraceLayer,
};
use tracing::{Level, info};
use tracing_subscriber::{EnvFilter, fmt, prelude::*};
mod audit;
mod auth;
mod config;
mod docs;
mod license;
mod metrics;
use config::Config;
#[derive(Clone)]
struct AppState {
config: Arc<Config>,
client: reqwest::Client,
}
#[derive(Serialize)]
struct HealthResponse {
status: String,
opensearch_url: String,
version: String,
}
#[derive(Serialize, Deserialize)]
struct Document {
#[serde(flatten)]
fields: serde_json::Value,
}
#[derive(Deserialize)]
struct SearchQuery {
q: String,
#[serde(default = "default_size")]
size: usize,
#[serde(default)]
from: usize,
}
fn default_size() -> usize {
10
}
#[derive(Serialize)]
struct IndexResponse {
id: String,
index: String,
result: String,
}
#[derive(Serialize)]
struct SearchResponse {
hits: Vec<serde_json::Value>,
total: usize,
took: u64,
}
#[derive(Serialize)]
struct ErrorResponse {
error: String,
message: String,
}
#[tokio::main]
async fn main() -> Result<()> {
let config = Config::from_env().expect("Failed to load configuration");
tracing_subscriber::registry()
.with(fmt::layer())
.with(
EnvFilter::builder()
.with_default_directive(Level::INFO.into())
.from_env_lossy(),
)
.init();
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()?;
let state = Arc::new(AppState {
config: Arc::new(config),
client,
});
if let Err(e) = metrics::setup_metrics_indices(&state).await {
tracing::warn!("Não foi possível criar templates de métricas: {}", e);
}
let metrics_state = state.clone();
tokio::spawn(async move {
metrics::collect_system_metrics(metrics_state).await;
});
let app = Router::new()
.route("/health", get(health_check))
.fallback(opensearch_proxy)
.layer(middleware::from_fn_with_state(
state.clone(),
metrics::metrics_middleware,
))
.layer(
CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any),
)
.layer(CompressionLayer::new())
.layer(TraceLayer::new_for_http())
.with_state(state.clone());
info!("OpenSearch API starting on http://{}", &state.config.addr);
info!(
"Connected to OpenSearch at: {}",
&state.config.opensearch_url
);
let listener = tokio::net::TcpListener::bind(&state.config.addr).await?;
axum::serve(listener, app).await?;
Ok(())
}
async fn root() -> &'static str {
"OpenSearch API v0.1.0"
}
async fn health_check(
State(state): State<Arc<AppState>>,
) -> Result<Json<HealthResponse>, StatusCode> {
let response = state
.client
.get(&state.config.opensearch_url)
.send()
.await
.map_err(|_| StatusCode::SERVICE_UNAVAILABLE)?;
if !response.status().is_success() {
return Err(StatusCode::SERVICE_UNAVAILABLE);
}
Ok(Json(HealthResponse {
status: "healthy".to_string(),
opensearch_url: state.config.opensearch_url.clone(),
version: "0.1.0".to_string(),
}))
}
async fn index_document(
auth_user: auth::AuthUser, State(state): State<Arc<AppState>>,
Path(index_name): Path<String>,
Json(document): Json<Document>,
) -> Result<Json<IndexResponse>, (StatusCode, Json<ErrorResponse>)> {
tracing::info!(
"User {} (role: {}) indexing document in {}",
auth_user.id,
auth_user.role,
index_name
);
let url = format!("{}/{}/_doc", state.config.opensearch_url, index_name);
let response = state
.client
.post(&url)
.json(&document.fields)
.send()
.await
.map_err(|e| {
(
StatusCode::BAD_GATEWAY,
Json(ErrorResponse {
error: "opensearch_error".to_string(),
message: e.to_string(),
}),
)
})?;
let status = response.status();
let body: serde_json::Value = response.json().await.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "parse_error".to_string(),
message: e.to_string(),
}),
)
})?;
if !status.is_success() {
return Err((
StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
Json(ErrorResponse {
error: "opensearch_error".to_string(),
message: body
.get("error")
.map(|e| e.to_string())
.unwrap_or_else(|| "Unknown error".to_string()),
}),
));
}
Ok(Json(IndexResponse {
id: body["_id"].as_str().unwrap_or("").to_string(),
index: index_name,
result: body["result"].as_str().unwrap_or("unknown").to_string(),
}))
}
async fn search_documents(
auth_user: auth::AuthUser, State(state): State<Arc<AppState>>,
Path(index_name): Path<String>,
Query(params): Query<SearchQuery>,
) -> Result<Json<SearchResponse>, (StatusCode, Json<ErrorResponse>)> {
tracing::info!(
"User {} searching in {} for: {}",
auth_user.id,
index_name,
params.q
);
let query = serde_json::json!({
"query": {
"multi_match": {
"query": params.q,
"fields": ["*"]
}
},
"size": params.size,
"from": params.from
});
let url = format!("{}/{}/_search", state.config.opensearch_url, index_name);
let response = state
.client
.post(&url)
.json(&query)
.send()
.await
.map_err(|e| {
(
StatusCode::BAD_GATEWAY,
Json(ErrorResponse {
error: "opensearch_error".to_string(),
message: e.to_string(),
}),
)
})?;
let status = response.status();
let body: serde_json::Value = response.json().await.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "parse_error".to_string(),
message: e.to_string(),
}),
)
})?;
if !status.is_success() {
return Err((
StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
Json(ErrorResponse {
error: "opensearch_error".to_string(),
message: body
.get("error")
.map(|e| e.to_string())
.unwrap_or_else(|| "Unknown error".to_string()),
}),
));
}
let hits = body["hits"]["hits"]
.as_array()
.map(|arr| arr.iter().map(|hit| hit["_source"].clone()).collect())
.unwrap_or_default();
let total = body["hits"]["total"]["value"].as_u64().unwrap_or(0) as usize;
let took = body["took"].as_u64().unwrap_or(0);
Ok(Json(SearchResponse { hits, total, took }))
}
async fn list_indices(
auth_user: auth::AuthUser, State(state): State<Arc<AppState>>,
) -> Result<Json<Vec<serde_json::Value>>, (StatusCode, Json<ErrorResponse>)> {
tracing::info!("User {} listing indices", auth_user.id);
let url = format!("{}/_cat/indices?format=json", state.config.opensearch_url);
let response = state.client.get(&url).send().await.map_err(|e| {
(
StatusCode::BAD_GATEWAY,
Json(ErrorResponse {
error: "opensearch_error".to_string(),
message: e.to_string(),
}),
)
})?;
let indices: Vec<serde_json::Value> = response.json().await.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "parse_error".to_string(),
message: e.to_string(),
}),
)
})?;
Ok(Json(indices))
}
async fn get_index_info(
auth_user: auth::AuthUser, State(state): State<Arc<AppState>>,
Path(index_name): Path<String>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<ErrorResponse>)> {
tracing::info!(
"User {} getting info for index: {}",
auth_user.id,
index_name
);
let url = format!("{}/{}", state.config.opensearch_url, index_name);
let response = state.client.get(&url).send().await.map_err(|e| {
(
StatusCode::BAD_GATEWAY,
Json(ErrorResponse {
error: "opensearch_error".to_string(),
message: e.to_string(),
}),
)
})?;
let status = response.status();
let body: serde_json::Value = response.json().await.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "parse_error".to_string(),
message: e.to_string(),
}),
)
})?;
if !status.is_success() {
return Err((
StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
Json(ErrorResponse {
error: "opensearch_error".to_string(),
message: body
.get("error")
.map(|e| e.to_string())
.unwrap_or_else(|| "Unknown error".to_string()),
}),
));
}
Ok(Json(body))
}
async fn opensearch_proxy(
State(state): State<Arc<AppState>>,
headers: axum::http::HeaderMap,
method: axum::http::Method,
uri: axum::http::Uri,
body: axum::body::Bytes,
) -> impl axum::response::IntoResponse {
use axum::http::{StatusCode, header};
let api_key = headers
.get(header::AUTHORIZATION)
.and_then(|h| h.to_str().ok())
.and_then(|h| h.strip_prefix("Bearer "));
let is_valid = match api_key {
Some(key) => {
std::env::var("API_TOKENS")
.unwrap_or_default()
.split(',')
.any(|token| token.trim() == key)
}
None => false,
};
if !is_valid {
return axum::response::Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(axum::body::Body::from(
"Unauthorized: Invalid or missing API key",
))
.unwrap();
}
let path = uri.path();
let query = uri.query().unwrap_or("");
tracing::info!("Proxy request: {} {}", method, path);
let url = if query.is_empty() {
format!("{}{}", state.config.opensearch_url, path)
} else {
format!("{}{}?{}", state.config.opensearch_url, path, query)
};
let mut opensearch_req = state.client.request(method, &url);
for (key, value) in headers.iter() {
if key != header::AUTHORIZATION && key != header::HOST && key != header::CONNECTION {
opensearch_req = opensearch_req.header(key, value);
}
}
if !body.is_empty() {
opensearch_req = opensearch_req.body(body);
}
let response = match opensearch_req.send().await {
Ok(resp) => resp,
Err(e) => {
tracing::error!("OpenSearch error: {}", e);
return axum::response::Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(axum::body::Body::from(format!("OpenSearch error: {}", e)))
.unwrap();
}
};
let status = response.status();
let resp_headers = response.headers().clone();
let response_body = match response.bytes().await {
Ok(bytes) => bytes,
Err(e) => {
return axum::response::Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(axum::body::Body::from(format!(
"Error reading response: {}",
e
)))
.unwrap();
}
};
let mut final_response = axum::response::Response::builder().status(status.as_u16());
for (key, value) in resp_headers.iter() {
if key != header::CONNECTION && key != header::TRANSFER_ENCODING {
final_response = final_response.header(key, value);
}
}
final_response
.body(axum::body::Body::from(response_body))
.unwrap()
}