mod metrics;
use std::any::Any;
use std::future::Future;
use std::panic::UnwindSafe;
use std::sync::OnceLock;
use std::time::Instant;
use opentelemetry::metrics::MeterProvider as _;
use opentelemetry::metrics::ObservableGauge;
use tokio::sync::oneshot;
use tracing::Instrument;
use tracing::Span;
use tracing::info_span;
use tracing_core::Dispatch;
use tracing_subscriber::util::SubscriberInitExt;
use self::metrics::ActiveComputeMetric;
use self::metrics::JobWatcher;
use self::metrics::Outcome;
use self::metrics::observe_compute_duration;
use self::metrics::observe_queue_wait_duration;
use crate::ageing_priority_queue::AgeingPriorityQueue;
use crate::ageing_priority_queue::Priority;
use crate::metrics::meter_provider;
use crate::plugins::telemetry::consts::COMPUTE_JOB_EXECUTION_SPAN_NAME;
use crate::plugins::telemetry::consts::COMPUTE_JOB_SPAN_NAME;
fn queue_capacity() -> usize {
std::env::var("APOLLO_ROUTER_COMPUTE_QUEUE_CAPACITY_PER_THREAD")
.ok()
.and_then(|value| value.parse::<usize>().ok())
.unwrap_or(1000)
}
fn thread_pool_size() -> usize {
if let Some(threads) = std::env::var("APOLLO_ROUTER_COMPUTE_THREADS")
.ok()
.and_then(|value| value.parse::<usize>().ok())
{
threads
} else {
std::thread::available_parallelism()
.expect("available_parallelism() failed")
.get()
}
}
#[derive(Copy, Clone, Hash, Eq, PartialEq, Debug, strum_macros::IntoStaticStr)]
#[strum(serialize_all = "snake_case")]
pub(crate) enum ComputeJobType {
QueryParsing,
QueryPlanning,
Introspection,
}
impl From<ComputeJobType> for Priority {
fn from(job_type: ComputeJobType) -> Self {
match job_type {
ComputeJobType::QueryPlanning => Self::P8, ComputeJobType::QueryParsing => Self::P4, ComputeJobType::Introspection => Self::P1, }
}
}
impl From<ComputeJobType> for opentelemetry::Value {
fn from(compute_job_type: ComputeJobType) -> Self {
let s: &'static str = compute_job_type.into();
s.into()
}
}
pub(crate) struct Job {
subscriber: Dispatch,
parent_span: Span,
ty: ComputeJobType,
queue_start: Instant,
job_fn: Box<dyn FnOnce() + Send + 'static>,
}
fn queue() -> &'static AgeingPriorityQueue<Job> {
static QUEUE: OnceLock<AgeingPriorityQueue<Job>> = OnceLock::new();
QUEUE.get_or_init(|| {
let pool_size = thread_pool_size();
for _ in 0..pool_size {
std::thread::spawn(|| {
let queue = queue();
let mut receiver = queue.receiver();
loop {
let (job, age) = receiver.blocking_recv();
let job_type: &'static str = job.ty.into();
let age: &'static str = age.into();
let _subscriber = job.subscriber.set_default();
job.parent_span.in_scope(|| {
let span = info_span!(
COMPUTE_JOB_EXECUTION_SPAN_NAME,
"job.type" = job_type,
"job.age" = age
);
span.in_scope(|| {
observe_queue_wait_duration(job.ty, job.queue_start.elapsed());
let _active_metric = ActiveComputeMetric::register(job.ty);
let job_start = Instant::now();
(job.job_fn)();
observe_compute_duration(job.ty, job_start.elapsed());
})
})
}
});
}
AgeingPriorityQueue::soft_bounded(queue_capacity() * pool_size)
})
}
pub(crate) fn execute<T, F>(
compute_job_type: ComputeJobType,
job: F,
) -> impl Future<Output = std::thread::Result<T>>
where
F: FnOnce() -> T + Send + UnwindSafe + 'static,
T: Send + 'static,
{
let compute_job_type_str: &'static str = compute_job_type.into();
let span = info_span!(
COMPUTE_JOB_SPAN_NAME,
"job.type" = compute_job_type_str,
"job.outcome" = tracing::field::Empty
);
span.in_scope(|| {
let job_watcher = JobWatcher::new(compute_job_type);
let (tx, rx) = oneshot::channel();
let job = Box::new(move || {
let _ = tx.send(std::panic::catch_unwind(job));
});
let job = Job {
subscriber: Dispatch::default(),
parent_span: Span::current(),
ty: compute_job_type,
job_fn: job,
queue_start: Instant::now(),
};
queue().send(compute_job_type.into(), job);
async move {
let result = rx.await;
let mut local_job_watcher = job_watcher;
local_job_watcher.outcome = match &result {
Ok(Ok(_)) => Outcome::ExecutedOk,
Ok(Err(_)) => Outcome::ExecutedError,
Err(_) => Outcome::ChannelError,
};
match result {
Ok(r) => r,
Err(e) => Err(Box::new(e) as Box<dyn Any + Send>),
}
}
.in_current_span()
})
}
pub(crate) fn is_full() -> bool {
queue().is_full()
}
pub(crate) fn create_queue_size_gauge() -> ObservableGauge<u64> {
meter_provider()
.meter("apollo/router")
.u64_observable_gauge("apollo.router.compute_jobs.queued")
.with_description(
"Number of computation jobs (parsing, planning, …) waiting to be scheduled",
)
.with_callback(move |m| m.observe(queue().queued_count() as u64, &[]))
.init()
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use std::time::Instant;
use tracing_futures::WithSubscriber;
use super::*;
use crate::assert_snapshot_subscriber;
#[tokio::test]
async fn test_observability() {
async {
let span = info_span!("test_observability");
async {
tracing::info!("Outer");
let job = execute(ComputeJobType::QueryParsing, || {
tracing::info!("Inner");
1
});
let result = job.await.unwrap();
assert_eq!(result, 1);
}
.instrument(span)
.await;
}
.with_subscriber(assert_snapshot_subscriber!())
.await;
}
#[tokio::test]
async fn test_executes_on_different_thread() {
let test_thread = std::thread::current().id();
let job_thread = execute(ComputeJobType::QueryParsing, || std::thread::current().id())
.await
.expect("job panicked");
assert_ne!(job_thread, test_thread)
}
#[tokio::test]
async fn test_parallelism() {
if thread_pool_size() < 2 {
return;
}
let start = Instant::now();
let one = execute(ComputeJobType::QueryPlanning, || {
std::thread::sleep(Duration::from_millis(1_000));
1
});
let two = execute(ComputeJobType::QueryPlanning, || {
std::thread::sleep(Duration::from_millis(1_000));
1 + 1
});
tokio::time::sleep(Duration::from_millis(500)).await;
assert_eq!(one.await.unwrap(), 1);
assert_eq!(two.await.unwrap(), 2);
assert!(start.elapsed() < Duration::from_millis(1_400));
}
}