use crate::{
attach_shared_consumer, build_shared_single_producer, MultiProcessError, SharedConsumer,
SharedDisruptorBuilder, SharedProducer,
};
use std::error::Error;
use std::fmt::{Display, Formatter};
use std::time::Duration;
const DEFAULT_COORDINATION_TIMEOUT: Duration = Duration::from_secs(15);
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum WorkerCount {
Two,
Three,
Four,
Five,
Six,
Seven,
Eight,
}
impl WorkerCount {
pub const fn as_usize(self) -> usize {
match self {
Self::Two => 2,
Self::Three => 3,
Self::Four => 4,
Self::Five => 5,
Self::Six => 6,
Self::Seven => 7,
Self::Eight => 8,
}
}
pub const fn from_usize(value: usize) -> Option<Self> {
match value {
2 => Some(Self::Two),
3 => Some(Self::Three),
4 => Some(Self::Four),
5 => Some(Self::Five),
6 => Some(Self::Six),
7 => Some(Self::Seven),
8 => Some(Self::Eight),
_ => None,
}
}
}
#[derive(Debug)]
pub enum InferenceTopologyError {
InvalidWorkerIndex {
index: usize,
worker_count: usize,
},
MultiProcess(MultiProcessError),
}
impl Display for InferenceTopologyError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidWorkerIndex { index, worker_count } => write!(
f,
"worker index {index} is out of range for fixed topology with {worker_count} workers"
),
Self::MultiProcess(error) => Display::fmt(error, f),
}
}
}
impl Error for InferenceTopologyError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
Self::InvalidWorkerIndex { .. } => None,
Self::MultiProcess(error) => Some(error),
}
}
}
impl From<MultiProcessError> for InferenceTopologyError {
fn from(error: MultiProcessError) -> Self {
Self::MultiProcess(error)
}
}
pub type InferenceTopologyResult<T> = Result<T, InferenceTopologyError>;
#[derive(Clone, Debug)]
pub struct FixedTopology {
segment_name: String,
buffer_size: usize,
worker_count: WorkerCount,
coordination_timeout: Duration,
}
impl FixedTopology {
pub fn new(
segment_name: impl Into<String>,
buffer_size: usize,
worker_count: WorkerCount,
) -> Self {
Self {
segment_name: segment_name.into(),
buffer_size,
worker_count,
coordination_timeout: DEFAULT_COORDINATION_TIMEOUT,
}
}
pub fn segment_name(&self) -> &str {
&self.segment_name
}
pub fn buffer_size(&self) -> usize {
self.buffer_size
}
pub const fn worker_count(&self) -> WorkerCount {
self.worker_count
}
pub const fn coordination_timeout(&self) -> Duration {
self.coordination_timeout
}
pub fn worker_indices(&self) -> std::ops::Range<usize> {
0..self.worker_count.as_usize()
}
pub fn with_coordination_timeout(mut self, timeout: Duration) -> Self {
self.coordination_timeout = timeout;
self
}
pub fn worker_consumer_id(&self, worker_index: usize) -> InferenceTopologyResult<String> {
self.validate_worker_index(worker_index)?;
Ok(format!("{}_{}", self.worker_prefix(), worker_index))
}
pub fn scheduler_builder<E>(&self) -> SharedDisruptorBuilder<E>
where
E: Copy + Default + 'static,
{
let worker_count = self.worker_count.as_usize();
build_shared_single_producer::<E>(&self.segment_name, self.buffer_size)
.discover_consumer_with_prefix(worker_count, &self.worker_prefix())
.wait_for_consumers(worker_count as i64, self.coordination_timeout)
}
pub fn build_scheduler<E, F>(
&self,
default_event_fn: F,
) -> InferenceTopologyResult<SharedProducer<E>>
where
E: Copy + Default + 'static,
F: FnMut() -> E,
{
self.scheduler_builder()
.build_producer(default_event_fn)
.map_err(Into::into)
}
pub fn worker_builder<E>(
&self,
worker_index: usize,
) -> InferenceTopologyResult<SharedDisruptorBuilder<E>>
where
E: Copy + Default + 'static,
{
let consumer_id = self.worker_consumer_id(worker_index)?;
Ok(
attach_shared_consumer::<E>(&self.segment_name, self.buffer_size)
.with_consumer_id(&consumer_id),
)
}
pub fn attach_worker<E>(
&self,
worker_index: usize,
) -> InferenceTopologyResult<SharedConsumer<E>>
where
E: Copy + Default + 'static,
{
self.worker_builder(worker_index)?
.build_consumer()
.map_err(Into::into)
}
fn validate_worker_index(&self, worker_index: usize) -> InferenceTopologyResult<()> {
let worker_count = self.worker_count.as_usize();
if worker_index < worker_count {
Ok(())
} else {
Err(InferenceTopologyError::InvalidWorkerIndex {
index: worker_index,
worker_count,
})
}
}
fn worker_prefix(&self) -> String {
format!("{}_wk", self.segment_name)
}
}