---
source: crates/nautilus-codegen/tests/snapshot_tests.rs
assertion_line: 414
expression: async_code
---
"""User model and related types."""
from __future__ import annotations
from dataclasses import asdict, is_dataclass
from pydantic import BaseModel, Field
from typing import TYPE_CHECKING, Any, Dict, List, Optional, ClassVar
from typing import TypedDict, NotRequired, Literal, Union
from datetime import datetime
from uuid import UUID
from decimal import Decimal
from enum import Enum
from .._internal.descriptors import classproperty
if TYPE_CHECKING:
from .._internal.client import NautilusClient
class User(BaseModel):
"""User model mapped to 'User' table."""
id: Optional[int] = None
name: Optional[str] = None
nautilus: ClassVar["UserDelegate"]
if TYPE_CHECKING:
# Type hint for IDE autocomplete - at runtime this is a classproperty descriptor
nautilus: "UserDelegate"
else:
@classproperty
def nautilus(cls) -> "UserDelegate":
"""Access the auto-registered client delegate for User.
This property allows calling model operations directly on the class:
users = await User.nautilus.find_many(where={"role": "ADMIN"}, take=10)
user = await User.nautilus.create({"email": "alice@example.com"})
Requires a Nautilus instance with auto_register=True:
async with Nautilus(auto_register=True):
await User.nautilus.find_many()
Returns:
The delegate instance for this model.
Raises:
RuntimeError: If no auto-registered Nautilus instance exists.
"""
from .._internal.client import NautilusClient
client = NautilusClient.get_global_instance()
if client is None:
raise RuntimeError(
f"No Nautilus instance with auto_register=True found. "
f"Create one with: Nautilus(auto_register=True)"
)
return client.get_delegate("user")
# Engine response key mapping: "table__db_column" -> Python field name
_User_row_map: Dict[str, str] = {
"User__id": "id",
"User__name": "name",
}
_User_fields: frozenset = frozenset({
"id",
"name",
})
def _coerce_boolean_value(value: Any) -> Any:
if isinstance(value, bool):
return value
if isinstance(value, int) and value in (0, 1):
return bool(value)
if isinstance(value, str):
lowered = value.lower()
if lowered in ("true", "1"):
return True
if lowered in ("false", "0"):
return False
return value
_User_scalar_coercers: Dict[str, Any] = {
}
# Python field name -> DB column name mapping (handles @map renames)
_User_py_to_db: Dict[str, str] = {
"id": "id",
"name": "name",
}
_User_py_to_logical: Dict[str, str] = {
"id": "id",
"name": "name",
}
_MISSING = object()
class StringFilter(TypedDict, total=False):
"""String field filter operations."""
equals: NotRequired[str]
not_: NotRequired[str]
contains: NotRequired[str]
startswith: NotRequired[str]
endswith: NotRequired[str]
in_: NotRequired[List[str]]
not_in: NotRequired[List[str]]
class IntFilter(TypedDict, total=False):
"""Integer field filter operations."""
equals: NotRequired[int]
not_: NotRequired[int]
lt: NotRequired[int]
lte: NotRequired[int]
gt: NotRequired[int]
gte: NotRequired[int]
in_: NotRequired[List[int]]
not_in: NotRequired[List[int]]
class FloatFilter(TypedDict, total=False):
"""Float field filter operations."""
equals: NotRequired[float]
not_: NotRequired[float]
lt: NotRequired[float]
lte: NotRequired[float]
gt: NotRequired[float]
gte: NotRequired[float]
in_: NotRequired[List[float]]
not_in: NotRequired[List[float]]
class DecimalFilter(TypedDict, total=False):
"""Decimal field filter operations."""
equals: NotRequired[Decimal]
not_: NotRequired[Decimal]
lt: NotRequired[Decimal]
lte: NotRequired[Decimal]
gt: NotRequired[Decimal]
gte: NotRequired[Decimal]
in_: NotRequired[List[Decimal]]
not_in: NotRequired[List[Decimal]]
class DateTimeFilter(TypedDict, total=False):
"""DateTime field filter operations."""
equals: NotRequired[datetime]
not_: NotRequired[datetime]
lt: NotRequired[datetime]
lte: NotRequired[datetime]
gt: NotRequired[datetime]
gte: NotRequired[datetime]
in_: NotRequired[List[datetime]]
not_in: NotRequired[List[datetime]]
class UuidFilter(TypedDict, total=False):
"""UUID field filter operations."""
equals: NotRequired[UUID]
not_: NotRequired[UUID]
in_: NotRequired[List[UUID]]
not_in: NotRequired[List[UUID]]
class BoolFilter(TypedDict, total=False):
"""Boolean field filter operations."""
equals: NotRequired[bool]
not_: NotRequired[bool]
class UserWhereInput(TypedDict, total=False):
"""Filter conditions for User queries."""
id: NotRequired[Union[int, IntFilter]]
name: NotRequired[Union[str, StringFilter]]
AND: NotRequired[Union[UserWhereInput, List[UserWhereInput]]]
OR: NotRequired[Union[UserWhereInput, List[UserWhereInput]]]
NOT: NotRequired[Union[UserWhereInput, List[UserWhereInput]]]
class UserCreateInput(TypedDict, total=False):
"""Input data for creating a User."""
id: NotRequired[int]
name: str
class UserUpdateInput(TypedDict, total=False):
"""Input data for updating a User."""
name: NotRequired[str]
class UserOrderByInput(TypedDict, total=False):
"""Ordering options for User queries."""
id: NotRequired[Literal["asc", "desc"]]
name: NotRequired[Literal["asc", "desc"]]
class UserSelectInput(TypedDict, total=False):
"""Scalar field projection for User queries."""
id: NotRequired[bool]
name: NotRequired[bool]
class UserUpsertInput(TypedDict, total=False):
"""Input data for upserting a User.
Contains both ``create`` (used when no matching record exists) and
``update`` (applied when a matching record is found).
"""
create: NotRequired[UserCreateInput]
update: NotRequired[UserUpdateInput]
UserScalarFieldKeys = Literal["id", "name"]
SortOrder = Literal["asc", "desc"]
class UserCountAggregateInput(TypedDict, total=False):
"""Select which fields to include in COUNT aggregation."""
_all: NotRequired[bool]
id: NotRequired[bool]
name: NotRequired[bool]
class UserAvgAggregateInput(TypedDict, total=False):
"""Select which numeric fields to include in AVG aggregation."""
id: NotRequired[bool]
class UserSumAggregateInput(TypedDict, total=False):
"""Select which numeric fields to include in SUM aggregation."""
id: NotRequired[bool]
class UserMinAggregateInput(TypedDict, total=False):
"""Select which fields to include in MIN aggregation."""
id: NotRequired[bool]
name: NotRequired[bool]
class UserMaxAggregateInput(TypedDict, total=False):
"""Select which fields to include in MAX aggregation."""
id: NotRequired[bool]
name: NotRequired[bool]
class UserGroupByOutput(TypedDict, total=False):
"""Result row from a group_by query on User."""
id: NotRequired[Optional[int]]
name: NotRequired[str]
_count: NotRequired[Dict[str, int]]
_avg: NotRequired[Dict[str, float]]
_sum: NotRequired[Dict[str, Any]]
_min: NotRequired[Dict[str, Any]]
_max: NotRequired[Dict[str, Any]]
def _serialize_wire_value(value: Any) -> Any:
"""Convert generated client values into JSON-RPC-safe payloads."""
if isinstance(value, Enum):
return value.value
if is_dataclass(value):
return {k: _serialize_wire_value(v) for k, v in asdict(value).items()}
if isinstance(value, list):
return [_serialize_wire_value(item) for item in value]
if isinstance(value, dict):
return {k: _serialize_wire_value(v) for k, v in value.items()}
return value
def _process_create_data(data: Dict[str, Any], py_to_db: Dict[str, str]) -> Dict[str, Any]:
"""Convert CreateInput/UpdateInput dict to DB column names."""
result = {}
for key, value in data.items():
db_key = py_to_db.get(key, key)
result[db_key] = _serialize_wire_value(value)
return result
def _process_where_filters(where: Dict[str, Any], py_to_db: Dict[str, str]) -> Dict[str, Any]:
"""Convert WhereInput format to internal filter format."""
result = {}
for key, value in where.items():
if key in ('AND', 'OR', 'NOT'):
if isinstance(value, list):
result[key] = [_process_where_filters(item, py_to_db) for item in value]
else:
result[key] = _process_where_filters(value, py_to_db)
else:
db_key = py_to_db.get(key, key)
if isinstance(value, dict):
# Filter object with operators – keep nested {field: {op: value}} format
# so the engine's parse_field_condition receives {"id": {"in": [...]}}.
ops = {}
for op, op_value in value.items():
# Convert Python-safe names back (in_ -> in, not_ -> not)
actual_op = op.rstrip('_') if op.endswith('_') and op[:-1] in ('in', 'not') else op
coerced = _serialize_wire_value(op_value)
if actual_op == 'equals':
# 'equals' shorthand maps to direct equality
result[db_key] = coerced
else:
ops[actual_op] = coerced
if ops:
result[db_key] = ops
else:
result[db_key] = _serialize_wire_value(value)
return result
def _process_select_fields(select: Dict[str, bool], py_to_logical: Dict[str, str]) -> Dict[str, bool]:
"""Convert Python field names to logical field names for query projection."""
result = {}
for key, value in select.items():
logical_key = py_to_logical.get(key, key)
result[logical_key] = value
return result
def _serialize_include(include: Any) -> Any:
"""Serialize include dict to the engine wire format.
Converts Python snake_case keys (``order_by``) to camelCase (``orderBy``)
and recursively processes nested includes.
"""
if include is None or isinstance(include, bool):
return include
if not isinstance(include, dict):
return include
result: Dict[str, Any] = {}
for field, spec in include.items():
if isinstance(spec, bool):
result[field] = spec
elif isinstance(spec, dict):
node: Dict[str, Any] = {}
for k, v in spec.items():
if k == "order_by":
if isinstance(v, list):
node["orderBy"] = v
elif isinstance(v, dict):
node["orderBy"] = [{fk: fv} for fk, fv in v.items()]
else:
node["orderBy"] = v
elif k == "include":
node["include"] = _serialize_include(v)
else:
node[k] = v
result[field] = node
else:
result[field] = spec
return result
def _get_wire_value(row: Dict[str, Any], *keys: str) -> Any:
for key in keys:
if key in row:
return row[key]
return _MISSING
def _coerce_user_scalar(py_key: str, value: Any) -> Any:
if py_key in _User_scalar_coercers:
return _User_scalar_coercers[py_key](value)
return value
def _hydrate_user_relation(field_name: str, value: Any) -> Any:
return value
def _user_from_wire(row: Dict[str, Any]) -> "User":
"""Convert engine rows and nested include payloads to User."""
kwargs: Dict[str, Any] = {}
value = _get_wire_value(row, "User__id", "id")
if value is not _MISSING:
if value is not None:
kwargs["id"] = _coerce_user_scalar("id", value)
value = _get_wire_value(row, "User__name", "name")
if value is not _MISSING:
if value is not None:
kwargs["name"] = _coerce_user_scalar("name", value)
return User(**kwargs)
def _User_from_row(row: Dict[str, Any]) -> "User":
"""Convert engine row (table__column keys) to User dataclass."""
return _user_from_wire(row)
def _user_from_row(row: Dict[str, Any]) -> "User":
"""Compatibility wrapper keyed by model snake_case."""
return _user_from_wire(row)
class UserDelegate:
"""Operations for User.
Every method executes immediately when awaited — no extra ``.exec()`` call
is needed::
users = await client.User.find_many(where={"role": "ADMIN"}, take=10)
user = await client.User.create({"email": "alice@example.com"})
"""
def __init__(self, client: NautilusClient) -> None:
self._client = client
async def find_many(
self,
*,
where: Optional[UserWhereInput] = None,
order_by: Optional[UserOrderByInput] = None,
take: Optional[int] = None,
skip: Optional[int] = None,
select: Optional[UserSelectInput] = None,
include: Optional[Dict[str, Any]] = None,
cursor: Optional[Dict[str, Any]] = None,
distinct: Optional[List[str]] = None,
chunk_size: Optional[int] = None,
) -> List[User]:
"""Find all User records matching the given filters.
Args:
where: Filter conditions (optional).
order_by: Sorting directions, e.g. ``{"created_at": "desc"}``.
take: Maximum number of rows to return (LIMIT).
Positive values paginate **forward**; negative values
paginate **backward** from the cursor position.
skip: Number of rows to skip (OFFSET).
select: Scalar field projection, e.g. ``{"id": True, "display_name": True}``.
Primary-key fields are always returned by the engine.
include: Relations to eager-load, e.g. ``{"posts": True}``.
cursor: Stable keyset cursor — a dict of primary-key field name
to value. All PK fields of the model must be present.
The cursor record is included in the result (inclusive).
distinct: List of field names to deduplicate on.
Postgres: ``SELECT DISTINCT ON (col, ...)`` with the
distinct columns auto-prepended to ORDER BY.
SQLite / MySQL: plain ``SELECT DISTINCT`` (full-row
dedup — most effective combined with ``select``).
chunk_size: Optional protocol-level row chunk size for large result
sets. The client still returns one fully merged list.
Returns:
A list of matching ``User`` instances.
"""
args: Dict[str, Any] = {}
if where is not None:
args["where"] = _process_where_filters(where, _User_py_to_db)
if order_by is not None:
if isinstance(order_by, list):
args["orderBy"] = order_by
else:
args["orderBy"] = [{field: direction} for field, direction in order_by.items()]
if take is not None:
args["take"] = take
if skip is not None:
args["skip"] = skip
if select is not None:
args["select"] = _process_select_fields(select, _User_py_to_logical)
if include is not None:
args["include"] = _serialize_include(include)
if cursor is not None:
args["cursor"] = cursor
if distinct is not None:
args["distinct"] = distinct
payload: Dict[str, Any] = {
"protocolVersion": 1,
"model": "User",
"args": args,
}
if chunk_size is not None:
payload["chunkSize"] = chunk_size
result = await self._client._rpc("query.findMany", payload)
return [_User_from_row(row) for row in result.get("data", [])]
async def find_first(
self,
*,
where: Optional[UserWhereInput] = None,
order_by: Optional[UserOrderByInput] = None,
select: Optional[UserSelectInput] = None,
include: Optional[Dict[str, Any]] = None,
) -> Optional[User]:
"""Find the first User record matching the given filters.
Returns:
The first matching ``User`` instance, or ``None``.
"""
rows = await self.find_many(where=where, order_by=order_by, take=1, select=select, include=include)
return rows[0] if rows else None
async def find_unique(
self,
*,
where: UserWhereInput,
select: Optional[UserSelectInput] = None,
include: Optional[Dict[str, Any]] = None,
) -> Optional[User]:
"""Find a single User record by a unique filter.
Returns:
The matching ``User`` instance, or ``None``.
"""
rows = await self.find_many(where=where, take=1, select=select, include=include)
return rows[0] if rows else None
async def find_unique_or_throw(
self,
*,
where: UserWhereInput,
select: Optional[UserSelectInput] = None,
include: Optional[Dict[str, Any]] = None,
) -> User:
"""Find a single User record by a unique filter, or raise if not found.
Returns:
The matching ``User`` instance.
Raises:
NotFoundError: If no record matches the given filter.
"""
from ..errors import NotFoundError
record = await self.find_unique(where=where, select=select, include=include)
if record is None:
raise NotFoundError(
f"findUniqueOrThrow: no User record found matching the given filter"
)
return record
async def find_first_or_throw(
self,
*,
where: Optional[UserWhereInput] = None,
order_by: Optional[UserOrderByInput] = None,
select: Optional[UserSelectInput] = None,
include: Optional[Dict[str, Any]] = None,
) -> User:
"""Find the first User record matching ``where``, or raise if not found.
Returns:
The first matching ``User`` instance.
Raises:
NotFoundError: If no record matches the given filters.
"""
from ..errors import NotFoundError
record = await self.find_first(where=where, order_by=order_by, select=select, include=include)
if record is None:
raise NotFoundError(
f"findFirstOrThrow: no User record found matching the given filter"
)
return record
async def create(self, data: UserCreateInput, *, return_data: bool = True) -> Optional[User]:
"""Create a new User record.
Args:
data: Field values for the new record.
return_data: When ``True`` (default) the created record is returned.
Set to ``False`` for fire-and-forget inserts to avoid the
round-trip overhead of ``RETURNING``.
Returns:
The newly created ``User`` instance, or ``None`` when
``return_data=False``.
"""
result = await self._client._rpc("query.create", {
"protocolVersion": 1,
"model": "User",
"data": _process_create_data(data, _User_py_to_db),
"returnData": return_data,
})
if not return_data:
return None
data_rows = result.get("data", [])
if not data_rows:
raise RuntimeError("Create operation returned no data")
return _User_from_row(data_rows[0])
async def create_many(
self,
data: List[UserCreateInput],
*,
return_data: bool = True,
) -> List[User]:
"""Create multiple User records in a single batch.
Args:
data: List of field-value dicts for the records to create.
return_data: When ``True`` (default) the created records are returned.
Set to ``False`` for bulk inserts to skip ``RETURNING`` overhead.
Returns:
The list of newly created ``User`` instances, or an empty
list when ``return_data=False``.
"""
result = await self._client._rpc("query.createMany", {
"protocolVersion": 1,
"model": "User",
"data": [_process_create_data(_entry, _User_py_to_db) for _entry in data],
"returnData": return_data,
})
if not return_data:
return []
return [_User_from_row(row) for row in result.get("data", [])]
async def update(
self,
*,
where: Optional[UserWhereInput] = None,
data: UserUpdateInput,
return_data: bool = True,
):
"""Update User records matching ``where``.
Args:
where: Filter conditions. If ``None``, all records are updated.
data: Fields to update.
return_data: When ``True`` (default) returns the updated records.
Set to ``False`` for bulk updates to skip ``RETURNING`` overhead;
the affected-row count is returned instead.
Returns:
``List[User]`` when ``return_data=True``, or ``int``
(affected-row count) when ``return_data=False``.
"""
result = await self._client._rpc("query.update", {
"protocolVersion": 1,
"model": "User",
"filter": _process_where_filters(where or {}, _User_py_to_db),
"data": _process_create_data(data, _User_py_to_db),
"returnData": return_data,
})
if not return_data:
return result.get("count", 0)
return [_User_from_row(row) for row in result.get("data", [])]
async def delete(
self,
*,
where: UserWhereInput,
return_data: bool = True,
) -> Optional["User"]:
"""Delete the first User record matching ``where``.
Finds the first matching record and deletes it by its primary key.
Args:
where: Filter conditions. Required.
return_data: When ``True`` (default) returns the deleted record.
Set to ``False`` to skip ``RETURNING`` overhead; ``None`` is
returned in that case.
Returns:
The deleted ``User`` instance when ``return_data=True``,
or ``None`` when ``return_data=False``.
Raises:
NotFoundError: If no record matches the given filter.
"""
from ..errors import NotFoundError
record = await self.find_first(where=where)
if record is None:
raise NotFoundError(
f"delete: no User record found matching the given filter"
)
pk_filter = {
"id": getattr(record, "id"),
}
result = await self._client._rpc("query.delete", {
"protocolVersion": 1,
"model": "User",
"filter": _process_where_filters(pk_filter, _User_py_to_db),
"returnData": return_data,
})
if return_data:
data_rows = result.get("data", [])
if data_rows:
return _User_from_row(data_rows[0])
return record
return None
async def delete_many(
self,
*,
where: Optional[UserWhereInput] = None,
return_data: bool = False,
):
"""Delete all User records matching ``where``.
Args:
where: Filter conditions. If ``None``, all records are deleted.
return_data: When ``True`` returns the deleted records (uses
``RETURNING``). Defaults to ``False`` for efficiency.
Returns:
``int`` (affected-row count) when ``return_data=False`` (default),
or ``List[User]`` when ``return_data=True``.
"""
result = await self._client._rpc("query.delete", {
"protocolVersion": 1,
"model": "User",
"filter": _process_where_filters(where or {}, _User_py_to_db),
"returnData": return_data,
})
if return_data:
return [_User_from_row(row) for row in result.get("data", [])]
return result.get("count", 0)
async def count(
self,
*,
where: Optional[UserWhereInput] = None,
take: Optional[int] = None,
skip: Optional[int] = None,
cursor: Optional[Dict[str, Any]] = None,
) -> int:
"""Count the number of User records matching the given filters.
Args:
where: Filter conditions (optional).
take: Limit the count to a forward-pagination window of at most this
many rows. When combined with ``skip``, counts only the rows
in that slice (mirrors ``findMany`` semantics).
skip: Number of rows to skip before counting.
cursor: Keyset cursor — combined with ``take``/``skip`` to count
within a paginated window.
Returns:
The integer count of matching records.
"""
args: Dict[str, Any] = {}
if where is not None:
args["where"] = _process_where_filters(where, _User_py_to_db)
if take is not None:
args["take"] = take
if skip is not None:
args["skip"] = skip
if cursor is not None:
args["cursor"] = cursor
result = await self._client._rpc("query.count", {
"protocolVersion": 1,
"model": "User",
"args": args if args else None,
})
return result.get("count", 0)
async def group_by(
self,
by: List[UserScalarFieldKeys],
*,
where: Optional[UserWhereInput] = None,
having: Optional[Dict[str, Any]] = None,
take: Optional[int] = None,
skip: Optional[int] = None,
count: Optional[Union[bool, UserCountAggregateInput]] = None,
avg: Optional[UserAvgAggregateInput] = None,
sum: Optional[UserSumAggregateInput] = None,
min: Optional[UserMinAggregateInput] = None,
max: Optional[UserMaxAggregateInput] = None,
order: Optional[Union[
Dict[UserScalarFieldKeys, SortOrder],
List[Dict[UserScalarFieldKeys, SortOrder]],
]] = None,
) -> List[UserGroupByOutput]:
"""Group User records and compute aggregates.
Args:
by: List of field names to group by. Required.
where: Pre-group filter (applied before grouping).
having: Post-group filter (applied after grouping), e.g.
``{"_count": {"_all": {"gt": 5}}}``.
take: Maximum number of groups to return.
skip: Number of groups to skip.
count: Compute COUNT. Pass ``True`` for ``COUNT(*)`` or a
``UserCountAggregateInput`` dict to count specific fields.
avg: Compute AVG for the specified numeric fields.
sum: Compute SUM for the specified numeric fields.
min: Compute MIN for the specified fields.
max: Compute MAX for the specified fields.
order: Sort order — a mapping or list of mappings from field name
to ``"asc"``/``"desc"``. Aggregate keys like ``_count``,
``_sum``, etc. are also accepted.
Returns:
A list of ``UserGroupByOutput`` dicts, one per group.
"""
args: Dict[str, Any] = {"by": list(by)}
if where is not None:
args["where"] = _process_where_filters(where, _User_py_to_db)
if having is not None:
args["having"] = having
if take is not None:
args["take"] = take
if skip is not None:
args["skip"] = skip
if count is not None:
args["count"] = count
if avg is not None:
args["avg"] = avg
if sum is not None:
args["sum"] = sum
if min is not None:
args["min"] = min
if max is not None:
args["max"] = max
if order is not None:
if isinstance(order, list):
args["orderBy"] = order
else:
args["orderBy"] = [{field: direction} for field, direction in order.items()]
result = await self._client._rpc("query.groupBy", {
"protocolVersion": 1,
"model": "User",
"args": args,
})
return result.get("data", [])
async def raw_query(
self,
sql: str,
) -> List[Dict[str, Any]]:
"""Execute a raw SQL string and return the result rows as dicts.
The SQL is sent to the database as-is with no parameter binding.
Prefer :meth:`raw_stmt_query` when user-supplied values are involved to
avoid SQL injection.
Args:
sql: Raw SQL string to execute.
Returns:
A list of dicts mapping column name to value for each result row.
"""
result = await self._client._rpc("query.rawQuery", {
"protocolVersion": 1,
"sql": sql,
})
return result.get("data", [])
async def raw_stmt_query(
self,
sql: str,
params: Optional[List[Any]] = None,
) -> List[Dict[str, Any]]:
"""Execute a raw prepared-statement query with bound parameters.
Use ``$1``, ``$2``, … (PostgreSQL) or ``?`` (MySQL / SQLite) as
placeholders. Parameters are bound in the order they appear in *params*.
Args:
sql: SQL string containing parameter placeholders.
params: Ordered list of values to bind to the placeholders.
Returns:
A list of dicts mapping column name to value for each result row.
"""
result = await self._client._rpc("query.rawStmtQuery", {
"protocolVersion": 1,
"sql": sql,
"params": params or [],
})
return result.get("data", [])
async def upsert(
self,
*,
where: UserWhereInput,
data: UserUpsertInput,
include: Optional[Dict[str, Any]] = None,
return_data: bool = True,
) -> Optional[User]:
"""Create or update a User record.
Looks up an existing record using ``where``. If found, applies
``data["update"]``; otherwise creates a new record with ``data["create"]``.
Args:
where: Filter to locate the existing record (should target a unique field).
data: Dict with ``create`` and ``update`` keys, each holding the
respective field values.
include: Relations to eager-load on the returned record.
return_data: When ``True`` (default) returns the record.
Set to ``False`` to skip ``RETURNING`` overhead; ``None`` is
returned in that case.
Returns:
The created or updated ``User`` instance, or ``None`` when
``return_data=False``.
"""
create_data: UserCreateInput = data.get("create", {}) # type: ignore[assignment]
update_data: UserUpdateInput = data.get("update", {}) # type: ignore[assignment]
existing = await self.find_first(where=where)
if existing is not None:
result = await self.update(where=where, data=update_data, return_data=return_data)
if not return_data:
return None
if isinstance(result, list):
record = result[0] if result else existing
else:
record = existing
if include:
rows = await self.find_many(where=where, take=1, include=include)
return rows[0] if rows else record
return record
created = await self.create(create_data, return_data=return_data)
if not return_data:
return None
if created is None:
raise RuntimeError("upsert: create operation returned no data")
if include:
rows = await self.find_many(where=where, take=1, include=include)
return rows[0] if rows else created
return created
__all__ = [
"User",
"UserDelegate",
"UserWhereInput",
"UserCreateInput",
"UserUpdateInput",
"UserUpsertInput",
"UserOrderByInput",
"UserSelectInput",
"UserScalarFieldKeys",
"UserCountAggregateInput",
"UserAvgAggregateInput",
"UserSumAggregateInput",
"UserMinAggregateInput",
"UserMaxAggregateInput",
"UserGroupByOutput",
"SortOrder",
"StringFilter",
"IntFilter",
"FloatFilter",
"DecimalFilter",
"DateTimeFilter",
"UuidFilter",
"BoolFilter",
]