1#![warn(clippy::pedantic)]
2#![warn(clippy::cargo)]
3#![allow(clippy::multiple_crate_versions)]
4
5#[macro_use]
45extern crate log;
46
47use std::borrow::Cow;
48use std::collections::BTreeMap;
49use std::sync::atomic::{AtomicBool, Ordering};
50use std::sync::{Arc, Mutex};
51
52use rocket::fairing::{Fairing, Info, Kind};
53use rocket::http::Status;
54use rocket::request::local_cache_once;
55use rocket::serde::Deserialize;
56use rocket::{fairing, Build, Data, Request, Response, Rocket};
57use sentry::protocol::SpanStatus;
58use sentry::{protocol, ClientInitGuard, ClientOptions, TracesSampler, Transaction};
59
60const TRANSACTION_OPERATION_NAME: &str = "http.server";
61
62pub struct RocketSentry {
63 guard: Mutex<Option<ClientInitGuard>>,
64 transactions_enabled: AtomicBool,
65 traces_sampler: Option<Arc<TracesSampler>>,
66}
67
68#[derive(Deserialize)]
69struct Config {
70 sentry_dsn: String,
71 sentry_traces_sample_rate: Option<f32>, }
73
74impl RocketSentry {
75 #[must_use]
76 pub fn fairing() -> impl Fairing {
77 RocketSentry::builder().build()
78 }
79
80 #[must_use]
81 pub fn builder() -> RocketSentryBuilder {
82 RocketSentryBuilder::new()
83 }
84
85 fn init(&self, dsn: &str, traces_sample_rate: f32, environment: Cow<'static, str>) {
86 let guard = sentry::init((
87 dsn,
88 ClientOptions {
89 before_send: Some(Arc::new(|event| {
90 info!("Sending event to Sentry: {}", event.event_id);
91 Some(event)
92 })),
93 traces_sample_rate,
94 traces_sampler: self.traces_sampler.clone(),
95 environment: Some(environment),
96 ..Default::default()
97 },
98 ));
99
100 if guard.is_enabled() {
101 let mut self_guard = self.guard.lock().unwrap();
103 *self_guard = Some(guard);
104
105 info!("Sentry enabled.");
106 if traces_sample_rate > 0f32 || self.traces_sampler.is_some() {
107 self.transactions_enabled.store(true, Ordering::Relaxed);
108 }
109 } else {
110 error!("Sentry did not initialize.");
111 }
112 }
113
114 fn start_transaction(name: &str) -> Transaction {
115 let transaction_context = sentry::TransactionContext::new(name, TRANSACTION_OPERATION_NAME);
116 sentry::start_transaction(transaction_context)
117 }
118
119 fn invalid_transaction() -> Transaction {
122 let name = "INVALID TRANSACTION";
123 Self::start_transaction(name)
124 }
125}
126
127#[rocket::async_trait]
128impl Fairing for RocketSentry {
129 fn info(&self) -> Info {
130 Info {
131 name: "rocket-sentry",
132 kind: Kind::Ignite | Kind::Singleton | Kind::Request | Kind::Response,
133 }
134 }
135
136 async fn on_ignite(&self, rocket: Rocket<Build>) -> fairing::Result {
137 let figment = rocket.figment();
138 let profile_name = figment.profile().to_string();
139
140 let environment = match profile_name.as_str() {
142 "debug" => Cow::Borrowed("development"),
143 "release" => Cow::Borrowed("production"),
144 _ => Cow::Owned(profile_name),
145 };
146
147 let config: figment::error::Result<Config> = figment.extract();
148 match config {
149 Ok(config) => {
150 if config.sentry_dsn.is_empty() {
151 info!("Sentry disabled.");
152 } else {
153 let traces_sample_rate = config.sentry_traces_sample_rate.unwrap_or(0f32);
154 self.init(&config.sentry_dsn, traces_sample_rate, environment);
155 }
156 }
157 Err(err) => error!("Sentry not configured: {err}"),
158 }
159 Ok(rocket)
160 }
161
162 async fn on_request(&self, request: &mut Request<'_>, _: &mut Data<'_>) {
163 if self.transactions_enabled.load(Ordering::Relaxed) {
164 let name = request_to_transaction_name(request);
165 let build_transaction = move || Self::start_transaction(&name);
166 let request_transaction = local_cache_once!(request, build_transaction);
167 request.local_cache(request_transaction);
168 }
169 }
170
171 async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
172 if self.transactions_enabled.load(Ordering::Relaxed) {
173 let request_transaction = local_cache_once!(request, Self::invalid_transaction);
175 let ongoing_transaction: &Transaction = request.local_cache(request_transaction);
176 ongoing_transaction.set_status(map_status(response.status()));
177 set_transaction_request(ongoing_transaction, request);
178 ongoing_transaction.clone().finish();
179 }
180 }
181}
182
183fn set_transaction_request(transaction: &Transaction, request: &Request) {
184 transaction.set_request(protocol::Request {
185 url: None,
186 method: Some(request.method().to_string()),
187 data: None,
188 query_string: request_to_query_string(request),
189 cookies: None,
190 headers: request_to_header_map(request),
191 env: BTreeMap::new(),
192 });
193}
194
195fn request_to_transaction_name(request: &Request) -> String {
196 let method = request.method();
197 let path = request.uri().path();
198 format!("{method} {path}")
199}
200
201fn request_to_query_string(request: &Request) -> Option<String> {
202 Some(request.uri().query()?.to_string())
203}
204
205fn map_status(status: Status) -> SpanStatus {
206 #[allow(clippy::match_same_arms)]
207 match status.code {
208 100..=299 => SpanStatus::Ok,
209 300..=399 => SpanStatus::Ok,
212 401 => SpanStatus::Unauthenticated,
213 403 => SpanStatus::PermissionDenied,
214 404 => SpanStatus::NotFound,
215 409 => SpanStatus::AlreadyExists,
216 429 => SpanStatus::ResourceExhausted,
217 400..=499 => SpanStatus::InvalidArgument,
218 501 => SpanStatus::Unimplemented,
219 503 => SpanStatus::Unavailable,
220 500..=599 => SpanStatus::InternalError,
221 _ => SpanStatus::UnknownError,
222 }
223}
224
225fn request_to_header_map(request: &Request) -> BTreeMap<String, String> {
226 request
227 .headers()
228 .iter()
229 .map(|header| (header.name().to_string(), header.value().to_string()))
230 .collect()
231}
232
233pub struct RocketSentryBuilder {
234 traces_sampler: Option<Arc<TracesSampler>>,
235}
236
237impl RocketSentryBuilder {
238 #[must_use]
239 fn new() -> RocketSentryBuilder {
240 RocketSentryBuilder {
241 traces_sampler: None,
242 }
243 }
244
245 #[must_use]
246 pub fn traces_sampler(mut self, traces_sampler: Arc<TracesSampler>) -> RocketSentryBuilder {
247 self.traces_sampler = Some(traces_sampler);
248 self
249 }
250
251 #[must_use]
252 pub fn build(self) -> RocketSentry {
253 RocketSentry {
254 guard: Mutex::new(None),
255 transactions_enabled: AtomicBool::new(false),
256 traces_sampler: self.traces_sampler,
257 }
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use rocket::http::ContentType;
264 use rocket::http::Header;
265 use rocket::local::asynchronous::Client;
266 use sentry::TransactionContext;
267 use std::borrow::Cow;
268 use std::sync::atomic::Ordering;
269 use std::sync::Arc;
270
271 use crate::{
272 request_to_header_map, request_to_query_string, request_to_transaction_name, RocketSentry,
273 };
274
275 const DEFAULT_ENV: Cow<'static, str> = Cow::Borrowed("TEST");
276
277 #[rocket::async_test]
278 async fn request_to_sentry_transaction_name_get_no_path() {
279 let rocket = rocket::build();
280 let client = Client::tracked(rocket).await.unwrap();
281 let request = client.get("/");
282
283 let transaction_name = request_to_transaction_name(request.inner());
284
285 assert_eq!(transaction_name, "GET /");
286 }
287
288 #[rocket::async_test]
289 async fn request_to_sentry_transaction_name_get_some_path() {
290 let rocket = rocket::build();
291 let client = Client::tracked(rocket).await.unwrap();
292 let request = client.get("/some/path");
293
294 let transaction_name = request_to_transaction_name(request.inner());
295
296 assert_eq!(transaction_name, "GET /some/path");
297 }
298
299 #[rocket::async_test]
300 async fn request_to_sentry_transaction_name_post_path_with_variables() {
301 let rocket = rocket::build();
302 let client = Client::tracked(rocket).await.unwrap();
303 let request = client.post("/users/6");
304
305 let transaction_name = request_to_transaction_name(request.inner());
306
307 assert_eq!(transaction_name, "POST /users/6");
309 }
310
311 #[rocket::async_test]
312 async fn request_to_query_string_is_none() {
313 let rocket = rocket::build();
314 let client = Client::tracked(rocket).await.unwrap();
315 let request = client.post("/");
316
317 let query_string = request_to_query_string(request.inner());
318
319 assert_eq!(query_string, None);
320 }
321
322 #[rocket::async_test]
323 async fn request_to_query_string_single_parameter() {
324 let rocket = rocket::build();
325 let client = Client::tracked(rocket).await.unwrap();
326 let request = client.post("/?param1=value1");
327
328 let query_string = request_to_query_string(request.inner());
329
330 assert_eq!(query_string, Some("param1=value1".to_string()));
331 }
332
333 #[rocket::async_test]
334 async fn request_to_query_string_multiple_parameters() {
335 let rocket = rocket::build();
336 let client = Client::tracked(rocket).await.unwrap();
337 let request = client.post("/?param1=value1¶m2=value2");
338
339 let query_string = request_to_query_string(request.inner());
340
341 assert_eq!(
342 query_string,
343 Some("param1=value1¶m2=value2".to_string())
344 );
345 }
346
347 #[rocket::async_test]
348 async fn request_to_header_map_is_empty() {
349 let rocket = rocket::build();
350 let client = Client::tracked(rocket).await.unwrap();
351 let request = client.get("/");
352
353 let header_map = request_to_header_map(request.inner());
354
355 assert!(header_map.is_empty());
356 }
357
358 #[rocket::async_test]
359 async fn request_to_header_map_multiple() {
360 let rocket = rocket::build();
361 let client = Client::tracked(rocket).await.unwrap();
362 let request = client
363 .get("/")
364 .header(ContentType::JSON)
365 .header(Header::new("custom-key", "custom-value"));
366
367 let header_map = request_to_header_map(request.inner());
368
369 assert_eq!(
370 header_map.get("custom-key"),
371 Some(&"custom-value".to_string())
372 );
373 assert_eq!(
374 header_map.get("Content-Type"),
375 Some(&"application/json".to_string())
376 );
377 }
378
379 #[rocket::async_test]
381 async fn transactions_not_enabled() {
382 let rocket_sentry = RocketSentry::builder().build();
383
384 rocket_sentry.init("https://user@some.dsn/123", 0., DEFAULT_ENV);
385
386 assert!(!rocket_sentry.transactions_enabled.load(Ordering::Relaxed));
387 }
388
389 #[rocket::async_test]
390 async fn transactions_enabled_by_traces_sample_rate() {
391 let rocket_sentry = RocketSentry::builder().build();
392
393 rocket_sentry.init("https://user@some.dsn/123", 0.01, DEFAULT_ENV);
394
395 assert!(rocket_sentry.transactions_enabled.load(Ordering::Relaxed));
396 }
397
398 #[rocket::async_test]
399 async fn transactions_enabled_by_traces_sampler() {
400 let rocket_sentry = RocketSentry::builder()
401 .traces_sampler(Arc::new(move |_: &TransactionContext| -> f32 {
402 0. }))
404 .build();
405
406 rocket_sentry.init("https://user@some.dsn/123", 0., DEFAULT_ENV);
407
408 assert!(rocket_sentry.transactions_enabled.load(Ordering::Relaxed));
409 }
410}