use actix_web::body::BoxBody;
use actix_web::http::StatusCode;
use actix_web::test::TestRequest;
use actix_web::web::Bytes;
use actix_web::web::Data;
use actix_web::web::Json;
use actix_web::{HttpRequest, HttpResponse, Responder, get, post};
use serde::Deserialize;
use serde_json::{Value, json};
use std::collections::HashMap;
use std::fs;
use std::path::Path;
use std::time::Instant;
use tracing::{error, info};
use crate::AppState;
use crate::api::gateway::auth::{
authorize_gateway_request, read_right_for_resource, write_right_for_resource,
};
use crate::api::gateway::fetch::handle_fetch_data_route;
use crate::api::gateway::insert::insert;
use crate::api::headers::x_athena_client::x_athena_client;
use crate::api::response::{bad_request, internal_error};
use crate::drivers::postgresql::sqlx_driver::insert_row;
use crate::utils::format::normalize_column_name;
use crate::utils::request_logging::LoggedRequest;
use crate::utils::request_logging::{log_operation_event, log_request};
#[derive(Debug, Clone, Default, Deserialize)]
#[serde(rename_all = "snake_case", default)]
pub struct SourceConfig {
pub table_name: String,
pub view_name: Option<String>,
pub columns: Option<Vec<String>>,
pub conditions: Option<Vec<ConditionEntry>>,
pub limit: Option<u64>,
pub client: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct ConditionEntry {
pub eq_column: String,
#[serde(deserialize_with = "deserialize_eq_value")]
pub eq_value: String,
}
fn deserialize_eq_value<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: serde::Deserializer<'de>,
{
let v: Value = Value::deserialize(deserializer)?;
Ok(match v {
Value::Bool(b) => {
if b {
"true".to_string()
} else {
"false".to_string()
}
}
Value::String(s) => s,
Value::Number(n) => n.to_string(),
other => other.to_string(),
})
}
#[derive(Debug, Clone, Default, Deserialize)]
#[serde(rename_all = "snake_case", default)]
pub struct TransformConfig {
pub group_by: Option<String>,
pub time_granularity: Option<String>,
pub aggregation_column: Option<String>,
pub aggregation_strategy: Option<String>,
pub aggregation_dedup: Option<bool>,
pub column_aliases: Option<HashMap<String, String>>,
}
#[derive(Debug, Clone, Default, Deserialize)]
#[serde(rename_all = "snake_case", default)]
pub struct SinkConfig {
pub table_name: String,
pub client: Option<String>,
}
#[derive(Debug, Clone, Default, Deserialize)]
#[serde(rename_all = "snake_case", default)]
pub struct PipelineDefinition {
pub source: SourceConfig,
pub transform: Option<TransformConfig>,
pub sink: SinkConfig,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct PipelineRequest {
pub pipeline: Option<String>,
pub source: Option<SourceConfig>,
pub transform: Option<TransformConfig>,
pub sink: Option<SinkConfig>,
pub dry_run: Option<bool>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "snake_case")]
struct PipelineEntry {
name: String,
source: SourceConfig,
transform: Option<TransformConfig>,
sink: SinkConfig,
}
#[derive(Debug, Deserialize)]
struct PipelinesFile {
pipelines: Vec<PipelineEntry>,
}
pub fn load_registry_from_path<P: AsRef<Path>>(
path: P,
) -> Result<HashMap<String, PipelineDefinition>, Box<dyn std::error::Error + Send + Sync>> {
let content: String = fs::read_to_string(path.as_ref())?;
let file: PipelinesFile = serde_yaml::from_str(&content)?;
let mut map: HashMap<String, PipelineDefinition> = HashMap::new();
for entry in file.pipelines {
let def: PipelineDefinition = PipelineDefinition {
source: entry.source,
transform: entry.transform,
sink: entry.sink,
};
map.insert(entry.name, def);
}
Ok(map)
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct PipelineTemplateSummary {
pub name: String,
pub source_table: String,
pub sink_table: String,
pub has_transform: bool,
}
fn build_fetch_body(
source: &SourceConfig,
transform: Option<&TransformConfig>,
force_snake: bool,
) -> Value {
let table_name = source
.view_name
.as_deref()
.unwrap_or(source.table_name.as_str());
let mut body = json!({
"table_name": table_name,
});
if let Some(ref view_name) = source.view_name {
body["view_name"] = json!(view_name);
}
if let Some(ref cols) = source.columns {
body["columns"] = json!(cols);
}
if let Some(ref conditions) = source.conditions {
body["conditions"] = json!(
conditions
.iter()
.map(|c| {
let eq_column = if force_snake {
normalize_column_name(&c.eq_column, true)
} else {
c.eq_column.clone()
};
json!({ "eq_column": eq_column, "eq_value": c.eq_value.clone() })
})
.collect::<Vec<_>>()
);
}
if let Some(limit) = source.limit {
body["limit"] = json!(limit);
}
if let Some(t) = transform {
if let Some(ref g) = t.group_by {
body["group_by"] = json!(if force_snake {
normalize_column_name(g, true)
} else {
g.clone()
});
}
if let Some(ref tg) = t.time_granularity {
body["time_granularity"] = json!(tg);
}
if let Some(ref ac) = t.aggregation_column {
body["aggregation_column"] = json!(if force_snake {
normalize_column_name(ac, true)
} else {
ac.clone()
});
}
if let Some(ref as_) = t.aggregation_strategy {
body["aggregation_strategy"] = json!(as_);
}
if let Some(dedup) = t.aggregation_dedup {
body["aggregation_dedup"] = json!(dedup);
}
}
body
}
fn resolve_definition(
req: &PipelineRequest,
registry: &HashMap<String, PipelineDefinition>,
) -> Result<PipelineDefinition, String> {
let base: PipelineDefinition = if let Some(ref name) = req.pipeline {
registry
.get(name)
.cloned()
.ok_or_else(|| format!("unknown pipeline '{}'", name))?
} else {
PipelineDefinition::default()
};
let source: SourceConfig = req.source.clone().unwrap_or(base.source);
let transform: Option<TransformConfig> = req.transform.clone().or(base.transform);
let sink: SinkConfig = req.sink.clone().unwrap_or(base.sink);
if source.table_name.is_empty() {
return Err("source.table_name is required".to_string());
}
if sink.table_name.is_empty() {
return Err("sink.table_name is required".to_string());
}
Ok(PipelineDefinition {
source,
transform,
sink,
})
}
fn resolve_client_name(override_client: Option<&str>, fallback: &str) -> String {
let trimmed = override_client
.map(str::trim)
.filter(|value| !value.is_empty())
.unwrap_or(fallback);
trimmed.to_string()
}
fn apply_column_aliases(rows: Vec<Value>, aliases: Option<&HashMap<String, String>>) -> Vec<Value> {
let Some(alias_map) = aliases else {
return rows;
};
if alias_map.is_empty() {
return rows;
}
rows.into_iter()
.map(|row| {
let Some(obj) = row.as_object() else {
return row;
};
let mut rewritten = serde_json::Map::with_capacity(obj.len());
for (key, value) in obj {
let new_key = alias_map.get(key).cloned().unwrap_or_else(|| key.clone());
rewritten.insert(new_key, value.clone());
}
Value::Object(rewritten)
})
.collect()
}
fn request_for_client(req: &HttpRequest, client_name: &str) -> HttpRequest {
let mut builder = TestRequest::default()
.method(req.method().clone())
.uri(req.uri().path());
for (name, value) in req.headers() {
if name.as_str().eq_ignore_ascii_case("x-athena-client") {
continue;
}
builder = builder.insert_header((name.clone(), value.clone()));
}
builder
.insert_header(("X-Athena-Client", client_name))
.to_http_request()
}
async fn insert_one_row(
app_state: &Data<AppState>,
client_name: &str,
table_name: &str,
row: &Value,
) -> Result<Value, String> {
if let Some(pool) = app_state.pg_registry.get_pool(client_name) {
insert_row(&pool, table_name, row)
.await
.map_err(|e| format!("{:?}", e))
} else {
insert(table_name.to_string(), row.clone(), client_name)
.await
.map(|(v, _)| v)
.map_err(|e| format!("{:?}", e))
}
}
async fn execute_pipeline(
req: HttpRequest,
body: Option<Json<Value>>,
app_state: Data<AppState>,
) -> HttpResponse {
let operation_start: Instant = Instant::now();
let client_name: String = x_athena_client(&req);
if client_name.is_empty() {
return bad_request(
"Missing required header",
"X-Athena-Client header is required",
);
}
let body_value: Value = match body {
Some(b) => b.into_inner(),
None => {
return bad_request("Bad request", "Request body is required");
}
};
let pipeline_req: PipelineRequest = match serde_json::from_value(body_value.clone()) {
Ok(r) => r,
Err(e) => {
return bad_request("Invalid pipeline request", e.to_string());
}
};
let registry: &HashMap<String, PipelineDefinition> = match &app_state.pipeline_registry {
Some(r) => r.as_ref(),
None => &HashMap::new(),
};
let def = match resolve_definition(&pipeline_req, registry) {
Ok(d) => d,
Err(e) => return bad_request("Invalid pipeline definition", e),
};
let source_client: String = resolve_client_name(def.source.client.as_deref(), &client_name);
let sink_client: String = resolve_client_name(def.sink.client.as_deref(), &client_name);
let dry_run: bool = pipeline_req.dry_run.unwrap_or(false);
let auth = authorize_gateway_request(
&req,
app_state.get_ref(),
Some(&client_name),
vec![
read_right_for_resource(Some(&def.source.table_name)),
write_right_for_resource(Some(&def.sink.table_name)),
],
)
.await;
let logged_request: LoggedRequest = log_request(
req.clone(),
Some(app_state.get_ref()),
Some(auth.request_id.clone()),
Some(&auth.log_context),
);
if let Some(resp) = auth.response {
return resp;
}
let force_snake: bool = app_state.gateway_force_camel_case_to_snake_case;
let fetch_body: Value = build_fetch_body(&def.source, def.transform.as_ref(), force_snake);
let source_req: HttpRequest = request_for_client(&req, &source_client);
let fetch_response: HttpResponse =
handle_fetch_data_route(source_req, Some(Json(fetch_body)), app_state.clone()).await;
let status: StatusCode = fetch_response.status();
let body: BoxBody = fetch_response.into_body();
let response_bytes: Bytes = actix_web::body::to_bytes(body).await.unwrap_or_default();
if !status.is_success() {
let err_json: Value = serde_json::from_slice(&response_bytes)
.unwrap_or_else(|_| json!({ "error": "fetch failed" }));
log_operation_event(
Some(app_state.get_ref()),
&logged_request,
"pipeline_fetch",
Some(&def.sink.table_name),
operation_start.elapsed().as_millis(),
status,
Some(json!({
"status": status.as_u16(),
"error": err_json,
})),
);
return HttpResponse::build(status).json(err_json);
}
let response_json: Value = match serde_json::from_slice(&response_bytes) {
Ok(v) => v,
Err(e) => {
error!("Failed to parse fetch response as JSON: {}", e);
log_operation_event(
Some(app_state.get_ref()),
&logged_request,
"pipeline_parse",
Some(&def.sink.table_name),
operation_start.elapsed().as_millis(),
StatusCode::INTERNAL_SERVER_ERROR,
Some(json!({
"error": e.to_string(),
})),
);
return internal_error("Pipeline fetch response was not valid JSON", e.to_string());
}
};
let rows = response_json
.get("data")
.and_then(Value::as_array)
.cloned()
.unwrap_or_default();
let transformed_rows: Vec<Value> = apply_column_aliases(
rows,
def.transform
.as_ref()
.and_then(|t| t.column_aliases.as_ref()),
);
let sink_table: String = def.sink.table_name.clone();
let mut inserted: Vec<Value> = Vec::with_capacity(transformed_rows.len());
let mut errors: Vec<Value> = Vec::new();
if !dry_run {
for (i, row) in transformed_rows.iter().enumerate() {
match insert_one_row(&app_state, &sink_client, &sink_table, row).await {
Ok(value) => inserted.push(value),
Err(e) => {
errors.push(json!({ "index": i, "error": e }));
info!(
pipeline_sink = %sink_table,
index = i,
error = %e,
"pipeline insert row failed"
);
}
}
}
}
info!(
client = %client_name,
source_client = %source_client,
sink_client = %sink_client,
source = %def.source.table_name,
sink = %sink_table,
rows_fetched = transformed_rows.len(),
rows_inserted = inserted.len(),
errors = errors.len(),
dry_run = dry_run,
"pipeline run finished"
);
log_operation_event(
Some(app_state.get_ref()),
&logged_request,
"pipeline",
Some(&sink_table),
operation_start.elapsed().as_millis(),
StatusCode::OK,
Some(json!({
"rows_fetched": transformed_rows.len(),
"rows_inserted": inserted.len(),
"source_client": source_client.clone(),
"sink_client": sink_client.clone(),
"dry_run": dry_run
})),
);
let response_data = if dry_run {
transformed_rows
.iter()
.take(100)
.cloned()
.collect::<Vec<_>>()
} else {
inserted.clone()
};
HttpResponse::Ok().json(json!({
"data": response_data,
"rows_fetched": transformed_rows.len(),
"rows_inserted": inserted.len(),
"errors": errors,
"would_insert": if dry_run { json!(transformed_rows.len()) } else { Value::Null },
"dry_run": dry_run,
"source_client": source_client,
"sink_client": sink_client,
}))
}
#[post("/pipelines")]
pub async fn run_pipeline(
req: HttpRequest,
body: Option<Json<Value>>,
app_state: Data<AppState>,
) -> impl Responder {
execute_pipeline(req, body, app_state).await
}
#[post("/pipelines/simulate")]
pub async fn simulate_pipeline(
req: HttpRequest,
body: Option<Json<Value>>,
app_state: Data<AppState>,
) -> impl Responder {
let mut payload = body.map(|b| b.into_inner()).unwrap_or_else(|| json!({}));
payload["dry_run"] = json!(true);
execute_pipeline(req, Some(Json(payload)), app_state).await
}
#[get("/pipelines/templates")]
pub async fn list_pipeline_templates(app_state: Data<AppState>) -> impl Responder {
let registry: &HashMap<String, PipelineDefinition> = match &app_state.pipeline_registry {
Some(r) => r.as_ref(),
None => &HashMap::new(),
};
let mut templates: Vec<PipelineTemplateSummary> = registry
.iter()
.map(|(name, def)| PipelineTemplateSummary {
name: name.clone(),
source_table: def.source.table_name.clone(),
sink_table: def.sink.table_name.clone(),
has_transform: def.transform.is_some(),
})
.collect();
templates.sort_by(|a, b| a.name.cmp(&b.name));
HttpResponse::Ok().json(json!({ "templates": templates }))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn resolve_client_name_prefers_override() {
let client = resolve_client_name(Some("analytics_client"), "default_client");
assert_eq!(client, "analytics_client");
}
#[test]
fn resolve_client_name_falls_back_to_header_client() {
let client = resolve_client_name(Some(" "), "default_client");
assert_eq!(client, "default_client");
}
#[test]
fn apply_column_aliases_rewrites_object_keys() {
let rows = vec![json!({
"email": "test@example.com",
"created_at": "2026-01-01"
})];
let aliases = HashMap::from([
("email".to_string(), "user_email".to_string()),
("created_at".to_string(), "createdAt".to_string()),
]);
let out = apply_column_aliases(rows, Some(&aliases));
let first = out.first().and_then(Value::as_object).unwrap();
assert!(first.contains_key("user_email"));
assert!(first.contains_key("createdAt"));
assert!(!first.contains_key("email"));
assert!(!first.contains_key("created_at"));
}
}