use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use tonic::{Request, Response, Status};
use crate::proto;
use crate::proto::session_service_server::SessionService;
use super::auth::{AuthInfo, AuthValidator};
use super::backend::{GqlBackend, ResetTarget, SessionConfig, SessionProperty};
use super::{SessionManager, TransactionManager};
pub struct SessionServiceImpl<B: GqlBackend> {
backend: Arc<B>,
sessions: SessionManager,
transactions: TransactionManager,
auth: Option<Arc<dyn AuthValidator>>,
}
impl<B: GqlBackend> SessionServiceImpl<B> {
pub fn new(
backend: Arc<B>,
sessions: SessionManager,
transactions: TransactionManager,
auth: Option<Arc<dyn AuthValidator>>,
) -> Self {
Self {
backend,
sessions,
transactions,
auth,
}
}
}
#[tonic::async_trait]
impl<B: GqlBackend> SessionService for SessionServiceImpl<B> {
#[tracing::instrument(skip(self, request))]
async fn handshake(
&self,
request: Request<proto::HandshakeRequest>,
) -> Result<Response<proto::HandshakeResponse>, Status> {
let req = request.into_inner();
let auth_info: Option<AuthInfo> = if let Some(ref auth) = self.auth {
if let Some(ref creds) = req.credentials {
let info = auth.validate(creds).await.map_err(|_| {
tracing::warn!("authentication failed");
Status::unauthenticated("invalid credentials")
})?;
Some(info)
} else {
tracing::warn!("handshake missing credentials");
return Err(Status::unauthenticated("credentials required"));
}
} else {
None
};
let config = SessionConfig {
protocol_version: req.protocol_version,
client_info: req.client_info,
auth_info,
};
let handle = self
.backend
.create_session(&config)
.await
.map_err(|e| e.to_grpc_status())?;
if let Err(e) = self.sessions.register(&handle.0).await {
let _ = self.backend.close_session(&handle).await;
tracing::warn!("session limit reached");
return Err(Status::resource_exhausted(e.to_string()));
}
tracing::info!(session_id = %handle.0, "session created");
Ok(Response::new(proto::HandshakeResponse {
protocol_version: 1,
session_id: handle.0,
server_info: Some(proto::ServerInfo {
name: "gql-wire-protocol".to_owned(),
version: env!("CARGO_PKG_VERSION").to_owned(),
features: Vec::new(),
}),
limits: std::collections::HashMap::new(),
}))
}
#[tracing::instrument(skip(self, request), fields(session_id))]
async fn configure(
&self,
request: Request<proto::ConfigureRequest>,
) -> Result<Response<proto::ConfigureResponse>, Status> {
let req = request.into_inner();
let session_id = &req.session_id;
tracing::Span::current().record("session_id", session_id);
if !self.sessions.exists(session_id).await {
return Err(Status::not_found(format!("session {session_id} not found")));
}
self.sessions.touch(session_id).await;
let property = match req.property {
Some(proto::configure_request::Property::Schema(s)) => SessionProperty::Schema(s),
Some(proto::configure_request::Property::Graph(g)) => SessionProperty::Graph(g),
Some(proto::configure_request::Property::TimeZoneOffsetMinutes(tz)) => {
SessionProperty::TimeZone(tz)
}
Some(proto::configure_request::Property::Parameter(p)) => SessionProperty::Parameter {
name: p.name,
value: p
.value
.map_or(crate::types::Value::Null, crate::types::Value::from),
},
None => return Err(Status::invalid_argument("no property specified")),
};
self.backend
.configure_session(&super::SessionHandle(session_id.clone()), property.clone())
.await
.map_err(|e| e.to_grpc_status())?;
self.sessions
.configure(session_id, &property)
.await
.map_err(|e| e.to_grpc_status())?;
Ok(Response::new(proto::ConfigureResponse {}))
}
#[tracing::instrument(skip(self, request), fields(session_id))]
async fn reset(
&self,
request: Request<proto::ResetRequest>,
) -> Result<Response<proto::ResetResponse>, Status> {
let req = request.into_inner();
let session_id = &req.session_id;
tracing::Span::current().record("session_id", session_id);
if !self.sessions.exists(session_id).await {
return Err(Status::not_found(format!("session {session_id} not found")));
}
self.sessions.touch(session_id).await;
let target = match proto::ResetTarget::try_from(req.target) {
Ok(proto::ResetTarget::ResetAll) => ResetTarget::All,
Ok(proto::ResetTarget::ResetSchema) => ResetTarget::Schema,
Ok(proto::ResetTarget::ResetGraph) => ResetTarget::Graph,
Ok(proto::ResetTarget::ResetTimeZone) => ResetTarget::TimeZone,
Ok(proto::ResetTarget::ResetParameters) => ResetTarget::Parameters,
Err(_) => return Err(Status::invalid_argument("invalid reset target")),
};
self.backend
.reset_session(&super::SessionHandle(session_id.clone()), target)
.await
.map_err(|e| e.to_grpc_status())?;
self.sessions
.reset(session_id, target)
.await
.map_err(|e| e.to_grpc_status())?;
Ok(Response::new(proto::ResetResponse {}))
}
#[tracing::instrument(skip(self, request), fields(session_id))]
async fn close_session(
&self,
request: Request<proto::CloseSessionRequest>,
) -> Result<Response<proto::CloseSessionResponse>, Status> {
let req = request.into_inner();
let session_id = &req.session_id;
tracing::Span::current().record("session_id", session_id);
if !self.sessions.exists(session_id).await {
return Err(Status::not_found(format!("session {session_id} not found")));
}
let active_txns = self.transactions.remove_for_session(session_id).await;
for tx_id in &active_txns {
tracing::info!(session_id, transaction_id = %tx_id, "rolling back transaction on close");
let _ = self
.backend
.rollback(
&super::SessionHandle(session_id.clone()),
&super::TransactionHandle(tx_id.clone()),
)
.await;
}
self.backend
.close_session(&super::SessionHandle(session_id.clone()))
.await
.map_err(|e| e.to_grpc_status())?;
self.sessions.remove(session_id).await;
tracing::info!(session_id, "session closed");
Ok(Response::new(proto::CloseSessionResponse {}))
}
#[tracing::instrument(skip(self, request), fields(session_id))]
async fn ping(
&self,
request: Request<proto::PingRequest>,
) -> Result<Response<proto::PongResponse>, Status> {
let req = request.into_inner();
tracing::Span::current().record("session_id", &req.session_id);
if !self.sessions.exists(&req.session_id).await {
return Err(Status::not_found(format!(
"session {} not found",
req.session_id
)));
}
self.sessions.touch(&req.session_id).await;
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |d| i64::try_from(d.as_millis()).unwrap_or(i64::MAX));
Ok(Response::new(proto::PongResponse { timestamp }))
}
}