use std::{net::SocketAddr, sync::Arc};
use async_trait::async_trait;
use thiserror::Error;
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
pub enum ExtensionCapability {
BindUdp,
BindTcp,
ConnectTcp,
ConnectWebSocket,
ObserveObserverIngress,
ObserveSharedExtensionStream,
}
impl ExtensionCapability {
#[must_use]
pub const fn all() -> [Self; 6] {
[
Self::BindUdp,
Self::BindTcp,
Self::ConnectTcp,
Self::ConnectWebSocket,
Self::ObserveObserverIngress,
Self::ObserveSharedExtensionStream,
]
}
}
#[derive(Debug, Clone, Default, Eq, PartialEq)]
pub enum ExtensionStreamVisibility {
#[default]
Private,
Shared {
tag: String,
},
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct UdpListenerSpec {
pub resource_id: String,
pub bind_addr: SocketAddr,
pub visibility: ExtensionStreamVisibility,
pub read_buffer_bytes: usize,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct TcpListenerSpec {
pub resource_id: String,
pub bind_addr: SocketAddr,
pub visibility: ExtensionStreamVisibility,
pub read_buffer_bytes: usize,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct TcpConnectorSpec {
pub resource_id: String,
pub remote_addr: SocketAddr,
pub visibility: ExtensionStreamVisibility,
pub read_buffer_bytes: usize,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct WsConnectorSpec {
pub resource_id: String,
pub url: String,
pub visibility: ExtensionStreamVisibility,
pub read_buffer_bytes: usize,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum ExtensionResourceSpec {
UdpListener(UdpListenerSpec),
TcpListener(TcpListenerSpec),
TcpConnector(TcpConnectorSpec),
WsConnector(WsConnectorSpec),
}
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
pub enum RuntimePacketSourceKind {
ObserverIngress,
ExtensionResource,
}
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
pub enum RuntimePacketTransport {
Udp,
Tcp,
WebSocket,
}
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
pub enum RuntimePacketEventClass {
Packet,
ConnectionClosed,
}
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
pub enum RuntimeWebSocketFrameType {
Text,
Binary,
Ping,
Pong,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct RuntimePacketSource {
pub kind: RuntimePacketSourceKind,
pub transport: RuntimePacketTransport,
pub event_class: RuntimePacketEventClass,
pub owner_extension: Option<String>,
pub resource_id: Option<String>,
pub shared_tag: Option<String>,
pub websocket_frame_type: Option<RuntimeWebSocketFrameType>,
pub local_addr: Option<SocketAddr>,
pub remote_addr: Option<SocketAddr>,
}
#[derive(Debug, Clone)]
pub struct RuntimePacketEvent {
pub source: RuntimePacketSource,
pub bytes: Arc<[u8]>,
pub observed_unix_ms: u64,
}
#[derive(Debug, Clone, Default, Eq, PartialEq)]
pub struct PacketSubscription {
pub source_kind: Option<RuntimePacketSourceKind>,
pub transport: Option<RuntimePacketTransport>,
pub event_class: Option<RuntimePacketEventClass>,
pub local_addr: Option<SocketAddr>,
pub local_port: Option<u16>,
pub remote_addr: Option<SocketAddr>,
pub remote_port: Option<u16>,
pub owner_extension: Option<String>,
pub resource_id: Option<String>,
pub shared_tag: Option<String>,
pub websocket_frame_type: Option<RuntimeWebSocketFrameType>,
}
impl PacketSubscription {
#[must_use]
pub fn matches(&self, event: &RuntimePacketEvent) -> bool {
if let Some(source_kind) = self.source_kind
&& source_kind != event.source.kind
{
return false;
}
if let Some(transport) = self.transport
&& transport != event.source.transport
{
return false;
}
if let Some(event_class) = self.event_class
&& event_class != event.source.event_class
{
return false;
}
if let Some(local_addr) = self.local_addr
&& event.source.local_addr != Some(local_addr)
{
return false;
}
if let Some(local_port) = self.local_port
&& event.source.local_addr.map(|addr| addr.port()) != Some(local_port)
{
return false;
}
if let Some(remote_addr) = self.remote_addr
&& event.source.remote_addr != Some(remote_addr)
{
return false;
}
if let Some(remote_port) = self.remote_port
&& event.source.remote_addr.map(|addr| addr.port()) != Some(remote_port)
{
return false;
}
if let Some(owner_extension) = self.owner_extension.as_ref()
&& event.source.owner_extension.as_ref() != Some(owner_extension)
{
return false;
}
if let Some(resource_id) = self.resource_id.as_ref()
&& event.source.resource_id.as_ref() != Some(resource_id)
{
return false;
}
if let Some(shared_tag) = self.shared_tag.as_ref()
&& event.source.shared_tag.as_ref() != Some(shared_tag)
{
return false;
}
if let Some(websocket_frame_type) = self.websocket_frame_type
&& event.source.websocket_frame_type != Some(websocket_frame_type)
{
return false;
}
true
}
}
#[derive(Debug, Clone, Default, Eq, PartialEq)]
pub struct ExtensionManifest {
pub capabilities: Vec<ExtensionCapability>,
pub resources: Vec<ExtensionResourceSpec>,
pub subscriptions: Vec<PacketSubscription>,
}
#[derive(Debug, Clone)]
pub struct ExtensionContext {
pub extension_name: &'static str,
}
#[derive(Debug, Clone, Error, Eq, PartialEq)]
#[error("{reason}")]
pub struct ExtensionSetupError {
reason: String,
}
impl ExtensionSetupError {
#[must_use]
pub fn new(reason: impl Into<String>) -> Self {
Self {
reason: reason.into(),
}
}
}
#[async_trait]
pub trait RuntimeExtension: Send + Sync + 'static {
fn name(&self) -> &'static str {
core::any::type_name::<Self>()
}
fn has_explicit_name(&self) -> bool {
self.name() != core::any::type_name::<Self>()
}
async fn setup(
&self,
_ctx: ExtensionContext,
) -> Result<ExtensionManifest, ExtensionSetupError> {
Ok(ExtensionManifest::default())
}
async fn on_packet_received(&self, _event: RuntimePacketEvent) {}
async fn shutdown(&self, _ctx: ExtensionContext) {}
}