use crate::channel::CONNECTION_FROM_BASE;
use crate::data::*;
use crate::internal::*;
use crate::LocalOrNot;
use crate::{Tagged, Tagger};
use async_bincode::{AsyncBincodeStream, AsyncDestination};
use futures_util::{
future, future::TryFutureExt, ready, stream::futures_unordered::FuturesUnordered,
stream::TryStreamExt,
};
use nom_sql::CreateTableStatement;
use petgraph::graph::NodeIndex;
use std::collections::HashMap;
use std::future::Future;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::{fmt, io};
use tokio::io::AsyncWriteExt;
use tokio_tower::multiplex;
use tower_balance::p2c::Balance;
use tower_buffer::Buffer;
use tower_discover::ServiceStream;
use tower_limit::concurrency::ConcurrencyLimit;
use tower_service::Service;
use vec_map::VecMap;
type Transport = AsyncBincodeStream<
tokio::net::TcpStream,
Tagged<()>,
Tagged<LocalOrNot<Input>>,
AsyncDestination,
>;
#[macro_export]
macro_rules! row {
($tbl:ident, $($k:expr => $v:expr),+ $(,)*) => {
$crate::row!(@step $tbl, $($k => $v),+)
};
(@replace_expr ($_t:expr, $sub:expr)) => {$sub};
(@count_tts ($($e:expr),*)) => {<[()]>::len(&[$($crate::row!(@replace_expr ($e, ()))),*])};
(@step $tbl:ident, $(@$idx:expr; $ik:expr => $iv:expr,)* $ck:expr => $cv:expr $(, $k:expr => $v:expr)*) => {
$crate::row!(@step $tbl, $(@$idx; $ik => $iv,)* @$crate::row!(@count_tts ($($ik),*)); $ck => $cv $(, $k => $v)*)
};
(@step $tbl:ident, $(@$idx:expr; $k:expr => $v:expr),+) => {{
let mut row = vec![$crate::DataType::None; $tbl.columns().len()];
let mut vals = [$(Some(Into::<$crate::DataType>::into($v))),+];
let schema = $tbl.schema();
for (coli, col) in $tbl.columns().iter().enumerate() {
match &**col {
$($k => {
row[coli] = vals[$idx].take().expect("field name appears twice -- should be caught by match");
if let Some(ref schema) = schema {
if schema.fields[coli].constraints.iter().any(|c| c == &$crate::ColumnConstraint::NotNull) {
assert!(!row[coli].is_none(), "Attempted to set NOT NULL column '{}' to DataType::None", col);
}
}
},)|+
cname if schema.is_some() => {
let schema = schema.as_ref().unwrap();
let mut allow_null = true;
let spec = &schema.fields[coli];
for c in &spec.constraints {
use $crate::ColumnConstraint;
match c {
ColumnConstraint::NotNull => {
allow_null = false;
}
ColumnConstraint::DefaultValue(ref literal) => {
row[coli] = Into::<$crate::DataType>::into(literal);
}
ColumnConstraint::AutoIncrement => {
}
_ => {}
}
}
if !allow_null && row[coli].is_none() {
panic!("Column {} is declared NOT NULL, has no default, and was not provided", cname);
}
}
_ => { }
}
}
row
}};
}
#[cfg(test)]
#[allow(dead_code)]
async fn add_user(users: &mut Table) -> Result<(), TableError> {
let s = String::from("non copy");
let user = row!(users,
"username" => "jonhoo",
"password" => "hunter2",
"created_at" => chrono::Local::now().naive_local(),
"not an ident" => s,
"logins" => 0,
);
users.insert(user).await
}
#[macro_export]
macro_rules! update {
($tbl:ident, $($k:expr => $v:expr),+ $(,)*) => { $crate::update!(@step $tbl, $($k => $v),+) };
(@replace_expr ($_t:expr, $sub:expr)) => {$sub};
(@count_tts ($($e:expr),*)) => {<[()]>::len(&[$($crate::update!(@replace_expr ($e, ()))),*])};
(@step $tbl:ident, $(@$idx:expr; $ik:expr => $iv:expr,)* $ck:expr => $cv:expr $(, $k:expr => $v:expr)*) => {
$crate::update!(@step $tbl, $(@$idx; $ik => $iv,)* @$crate::update!(@count_tts ($($ik),*)); $ck => $cv $(, $k => $v)*)
};
(@step $tbl:ident, $(@$idx:expr; $k:expr => $v:expr),+) => {{
let mut set = vec![$((0, Into::<$crate::Modification>::into($v))),+];
for (coli, col) in $tbl.columns().iter().enumerate() {
match &**col {
$($k => {
set[$idx].0 = coli;
},)|+
_ => { }
}
}
set
}};
}
#[cfg(test)]
#[allow(dead_code)]
async fn update_user(users: &mut Table) -> Result<(), TableError> {
let user = update!(users,
"password" => "hunter3",
"logins" => crate::Modification::Apply(crate::Operation::Add, 1.into()),
);
users.update(vec!["jonhoo".into()], user).await
}
#[derive(Debug)]
struct Endpoint(SocketAddr);
type InnerService = multiplex::Client<
multiplex::MultiplexTransport<Transport, Tagger>,
tokio_tower::Error<multiplex::MultiplexTransport<Transport, Tagger>, Tagged<LocalOrNot<Input>>>,
Tagged<LocalOrNot<Input>>,
>;
impl Service<()> for Endpoint {
type Response = InnerService;
type Error = tokio::io::Error;
#[cfg(not(doc))]
type Future = impl Future<Output = Result<Self::Response, Self::Error>>;
#[cfg(doc)]
type Future = crate::doc_mock::Future<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _: ()) -> Self::Future {
let f = tokio::net::TcpStream::connect(self.0);
async move {
let mut s = f.await?;
s.set_nodelay(true)?;
s.write_all(&[CONNECTION_FROM_BASE]).await.unwrap();
s.flush().await.unwrap();
let s = AsyncBincodeStream::from(s).for_async();
let t = multiplex::MultiplexTransport::new(s, Tagger::default());
Ok(multiplex::Client::with_error_handler(t, |e| {
eprintln!("table server went away: {}", e)
}))
}
}
}
fn make_table_stream(
addr: SocketAddr,
) -> impl futures_util::stream::TryStream<
Ok = tower_discover::Change<usize, InnerService>,
Error = tokio::io::Error,
> {
(0..crate::TABLE_POOL_SIZE)
.map(|i| async move {
let svc = Endpoint(addr).call(()).await?;
Ok(tower_discover::Change::Insert(i, svc))
})
.collect::<futures_util::stream::FuturesUnordered<_>>()
}
fn make_table_discover(addr: SocketAddr) -> Discover {
ServiceStream::new(make_table_stream(addr))
}
#[cfg(not(doc))]
type Discover = impl tower_discover::Discover<Key = usize, Service = InnerService, Error = tokio::io::Error>
+ Unpin
+ Send;
#[cfg(doc)]
type Discover = crate::doc_mock::Discover<InnerService>;
pub(crate) type TableRpc = Buffer<
ConcurrencyLimit<Balance<Discover, Tagged<LocalOrNot<Input>>>>,
Tagged<LocalOrNot<Input>>,
>;
#[derive(Debug, Fail)]
pub enum TableError {
#[fail(
display = "wrong number of columns specified: expected {}, got {}",
_0, _1
)]
WrongColumnCount(usize, usize),
#[fail(
display = "wrong number of key columns used: expected {}, got {}",
_0, _1
)]
WrongKeyColumnCount(usize, usize),
#[fail(display = "{}", _0)]
TransportError(#[cause] failure::Error),
}
impl From<Box<dyn std::error::Error + Send + Sync>> for TableError {
fn from(e: Box<dyn std::error::Error + Send + Sync>) -> Self {
TableError::TransportError(failure::Error::from_boxed_compat(e))
}
}
#[doc(hidden)]
#[derive(Clone, Serialize, Deserialize)]
pub struct Input {
pub dst: LocalNodeIndex,
pub data: Vec<TableOperation>,
}
impl fmt::Debug for Input {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Input")
.field("dst", &self.dst)
.field("data", &self.data)
.finish()
}
}
#[doc(hidden)]
#[derive(Clone, Serialize, Deserialize)]
pub struct TableBuilder {
pub txs: Vec<SocketAddr>,
pub ni: NodeIndex,
pub addr: LocalNodeIndex,
pub key_is_primary: bool,
pub key: Vec<usize>,
pub dropped: VecMap<DataType>,
pub table_name: String,
pub columns: Vec<String>,
pub schema: Option<CreateTableStatement>,
}
impl TableBuilder {
pub(crate) fn build(
self,
rpcs: Arc<Mutex<HashMap<(SocketAddr, usize), TableRpc>>>,
) -> Result<Table, io::Error> {
let mut addrs = Vec::with_capacity(self.txs.len());
let mut conns = Vec::with_capacity(self.txs.len());
for (shardi, &addr) in self.txs.iter().enumerate() {
use std::collections::hash_map::Entry;
addrs.push(addr);
let mut rpcs = rpcs.lock().unwrap();
let s = match rpcs.entry((addr, shardi)) {
Entry::Occupied(e) => e.get().clone(),
Entry::Vacant(h) => {
let (c, w) = Buffer::pair(
ConcurrencyLimit::new(
Balance::from_entropy(make_table_discover(addr)),
crate::PENDING_LIMIT,
),
crate::BUFFER_TO_POOL,
);
use tracing_futures::Instrument;
tokio::spawn(w.instrument(tracing::debug_span!(
"table_worker",
addr = %addr,
shard = shardi
)));
h.insert(c.clone());
c
}
};
conns.push(s);
}
let dispatch = tracing::dispatcher::get_default(|d| d.clone());
Ok(Table {
ni: self.ni,
node: self.addr,
key: self.key,
key_is_primary: self.key_is_primary,
columns: self.columns,
dropped: self.dropped,
table_name: self.table_name,
schema: self.schema,
dst_is_local: false,
shard_addrs: addrs,
shards: conns,
dispatch,
})
}
}
#[derive(Clone)]
pub struct Table {
ni: NodeIndex,
node: LocalNodeIndex,
key_is_primary: bool,
key: Vec<usize>,
columns: Vec<String>,
dropped: VecMap<DataType>,
table_name: String,
schema: Option<CreateTableStatement>,
dst_is_local: bool,
shards: Vec<TableRpc>,
shard_addrs: Vec<SocketAddr>,
dispatch: tracing::Dispatch,
}
impl fmt::Debug for Table {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Table")
.field("ni", &self.ni)
.field("node", &self.node)
.field("key_is_primary", &self.key_is_primary)
.field("key", &self.key)
.field("columns", &self.columns)
.field("dropped", &self.dropped)
.field("table_name", &self.table_name)
.field("schema", &self.schema)
.field("dst_is_local", &self.dst_is_local)
.field("shard_addrs", &self.shard_addrs)
.finish()
}
}
impl Table {
#[allow(clippy::cognitive_complexity)]
fn input(
&mut self,
mut i: Input,
) -> impl Future<Output = Result<Tagged<()>, TableError>> + Send {
let span = if crate::trace_next_op() {
Some(tracing::trace_span!(
"table-request",
base = self.ni.index()
))
} else {
None
};
let immediate_err = || {
let ncols = self.columns.len() + self.dropped.len();
for op in &i.data {
match op {
TableOperation::Insert(ref row) => {
if row.len() != ncols {
return Err(TableError::WrongColumnCount(ncols, row.len()));
}
}
TableOperation::Delete { ref key } => {
if key.len() != self.key.len() {
return Err(TableError::WrongKeyColumnCount(self.key.len(), key.len()));
}
}
TableOperation::InsertOrUpdate {
ref row,
ref update,
} => {
if row.len() != ncols {
return Err(TableError::WrongColumnCount(ncols, row.len()));
}
if update.len() > self.columns.len() {
return Err(TableError::WrongColumnCount(
self.columns.len(),
update.len(),
));
}
}
TableOperation::Update { ref set, ref key } => {
if key.len() != self.key.len() {
return Err(TableError::WrongKeyColumnCount(self.key.len(), key.len()));
}
if set.len() > self.columns.len() {
return Err(TableError::WrongColumnCount(
self.columns.len(),
set.len(),
));
}
}
}
}
Ok(())
};
if let Err(e) = immediate_err() {
return future::Either::Left(async move { Err(e) });
}
if self.shards.len() == 1 {
let request = Tagged::from(if self.dst_is_local {
unsafe { LocalOrNot::for_local_transfer(i) }
} else {
LocalOrNot::new(i)
});
let _guard = span.as_ref().map(tracing::Span::enter);
tracing::trace!("submit request");
future::Either::Right(future::Either::Left(
self.shards[0].call(request).map_err(TableError::from),
))
} else {
if self.key.is_empty() {
unreachable!("sharded base without a key?");
}
if self.key.len() != 1 {
unimplemented!();
}
let key_col = self.key[0];
let _guard = span.as_ref().map(tracing::Span::enter);
tracing::trace!("shard request");
let mut shard_writes = vec![Vec::new(); self.shards.len()];
for r in i.data.drain(..) {
let shard = {
let key = match r {
TableOperation::Insert(ref r) => &r[key_col],
TableOperation::Delete { ref key } => &key[0],
TableOperation::Update { ref key, .. } => &key[0],
TableOperation::InsertOrUpdate { ref row, .. } => &row[key_col],
};
crate::shard_by(key, self.shards.len())
};
shard_writes[shard].push(r);
}
let wait_for = FuturesUnordered::new();
for (s, rs) in shard_writes.drain(..).enumerate() {
if !rs.is_empty() {
let p = if self.dst_is_local {
unsafe {
LocalOrNot::for_local_transfer(Input {
dst: i.dst,
data: rs,
})
}
} else {
LocalOrNot::new(Input {
dst: i.dst,
data: rs,
})
};
let request = Tagged::from(p);
let span = if span.is_some() {
Some(tracing::trace_span!("table-shard", s))
} else {
None
};
let _guard = span.as_ref().map(tracing::Span::enter);
tracing::trace!("submit request shard");
wait_for.push(self.shards[s].call(request));
} else {
self.shards[s] = self.shards[s].clone()
}
}
future::Either::Right(future::Either::Right(
wait_for
.try_for_each(|_| async { Ok(()) })
.map_err(TableError::from)
.map_ok(Tagged::from),
))
}
}
}
impl Service<Vec<TableOperation>> for Table {
type Error = TableError;
type Response = <TableRpc as Service<Tagged<LocalOrNot<Input>>>>::Response;
#[cfg(not(doc))]
type Future = impl Future<Output = Result<Tagged<()>, TableError>> + Send;
#[cfg(doc)]
type Future = crate::doc_mock::Future<Result<Tagged<()>, TableError>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
for s in &mut self.shards {
ready!(s.poll_ready(cx)).map_err(TableError::from)?;
}
Poll::Ready(Ok(()))
}
fn call(&mut self, ops: Vec<TableOperation>) -> Self::Future {
let i = self.prep_records(ops);
self.input(i)
}
}
impl Table {
pub fn table_name(&self) -> &str {
&self.table_name
}
#[doc(hidden)]
pub fn i_promise_dst_is_same_process(&mut self) {
self.dst_is_local = true;
}
pub fn columns(&self) -> &[String] {
&self.columns
}
pub fn schema(&self) -> Option<&CreateTableStatement> {
self.schema.as_ref()
}
fn inject_dropped_cols(&self, r: &mut TableOperation) {
use std::mem;
let ndropped = self.dropped.len();
if ndropped != 0 {
let dropped = self.dropped.iter().rev();
let r = match *r {
TableOperation::Insert(ref mut row)
| TableOperation::InsertOrUpdate { ref mut row, .. } => row,
_ => unimplemented!("we need to shift the update/delete cols!"),
};
let n = r.len() + ndropped;
let mut hole = n;
let mut last_unmoved = r.len() - 1;
r.resize(n, DataType::None);
for (next_insert, default) in dropped {
hole -= 1;
while hole != next_insert {
r.swap(last_unmoved, hole);
hole -= 1;
if last_unmoved == 0 {
debug_assert_eq!(next_insert, 0);
debug_assert_eq!(hole, 0);
break;
}
last_unmoved -= 1;
}
let current = &mut r[next_insert];
let old = mem::replace(current, default.clone());
debug_assert_eq!(old, DataType::None);
}
}
}
fn prep_records(&self, mut ops: Vec<TableOperation>) -> Input {
for r in &mut ops {
self.inject_dropped_cols(r);
}
Input {
dst: self.node,
data: ops,
}
}
async fn quick_n_dirty<Request, R>(
&mut self,
r: Request,
) -> Result<R, <Self as Service<Request>>::Error>
where
Request: Send + 'static,
Self: Service<Request, Response = Tagged<R>>,
{
future::poll_fn(|cx| self.poll_ready(cx)).await?;
Ok(self.call(r).await?.v)
}
pub async fn insert<V>(&mut self, u: V) -> Result<(), TableError>
where
V: Into<Vec<DataType>>,
{
self.quick_n_dirty(vec![TableOperation::Insert(u.into())])
.await
}
pub async fn perform_all<I, V>(&mut self, i: I) -> Result<(), TableError>
where
I: IntoIterator<Item = V>,
V: Into<TableOperation>,
{
self.quick_n_dirty(i.into_iter().map(Into::into).collect::<Vec<_>>())
.await
}
pub async fn delete<I>(&mut self, key: I) -> Result<(), TableError>
where
I: Into<Vec<DataType>>,
{
self.quick_n_dirty(vec![TableOperation::Delete { key: key.into() }])
.await
}
pub async fn update<V>(&mut self, key: Vec<DataType>, u: V) -> Result<(), TableError>
where
V: IntoIterator<Item = (usize, Modification)>,
{
assert!(
!self.key.is_empty() && self.key_is_primary,
"update operations can only be applied to base nodes with key columns"
);
let mut set = vec![Modification::None; self.columns.len()];
for (coli, m) in u {
if coli >= self.columns.len() {
return Err(TableError::WrongColumnCount(self.columns.len(), coli + 1));
}
set[coli] = m;
}
self.quick_n_dirty(vec![TableOperation::Update { key, set }])
.await
}
pub async fn insert_or_update<V>(
&mut self,
insert: Vec<DataType>,
update: V,
) -> Result<(), TableError>
where
V: IntoIterator<Item = (usize, Modification)>,
{
assert!(
!self.key.is_empty() && self.key_is_primary,
"update operations can only be applied to base nodes with key columns"
);
let mut set = vec![Modification::None; self.columns.len()];
for (coli, m) in update {
if coli >= self.columns.len() {
return Err(TableError::WrongColumnCount(self.columns.len(), coli + 1));
}
set[coli] = m;
}
self.quick_n_dirty(vec![TableOperation::InsertOrUpdate {
row: insert,
update: set,
}])
.await
}
}