Skip to main content

datafusion_distributed/observability/
service.rs

1use super::{
2    GetTaskProgressResponse, ObservabilityService, TaskProgress, TaskStatus, WorkerMetrics,
3    generated::observability::{GetTaskProgressRequest, PingRequest, PingResponse},
4};
5use crate::worker::generated::worker::TaskKey;
6use crate::worker::{SingleWriteMultiRead, TaskData};
7use crate::{GetClusterWorkersRequest, GetClusterWorkersResponse, WorkerResolver};
8use datafusion::error::DataFusionError;
9use datafusion::physical_plan::ExecutionPlan;
10use moka::future::Cache;
11use std::sync::Arc;
12#[cfg(feature = "system-metrics")]
13use std::time::Duration;
14#[cfg(feature = "system-metrics")]
15use sysinfo::{Pid, ProcessRefreshKind};
16#[cfg(feature = "system-metrics")]
17use tokio::sync::watch;
18use tonic::{Request, Response, Status};
19
20type ResultTaskData = Result<TaskData, Arc<DataFusionError>>;
21
22pub struct ObservabilityServiceImpl {
23    task_data_entries: Arc<Cache<TaskKey, Arc<SingleWriteMultiRead<ResultTaskData>>>>,
24    worker_resolver: Arc<dyn WorkerResolver + Send + Sync>,
25    #[cfg(feature = "system-metrics")]
26    system: watch::Receiver<WorkerMetrics>,
27}
28
29impl ObservabilityServiceImpl {
30    pub fn new(
31        task_data_entries: Arc<Cache<TaskKey, Arc<SingleWriteMultiRead<ResultTaskData>>>>,
32        worker_resolver: Arc<dyn WorkerResolver + Send + Sync>,
33    ) -> Self {
34        #[cfg(feature = "system-metrics")]
35        let (tx, rx) = tokio::sync::watch::channel(WorkerMetrics::default());
36
37        #[cfg(feature = "system-metrics")]
38        {
39            let pid = Pid::from_u32(std::process::id());
40            let mut sys = sysinfo::System::new_all();
41
42            // Spawn background task to periodically collect and send system metrics.
43            tokio::task::spawn(async move {
44                loop {
45                    sys.refresh_process_specifics(
46                        pid,
47                        ProcessRefreshKind::new().with_cpu().with_memory(),
48                    );
49
50                    if let Some(process) = sys.process(pid) {
51                        let num_cpus = std::thread::available_parallelism()
52                            .map(|n| n.get() as f64)
53                            .unwrap_or(1.0);
54                        let metrics = WorkerMetrics {
55                            rss_bytes: process.memory(),
56                            cpu_usage_percent: process.cpu_usage() as f64 / num_cpus,
57                        };
58                        if tx.send(metrics).is_err() {
59                            break;
60                        }
61                    } else if tx.send(WorkerMetrics::default()).is_err() {
62                        break;
63                    };
64
65                    tokio::time::sleep(Duration::from_millis(100)).await;
66                }
67            });
68        }
69        Self {
70            task_data_entries,
71            worker_resolver,
72            #[cfg(feature = "system-metrics")]
73            system: rx,
74        }
75    }
76}
77
78#[tonic::async_trait]
79impl ObservabilityService for ObservabilityServiceImpl {
80    async fn ping(&self, _request: Request<PingRequest>) -> Result<Response<PingResponse>, Status> {
81        Ok(Response::new(PingResponse { value: 1 }))
82    }
83
84    async fn get_task_progress(
85        &self,
86        _request: Request<GetTaskProgressRequest>,
87    ) -> Result<Response<GetTaskProgressResponse>, Status> {
88        let mut tasks = Vec::new();
89
90        for entry in self.task_data_entries.iter() {
91            let (internal_key, task_data_cell) = entry;
92
93            // Only include initialized tasks
94            if let Some(Ok(task_data)) = task_data_cell.read_now() {
95                let total_partitions = task_data.total_partitions() as u64;
96                let remaining = task_data.num_partitions_remaining() as u64;
97                let completed_partitions = total_partitions.saturating_sub(remaining);
98                let output_rows = output_rows_from_plan(&task_data.plan);
99
100                tasks.push(TaskProgress {
101                    task_key: Some((*internal_key).clone()),
102                    total_partitions,
103                    completed_partitions,
104                    status: TaskStatus::Running as i32,
105                    output_rows,
106                });
107            }
108        }
109
110        let worker_metrics = Some(self.collect_worker_metrics());
111
112        Ok(Response::new(GetTaskProgressResponse {
113            tasks,
114            worker_metrics,
115        }))
116    }
117
118    async fn get_cluster_workers(
119        &self,
120        _request: Request<GetClusterWorkersRequest>,
121    ) -> Result<Response<GetClusterWorkersResponse>, Status> {
122        let urls = self
123            .worker_resolver
124            .get_urls()
125            .map_err(|e| Status::internal(format!("Failed to resolve workers: {e}")))?;
126
127        let worker_urls = urls.into_iter().map(|url| url.to_string()).collect();
128
129        Ok(Response::new(GetClusterWorkersResponse { worker_urls }))
130    }
131}
132
133impl ObservabilityServiceImpl {
134    fn collect_worker_metrics(&self) -> WorkerMetrics {
135        #[cfg(not(feature = "system-metrics"))]
136        {
137            WorkerMetrics::default()
138        }
139
140        #[cfg(feature = "system-metrics")]
141        return *self.system.borrow();
142    }
143}
144
145/// Extracts output rows from the root plan node's metrics.
146fn output_rows_from_plan(plan: &Arc<dyn ExecutionPlan>) -> u64 {
147    plan.metrics().and_then(|m| m.output_rows()).unwrap_or(0) as u64
148}