pub mod maybe_tls;
mod socket_addr;
use std::{
fmt::Write,
future::Future,
io,
net::{IpAddr, SocketAddr, SocketAddrV4},
};
use flume::{Receiver, Sender};
use futures_core::future::BoxFuture;
use hyper::server::conn::http1::Connection;
use semver::Version;
use sqlx_core::{
net::WithSocket,
sql_str::{AssertSqlSafe, SqlSafeStr},
};
use crate::{
connection::websocket::{
future::{ExaRoundtrip, GetHosts, WebSocketFuture},
request::Execute,
socket::{ExaSocket, WithExaSocket},
},
etl::{
job::{maybe_tls::WithMaybeTlsSocketMaker, socket_addr::WithSocketAddr},
query::ExecuteEtl,
server::{OneShotService, WithHttpServer},
EtlQuery,
},
ExaConnection, SqlxResult,
};
pub type ServerBootstrap = BoxFuture<'static, io::Result<()>>;
pub type OneShotServer<S> = Connection<ExaSocket, S>;
pub trait EtlJob: Sized + Send + Sync {
const DEFAULT_BUF_SIZE: usize = 65536;
const GZ_FILE_EXT: &'static str = "gz";
const CSV_FILE_EXT: &'static str = "csv";
const HTTP_SCHEME: &'static str = "http";
const HTTPS_SCHEME: &'static str = "https";
const PK_FP_VER: Version = Version::new(8, 32, 0);
const JOB_TYPE: &'static str;
type Worker: Send;
type Service: OneShotService;
type DataPipe: Send + 'static;
fn use_compression(&self) -> Option<bool>;
fn num_workers(&self) -> usize;
fn create_worker(
&self,
parts_rx: Receiver<(Self::DataPipe, OneShotServer<Self::Service>)>,
with_compression: bool,
) -> Self::Worker;
fn create_service(&self, chan_tx: Sender<Self::DataPipe>) -> Self::Service;
fn connect(
&self,
wsm: WithMaybeTlsSocketMaker,
ips: Vec<IpAddr>,
port: u16,
with_compression: bool,
) -> impl Future<Output = SqlxResult<JobComponents<Self::Worker>>> + Send {
async move {
let num = self.num_workers();
let num = if num > 0 { num } else { ips.len() };
let (parts_tx, parts_rx) = flume::bounded(0);
let (chan_tx, chan_rx) = flume::bounded(0);
let mut addrs = Vec::with_capacity(num);
let mut workers = Vec::with_capacity(num);
let mut conn_futures = Vec::with_capacity(num);
for ip in ips.into_iter().take(num) {
let service: <Self as EtlJob>::Service = self.create_service(chan_tx.clone());
let with_exa_socket = WithExaSocket(SocketAddr::new(ip, port));
let with_maybe_tls_socket = wsm.make_with_socket(with_exa_socket);
let with_http_server = WithHttpServer::new(
with_maybe_tls_socket,
service,
chan_rx.clone(),
parts_tx.clone(),
);
let with_socket = WithSocketAddr(with_http_server);
let (addr, conn_future) =
sqlx_core::net::connect_tcp(&ip.to_string(), port, with_socket).await??;
let worker = self.create_worker(parts_rx.clone(), with_compression);
addrs.push(addr);
workers.push(worker);
conn_futures.push(conn_future);
}
Ok(JobComponents {
addrs,
workers,
conn_futures,
})
}
}
fn query(
&self,
addrs: Vec<SocketAddrV4>,
with_tls: bool,
with_compression: bool,
public_key: Option<String>,
) -> String;
fn build_job<'a, 'c>(
&'a self,
conn: &'c mut ExaConnection,
) -> impl Future<Output = SqlxResult<(EtlQuery<'c>, Vec<Self::Worker>)>> + Send
where
'c: 'a,
{
async {
let socket_addr = conn.server();
let port = socket_addr.port();
let ips = GetHosts::new(socket_addr.ip())
.future(&mut conn.ws)
.await?
.into();
let with_pub_key = conn.session_info().release_version() >= &Self::PK_FP_VER;
let with_tls = conn.attributes().encryption_enabled();
let with_compression = self
.use_compression()
.unwrap_or(conn.attributes().compression_enabled());
let (wsm, public_key) = WithMaybeTlsSocketMaker::new(with_tls, with_pub_key)?;
let JobComponents {
addrs,
workers,
conn_futures,
} = self.connect(wsm, ips, port, with_compression).await?;
let query = AssertSqlSafe(self.query(addrs, with_tls, with_compression, public_key))
.into_sql_str();
let query_future = ExecuteEtl(ExaRoundtrip::new(Execute(query))).future(&mut conn.ws);
Ok((EtlQuery::new(query_future, conn_futures), workers))
}
}
fn append_files(
query: &mut String,
addrs: Vec<SocketAddrV4>,
with_tls: bool,
with_compression: bool,
public_key: Option<String>,
) {
let prefix = if with_tls {
Self::HTTPS_SCHEME
} else {
Self::HTTP_SCHEME
};
let ext = if with_compression {
Self::GZ_FILE_EXT
} else {
Self::CSV_FILE_EXT
};
let public_key = public_key
.map(|pk| format!("PUBLIC KEY 'sha256//{pk}'"))
.unwrap_or_default();
for (idx, addr) in addrs.into_iter().enumerate() {
writeln!(
query,
"AT '{prefix}://{addr}' {public_key} FILE '{job_type}_{idx:0>5}.{ext}'",
job_type = Self::JOB_TYPE
)
.expect("writing to a String cannot fail");
}
}
fn push_comment(query: &mut String, comment: &str) {
query.push_str("/*\n");
query.push_str(comment);
query.push_str("*/\n");
}
fn push_ident(query: &mut String, ident: &str) {
query.push('"');
query.push_str(ident);
query.push('"');
}
fn push_literal(query: &mut String, lit: &str) {
query.push('\'');
query.push_str(lit);
query.push_str("' ");
}
fn push_key_value(query: &mut String, key: &str, value: &str) {
query.push_str(key);
query.push_str(" = ");
Self::push_literal(query, value);
}
}
struct JobComponents<W> {
addrs: Vec<SocketAddrV4>,
workers: Vec<W>,
conn_futures: Vec<ServerBootstrap>,
}
pub trait WithSocketMaker: Send + Sync {
type WithSocket: WithSocket<Output = io::Result<ExaSocket>> + Send;
fn make_with_socket(&self, with_socket: WithExaSocket) -> Self::WithSocket;
}