use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use bytes::Bytes;
use futures::{SinkExt, StreamExt};
use tokio::sync::{mpsc, oneshot};
use tokio::task::JoinHandle;
use crate::error::{Error, ProtocolError, TransportError};
use crate::protocol::{Request, Response};
use crate::transport::BoxedTransport;
const SEND_QUEUE_CAPACITY: usize = 64;
struct Outbound {
id: String,
bytes: Bytes,
reply: oneshot::Sender<Result<Response, Error>>,
}
#[derive(Clone, Debug)]
pub(crate) struct DispatcherHandle {
tx: mpsc::Sender<Outbound>,
}
impl DispatcherHandle {
pub(crate) async fn send(&self, request: Request) -> crate::Result<Response> {
let id = request_id(&request).to_string();
let bytes = serde_json::to_vec(&request)
.map(Bytes::from)
.map_err(|e| Error::from(ProtocolError::Json(e)))?;
let (reply_tx, reply_rx) = oneshot::channel();
self.tx
.send(Outbound {
id,
bytes,
reply: reply_tx,
})
.await
.map_err(|_| Error::from(TransportError::Closed))?;
reply_rx
.await
.map_err(|_| Error::from(TransportError::Closed))?
}
}
pub(crate) struct Dispatcher {
handle: DispatcherHandle,
join: JoinHandle<()>,
}
impl Dispatcher {
pub(crate) fn handle(&self) -> DispatcherHandle {
self.handle.clone()
}
pub(crate) fn spawn(transport: BoxedTransport, in_flight: Arc<AtomicU32>) -> Self {
let (tx, rx) = mpsc::channel::<Outbound>(SEND_QUEUE_CAPACITY);
let handle = DispatcherHandle { tx };
let join = tokio::spawn(run(transport, rx, in_flight));
Self { handle, join }
}
}
impl Drop for Dispatcher {
fn drop(&mut self) {
self.join.abort();
}
}
#[allow(clippy::match_same_arms)]
fn request_id(request: &Request) -> &str {
match request {
Request::Connect { id, .. } => id,
Request::Sql { id, .. } => id,
Request::PrepareSql { id, .. } => id,
Request::PrepareSqlExecute { id, .. } => id,
Request::Execute { id, .. } => id,
Request::SqlMore { id, .. } => id,
Request::SqlClose { id, .. } => id,
Request::Cl { id, .. } => id,
Request::GetVersion { id } => id,
Request::GetDbJob { id } => id,
Request::SetConfig { id, .. } => id,
Request::GetTraceData { id } => id,
Request::Dove { id, .. } => id,
Request::Ping { id } => id,
Request::Exit { id } => id,
}
}
#[allow(clippy::match_same_arms)]
fn response_id(response: &Response) -> &str {
match response {
Response::Connected { id, .. } => id,
Response::Pong { id } => id,
Response::Exited { id } => id,
Response::QueryResult(q) => &q.id,
Response::PreparedStatement { id, .. } => id,
Response::SqlClosed { id, .. } => id,
Response::ClResult { id, .. } => id,
Response::Version { id, .. } => id,
Response::DbJob { id, .. } => id,
Response::ConfigSet { id, .. } => id,
Response::TraceData { id, .. } => id,
Response::DoveResult { id, .. } => id,
Response::Error(e) => &e.id,
}
}
async fn run(
mut transport: BoxedTransport,
mut rx: mpsc::Receiver<Outbound>,
in_flight: Arc<AtomicU32>,
) {
let mut pending: HashMap<String, oneshot::Sender<Result<Response, Error>>> = HashMap::new();
loop {
tokio::select! {
outbound = rx.recv() => {
match outbound {
Some(Outbound { id, bytes, reply }) => {
if let Err(e) = transport.send(bytes).await {
let _ = reply.send(Err(Error::from(e)));
drain_pending(&mut pending, &in_flight, TransportError::Closed);
return;
}
in_flight.fetch_add(1, Ordering::Relaxed);
pending.insert(id, reply);
}
None => {
return;
}
}
}
frame = transport.next() => {
match frame {
Some(Ok(bytes)) => match serde_json::from_slice::<Response>(&bytes) {
Ok(response) => {
let id = response_id(&response).to_owned();
if let Some(reply) = pending.remove(&id) {
in_flight.fetch_sub(1, Ordering::Relaxed);
let _ = reply.send(Ok(response));
}
}
Err(e) => {
drain_pending_with_error(
&mut pending,
&in_flight,
Error::from(ProtocolError::Json(e)),
);
return;
}
},
Some(Err(e)) => {
drain_pending_with_error(&mut pending, &in_flight, Error::from(e));
return;
}
None => {
drain_pending(&mut pending, &in_flight, TransportError::Closed);
return;
}
}
}
}
}
}
fn drain_pending(
pending: &mut HashMap<String, oneshot::Sender<Result<Response, Error>>>,
in_flight: &Arc<AtomicU32>,
closed: TransportError,
) {
let mut iter = pending.drain();
if let Some((_id, reply)) = iter.next() {
in_flight.fetch_sub(1, Ordering::Relaxed);
let _ = reply.send(Err(Error::from(closed)));
}
for (_id, reply) in iter {
in_flight.fetch_sub(1, Ordering::Relaxed);
let _ = reply.send(Err(Error::from(TransportError::Closed)));
}
}
fn drain_pending_with_error(
pending: &mut HashMap<String, oneshot::Sender<Result<Response, Error>>>,
in_flight: &Arc<AtomicU32>,
err: Error,
) {
let mut iter = pending.drain();
if let Some((_id, reply)) = iter.next() {
in_flight.fetch_sub(1, Ordering::Relaxed);
let _ = reply.send(Err(err));
}
for (_id, reply) in iter {
in_flight.fetch_sub(1, Ordering::Relaxed);
let _ = reply.send(Err(Error::from(TransportError::Closed)));
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::{ClMessage, ErrorResponse, QueryMetaData, QueryResult};
#[test]
fn test_request_id_returns_carried_id_for_every_variant() {
let id = "test-id".to_string();
let cases: &[Request] = &[
Request::Connect {
id: id.clone(),
user: "u".into(),
password: "p".into(),
},
Request::Sql {
id: id.clone(),
sql: "SELECT 1".into(),
rows: None,
parameters: None,
},
Request::PrepareSql {
id: id.clone(),
sql: "SELECT 1".into(),
},
Request::PrepareSqlExecute {
id: id.clone(),
sql: "SELECT 1".into(),
parameters: None,
rows: None,
},
Request::Execute {
id: id.clone(),
cont_id: "c".into(),
parameters: None,
},
Request::SqlMore {
id: id.clone(),
cont_id: "c".into(),
rows: 10,
},
Request::SqlClose {
id: id.clone(),
cont_id: "c".into(),
},
Request::Cl {
id: id.clone(),
cmd: "DSPLIB".into(),
},
Request::GetVersion { id: id.clone() },
Request::GetDbJob { id: id.clone() },
Request::SetConfig {
id: id.clone(),
tracedest: "FILE".into(),
tracelevel: "ERRORS".into(),
},
Request::GetTraceData { id: id.clone() },
Request::Dove {
id: id.clone(),
sql: "SELECT 1".into(),
},
Request::Ping { id: id.clone() },
Request::Exit { id: id.clone() },
];
for req in cases {
assert_eq!(request_id(req), id.as_str());
}
}
#[test]
fn test_response_id_returns_carried_id_for_every_variant() {
let id = "test-id".to_string();
let qr = QueryResult {
id: id.clone(),
success: true,
has_results: false,
update_count: -1,
cont_id: None,
is_done: true,
metadata: QueryMetaData {
column_count: 0,
columns: vec![],
},
data: vec![],
execution_time: 0.0,
};
let err = ErrorResponse {
id: id.clone(),
success: false,
sqlstate: None,
sqlcode: None,
error: None,
job: None,
};
let cases: &[Response] = &[
Response::Connected {
id: id.clone(),
version: "1".into(),
job: "J".into(),
},
Response::Pong { id: id.clone() },
Response::Exited { id: id.clone() },
Response::QueryResult(qr),
Response::PreparedStatement {
id: id.clone(),
success: true,
cont_id: "c".into(),
execution_time: 0.0,
},
Response::SqlClosed {
id: id.clone(),
success: true,
},
Response::ClResult {
id: id.clone(),
success: true,
messages: vec![ClMessage {
id: None,
kind: None,
text: None,
}],
},
Response::Version {
id: id.clone(),
success: true,
version: "1".into(),
},
Response::DbJob {
id: id.clone(),
success: true,
job: "J".into(),
},
Response::ConfigSet {
id: id.clone(),
success: true,
},
Response::TraceData {
id: id.clone(),
success: true,
tracedata: String::new(),
},
Response::DoveResult {
id: id.clone(),
success: true,
result: serde_json::json!({}),
},
Response::Error(err),
];
for resp in cases {
assert_eq!(response_id(resp), id.as_str());
}
}
}