athena_rs 3.3.0

Database gateway API
Documentation
//! Accumulates Postgres `/gateway/insert` work and flushes on deadlines for batching.

use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;
use std::sync::OnceLock;

use actix_web::web::Data;
use serde_json::Value;
use tokio::sync::{Mutex, Notify, oneshot};

use crate::AppState;
use crate::drivers::postgresql::sqlx_driver::insert_rows_bulk;

/// One insert request waiting for window flush.
pub struct WindowInsertJob {
    pub trace_id: String,
    pub user_id: String,
    pub company_id: String,
    pub organization_id: String,
    pub metadata_user_id: Option<String>,
    pub metadata_company_id: Option<String>,
    pub metadata_organization_id: Option<String>,
    pub body: Value,
    pub table_name: String,
    pub insert_body: Value,
    pub resource_id_key: String,
    pub client_name: String,
    /// When false, never merged with other rows (e.g. `update_body` present).
    pub merge_eligible: bool,
    pub logged_request_id: String,
    pub logged_client_name: String,
    pub logged_method: String,
    pub logged_path: String,
    pub operation_start: std::time::Instant,
    pub verbose_logging: bool,
    pub ansi_enabled: bool,
    pub x_publish_event: bool,
    pub resolved_company_for_event: Option<String>,
    pub due: tokio::time::Instant,
}

#[derive(Clone, Eq, PartialEq, Hash, Debug)]
struct MergeLaneKey {
    client_name: String,
    table_name: String,
    column_fingerprint: String,
    merge_eligible: bool,
}

impl MergeLaneKey {
    fn for_job(job: &WindowInsertJob, deny_tables: &HashSet<String>) -> Self {
        let merge_eligible = job.merge_eligible
            && !super::insert_request_has_update_body(&job.body)
            && !table_in_denylist(&job.table_name, deny_tables);
        Self {
            client_name: job.client_name.clone(),
            table_name: job.table_name.clone(),
            column_fingerprint: insert_body_fingerprint(&job.insert_body),
            merge_eligible,
        }
    }
}

fn insert_body_fingerprint(body: &Value) -> String {
    let Some(obj) = body.as_object() else {
        return String::new();
    };
    let mut keys: Vec<&str> = obj.keys().map(|s| s.as_str()).collect();
    keys.sort_unstable();
    keys.join("\u{001f}")
}

fn table_in_denylist(table: &str, deny: &HashSet<String>) -> bool {
    if deny.is_empty() {
        return false;
    }
    let t = table.trim();
    deny.contains(t)
        || t.rsplit_once('.')
            .is_some_and(|(_, short)| deny.contains(short))
}

struct Queued {
    lane: MergeLaneKey,
    job: WindowInsertJob,
    response_tx: oneshot::Sender<super::WindowInsertOutcome>,
}

/// Tunables for the insert window (from config).
#[derive(Clone)]
pub struct InsertWindowSettings {
    pub max_batch: usize,
    pub max_queued: usize,
    pub deny_tables: HashSet<String>,
}

struct Inner {
    settings: InsertWindowSettings,
    state: Mutex<State>,
    notify: Notify,
    /// Set from bootstrap after `Data<AppState>` exists (breaks circular init).
    app_state: OnceLock<Data<AppState>>,
}

struct State {
    pending: Vec<Queued>,
}

/// Background flush worker + bounded queue for insert windowing.
pub struct InsertWindowCoordinator {
    inner: Arc<Inner>,
}

impl InsertWindowCoordinator {
    pub fn new(settings: InsertWindowSettings) -> Arc<Self> {
        let inner: Arc<Inner> = Arc::new(Inner {
            settings,
            state: Mutex::new(State {
                pending: Vec::new(),
            }),
            notify: Notify::new(),
            app_state: OnceLock::new(),
        });
        Arc::new(Self { inner })
    }

    pub fn bind_app_state(&self, app_state: Data<AppState>) {
        if self.inner.app_state.set(app_state).is_err() {
            return;
        }
        let inner: Arc<Inner> = self.inner.clone();
        tokio::spawn(worker_loop(inner));
    }

    pub(crate) async fn submit(
        &self,
        job: WindowInsertJob,
        response_tx: oneshot::Sender<super::WindowInsertOutcome>,
    ) {
        let lane: MergeLaneKey = MergeLaneKey::for_job(&job, &self.inner.settings.deny_tables);
        {
            let mut st: tokio::sync::MutexGuard<'_, State> = self.inner.state.lock().await;
            if st.pending.len() < self.inner.settings.max_queued {
                st.pending.push(Queued {
                    lane,
                    job,
                    response_tx,
                });
                drop(st);
                self.inner.notify.notify_one();
                return;
            }
        }

        let Some(app) = self.inner.app_state.get().cloned() else {
            let _ = response_tx.send(super::window_insert_internal_error(
                "Insert window coordinator is not initialized",
            ));
            return;
        };
        app.metrics_state
            .record_gateway_insert_window_event("fallback_queue_full");
        let outcome: super::WindowInsertOutcome =
            super::run_postgres_insert_to_outcome(app, job).await;
        let _ = response_tx.send(outcome);
    }
}

