athena_rs 0.75.4

WIP Database API gateway
Documentation
//! Config-driven pipeline API: source → transform → sink.
//!
//! Pipelines are defined in the request (inline or by reference to a prebuilt name)
//! and reuse the gateway fetch/insert schemas and x-athena-client routing.

use actix_web::web::Data;
use actix_web::{post, HttpRequest, HttpResponse, Responder};
use actix_web::web::Json;
use serde::Deserialize;
use serde_json::{Value, json};
use std::collections::HashMap;
use std::fs;
use std::path::Path;
use tracing::{error, info};

use crate::api::gateway::fetch::handle_fetch_data_route;
use crate::api::gateway::insert::insert;
use crate::api::headers::x_athena_client::x_athena_client;
use crate::AppState;
use crate::drivers::postgresql::sqlx_driver::insert_row;
use crate::utils::format::normalize_column_name;
use crate::utils::request_logging::log_request;

/// Source stage: same shape as gateway fetch (table/view, columns, conditions).
#[derive(Debug, Clone, Default, Deserialize)]
#[serde(rename_all = "snake_case", default)]
pub struct SourceConfig {
    pub table_name: String,
    pub view_name: Option<String>,
    pub columns: Option<Vec<String>>,
    pub conditions: Option<Vec<ConditionEntry>>,
    pub limit: Option<u64>,
}

/// Single equality condition for gateway-style queries.
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct ConditionEntry {
    pub eq_column: String,
    #[serde(deserialize_with = "deserialize_eq_value")]
    pub eq_value: String,
}

fn deserialize_eq_value<'de, D>(deserializer: D) -> Result<String, D::Error>
where
    D: serde::Deserializer<'de>,
{
    let v = Value::deserialize(deserializer)?;
    Ok(match v {
        Value::Bool(b) => {
            if b {
                "true".to_string()
            } else {
                "false".to_string()
            }
        }
        Value::String(s) => s,
        Value::Number(n) => n.to_string(),
        other => other.to_string(),
    })
}

/// Transform stage: maps to PostProcessingConfig (group_by, aggregation, etc.).
#[derive(Debug, Clone, Default, Deserialize)]
#[serde(rename_all = "snake_case", default)]
pub struct TransformConfig {
    pub group_by: Option<String>,
    pub time_granularity: Option<String>,
    pub aggregation_column: Option<String>,
    pub aggregation_strategy: Option<String>,
    pub aggregation_dedup: Option<bool>,
}

/// Sink stage: target table for writes (same as gateway insert target).
#[derive(Debug, Clone, Default, Deserialize)]
#[serde(rename_all = "snake_case", default)]
pub struct SinkConfig {
    pub table_name: String,
}

/// Full pipeline definition (registry entry or inline in request).
#[derive(Debug, Clone, Default, Deserialize)]
#[serde(rename_all = "snake_case", default)]
pub struct PipelineDefinition {
    pub source: SourceConfig,
    pub transform: Option<TransformConfig>,
    pub sink: SinkConfig,
}

/// Request body: reference a prebuilt pipeline and/or supply inline overrides.
#[derive(Debug, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct PipelineRequest {
    /// Reference a prebuilt pipeline by name (merged with overrides below).
    pub pipeline: Option<String>,
    pub source: Option<SourceConfig>,
    pub transform: Option<TransformConfig>,
    pub sink: Option<SinkConfig>,
}

/// Entry in the pipelines YAML file (name + definition).
#[derive(Debug, Deserialize)]
#[serde(rename_all = "snake_case")]
struct PipelineEntry {
    name: String,
    source: SourceConfig,
    transform: Option<TransformConfig>,
    sink: SinkConfig,
}

/// Root of the pipelines config file.
#[derive(Debug, Deserialize)]
struct PipelinesFile {
    pipelines: Vec<PipelineEntry>,
}

