gwp 0.2.1

A standalone, pure Rust gRPC wire protocol for GQL (ISO/IEC 39075)
Documentation
//! `GqlService` gRPC implementation.
//!
//! All GQL-domain errors are returned as GQLSTATUS codes in the
//! response payload. gRPC status is always OK unless there is a
//! transport-level failure.

use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;

use tokio_stream::Stream;
use tonic::{Request, Response, Status};

use crate::proto;
use crate::proto::gql_service_server::GqlService;
use crate::status as gql_status;
use crate::types::Value;

use super::backend::{GqlBackend, ResultFrame, ResultStream};
use super::{SessionHandle, SessionManager, TransactionHandle, TransactionManager};

/// Implementation of the `GqlService` gRPC service.
pub struct GqlServiceImpl<B: GqlBackend> {
    backend: Arc<B>,
    sessions: SessionManager,
    transactions: TransactionManager,
}

impl<B: GqlBackend> GqlServiceImpl<B> {
    /// Create a new GQL service.
    pub fn new(
        backend: Arc<B>,
        sessions: SessionManager,
        transactions: TransactionManager,
    ) -> Self {
        Self {
            backend,
            sessions,
            transactions,
        }
    }

    /// Validate a session exists and update its activity timestamp.
    async fn validate_session(&self, session_id: &str) -> Result<(), Status> {
        if self.sessions.exists(session_id).await {
            self.sessions.touch(session_id).await;
            Ok(())
        } else {
            Err(Status::not_found(format!("session {session_id} not found")))
        }
    }
}

#[tonic::async_trait]
impl<B: GqlBackend> GqlService for GqlServiceImpl<B> {
    type ExecuteStream = Pin<Box<dyn Stream<Item = Result<proto::ExecuteResponse, Status>> + Send>>;

    #[tracing::instrument(skip(self, request), fields(session_id, statement))]
    async fn execute(
        &self,
        request: Request<proto::ExecuteRequest>,
    ) -> Result<Response<Self::ExecuteStream>, Status> {
        let req = request.into_inner();
        let span = tracing::Span::current();
        span.record("session_id", &req.session_id);
        span.record(
            "statement",
            tracing::field::display(if req.statement.len() > 100 {
                &req.statement[..100]
            } else {
                &req.statement
            }),
        );

        self.validate_session(&req.session_id).await?;

        let session = SessionHandle(req.session_id.clone());
        let transaction = if let Some(ref tx_id) = req.transaction_id {
            // Validate the transaction belongs to this session
            self.transactions
                .validate(tx_id, &req.session_id)
                .await
                .map_err(|e| e.to_grpc_status())?;
            Some(TransactionHandle(tx_id.clone()))
        } else {
            None
        };

        let parameters: HashMap<String, Value> = req
            .parameters
            .into_iter()
            .map(|(k, v)| (k, Value::from(v)))
            .collect();

        let result_stream = self
            .backend
            .execute(&session, &req.statement, &parameters, transaction.as_ref())
            .await;

        match result_stream {
            Ok(stream) => {
                let output = ResultStreamAdapter { inner: stream };
                Ok(Response::new(Box::pin(output)))
            }
            Err(err) => {
                tracing::warn!(error = %err, "execute failed");
                // GQL errors go in the response payload, not gRPC status
                let status = match err.gql_status() {
                    Some(s) => s.clone(),
                    None => gql_status::error(gql_status::DATA_EXCEPTION, err.to_string()),
                };

                let summary_stream = futures_single_response(proto::ExecuteResponse {
                    frame: Some(proto::execute_response::Frame::Summary(
                        proto::ResultSummary {
                            status: Some(status),
                            warnings: Vec::new(),
                            rows_affected: 0,
                            counters: HashMap::new(),
                        },
                    )),
                });

                Ok(Response::new(Box::pin(summary_stream)))
            }
        }
    }

    #[tracing::instrument(skip(self, request), fields(session_id))]
    async fn begin_transaction(
        &self,
        request: Request<proto::BeginRequest>,
    ) -> Result<Response<proto::BeginResponse>, Status> {
        let req = request.into_inner();
        tracing::Span::current().record("session_id", &req.session_id);
        self.validate_session(&req.session_id).await?;

        let session = SessionHandle(req.session_id.clone());
        let mode =
            proto::TransactionMode::try_from(req.mode).unwrap_or(proto::TransactionMode::ReadWrite);

        match self.backend.begin_transaction(&session, mode).await {
            Ok(handle) => {
                let tx_id = handle.0.clone();

                if let Err(e) = self
                    .transactions
                    .register(&tx_id, &req.session_id, mode)
                    .await
                {
                    // Roll back the backend transaction if we can't register it
                    let _ = self.backend.rollback(&session, &handle).await;
                    tracing::warn!(session_id = %req.session_id, "double begin rejected");
                    return Ok(Response::new(proto::BeginResponse {
                        transaction_id: String::new(),
                        status: Some(gql_status::error(
                            gql_status::ACTIVE_TRANSACTION,
                            e.to_string(),
                        )),
                    }));
                }

                self.sessions
                    .set_active_transaction(&req.session_id, Some(tx_id.clone()))
                    .await
                    .ok();

                tracing::info!(session_id = %req.session_id, transaction_id = %tx_id, "transaction started");

                Ok(Response::new(proto::BeginResponse {
                    transaction_id: tx_id,
                    status: Some(gql_status::success()),
                }))
            }
            Err(err) => {
                let status = match err.gql_status() {
                    Some(s) => s.clone(),
                    None => gql_status::error(gql_status::ACTIVE_TRANSACTION, err.to_string()),
                };
                Ok(Response::new(proto::BeginResponse {
                    transaction_id: String::new(),
                    status: Some(status),
                }))
            }
        }
    }

