use std::future::Future;
use std::sync::Arc;
use bytes::BytesMut;
use edgedb_protocol::QueryResult;
use edgedb_protocol::common::CompilationOptions;
use edgedb_protocol::common::{IoFormat, Capabilities, Cardinality};
use edgedb_protocol::model::Json;
use edgedb_protocol::query_arg::{QueryArgs, Encoder};
use tokio::sync::oneshot;
use tokio::time::sleep;
use crate::errors::{ClientError};
use crate::errors::{Error, ErrorKind, SHOULD_RETRY};
use crate::errors::{ProtocolEncodingError, NoResultExpected, NoDataError};
use crate::raw::{Pool, PoolConnection, Options, PoolState};
#[derive(Debug)]
pub struct Transaction {
iteration: u32,
state: Arc<PoolState>,
inner: Option<Inner>,
}
#[derive(Debug)]
pub struct TransactionResult {
conn: PoolConnection,
started: bool,
}
#[derive(Debug)]
pub struct Inner {
started: bool,
conn: PoolConnection,
return_conn: oneshot::Sender<TransactionResult>,
}
trait Assert: Send {}
impl Assert for Transaction {}
impl Drop for Transaction {
fn drop(&mut self) {
self.inner.take().map(|Inner { started, conn, return_conn }| {
return_conn.send(TransactionResult {
started,
conn,
}).ok()
});
}
}
pub(crate) async fn transaction<T, B, F>(
pool: &Pool,
options: &Options,
mut body: B,
) -> Result<T, Error>
where
B: FnMut(Transaction) -> F,
F: Future<Output = Result<T, Error>>,
{
let mut iteration = 0;
'transaction: loop {
let conn = pool.acquire().await?;
let (tx, mut rx) = oneshot::channel();
let tran = Transaction {
iteration,
state: options.state.clone(),
inner: Some(Inner {
started: false,
conn,
return_conn: tx,
}),
};
let result = body(tran).await;
let TransactionResult { mut conn, started } = rx.try_recv().expect(
"Transaction object must \
be dropped by the time transaction body finishes.",
);
match result {
Ok(val) => {
log::debug!("Comitting transaction");
if started {
conn.statement("COMMIT", &options.state).await?;
}
return Ok(val);
}
Err(outer) => {
log::debug!("Rolling back transaction on error");
if started {
conn.statement("ROLLBACK", &options.state).await?;
}
let some_retry = outer.chain().find_map(|e| {
e.downcast_ref::<Error>().and_then(|e| {
if e.has_tag(SHOULD_RETRY) {
Some(e)
} else {
None
}
})
});
if some_retry.is_none() {
return Err(outer);
} else {
let e = some_retry.unwrap();
let rule = options.retry.get_rule(e);
if iteration >= rule.attempts {
return Err(outer);
} else {
log::info!("Retrying transaction on {:#}", e);
iteration += 1;
sleep((rule.backoff)(iteration)).await;
continue 'transaction;
}
}
}
}
}
}
impl Transaction {
pub fn iteration(&self) -> u32 {
self.iteration
}
async fn ensure_started(&mut self) -> anyhow::Result<(), Error> {
if let Some(inner) = &mut self.inner {
if !inner.started {
inner.conn.statement("START TRANSACTION", &self.state).await?;
inner.started = true;
}
return Ok(());
}
Err(ClientError::with_message("using transaction after drop"))
}
fn inner(&mut self) -> &mut Inner {
self.inner.as_mut().expect("transaction object is not dropped")
}
pub async fn query<R, A>(&mut self, query: &str, arguments: &A)
-> Result<Vec<R>, Error>
where A: QueryArgs,
R: QueryResult,
{
self.ensure_started().await?;
let flags = CompilationOptions {
implicit_limit: None,
implicit_typenames: false,
implicit_typeids: false,
explicit_objectids: true,
allow_capabilities: Capabilities::MODIFICATIONS,
io_format: IoFormat::Binary,
expected_cardinality: Cardinality::Many,
};
let state = self.state.clone(); let ref mut conn = self.inner().conn;
let desc = conn.parse(&flags, query, &state).await?;
let inp_desc = desc.input()
.map_err(ProtocolEncodingError::with_source)?;
let mut arg_buf = BytesMut::with_capacity(8);
arguments.encode(&mut Encoder::new(
&inp_desc.as_query_arg_context(),
&mut arg_buf,
))?;
let data = conn.execute(
&flags, query, &state, &desc, &arg_buf.freeze(),
).await?;
let out_desc = desc.output()
.map_err(ProtocolEncodingError::with_source)?;
match out_desc.root_pos() {
Some(root_pos) => {
let ctx = out_desc.as_queryable_context();
let mut state = R::prepare(&ctx, root_pos)?;
let rows = data.into_iter()
.flat_map(|chunk| chunk.data)
.map(|chunk| R::decode(&mut state, &chunk))
.collect::<Result<_, _>>()?;
Ok(rows)
}
None => Err(NoResultExpected::build()),
}
}
pub async fn query_single<R, A>(&mut self, query: &str, arguments: &A)
-> Result<Option<R>, Error>
where A: QueryArgs,
R: QueryResult,
{
self.ensure_started().await?;
let flags = CompilationOptions {
implicit_limit: None,
implicit_typenames: false,
implicit_typeids: false,
explicit_objectids: true,
allow_capabilities: Capabilities::MODIFICATIONS,
io_format: IoFormat::Binary,
expected_cardinality: Cardinality::AtMostOne,
};
let state = self.state.clone(); let ref mut conn = self.inner().conn;
let desc = conn.parse(&flags, query, &state).await?;
let inp_desc = desc.input()
.map_err(ProtocolEncodingError::with_source)?;
let mut arg_buf = BytesMut::with_capacity(8);
arguments.encode(&mut Encoder::new(
&inp_desc.as_query_arg_context(),
&mut arg_buf,
))?;
let data = conn.execute(
&flags, query, &state, &desc, &arg_buf.freeze(),
).await?;
let out_desc = desc.output()
.map_err(ProtocolEncodingError::with_source)?;
match out_desc.root_pos() {
Some(root_pos) => {
let ctx = out_desc.as_queryable_context();
let mut state = R::prepare(&ctx, root_pos)?;
let bytes = data.into_iter().next()
.and_then(|chunk| chunk.data.into_iter().next());
if let Some(bytes) = bytes {
Ok(Some(R::decode(&mut state, &bytes)?))
} else {
Ok(None)
}
}
None => Err(NoResultExpected::build()),
}
}
pub async fn query_required_single<R, A>(&mut self, query: &str, arguments: &A)
-> Result<R, Error>
where A: QueryArgs,
R: QueryResult,
{
self.query_single(query, arguments).await?
.ok_or_else(|| NoDataError::with_message(
"query row returned zero results"))
}
pub async fn query_json(&mut self, query: &str, arguments: &impl QueryArgs)
-> Result<Json, Error>
{
self.ensure_started().await?;
let flags = CompilationOptions {
implicit_limit: None,
implicit_typenames: false,
implicit_typeids: false,
explicit_objectids: true,
allow_capabilities: Capabilities::MODIFICATIONS,
io_format: IoFormat::Json,
expected_cardinality: Cardinality::Many,
};
let state = self.state.clone(); let ref mut conn = self.inner().conn;
let desc = conn.parse(&flags, query, &state).await?;
let inp_desc = desc.input()
.map_err(ProtocolEncodingError::with_source)?;
let mut arg_buf = BytesMut::with_capacity(8);
arguments.encode(&mut Encoder::new(
&inp_desc.as_query_arg_context(),
&mut arg_buf,
))?;
let data = conn.execute(
&flags, query, &state, &desc, &arg_buf.freeze(),
).await?;
let out_desc = desc.output()
.map_err(ProtocolEncodingError::with_source)?;
match out_desc.root_pos() {
Some(root_pos) => {
let ctx = out_desc.as_queryable_context();
let mut state = String::prepare(&ctx, root_pos)?;
let bytes = data.into_iter().next()
.and_then(|chunk| chunk.data.into_iter().next());
if let Some(bytes) = bytes {
let s = String::decode(&mut state, &bytes)?;
Ok(unsafe { Json::new_unchecked(s) })
} else {
Err(NoDataError::with_message(
"query row returned zero results"))
}
}
None => Err(NoResultExpected::build()),
}
}
pub async fn query_single_json(&mut self,
query: &str, arguments: &impl QueryArgs)
-> Result<Option<Json>, Error>
{
self.ensure_started().await?;
let flags = CompilationOptions {
implicit_limit: None,
implicit_typenames: false,
implicit_typeids: false,
explicit_objectids: true,
allow_capabilities: Capabilities::MODIFICATIONS,
io_format: IoFormat::Json,
expected_cardinality: Cardinality::AtMostOne,
};
let state = self.state.clone(); let ref mut conn = self.inner().conn;
let desc = conn.parse(&flags, query, &state).await?;
let inp_desc = desc.input()
.map_err(ProtocolEncodingError::with_source)?;
let mut arg_buf = BytesMut::with_capacity(8);
arguments.encode(&mut Encoder::new(
&inp_desc.as_query_arg_context(),
&mut arg_buf,
))?;
let data = conn.execute(
&flags, query, &state, &desc, &arg_buf.freeze(),
).await?;
let out_desc = desc.output()
.map_err(ProtocolEncodingError::with_source)?;
match out_desc.root_pos() {
Some(root_pos) => {
let ctx = out_desc.as_queryable_context();
let mut state = String::prepare(&ctx, root_pos)?;
let bytes = data.into_iter().next()
.and_then(|chunk| chunk.data.into_iter().next());
if let Some(bytes) = bytes {
let s = String::decode(&mut state, &bytes)?;
Ok(Some(unsafe { Json::new_unchecked(s) }))
} else {
Ok(None)
}
}
None => Err(NoResultExpected::build()),
}
}
pub async fn query_required_single_json(&mut self,
query: &str, arguments: &impl QueryArgs)
-> Result<Json, Error>
{
self.query_single_json(query, arguments).await?
.ok_or_else(|| NoDataError::with_message(
"query row returned zero results"))
}
pub async fn execute<A>(&mut self, query: &str, arguments: &A)
-> Result<(), Error>
where A: QueryArgs,
{
self.ensure_started().await?;
let flags = CompilationOptions {
implicit_limit: None,
implicit_typenames: false,
implicit_typeids: false,
explicit_objectids: true,
allow_capabilities: Capabilities::MODIFICATIONS,
io_format: IoFormat::Binary,
expected_cardinality: Cardinality::Many,
};
let state = self.state.clone(); let ref mut conn = self.inner().conn;
let desc = conn.parse(&flags, query, &state).await?;
let inp_desc = desc.input()
.map_err(ProtocolEncodingError::with_source)?;
let mut arg_buf = BytesMut::with_capacity(8);
arguments.encode(&mut Encoder::new(
&inp_desc.as_query_arg_context(),
&mut arg_buf,
))?;
conn.execute(&flags, query, &state, &desc, &arg_buf.freeze()).await?;
Ok(())
}
}
#[allow(dead_code, unreachable_code)]
fn _transaction_assertions() {
let _cli: crate::Client = unimplemented!();
assert_send(
_cli.transaction(|mut tx| async move { tx.query_json("SELECT 'hello'", &()).await }),
);
}
fn assert_send<T: Send>(_: T) {}