1use aws_sdk_cloudwatch::Client as CloudWatchClient;
6use aws_sdk_cloudwatch::types::{Dimension, MetricDatum, StandardUnit};
7use aws_sdk_cloudwatchlogs::Client as CloudWatchLogsClient;
8use aws_sdk_cloudwatchlogs::types::InputLogEvent;
9use llm_shield_cloud::{
10 async_trait, CloudError, CloudLogger, CloudMetrics, LogEntry, LogLevel, Metric, Result,
11};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15
16pub struct CloudWatchMetrics {
50 client: CloudWatchClient,
51 namespace: String,
52 region: String,
53 batch_buffer: Arc<RwLock<Vec<Metric>>>,
54 batch_size: usize,
55}
56
57impl CloudWatchMetrics {
58 pub async fn new(namespace: impl Into<String>) -> Result<Self> {
68 let config = aws_config::load_from_env().await;
69 let region = config
70 .region()
71 .map(|r| r.to_string())
72 .unwrap_or_else(|| "us-east-1".to_string());
73
74 let client = CloudWatchClient::new(&config);
75 let namespace = namespace.into();
76
77 tracing::info!(
78 "Initialized CloudWatch Metrics client for namespace: {} in region: {}",
79 namespace,
80 region
81 );
82
83 Ok(Self {
84 client,
85 namespace,
86 region,
87 batch_buffer: Arc::new(RwLock::new(Vec::new())),
88 batch_size: 20, })
90 }
91
92 pub async fn new_with_config(
104 namespace: impl Into<String>,
105 region: impl Into<String>,
106 batch_size: usize,
107 ) -> Result<Self> {
108 let region_str = region.into();
109 let config = aws_config::from_env()
110 .region(aws_config::Region::new(region_str.clone()))
111 .load()
112 .await;
113
114 let client = CloudWatchClient::new(&config);
115 let namespace = namespace.into();
116
117 tracing::info!(
118 "Initialized CloudWatch Metrics client for namespace: {} in region: {} (batch size: {})",
119 namespace,
120 region_str,
121 batch_size
122 );
123
124 Ok(Self {
125 client,
126 namespace,
127 region: region_str,
128 batch_buffer: Arc::new(RwLock::new(Vec::new())),
129 batch_size: batch_size.min(1000), })
131 }
132
133 pub fn namespace(&self) -> &str {
135 &self.namespace
136 }
137
138 pub fn region(&self) -> &str {
140 &self.region
141 }
142
143 pub async fn flush(&self) -> Result<()> {
145 let mut buffer = self.batch_buffer.write().await;
146
147 if buffer.is_empty() {
148 return Ok(());
149 }
150
151 let metrics_to_send = buffer.drain(..).collect::<Vec<_>>();
152 drop(buffer); self.send_metrics_batch(&metrics_to_send).await?;
155
156 Ok(())
157 }
158
159 async fn send_metrics_batch(&self, metrics: &[Metric]) -> Result<()> {
161 if metrics.is_empty() {
162 return Ok(());
163 }
164
165 tracing::debug!("Sending {} metrics to CloudWatch", metrics.len());
166
167 let metric_data: Vec<MetricDatum> = metrics
168 .iter()
169 .map(|m| {
170 let mut datum = MetricDatum::builder()
171 .metric_name(&m.name)
172 .value(m.value)
173 .timestamp(aws_sdk_cloudwatch::primitives::DateTime::from_secs(
174 m.timestamp as i64,
175 ));
176
177 for (key, value) in &m.dimensions {
179 datum = datum.dimensions(
180 Dimension::builder()
181 .name(key.clone())
182 .value(value.clone())
183 .build(),
184 );
185 }
186
187 if let Some(ref unit_str) = m.unit {
189 if let Ok(unit) = parse_standard_unit(unit_str) {
190 datum = datum.unit(unit);
191 }
192 }
193
194 datum.build()
195 })
196 .collect();
197
198 self.client
199 .put_metric_data()
200 .namespace(&self.namespace)
201 .set_metric_data(Some(metric_data))
202 .send()
203 .await
204 .map_err(|e| CloudError::MetricsExport(e.to_string()))?;
205
206 tracing::info!("Successfully sent {} metrics to CloudWatch", metrics.len());
207
208 Ok(())
209 }
210}
211
212#[async_trait]
213impl CloudMetrics for CloudWatchMetrics {
214 async fn export_metrics(&self, metrics: &[Metric]) -> Result<()> {
215 tracing::debug!("Exporting {} metrics to CloudWatch", metrics.len());
216
217 for chunk in metrics.chunks(self.batch_size) {
219 self.send_metrics_batch(chunk).await?;
220 }
221
222 Ok(())
223 }
224
225 async fn export_metric(&self, metric: &Metric) -> Result<()> {
226 tracing::debug!("Exporting metric to CloudWatch: {}", metric.name);
227
228 let mut buffer = self.batch_buffer.write().await;
230 buffer.push(metric.clone());
231
232 if buffer.len() >= self.batch_size {
234 let metrics_to_send = buffer.drain(..).collect::<Vec<_>>();
235 drop(buffer); self.send_metrics_batch(&metrics_to_send).await?;
238 }
239
240 Ok(())
241 }
242}
243
244pub struct CloudWatchLogger {
270 client: CloudWatchLogsClient,
271 log_group: String,
272 log_stream: String,
273 region: String,
274 sequence_token: Arc<RwLock<Option<String>>>,
275 batch_buffer: Arc<RwLock<Vec<LogEntry>>>,
276 batch_size: usize,
277}
278
279impl CloudWatchLogger {
280 pub async fn new(
291 log_group: impl Into<String>,
292 log_stream: impl Into<String>,
293 ) -> Result<Self> {
294 let config = aws_config::load_from_env().await;
295 let region = config
296 .region()
297 .map(|r| r.to_string())
298 .unwrap_or_else(|| "us-east-1".to_string());
299
300 let client = CloudWatchLogsClient::new(&config);
301 let log_group = log_group.into();
302 let log_stream = log_stream.into();
303
304 let _ = client
306 .create_log_stream()
307 .log_group_name(&log_group)
308 .log_stream_name(&log_stream)
309 .send()
310 .await;
311
312 tracing::info!(
313 "Initialized CloudWatch Logs client for log group: {} stream: {} in region: {}",
314 log_group,
315 log_stream,
316 region
317 );
318
319 Ok(Self {
320 client,
321 log_group,
322 log_stream,
323 region,
324 sequence_token: Arc::new(RwLock::new(None)),
325 batch_buffer: Arc::new(RwLock::new(Vec::new())),
326 batch_size: 100, })
328 }
329
330 pub async fn new_with_config(
343 log_group: impl Into<String>,
344 log_stream: impl Into<String>,
345 region: impl Into<String>,
346 batch_size: usize,
347 ) -> Result<Self> {
348 let region_str = region.into();
349 let config = aws_config::from_env()
350 .region(aws_config::Region::new(region_str.clone()))
351 .load()
352 .await;
353
354 let client = CloudWatchLogsClient::new(&config);
355 let log_group = log_group.into();
356 let log_stream = log_stream.into();
357
358 let _ = client
360 .create_log_stream()
361 .log_group_name(&log_group)
362 .log_stream_name(&log_stream)
363 .send()
364 .await;
365
366 tracing::info!(
367 "Initialized CloudWatch Logs client for log group: {} stream: {} in region: {} (batch size: {})",
368 log_group,
369 log_stream,
370 region_str,
371 batch_size
372 );
373
374 Ok(Self {
375 client,
376 log_group,
377 log_stream,
378 region: region_str,
379 sequence_token: Arc::new(RwLock::new(None)),
380 batch_buffer: Arc::new(RwLock::new(Vec::new())),
381 batch_size,
382 })
383 }
384
385 pub fn log_group(&self) -> &str {
387 &self.log_group
388 }
389
390 pub fn log_stream(&self) -> &str {
392 &self.log_stream
393 }
394
395 pub fn region(&self) -> &str {
397 &self.region
398 }
399
400 pub async fn flush(&self) -> Result<()> {
402 let mut buffer = self.batch_buffer.write().await;
403
404 if buffer.is_empty() {
405 return Ok(());
406 }
407
408 let logs_to_send = buffer.drain(..).collect::<Vec<_>>();
409 drop(buffer); self.send_logs_batch(&logs_to_send).await?;
412
413 Ok(())
414 }
415
416 async fn send_logs_batch(&self, entries: &[LogEntry]) -> Result<()> {
418 if entries.is_empty() {
419 return Ok(());
420 }
421
422 tracing::debug!("Sending {} log entries to CloudWatch Logs", entries.len());
423
424 let mut log_events: Vec<InputLogEvent> = entries
426 .iter()
427 .map(|entry| {
428 let timestamp = entry
429 .timestamp
430 .duration_since(std::time::UNIX_EPOCH)
431 .unwrap_or_default()
432 .as_millis() as i64;
433
434 let mut message = format!("[{}] {}", format_log_level(&entry.level), entry.message);
436
437 if !entry.labels.is_empty() {
438 message.push_str(&format!(" {:?}", entry.labels));
439 }
440
441 if let Some(ref trace_id) = entry.trace_id {
442 message.push_str(&format!(" trace_id={}", trace_id));
443 }
444
445 if let Some(ref span_id) = entry.span_id {
446 message.push_str(&format!(" span_id={}", span_id));
447 }
448
449 InputLogEvent::builder()
450 .timestamp(timestamp)
451 .message(message)
452 .build()
453 .expect("Failed to build InputLogEvent")
454 })
455 .collect();
456
457 log_events.sort_by_key(|e| e.timestamp);
459
460 let sequence_token = self.sequence_token.read().await.clone();
462
463 let mut request = self
465 .client
466 .put_log_events()
467 .log_group_name(&self.log_group)
468 .log_stream_name(&self.log_stream)
469 .set_log_events(Some(log_events));
470
471 if let Some(token) = sequence_token {
472 request = request.sequence_token(token);
473 }
474
475 let response = request
476 .send()
477 .await
478 .map_err(|e| CloudError::LogExport(e.to_string()))?;
479
480 if let Some(next_token) = response.next_sequence_token {
482 *self.sequence_token.write().await = Some(next_token);
483 }
484
485 tracing::info!(
486 "Successfully sent {} log entries to CloudWatch Logs",
487 entries.len()
488 );
489
490 Ok(())
491 }
492}
493
494#[async_trait]
495impl CloudLogger for CloudWatchLogger {
496 async fn log(&self, message: &str, level: LogLevel) -> Result<()> {
497 let entry = LogEntry {
498 timestamp: std::time::SystemTime::now(),
499 level,
500 message: message.to_string(),
501 labels: HashMap::new(),
502 trace_id: None,
503 span_id: None,
504 source: None,
505 };
506
507 self.log_structured(&entry).await
508 }
509
510 async fn log_structured(&self, entry: &LogEntry) -> Result<()> {
511 tracing::debug!("Logging structured entry to CloudWatch Logs");
512
513 let mut buffer = self.batch_buffer.write().await;
515 buffer.push(entry.clone());
516
517 if buffer.len() >= self.batch_size {
519 let logs_to_send = buffer.drain(..).collect::<Vec<_>>();
520 drop(buffer); self.send_logs_batch(&logs_to_send).await?;
523 }
524
525 Ok(())
526 }
527
528 async fn log_batch(&self, entries: &[LogEntry]) -> Result<()> {
529 tracing::debug!("Logging batch of {} entries to CloudWatch Logs", entries.len());
530
531 for chunk in entries.chunks(self.batch_size) {
533 self.send_logs_batch(chunk).await?;
534 }
535
536 Ok(())
537 }
538}
539
540fn parse_standard_unit(unit_str: &str) -> Result<StandardUnit> {
542 match unit_str.to_lowercase().as_str() {
543 "seconds" => Ok(StandardUnit::Seconds),
544 "microseconds" => Ok(StandardUnit::Microseconds),
545 "milliseconds" => Ok(StandardUnit::Milliseconds),
546 "bytes" => Ok(StandardUnit::Bytes),
547 "kilobytes" => Ok(StandardUnit::Kilobytes),
548 "megabytes" => Ok(StandardUnit::Megabytes),
549 "gigabytes" => Ok(StandardUnit::Gigabytes),
550 "terabytes" => Ok(StandardUnit::Terabytes),
551 "bits" => Ok(StandardUnit::Bits),
552 "kilobits" => Ok(StandardUnit::Kilobits),
553 "megabits" => Ok(StandardUnit::Megabits),
554 "gigabits" => Ok(StandardUnit::Gigabits),
555 "terabits" => Ok(StandardUnit::Terabits),
556 "percent" => Ok(StandardUnit::Percent),
557 "count" => Ok(StandardUnit::Count),
558 "bytes/second" | "bytes_per_second" => Ok(StandardUnit::BytesSecond),
559 "kilobytes/second" | "kilobytes_per_second" => Ok(StandardUnit::KilobytesSecond),
560 "megabytes/second" | "megabytes_per_second" => Ok(StandardUnit::MegabytesSecond),
561 "gigabytes/second" | "gigabytes_per_second" => Ok(StandardUnit::GigabytesSecond),
562 "terabytes/second" | "terabytes_per_second" => Ok(StandardUnit::TerabytesSecond),
563 "bits/second" | "bits_per_second" => Ok(StandardUnit::BitsSecond),
564 "kilobits/second" | "kilobits_per_second" => Ok(StandardUnit::KilobitsSecond),
565 "megabits/second" | "megabits_per_second" => Ok(StandardUnit::MegabitsSecond),
566 "gigabits/second" | "gigabits_per_second" => Ok(StandardUnit::GigabitsSecond),
567 "terabits/second" | "terabits_per_second" => Ok(StandardUnit::TerabitsSecond),
568 "count/second" | "count_per_second" => Ok(StandardUnit::CountSecond),
569 "none" => Ok(StandardUnit::None),
570 _ => Ok(StandardUnit::None),
571 }
572}
573
574fn format_log_level(level: &LogLevel) -> &'static str {
576 match level {
577 LogLevel::Trace => "TRACE",
578 LogLevel::Debug => "DEBUG",
579 LogLevel::Info => "INFO",
580 LogLevel::Warn => "WARN",
581 LogLevel::Error => "ERROR",
582 LogLevel::Fatal => "FATAL",
583 }
584}
585
586#[cfg(test)]
587mod tests {
588 use super::*;
589
590 #[test]
591 fn test_parse_standard_unit() {
592 assert!(matches!(
593 parse_standard_unit("count"),
594 Ok(StandardUnit::Count)
595 ));
596 assert!(matches!(
597 parse_standard_unit("bytes"),
598 Ok(StandardUnit::Bytes)
599 ));
600 assert!(matches!(
601 parse_standard_unit("seconds"),
602 Ok(StandardUnit::Seconds)
603 ));
604 assert!(matches!(
605 parse_standard_unit("percent"),
606 Ok(StandardUnit::Percent)
607 ));
608 assert!(matches!(
609 parse_standard_unit("invalid"),
610 Ok(StandardUnit::None)
611 ));
612 }
613
614 #[test]
615 fn test_format_log_level() {
616 assert_eq!(format_log_level(&LogLevel::Trace), "TRACE");
617 assert_eq!(format_log_level(&LogLevel::Debug), "DEBUG");
618 assert_eq!(format_log_level(&LogLevel::Info), "INFO");
619 assert_eq!(format_log_level(&LogLevel::Warn), "WARN");
620 assert_eq!(format_log_level(&LogLevel::Error), "ERROR");
621 assert_eq!(format_log_level(&LogLevel::Fatal), "FATAL");
622 }
623
624 #[test]
625 fn test_batch_size_limits() {
626 let metrics_batch_size = 20;
627 let logs_batch_size = 100;
628
629 assert!(metrics_batch_size <= 1000);
631 assert!(logs_batch_size <= 10000);
632 }
633
634 #[tokio::test]
635 async fn test_metric_batching() {
636 let metrics = vec![
637 Metric {
638 name: "test1".to_string(),
639 value: 1.0,
640 timestamp: 1000,
641 dimensions: HashMap::new(),
642 unit: Some("Count".to_string()),
643 },
644 Metric {
645 name: "test2".to_string(),
646 value: 2.0,
647 timestamp: 2000,
648 dimensions: HashMap::new(),
649 unit: Some("Count".to_string()),
650 },
651 ];
652
653 let batch_size = 1;
655 let chunks: Vec<_> = metrics.chunks(batch_size).collect();
656
657 assert_eq!(chunks.len(), 2);
658 assert_eq!(chunks[0].len(), 1);
659 assert_eq!(chunks[1].len(), 1);
660 }
661
662 #[tokio::test]
663 async fn test_log_entry_sorting() {
664 let mut entries = vec![
665 LogEntry {
666 timestamp: std::time::UNIX_EPOCH + std::time::Duration::from_secs(2000),
667 level: LogLevel::Info,
668 message: "second".to_string(),
669 labels: HashMap::new(),
670 trace_id: None,
671 span_id: None,
672 },
673 LogEntry {
674 timestamp: std::time::UNIX_EPOCH + std::time::Duration::from_secs(1000),
675 level: LogLevel::Info,
676 message: "first".to_string(),
677 labels: HashMap::new(),
678 trace_id: None,
679 span_id: None,
680 },
681 ];
682
683 entries.sort_by_key(|e| {
684 e.timestamp
685 .duration_since(std::time::UNIX_EPOCH)
686 .unwrap()
687 .as_millis()
688 });
689
690 assert_eq!(entries[0].message, "first");
691 assert_eq!(entries[1].message, "second");
692 }
693}