use std::sync::Arc;
use toasty_core::driver::{Connection, Rows};
use toasty_core::stmt::Value;
use tokio::{
sync::{mpsc, oneshot},
task::JoinHandle,
};
use super::pool::SweepWaker;
use crate::engine::Engine;
pub(crate) enum ConnectionOperation {
ExecStatement {
stmt: Box<toasty_core::stmt::Statement>,
in_transaction: bool,
tx: oneshot::Sender<crate::Result<toasty_core::driver::ExecResponse>>,
},
ExecOperation {
operation: Box<toasty_core::driver::operation::Operation>,
tx: oneshot::Sender<crate::Result<toasty_core::driver::ExecResponse>>,
},
PushSchema {
tx: oneshot::Sender<crate::Result<()>>,
},
Ping {
tx: oneshot::Sender<crate::Result<()>>,
},
}
pub(crate) struct ConnectionHandle {
pub(crate) in_tx: mpsc::UnboundedSender<ConnectionOperation>,
join_handle: JoinHandle<()>,
}
impl ConnectionHandle {
pub(crate) fn spawn(
connection: Box<dyn Connection>,
engine: Engine,
sweep_waker: Arc<SweepWaker>,
) -> Self {
let (in_tx, in_rx) = mpsc::unbounded_channel::<ConnectionOperation>();
let task = ConnectionTask {
connection,
engine,
in_rx,
sweep_waker,
};
let join_handle = tokio::spawn(task.run());
Self { in_tx, join_handle }
}
pub(crate) fn is_finished(&self) -> bool {
self.join_handle.is_finished()
}
}
impl std::fmt::Debug for ConnectionHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConnectionHandle")
.field("channel_closed", &self.in_tx.is_closed())
.field("task_finished", &self.join_handle.is_finished())
.finish()
}
}
struct ConnectionTask {
connection: Box<dyn Connection>,
engine: Engine,
in_rx: mpsc::UnboundedReceiver<ConnectionOperation>,
sweep_waker: Arc<SweepWaker>,
}
impl ConnectionTask {
async fn run(mut self) {
while let Some(op) = self.in_rx.recv().await {
if !self.handle(op).await {
return;
}
}
}
async fn handle(&mut self, op: ConnectionOperation) -> bool {
match op {
ConnectionOperation::ExecStatement {
stmt,
in_transaction,
tx,
} => {
let result = self.exec_statement(*stmt, in_transaction).await;
self.respond(tx, result)
}
ConnectionOperation::ExecOperation { operation, tx } => {
let result = self.connection.exec(&self.engine.schema, *operation).await;
self.respond(tx, result)
}
ConnectionOperation::PushSchema { tx } => {
let result = self.connection.push_schema(&self.engine.schema).await;
self.respond(tx, result)
}
ConnectionOperation::Ping { tx } => {
let result = self.connection.ping().await;
self.respond(tx, result)
}
}
}
async fn exec_statement(
&mut self,
stmt: toasty_core::stmt::Statement,
in_transaction: bool,
) -> crate::Result<toasty_core::driver::ExecResponse> {
let single = stmt.is_single();
let mut response = self
.engine
.exec(&mut *self.connection, stmt, in_transaction)
.await?;
response.values.buffer().await?;
if single {
let Rows::Value(Value::List(mut items)) = response.values else {
unreachable!()
};
assert!(
items.len() <= 1,
"expected at most 1 row for single statement, got {}",
items.len()
);
response.values = Rows::Value(items.pop().unwrap_or(Value::Null));
}
Ok(response)
}
fn respond<T>(&mut self, tx: oneshot::Sender<T>, result: T) -> bool {
if self.connection.is_valid() {
let _ = tx.send(result);
true
} else {
tracing::debug!("connection reported invalid; closing channel and exiting");
self.in_rx.close();
self.sweep_waker.wake();
let _ = tx.send(result);
false
}
}
}