    #[tracing::instrument(skip(self, request), fields(session_id, transaction_id))]
    async fn commit(
        &self,
        request: Request<proto::CommitRequest>,
    ) -> Result<Response<proto::CommitResponse>, Status> {
        let req = request.into_inner();
        let span = tracing::Span::current();
        span.record("session_id", &req.session_id);
        span.record("transaction_id", &req.transaction_id);
        self.validate_session(&req.session_id).await?;

        if let Err(e) = self
            .transactions
            .validate(&req.transaction_id, &req.session_id)
            .await
        {
            return Ok(Response::new(proto::CommitResponse {
                status: Some(gql_status::error(
                    gql_status::INVALID_TRANSACTION_STATE,
                    e.to_string(),
                )),
            }));
        }

        let session = SessionHandle(req.session_id.clone());
        let transaction = TransactionHandle(req.transaction_id.clone());

        match self.backend.commit(&session, &transaction).await {
            Ok(()) => {
                self.transactions.remove(&req.transaction_id).await.ok();
                self.sessions
                    .set_active_transaction(&req.session_id, None)
                    .await
                    .ok();

                tracing::info!("transaction committed");

                Ok(Response::new(proto::CommitResponse {
                    status: Some(gql_status::success()),
                }))
            }
            Err(err) => {
                tracing::warn!(error = %err, "commit failed");
                let status = match err.gql_status() {
                    Some(s) => s.clone(),
                    None => gql_status::error(gql_status::TRANSACTION_ROLLBACK, err.to_string()),
                };
                Ok(Response::new(proto::CommitResponse {
                    status: Some(status),
                }))
            }
        }
    }

    #[tracing::instrument(skip(self, request), fields(session_id, transaction_id))]
    async fn rollback(
        &self,
        request: Request<proto::RollbackRequest>,
    ) -> Result<Response<proto::RollbackResponse>, Status> {
        let req = request.into_inner();
        let span = tracing::Span::current();
        span.record("session_id", &req.session_id);
        span.record("transaction_id", &req.transaction_id);
        self.validate_session(&req.session_id).await?;

        if let Err(e) = self
            .transactions
            .validate(&req.transaction_id, &req.session_id)
            .await
        {
            return Ok(Response::new(proto::RollbackResponse {
                status: Some(gql_status::error(
                    gql_status::INVALID_TRANSACTION_STATE,
                    e.to_string(),
                )),
            }));
        }

        let session = SessionHandle(req.session_id.clone());
        let transaction = TransactionHandle(req.transaction_id.clone());

        match self.backend.rollback(&session, &transaction).await {
            Ok(()) => {
                self.transactions.remove(&req.transaction_id).await.ok();
                self.sessions
                    .set_active_transaction(&req.session_id, None)
                    .await
                    .ok();

                tracing::info!("transaction rolled back");

                Ok(Response::new(proto::RollbackResponse {
                    status: Some(gql_status::success()),
                }))
            }
            Err(err) => {
                tracing::warn!(error = %err, "rollback failed");
                let status = match err.gql_status() {
                    Some(s) => s.clone(),
                    None => gql_status::error(gql_status::TRANSACTION_ROLLBACK, err.to_string()),
                };
                Ok(Response::new(proto::RollbackResponse {
                    status: Some(status),
                }))
            }
        }
    }
}

// ============================================================================
// Stream adapters
// ============================================================================

/// Adapts a `ResultStream` into a tonic-compatible `Stream`.
struct ResultStreamAdapter {
    inner: Pin<Box<dyn ResultStream>>,
}

impl Stream for ResultStreamAdapter {
    type Item = Result<proto::ExecuteResponse, Status>;

    fn poll_next(
        mut self: Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Option<Self::Item>> {
        match self.inner.as_mut().poll_next(cx) {
            std::task::Poll::Ready(Some(Ok(frame))) => {
                let response = match frame {
                    ResultFrame::Header(h) => proto::ExecuteResponse {
                        frame: Some(proto::execute_response::Frame::Header(h)),
                    },
                    ResultFrame::Batch(b) => proto::ExecuteResponse {
                        frame: Some(proto::execute_response::Frame::RowBatch(b)),
                    },
                    ResultFrame::Summary(s) => proto::ExecuteResponse {
                        frame: Some(proto::execute_response::Frame::Summary(s)),
                    },
                };
                std::task::Poll::Ready(Some(Ok(response)))
            }
            std::task::Poll::Ready(Some(Err(err))) => {
                // Convert backend error to a summary frame with GQLSTATUS
                let status = match err.gql_status() {
                    Some(s) => s.clone(),
                    None => gql_status::error(gql_status::DATA_EXCEPTION, err.to_string()),
                };
                let response = proto::ExecuteResponse {
                    frame: Some(proto::execute_response::Frame::Summary(
                        proto::ResultSummary {
                            status: Some(status),
                            warnings: Vec::new(),
                            rows_affected: 0,
                            counters: HashMap::new(),
                        },
                    )),
                };
                std::task::Poll::Ready(Some(Ok(response)))
            }
            std::task::Poll::Ready(None) => std::task::Poll::Ready(None),
            std::task::Poll::Pending => std::task::Poll::Pending,
        }
    }
}

/// Create a stream that yields a single response then completes.
fn futures_single_response(
    response: proto::ExecuteResponse,
) -> impl Stream<Item = Result<proto::ExecuteResponse, Status>> {
    tokio_stream::once(Ok(response))
}