/// Loads pipeline definitions from a YAML file into a map by name.
///
/// File format: `pipelines: [ { name: "x", source: {...}, transform?: {...}, sink: {...} }, ... ]`
pub fn load_registry_from_path<P: AsRef<Path>>(
    path: P,
) -> Result<HashMap<String, PipelineDefinition>, Box<dyn std::error::Error + Send + Sync>> {
    let content = fs::read_to_string(path.as_ref())?;
    let file: PipelinesFile = serde_yaml::from_str(&content)?;
    let mut map = HashMap::new();
    for entry in file.pipelines {
        let def = PipelineDefinition {
            source: entry.source,
            transform: entry.transform,
            sink: entry.sink,
        };
        map.insert(entry.name, def);
    }
    Ok(map)
}

/// Builds a gateway fetch JSON body from source + transform (and optional force_snake).
fn build_fetch_body(
    source: &SourceConfig,
    transform: Option<&TransformConfig>,
    force_snake: bool,
) -> Value {
    let table_name = source
        .view_name
        .as_deref()
        .unwrap_or(source.table_name.as_str());
    let mut body = json!({
        "table_name": table_name,
    });
    if let Some(ref view_name) = source.view_name {
        body["view_name"] = json!(view_name);
    }
    if let Some(ref cols) = source.columns {
        body["columns"] = json!(cols);
    }
    if let Some(ref conditions) = source.conditions {
        body["conditions"] = json!(conditions
            .iter()
            .map(|c| {
                let eq_column = if force_snake {
                    normalize_column_name(&c.eq_column, true)
                } else {
                    c.eq_column.clone()
                };
                json!({ "eq_column": eq_column, "eq_value": c.eq_value.clone() })
            })
            .collect::<Vec<_>>());
    }
    if let Some(limit) = source.limit {
        body["limit"] = json!(limit);
    }
    if let Some(t) = transform {
        if let Some(ref g) = t.group_by {
            body["group_by"] = json!(if force_snake {
                normalize_column_name(g, true)
            } else {
                g.clone()
            });
        }
        if let Some(ref tg) = t.time_granularity {
            body["time_granularity"] = json!(tg);
        }
        if let Some(ref ac) = t.aggregation_column {
            body["aggregation_column"] = json!(if force_snake {
                normalize_column_name(ac, true)
            } else {
                ac.clone()
            });
        }
        if let Some(ref as_) = t.aggregation_strategy {
            body["aggregation_strategy"] = json!(as_);
        }
        if let Some(dedup) = t.aggregation_dedup {
            body["aggregation_dedup"] = json!(dedup);
        }
    }
    body
}

/// Resolves the effective pipeline definition from request and registry.
fn resolve_definition(
    req: &PipelineRequest,
    registry: &HashMap<String, PipelineDefinition>,
) -> Result<PipelineDefinition, String> {
    let base = if let Some(ref name) = req.pipeline {
        registry
            .get(name)
            .cloned()
            .ok_or_else(|| format!("unknown pipeline '{}'", name))?
    } else {
        PipelineDefinition::default()
    };
    let source = req.source.clone().unwrap_or(base.source);
    let transform = req.transform.clone().or(base.transform);
    let sink = req.sink.clone().unwrap_or(base.sink);
    if source.table_name.is_empty() {
        return Err("source.table_name is required".to_string());
    }
    if sink.table_name.is_empty() {
        return Err("sink.table_name is required".to_string());
    }
    Ok(PipelineDefinition {
        source,
        transform,
        sink,
    })
}

/// Inserts one row into the sink using the same client routing as the gateway.
async fn insert_one_row(
    app_state: &Data<AppState>,
    client_name: &str,
    table_name: &str,
    row: &Value,
) -> Result<Value, String> {
    if let Some(pool) = app_state.pg_registry.get_pool(client_name) {
        insert_row(&pool, table_name, row)
            .await
            .map_err(|e| format!("{:?}", e))
    } else {
        insert(
            table_name.to_string(),
            row.clone(),
            client_name,
        )
        .await
        .map(|(v, _)| v)
        .map_err(|e| format!("{:?}", e))
    }
}

