from __future__ import annotations
import json
from dataclasses import dataclass
from typing import Any, Callable, Generic, Iterator, Optional, TypeVar, overload
from net import Net, IngestResult, StoredEvent, PollResponse, Stats
from net_sdk.stream import EventStream, SubscribeOpts, TypedEventStream
from net_sdk.channel import TypedChannel
T = TypeVar("T")
@dataclass
class Receipt:
shard_id: int
timestamp: int
class NetNode:
def __init__(
self,
shards: Optional[int] = None,
buffer_capacity: Optional[int] = None,
backpressure: Optional[str] = None,
*,
redis_url: Optional[str] = None,
redis_prefix: Optional[str] = None,
redis_pipeline_size: Optional[int] = None,
redis_pool_size: Optional[int] = None,
redis_connect_timeout_ms: Optional[int] = None,
redis_command_timeout_ms: Optional[int] = None,
redis_max_stream_len: Optional[int] = None,
jetstream_url: Optional[str] = None,
jetstream_prefix: Optional[str] = None,
jetstream_connect_timeout_ms: Optional[int] = None,
jetstream_request_timeout_ms: Optional[int] = None,
jetstream_max_messages: Optional[int] = None,
jetstream_max_bytes: Optional[int] = None,
jetstream_max_age_ms: Optional[int] = None,
jetstream_replicas: Optional[int] = None,
mesh_bind: Optional[str] = None,
mesh_peer: Optional[str] = None,
mesh_psk: Optional[str] = None,
mesh_role: Optional[str] = None,
mesh_peer_public_key: Optional[str] = None,
mesh_secret_key: Optional[str] = None,
mesh_public_key: Optional[str] = None,
mesh_reliability: Optional[str] = None,
mesh_heartbeat_interval_ms: Optional[int] = None,
mesh_session_timeout_ms: Optional[int] = None,
mesh_batched_io: Optional[bool] = None,
mesh_packet_pool_size: Optional[int] = None,
) -> None:
self._bus = Net(
num_shards=shards,
ring_buffer_capacity=buffer_capacity,
backpressure_mode=backpressure,
redis_url=redis_url,
redis_prefix=redis_prefix,
redis_pipeline_size=redis_pipeline_size,
redis_pool_size=redis_pool_size,
redis_connect_timeout_ms=redis_connect_timeout_ms,
redis_command_timeout_ms=redis_command_timeout_ms,
redis_max_stream_len=redis_max_stream_len,
jetstream_url=jetstream_url,
jetstream_prefix=jetstream_prefix,
jetstream_connect_timeout_ms=jetstream_connect_timeout_ms,
jetstream_request_timeout_ms=jetstream_request_timeout_ms,
jetstream_max_messages=jetstream_max_messages,
jetstream_max_bytes=jetstream_max_bytes,
jetstream_max_age_ms=jetstream_max_age_ms,
jetstream_replicas=jetstream_replicas,
net_bind_addr=mesh_bind,
net_peer_addr=mesh_peer,
net_psk=mesh_psk,
net_role=mesh_role,
net_peer_public_key=mesh_peer_public_key,
net_secret_key=mesh_secret_key,
net_public_key=mesh_public_key,
net_reliability=mesh_reliability,
net_heartbeat_interval_ms=mesh_heartbeat_interval_ms,
net_session_timeout_ms=mesh_session_timeout_ms,
net_batched_io=mesh_batched_io,
net_packet_pool_size=mesh_packet_pool_size,
)
@classmethod
def from_bus(cls, bus: Net) -> NetNode:
node = cls.__new__(cls)
node._bus = bus
return node
def emit(self, event: Any) -> Receipt:
if hasattr(event, "model_dump"):
data = event.model_dump()
elif hasattr(event, "__dict__") and not isinstance(event, dict):
data = event.__dict__
else:
data = event
result = self._bus.ingest(data) if isinstance(data, dict) else self._bus.ingest_raw(json.dumps(data))
return Receipt(shard_id=result.shard_id, timestamp=result.timestamp)
def emit_raw(self, json_str: str) -> Receipt:
result = self._bus.ingest_raw(json_str)
return Receipt(shard_id=result.shard_id, timestamp=result.timestamp)
def emit_batch(self, events: list[Any]) -> int:
payloads: list[str] = []
for event in events:
if hasattr(event, "model_dump"):
payloads.append(json.dumps(event.model_dump()))
elif isinstance(event, dict):
payloads.append(json.dumps(event))
elif hasattr(event, "__dict__"):
payloads.append(json.dumps(event.__dict__))
else:
payloads.append(json.dumps(event))
return self._bus.ingest_raw_batch(payloads)
def emit_raw_batch(self, json_strs: list[str]) -> int:
return self._bus.ingest_raw_batch(json_strs)
def fire(self, json_str: str) -> None:
self._bus.ingest_raw(json_str)
def poll(
self,
limit: int = 100,
cursor: Optional[str] = None,
filter: Optional[str] = None,
ordering: Optional[str] = None,
) -> PollResponse:
return self._bus.poll(limit=limit, cursor=cursor, filter=filter, ordering=ordering)
def poll_one(self) -> Optional[StoredEvent]:
response = self.poll(limit=1)
if len(response) > 0:
return response.events[0]
return None
def subscribe(
self,
limit: int = 100,
filter: Optional[str] = None,
ordering: Optional[str] = None,
timeout: Optional[float] = None,
poll_interval: float = 0.001,
max_backoff: float = 0.1,
) -> EventStream:
opts = SubscribeOpts(
limit=limit,
filter=filter,
ordering=ordering,
poll_interval=poll_interval,
max_backoff=max_backoff,
timeout=timeout,
)
return EventStream(self._bus, opts)
def subscribe_typed(
self,
model: type[T],
limit: int = 100,
filter: Optional[str] = None,
ordering: Optional[str] = None,
timeout: Optional[float] = None,
) -> TypedEventStream[T]:
opts = SubscribeOpts(
limit=limit,
filter=filter,
ordering=ordering,
timeout=timeout,
)
if hasattr(model, "model_validate_json"):
def parse(raw: str) -> T:
return model.model_validate_json(raw) else:
def parse(raw: str) -> T:
return model(**json.loads(raw))
return TypedEventStream(self._bus, parse, opts)
def channel(
self,
name: str,
model: Optional[type[T]] = None,
) -> TypedChannel[T]:
return TypedChannel(self._bus, name, model=model)
def stats(self) -> Stats:
return self._bus.stats()
def shards(self) -> int:
return self._bus.num_shards()
def shutdown(self) -> None:
self._bus.shutdown()
@property
def bus(self) -> Net:
return self._bus
def __enter__(self) -> NetNode:
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
self.shutdown()
return False
def __repr__(self) -> str:
return f"NetNode(shards={self.shards()})"