async fn worker_loop(inner: Arc<Inner>) {
    loop {
        let Some(app) = inner.app_state.get().cloned() else {
            tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
            continue;
        };
        let next_deadline: Option<tokio::time::Instant> = {
            let st: tokio::sync::MutexGuard<'_, State> = inner.state.lock().await;
            st.pending.iter().map(|q| q.job.due).min()
        };
        if next_deadline.is_none() {
            inner.notify.notified().await;
            continue;
        }
        tokio::select! {
            _ = tokio::time::sleep_until(next_deadline.unwrap()) => {}
            _ = inner.notify.notified() => { continue; }
        }

        let due_batch: Vec<Queued> = {
            let mut st: tokio::sync::MutexGuard<'_, State> = inner.state.lock().await;
            let now: tokio::time::Instant = tokio::time::Instant::now();
            let mut remain: Vec<Queued> = Vec::new();
            let mut ready: Vec<Queued> = Vec::new();
            for q in st.pending.drain(..) {
                if q.job.due <= now {
                    ready.push(q);
                } else {
                    remain.push(q);
                }
            }
            st.pending = remain;
            ready
        };

        if due_batch.is_empty() {
            continue;
        }

        let mut by_lane: HashMap<MergeLaneKey, Vec<Queued>> = HashMap::new();
        for q in due_batch {
            by_lane.entry(q.lane.clone()).or_default().push(q);
        }

        let settings: InsertWindowSettings = inner.settings.clone();
        for (lane, group) in by_lane {
            process_lane_group(app.clone(), lane, group, &settings).await;
        }
    }
}

async fn process_lane_group(
    app: Data<AppState>,
    lane: MergeLaneKey,
    mut group: Vec<Queued>,
    settings: &InsertWindowSettings,
) {
    if group.len() == 1 || !lane.merge_eligible {
        app.metrics_state
            .record_gateway_insert_window_event("flush_single");
        for q in group.drain(..) {
            let outcome: super::WindowInsertOutcome =
                super::run_postgres_insert_to_outcome(app.clone(), q.job).await;
            let _ = q.response_tx.send(outcome);
        }
        return;
    }

    let pool_opt = app.pg_registry.get_pool(&lane.client_name);
    let Some(pool) = pool_opt else {
        app.metrics_state
            .record_gateway_insert_window_event("flush_single_no_pool");
        for q in group.drain(..) {
            let outcome: super::WindowInsertOutcome =
                super::run_postgres_insert_to_outcome(app.clone(), q.job).await;
            let _ = q.response_tx.send(outcome);
        }
        return;
    };

    while !group.is_empty() {
        let chunk_len: usize = settings.max_batch.min(group.len());
        let chunk: Vec<Queued> = group.drain(..chunk_len).collect();
        let payloads: Vec<Value> = chunk.iter().map(|q| q.job.insert_body.clone()).collect();
        match insert_rows_bulk(&pool, &lane.table_name, &payloads).await {
            Ok(rows) => {
                if rows.len() != chunk.len() {
                    app.metrics_state
                        .record_gateway_insert_window_event("bulk_row_count_mismatch");
                    for q in chunk {
                        let outcome: super::WindowInsertOutcome =
                            super::run_postgres_insert_to_outcome(app.clone(), q.job).await;
                        let _ = q.response_tx.send(outcome);
                    }
                    continue;
                }
                app.metrics_state
                    .record_gateway_insert_window_event("flush_bulk");
                let mut any_invalidate: bool = false;
                for row in &rows {
                    if super::should_invalidate_cache_after_insert(row) {
                        any_invalidate = true;
                        break;
                    }
                }
                if any_invalidate {
                    let _ = crate::api::cache::invalidation::invalidate_scoped_gateway_cache(
                        app.clone(),
                        &lane.client_name,
                        &lane.table_name,
                    )
                    .await;
                }
                app.metrics_state
                    .record_gateway_postgres_backend("/gateway/insert", "sqlx");
                for (q, row) in chunk.into_iter().zip(rows.into_iter()) {
                    let body: Value =
                        super::finish_postgres_insert_success_json(app.clone(), &q.job, row).await;
                    let _ = q
                        .response_tx
                        .send(super::WindowInsertOutcome::Success(body));
                }
            }
            Err(_err) => {
                app.metrics_state
                    .record_gateway_insert_window_event("bulk_fallback_sql_error");
                for q in chunk {
                    let outcome: super::WindowInsertOutcome =
                        super::run_postgres_insert_to_outcome(app.clone(), q.job).await;
                    let _ = q.response_tx.send(outcome);
                }
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::{insert_body_fingerprint, table_in_denylist};
    use serde_json::json;
    use std::collections::HashSet;

    #[test]
    fn insert_body_fingerprint_sorts_object_keys() {
        let body = json!({ "b": 1, "a": 2 });
        assert_eq!(insert_body_fingerprint(&body), "a\u{001f}b");
    }

    #[test]
    fn insert_body_fingerprint_non_object_empty() {
        assert_eq!(insert_body_fingerprint(&json!([])), "");
    }

    #[test]
    fn denylist_matches_short_name_and_schema_qualified() {
        let mut deny: HashSet<String> = HashSet::new();
        deny.insert("items".to_string());
        assert!(table_in_denylist("items", &deny));
        assert!(table_in_denylist("public.items", &deny));
        assert!(!table_in_denylist("other", &deny));
    }
}