#[cfg(test)]
use std::str::FromStr;
use std::{
any::{Any, type_name_of_val},
collections::HashSet,
io::ErrorKind,
net::SocketAddr,
panic::AssertUnwindSafe,
sync::{
Arc, Mutex, RwLock,
atomic::{AtomicU64, Ordering},
},
time::{Duration, Instant},
};
use futures_util::{FutureExt, SinkExt, StreamExt};
use sof_support::time_support::current_unix_ms;
use tokio::{
io::AsyncReadExt,
net::{TcpListener, TcpStream, UdpSocket},
sync::mpsc,
task::{JoinHandle, JoinSet},
time::timeout,
};
use tokio_tungstenite::{
MaybeTlsStream, WebSocketStream, connect_async_with_config,
tungstenite::{
Error as WebSocketError, Message,
client::IntoClientRequest,
protocol::{CloseFrame, WebSocketConfig},
},
};
use crate::framework::extension::{
ExtensionCapability, ExtensionContext, ExtensionManifest, ExtensionResourceSpec,
ExtensionStreamVisibility, PacketSubscription, RuntimeExtension, RuntimePacketEvent,
RuntimePacketEventClass, RuntimePacketSource, RuntimePacketSourceKind, RuntimePacketTransport,
RuntimeWebSocketFrameType, TcpConnectorSpec, TcpListenerSpec, UdpListenerSpec, WsConnectorSpec,
};
const DEFAULT_EVENT_QUEUE_CAPACITY: usize = 8_192;
const INITIAL_DROP_LOG_LIMIT: u64 = 5;
const DROP_LOG_SAMPLE_EVERY: u64 = 1_000;
const DEFAULT_STARTUP_TIMEOUT_SECS: u64 = 5;
const DEFAULT_SHUTDOWN_TIMEOUT_SECS: u64 = 3;
const DEFAULT_RESOURCE_READ_BUFFER_BYTES: usize = 2_048;
const MAX_RESOURCE_READ_BUFFER_BYTES: usize = 1024 * 1024;
const EXTENSION_WEBSOCKET_MESSAGE_LIMIT_MULTIPLIER: usize = 64;
const MIN_EXTENSION_HOST_TIMEOUT: Duration = Duration::from_millis(1);
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct RuntimeExtensionStartupFailure {
pub extension: &'static str,
pub reason: String,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct RuntimeExtensionStartupReport {
pub discovered_extensions: usize,
pub active_extensions: usize,
pub failed_extensions: usize,
pub failures: Vec<RuntimeExtensionStartupFailure>,
}
impl RuntimeExtensionStartupReport {
const fn empty(discovered_extensions: usize) -> Self {
Self {
discovered_extensions,
active_extensions: 0,
failed_extensions: 0,
failures: Vec::new(),
}
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct RuntimeExtensionDispatchMetrics {
pub extension: &'static str,
pub dropped_events: u64,
pub queue_depth: u64,
pub max_queue_depth: u64,
pub dispatched_events: u64,
pub avg_dispatch_lag_us: u64,
pub max_dispatch_lag_us: u64,
}
#[derive(Debug, Clone)]
pub struct RuntimeExtensionCapabilityPolicy {
allowed: HashSet<ExtensionCapability>,
}
impl Default for RuntimeExtensionCapabilityPolicy {
fn default() -> Self {
Self {
allowed: ExtensionCapability::all().into_iter().collect(),
}
}
}
impl RuntimeExtensionCapabilityPolicy {
#[must_use]
pub fn allow_all() -> Self {
Self::default()
}
#[must_use]
pub fn deny_all() -> Self {
Self {
allowed: HashSet::new(),
}
}
#[must_use]
pub fn production_defaults() -> Self {
Self::deny_all()
.with(ExtensionCapability::BindUdp)
.with(ExtensionCapability::BindTcp)
.with(ExtensionCapability::ObserveObserverIngress)
.with(ExtensionCapability::ObserveSharedExtensionStream)
}
#[must_use]
pub fn allows(&self, capability: ExtensionCapability) -> bool {
self.allowed.contains(&capability)
}
#[must_use]
pub fn with(mut self, capability: ExtensionCapability) -> Self {
self.allowed.insert(capability);
self
}
#[must_use]
pub fn without(mut self, capability: ExtensionCapability) -> Self {
self.allowed.remove(&capability);
self
}
}
pub struct RuntimeExtensionHostBuilder {
extensions: Vec<Arc<dyn RuntimeExtension>>,
event_queue_capacity: usize,
startup_timeout: Duration,
shutdown_timeout: Duration,
capability_policy: RuntimeExtensionCapabilityPolicy,
require_explicit_extension_names: bool,
}
impl Default for RuntimeExtensionHostBuilder {
fn default() -> Self {
Self {
extensions: Vec::new(),
event_queue_capacity: DEFAULT_EVENT_QUEUE_CAPACITY,
startup_timeout: Duration::from_secs(DEFAULT_STARTUP_TIMEOUT_SECS),
shutdown_timeout: Duration::from_secs(DEFAULT_SHUTDOWN_TIMEOUT_SECS),
capability_policy: RuntimeExtensionCapabilityPolicy::default(),
require_explicit_extension_names: false,
}
}
}
impl RuntimeExtensionHostBuilder {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn production_defaults() -> Self {
Self {
capability_policy: RuntimeExtensionCapabilityPolicy::production_defaults(),
require_explicit_extension_names: true,
..Self::default()
}
}
#[must_use]
pub fn with_event_queue_capacity(mut self, capacity: usize) -> Self {
self.event_queue_capacity = capacity.max(1);
self
}
#[must_use]
pub const fn with_startup_timeout(mut self, timeout: Duration) -> Self {
self.startup_timeout = if timeout.is_zero() {
MIN_EXTENSION_HOST_TIMEOUT
} else {
timeout
};
self
}
#[must_use]
pub const fn with_shutdown_timeout(mut self, timeout: Duration) -> Self {
self.shutdown_timeout = if timeout.is_zero() {
MIN_EXTENSION_HOST_TIMEOUT
} else {
timeout
};
self
}
#[must_use]
pub fn with_capability_policy(mut self, policy: RuntimeExtensionCapabilityPolicy) -> Self {
self.capability_policy = policy;
self
}
#[must_use]
pub const fn with_require_explicit_extension_names(mut self, require: bool) -> Self {
self.require_explicit_extension_names = require;
self
}
#[must_use]
pub fn add_extension<E>(mut self, extension: E) -> Self
where
E: RuntimeExtension,
{
self.extensions.push(Arc::new(extension));
self
}
#[must_use]
pub fn add_shared_extension(mut self, extension: Arc<dyn RuntimeExtension>) -> Self {
self.extensions.push(extension);
self
}
#[must_use]
pub fn build(self) -> RuntimeExtensionHost {
RuntimeExtensionHost {
inner: Arc::new(RuntimeExtensionHostInner {
extensions: Arc::from(self.extensions),
event_queue_capacity: self.event_queue_capacity.max(1),
startup_timeout: self.startup_timeout,
shutdown_timeout: self.shutdown_timeout,
capability_policy: self.capability_policy,
require_explicit_extension_names: self.require_explicit_extension_names,
runtime_state: RwLock::new(None),
}),
}
}
}
struct RuntimeExtensionHostInner {
extensions: Arc<[Arc<dyn RuntimeExtension>]>,
event_queue_capacity: usize,
startup_timeout: Duration,
shutdown_timeout: Duration,
capability_policy: RuntimeExtensionCapabilityPolicy,
require_explicit_extension_names: bool,
runtime_state: RwLock<Option<RuntimeExtensionRuntimeState>>,
}
struct RuntimeExtensionRuntimeState {
active_extensions: Arc<[Arc<ActiveRuntimeExtension>]>,
}
struct ActiveRuntimeExtension {
extension: Arc<dyn RuntimeExtension>,
name: &'static str,
capabilities: HashSet<ExtensionCapability>,
subscriptions: Arc<[PacketSubscription]>,
dispatcher: ExtensionDispatcher,
resource_handles: Mutex<Vec<JoinHandle<()>>>,
}
impl ActiveRuntimeExtension {
fn dropped_event_count(&self) -> u64 {
self.dispatcher.dropped_count()
}
fn dispatch_metrics_snapshot(&self) -> RuntimeExtensionDispatchMetrics {
self.dispatcher.metrics_snapshot(self.name)
}
fn push_resource_handle(&self, handle: JoinHandle<()>) {
if let Ok(mut guard) = self.resource_handles.lock() {
guard.push(handle);
}
}
fn abort_resource_handles(&self) {
if let Ok(mut guard) = self.resource_handles.lock() {
for handle in guard.drain(..) {
handle.abort();
}
}
}
}
#[derive(Clone)]
struct ExtensionDispatcher {
tx: mpsc::Sender<QueuedRuntimePacketEvent>,
dropped_events: Arc<AtomicU64>,
queue_depth: Arc<AtomicU64>,
max_queue_depth: Arc<AtomicU64>,
dispatched_events: Arc<AtomicU64>,
total_dispatch_lag_us: Arc<AtomicU64>,
max_dispatch_lag_us: Arc<AtomicU64>,
worker: Arc<Mutex<Option<JoinHandle<()>>>>,
}
struct QueuedRuntimePacketEvent {
event: RuntimePacketEvent,
queued_at: Instant,
}
impl ExtensionDispatcher {
fn spawn(
extension: &Arc<dyn RuntimeExtension>,
extension_name: &'static str,
queue_capacity: usize,
) -> Self {
let (tx, mut rx) = mpsc::channel::<QueuedRuntimePacketEvent>(queue_capacity.max(1));
let dropped_events = Arc::new(AtomicU64::new(0));
let queue_depth = Arc::new(AtomicU64::new(0));
let max_queue_depth = Arc::new(AtomicU64::new(0));
let dispatched_events = Arc::new(AtomicU64::new(0));
let total_dispatch_lag_us = Arc::new(AtomicU64::new(0));
let max_dispatch_lag_us = Arc::new(AtomicU64::new(0));
let worker_extension = Arc::clone(extension);
let worker_queue_depth = Arc::clone(&queue_depth);
let worker_dispatched_events = Arc::clone(&dispatched_events);
let worker_total_dispatch_lag_us = Arc::clone(&total_dispatch_lag_us);
let worker_max_dispatch_lag_us = Arc::clone(&max_dispatch_lag_us);
let worker = tokio::spawn(async move {
while let Some(queued_event) = rx.recv().await {
worker_queue_depth.fetch_sub(1, Ordering::Relaxed);
let queue_lag_us =
u64::try_from(queued_event.queued_at.elapsed().as_micros()).unwrap_or(u64::MAX);
worker_dispatched_events.fetch_add(1, Ordering::Relaxed);
worker_total_dispatch_lag_us.fetch_add(queue_lag_us, Ordering::Relaxed);
record_max_atomic(&worker_max_dispatch_lag_us, queue_lag_us);
let callback_extension = Arc::clone(&worker_extension);
if let Err(payload) =
invoke_extension_callback(callback_extension, queued_event.event).await
{
let panic_message = panic_payload_to_string(payload.as_ref());
tracing::error!(
extension = extension_name,
panic = %panic_message,
"runtime extension packet callback panicked; continuing runtime"
);
}
}
});
Self {
tx,
dropped_events,
queue_depth,
max_queue_depth,
dispatched_events,
total_dispatch_lag_us,
max_dispatch_lag_us,
worker: Arc::new(Mutex::new(Some(worker))),
}
}
fn dispatch(&self, extension_name: &'static str, event: RuntimePacketEvent) {
let queued_event = QueuedRuntimePacketEvent {
event,
queued_at: Instant::now(),
};
let queue_depth = self
.queue_depth
.fetch_add(1, Ordering::Relaxed)
.saturating_add(1);
record_max_atomic(&self.max_queue_depth, queue_depth);
match self.tx.try_send(queued_event) {
Ok(()) => {}
Err(mpsc::error::TrySendError::Full(_)) => {
self.queue_depth.fetch_sub(1, Ordering::Relaxed);
self.record_drop(extension_name, "queue full");
}
Err(mpsc::error::TrySendError::Closed(_)) => {
self.queue_depth.fetch_sub(1, Ordering::Relaxed);
self.record_drop(extension_name, "queue closed");
}
}
}
fn dropped_count(&self) -> u64 {
self.dropped_events.load(Ordering::Relaxed)
}
fn metrics_snapshot(&self, extension_name: &'static str) -> RuntimeExtensionDispatchMetrics {
let dispatched_events = self.dispatched_events.load(Ordering::Relaxed);
let total_dispatch_lag_us = self.total_dispatch_lag_us.load(Ordering::Relaxed);
let avg_dispatch_lag_us = if dispatched_events == 0 {
0
} else {
total_dispatch_lag_us
.checked_div(dispatched_events)
.unwrap_or_default()
};
RuntimeExtensionDispatchMetrics {
extension: extension_name,
dropped_events: self.dropped_events.load(Ordering::Relaxed),
queue_depth: self.queue_depth.load(Ordering::Relaxed),
max_queue_depth: self.max_queue_depth.load(Ordering::Relaxed),
dispatched_events,
avg_dispatch_lag_us,
max_dispatch_lag_us: self.max_dispatch_lag_us.load(Ordering::Relaxed),
}
}
fn abort_worker(&self) {
if let Ok(mut guard) = self.worker.lock()
&& let Some(handle) = guard.take()
{
handle.abort();
}
}
fn record_drop(&self, extension_name: &'static str, reason: &'static str) {
let dropped = self
.dropped_events
.fetch_add(1, Ordering::Relaxed)
.saturating_add(1);
if dropped <= INITIAL_DROP_LOG_LIMIT || dropped.is_multiple_of(DROP_LOG_SAMPLE_EVERY) {
tracing::warn!(
extension = extension_name,
reason,
dropped,
queue_depth = self.queue_depth.load(Ordering::Relaxed),
"dropping runtime extension packet event to protect ingest hot path"
);
}
}
}
async fn invoke_extension_callback(
extension: Arc<dyn RuntimeExtension>,
event: RuntimePacketEvent,
) -> Result<(), Box<dyn Any + Send>> {
AssertUnwindSafe(extension.on_packet_received(event))
.catch_unwind()
.await
}
#[derive(Clone)]
pub struct RuntimeExtensionHost {
inner: Arc<RuntimeExtensionHostInner>,
}
impl Default for RuntimeExtensionHost {
fn default() -> Self {
RuntimeExtensionHostBuilder::new().build()
}
}
impl RuntimeExtensionHost {
#[must_use]
pub fn builder() -> RuntimeExtensionHostBuilder {
RuntimeExtensionHostBuilder::new()
}
#[must_use]
pub fn production_builder() -> RuntimeExtensionHostBuilder {
RuntimeExtensionHostBuilder::production_defaults()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.extensions.is_empty()
}
#[must_use]
pub fn len(&self) -> usize {
self.inner.extensions.len()
}
#[must_use]
pub fn extension_names(&self) -> Vec<&'static str> {
self.inner
.extensions
.iter()
.map(|extension| extension.name())
.collect()
}
#[must_use]
pub fn active_extension_names(&self) -> Vec<&'static str> {
self.inner
.runtime_state
.read()
.ok()
.and_then(|guard| {
guard.as_ref().map(|state| {
state
.active_extensions
.iter()
.map(|extension| extension.name)
.collect()
})
})
.unwrap_or_default()
}
#[must_use]
pub fn dropped_event_count(&self) -> u64 {
self.inner
.runtime_state
.read()
.ok()
.and_then(|guard| {
guard.as_ref().map(|state| {
state
.active_extensions
.iter()
.map(|extension| extension.dropped_event_count())
.sum()
})
})
.unwrap_or_default()
}
#[must_use]
pub fn dropped_event_counts_by_extension(&self) -> Vec<(&'static str, u64)> {
self.inner
.runtime_state
.read()
.ok()
.and_then(|guard| {
guard.as_ref().map(|state| {
state
.active_extensions
.iter()
.map(|extension| (extension.name, extension.dropped_event_count()))
.collect()
})
})
.unwrap_or_default()
}
#[must_use]
pub fn dispatch_metrics_by_extension(&self) -> Vec<RuntimeExtensionDispatchMetrics> {
self.inner
.runtime_state
.read()
.ok()
.and_then(|guard| {
guard.as_ref().map(|state| {
state
.active_extensions
.iter()
.map(|extension| extension.dispatch_metrics_snapshot())
.collect()
})
})
.unwrap_or_default()
}
pub async fn startup(&self) -> RuntimeExtensionStartupReport {
if let Ok(guard) = self.inner.runtime_state.read()
&& let Some(state) = guard.as_ref()
{
return RuntimeExtensionStartupReport {
discovered_extensions: self.inner.extensions.len(),
active_extensions: state.active_extensions.len(),
failed_extensions: 0,
failures: Vec::new(),
};
}
let mut report = RuntimeExtensionStartupReport::empty(self.inner.extensions.len());
let mut active_extensions: Vec<Arc<ActiveRuntimeExtension>> = Vec::new();
let mut seen_extension_names = HashSet::<&'static str>::new();
for extension in self.inner.extensions.iter() {
let extension = Arc::clone(extension);
let extension_name = extension.name();
let has_explicit_name = extension.has_explicit_name();
if !has_explicit_name {
let concrete_type_name = type_name_of_val(extension.as_ref());
tracing::warn!(
extension = extension_name,
concrete_type = concrete_type_name,
"runtime extension uses implicit type-name identifier; override `name()` with a stable literal for telemetry/filter durability"
);
if self.inner.require_explicit_extension_names {
report.failures.push(RuntimeExtensionStartupFailure {
extension: extension_name,
reason:
"runtime policy requires explicit stable extension names; override RuntimeExtension::name"
.to_owned(),
});
continue;
}
}
if !seen_extension_names.insert(extension_name) {
report.failures.push(RuntimeExtensionStartupFailure {
extension: extension_name,
reason: format!(
"duplicate extension name `{extension_name}`; extension names must be unique"
),
});
continue;
}
let startup_context = ExtensionContext { extension_name };
let manifest_result =
timeout(self.inner.startup_timeout, extension.setup(startup_context)).await;
let manifest = match manifest_result {
Ok(Ok(manifest)) => manifest,
Ok(Err(error)) => {
report.failures.push(RuntimeExtensionStartupFailure {
extension: extension_name,
reason: error.to_string(),
});
continue;
}
Err(_elapsed) => {
report.failures.push(RuntimeExtensionStartupFailure {
extension: extension_name,
reason: format!(
"startup hook timed out after {}ms",
self.inner.startup_timeout.as_millis()
),
});
continue;
}
};
let validated =
match validate_manifest(extension_name, &manifest, &self.inner.capability_policy) {
Ok(validated) => validated,
Err(reason) => {
report.failures.push(RuntimeExtensionStartupFailure {
extension: extension_name,
reason,
});
continue;
}
};
let active = Arc::new(ActiveRuntimeExtension {
extension: Arc::clone(&extension),
name: extension_name,
capabilities: validated.capabilities,
subscriptions: Arc::from(validated.subscriptions),
dispatcher: ExtensionDispatcher::spawn(
&extension,
extension_name,
self.inner.event_queue_capacity,
),
resource_handles: Mutex::new(Vec::new()),
});
if let Err(reason) = self.provision_resources(&active, &manifest.resources).await {
active.abort_resource_handles();
active.dispatcher.abort_worker();
report.failures.push(RuntimeExtensionStartupFailure {
extension: extension_name,
reason,
});
continue;
}
active_extensions.push(active);
}
report.active_extensions = active_extensions.len();
report.failed_extensions = report.failures.len();
if let Ok(mut guard) = self.inner.runtime_state.write() {
*guard = Some(RuntimeExtensionRuntimeState {
active_extensions: Arc::from(active_extensions),
});
}
report
}
pub fn on_observer_packet(&self, source: SocketAddr, bytes: &[u8]) {
self.on_observer_packet_shared(source, Arc::from(bytes));
}
pub fn on_observer_packet_shared(&self, source: SocketAddr, bytes: Arc<[u8]>) {
if bytes.is_empty() {
return;
}
let source_meta = RuntimePacketSource {
kind: RuntimePacketSourceKind::ObserverIngress,
transport: RuntimePacketTransport::Udp,
event_class: RuntimePacketEventClass::Packet,
owner_extension: None,
resource_id: None,
shared_tag: None,
websocket_frame_type: None,
local_addr: None,
remote_addr: Some(source),
};
let event = RuntimePacketEvent {
source: source_meta,
bytes,
observed_unix_ms: current_unix_ms(),
};
self.dispatch_runtime_packet(&event);
}
pub fn emit_extension_packet(&self, source: RuntimePacketSource, bytes: Arc<[u8]>) {
if bytes.is_empty() && source.event_class == RuntimePacketEventClass::Packet {
return;
}
let event = RuntimePacketEvent {
source,
bytes,
observed_unix_ms: current_unix_ms(),
};
self.dispatch_runtime_packet(&event);
}
pub async fn shutdown(&self) {
let state = self
.inner
.runtime_state
.write()
.ok()
.and_then(|mut guard| guard.take());
let Some(state) = state else {
return;
};
for extension in state.active_extensions.iter() {
extension.abort_resource_handles();
}
for extension in state.active_extensions.iter() {
let shutdown_context = ExtensionContext {
extension_name: extension.name,
};
let shutdown_result = timeout(
self.inner.shutdown_timeout,
extension.extension.shutdown(shutdown_context),
)
.await;
if shutdown_result.is_err() {
tracing::warn!(
extension = extension.name,
timeout_ms = self.inner.shutdown_timeout.as_millis(),
"runtime extension shutdown timed out; force-cancelling"
);
}
}
for extension in state.active_extensions.iter() {
extension.dispatcher.abort_worker();
}
}
fn dispatch_runtime_packet(&self, event: &RuntimePacketEvent) {
let Some(active_extensions) = self.inner.runtime_state.read().ok().and_then(|guard| {
guard
.as_ref()
.map(|state| Arc::clone(&state.active_extensions))
}) else {
return;
};
for extension in active_extensions.iter() {
if !extension_can_observe_event(extension, event) {
continue;
}
if !extension
.subscriptions
.iter()
.any(|subscription| subscription.matches(event))
{
continue;
}
extension.dispatcher.dispatch(extension.name, event.clone());
}
}
async fn provision_resources(
&self,
extension: &Arc<ActiveRuntimeExtension>,
resources: &[ExtensionResourceSpec],
) -> Result<(), String> {
let startup_timeout = self.inner.startup_timeout;
for resource in resources {
let handle = match resource {
ExtensionResourceSpec::UdpListener(spec) => {
spawn_udp_listener(self.clone(), extension, spec.clone()).await?
}
ExtensionResourceSpec::TcpListener(spec) => {
spawn_tcp_listener(self.clone(), extension, spec.clone()).await?
}
ExtensionResourceSpec::TcpConnector(spec) => {
spawn_tcp_connector(self.clone(), extension, spec.clone(), startup_timeout)
.await?
}
ExtensionResourceSpec::WsConnector(spec) => {
spawn_ws_connector(self.clone(), extension, spec.clone(), startup_timeout)
.await?
}
};
extension.push_resource_handle(handle);
}
Ok(())
}
}
fn extension_can_observe_event(
extension: &ActiveRuntimeExtension,
event: &RuntimePacketEvent,
) -> bool {
match event.source.kind {
RuntimePacketSourceKind::ObserverIngress => extension
.capabilities
.contains(&ExtensionCapability::ObserveObserverIngress),
RuntimePacketSourceKind::ExtensionResource => {
let owner_name = event.source.owner_extension.as_deref();
if owner_name == Some(extension.name) {
return true;
}
event.source.shared_tag.is_some()
&& extension
.capabilities
.contains(&ExtensionCapability::ObserveSharedExtensionStream)
}
}
}
#[derive(Clone)]
struct ExtensionResourceEmitter {
host: RuntimeExtensionHost,
owner_extension: String,
resource_id: String,
shared_tag: Option<String>,
transport: RuntimePacketTransport,
local_addr: Option<SocketAddr>,
remote_addr: Option<SocketAddr>,
}
impl ExtensionResourceEmitter {
fn new(
host: RuntimeExtensionHost,
extension_name: &str,
resource_id: &str,
shared_tag: Option<String>,
transport: RuntimePacketTransport,
local_addr: Option<SocketAddr>,
remote_addr: Option<SocketAddr>,
) -> Self {
Self {
host,
owner_extension: extension_name.to_owned(),
resource_id: resource_id.to_owned(),
shared_tag,
transport,
local_addr,
remote_addr,
}
}
fn emit_payload(
&self,
payload: &[u8],
websocket_frame_type: Option<RuntimeWebSocketFrameType>,
max_payload_chunk_bytes: usize,
) {
let chunk_size = max_payload_chunk_bytes.max(1);
for chunk in payload.chunks(chunk_size) {
self.emit_event(
RuntimePacketEventClass::Packet,
websocket_frame_type,
Arc::from(chunk),
);
}
}
fn emit_event(
&self,
event_class: RuntimePacketEventClass,
websocket_frame_type: Option<RuntimeWebSocketFrameType>,
bytes: Arc<[u8]>,
) {
self.emit_event_with_remote_addr(
event_class,
websocket_frame_type,
self.remote_addr,
bytes,
);
}
fn emit_event_with_remote_addr(
&self,
event_class: RuntimePacketEventClass,
websocket_frame_type: Option<RuntimeWebSocketFrameType>,
remote_addr: Option<SocketAddr>,
bytes: Arc<[u8]>,
) {
let source = RuntimePacketSource {
kind: RuntimePacketSourceKind::ExtensionResource,
transport: self.transport,
event_class,
owner_extension: Some(self.owner_extension.clone()),
resource_id: Some(self.resource_id.clone()),
shared_tag: self.shared_tag.clone(),
websocket_frame_type,
local_addr: self.local_addr,
remote_addr,
};
self.host.emit_extension_packet(source, bytes);
}
}
struct ExtensionResourceReadContext {
emitter: ExtensionResourceEmitter,
max_payload_chunk_bytes: usize,
}
impl ExtensionResourceReadContext {
const fn new(emitter: ExtensionResourceEmitter, max_payload_chunk_bytes: usize) -> Self {
Self {
emitter,
max_payload_chunk_bytes,
}
}
}
async fn spawn_udp_listener(
host: RuntimeExtensionHost,
extension: &Arc<ActiveRuntimeExtension>,
spec: UdpListenerSpec,
) -> Result<JoinHandle<()>, String> {
let socket = UdpSocket::bind(spec.bind_addr)
.await
.map_err(|error| format!("failed to bind udp listener {}: {error}", spec.bind_addr))?;
let local_addr = socket.local_addr().ok();
let owner_extension = extension.name.to_owned();
let resource_id = spec.resource_id;
let shared_tag = visibility_tag(spec.visibility);
let read_buffer_bytes = spec
.read_buffer_bytes
.max(DEFAULT_RESOURCE_READ_BUFFER_BYTES);
let emitter = ExtensionResourceEmitter::new(
host,
&owner_extension,
&resource_id,
shared_tag,
RuntimePacketTransport::Udp,
local_addr,
None,
);
let handle = tokio::spawn(async move {
let mut buffer = vec![0_u8; read_buffer_bytes];
loop {
match socket.recv_from(&mut buffer).await {
Ok((len, remote_addr)) => {
if len == 0 {
continue;
}
if let Some(payload) = buffer.get(..len) {
emitter.emit_event_with_remote_addr(
RuntimePacketEventClass::Packet,
None,
Some(remote_addr),
Arc::from(payload),
);
}
}
Err(error) => {
if error.kind() == ErrorKind::Interrupted {
continue;
}
tracing::warn!(
extension = owner_extension,
resource_id,
error = %error,
"udp extension listener read loop terminated"
);
break;
}
}
}
});
Ok(handle)
}
async fn spawn_tcp_listener(
host: RuntimeExtensionHost,
extension: &Arc<ActiveRuntimeExtension>,
spec: TcpListenerSpec,
) -> Result<JoinHandle<()>, String> {
let listener = TcpListener::bind(spec.bind_addr)
.await
.map_err(|error| format!("failed to bind tcp listener {}: {error}", spec.bind_addr))?;
let owner_extension = extension.name.to_owned();
let resource_id = spec.resource_id;
let shared_tag = visibility_tag(spec.visibility);
let read_buffer_bytes = spec
.read_buffer_bytes
.max(DEFAULT_RESOURCE_READ_BUFFER_BYTES);
let handle = tokio::spawn(async move {
let mut connections = JoinSet::new();
loop {
tokio::select! {
accepted = listener.accept() => {
match accepted {
Ok((stream, remote_addr)) => {
let local_addr = stream.local_addr().ok();
let emitter = ExtensionResourceEmitter::new(
host.clone(),
&owner_extension,
&resource_id,
shared_tag.clone(),
RuntimePacketTransport::Tcp,
local_addr,
Some(remote_addr),
);
connections.spawn(read_tcp_stream_packets(
ExtensionResourceReadContext::new(emitter, read_buffer_bytes),
stream,
));
}
Err(error) => {
if error.kind() == ErrorKind::Interrupted {
continue;
}
tracing::warn!(
extension = owner_extension,
resource_id,
error = %error,
"tcp extension listener accept loop terminated"
);
break;
}
}
}
Some(_result) = connections.join_next(), if !connections.is_empty() => {}
}
}
});
Ok(handle)
}
async fn spawn_tcp_connector(
host: RuntimeExtensionHost,
extension: &Arc<ActiveRuntimeExtension>,
spec: TcpConnectorSpec,
startup_timeout: Duration,
) -> Result<JoinHandle<()>, String> {
let stream = timeout(startup_timeout, TcpStream::connect(spec.remote_addr))
.await
.map_err(|_elapsed| {
format!(
"tcp connector {} timed out after {}ms during startup",
spec.remote_addr,
startup_timeout.as_millis()
)
})?
.map_err(|error| format!("failed to connect tcp {}: {error}", spec.remote_addr))?;
let local_addr = stream.local_addr().ok();
let remote_addr = stream.peer_addr().ok();
let owner_extension = extension.name.to_owned();
let resource_id = spec.resource_id;
let shared_tag = visibility_tag(spec.visibility);
let read_buffer_bytes = spec
.read_buffer_bytes
.max(DEFAULT_RESOURCE_READ_BUFFER_BYTES);
let handle = tokio::spawn(async move {
let emitter = ExtensionResourceEmitter::new(
host,
&owner_extension,
&resource_id,
shared_tag,
RuntimePacketTransport::Tcp,
local_addr,
remote_addr,
);
read_tcp_stream_packets(
ExtensionResourceReadContext::new(emitter, read_buffer_bytes),
stream,
)
.await;
});
Ok(handle)
}
async fn spawn_ws_connector(
host: RuntimeExtensionHost,
extension: &Arc<ActiveRuntimeExtension>,
spec: WsConnectorSpec,
startup_timeout: Duration,
) -> Result<JoinHandle<()>, String> {
let max_payload_chunk_bytes = spec
.read_buffer_bytes
.max(DEFAULT_RESOURCE_READ_BUFFER_BYTES);
let (stream, _response) = timeout(
startup_timeout,
connect_async_with_config(
spec.url.as_str(),
Some(extension_websocket_transport_config(
max_payload_chunk_bytes,
)),
false,
),
)
.await
.map_err(|_elapsed| {
format!(
"websocket connector {} timed out after {}ms during startup",
spec.url,
startup_timeout.as_millis()
)
})?
.map_err(|error| format!("failed to connect websocket {}: {error}", spec.url))?;
let io = stream.get_ref().get_ref();
let local_addr = io.local_addr().ok();
let peer_addr = io.peer_addr().ok();
let owner_extension = extension.name.to_owned();
let resource_id = spec.resource_id;
let shared_tag = visibility_tag(spec.visibility);
let handle = tokio::spawn(async move {
let emitter = ExtensionResourceEmitter::new(
host,
&owner_extension,
&resource_id,
shared_tag,
RuntimePacketTransport::WebSocket,
local_addr,
peer_addr,
);
read_websocket_messages(
ExtensionResourceReadContext::new(emitter, max_payload_chunk_bytes),
stream,
)
.await;
});
Ok(handle)
}
fn extension_websocket_transport_config(max_payload_chunk_bytes: usize) -> WebSocketConfig {
let max_message_size = max_payload_chunk_bytes
.max(DEFAULT_RESOURCE_READ_BUFFER_BYTES)
.saturating_mul(EXTENSION_WEBSOCKET_MESSAGE_LIMIT_MULTIPLIER);
WebSocketConfig::default()
.max_message_size(Some(max_message_size))
.max_frame_size(Some(max_message_size))
}
async fn read_tcp_stream_packets(context: ExtensionResourceReadContext, mut stream: TcpStream) {
let mut buffer = vec![0_u8; context.max_payload_chunk_bytes.max(1)];
loop {
match stream.read(&mut buffer).await {
Ok(0) => break,
Ok(len) => {
if let Some(payload) = buffer.get(..len) {
context.emitter.emit_event(
RuntimePacketEventClass::Packet,
None,
Arc::from(payload),
);
}
}
Err(error) => {
if error.kind() == ErrorKind::Interrupted {
continue;
}
tracing::warn!(
extension = context.emitter.owner_extension.as_str(),
resource_id = context.emitter.resource_id.as_str(),
error = %error,
"extension tcp stream read loop terminated"
);
break;
}
}
}
}
async fn read_websocket_messages(
context: ExtensionResourceReadContext,
mut stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
) {
loop {
match stream.next().await {
Some(Ok(Message::Text(text))) => {
context.emitter.emit_payload(
text.as_str().as_bytes(),
Some(RuntimeWebSocketFrameType::Text),
context.max_payload_chunk_bytes,
);
}
Some(Ok(Message::Binary(bytes))) => {
context.emitter.emit_payload(
bytes.as_ref(),
Some(RuntimeWebSocketFrameType::Binary),
context.max_payload_chunk_bytes,
);
}
Some(Ok(Message::Ping(payload))) => {
context.emitter.emit_payload(
payload.as_ref(),
Some(RuntimeWebSocketFrameType::Ping),
context.max_payload_chunk_bytes,
);
if let Err(error) = stream.send(Message::Pong(payload)).await {
tracing::warn!(
extension = context.emitter.owner_extension.as_str(),
resource_id = context.emitter.resource_id.as_str(),
error = %error,
"failed to send websocket pong frame; stopping connector"
);
break;
}
}
Some(Ok(Message::Pong(payload))) => {
context.emitter.emit_payload(
payload.as_ref(),
Some(RuntimeWebSocketFrameType::Pong),
context.max_payload_chunk_bytes,
);
}
Some(Ok(Message::Close(frame))) => {
emit_websocket_close_event(&context, frame.as_ref());
if let Err(error) = stream.close(None).await
&& !matches!(
error,
WebSocketError::ConnectionClosed | WebSocketError::AlreadyClosed
)
{
tracing::warn!(
extension = context.emitter.owner_extension.as_str(),
resource_id = context.emitter.resource_id.as_str(),
error = %error,
"failed to complete websocket close handshake"
);
}
break;
}
Some(Ok(Message::Frame(_))) => {
}
Some(Err(error)) => {
if matches!(
error,
WebSocketError::ConnectionClosed | WebSocketError::AlreadyClosed
) {
emit_websocket_close_event(&context, None);
} else {
tracing::warn!(
extension = context.emitter.owner_extension.as_str(),
resource_id = context.emitter.resource_id.as_str(),
error = %error,
"websocket connector read loop terminated"
);
}
break;
}
None => break,
}
}
}
fn emit_websocket_close_event(context: &ExtensionResourceReadContext, frame: Option<&CloseFrame>) {
let close_payload = frame
.map(|close_frame| close_frame.reason.as_bytes())
.unwrap_or_default();
context.emitter.emit_event(
RuntimePacketEventClass::ConnectionClosed,
None,
Arc::from(close_payload),
);
tracing::info!(
extension = context.emitter.owner_extension.as_str(),
resource_id = context.emitter.resource_id.as_str(),
close_code = frame.map(|close_frame| u16::from(close_frame.code)),
close_reason = frame
.map(|close_frame| close_frame.reason.to_string())
.unwrap_or_default(),
"websocket connector closed by remote peer"
);
}
fn visibility_tag(visibility: ExtensionStreamVisibility) -> Option<String> {
match visibility {
ExtensionStreamVisibility::Private => None,
ExtensionStreamVisibility::Shared { tag } => Some(tag),
}
}
struct ValidatedManifest {
capabilities: HashSet<ExtensionCapability>,
subscriptions: Vec<PacketSubscription>,
}
fn validate_websocket_resource_url(resource_id: &str, url: &str) -> Result<(), String> {
if url.trim().is_empty() {
return Err(format!(
"resource `{resource_id}` declares empty websocket url"
));
}
url.into_client_request().map(|_| ()).map_err(|error| {
format!("resource `{resource_id}` declares invalid websocket url `{url}`: {error}")
})
}
fn validate_packet_subscription(
subscription: &PacketSubscription,
capabilities: &HashSet<ExtensionCapability>,
) -> Result<(), String> {
if matches!(
subscription.source_kind,
Some(RuntimePacketSourceKind::ObserverIngress)
) {
if !capabilities.contains(&ExtensionCapability::ObserveObserverIngress) {
return Err(
"subscription declares ObserverIngress source without ObserveObserverIngress capability"
.to_owned(),
);
}
if subscription.owner_extension.is_some()
|| subscription.resource_id.is_some()
|| subscription.shared_tag.is_some()
{
return Err(
"subscription declares ObserverIngress source with extension-resource-only selectors"
.to_owned(),
);
}
}
if let Some(owner_extension) = subscription.owner_extension.as_ref()
&& owner_extension.trim().is_empty()
{
return Err("subscription declares empty owner_extension".to_owned());
}
if let Some(resource_id) = subscription.resource_id.as_ref()
&& resource_id.trim().is_empty()
{
return Err("subscription declares empty resource_id".to_owned());
}
if let Some(shared_tag) = subscription.shared_tag.as_ref() {
if shared_tag.trim().is_empty() {
return Err("subscription declares empty shared_tag".to_owned());
}
if !capabilities.contains(&ExtensionCapability::ObserveSharedExtensionStream) {
return Err(
"subscription declares shared_tag without ObserveSharedExtensionStream capability"
.to_owned(),
);
}
}
Ok(())
}
fn validate_manifest(
extension_name: &'static str,
manifest: &ExtensionManifest,
policy: &RuntimeExtensionCapabilityPolicy,
) -> Result<ValidatedManifest, String> {
if extension_name.trim().is_empty() {
return Err("extension declares empty name".to_owned());
}
let capabilities: HashSet<ExtensionCapability> =
manifest.capabilities.iter().copied().collect();
for capability in &capabilities {
if !policy.allows(*capability) {
return Err(format!(
"capability `{capability:?}` is not allowed by runtime policy"
));
}
}
let mut resource_ids = HashSet::<String>::new();
for resource in &manifest.resources {
let (resource_id, visibility, read_buffer_bytes, required_capability) = match resource {
ExtensionResourceSpec::UdpListener(spec) => (
&spec.resource_id,
&spec.visibility,
spec.read_buffer_bytes,
ExtensionCapability::BindUdp,
),
ExtensionResourceSpec::TcpListener(spec) => (
&spec.resource_id,
&spec.visibility,
spec.read_buffer_bytes,
ExtensionCapability::BindTcp,
),
ExtensionResourceSpec::TcpConnector(spec) => (
&spec.resource_id,
&spec.visibility,
spec.read_buffer_bytes,
ExtensionCapability::ConnectTcp,
),
ExtensionResourceSpec::WsConnector(spec) => (
&spec.resource_id,
&spec.visibility,
spec.read_buffer_bytes,
ExtensionCapability::ConnectWebSocket,
),
};
if resource_id.trim().is_empty() {
return Err(format!(
"extension `{extension_name}` declares empty resource_id"
));
}
if !resource_ids.insert(resource_id.clone()) {
return Err(format!(
"duplicate resource_id `{resource_id}` in startup manifest for extension `{extension_name}`"
));
}
if read_buffer_bytes == 0 {
return Err(format!(
"resource `{resource_id}` declares zero read_buffer_bytes"
));
}
if read_buffer_bytes > MAX_RESOURCE_READ_BUFFER_BYTES {
return Err(format!(
"resource `{resource_id}` read_buffer_bytes {read_buffer_bytes} exceeds max {}",
MAX_RESOURCE_READ_BUFFER_BYTES
));
}
if matches!(
visibility,
ExtensionStreamVisibility::Shared { tag } if tag.trim().is_empty()
) {
return Err(format!(
"resource `{resource_id}` declares empty shared visibility tag"
));
}
if !capabilities.contains(&required_capability) {
return Err(format!(
"resource `{resource_id}` requires undeclared capability `{required_capability:?}`"
));
}
if let ExtensionResourceSpec::WsConnector(spec) = resource {
validate_websocket_resource_url(resource_id, &spec.url)?;
}
}
for subscription in &manifest.subscriptions {
validate_packet_subscription(subscription, &capabilities)?;
}
Ok(ValidatedManifest {
capabilities,
subscriptions: manifest.subscriptions.clone(),
})
}
fn record_max_atomic(target: &AtomicU64, value: u64) {
let mut observed = target.load(Ordering::Relaxed);
while value > observed {
match target.compare_exchange_weak(observed, value, Ordering::Relaxed, Ordering::Relaxed) {
Ok(_) => return,
Err(actual) => observed = actual,
}
}
}
fn panic_payload_to_string(payload: &(dyn Any + Send)) -> String {
payload.downcast_ref::<&str>().map_or_else(
|| {
payload
.downcast_ref::<String>()
.cloned()
.unwrap_or_else(|| "non-string panic payload".to_owned())
},
|message| (*message).to_owned(),
)
}
#[cfg(test)]
mod tests {
use super::*;
use std::{
net::TcpListener as StdTcpListener,
sync::atomic::{AtomicBool, AtomicUsize},
};
use crate::framework::ExtensionSetupError;
use async_trait::async_trait;
use sof_support::bench::{avg_ns_per_iteration, profile_iterations};
use tokio::io::AsyncWriteExt;
use tokio::time::{Instant as TokioInstant, sleep};
use tokio_tungstenite::accept_async;
struct CounterExtension {
name: &'static str,
startup_manifest: ExtensionManifest,
packet_count: Arc<AtomicUsize>,
shutdown_wait: Duration,
shutdown_called: Arc<AtomicBool>,
}
#[async_trait]
impl RuntimeExtension for CounterExtension {
fn name(&self) -> &'static str {
self.name
}
async fn setup(
&self,
_ctx: ExtensionContext,
) -> Result<ExtensionManifest, ExtensionSetupError> {
Ok(self.startup_manifest.clone())
}
async fn on_packet_received(&self, _event: RuntimePacketEvent) {
self.packet_count.fetch_add(1, Ordering::Relaxed);
}
async fn shutdown(&self, _ctx: ExtensionContext) {
self.shutdown_called.store(true, Ordering::Relaxed);
if !self.shutdown_wait.is_zero() {
sleep(self.shutdown_wait).await;
}
}
}
struct PanicOnceExtension {
panic_seen: AtomicBool,
packet_count: Arc<AtomicUsize>,
}
#[async_trait]
impl RuntimeExtension for PanicOnceExtension {
fn name(&self) -> &'static str {
"panic-once-extension"
}
async fn setup(
&self,
_ctx: ExtensionContext,
) -> Result<ExtensionManifest, ExtensionSetupError> {
Ok(ExtensionManifest {
capabilities: vec![ExtensionCapability::ObserveObserverIngress],
resources: Vec::new(),
subscriptions: vec![PacketSubscription {
source_kind: Some(RuntimePacketSourceKind::ObserverIngress),
..PacketSubscription::default()
}],
})
}
async fn on_packet_received(&self, _event: RuntimePacketEvent) {
if !self.panic_seen.swap(true, Ordering::Relaxed) {
panic!("intentional extension panic");
}
self.packet_count.fetch_add(1, Ordering::Relaxed);
}
}
fn sample_runtime_packet_event() -> RuntimePacketEvent {
RuntimePacketEvent {
source: RuntimePacketSource {
kind: RuntimePacketSourceKind::ObserverIngress,
transport: RuntimePacketTransport::Udp,
event_class: RuntimePacketEventClass::Packet,
owner_extension: None,
resource_id: None,
shared_tag: None,
websocket_frame_type: None,
local_addr: None,
remote_addr: Some(SocketAddr::from_str("127.0.0.1:9001").expect("valid addr")),
},
bytes: Arc::from(&[7_u8; 32][..]),
observed_unix_ms: 0,
}
}
async fn invoke_extension_callback_baseline(
extension: Arc<dyn RuntimeExtension>,
event: RuntimePacketEvent,
) -> Result<(), Box<dyn Any + Send>> {
let handle = tokio::spawn(async move {
extension.on_packet_received(event).await;
});
match handle.await {
Ok(()) => Ok(()),
Err(error) => Err(error.into_panic()),
}
}
struct StartupFailExtension;
#[async_trait]
impl RuntimeExtension for StartupFailExtension {
fn name(&self) -> &'static str {
"startup-fail"
}
async fn setup(
&self,
_ctx: ExtensionContext,
) -> Result<ExtensionManifest, ExtensionSetupError> {
Err(ExtensionSetupError::new("intentional startup fail"))
}
}
struct ImplicitNameExtension;
#[async_trait]
impl RuntimeExtension for ImplicitNameExtension {
async fn setup(
&self,
_ctx: ExtensionContext,
) -> Result<ExtensionManifest, ExtensionSetupError> {
Ok(ExtensionManifest {
capabilities: vec![ExtensionCapability::ObserveObserverIngress],
resources: Vec::new(),
subscriptions: vec![PacketSubscription {
source_kind: Some(RuntimePacketSourceKind::ObserverIngress),
..PacketSubscription::default()
}],
})
}
}
#[tokio::test]
async fn startup_failure_isolated_per_extension() {
let ok_counter = Arc::new(AtomicUsize::new(0));
let host = RuntimeExtensionHost::builder()
.add_extension(StartupFailExtension)
.add_extension(CounterExtension {
name: "ok-extension",
startup_manifest: ExtensionManifest {
capabilities: vec![ExtensionCapability::ObserveObserverIngress],
resources: Vec::new(),
subscriptions: vec![PacketSubscription {
source_kind: Some(RuntimePacketSourceKind::ObserverIngress),
..PacketSubscription::default()
}],
},
packet_count: Arc::clone(&ok_counter),
shutdown_wait: Duration::ZERO,
shutdown_called: Arc::new(AtomicBool::new(false)),
})
.build();
let report = host.startup().await;
assert_eq!(report.discovered_extensions, 2);
assert_eq!(report.active_extensions, 1);
assert_eq!(report.failed_extensions, 1);
host.on_observer_packet(
SocketAddr::from_str("127.0.0.1:8001").expect("valid addr"),
&[1, 2, 3],
);
sleep(Duration::from_millis(50)).await;
assert_eq!(ok_counter.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn capability_policy_denies_resource_startup() {
let host = RuntimeExtensionHost::builder()
.with_capability_policy(
RuntimeExtensionCapabilityPolicy::allow_all().without(ExtensionCapability::BindUdp),
)
.add_extension(CounterExtension {
name: "bind-udp-extension",
startup_manifest: ExtensionManifest {
capabilities: vec![ExtensionCapability::BindUdp],
resources: vec![ExtensionResourceSpec::UdpListener(UdpListenerSpec {
resource_id: "udp-1".to_owned(),
bind_addr: SocketAddr::from_str("127.0.0.1:0").expect("valid addr"),
visibility: ExtensionStreamVisibility::Private,
read_buffer_bytes: 128,
})],
subscriptions: Vec::new(),
},
packet_count: Arc::new(AtomicUsize::new(0)),
shutdown_wait: Duration::ZERO,
shutdown_called: Arc::new(AtomicBool::new(false)),
})
.build();
let report = host.startup().await;
assert_eq!(report.active_extensions, 0);
assert_eq!(report.failed_extensions, 1);
}
#[tokio::test]
async fn startup_rejects_empty_resource_id() {
let host = RuntimeExtensionHost::builder()
.add_extension(CounterExtension {
name: "empty-resource-id",
startup_manifest: ExtensionManifest {
capabilities: vec![ExtensionCapability::BindUdp],
resources: vec![ExtensionResourceSpec::UdpListener(UdpListenerSpec {
resource_id: " ".to_owned(),
bind_addr: SocketAddr::from_str("127.0.0.1:0").expect("valid addr"),
visibility: ExtensionStreamVisibility::Private,
read_buffer_bytes: 128,
})],
subscriptions: Vec::new(),
},
packet_count: Arc::new(AtomicUsize::new(0)),
shutdown_wait: Duration::ZERO,
shutdown_called: Arc::new(AtomicBool::new(false)),
})
.build();
let report = host.startup().await;
assert_eq!(report.active_extensions, 0);
assert_eq!(report.failed_extensions, 1);
assert!(report.failures[0].reason.contains("empty resource_id"));
}
#[tokio::test]
async fn startup_rejects_empty_extension_name() {
let host = RuntimeExtensionHost::builder()
.add_extension(CounterExtension {
name: " ",
startup_manifest: ExtensionManifest {
capabilities: vec![ExtensionCapability::BindUdp],
resources: vec![ExtensionResourceSpec::UdpListener(UdpListenerSpec {
resource_id: "udp-feed".to_owned(),
bind_addr: SocketAddr::from_str("127.0.0.1:0").expect("valid addr"),
visibility: ExtensionStreamVisibility::Private,
read_buffer_bytes: 128,
})],
subscriptions: Vec::new(),
},
packet_count: Arc::new(AtomicUsize::new(0)),
shutdown_wait: Duration::ZERO,
shutdown_called: Arc::new(AtomicBool::new(false)),
})
.build();
let report = host.startup().await;
assert_eq!(report.active_extensions, 0);
assert_eq!(report.failed_extensions, 1);
assert!(report.failures[0].reason.contains("empty name"));
}
#[tokio::test]
async fn startup_rejects_empty_shared_visibility_tag() {
let host = RuntimeExtensionHost::builder()
.add_extension(CounterExtension {
name: "empty-shared-tag",
startup_manifest: ExtensionManifest {
capabilities: vec![ExtensionCapability::BindTcp],
resources: vec![ExtensionResourceSpec::TcpListener(TcpListenerSpec {
resource_id: "tcp-feed".to_owned(),
bind_addr: SocketAddr::from_str("127.0.0.1:0").expect("valid addr"),
visibility: ExtensionStreamVisibility::Shared {
tag: " ".to_owned(),
},
read_buffer_bytes: 128,
})],
subscriptions: Vec::new(),
},
packet_count: Arc::new(AtomicUsize::new(0)),
shutdown_wait: Duration::ZERO,
shutdown_called: Arc::new(AtomicBool::new(false)),
})
.build();
let report = host.startup().await;
assert_eq!(report.active_extensions, 0);
assert_eq!(report.failed_extensions, 1);
assert!(
report.failures[0]
.reason
.contains("empty shared visibility tag")
);
}
#[tokio::test]
async fn startup_rejects_oversized_read_buffer_bytes() {
let host = RuntimeExtensionHost::builder()
.add_extension(CounterExtension {
name: "oversized-read-buffer",
startup_manifest: ExtensionManifest {
capabilities: vec![ExtensionCapability::ConnectWebSocket],
resources: vec![ExtensionResourceSpec::WsConnector(WsConnectorSpec {
resource_id: "ws-feed".to_owned(),
url: "ws://127.0.0.1:1/feed".to_owned(),
visibility: ExtensionStreamVisibility::Private,
read_buffer_bytes: MAX_RESOURCE_READ_BUFFER_BYTES.saturating_add(1),
})],
subscriptions: Vec::new(),
},
packet_count: Arc::new(AtomicUsize::new(0)),
shutdown_wait: Duration::ZERO,
shutdown_called: Arc::new(AtomicBool::new(false)),
})
.build();
let report = host.startup().await;
assert_eq!(report.active_extensions, 0);
assert_eq!(report.failed_extensions, 1);
assert!(report.failures[0].reason.contains("read_buffer_bytes"));
}
#[tokio::test]
async fn startup_rejects_zero_read_buffer_bytes() {
let host = RuntimeExtensionHost::builder()
.add_extension(CounterExtension {
name: "zero-read-buffer",
startup_manifest: ExtensionManifest {
capabilities: vec![ExtensionCapability::ConnectWebSocket],
resources: vec![ExtensionResourceSpec::WsConnector(WsConnectorSpec {
resource_id: "ws-feed".to_owned(),
url: "ws://127.0.0.1:1/feed".to_owned(),
visibility: ExtensionStreamVisibility::Private,
read_buffer_bytes: 0,
})],
subscriptions: Vec::new(),
},
packet_count: Arc::new(AtomicUsize::new(0)),
shutdown_wait: Duration::ZERO,
shutdown_called: Arc::new(AtomicBool::new(false)),
})
.build();
let report = host.startup().await;
assert_eq!(report.active_extensions, 0);
assert_eq!(report.failed_extensions, 1);
assert!(report.failures[0].reason.contains("zero read_buffer_bytes"));
}
#[tokio::test]
async fn startup_rejects_invalid_websocket_url() {
let host = RuntimeExtensionHost::builder()
.add_extension(CounterExtension {
name: "invalid-websocket-url",
startup_manifest: ExtensionManifest {
capabilities: vec![ExtensionCapability::ConnectWebSocket],
resources: vec![ExtensionResourceSpec::WsConnector(WsConnectorSpec {
resource_id: "ws-feed".to_owned(),
url: "not a websocket url".to_owned(),
visibility: ExtensionStreamVisibility::Private,
read_buffer_bytes: 128,
})],
subscriptions: Vec::new(),
},
packet_count: Arc::new(AtomicUsize::new(0)),
shutdown_wait: Duration::ZERO,
shutdown_called: Arc::new(AtomicBool::new(false)),
})
.build();
let report = host.startup().await;
assert_eq!(report.active_extensions, 0);
assert_eq!(report.failed_extensions, 1);
assert!(report.failures[0].reason.contains("invalid websocket url"));
}
#[tokio::test]
async fn startup_rejects_empty_subscription_shared_tag() {
let host = RuntimeExtensionHost::builder()
.add_extension(CounterExtension {
name: "empty-subscription-shared-tag",
startup_manifest: ExtensionManifest {
capabilities: vec![ExtensionCapability::ObserveSharedExtensionStream],
resources: Vec::new(),
subscriptions: vec![PacketSubscription {
source_kind: Some(RuntimePacketSourceKind::ExtensionResource),
shared_tag: Some(" ".to_owned()),
..PacketSubscription::default()
}],
},
packet_count: Arc::new(AtomicUsize::new(0)),
shutdown_wait: Duration::ZERO,
shutdown_called: Arc::new(AtomicBool::new(false)),
})
.build();
let report = host.startup().await;
assert_eq!(report.active_extensions, 0);
assert_eq!(report.failed_extensions, 1);
assert!(report.failures[0].reason.contains("empty shared_tag"));
}
#[tokio::test]
async fn startup_rejects_shared_stream_subscription_without_capability() {
let host = RuntimeExtensionHost::builder()
.add_extension(CounterExtension {
name: "missing-shared-stream-capability",
startup_manifest: ExtensionManifest {
capabilities: vec![ExtensionCapability::ObserveObserverIngress],
resources: Vec::new(),
subscriptions: vec![PacketSubscription {
source_kind: Some(RuntimePacketSourceKind::ExtensionResource),
shared_tag: Some("shared-feed".to_owned()),
..PacketSubscription::default()
}],
},
packet_count: Arc::new(AtomicUsize::new(0)),
shutdown_wait: Duration::ZERO,
shutdown_called: Arc::new(AtomicBool::new(false)),
})
.build();
let report = host.startup().await;
assert_eq!(report.active_extensions, 0);
assert_eq!(report.failed_extensions, 1);
assert!(
report.failures[0]
.reason
.contains("ObserveSharedExtensionStream capability")
);
}
#[tokio::test]
async fn startup_rejects_observer_ingress_subscription_without_capability() {
let host = RuntimeExtensionHost::builder()
.add_extension(CounterExtension {
name: "missing-observer-ingress-capability",
startup_manifest: ExtensionManifest {
capabilities: vec![ExtensionCapability::ObserveSharedExtensionStream],
resources: Vec::new(),
subscriptions: vec![PacketSubscription {
source_kind: Some(RuntimePacketSourceKind::ObserverIngress),
..PacketSubscription::default()
}],
},
packet_count: Arc::new(AtomicUsize::new(0)),
shutdown_wait: Duration::ZERO,
shutdown_called: Arc::new(AtomicBool::new(false)),
})
.build();
let report = host.startup().await;
assert_eq!(report.active_extensions, 0);
assert_eq!(report.failed_extensions, 1);
assert!(
report.failures[0]
.reason
.contains("ObserveObserverIngress capability")
);
}
#[tokio::test]
async fn startup_rejects_observer_ingress_subscription_with_resource_selectors() {
let host = RuntimeExtensionHost::builder()
.add_extension(CounterExtension {
name: "observer-ingress-resource-selectors",
startup_manifest: ExtensionManifest {
capabilities: vec![ExtensionCapability::ObserveObserverIngress],
resources: Vec::new(),
subscriptions: vec![PacketSubscription {
source_kind: Some(RuntimePacketSourceKind::ObserverIngress),
owner_extension: Some("owner".to_owned()),
..PacketSubscription::default()
}],
},
packet_count: Arc::new(AtomicUsize::new(0)),
shutdown_wait: Duration::ZERO,
shutdown_called: Arc::new(AtomicBool::new(false)),
})
.build();
let report = host.startup().await;
assert_eq!(report.active_extensions, 0);
assert_eq!(report.failed_extensions, 1);
assert!(
report.failures[0]
.reason
.contains("extension-resource-only selectors")
);
}
#[test]
fn builder_clamps_zero_startup_timeout() {
let host = RuntimeExtensionHost::builder()
.with_startup_timeout(Duration::ZERO)
.build();
assert_eq!(host.inner.startup_timeout, Duration::from_millis(1));
}
#[test]
fn builder_clamps_zero_shutdown_timeout() {
let host = RuntimeExtensionHost::builder()
.with_shutdown_timeout(Duration::ZERO)
.build();
assert_eq!(host.inner.shutdown_timeout, Duration::from_millis(1));
}
#[tokio::test]
async fn production_defaults_deny_outbound_connectors() {
let host = RuntimeExtensionHost::production_builder()
.add_extension(CounterExtension {
name: "connect-tcp-extension",
startup_manifest: ExtensionManifest {
capabilities: vec![ExtensionCapability::ConnectTcp],
resources: vec![ExtensionResourceSpec::TcpConnector(TcpConnectorSpec {
resource_id: "tcp-outbound".to_owned(),
remote_addr: SocketAddr::from_str("127.0.0.1:9").expect("valid addr"),
visibility: ExtensionStreamVisibility::Private,
read_buffer_bytes: 128,
})],
subscriptions: Vec::new(),
},
packet_count: Arc::new(AtomicUsize::new(0)),
shutdown_wait: Duration::ZERO,
shutdown_called: Arc::new(AtomicBool::new(false)),
})
.build();
let report = host.startup().await;
assert_eq!(report.active_extensions, 0);
assert_eq!(report.failed_extensions, 1);
assert!(report.failures[0].reason.contains("not allowed"));
}
#[tokio::test]
async fn strict_name_policy_rejects_implicit_type_name_extensions() {
let host = RuntimeExtensionHost::builder()
.with_require_explicit_extension_names(true)
.add_extension(ImplicitNameExtension)
.build();
let report = host.startup().await;
assert_eq!(report.active_extensions, 0);
assert_eq!(report.failed_extensions, 1);
assert!(
report.failures[0]
.reason
.contains("requires explicit stable extension names")
);
}
#[tokio::test]
async fn owner_only_and_shared_stream_visibility() {
let owner_counter = Arc::new(AtomicUsize::new(0));
let shared_counter = Arc::new(AtomicUsize::new(0));
let host = RuntimeExtensionHost::builder()
.add_extension(CounterExtension {
name: "owner",
startup_manifest: ExtensionManifest {
capabilities: vec![],
resources: Vec::new(),
subscriptions: vec![PacketSubscription {
source_kind: Some(RuntimePacketSourceKind::ExtensionResource),
owner_extension: Some("owner".to_owned()),
..PacketSubscription::default()
}],
},
packet_count: Arc::clone(&owner_counter),
shutdown_wait: Duration::ZERO,
shutdown_called: Arc::new(AtomicBool::new(false)),
})
.add_extension(CounterExtension {
name: "shared-reader",
startup_manifest: ExtensionManifest {
capabilities: vec![ExtensionCapability::ObserveSharedExtensionStream],
resources: Vec::new(),
subscriptions: vec![PacketSubscription {
source_kind: Some(RuntimePacketSourceKind::ExtensionResource),
shared_tag: Some("shared-feed".to_owned()),
..PacketSubscription::default()
}],
},
packet_count: Arc::clone(&shared_counter),
shutdown_wait: Duration::ZERO,
shutdown_called: Arc::new(AtomicBool::new(false)),
})
.build();
let report = host.startup().await;
assert_eq!(report.active_extensions, 2);
host.emit_extension_packet(
RuntimePacketSource {
kind: RuntimePacketSourceKind::ExtensionResource,
transport: RuntimePacketTransport::Tcp,
event_class: RuntimePacketEventClass::Packet,
owner_extension: Some("owner".to_owned()),
resource_id: Some("feed-1".to_owned()),
shared_tag: None,
websocket_frame_type: None,
local_addr: None,
remote_addr: None,
},
Arc::from(&[1_u8][..]),
);
sleep(Duration::from_millis(50)).await;
assert_eq!(owner_counter.load(Ordering::Relaxed), 1);
assert_eq!(shared_counter.load(Ordering::Relaxed), 0);
host.emit_extension_packet(
RuntimePacketSource {
kind: RuntimePacketSourceKind::ExtensionResource,
transport: RuntimePacketTransport::Tcp,
event_class: RuntimePacketEventClass::Packet,
owner_extension: Some("owner".to_owned()),
resource_id: Some("feed-1".to_owned()),
shared_tag: Some("shared-feed".to_owned()),
websocket_frame_type: None,
local_addr: None,
remote_addr: None,
},
Arc::from(&[2_u8][..]),
);
sleep(Duration::from_millis(50)).await;
assert_eq!(owner_counter.load(Ordering::Relaxed), 2);
assert_eq!(shared_counter.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn websocket_frame_type_subscription_filters_dispatch() {
let text_counter = Arc::new(AtomicUsize::new(0));
let any_counter = Arc::new(AtomicUsize::new(0));
let host = RuntimeExtensionHost::builder()
.add_extension(CounterExtension {
name: "text-only",
startup_manifest: ExtensionManifest {
capabilities: vec![ExtensionCapability::ObserveSharedExtensionStream],
resources: Vec::new(),
subscriptions: vec![PacketSubscription {
source_kind: Some(RuntimePacketSourceKind::ExtensionResource),
transport: Some(RuntimePacketTransport::WebSocket),
shared_tag: Some("ws-shared".to_owned()),
websocket_frame_type: Some(RuntimeWebSocketFrameType::Text),
..PacketSubscription::default()
}],
},
packet_count: Arc::clone(&text_counter),
shutdown_wait: Duration::ZERO,
shutdown_called: Arc::new(AtomicBool::new(false)),
})
.add_extension(CounterExtension {
name: "any-frame",
startup_manifest: ExtensionManifest {
capabilities: vec![ExtensionCapability::ObserveSharedExtensionStream],
resources: Vec::new(),
subscriptions: vec![PacketSubscription {
source_kind: Some(RuntimePacketSourceKind::ExtensionResource),
transport: Some(RuntimePacketTransport::WebSocket),
shared_tag: Some("ws-shared".to_owned()),
..PacketSubscription::default()
}],
},
packet_count: Arc::clone(&any_counter),
shutdown_wait: Duration::ZERO,
shutdown_called: Arc::new(AtomicBool::new(false)),
})
.build();
let report = host.startup().await;
assert_eq!(report.active_extensions, 2);
host.emit_extension_packet(
RuntimePacketSource {
kind: RuntimePacketSourceKind::ExtensionResource,
transport: RuntimePacketTransport::WebSocket,
event_class: RuntimePacketEventClass::Packet,
owner_extension: Some("owner".to_owned()),
resource_id: Some("ws-feed".to_owned()),
shared_tag: Some("ws-shared".to_owned()),
websocket_frame_type: Some(RuntimeWebSocketFrameType::Text),
local_addr: None,
remote_addr: None,
},
Arc::from(&[1_u8][..]),
);
host.emit_extension_packet(
RuntimePacketSource {
kind: RuntimePacketSourceKind::ExtensionResource,
transport: RuntimePacketTransport::WebSocket,
event_class: RuntimePacketEventClass::Packet,
owner_extension: Some("owner".to_owned()),
resource_id: Some("ws-feed".to_owned()),
shared_tag: Some("ws-shared".to_owned()),
websocket_frame_type: Some(RuntimeWebSocketFrameType::Binary),
local_addr: None,
remote_addr: None,
},
Arc::from(&[2_u8][..]),
);
sleep(Duration::from_millis(50)).await;
assert_eq!(text_counter.load(Ordering::Relaxed), 1);
assert_eq!(any_counter.load(Ordering::Relaxed), 2);
}
#[tokio::test]
async fn connection_closed_event_dispatches_with_empty_payload() {
let close_counter = Arc::new(AtomicUsize::new(0));
let host = RuntimeExtensionHost::builder()
.add_extension(CounterExtension {
name: "close-reader",
startup_manifest: ExtensionManifest {
capabilities: vec![ExtensionCapability::ObserveSharedExtensionStream],
resources: Vec::new(),
subscriptions: vec![PacketSubscription {
source_kind: Some(RuntimePacketSourceKind::ExtensionResource),
transport: Some(RuntimePacketTransport::WebSocket),
event_class: Some(RuntimePacketEventClass::ConnectionClosed),
shared_tag: Some("ws-close".to_owned()),
..PacketSubscription::default()
}],
},
packet_count: Arc::clone(&close_counter),
shutdown_wait: Duration::ZERO,
shutdown_called: Arc::new(AtomicBool::new(false)),
})
.build();
let report = host.startup().await;
assert_eq!(report.active_extensions, 1);
host.emit_extension_packet(
RuntimePacketSource {
kind: RuntimePacketSourceKind::ExtensionResource,
transport: RuntimePacketTransport::WebSocket,
event_class: RuntimePacketEventClass::ConnectionClosed,
owner_extension: Some("owner".to_owned()),
resource_id: Some("ws-feed".to_owned()),
shared_tag: Some("ws-close".to_owned()),
websocket_frame_type: None,
local_addr: None,
remote_addr: None,
},
Arc::from(&[][..]),
);
sleep(Duration::from_millis(50)).await;
assert_eq!(close_counter.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn tcp_listener_accepts_new_connections_while_existing_stream_stays_open() {
let probe = StdTcpListener::bind("127.0.0.1:0").expect("bind probe listener");
let bind_addr = probe.local_addr().expect("probe local addr");
drop(probe);
let packet_count = Arc::new(AtomicUsize::new(0));
let host = RuntimeExtensionHost::builder()
.add_extension(CounterExtension {
name: "tcp-listener-extension",
startup_manifest: ExtensionManifest {
capabilities: vec![ExtensionCapability::BindTcp],
resources: vec![ExtensionResourceSpec::TcpListener(TcpListenerSpec {
resource_id: "tcp-listener".to_owned(),
bind_addr,
visibility: ExtensionStreamVisibility::Private,
read_buffer_bytes: 128,
})],
subscriptions: vec![PacketSubscription {
source_kind: Some(RuntimePacketSourceKind::ExtensionResource),
transport: Some(RuntimePacketTransport::Tcp),
owner_extension: Some("tcp-listener-extension".to_owned()),
..PacketSubscription::default()
}],
},
packet_count: Arc::clone(&packet_count),
shutdown_wait: Duration::ZERO,
shutdown_called: Arc::new(AtomicBool::new(false)),
})
.build();
let report = host.startup().await;
assert_eq!(report.active_extensions, 1);
let _first = TcpStream::connect(bind_addr)
.await
.expect("connect first tcp client");
let mut second = TcpStream::connect(bind_addr)
.await
.expect("connect second tcp client");
assert!(second.write_all(b"second").await.is_ok());
sleep(Duration::from_millis(50)).await;
assert_eq!(packet_count.load(Ordering::Relaxed), 1);
host.shutdown().await;
}
struct SlowExtension {
counter: Arc<AtomicUsize>,
}
#[async_trait]
impl RuntimeExtension for SlowExtension {
fn name(&self) -> &'static str {
"slow-extension"
}
async fn setup(
&self,
_ctx: ExtensionContext,
) -> Result<ExtensionManifest, ExtensionSetupError> {
Ok(ExtensionManifest {
capabilities: vec![ExtensionCapability::ObserveObserverIngress],
resources: Vec::new(),
subscriptions: vec![PacketSubscription {
source_kind: Some(RuntimePacketSourceKind::ObserverIngress),
..PacketSubscription::default()
}],
})
}
async fn on_packet_received(&self, _event: RuntimePacketEvent) {
sleep(Duration::from_millis(120)).await;
self.counter.fetch_add(1, Ordering::Relaxed);
}
}
#[tokio::test]
async fn queue_pressure_drops_events_without_blocking() {
let counter = Arc::new(AtomicUsize::new(0));
let host = RuntimeExtensionHost::builder()
.with_event_queue_capacity(1)
.add_extension(SlowExtension {
counter: Arc::clone(&counter),
})
.build();
let report = host.startup().await;
assert_eq!(report.active_extensions, 1);
let source = SocketAddr::from_str("127.0.0.1:9001").expect("valid addr");
for _ in 0..16 {
host.on_observer_packet(source, &[7_u8; 32]);
}
sleep(Duration::from_millis(350)).await;
assert!(counter.load(Ordering::Relaxed) < 16);
assert!(host.dropped_event_count() > 0);
let metrics = host.dispatch_metrics_by_extension();
assert_eq!(metrics.len(), 1);
assert!(metrics[0].dropped_events > 0);
assert!(metrics[0].max_queue_depth >= 1);
assert_eq!(metrics[0].queue_depth, 0);
assert!(metrics[0].dispatched_events >= 1);
}
#[tokio::test]
async fn packet_callback_panic_does_not_stop_dispatcher() {
let packet_count = Arc::new(AtomicUsize::new(0));
let host = RuntimeExtensionHost::builder()
.add_extension(PanicOnceExtension {
panic_seen: AtomicBool::new(false),
packet_count: Arc::clone(&packet_count),
})
.build();
let report = host.startup().await;
assert_eq!(report.active_extensions, 1);
let source = SocketAddr::from_str("127.0.0.1:9001").expect("valid addr");
host.on_observer_packet(source, &[1_u8; 8]);
host.on_observer_packet(source, &[2_u8; 8]);
sleep(Duration::from_millis(100)).await;
assert_eq!(packet_count.load(Ordering::Relaxed), 1);
host.shutdown().await;
}
#[tokio::test]
#[ignore = "profiling fixture for runtime extension callback isolation"]
async fn runtime_extension_callback_isolation_profile_fixture() {
let iterations = profile_iterations(50_000);
let baseline_extension = Arc::new(CounterExtension {
name: "baseline-counter",
startup_manifest: ExtensionManifest {
capabilities: Vec::new(),
resources: Vec::new(),
subscriptions: Vec::new(),
},
packet_count: Arc::new(AtomicUsize::new(0)),
shutdown_wait: Duration::ZERO,
shutdown_called: Arc::new(AtomicBool::new(false)),
});
let optimized_extension = Arc::new(CounterExtension {
name: "optimized-counter",
startup_manifest: ExtensionManifest {
capabilities: Vec::new(),
resources: Vec::new(),
subscriptions: Vec::new(),
},
packet_count: Arc::new(AtomicUsize::new(0)),
shutdown_wait: Duration::ZERO,
shutdown_called: Arc::new(AtomicBool::new(false)),
});
let event = sample_runtime_packet_event();
let baseline_started_at = Instant::now();
for _ in 0..iterations {
invoke_extension_callback_baseline(
Arc::clone(&baseline_extension) as Arc<dyn RuntimeExtension>,
event.clone(),
)
.await
.expect("baseline callback");
}
let baseline_elapsed = baseline_started_at.elapsed();
let optimized_started_at = Instant::now();
for _ in 0..iterations {
invoke_extension_callback(
Arc::clone(&optimized_extension) as Arc<dyn RuntimeExtension>,
event.clone(),
)
.await
.expect("optimized callback");
}
let optimized_elapsed = optimized_started_at.elapsed();
let baseline_avg_ns = avg_ns_per_iteration(baseline_elapsed, iterations);
let optimized_avg_ns = avg_ns_per_iteration(optimized_elapsed, iterations);
println!(
"runtime_extension_callback_isolation_profile_fixture iterations={} baseline_us={} optimized_us={} baseline_avg_ns_per_iteration={} optimized_avg_ns_per_iteration={} baseline_avg_us_per_iteration={:.3} optimized_avg_us_per_iteration={:.3}",
iterations,
baseline_elapsed.as_micros(),
optimized_elapsed.as_micros(),
baseline_avg_ns,
optimized_avg_ns,
baseline_avg_ns as f64 / 1_000.0,
optimized_avg_ns as f64 / 1_000.0,
);
}
#[test]
#[ignore = "profiling fixture for udp extension emitter remote-address churn"]
fn udp_extension_emitter_remote_addr_profile_fixture() {
let iterations = profile_iterations(200_000);
let host = RuntimeExtensionHost::default();
let emitter = ExtensionResourceEmitter::new(
host,
"udp-extension",
"udp-listener",
None,
RuntimePacketTransport::Udp,
Some(SocketAddr::from_str("127.0.0.1:7000").expect("valid local addr")),
None,
);
let remote_addr = Some(SocketAddr::from_str("127.0.0.1:8000").expect("valid remote addr"));
let payload = Arc::from(&[9_u8; 256][..]);
let baseline_started_at = Instant::now();
for _ in 0..iterations {
ExtensionResourceEmitter {
remote_addr,
..emitter.clone()
}
.emit_event(RuntimePacketEventClass::Packet, None, Arc::clone(&payload));
}
let baseline_elapsed = baseline_started_at.elapsed();
let optimized_started_at = Instant::now();
for _ in 0..iterations {
emitter.emit_event_with_remote_addr(
RuntimePacketEventClass::Packet,
None,
remote_addr,
Arc::clone(&payload),
);
}
let optimized_elapsed = optimized_started_at.elapsed();
let baseline_avg_ns = avg_ns_per_iteration(baseline_elapsed, iterations);
let optimized_avg_ns = avg_ns_per_iteration(optimized_elapsed, iterations);
println!(
"udp_extension_emitter_remote_addr_profile_fixture iterations={} baseline_us={} optimized_us={} baseline_avg_ns_per_iteration={} optimized_avg_ns_per_iteration={} baseline_avg_us_per_iteration={:.3} optimized_avg_us_per_iteration={:.3}",
iterations,
baseline_elapsed.as_micros(),
optimized_elapsed.as_micros(),
baseline_avg_ns,
optimized_avg_ns,
baseline_avg_ns as f64 / 1_000.0,
optimized_avg_ns as f64 / 1_000.0,
);
}
#[tokio::test]
async fn shutdown_deadline_then_cancel() {
let shutdown_called = Arc::new(AtomicBool::new(false));
let shutdown_timeout = Duration::from_millis(25);
let shutdown_wait = Duration::from_secs(5);
let host = RuntimeExtensionHost::builder()
.with_shutdown_timeout(shutdown_timeout)
.add_extension(CounterExtension {
name: "slow-shutdown",
startup_manifest: ExtensionManifest {
capabilities: vec![ExtensionCapability::ObserveObserverIngress],
resources: Vec::new(),
subscriptions: vec![PacketSubscription {
source_kind: Some(RuntimePacketSourceKind::ObserverIngress),
..PacketSubscription::default()
}],
},
packet_count: Arc::new(AtomicUsize::new(0)),
shutdown_wait,
shutdown_called: Arc::clone(&shutdown_called),
})
.build();
let report = host.startup().await;
assert_eq!(report.active_extensions, 1);
let started = TokioInstant::now();
host.shutdown().await;
let elapsed = started.elapsed();
assert!(elapsed >= shutdown_timeout);
assert!(elapsed < Duration::from_secs(1));
assert!(elapsed < shutdown_wait);
assert!(shutdown_called.load(Ordering::Relaxed));
}
#[tokio::test]
#[ignore = "requires local socket bind/connect permissions"]
async fn startup_provisions_udp_tcp_and_ws_resources() {
let tcp_server = TcpListener::bind("127.0.0.1:0")
.await
.expect("bind tcp server");
let tcp_server_addr = tcp_server.local_addr().expect("tcp local addr");
let tcp_server_task = tokio::spawn(async move {
if let Ok((mut stream, _)) = tcp_server.accept().await {
assert!(stream.write_all(b"tcp").await.is_ok());
}
});
let ws_server = TcpListener::bind("127.0.0.1:0")
.await
.expect("bind ws server");
let ws_server_addr = ws_server.local_addr().expect("ws local addr");
let ws_server_task = tokio::spawn(async move {
if let Ok((stream, _)) = ws_server.accept().await
&& let Ok(mut websocket) = accept_async(stream).await
{
assert!(websocket.send(Message::Text("ws".into())).await.is_ok());
}
});
let host = RuntimeExtensionHost::builder()
.add_extension(CounterExtension {
name: "resource-extension",
startup_manifest: ExtensionManifest {
capabilities: vec![
ExtensionCapability::BindUdp,
ExtensionCapability::BindTcp,
ExtensionCapability::ConnectTcp,
ExtensionCapability::ConnectWebSocket,
],
resources: vec![
ExtensionResourceSpec::UdpListener(UdpListenerSpec {
resource_id: "udp-listener".to_owned(),
bind_addr: SocketAddr::from_str("127.0.0.1:0").expect("addr"),
visibility: ExtensionStreamVisibility::Private,
read_buffer_bytes: 128,
}),
ExtensionResourceSpec::TcpListener(TcpListenerSpec {
resource_id: "tcp-listener".to_owned(),
bind_addr: SocketAddr::from_str("127.0.0.1:0").expect("addr"),
visibility: ExtensionStreamVisibility::Private,
read_buffer_bytes: 128,
}),
ExtensionResourceSpec::TcpConnector(TcpConnectorSpec {
resource_id: "tcp-connector".to_owned(),
remote_addr: tcp_server_addr,
visibility: ExtensionStreamVisibility::Private,
read_buffer_bytes: 128,
}),
ExtensionResourceSpec::WsConnector(WsConnectorSpec {
resource_id: "ws-connector".to_owned(),
url: format!("ws://{ws_server_addr}/feed"),
visibility: ExtensionStreamVisibility::Private,
read_buffer_bytes: 128,
}),
],
subscriptions: vec![PacketSubscription {
source_kind: Some(RuntimePacketSourceKind::ExtensionResource),
..PacketSubscription::default()
}],
},
packet_count: Arc::new(AtomicUsize::new(0)),
shutdown_wait: Duration::ZERO,
shutdown_called: Arc::new(AtomicBool::new(false)),
})
.build();
let report = host.startup().await;
assert_eq!(report.active_extensions, 1);
assert_eq!(report.failed_extensions, 0);
assert!(tcp_server_task.await.is_ok());
assert!(ws_server_task.await.is_ok());
host.shutdown().await;
}
#[tokio::test]
#[ignore = "requires local socket bind/connect permissions"]
async fn websocket_connector_remote_close_dispatches_connection_closed() {
let ws_server = TcpListener::bind("127.0.0.1:0")
.await
.expect("bind ws server");
let ws_server_addr = ws_server.local_addr().expect("ws local addr");
let ws_server_task = tokio::spawn(async move {
if let Ok((stream, _)) = ws_server.accept().await
&& let Ok(mut websocket) = accept_async(stream).await
{
assert!(websocket.close(None).await.is_ok());
}
});
let closed_count = Arc::new(AtomicUsize::new(0));
let host = RuntimeExtensionHost::builder()
.add_extension(CounterExtension {
name: "ws-close-extension",
startup_manifest: ExtensionManifest {
capabilities: vec![ExtensionCapability::ConnectWebSocket],
resources: vec![ExtensionResourceSpec::WsConnector(WsConnectorSpec {
resource_id: "ws-connector".to_owned(),
url: format!("ws://{ws_server_addr}/feed"),
visibility: ExtensionStreamVisibility::Private,
read_buffer_bytes: 128,
})],
subscriptions: vec![PacketSubscription {
source_kind: Some(RuntimePacketSourceKind::ExtensionResource),
transport: Some(RuntimePacketTransport::WebSocket),
event_class: Some(RuntimePacketEventClass::ConnectionClosed),
owner_extension: Some("ws-close-extension".to_owned()),
..PacketSubscription::default()
}],
},
packet_count: Arc::clone(&closed_count),
shutdown_wait: Duration::ZERO,
shutdown_called: Arc::new(AtomicBool::new(false)),
})
.build();
let report = host.startup().await;
assert_eq!(report.active_extensions, 1);
sleep(Duration::from_millis(100)).await;
assert_eq!(closed_count.load(Ordering::Relaxed), 1);
assert!(ws_server_task.await.is_ok());
host.shutdown().await;
}
#[tokio::test]
async fn startup_times_out_hung_websocket_connector() {
let ws_server = TcpListener::bind("127.0.0.1:0")
.await
.expect("bind ws server");
let ws_server_addr = ws_server.local_addr().expect("ws local addr");
let ws_server_task = tokio::spawn(async move {
if let Ok((_stream, _)) = ws_server.accept().await {
sleep(Duration::from_secs(5)).await;
}
});
let host = RuntimeExtensionHost::builder()
.with_startup_timeout(Duration::from_millis(50))
.add_extension(CounterExtension {
name: "hung-ws-connector",
startup_manifest: ExtensionManifest {
capabilities: vec![ExtensionCapability::ConnectWebSocket],
resources: vec![ExtensionResourceSpec::WsConnector(WsConnectorSpec {
resource_id: "ws-connector".to_owned(),
url: format!("ws://{ws_server_addr}/feed"),
visibility: ExtensionStreamVisibility::Private,
read_buffer_bytes: 128,
})],
subscriptions: Vec::new(),
},
packet_count: Arc::new(AtomicUsize::new(0)),
shutdown_wait: Duration::ZERO,
shutdown_called: Arc::new(AtomicBool::new(false)),
})
.build();
let report = host.startup().await;
assert_eq!(report.active_extensions, 0);
assert_eq!(report.failed_extensions, 1);
assert!(report.failures[0].reason.contains("timed out"));
ws_server_task.abort();
drop(ws_server_task.await);
}
#[test]
fn extension_websocket_transport_config_caps_frames_from_chunk_size() {
let transport = extension_websocket_transport_config(4_096);
let expected = 4_096_usize.saturating_mul(EXTENSION_WEBSOCKET_MESSAGE_LIMIT_MULTIPLIER);
assert_eq!(transport.max_message_size, Some(expected));
assert_eq!(transport.max_frame_size, Some(expected));
let floor_transport = extension_websocket_transport_config(1);
assert_eq!(
floor_transport.max_message_size,
Some(
DEFAULT_RESOURCE_READ_BUFFER_BYTES
.saturating_mul(EXTENSION_WEBSOCKET_MESSAGE_LIMIT_MULTIPLIER)
)
);
assert_eq!(
floor_transport.max_frame_size,
Some(
DEFAULT_RESOURCE_READ_BUFFER_BYTES
.saturating_mul(EXTENSION_WEBSOCKET_MESSAGE_LIMIT_MULTIPLIER)
)
);
}
}