pprof_hyper_server/
lib.rs

1#![doc = include_str!("../README.md")]
2#![cfg_attr(docsrs, feature(doc_auto_cfg))]
3#![forbid(unsafe_code)]
4#![warn(missing_docs)]
5#![warn(clippy::print_stderr)]
6#![warn(clippy::print_stdout)]
7
8use anyhow::Result;
9use async_channel::bounded;
10use async_executor::Executor;
11use async_io::Async;
12use http_body_util::Full;
13use hyper::{
14    Method, Request, Response, StatusCode, Uri,
15    body::{Bytes, Incoming},
16    service::service_fn,
17};
18use smol_hyper::rt::FuturesIo;
19use std::{
20    borrow::Cow,
21    collections::HashMap,
22    net::{SocketAddr, TcpListener, TcpStream},
23    sync::Arc,
24};
25
26const MAX_CONCURRENT_REQUESTS: usize = 2; // 1 cpu + 1 mem
27const NOT_FOUND: &[u8] = "Not Found".as_bytes();
28
29/// Config allows customizing global pprof config.
30#[derive(Default, Clone, Debug)]
31#[allow(dead_code)]
32pub struct Config<'a> {
33    /// Defaults to pprof_cpu::PPROF_BLOCKLIST.
34    pub pprof_blocklist: Option<&'a [&'a str]>,
35    /// Defaults to pprof_cpu::PPROF_DEFAULT_SECONDS.
36    pub pprof_default_seconds: Option<i32>,
37    /// Defaults to pprof_cpu::PPROF_DEFAULT_SAMPLING.
38    pub pprof_default_sampling: Option<i32>,
39}
40
41#[cfg(all(feature = "pprof_cpu", not(target_env = "msvc")))]
42mod pprof_cpu {
43    pub const PPROF_BLOCKLIST: &[&str; 4] = &["libc", "libgcc", "pthread", "vdso"];
44    pub const PPROF_DEFAULT_SECONDS: i32 = 30; // same as golang pprof
45    pub const PPROF_DEFAULT_SAMPLING: i32 = 99;
46}
47
48struct Task<'a> {
49    client: Async<TcpStream>,
50    #[allow(dead_code)]
51    config: Arc<Config<'a>>,
52}
53
54impl Task<'_> {
55    /// Handle a new client.
56    async fn handle_client(self) -> Result<()> {
57        hyper::server::conn::http1::Builder::new()
58            .serve_connection(
59                FuturesIo::new(&self.client),
60                service_fn(|req| self.serve(req)),
61            )
62            .await
63            .unwrap_or_default(); // don't use ? otherwise early connection close errors are propagated
64
65        Ok(())
66    }
67
68    async fn serve(&self, req: Request<Incoming>) -> Result<Response<Full<Bytes>>> {
69        match (req.method(), req.uri().path()) {
70            (&Method::GET, "/debug/pprof/allocs" | "/debug/pprof/heap") => {
71                self.memory_profile().await
72            }
73            (&Method::GET, "/debug/pprof/profile") => self.cpu_profile(req).await,
74            _ => not_found(),
75        }
76    }
77}
78
79impl Task<'_> {
80    #[cfg(all(feature = "pprof_cpu", not(target_env = "msvc")))]
81    async fn cpu_profile(&self, req: Request<Incoming>) -> Result<Response<Full<Bytes>>> {
82        use crate::pprof_cpu::*;
83        use async_io::Timer;
84        use pprof::{ProfilerGuardBuilder, protos::Message};
85        use std::time::Duration;
86
87        let params = get_params(req.uri());
88
89        let profile_seconds = parse_i32_params(
90            &params,
91            "seconds",
92            self.config
93                .pprof_default_seconds
94                .unwrap_or(PPROF_DEFAULT_SECONDS),
95        );
96        let profile_sampling = parse_i32_params(
97            &params,
98            "sampling",
99            self.config
100                .pprof_default_sampling
101                .unwrap_or(PPROF_DEFAULT_SAMPLING),
102        );
103
104        let blocklist = self.config.pprof_blocklist.unwrap_or(PPROF_BLOCKLIST);
105        let guard = ProfilerGuardBuilder::default()
106            .frequency(profile_sampling)
107            .blocklist(blocklist)
108            .build()?;
109
110        Timer::after(Duration::from_secs(profile_seconds.try_into()?)).await;
111
112        let profile = guard.report().build()?.pprof()?;
113
114        let mut content = Vec::new();
115        profile.encode(&mut content)?;
116
117        Ok(Response::new(Full::new(Bytes::from(content))))
118    }
119
120    #[cfg(any(not(feature = "pprof_cpu"), target_env = "msvc"))]
121    async fn cpu_profile(&self, _: Request<Incoming>) -> Result<Response<Full<Bytes>>> {
122        not_found()
123    }
124
125    #[cfg(all(feature = "pprof_heap", not(target_env = "msvc")))]
126    async fn memory_profile(&self) -> Result<Response<Full<Bytes>>> {
127        let prof_ctl = jemalloc_pprof::PROF_CTL.as_ref();
128
129        match prof_ctl {
130            None => Err(anyhow::anyhow!("heap profiling not activated")),
131            Some(prof_ctl) => {
132                let mut prof_ctl = prof_ctl.try_lock()?;
133
134                if !prof_ctl.activated() {
135                    return Err(anyhow::anyhow!("heap profiling not activated"));
136                }
137
138                let pprof = prof_ctl.dump_pprof()?;
139
140                Ok(Response::new(Full::new(Bytes::from(pprof))))
141            }
142        }
143    }
144
145    #[cfg(any(not(feature = "pprof_heap"), target_env = "msvc"))]
146    async fn memory_profile(&self) -> Result<Response<Full<Bytes>>> {
147        not_found()
148    }
149}
150
151#[allow(dead_code)]
152fn get_params<'a>(uri: &'a Uri) -> HashMap<Cow<'a, str>, Cow<'a, str>> {
153    let params: HashMap<Cow<'_, str>, Cow<'_, str>> = uri
154        .query()
155        .map(|v| form_urlencoded::parse(v.as_bytes()).collect())
156        .unwrap_or_default();
157
158    params
159}
160
161#[allow(dead_code)]
162fn parse_i32_params<'a>(
163    params: &'a HashMap<Cow<'a, str>, Cow<'a, str>>,
164    name: &str,
165    default: i32,
166) -> i32 {
167    params
168        .get(name)
169        .and_then(|e| e.parse::<i32>().ok())
170        .unwrap_or(default)
171}
172
173fn not_found() -> Result<Response<Full<Bytes>>> {
174    Ok(Response::builder()
175        .status(StatusCode::NOT_FOUND)
176        .body(Full::new(Bytes::from(NOT_FOUND)))
177        .unwrap_or_default())
178}
179
180/// Listens for incoming connections and serves them under pprof HTTP API.
181pub async fn serve<'a>(bind_address: SocketAddr, config: Config<'a>) -> Result<()> {
182    let listener = Async::<TcpListener>::bind(bind_address)?;
183    let (s, r) = bounded::<Task>(MAX_CONCURRENT_REQUESTS);
184    let config = Arc::new(config);
185    let ex = Arc::new(Executor::new());
186
187    ex.spawn({
188        let ex = ex.clone();
189        async move {
190            loop {
191                if let Ok(task) = r.recv().await {
192                    ex.spawn(async {
193                        task.handle_client().await.unwrap_or_default();
194                    })
195                    .detach();
196                }
197            }
198        }
199    })
200    .detach();
201
202    ex.run({
203        async move {
204            // stack max MAX_CONCURRENT_REQUESTS requests
205            // if we cannot add more tasks, drop the connection
206            // we don't need a multi threaded server to serve pprof server, but don't want it to be a source of DOS.
207            loop {
208                let listener = listener.accept().await;
209                if let Ok((client, _)) = listener {
210                    let task = Task {
211                        client,
212                        config: config.clone(),
213                    };
214
215                    // we ignore the potential error as it would mean we should drop the connection if channel is full.
216                    let _ = s.try_send(task);
217                }
218            }
219        }
220    })
221    .await;
222
223    Ok(())
224}