/// POST /pipelines: run a pipeline (inline or by reference) with x-athena-client routing.
///
/// **Headers:** `X-Athena-Client` is required (same as gateway; routes to Postgres or Supabase).
///
/// **Request body (inline):**
/// ```json
/// {
///   "source": {
///     "table_name": "users",
///     "view_name": null,
///     "columns": ["id", "email"],
///     "conditions": [{ "eq_column": "workspace_id", "eq_value": "abc" }],
///     "limit": 100
///   },
///   "transform": {
///     "group_by": "created_at",
///     "time_granularity": "day",
///     "aggregation_column": "total",
///     "aggregation_strategy": "cumulative_sum",
///     "aggregation_dedup": false
///   },
///   "sink": { "table_name": "users_backup" }
/// }
/// ```
///
/// **Request body (prebuilt reference):** `{ "pipeline": "example_copy" }` with optional
/// `source`, `transform`, `sink` overrides. Prebuilt definitions are loaded from
/// `config/pipelines.yaml`.
///
/// **Response:** `{ "data": [...], "rows_fetched": N, "rows_inserted": N, "errors": [] }`.
#[post("/pipelines")]
pub async fn run_pipeline(
    req: HttpRequest,
    body: Option<Json<Value>>,
    app_state: Data<AppState>,
) -> impl Responder {
    log_request(req.clone());
    let client_name = x_athena_client(&req);
    if client_name.is_empty() {
        return HttpResponse::BadRequest().json(json!({
            "error": "X-Athena-Client header is required"
        }));
    }

    let body_value = match body {
        Some(b) => b.into_inner(),
        None => {
            return HttpResponse::BadRequest().json(json!({
                "error": "Request body is required"
            }));
        }
    };

    let pipeline_req: PipelineRequest = match serde_json::from_value(body_value.clone()) {
        Ok(r) => r,
        Err(e) => {
            return HttpResponse::BadRequest().json(json!({
                "error": "Invalid pipeline request",
                "details": e.to_string()
            }));
        }
    };

    let registry = match &app_state.pipeline_registry {
        Some(r) => r.as_ref(),
        None => &HashMap::new(),
    };
    let def = match resolve_definition(&pipeline_req, registry) {
        Ok(d) => d,
        Err(e) => {
            return HttpResponse::BadRequest().json(json!({
                "error": e
            }));
        }
    };

    let force_snake = app_state.gateway_force_camel_case_to_snake_case;
    let fetch_body = build_fetch_body(&def.source, def.transform.as_ref(), force_snake);
    let fetch_response =
        handle_fetch_data_route(req.clone(), Some(Json(fetch_body)), app_state.clone()).await;

    let status = fetch_response.status();
    let body = fetch_response.into_body();
    let response_bytes = actix_web::body::to_bytes(body)
        .await
        .unwrap_or_default();
    if !status.is_success() {
        let err_json: Value = serde_json::from_slice(&response_bytes)
            .unwrap_or_else(|_| json!({ "error": "fetch failed" }));
        return HttpResponse::build(status).json(err_json);
    }
    let response_json: Value = match serde_json::from_slice(&response_bytes) {
        Ok(v) => v,
        Err(e) => {
            error!("Failed to parse fetch response as JSON: {}", e);
            return HttpResponse::InternalServerError().json(json!({
                "error": "Pipeline fetch response was not valid JSON",
                "details": e.to_string()
            }));
        }
    };

    let rows = response_json
        .get("data")
        .and_then(Value::as_array)
        .cloned()
        .unwrap_or_default();

    let sink_table = def.sink.table_name.clone();
    let mut inserted = Vec::with_capacity(rows.len());
    let mut errors = Vec::new();

    for (i, row) in rows.iter().enumerate() {
        match insert_one_row(&app_state, &client_name, &sink_table, row).await {
            Ok(value) => inserted.push(value),
            Err(e) => {
                errors.push(json!({ "index": i, "error": e }));
                info!(
                    pipeline_sink = %sink_table,
                    index = i,
                    error = %e,
                    "pipeline insert row failed"
                );
            }
        }
    }

    info!(
        client = %client_name,
        source = %def.source.table_name,
        sink = %sink_table,
        rows_fetched = rows.len(),
        rows_inserted = inserted.len(),
        errors = errors.len(),
        "pipeline run finished"
    );

    HttpResponse::Ok().json(json!({
        "data": inserted,
        "rows_fetched": rows.len(),
        "rows_inserted": inserted.len(),
        "errors": errors,
    }))
}