use crate::client::DgraphClient;
use crate::types::{TxnContext, Request, Extensions, Mutation, Assigned, DgraphError};
use std::error::Error;
use std::fmt;
use std::collections::HashMap;
use serde_json::Value;
use std::time::SystemTime;
#[derive(Debug, Clone)]
pub struct TxnFinished {}
impl fmt::Display for TxnFinished {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "ERR_FINISHED")
}
}
impl Error for TxnFinished {
fn description(&self) -> &str {
"Transaction has already been committed or discarded"
}
fn cause(&self) -> Option<&Error> {
None
}
}
fn finished_error() -> TxnFinished {
let err = TxnFinished {};
err
}
#[derive(Debug, Clone)]
pub struct StartTxnMismatch {}
impl fmt::Display for StartTxnMismatch {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "ERR_START_TXN_MISMATCH")
}
}
impl Error for StartTxnMismatch {
fn description(&self) -> &str {
"StartTs mismatch"
}
fn cause(&self) -> Option<&Error> {
None
}
}
fn start_txn_mismatch_error() -> StartTxnMismatch {
let err = StartTxnMismatch {};
err
}
#[derive(Debug, Clone)]
pub struct MutationErrors {
errors: Vec<DgraphError>,
}
impl fmt::Display for MutationErrors {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "MUTATION_ERRORS. {:?}", self.errors)
}
}
impl Error for MutationErrors {
fn description(&self) -> &str {
"Mutation data contains errors"
}
fn cause(&self) -> Option<&Error> {
None
}
}
fn mutation_error(errors: Vec<DgraphError>) -> MutationErrors {
let err = MutationErrors { errors };
err
}
fn merge_vec(mut a: Vec<String>, b: Vec<String>) -> Vec<String> {
a.extend(b);
a.sort_unstable();
a.dedup();
a
}
fn merge_context(ctx: TxnContext, extensions: Extensions) -> Result<TxnContext, Box<Error>> {
match extensions.txn {
Some(src) => {
let start_ts = if ctx.start_ts.unwrap() == 0 {
src.start_ts
} else if ctx.start_ts.unwrap() != src.start_ts.unwrap() {
return Err(Box::new(start_txn_mismatch_error()));
} else {
ctx.start_ts
};
let keys = if let Some(keys) = src.keys {
let actual_keys = ctx.keys.unwrap();
Some(merge_vec(keys, actual_keys))
} else {
None
};
let preds = if let Some(preds) = src.preds {
let actual_preds = ctx.preds.unwrap();
Some(merge_vec(preds, actual_preds))
} else {
None
};
let ctx = TxnContext {
start_ts,
keys,
preds,
aborted: Some(false),
};
Ok(ctx)
}
None => Ok(ctx)
}
}
pub struct Txn<'a> {
dc: &'a DgraphClient,
ctx: TxnContext,
finished: Option<bool>,
mutated: Option<bool>,
}
impl<'a> Txn<'a> {
pub fn new(dc: &DgraphClient) -> Txn {
let ctx = TxnContext {
start_ts: Some(0),
keys: Some(Vec::new()),
preds: Some(Vec::new()),
aborted: None,
};
Txn {
dc,
ctx,
finished: Some(false),
mutated: Some(false),
}
}
pub fn query(&mut self, q: String) -> Result<Value, Box<Error>> {
self.query_with_vars(q, None)
}
pub fn query_with_vars(&mut self, q: String, vars: Option<HashMap<String, String>>)
-> Result<Value, Box<Error>> {
let finished = self.finished.unwrap();
if finished == true {
self.dc.debug(format!("Query request (ERR_FINISHED):\nquery = {}\nvars = {:?}",
q, vars));
return Err(Box::new(finished_error()));
}
let query = q.clone();
let req = Request {
query: Some(q),
start_ts: self.ctx.start_ts,
vars,
};
self.dc.debug(format!("Query request:\n{}\nvars:{}",
query,
if let Some(vars) = &req.vars {
serde_json::to_string(vars)?
} else {
String::from("")
}));
let client = self.dc.any_client();
let start_time = SystemTime::now();
match client.query(req) {
Ok((value, mut extensions)) => {
let ctx = self.ctx.clone();
let latency = &mut extensions.server_latency;
latency.network_ns = Some(start_time.elapsed()?.subsec_nanos());
let value_json = serde_json::to_string(&value)?;
let debug_msg = format!("Query response:\n{}\nQuery {:?}", value_json, latency);
match merge_context(ctx, extensions) {
Ok(context) => {
self.ctx = context;
self.dc.debug(debug_msg);
Ok(value)
}
Err(e) => Err(e),
}
}
Err(e) => Err(e),
}
}
pub fn mutation(&mut self, mut mu: Mutation) -> Result<Assigned, Box<Error>> {
if self.finished.unwrap() == true {
self.dc.debug(format!("Mutate request (ERR_FINISHED):\nmutation = {:?}", mu));
return Err(Box::new(finished_error()));
}
self.mutated = Some(true);
mu.start_ts = self.ctx.start_ts;
self.dc.debug(format!("Mutate request:\n{:?}", mu));
if let Some(commit_now) = mu.commit_now {
if commit_now == true {
self.finished = Some(true);
}
};
let client = self.dc.any_client();
let start_time = SystemTime::now();
match client.mutate(mu) {
Ok(assigned) => {
let result = assigned.clone();
if let Some(mut extensions) = assigned.extensions {
let ctx = self.ctx.clone();
let latency = &mut extensions.server_latency;
latency.network_ns = Some(start_time.elapsed()?.subsec_nanos());
let debug_msg = format!("Mutate response:\n{:?}\nMutate {:?}",
result, latency);
match merge_context(ctx, extensions) {
Ok(context) => {
self.ctx = context;
self.dc.debug(debug_msg);
Ok(result)
}
Err(e) => Err(e),
}
} else if let Some(errors) = assigned.errors {
Err(Box::new(mutation_error(errors)))
} else {
Ok(result)
}
}
Err(e) => {
match self.discard() {
Ok(_) => Err(e),
Err(_) => Err(e),
}
}
}
}
pub fn commit(&mut self) -> Result<(), Box<Error>> {
if self.finished.unwrap() == true {
return Err(Box::new(finished_error()));
}
self.finished = Some(true);
let client = self.dc.any_client();
match client.commit(self.ctx.clone()) {
Ok(_ctx) => {
Ok(())
}
Err(e) => Err(e),
}
}
pub fn discard(&mut self) -> Result<(), Box<Error>> {
if self.finished.unwrap() == true {
return Ok(());
}
self.finished = Some(true);
if self.mutated.unwrap() == false {
return Ok(());
}
let ctx = &mut self.ctx;
ctx.aborted = Some(true);
let client = self.dc.any_client();
match client.abort(ctx) {
Ok(_ctx) => {
Ok(())
}
Err(e) => Err(e),
}
}
}