1use std::fmt::Debug;
2use std::num::NonZeroUsize;
3use std::time::Duration;
4
5use tokio::{
6 sync::{mpsc::UnboundedReceiver, oneshot},
7 time::interval,
8};
9
10use crate::{CloudWatchClient, client::NoopClient, dispatch::LogEvent, guard::ShutdownSignal};
11
12#[derive(Debug, Clone)]
14pub struct ExportConfig {
15 batch_size: NonZeroUsize,
17 interval: Duration,
19 destination: LogDestination,
21}
22
23#[derive(Debug, Clone, Default)]
25pub struct LogDestination {
26 pub log_group_name: String,
28 pub log_stream_name: String,
30}
31
32impl Default for ExportConfig {
33 fn default() -> Self {
34 Self {
35 batch_size: NonZeroUsize::new(5).unwrap(),
36 interval: Duration::from_secs(5),
37 destination: LogDestination::default(),
38 }
39 }
40}
41
42impl ExportConfig {
43 pub fn with_batch_size<T>(self, batch_size: T) -> Self
45 where
46 T: TryInto<NonZeroUsize>,
47 <T as TryInto<NonZeroUsize>>::Error: Debug,
48 {
49 Self {
50 batch_size: batch_size
51 .try_into()
52 .expect("batch size must be greater than or equal to 1"),
53 ..self
54 }
55 }
56
57 pub fn with_interval(self, interval: Duration) -> Self {
59 Self { interval, ..self }
60 }
61
62 pub fn with_log_group_name(self, log_group_name: impl Into<String>) -> Self {
64 Self {
65 destination: LogDestination {
66 log_group_name: log_group_name.into(),
67 log_stream_name: self.destination.log_stream_name,
68 },
69 ..self
70 }
71 }
72
73 pub fn with_log_stream_name(self, log_stream_name: impl Into<String>) -> Self {
75 Self {
76 destination: LogDestination {
77 log_stream_name: log_stream_name.into(),
78 log_group_name: self.destination.log_group_name,
79 },
80 ..self
81 }
82 }
83}
84
85pub(crate) struct BatchExporter<C> {
86 client: C,
87 queue: Vec<LogEvent>,
88 config: ExportConfig,
89}
90
91impl Default for BatchExporter<NoopClient> {
92 fn default() -> Self {
93 Self::new(NoopClient::new(), ExportConfig::default())
94 }
95}
96
97impl<C> BatchExporter<C> {
98 pub(crate) fn new(client: C, config: ExportConfig) -> Self {
99 Self {
100 client,
101 config,
102 queue: Vec::new(),
103 }
104 }
105}
106
107impl<C> BatchExporter<C>
108where
109 C: CloudWatchClient + Send + Sync + 'static,
110{
111 pub(crate) async fn run(
112 mut self,
113 mut rx: UnboundedReceiver<LogEvent>,
114 mut shutdown_rx: oneshot::Receiver<ShutdownSignal>,
115 ) {
116 let mut interval = interval(self.config.interval);
117 let mut shutdown_signal = None;
118
119 loop {
120 tokio::select! {
121 _ = interval.tick() => {
122 if self.queue.is_empty() {
123 continue;
124 }
125 }
126
127 event = rx.recv() => {
128 let Some(event) = event else {
129 break;
130 };
131
132 self.queue.push(event);
133 if self.queue.len() < <NonZeroUsize as Into<usize>>::into(self.config.batch_size) {
134 continue
135 }
136 }
137
138 received_shutdown = &mut shutdown_rx => {
139 if let Ok(signal) = received_shutdown {
140 shutdown_signal = Some(signal);
141 }
142 while let Ok(event) = rx.try_recv() {
143 self.queue.push(event);
144 }
145 break;
146 }
147 }
148 self.flush().await;
149 }
150 self.flush().await;
151 if let Some(shutdown_signal) = shutdown_signal {
152 shutdown_signal.ack();
153 }
154 }
155
156 async fn flush(&mut self) {
157 let logs: Vec<LogEvent> = Self::take_from_queue(&mut self.queue);
158
159 if logs.is_empty() {
160 return;
161 }
162
163 if let Err(err) = self
164 .client
165 .put_logs(self.config.destination.clone(), logs)
166 .await
167 {
168 eprintln!(
169 "[tracing-cloudwatch] Unable to put logs to cloudwatch. Error: {err:?} {:?}",
170 self.config.destination
171 );
172 }
173 }
174
175 fn take_from_queue(queue: &mut Vec<LogEvent>) -> Vec<LogEvent> {
176 if cfg!(feature = "ordered_logs") {
177 let mut logs = std::mem::take(queue);
178 logs.sort_by_key(|log| log.timestamp);
179 logs
180 } else {
181 std::mem::take(queue)
182 }
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189
190 mod helper {
191 use super::*;
192 use async_trait::async_trait;
193 use std::sync::{Arc, Mutex};
194 use tokio::time::{sleep, timeout};
195
196 #[derive(Clone, Default)]
197 pub(super) struct RecordingClient {
198 logs: Arc<Mutex<Vec<LogEvent>>>,
199 }
200
201 #[async_trait]
202 impl CloudWatchClient for RecordingClient {
203 async fn put_logs(
204 &self,
205 _dest: LogDestination,
206 logs: Vec<LogEvent>,
207 ) -> Result<(), crate::client::PutLogsError> {
208 self.logs.lock().unwrap().extend(logs);
209 Ok(())
210 }
211 }
212
213 impl RecordingClient {
214 pub(super) fn exported_count(&self) -> usize {
215 self.logs.lock().unwrap().len()
216 }
217
218 pub(super) fn exported_messages(&self) -> Vec<String> {
219 self.logs
220 .lock()
221 .unwrap()
222 .iter()
223 .map(|event| event.message.clone())
224 .collect()
225 }
226 }
227
228 pub(super) async fn wait_for_exported_count(client: &RecordingClient, expected: usize) {
229 timeout(Duration::from_secs(1), async {
230 loop {
231 if client.exported_count() >= expected {
232 break;
233 }
234 sleep(Duration::from_millis(10)).await;
235 }
236 })
237 .await
238 .expect("timed out waiting for exported log events");
239 }
240 }
241
242 mod ordering {
243 use super::*;
244 use chrono::{DateTime, Utc};
245
246 const ONE_DAY_NS: i64 = 86_400_000_000_000;
247 const DAY_ONE: DateTime<Utc> = DateTime::from_timestamp_nanos(0 + ONE_DAY_NS);
248 const DAY_TWO: DateTime<Utc> = DateTime::from_timestamp_nanos(0 + (ONE_DAY_NS * 2));
249 const DAY_THREE: DateTime<Utc> = DateTime::from_timestamp_nanos(0 + (ONE_DAY_NS * 3));
250
251 fn unordered_queue() -> Vec<LogEvent> {
252 vec![
253 LogEvent {
254 message: "1".to_string(),
255 timestamp: DAY_ONE,
256 },
257 LogEvent {
258 message: "3".to_string(),
259 timestamp: DAY_THREE,
260 },
261 LogEvent {
262 message: "2".to_string(),
263 timestamp: DAY_TWO,
264 },
265 ]
266 }
267
268 #[cfg(feature = "ordered_logs")]
269 fn assert_is_ordered(logs: Vec<LogEvent>) {
270 let mut last_timestamp = DateTime::from_timestamp_nanos(0);
271
272 for log in logs {
273 assert!(
274 log.timestamp > last_timestamp,
275 "Not true: {} > {}",
276 log.timestamp,
277 last_timestamp
278 );
279 last_timestamp = log.timestamp;
280 }
281 }
282
283 #[cfg(not(feature = "ordered_logs"))]
284 #[test]
285 fn does_not_order_logs_by_default() {
286 let mut unordered_queue = unordered_queue();
287 let still_unordered_queue =
288 BatchExporter::<NoopClient>::take_from_queue(&mut unordered_queue);
289
290 let mut still_unordered_queue_iter = still_unordered_queue.iter();
291 assert_eq!(
292 DAY_ONE,
293 still_unordered_queue_iter.next().unwrap().timestamp
294 );
295 assert_eq!(
296 DAY_THREE,
297 still_unordered_queue_iter.next().unwrap().timestamp
298 );
299 assert_eq!(
300 DAY_TWO,
301 still_unordered_queue_iter.next().unwrap().timestamp
302 );
303 }
304
305 #[cfg(feature = "ordered_logs")]
306 #[test]
307 fn orders_logs_when_enabled() {
308 let mut unordered_queue = unordered_queue();
309 let ordered_queue = BatchExporter::<NoopClient>::take_from_queue(&mut unordered_queue);
310 assert_is_ordered(ordered_queue);
311 }
312 }
313
314 mod integration {
315 use super::helper::{RecordingClient, wait_for_exported_count};
316 use super::*;
317 use chrono::Utc;
318 use tokio::time::sleep;
319 use tracing_subscriber::layer::SubscriberExt;
320
321 #[tokio::test(flavor = "current_thread")]
322 async fn drains_all_buffered_events_on_shutdown() {
323 let client = RecordingClient::default();
324 let exporter = BatchExporter::new(
325 client.clone(),
326 ExportConfig::default()
327 .with_batch_size(10_000)
328 .with_interval(Duration::from_secs(60))
329 .with_log_group_name("group")
330 .with_log_stream_name("stream"),
331 );
332
333 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
334 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<ShutdownSignal>();
335 let (shutdown_signal, _ack_rx) = ShutdownSignal::new();
336
337 let total = 512;
338 for idx in 0..total {
339 tx.send(LogEvent {
340 message: format!("event-{idx}"),
341 timestamp: Utc::now(),
342 })
343 .unwrap();
344 }
345 drop(tx);
346 shutdown_tx.send(shutdown_signal).unwrap();
347
348 exporter.run(rx, shutdown_rx).await;
349
350 assert_eq!(
351 client.exported_count(),
352 total,
353 "all events queued before shutdown should be exported"
354 );
355 }
356
357 #[tokio::test(flavor = "current_thread")]
358 async fn exports_events_with_registry_on_guard_shutdown() {
359 let client = RecordingClient::default();
360 let (cw_layer, guard) = crate::layer()
361 .with_code_location(false)
362 .with_target(false)
363 .with_client(
364 client.clone(),
365 ExportConfig::default()
366 .with_batch_size(1024)
367 .with_interval(Duration::from_secs(60))
368 .with_log_group_name("group")
369 .with_log_stream_name("stream"),
370 );
371
372 let subscriber = tracing_subscriber::registry().with(cw_layer);
373 tracing::subscriber::with_default(subscriber, || {
374 tracing::info!("integration-log-1");
375 tracing::warn!("integration-log-2");
376 });
377
378 guard.shutdown().await;
379
380 let messages = client.exported_messages();
381 assert_eq!(messages.len(), 2);
382 assert!(
383 messages
384 .iter()
385 .any(|message| message.contains("integration-log-1"))
386 );
387 assert!(
388 messages
389 .iter()
390 .any(|message| message.contains("integration-log-2"))
391 );
392 }
393
394 #[tokio::test(flavor = "current_thread")]
395 async fn exports_when_batch_size_is_reached() {
396 let client = RecordingClient::default();
397 let (cw_layer, guard) = crate::layer()
398 .with_code_location(false)
399 .with_target(false)
400 .with_client(
401 client.clone(),
402 ExportConfig::default()
403 .with_batch_size(2)
404 .with_interval(Duration::from_secs(60))
405 .with_log_group_name("group")
406 .with_log_stream_name("stream"),
407 );
408
409 let subscriber = tracing_subscriber::registry().with(cw_layer);
410 sleep(Duration::from_millis(20)).await;
412
413 tracing::subscriber::with_default(subscriber, || {
414 tracing::info!("batch-log-1");
415 tracing::info!("batch-log-2");
416 });
417
418 wait_for_exported_count(&client, 2).await;
419 guard.shutdown().await;
420 }
421
422 #[tokio::test(flavor = "current_thread")]
423 async fn exports_without_shutdown_when_batch_not_full() {
424 let client = RecordingClient::default();
425 let (cw_layer, guard) = crate::layer()
426 .with_code_location(false)
427 .with_target(false)
428 .with_client(
429 client.clone(),
430 ExportConfig::default()
431 .with_batch_size(1024)
432 .with_interval(Duration::from_millis(200))
433 .with_log_group_name("group")
434 .with_log_stream_name("stream"),
435 );
436
437 let subscriber = tracing_subscriber::registry().with(cw_layer);
438 sleep(Duration::from_millis(20)).await;
440
441 tracing::subscriber::with_default(subscriber, || {
442 tracing::info!("interval-log-1");
443 });
444
445 wait_for_exported_count(&client, 1).await;
446 guard.shutdown().await;
447 }
448 }
449
450 mod recursion {
451 use super::*;
452 use async_trait::async_trait;
453 use std::sync::{
454 Arc, Mutex,
455 atomic::{AtomicUsize, Ordering},
456 };
457 use tokio::time::sleep;
458 use tokio::time::timeout;
459 use tracing_subscriber::layer::SubscriberExt;
460
461 const RECURSIVE_EMIT_LIMIT: usize = 5;
462
463 #[derive(Clone, Default)]
464 struct InternalTracingClient {
465 logs: Arc<Mutex<Vec<LogEvent>>>,
466 put_calls: Arc<AtomicUsize>,
467 }
468
469 #[async_trait]
470 impl CloudWatchClient for InternalTracingClient {
471 async fn put_logs(
472 &self,
473 _dest: LogDestination,
474 logs: Vec<LogEvent>,
475 ) -> Result<(), crate::client::PutLogsError> {
476 let call = self.put_calls.fetch_add(1, Ordering::Relaxed) + 1;
477 self.logs.lock().unwrap().extend(logs);
478
479 if call < RECURSIVE_EMIT_LIMIT {
481 tracing::error!("simulated-sdk-internal-error-{call}");
482 }
483
484 Ok(())
485 }
486 }
487
488 impl InternalTracingClient {
489 fn put_call_count(&self) -> usize {
490 self.put_calls.load(Ordering::Relaxed)
491 }
492
493 fn exported_messages(&self) -> Vec<String> {
494 self.logs
495 .lock()
496 .unwrap()
497 .iter()
498 .map(|event| event.message.clone())
499 .collect()
500 }
501 }
502
503 async fn wait_for_put_calls(client: &InternalTracingClient, expected: usize) {
504 timeout(Duration::from_secs(1), async {
505 loop {
506 if client.put_call_count() >= expected {
507 break;
508 }
509 sleep(Duration::from_millis(10)).await;
510 }
511 })
512 .await
513 .expect("timed out waiting for put_logs calls");
514 }
515
516 #[tokio::test(flavor = "current_thread")]
517 async fn does_not_recurse_when_client_emits_internal_traces() {
518 let client = InternalTracingClient::default();
519 let (cw_layer, guard) = crate::layer()
520 .with_code_location(false)
521 .with_target(false)
522 .with_client(
523 client.clone(),
524 ExportConfig::default()
525 .with_batch_size(1)
526 .with_interval(Duration::from_secs(60))
527 .with_log_group_name("group")
528 .with_log_stream_name("stream"),
529 );
530
531 let subscriber = tracing_subscriber::registry().with(cw_layer);
532 let _default = tracing::subscriber::set_default(subscriber);
533
534 tracing::info!("application-log-1");
535
536 wait_for_put_calls(&client, 1).await;
537 sleep(Duration::from_millis(100)).await;
538
539 assert_eq!(
540 1,
541 client.put_call_count(),
542 "with subscriber isolation enabled, internal tracing must not recurse"
543 );
544
545 guard.shutdown().await;
546
547 let messages = client.exported_messages();
548 assert_eq!(messages.len(), 1);
549 assert!(
550 messages
551 .iter()
552 .any(|message| message.contains("application-log-1"))
553 );
554 assert!(
555 messages
556 .iter()
557 .all(|message| !message.contains("simulated-sdk-internal-error")),
558 "internal tracing output must not be exported back into CloudWatch input"
559 );
560 }
561 }
562}