from __future__ import annotations
from typing import Any
from gwp_py._convert import value_to_proto
from gwp_py._generated import gql_service_pb2 as gql_pb2
from gwp_py._generated import gql_types_pb2 as types_pb2
from gwp_py.errors import GqlStatusError, TransactionError
from gwp_py.result import ResultCursor
from gwp_py.status import is_exception
class Transaction:
def __init__(
self,
session_id: str,
transaction_id: str,
gql_stub: Any,
):
self._session_id = session_id
self._transaction_id = transaction_id
self._gql_stub = gql_stub
self._committed = False
self._rolled_back = False
@classmethod
async def begin(
cls,
session_id: str,
gql_stub: Any,
mode: int,
) -> Transaction:
resp = await gql_stub.BeginTransaction(
gql_pb2.BeginRequest(session_id=session_id, mode=mode)
)
if resp.status and is_exception(resp.status.code):
raise GqlStatusError(resp.status.code, resp.status.message)
if not resp.transaction_id:
raise TransactionError("server returned empty transaction ID")
return cls(session_id, resp.transaction_id, gql_stub)
@property
def transaction_id(self) -> str:
return self._transaction_id
async def execute(
self,
statement: str,
parameters: dict[str, Any] | None = None,
) -> ResultCursor:
proto_params = {}
if parameters:
for k, v in parameters.items():
proto_params[k] = value_to_proto(v, types_pb2)
stream = self._gql_stub.Execute(
gql_pb2.ExecuteRequest(
session_id=self._session_id,
statement=statement,
parameters=proto_params,
transaction_id=self._transaction_id,
)
)
return ResultCursor(stream)
async def commit(self) -> None:
resp = await self._gql_stub.Commit(
gql_pb2.CommitRequest(
session_id=self._session_id,
transaction_id=self._transaction_id,
)
)
self._committed = True
if resp.status and is_exception(resp.status.code):
raise GqlStatusError(resp.status.code, resp.status.message)
async def rollback(self) -> None:
if self._committed or self._rolled_back:
return
resp = await self._gql_stub.Rollback(
gql_pb2.RollbackRequest(
session_id=self._session_id,
transaction_id=self._transaction_id,
)
)
self._rolled_back = True
if resp.status and is_exception(resp.status.code):
raise GqlStatusError(resp.status.code, resp.status.message)
async def __aenter__(self) -> Transaction:
return self
async def __aexit__(self, exc_type: type | None, *args: object) -> None:
if exc_type is not None:
await self.rollback()
elif not self._committed:
await self.commit()