Skip to main content

datafusion_distributed/observability/
service.rs

1use super::{
2    GetTaskProgressResponse, ObservabilityService, TaskProgress, TaskStatus, WorkerMetrics,
3    generated::observability::{GetTaskProgressRequest, PingRequest, PingResponse},
4};
5use crate::worker::generated::worker::TaskKey;
6use crate::worker::{SingleWriteMultiRead, TaskData};
7use crate::{GetClusterWorkersRequest, GetClusterWorkersResponse, WorkerResolver};
8use datafusion::error::DataFusionError;
9use datafusion::physical_plan::ExecutionPlan;
10use moka::future::Cache;
11use std::sync::Arc;
12#[cfg(feature = "system-metrics")]
13use std::time::Duration;
14#[cfg(feature = "system-metrics")]
15use sysinfo::{Pid, ProcessRefreshKind};
16#[cfg(feature = "system-metrics")]
17use tokio::sync::watch;
18use tonic::{Request, Response, Status};
19
20type ResultTaskData = Result<TaskData, Arc<DataFusionError>>;
21
22pub struct ObservabilityServiceImpl {
23    task_data_entries: Arc<Cache<TaskKey, Arc<SingleWriteMultiRead<ResultTaskData>>>>,
24    worker_resolver: Arc<dyn WorkerResolver + Send + Sync>,
25    #[cfg(feature = "system-metrics")]
26    system: watch::Receiver<WorkerMetrics>,
27}
28
29impl ObservabilityServiceImpl {
30    pub fn new(
31        task_data_entries: Arc<Cache<TaskKey, Arc<SingleWriteMultiRead<ResultTaskData>>>>,
32        worker_resolver: Arc<dyn WorkerResolver + Send + Sync>,
33    ) -> Self {
34        #[cfg(feature = "system-metrics")]
35        let (tx, rx) = tokio::sync::watch::channel(WorkerMetrics::default());
36
37        #[cfg(feature = "system-metrics")]
38        {
39            let pid = Pid::from_u32(std::process::id());
40            let mut sys = sysinfo::System::new_all();
41
42            // Spawn background task to periodically collect and send system metrics.
43            #[allow(clippy::disallowed_methods)]
44            tokio::task::spawn(async move {
45                loop {
46                    sys.refresh_process_specifics(
47                        pid,
48                        ProcessRefreshKind::new().with_cpu().with_memory(),
49                    );
50
51                    if let Some(process) = sys.process(pid) {
52                        let num_cpus = std::thread::available_parallelism()
53                            .map(|n| n.get() as f64)
54                            .unwrap_or(1.0);
55                        let metrics = WorkerMetrics {
56                            rss_bytes: process.memory(),
57                            cpu_usage_percent: process.cpu_usage() as f64 / num_cpus,
58                        };
59                        if tx.send(metrics).is_err() {
60                            break;
61                        }
62                    } else if tx.send(WorkerMetrics::default()).is_err() {
63                        break;
64                    };
65
66                    tokio::time::sleep(Duration::from_millis(100)).await;
67                }
68            });
69        }
70        Self {
71            task_data_entries,
72            worker_resolver,
73            #[cfg(feature = "system-metrics")]
74            system: rx,
75        }
76    }
77}
78
79#[tonic::async_trait]
80impl ObservabilityService for ObservabilityServiceImpl {
81    async fn ping(&self, _request: Request<PingRequest>) -> Result<Response<PingResponse>, Status> {
82        Ok(Response::new(PingResponse { value: 1 }))
83    }
84
85    async fn get_task_progress(
86        &self,
87        _request: Request<GetTaskProgressRequest>,
88    ) -> Result<Response<GetTaskProgressResponse>, Status> {
89        let mut tasks = Vec::new();
90
91        for entry in self.task_data_entries.iter() {
92            let (internal_key, task_data_cell) = entry;
93
94            // Only include initialized tasks
95            if let Some(Ok(task_data)) = task_data_cell.read_now() {
96                let total_partitions = task_data.total_partitions() as u64;
97                let remaining = task_data.num_partitions_remaining() as u64;
98                let completed_partitions = total_partitions.saturating_sub(remaining);
99                let output_rows = output_rows_from_plan(&task_data.base_plan);
100
101                tasks.push(TaskProgress {
102                    task_key: Some((*internal_key).clone()),
103                    total_partitions,
104                    completed_partitions,
105                    status: TaskStatus::Running as i32,
106                    output_rows,
107                });
108            }
109        }
110
111        let worker_metrics = Some(self.collect_worker_metrics());
112
113        Ok(Response::new(GetTaskProgressResponse {
114            tasks,
115            worker_metrics,
116        }))
117    }
118
119    async fn get_cluster_workers(
120        &self,
121        _request: Request<GetClusterWorkersRequest>,
122    ) -> Result<Response<GetClusterWorkersResponse>, Status> {
123        let urls = self
124            .worker_resolver
125            .get_urls()
126            .map_err(|e| Status::internal(format!("Failed to resolve workers: {e}")))?;
127
128        let worker_urls = urls.into_iter().map(|url| url.to_string()).collect();
129
130        Ok(Response::new(GetClusterWorkersResponse { worker_urls }))
131    }
132}
133
134impl ObservabilityServiceImpl {
135    fn collect_worker_metrics(&self) -> WorkerMetrics {
136        #[cfg(not(feature = "system-metrics"))]
137        {
138            WorkerMetrics::default()
139        }
140
141        #[cfg(feature = "system-metrics")]
142        return *self.system.borrow();
143    }
144}
145
146/// Extracts output rows from the root plan node's metrics.
147fn output_rows_from_plan(plan: &Arc<dyn ExecutionPlan>) -> u64 {
148    plan.metrics().and_then(|m| m.output_rows()).unwrap_or(0) as u64
149}