use std::sync::Arc;
use crate::bus::{AllocFailReason, AppIngressErrorKind, AppIngressSource, NodeEvent, TypedBus};
use crate::completion::{CompletionHandle, CompletionSink};
use crate::framework::rtt_tracker::{chain_id_from_targets, ChainContext};
use crate::framework::{
rtt_tracker::RttTracker, AddressBook, BackoffTable, BackpressureTracker, EventSource,
HoldTable, InboundDedup, OutboundQueue, PeerGate, PeerGovernor, RecordBuffer, RequestTracker,
RngU64Source, Scheduler, SerializeQueue,
};
use crate::ids::{CommandId, OpRef};
use crate::ingress::{IngressEvent, IngressQueue, COMPLETION_DETAIL_CAP};
use crate::slot_value::SlotValue;
impl CompletionSink for IngressQueue {
fn complete(&self, cmd_id: CommandId, result_bytes: &[u8]) {
let byte_count = result_bytes.len();
let cap = self.completion_result_cap();
if byte_count > cap {
let _ = self.push(IngressEvent::AppIngressError {
source: AppIngressSource::Completion { command: cmd_id },
byte_count,
kind: AppIngressErrorKind::PerItemCapExceeded { cap },
});
return;
}
let mut owned: Vec<u8> = Vec::new();
if crate::fallible::try_reserve_exact(&mut owned, byte_count).is_err() {
let _ = self.push(IngressEvent::AppIngressError {
source: AppIngressSource::Completion { command: cmd_id },
byte_count,
kind: AppIngressErrorKind::AllocationFailed {
reason: AllocFailReason::HeapExhausted,
},
});
return;
}
owned.extend_from_slice(result_bytes);
let _ = self.push(IngressEvent::Completion {
cmd_id,
results: vec![owned],
});
}
fn fail(&self, cmd_id: CommandId, detail: &str) {
let truncated = if detail.len() > COMPLETION_DETAIL_CAP {
let mut end = COMPLETION_DETAIL_CAP;
while end > 0 && !detail.is_char_boundary(end) {
end -= 1;
}
&detail[..end]
} else {
detail
};
let owned: String = truncated.to_string();
let _ = self.push(IngressEvent::CompletionFailed {
cmd_id,
detail: owned,
});
}
}
pub struct PeerCtx<'a> {
pub gate: &'a mut PeerGate,
pub backoff: &'a mut BackoffTable,
pub governor: &'a mut PeerGovernor,
pub addresses: &'a mut AddressBook,
pub backpressure: &'a mut BackpressureTracker,
}
pub struct NetCtx<'a> {
pub outbound: &'a mut OutboundQueue,
pub rtt: &'a mut RttTracker,
pub requests: &'a mut RequestTracker,
pub dedup: &'a mut InboundDedup,
pub pending_peer_resolve_failures: &'a mut Vec<(Option<crate::ids::PeerId>, crate::ids::OpRef)>,
}
pub struct TimeCtx<'a> {
pub scheduler: &'a mut Scheduler,
}
pub struct SyscallCtx<'a> {
pub serialize_queue: &'a mut SerializeQueue,
pub hold_table: &'a mut HoldTable,
pub record_buffer: &'a mut RecordBuffer,
pub event_source: &'a mut EventSource,
pub counters: &'a mut std::collections::HashMap<String, u64>,
pub any_fired_groups: &'a mut std::collections::HashSet<String>,
pub deadline_match_fired: &'a mut std::collections::HashSet<(u64, u64)>,
pub rng: &'a mut dyn RngU64Source,
pub pending_app_events: &'a mut Vec<crate::bus::AppEvent>,
}
#[derive(Clone, Copy, Debug, Default)]
pub struct InboundCtx {
pub src_peer: Option<crate::ids::PeerId>,
pub wire_req_id: Option<u64>,
pub arrival_ns: Option<u64>,
pub remaining_deadline_ns: Option<u64>,
}
pub struct CurrentCallCtx<'a> {
pub op_ref: OpRef,
pub exec_id: crate::ids::ExecId,
pub self_peer: crate::ids::PeerId,
pub node_attributes: &'a [bb_ir::proto::onnx::AttributeProto],
pub node_metadata: &'a [bb_ir::proto::onnx::StringStringEntryProto],
pub inbound: InboundCtx,
pub pending_completions: Vec<PendingCompletion>,
pub next_command_id: &'a mut u64,
}
pub struct RuntimeResourceRef<'a> {
pub peers: PeerCtx<'a>,
pub net: NetCtx<'a>,
pub time: TimeCtx<'a>,
pub syscall: SyscallCtx<'a>,
pub bus: &'a mut TypedBus,
pub ingress: Arc<IngressQueue>,
pub components: ComponentsView<'a>,
pub current: CurrentCallCtx<'a>,
}
#[derive(Default)]
pub struct ComponentsView<'a> {
pub instances: Option<&'a [Option<Box<dyn crate::component::ErasedComponent>>]>,
pub slots: Option<&'a std::collections::HashMap<String, crate::ids::ComponentRef>>,
}
impl ComponentsView<'_> {
pub fn for_slot(&self, slot_name: &str) -> Option<&dyn crate::component::ErasedComponent> {
let slots = self.slots?;
let instances = self.instances?;
let cref = slots.get(slot_name)?;
let idx = cref.as_u32() as usize;
instances.get(idx)?.as_deref()
}
pub fn for_slot_as<T: 'static>(&self, slot_name: &str) -> Option<&T> {
let erased = self.for_slot(slot_name)?;
let any: &dyn std::any::Any = erased;
any.downcast_ref::<T>()
}
}
#[derive(Debug)]
pub enum DependencyError {
NotBound {
slot: String,
},
TypeMismatch {
slot: String,
expected: &'static str,
},
}
impl std::fmt::Display for DependencyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NotBound { slot } => write!(f, "no component bound at slot `{slot}`"),
Self::TypeMismatch { slot, expected } => {
write!(f, "component at slot `{slot}` is not a `{expected}`",)
}
}
}
}
impl std::error::Error for DependencyError {}
impl RuntimeResourceRef<'_> {
pub fn local_addresses(&self) -> &[crate::framework::Address] {
self.peers
.addresses
.lookup(self.current.self_peer)
.unwrap_or(&[])
}
pub fn dependency<T: 'static>(&self, slot_name: &str) -> Result<&T, DependencyError> {
if self.components.for_slot(slot_name).is_none() {
return Err(DependencyError::NotBound {
slot: slot_name.to_string(),
});
}
self.components
.for_slot_as::<T>(slot_name)
.ok_or_else(|| DependencyError::TypeMismatch {
slot: slot_name.to_string(),
expected: std::any::type_name::<T>(),
})
}
pub fn estimate_wire_budget_ns(
&self,
target: crate::ids::NodeSiteId,
chain: Option<crate::framework::rtt_tracker::ChainContext>,
static_default_ns: u64,
) -> u64 {
self.net
.rtt
.estimate_budget_ns(target, chain, static_default_ns)
}
pub fn read_chain_context(&self) -> Option<crate::framework::rtt_tracker::ChainContext> {
let mut chain_targets: Option<&str> = None;
let mut hop_index: u8 = 0;
for prop in self.current.node_metadata {
match prop.key.as_str() {
"ai.bytesandbrains.wire.chain_targets" => {
chain_targets = Some(prop.value.as_str());
}
"ai.bytesandbrains.wire.chain_hop_index" => {
if let Ok(h) = prop.value.parse::<u8>() {
hop_index = h;
}
}
_ => {}
}
}
chain_targets.map(|targets| ChainContext {
chain_id: chain_id_from_targets(targets),
hop_index,
})
}
pub fn observe_wire_round_trip(
&mut self,
target: crate::ids::NodeSiteId,
chain: Option<crate::framework::rtt_tracker::ChainContext>,
elapsed_ns: u64,
now_ns: u64,
) {
self.net
.rtt
.observe_round_trip(target, chain, elapsed_ns, now_ns);
}
pub fn allocate_command_id(&mut self) -> CommandId {
let id = *self.current.next_command_id;
*self.current.next_command_id = self.current.next_command_id.saturating_add(1);
CommandId::from(id)
}
pub fn complete_command(
&mut self,
cmd_id: CommandId,
results: Vec<(String, Box<dyn SlotValue>)>,
) {
self.current
.pending_completions
.push(PendingCompletion { cmd_id, results });
}
pub fn publish_bus(&mut self, event: NodeEvent) {
self.bus.publish(event);
}
pub fn open_completion<R, E>(&mut self) -> CompletionHandle<R, E>
where
R: serde::Serialize,
E: std::fmt::Display,
{
let cmd_id = self.allocate_command_id();
let sink: Arc<dyn CompletionSink> = self.ingress.clone();
CompletionHandle::new(cmd_id, sink)
}
}
pub struct PendingCompletion {
pub cmd_id: CommandId,
pub results: Vec<(String, Box<dyn SlotValue>)>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct ComponentTimerKind(pub u32);
impl ComponentTimerKind {
pub const fn new(kind: u32) -> Self {
Self(kind)
}
pub const fn as_u32(self) -> u32 {
self.0
}
}