Skip to main content

pprof_hyper_server/
lib.rs

1#![doc = include_str!("../README.md")]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3#![forbid(unsafe_code)]
4#![warn(missing_docs)]
5#![warn(clippy::print_stderr)]
6#![warn(clippy::print_stdout)]
7
8use anyhow::Result;
9use async_channel::bounded;
10use async_executor::Executor;
11use async_io::Async;
12use http_body_util::Full;
13use hyper::{
14    Method, Request, Response, StatusCode, Uri,
15    body::{Bytes, Incoming},
16    service::service_fn,
17};
18use smol_hyper::rt::FuturesIo;
19use std::{
20    borrow::Cow,
21    collections::HashMap,
22    net::{SocketAddr, TcpListener, TcpStream},
23    sync::Arc,
24};
25
26const MAX_CONCURRENT_REQUESTS: usize = 2; // 1 cpu + 1 mem
27const NOT_FOUND: &[u8] = "Not Found".as_bytes();
28
29/// Config allows customizing global pprof config.
30#[derive(Default, Clone, Debug)]
31#[allow(dead_code)]
32pub struct Config<'a> {
33    /// Defaults to pprof_cpu::PPROF_BLOCKLIST.
34    pub pprof_blocklist: Option<&'a [&'a str]>,
35    /// Defaults to pprof_cpu::PPROF_DEFAULT_SECONDS.
36    pub pprof_default_seconds: Option<i32>,
37    /// Defaults to pprof_cpu::PPROF_DEFAULT_SAMPLING.
38    pub pprof_default_sampling: Option<i32>,
39}
40
41#[cfg(all(feature = "pprof_cpu", not(target_env = "msvc")))]
42mod pprof_cpu {
43    pub const PPROF_BLOCKLIST: &[&str; 4] = &["libc", "libgcc", "pthread", "vdso"];
44    pub const PPROF_DEFAULT_SECONDS: i32 = 30; // same as golang pprof
45    pub const PPROF_DEFAULT_SAMPLING: i32 = 99;
46}
47
48struct Task<'a> {
49    client: Async<TcpStream>,
50    #[allow(dead_code)]
51    config: Arc<Config<'a>>,
52}
53
54impl Task<'_> {
55    /// Handle a new client.
56    async fn handle_client(self) -> Result<()> {
57        hyper::server::conn::http1::Builder::new()
58            .serve_connection(
59                FuturesIo::new(&self.client),
60                service_fn(|req| self.serve(req)),
61            )
62            .await
63            .unwrap_or_default(); // don't use ? otherwise early connection close errors are propagated
64
65        Ok(())
66    }
67
68    async fn serve(&self, req: Request<Incoming>) -> Result<Response<Full<Bytes>>> {
69        match (req.method(), req.uri().path()) {
70            (&Method::GET, "/debug/pprof/allocs" | "/debug/pprof/heap") => {
71                self.memory_profile().await
72            }
73            (&Method::GET, "/debug/pprof/profile") => self.cpu_profile(req).await,
74            _ => not_found(),
75        }
76    }
77}
78
79impl Task<'_> {
80    #[cfg(all(feature = "pprof_cpu", not(target_env = "msvc")))]
81    async fn cpu_profile(&self, req: Request<Incoming>) -> Result<Response<Full<Bytes>>> {
82        use crate::pprof_cpu::*;
83        use async_io::Timer;
84        use flate2::write::GzEncoder;
85        use pprof::{ProfilerGuardBuilder, protos::Message};
86        use std::io::Write;
87        use std::time::Duration;
88
89        let params = get_params(req.uri());
90
91        let profile_seconds = parse_i32_params(
92            &params,
93            "seconds",
94            self.config
95                .pprof_default_seconds
96                .unwrap_or(PPROF_DEFAULT_SECONDS),
97        );
98        let profile_sampling = parse_i32_params(
99            &params,
100            "sampling",
101            self.config
102                .pprof_default_sampling
103                .unwrap_or(PPROF_DEFAULT_SAMPLING),
104        );
105
106        let blocklist = self.config.pprof_blocklist.unwrap_or(PPROF_BLOCKLIST);
107        let guard = ProfilerGuardBuilder::default()
108            .frequency(profile_sampling)
109            .blocklist(blocklist)
110            .build()?;
111
112        Timer::after(Duration::from_secs(profile_seconds.try_into()?)).await;
113
114        let profile = guard.report().build()?.pprof()?;
115
116        let mut content = Vec::new();
117        profile.encode(&mut content)?;
118
119        let mut gz = GzEncoder::new(Vec::new(), flate2::Compression::default());
120        gz.write_all(&content)?;
121        let compressed = gz.finish()?;
122
123        Ok(Response::new(Full::new(Bytes::from(compressed))))
124    }
125
126    #[cfg(any(not(feature = "pprof_cpu"), target_env = "msvc"))]
127    async fn cpu_profile(&self, _: Request<Incoming>) -> Result<Response<Full<Bytes>>> {
128        not_found()
129    }
130
131    #[cfg(all(feature = "pprof_heap", not(target_env = "msvc")))]
132    async fn memory_profile(&self) -> Result<Response<Full<Bytes>>> {
133        let prof_ctl = jemalloc_pprof::PROF_CTL.as_ref();
134
135        match prof_ctl {
136            None => Err(anyhow::anyhow!("heap profiling not activated")),
137            Some(prof_ctl) => {
138                let mut prof_ctl = prof_ctl.try_lock()?;
139
140                if !prof_ctl.activated() {
141                    return Err(anyhow::anyhow!("heap profiling not activated"));
142                }
143
144                let pprof = prof_ctl.dump_pprof()?;
145
146                Ok(Response::new(Full::new(Bytes::from(pprof))))
147            }
148        }
149    }
150
151    #[cfg(any(not(feature = "pprof_heap"), target_env = "msvc"))]
152    async fn memory_profile(&self) -> Result<Response<Full<Bytes>>> {
153        not_found()
154    }
155}
156
157#[allow(dead_code)]
158fn get_params<'a>(uri: &'a Uri) -> HashMap<Cow<'a, str>, Cow<'a, str>> {
159    let params: HashMap<Cow<'_, str>, Cow<'_, str>> = uri
160        .query()
161        .map(|v| form_urlencoded::parse(v.as_bytes()).collect())
162        .unwrap_or_default();
163
164    params
165}
166
167#[allow(dead_code)]
168fn parse_i32_params<'a>(
169    params: &'a HashMap<Cow<'a, str>, Cow<'a, str>>,
170    name: &str,
171    default: i32,
172) -> i32 {
173    params
174        .get(name)
175        .and_then(|e| e.parse::<i32>().ok())
176        .unwrap_or(default)
177}
178
179fn not_found() -> Result<Response<Full<Bytes>>> {
180    Ok(Response::builder()
181        .status(StatusCode::NOT_FOUND)
182        .body(Full::new(Bytes::from(NOT_FOUND)))
183        .unwrap_or_default())
184}
185
186/// Listens for incoming connections and serves them under pprof HTTP API.
187pub async fn serve<'a>(bind_address: SocketAddr, config: Config<'a>) -> Result<()> {
188    let listener = Async::<TcpListener>::bind(bind_address)?;
189    let (s, r) = bounded::<Task>(MAX_CONCURRENT_REQUESTS);
190    let config = Arc::new(config);
191    let ex = Arc::new(Executor::new());
192
193    ex.spawn({
194        let ex = ex.clone();
195        async move {
196            loop {
197                if let Ok(task) = r.recv().await {
198                    ex.spawn(async {
199                        task.handle_client().await.unwrap_or_default();
200                    })
201                    .detach();
202                }
203            }
204        }
205    })
206    .detach();
207
208    ex.run({
209        async move {
210            // stack max MAX_CONCURRENT_REQUESTS requests
211            // if we cannot add more tasks, drop the connection
212            // we don't need a multi threaded server to serve pprof server, but don't want it to be a source of DOS.
213            loop {
214                let listener = listener.accept().await;
215                if let Ok((client, _)) = listener {
216                    let task = Task {
217                        client,
218                        config: config.clone(),
219                    };
220
221                    // we ignore the potential error as it would mean we should drop the connection if channel is full.
222                    let _ = s.try_send(task);
223                }
224            }
225        }
226    })
227    .await;
228
229    Ok(())
230}