use std::sync::Arc;
use async_trait::async_trait;
use dataflow_rs::engine::functions::AsyncFunctionHandler;
use dataflow_rs::engine::task_context::TaskContext;
use dataflow_rs::engine::task_outcome::TaskOutcome;
use serde_json::Value;
use super::connector_helpers::{
apply_output, bind_json_params, extract_output_path, profile_handler, require_db_connector,
require_str_field, resolve_connector, timed_query, to_exec_error,
};
use crate::connector::ConnectorRegistry;
use crate::connector::pool_cache::SqlPoolCache;
pub struct DbWriteHandler {
pub pool_cache: Arc<SqlPoolCache>,
pub registry: Arc<ConnectorRegistry>,
}
#[async_trait]
impl AsyncFunctionHandler for DbWriteHandler {
type Input = Value;
async fn execute(
&self,
ctx: &mut TaskContext<'_>,
input: &Value,
) -> dataflow_rs::Result<TaskOutcome> {
profile_handler("db_write", input, async move {
let connector_name = require_str_field(input, "connector", "db_write")?;
let query = require_str_field(input, "query", "db_write")?;
let params = input.get("params").and_then(|v| v.as_array());
let connector_config = resolve_connector(&self.registry, connector_name).await?;
let db_config = require_db_connector(&connector_config, connector_name)?;
let pool = self
.pool_cache
.get_pool(connector_name, db_config)
.await
.map_err(to_exec_error)?;
let sqlx_query = sqlx::query(query);
let sqlx_query = if let Some(params) = params {
bind_json_params(sqlx_query, params)
} else {
sqlx_query
};
let result = timed_query(
db_config.query_timeout_ms,
"db_write",
sqlx_query.execute(&pool),
)
.await?;
let output = serde_json::json!({
"rows_affected": result.rows_affected(),
});
let output_path = extract_output_path(input);
apply_output(ctx, output_path, output);
Ok(TaskOutcome::Success)
})
.await
}
}