use crate::error::{ConnectorError as Error, Result};
use crate::single_row::{fetch_single_row, SingleRowExpectation};
use crate::{ConnectorPoolOptions, Executor, MysqlRowStream, Row};
use futures::future::BoxFuture;
use nautilus_core::Value;
use nautilus_dialect::Sql;
use sqlx::mysql::{MySqlConnectOptions, MySqlPool, MySqlPoolOptions};
pub struct MysqlExecutor {
pool: MySqlPool,
}
impl MysqlExecutor {
pub async fn new(url: &str) -> Result<Self> {
Self::new_with_options(url, ConnectorPoolOptions::default()).await
}
pub async fn new_with_options(url: &str, pool_options: ConnectorPoolOptions) -> Result<Self> {
let connect_options = pool_options.apply_to_mysql_connect_options(
url.parse::<MySqlConnectOptions>()
.map_err(|e| Error::connection(e, "Invalid MySQL connection options"))?,
);
let pool = pool_options
.apply_to(MySqlPoolOptions::new().max_connections(5))
.connect_with(connect_options)
.await
.map_err(|e| Error::connection(e, "Failed to connect to MySQL"))?;
Ok(Self { pool })
}
pub fn pool(&self) -> &MySqlPool {
&self.pool
}
pub async fn execute_raw(&self, sql: &str) -> Result<()> {
sqlx::query(sql)
.persistent(false)
.execute(&self.pool)
.await
.map(|_| ())
.map_err(|e| Error::database(e, "DDL error"))
}
fn execute_collect_internal<'conn>(
&'conn self,
sql: &'conn Sql,
) -> BoxFuture<'conn, Result<Vec<Row>>> {
Box::pin(async move {
let mut conn = self
.pool
.acquire()
.await
.map_err(|e| Error::connection(e, "Failed to acquire connection"))?;
let mut query = sqlx::query(&sql.text);
for param in &sql.params {
query = bind_value(query, param)?;
}
let mysql_rows = query
.fetch_all(&mut *conn)
.await
.map_err(|e| Error::database(e, "Query execution failed"))?;
drop(conn);
crate::mysql_stream::decode_rows(&mysql_rows)
})
}
fn execute_and_fetch_collect_internal<'conn>(
&'conn self,
mutation: &'conn Sql,
fetch: &'conn Sql,
) -> BoxFuture<'conn, Result<Vec<Row>>> {
Box::pin(async move {
use sqlx::Executor as _;
let mut conn = self
.pool
.acquire()
.await
.map_err(|e| Error::connection(e, "Failed to acquire connection"))?;
let mut mutation_query = sqlx::query(&mutation.text);
for param in &mutation.params {
mutation_query = bind_value(mutation_query, param)?;
}
(&mut *conn)
.execute(mutation_query)
.await
.map_err(|e| Error::database(e, "Mutation failed"))?;
let mut fetch_query = sqlx::query(&fetch.text);
for param in &fetch.params {
fetch_query = bind_value(fetch_query, param)?;
}
let mysql_rows = fetch_query
.fetch_all(&mut *conn)
.await
.map_err(|e| Error::database(e, "Fetch failed"))?;
drop(conn);
crate::mysql_stream::decode_rows(&mysql_rows)
})
}
impl_execute_affected!();
}
impl Executor for MysqlExecutor {
type Row<'conn>
= Row
where
Self: 'conn;
type RowStream<'conn>
= MysqlRowStream<'conn>
where
Self: 'conn;
fn execute<'conn>(&'conn self, sql: &'conn Sql) -> Self::RowStream<'conn> {
crate::streaming::spawn_streaming_query(crate::streaming::StreamingQuery::<
sqlx::MySql,
_,
_,
> {
pool: self.pool.clone(),
sql_text: sql.text.clone(),
params: sql.params.clone(),
bind: bind_value,
decode: crate::mysql_stream::streaming_decoder(),
query_context: "Query execution failed",
persistent: true,
})
}
fn execute_owned(&self, sql: Sql) -> crate::row_stream::RowStream<'static> {
crate::streaming::spawn_streaming_query(crate::streaming::StreamingQuery::<
sqlx::MySql,
_,
_,
> {
pool: self.pool.clone(),
sql_text: sql.text,
params: sql.params,
bind: bind_value,
decode: crate::mysql_stream::streaming_decoder(),
query_context: "Query execution failed",
persistent: true,
})
}
fn execute_and_fetch<'conn>(
&'conn self,
mutation: &'conn Sql,
fetch: &'conn Sql,
) -> Self::RowStream<'conn> {
MysqlRowStream::from_rows_future(self.execute_and_fetch_collect_internal(mutation, fetch))
}
fn execute_collect<'conn>(
&'conn self,
sql: &'conn Sql,
) -> BoxFuture<'conn, Result<Vec<Self::Row<'conn>>>>
where
Self: 'conn,
{
self.execute_collect_internal(sql)
}
fn execute_one<'conn>(
&'conn self,
sql: &'conn Sql,
) -> BoxFuture<'conn, Result<Self::Row<'conn>>>
where
Self: 'conn,
{
Box::pin(async move {
let mut conn = self
.pool
.acquire()
.await
.map_err(|e| Error::connection(e, "Failed to acquire connection"))?;
let row = fetch_single_row::<sqlx::MySql, _, _, _>(
&mut *conn,
&sql.text,
&sql.params,
bind_value,
crate::mysql_stream::decode_row_internal,
"Query execution failed",
SingleRowExpectation::ExactlyOne,
)
.await?;
drop(conn);
row.ok_or_else(|| Error::database_msg("Expected exactly one row, got 0"))
})
}
fn execute_optional<'conn>(
&'conn self,
sql: &'conn Sql,
) -> BoxFuture<'conn, Result<Option<Self::Row<'conn>>>>
where
Self: 'conn,
{
Box::pin(async move {
let mut conn = self
.pool
.acquire()
.await
.map_err(|e| Error::connection(e, "Failed to acquire connection"))?;
let row = fetch_single_row::<sqlx::MySql, _, _, _>(
&mut *conn,
&sql.text,
&sql.params,
bind_value,
crate::mysql_stream::decode_row_internal,
"Query execution failed",
SingleRowExpectation::ZeroOrOne,
)
.await?;
drop(conn);
Ok(row)
})
}
}
pub(crate) fn bind_value<'q>(
query: sqlx::query::Query<'q, sqlx::MySql, sqlx::mysql::MySqlArguments>,
value: &'q Value,
) -> Result<sqlx::query::Query<'q, sqlx::MySql, sqlx::mysql::MySqlArguments>> {
match value {
Value::Null => Ok(query.bind(None::<String>)),
Value::Bool(b) => Ok(query.bind(b)),
Value::I32(i) => Ok(query.bind(i)),
Value::I64(i) => Ok(query.bind(i)),
Value::F64(f) => Ok(query.bind(f)),
Value::Decimal(d) => Ok(query.bind(d.to_string())),
Value::DateTime(dt) => Ok(query.bind(dt.format("%Y-%m-%dT%H:%M:%S%.f").to_string())),
Value::Uuid(u) => Ok(query.bind(u.to_string())),
Value::String(s) => Ok(query.bind(s.as_str())),
Value::Geometry(raw) | Value::Geography(raw) => Ok(query.bind(raw.as_str())),
Value::Hstore(_) => Err(Error::database_msg(
"HSTORE values are only supported on PostgreSQL",
)),
Value::Vector(_) => Err(Error::database_msg(
"VECTOR values are only supported on PostgreSQL",
)),
Value::Bytes(b) => Ok(query.bind(b.as_slice())),
Value::Json(j) => Ok(query.bind(j.to_string())),
Value::Array(_) => Ok(query.bind(crate::utils::value_to_json(value).to_string())),
Value::Enum { value, .. } => Ok(query.bind(value.as_str())),
Value::Array2D(_) => Ok(query.bind(crate::utils::value_to_json(value).to_string())),
Value::Composite { .. } => Err(Error::database_msg(
"native composite-type values are only supported on PostgreSQL",
)),
}
}