use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use tokio_stream::Stream;
use tonic::{Request, Response, Status};
use crate::proto;
use crate::proto::gql_service_server::GqlService;
use crate::status as gql_status;
use crate::types::Value;
use super::backend::{GqlBackend, ResultFrame, ResultStream};
use super::{SessionHandle, SessionManager, TransactionHandle, TransactionManager};
pub struct GqlServiceImpl<B: GqlBackend> {
backend: Arc<B>,
sessions: SessionManager,
transactions: TransactionManager,
}
impl<B: GqlBackend> GqlServiceImpl<B> {
pub fn new(
backend: Arc<B>,
sessions: SessionManager,
transactions: TransactionManager,
) -> Self {
Self {
backend,
sessions,
transactions,
}
}
async fn validate_session(&self, session_id: &str) -> Result<(), Status> {
if self.sessions.exists(session_id).await {
self.sessions.touch(session_id).await;
Ok(())
} else {
Err(Status::not_found(format!("session {session_id} not found")))
}
}
}
#[tonic::async_trait]
impl<B: GqlBackend> GqlService for GqlServiceImpl<B> {
type ExecuteStream = Pin<Box<dyn Stream<Item = Result<proto::ExecuteResponse, Status>> + Send>>;
#[tracing::instrument(skip(self, request), fields(session_id, statement))]
async fn execute(
&self,
request: Request<proto::ExecuteRequest>,
) -> Result<Response<Self::ExecuteStream>, Status> {
let req = request.into_inner();
let span = tracing::Span::current();
span.record("session_id", &req.session_id);
span.record(
"statement",
tracing::field::display(if req.statement.len() > 100 {
&req.statement[..100]
} else {
&req.statement
}),
);
self.validate_session(&req.session_id).await?;
let session = SessionHandle(req.session_id.clone());
let transaction = if let Some(ref tx_id) = req.transaction_id {
self.transactions
.validate(tx_id, &req.session_id)
.await
.map_err(|e| e.to_grpc_status())?;
Some(TransactionHandle(tx_id.clone()))
} else {
None
};
let parameters: HashMap<String, Value> = req
.parameters
.into_iter()
.map(|(k, v)| (k, Value::from(v)))
.collect();
let result_stream = self
.backend
.execute(&session, &req.statement, ¶meters, transaction.as_ref())
.await;
match result_stream {
Ok(stream) => {
let output = ResultStreamAdapter { inner: stream };
Ok(Response::new(Box::pin(output)))
}
Err(err) => {
tracing::warn!(error = %err, "execute failed");
let status = match err.gql_status() {
Some(s) => s.clone(),
None => gql_status::error(gql_status::DATA_EXCEPTION, err.to_string()),
};
let summary_stream = futures_single_response(proto::ExecuteResponse {
frame: Some(proto::execute_response::Frame::Summary(
proto::ResultSummary {
status: Some(status),
warnings: Vec::new(),
rows_affected: 0,
counters: HashMap::new(),
},
)),
});
Ok(Response::new(Box::pin(summary_stream)))
}
}
}
#[tracing::instrument(skip(self, request), fields(session_id))]
async fn begin_transaction(
&self,
request: Request<proto::BeginRequest>,
) -> Result<Response<proto::BeginResponse>, Status> {
let req = request.into_inner();
tracing::Span::current().record("session_id", &req.session_id);
self.validate_session(&req.session_id).await?;
let session = SessionHandle(req.session_id.clone());
let mode =
proto::TransactionMode::try_from(req.mode).unwrap_or(proto::TransactionMode::ReadWrite);
match self.backend.begin_transaction(&session, mode).await {
Ok(handle) => {
let tx_id = handle.0.clone();
if let Err(e) = self
.transactions
.register(&tx_id, &req.session_id, mode)
.await
{
let _ = self.backend.rollback(&session, &handle).await;
tracing::warn!(session_id = %req.session_id, "double begin rejected");
return Ok(Response::new(proto::BeginResponse {
transaction_id: String::new(),
status: Some(gql_status::error(
gql_status::ACTIVE_TRANSACTION,
e.to_string(),
)),
}));
}
self.sessions
.set_active_transaction(&req.session_id, Some(tx_id.clone()))
.await
.ok();
tracing::info!(session_id = %req.session_id, transaction_id = %tx_id, "transaction started");
Ok(Response::new(proto::BeginResponse {
transaction_id: tx_id,
status: Some(gql_status::success()),
}))
}
Err(err) => {
let status = match err.gql_status() {
Some(s) => s.clone(),
None => gql_status::error(gql_status::ACTIVE_TRANSACTION, err.to_string()),
};
Ok(Response::new(proto::BeginResponse {
transaction_id: String::new(),
status: Some(status),
}))
}
}
}
#[tracing::instrument(skip(self, request), fields(session_id, transaction_id))]
async fn commit(
&self,
request: Request<proto::CommitRequest>,
) -> Result<Response<proto::CommitResponse>, Status> {
let req = request.into_inner();
let span = tracing::Span::current();
span.record("session_id", &req.session_id);
span.record("transaction_id", &req.transaction_id);
self.validate_session(&req.session_id).await?;
if let Err(e) = self
.transactions
.validate(&req.transaction_id, &req.session_id)
.await
{
return Ok(Response::new(proto::CommitResponse {
status: Some(gql_status::error(
gql_status::INVALID_TRANSACTION_STATE,
e.to_string(),
)),
}));
}
let session = SessionHandle(req.session_id.clone());
let transaction = TransactionHandle(req.transaction_id.clone());
match self.backend.commit(&session, &transaction).await {
Ok(()) => {
self.transactions.remove(&req.transaction_id).await.ok();
self.sessions
.set_active_transaction(&req.session_id, None)
.await
.ok();
tracing::info!("transaction committed");
Ok(Response::new(proto::CommitResponse {
status: Some(gql_status::success()),
}))
}
Err(err) => {
tracing::warn!(error = %err, "commit failed");
let status = match err.gql_status() {
Some(s) => s.clone(),
None => gql_status::error(gql_status::TRANSACTION_ROLLBACK, err.to_string()),
};
Ok(Response::new(proto::CommitResponse {
status: Some(status),
}))
}
}
}
#[tracing::instrument(skip(self, request), fields(session_id, transaction_id))]
async fn rollback(
&self,
request: Request<proto::RollbackRequest>,
) -> Result<Response<proto::RollbackResponse>, Status> {
let req = request.into_inner();
let span = tracing::Span::current();
span.record("session_id", &req.session_id);
span.record("transaction_id", &req.transaction_id);
self.validate_session(&req.session_id).await?;
if let Err(e) = self
.transactions
.validate(&req.transaction_id, &req.session_id)
.await
{
return Ok(Response::new(proto::RollbackResponse {
status: Some(gql_status::error(
gql_status::INVALID_TRANSACTION_STATE,
e.to_string(),
)),
}));
}
let session = SessionHandle(req.session_id.clone());
let transaction = TransactionHandle(req.transaction_id.clone());
match self.backend.rollback(&session, &transaction).await {
Ok(()) => {
self.transactions.remove(&req.transaction_id).await.ok();
self.sessions
.set_active_transaction(&req.session_id, None)
.await
.ok();
tracing::info!("transaction rolled back");
Ok(Response::new(proto::RollbackResponse {
status: Some(gql_status::success()),
}))
}
Err(err) => {
tracing::warn!(error = %err, "rollback failed");
let status = match err.gql_status() {
Some(s) => s.clone(),
None => gql_status::error(gql_status::TRANSACTION_ROLLBACK, err.to_string()),
};
Ok(Response::new(proto::RollbackResponse {
status: Some(status),
}))
}
}
}
}
struct ResultStreamAdapter {
inner: Pin<Box<dyn ResultStream>>,
}
impl Stream for ResultStreamAdapter {
type Item = Result<proto::ExecuteResponse, Status>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
match self.inner.as_mut().poll_next(cx) {
std::task::Poll::Ready(Some(Ok(frame))) => {
let response = match frame {
ResultFrame::Header(h) => proto::ExecuteResponse {
frame: Some(proto::execute_response::Frame::Header(h)),
},
ResultFrame::Batch(b) => proto::ExecuteResponse {
frame: Some(proto::execute_response::Frame::RowBatch(b)),
},
ResultFrame::Summary(s) => proto::ExecuteResponse {
frame: Some(proto::execute_response::Frame::Summary(s)),
},
};
std::task::Poll::Ready(Some(Ok(response)))
}
std::task::Poll::Ready(Some(Err(err))) => {
let status = match err.gql_status() {
Some(s) => s.clone(),
None => gql_status::error(gql_status::DATA_EXCEPTION, err.to_string()),
};
let response = proto::ExecuteResponse {
frame: Some(proto::execute_response::Frame::Summary(
proto::ResultSummary {
status: Some(status),
warnings: Vec::new(),
rows_affected: 0,
counters: HashMap::new(),
},
)),
};
std::task::Poll::Ready(Some(Ok(response)))
}
std::task::Poll::Ready(None) => std::task::Poll::Ready(None),
std::task::Poll::Pending => std::task::Poll::Pending,
}
}
}
fn futures_single_response(
response: proto::ExecuteResponse,
) -> impl Stream<Item = Result<proto::ExecuteResponse, Status>> {
tokio_stream::once(Ok(response))
}