use anyhow::Context;
use num_enum::{FromPrimitive, IntoPrimitive};
use std::sync::{
atomic::{AtomicU8, Ordering},
Arc, Mutex,
};
use tokio::{
runtime,
sync::Notify,
task::{JoinError, JoinSet},
};
use crate::pipeline::elements::output::{run::run_async_output, AsyncOutputStream};
use crate::pipeline::matching::OutputNamePattern;
use crate::pipeline::naming::{namespace::Namespace2, OutputName};
use crate::pipeline::util::{
channel,
stream::{ControlledStream, SharedStreamState, StreamState},
};
use crate::pipeline::{control::matching::OutputMatcher, matching::ElementNamePattern, naming::ElementKind};
use crate::{measurement::MeasurementBuffer, pipeline::error::PipelineError};
use crate::{metrics::online::MetricReader, pipeline::naming::ElementName};
use super::{
builder::{self, OutputBuilder},
run::run_blocking_output,
};
#[derive(Debug)]
pub enum ControlMessage {
Configure(ConfigureMessage),
CreateMany(CreateManyMessage),
}
#[derive(Debug)]
pub struct ConfigureMessage {
pub matcher: OutputMatcher,
pub new_state: TaskState,
}
#[derive(Debug)]
pub struct CreateManyMessage {
pub builders: Vec<(OutputName, builder::SendOutputBuilder)>,
}
#[derive(Clone, Debug, PartialEq, Eq, Copy, IntoPrimitive, FromPrimitive)]
#[repr(u8)]
pub enum TaskState {
Run,
RunDiscard,
Pause,
StopFinish,
#[num_enum(default)]
StopNow,
}
pub enum SingleOutputController {
Blocking(Arc<SharedOutputConfig>),
Async(Arc<SharedStreamState>),
}
pub struct SharedOutputConfig {
pub change_notifier: Notify,
pub atomic_state: AtomicU8,
}
impl SharedOutputConfig {
pub fn new() -> Self {
Self {
change_notifier: Notify::new(),
atomic_state: AtomicU8::new(TaskState::Run as u8),
}
}
pub fn set_state(&self, state: TaskState) {
self.atomic_state.store(state as u8, Ordering::Relaxed);
self.change_notifier.notify_one();
}
}
impl SingleOutputController {
pub fn set_state(&mut self, state: TaskState) {
match self {
SingleOutputController::Blocking(shared) => shared.set_state(state),
SingleOutputController::Async(arc) => arc.set(StreamState::from(state as u8)),
}
}
}
pub(crate) struct OutputControl {
tasks: TaskManager,
metrics: MetricReader,
}
struct TaskManager {
spawned_tasks: JoinSet<Result<(), PipelineError>>,
controllers: Vec<(OutputName, SingleOutputController)>,
rx_provider: channel::ReceiverProvider,
rt_normal: runtime::Handle,
metrics: MetricReader,
}
impl OutputControl {
pub fn new(rx_provider: channel::ReceiverProvider, rt_normal: runtime::Handle, metrics: MetricReader) -> Self {
Self {
tasks: TaskManager {
spawned_tasks: JoinSet::new(),
controllers: Vec::new(),
rx_provider,
rt_normal,
metrics: metrics.clone(),
},
metrics,
}
}
pub fn blocking_create_outputs(&mut self, outputs: Namespace2<OutputBuilder>) -> anyhow::Result<()> {
let metrics = self.metrics.blocking_read();
for ((plugin, output_name), builder) in outputs {
let mut ctx = builder::OutputBuildContext {
metrics: &metrics,
metrics_r: &self.metrics.clone(),
runtime: self.tasks.rt_normal.clone(),
};
let full_name = OutputName::new(plugin.clone(), output_name);
self.tasks
.create_output(&mut ctx, full_name, builder)
.inspect_err(|e| log::error!("Error in output creation requested by plugin {plugin}: {e:#}"))?;
}
Ok(())
}
pub async fn create_outputs(
&mut self,
builders: Vec<(OutputName, builder::SendOutputBuilder)>,
) -> anyhow::Result<()> {
let metrics = self.metrics.read().await;
let mut ctx = builder::OutputBuildContext {
metrics: &metrics,
metrics_r: &self.metrics,
runtime: self.tasks.rt_normal.clone(),
};
let n = builders.len();
log::debug!("Creating {n} outputs...");
let mut n_errors = 0;
for (name, builder) in builders {
let _ = self
.tasks
.create_output(&mut ctx, name.clone(), builder.into())
.inspect_err(|e| {
log::error!("Error while creating source '{name}': {e:?}");
n_errors += 1;
});
}
if n_errors == 0 {
Ok(())
} else {
Err(anyhow::anyhow!(
"failed to create {n_errors}/{n} outputs (see logs above)"
))
}
}
pub async fn handle_message(&mut self, msg: ControlMessage) -> anyhow::Result<()> {
match msg {
ControlMessage::Configure(msg) => self.tasks.reconfigure(msg),
ControlMessage::CreateMany(msg) => self.create_outputs(msg.builders).await?,
}
Ok(())
}
pub async fn join_next_task(&mut self) -> Result<Result<(), PipelineError>, JoinError> {
match self.tasks.spawned_tasks.join_next().await {
Some(res) => res,
None => unreachable!("join_next_task must be guarded by has_task to prevent an infinite loop"),
}
}
pub fn has_task(&self) -> bool {
!self.tasks.spawned_tasks.is_empty()
}
pub async fn shutdown<F>(mut self, handle_task_result: F)
where
F: FnMut(Result<Result<(), PipelineError>, tokio::task::JoinError>),
{
let stop_msg = ControlMessage::Configure(ConfigureMessage {
matcher: OutputMatcher::Name(OutputNamePattern::wildcard()),
new_state: TaskState::StopFinish,
});
self.handle_message(stop_msg)
.await
.expect("handle_message in shutdown should not fail");
self.tasks.shutdown(handle_task_result).await;
}
pub fn list_elements(&self, buf: &mut Vec<ElementName>, pat: &ElementNamePattern) {
if pat.kind == None || pat.kind == Some(ElementKind::Output) {
buf.extend(self.tasks.controllers.iter().filter_map(|(name, _)| {
if pat.matches(name) {
Some(name.to_owned().into())
} else {
None
}
}))
}
}
}
impl TaskManager {
fn create_output(
&mut self,
ctx: &mut builder::OutputBuildContext,
name: OutputName,
builder: OutputBuilder,
) -> anyhow::Result<()> {
match builder {
OutputBuilder::Blocking(builder) => self.create_blocking_output(ctx, name, builder),
OutputBuilder::Async(builder) => self.create_async_output(ctx, name, builder),
}
}
fn create_blocking_output(
&mut self,
ctx: &mut dyn builder::BlockingOutputBuildContext,
name: OutputName,
builder: Box<dyn builder::BlockingOutputBuilder>,
) -> anyhow::Result<()> {
let output = builder(ctx).context("output creation failed")?;
let rx = self.rx_provider.get(); let metrics = self.metrics.clone();
let config = Arc::new(SharedOutputConfig::new());
let shared_config = config.clone();
let control = SingleOutputController::Blocking(config);
self.controllers.push((name.clone(), control));
let guarded_output = Arc::new(Mutex::new(output));
match rx {
channel::ReceiverEnum::Broadcast(rx) => {
let task = run_blocking_output(name, guarded_output, rx, metrics, shared_config);
self.spawned_tasks.spawn_on(task, &self.rt_normal);
}
channel::ReceiverEnum::Single(rx) => {
let task = run_blocking_output(name, guarded_output, rx, metrics, shared_config);
self.spawned_tasks.spawn_on(task, &self.rt_normal);
}
}
Ok(())
}
fn create_async_output(
&mut self,
ctx: &mut dyn builder::AsyncOutputBuildContext,
name: OutputName,
builder: Box<dyn builder::AsyncOutputBuilder>,
) -> anyhow::Result<()> {
use channel::MeasurementReceiver;
fn box_controlled_stream<
S: futures::Stream<Item = Result<MeasurementBuffer, channel::StreamRecvError>> + Send + 'static,
>(
stream: S,
) -> (AsyncOutputStream, Arc<SharedStreamState>) {
let stream = Box::pin(ControlledStream::new(stream));
let state = stream.state();
(AsyncOutputStream(stream), state)
}
let rx = self.rx_provider.get();
let (stream, state) = match rx {
channel::ReceiverEnum::Broadcast(receiver) => box_controlled_stream(receiver.into_stream()),
channel::ReceiverEnum::Single(receiver) => box_controlled_stream(receiver.into_stream()),
};
let output = builder(ctx, stream).context("output creation failed")?;
let control = SingleOutputController::Async(state);
self.controllers.push((name.clone(), control));
let task = run_async_output(name, output);
self.spawned_tasks.spawn_on(task, &self.rt_normal);
Ok(())
}
fn reconfigure(&mut self, msg: ConfigureMessage) {
for (name, output_config) in &mut self.controllers {
if msg.matcher.matches(name) {
output_config.set_state(msg.new_state);
}
}
}
async fn shutdown<F>(self, mut handle_task_result: F)
where
F: FnMut(Result<Result<(), PipelineError>, tokio::task::JoinError>),
{
drop(self.rx_provider);
let mut spawned_tasks = self.spawned_tasks;
loop {
match spawned_tasks.join_next().await {
Some(res) => handle_task_result(res),
None => break,
}
}
}
}