llm_shield_cloud_aws/
observability.rs

1//! AWS CloudWatch observability integration.
2//!
3//! Provides implementations of `CloudMetrics` and `CloudLogger` traits for AWS CloudWatch.
4
5use aws_sdk_cloudwatch::Client as CloudWatchClient;
6use aws_sdk_cloudwatch::types::{Dimension, MetricDatum, StandardUnit};
7use aws_sdk_cloudwatchlogs::Client as CloudWatchLogsClient;
8use aws_sdk_cloudwatchlogs::types::InputLogEvent;
9use llm_shield_cloud::{
10    async_trait, CloudError, CloudLogger, CloudMetrics, LogEntry, LogLevel, Metric, Result,
11};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15
16/// AWS CloudWatch Metrics implementation of `CloudMetrics`.
17///
18/// This implementation provides:
19/// - Batched metric export for efficiency
20/// - Support for custom dimensions
21/// - Automatic namespace configuration
22/// - Standard and custom units
23///
24/// # Example
25///
26/// ```no_run
27/// use llm_shield_cloud_aws::CloudWatchMetrics;
28/// use llm_shield_cloud::{CloudMetrics, Metric};
29/// use std::collections::HashMap;
30///
31/// #[tokio::main]
32/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
33///     let metrics = CloudWatchMetrics::new("LLMShield").await?;
34///
35///     let metric = Metric {
36///         name: "RequestCount".to_string(),
37///         value: 1.0,
38///         timestamp: std::time::SystemTime::now()
39///             .duration_since(std::time::UNIX_EPOCH)?
40///             .as_secs(),
41///         dimensions: HashMap::new(),
42///         unit: Some("Count".to_string()),
43///     };
44///
45///     metrics.export_metric(&metric).await?;
46///     Ok(())
47/// }
48/// ```
49pub struct CloudWatchMetrics {
50    client: CloudWatchClient,
51    namespace: String,
52    region: String,
53    batch_buffer: Arc<RwLock<Vec<Metric>>>,
54    batch_size: usize,
55}
56
57impl CloudWatchMetrics {
58    /// Creates a new CloudWatch Metrics client with default configuration.
59    ///
60    /// # Arguments
61    ///
62    /// * `namespace` - CloudWatch namespace (e.g., "LLMShield")
63    ///
64    /// # Errors
65    ///
66    /// Returns error if AWS configuration cannot be loaded.
67    pub async fn new(namespace: impl Into<String>) -> Result<Self> {
68        let config = aws_config::load_from_env().await;
69        let region = config
70            .region()
71            .map(|r| r.to_string())
72            .unwrap_or_else(|| "us-east-1".to_string());
73
74        let client = CloudWatchClient::new(&config);
75        let namespace = namespace.into();
76
77        tracing::info!(
78            "Initialized CloudWatch Metrics client for namespace: {} in region: {}",
79            namespace,
80            region
81        );
82
83        Ok(Self {
84            client,
85            namespace,
86            region,
87            batch_buffer: Arc::new(RwLock::new(Vec::new())),
88            batch_size: 20, // CloudWatch allows up to 1000, but 20 is safer
89        })
90    }
91
92    /// Creates a new CloudWatch Metrics client with specific region and batch size.
93    ///
94    /// # Arguments
95    ///
96    /// * `namespace` - CloudWatch namespace
97    /// * `region` - AWS region
98    /// * `batch_size` - Number of metrics to batch before sending (max 1000)
99    ///
100    /// # Errors
101    ///
102    /// Returns error if AWS configuration cannot be loaded.
103    pub async fn new_with_config(
104        namespace: impl Into<String>,
105        region: impl Into<String>,
106        batch_size: usize,
107    ) -> Result<Self> {
108        let region_str = region.into();
109        let config = aws_config::from_env()
110            .region(aws_config::Region::new(region_str.clone()))
111            .load()
112            .await;
113
114        let client = CloudWatchClient::new(&config);
115        let namespace = namespace.into();
116
117        tracing::info!(
118            "Initialized CloudWatch Metrics client for namespace: {} in region: {} (batch size: {})",
119            namespace,
120            region_str,
121            batch_size
122        );
123
124        Ok(Self {
125            client,
126            namespace,
127            region: region_str,
128            batch_buffer: Arc::new(RwLock::new(Vec::new())),
129            batch_size: batch_size.min(1000), // CloudWatch hard limit
130        })
131    }
132
133    /// Gets the namespace this client is configured for.
134    pub fn namespace(&self) -> &str {
135        &self.namespace
136    }
137
138    /// Gets the AWS region this client is configured for.
139    pub fn region(&self) -> &str {
140        &self.region
141    }
142
143    /// Flushes buffered metrics to CloudWatch.
144    pub async fn flush(&self) -> Result<()> {
145        let mut buffer = self.batch_buffer.write().await;
146
147        if buffer.is_empty() {
148            return Ok(());
149        }
150
151        let metrics_to_send = buffer.drain(..).collect::<Vec<_>>();
152        drop(buffer); // Release lock before network call
153
154        self.send_metrics_batch(&metrics_to_send).await?;
155
156        Ok(())
157    }
158
159    /// Sends a batch of metrics to CloudWatch.
160    async fn send_metrics_batch(&self, metrics: &[Metric]) -> Result<()> {
161        if metrics.is_empty() {
162            return Ok(());
163        }
164
165        tracing::debug!("Sending {} metrics to CloudWatch", metrics.len());
166
167        let metric_data: Vec<MetricDatum> = metrics
168            .iter()
169            .map(|m| {
170                let mut datum = MetricDatum::builder()
171                    .metric_name(&m.name)
172                    .value(m.value)
173                    .timestamp(aws_sdk_cloudwatch::primitives::DateTime::from_secs(
174                        m.timestamp as i64,
175                    ));
176
177                // Add dimensions
178                for (key, value) in &m.dimensions {
179                    datum = datum.dimensions(
180                        Dimension::builder()
181                            .name(key.clone())
182                            .value(value.clone())
183                            .build(),
184                    );
185                }
186
187                // Add unit if specified
188                if let Some(ref unit_str) = m.unit {
189                    if let Ok(unit) = parse_standard_unit(unit_str) {
190                        datum = datum.unit(unit);
191                    }
192                }
193
194                datum.build()
195            })
196            .collect();
197
198        self.client
199            .put_metric_data()
200            .namespace(&self.namespace)
201            .set_metric_data(Some(metric_data))
202            .send()
203            .await
204            .map_err(|e| CloudError::MetricsExport(e.to_string()))?;
205
206        tracing::info!("Successfully sent {} metrics to CloudWatch", metrics.len());
207
208        Ok(())
209    }
210}
211
212#[async_trait]
213impl CloudMetrics for CloudWatchMetrics {
214    async fn export_metrics(&self, metrics: &[Metric]) -> Result<()> {
215        tracing::debug!("Exporting {} metrics to CloudWatch", metrics.len());
216
217        // Send in batches of batch_size
218        for chunk in metrics.chunks(self.batch_size) {
219            self.send_metrics_batch(chunk).await?;
220        }
221
222        Ok(())
223    }
224
225    async fn export_metric(&self, metric: &Metric) -> Result<()> {
226        tracing::debug!("Exporting metric to CloudWatch: {}", metric.name);
227
228        // Add to buffer
229        let mut buffer = self.batch_buffer.write().await;
230        buffer.push(metric.clone());
231
232        // Flush if buffer is full
233        if buffer.len() >= self.batch_size {
234            let metrics_to_send = buffer.drain(..).collect::<Vec<_>>();
235            drop(buffer); // Release lock before network call
236
237            self.send_metrics_batch(&metrics_to_send).await?;
238        }
239
240        Ok(())
241    }
242}
243
244/// AWS CloudWatch Logs implementation of `CloudLogger`.
245///
246/// This implementation provides:
247/// - Batched log export for efficiency
248/// - Structured logging support
249/// - Automatic log stream creation
250/// - Log group configuration
251///
252/// # Example
253///
254/// ```no_run
255/// use llm_shield_cloud_aws::CloudWatchLogger;
256/// use llm_shield_cloud::{CloudLogger, LogLevel};
257///
258/// #[tokio::main]
259/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
260///     let logger = CloudWatchLogger::new(
261///         "/llm-shield/api",
262///         "production-instance-1"
263///     ).await?;
264///
265///     logger.log("Application started", LogLevel::Info).await?;
266///     Ok(())
267/// }
268/// ```
269pub struct CloudWatchLogger {
270    client: CloudWatchLogsClient,
271    log_group: String,
272    log_stream: String,
273    region: String,
274    sequence_token: Arc<RwLock<Option<String>>>,
275    batch_buffer: Arc<RwLock<Vec<LogEntry>>>,
276    batch_size: usize,
277}
278
279impl CloudWatchLogger {
280    /// Creates a new CloudWatch Logs client with default configuration.
281    ///
282    /// # Arguments
283    ///
284    /// * `log_group` - CloudWatch Logs log group name (e.g., "/llm-shield/api")
285    /// * `log_stream` - Log stream name (e.g., "instance-1")
286    ///
287    /// # Errors
288    ///
289    /// Returns error if AWS configuration cannot be loaded or log stream creation fails.
290    pub async fn new(
291        log_group: impl Into<String>,
292        log_stream: impl Into<String>,
293    ) -> Result<Self> {
294        let config = aws_config::load_from_env().await;
295        let region = config
296            .region()
297            .map(|r| r.to_string())
298            .unwrap_or_else(|| "us-east-1".to_string());
299
300        let client = CloudWatchLogsClient::new(&config);
301        let log_group = log_group.into();
302        let log_stream = log_stream.into();
303
304        // Try to create log stream (idempotent if already exists)
305        let _ = client
306            .create_log_stream()
307            .log_group_name(&log_group)
308            .log_stream_name(&log_stream)
309            .send()
310            .await;
311
312        tracing::info!(
313            "Initialized CloudWatch Logs client for log group: {} stream: {} in region: {}",
314            log_group,
315            log_stream,
316            region
317        );
318
319        Ok(Self {
320            client,
321            log_group,
322            log_stream,
323            region,
324            sequence_token: Arc::new(RwLock::new(None)),
325            batch_buffer: Arc::new(RwLock::new(Vec::new())),
326            batch_size: 100, // CloudWatch allows up to 10,000 events
327        })
328    }
329
330    /// Creates a new CloudWatch Logs client with specific configuration.
331    ///
332    /// # Arguments
333    ///
334    /// * `log_group` - CloudWatch Logs log group name
335    /// * `log_stream` - Log stream name
336    /// * `region` - AWS region
337    /// * `batch_size` - Number of log entries to batch before sending
338    ///
339    /// # Errors
340    ///
341    /// Returns error if AWS configuration cannot be loaded.
342    pub async fn new_with_config(
343        log_group: impl Into<String>,
344        log_stream: impl Into<String>,
345        region: impl Into<String>,
346        batch_size: usize,
347    ) -> Result<Self> {
348        let region_str = region.into();
349        let config = aws_config::from_env()
350            .region(aws_config::Region::new(region_str.clone()))
351            .load()
352            .await;
353
354        let client = CloudWatchLogsClient::new(&config);
355        let log_group = log_group.into();
356        let log_stream = log_stream.into();
357
358        // Try to create log stream
359        let _ = client
360            .create_log_stream()
361            .log_group_name(&log_group)
362            .log_stream_name(&log_stream)
363            .send()
364            .await;
365
366        tracing::info!(
367            "Initialized CloudWatch Logs client for log group: {} stream: {} in region: {} (batch size: {})",
368            log_group,
369            log_stream,
370            region_str,
371            batch_size
372        );
373
374        Ok(Self {
375            client,
376            log_group,
377            log_stream,
378            region: region_str,
379            sequence_token: Arc::new(RwLock::new(None)),
380            batch_buffer: Arc::new(RwLock::new(Vec::new())),
381            batch_size,
382        })
383    }
384
385    /// Gets the log group this client is configured for.
386    pub fn log_group(&self) -> &str {
387        &self.log_group
388    }
389
390    /// Gets the log stream this client is configured for.
391    pub fn log_stream(&self) -> &str {
392        &self.log_stream
393    }
394
395    /// Gets the AWS region this client is configured for.
396    pub fn region(&self) -> &str {
397        &self.region
398    }
399
400    /// Flushes buffered log entries to CloudWatch Logs.
401    pub async fn flush(&self) -> Result<()> {
402        let mut buffer = self.batch_buffer.write().await;
403
404        if buffer.is_empty() {
405            return Ok(());
406        }
407
408        let logs_to_send = buffer.drain(..).collect::<Vec<_>>();
409        drop(buffer); // Release lock before network call
410
411        self.send_logs_batch(&logs_to_send).await?;
412
413        Ok(())
414    }
415
416    /// Sends a batch of log entries to CloudWatch Logs.
417    async fn send_logs_batch(&self, entries: &[LogEntry]) -> Result<()> {
418        if entries.is_empty() {
419            return Ok(());
420        }
421
422        tracing::debug!("Sending {} log entries to CloudWatch Logs", entries.len());
423
424        // Convert LogEntry to InputLogEvent
425        let mut log_events: Vec<InputLogEvent> = entries
426            .iter()
427            .map(|entry| {
428                let timestamp = entry
429                    .timestamp
430                    .duration_since(std::time::UNIX_EPOCH)
431                    .unwrap_or_default()
432                    .as_millis() as i64;
433
434                // Format message with structured fields
435                let mut message = format!("[{}] {}", format_log_level(&entry.level), entry.message);
436
437                if !entry.labels.is_empty() {
438                    message.push_str(&format!(" {:?}", entry.labels));
439                }
440
441                if let Some(ref trace_id) = entry.trace_id {
442                    message.push_str(&format!(" trace_id={}", trace_id));
443                }
444
445                if let Some(ref span_id) = entry.span_id {
446                    message.push_str(&format!(" span_id={}", span_id));
447                }
448
449                InputLogEvent::builder()
450                    .timestamp(timestamp)
451                    .message(message)
452                    .build()
453                    .expect("Failed to build InputLogEvent")
454            })
455            .collect();
456
457        // Sort by timestamp (required by CloudWatch)
458        log_events.sort_by_key(|e| e.timestamp);
459
460        // Get current sequence token
461        let sequence_token = self.sequence_token.read().await.clone();
462
463        // Send log events
464        let mut request = self
465            .client
466            .put_log_events()
467            .log_group_name(&self.log_group)
468            .log_stream_name(&self.log_stream)
469            .set_log_events(Some(log_events));
470
471        if let Some(token) = sequence_token {
472            request = request.sequence_token(token);
473        }
474
475        let response = request
476            .send()
477            .await
478            .map_err(|e| CloudError::LogExport(e.to_string()))?;
479
480        // Update sequence token for next request
481        if let Some(next_token) = response.next_sequence_token {
482            *self.sequence_token.write().await = Some(next_token);
483        }
484
485        tracing::info!(
486            "Successfully sent {} log entries to CloudWatch Logs",
487            entries.len()
488        );
489
490        Ok(())
491    }
492}
493
494#[async_trait]
495impl CloudLogger for CloudWatchLogger {
496    async fn log(&self, message: &str, level: LogLevel) -> Result<()> {
497        let entry = LogEntry {
498            timestamp: std::time::SystemTime::now(),
499            level,
500            message: message.to_string(),
501            labels: HashMap::new(),
502            trace_id: None,
503            span_id: None,
504            source: None,
505        };
506
507        self.log_structured(&entry).await
508    }
509
510    async fn log_structured(&self, entry: &LogEntry) -> Result<()> {
511        tracing::debug!("Logging structured entry to CloudWatch Logs");
512
513        // Add to buffer
514        let mut buffer = self.batch_buffer.write().await;
515        buffer.push(entry.clone());
516
517        // Flush if buffer is full
518        if buffer.len() >= self.batch_size {
519            let logs_to_send = buffer.drain(..).collect::<Vec<_>>();
520            drop(buffer); // Release lock before network call
521
522            self.send_logs_batch(&logs_to_send).await?;
523        }
524
525        Ok(())
526    }
527
528    async fn log_batch(&self, entries: &[LogEntry]) -> Result<()> {
529        tracing::debug!("Logging batch of {} entries to CloudWatch Logs", entries.len());
530
531        // Send in batches of batch_size
532        for chunk in entries.chunks(self.batch_size) {
533            self.send_logs_batch(chunk).await?;
534        }
535
536        Ok(())
537    }
538}
539
540/// Parses a string into a CloudWatch StandardUnit.
541fn parse_standard_unit(unit_str: &str) -> Result<StandardUnit> {
542    match unit_str.to_lowercase().as_str() {
543        "seconds" => Ok(StandardUnit::Seconds),
544        "microseconds" => Ok(StandardUnit::Microseconds),
545        "milliseconds" => Ok(StandardUnit::Milliseconds),
546        "bytes" => Ok(StandardUnit::Bytes),
547        "kilobytes" => Ok(StandardUnit::Kilobytes),
548        "megabytes" => Ok(StandardUnit::Megabytes),
549        "gigabytes" => Ok(StandardUnit::Gigabytes),
550        "terabytes" => Ok(StandardUnit::Terabytes),
551        "bits" => Ok(StandardUnit::Bits),
552        "kilobits" => Ok(StandardUnit::Kilobits),
553        "megabits" => Ok(StandardUnit::Megabits),
554        "gigabits" => Ok(StandardUnit::Gigabits),
555        "terabits" => Ok(StandardUnit::Terabits),
556        "percent" => Ok(StandardUnit::Percent),
557        "count" => Ok(StandardUnit::Count),
558        "bytes/second" | "bytes_per_second" => Ok(StandardUnit::BytesSecond),
559        "kilobytes/second" | "kilobytes_per_second" => Ok(StandardUnit::KilobytesSecond),
560        "megabytes/second" | "megabytes_per_second" => Ok(StandardUnit::MegabytesSecond),
561        "gigabytes/second" | "gigabytes_per_second" => Ok(StandardUnit::GigabytesSecond),
562        "terabytes/second" | "terabytes_per_second" => Ok(StandardUnit::TerabytesSecond),
563        "bits/second" | "bits_per_second" => Ok(StandardUnit::BitsSecond),
564        "kilobits/second" | "kilobits_per_second" => Ok(StandardUnit::KilobitsSecond),
565        "megabits/second" | "megabits_per_second" => Ok(StandardUnit::MegabitsSecond),
566        "gigabits/second" | "gigabits_per_second" => Ok(StandardUnit::GigabitsSecond),
567        "terabits/second" | "terabits_per_second" => Ok(StandardUnit::TerabitsSecond),
568        "count/second" | "count_per_second" => Ok(StandardUnit::CountSecond),
569        "none" => Ok(StandardUnit::None),
570        _ => Ok(StandardUnit::None),
571    }
572}
573
574/// Formats a LogLevel as a string.
575fn format_log_level(level: &LogLevel) -> &'static str {
576    match level {
577        LogLevel::Trace => "TRACE",
578        LogLevel::Debug => "DEBUG",
579        LogLevel::Info => "INFO",
580        LogLevel::Warn => "WARN",
581        LogLevel::Error => "ERROR",
582        LogLevel::Fatal => "FATAL",
583    }
584}
585
586#[cfg(test)]
587mod tests {
588    use super::*;
589
590    #[test]
591    fn test_parse_standard_unit() {
592        assert!(matches!(
593            parse_standard_unit("count"),
594            Ok(StandardUnit::Count)
595        ));
596        assert!(matches!(
597            parse_standard_unit("bytes"),
598            Ok(StandardUnit::Bytes)
599        ));
600        assert!(matches!(
601            parse_standard_unit("seconds"),
602            Ok(StandardUnit::Seconds)
603        ));
604        assert!(matches!(
605            parse_standard_unit("percent"),
606            Ok(StandardUnit::Percent)
607        ));
608        assert!(matches!(
609            parse_standard_unit("invalid"),
610            Ok(StandardUnit::None)
611        ));
612    }
613
614    #[test]
615    fn test_format_log_level() {
616        assert_eq!(format_log_level(&LogLevel::Trace), "TRACE");
617        assert_eq!(format_log_level(&LogLevel::Debug), "DEBUG");
618        assert_eq!(format_log_level(&LogLevel::Info), "INFO");
619        assert_eq!(format_log_level(&LogLevel::Warn), "WARN");
620        assert_eq!(format_log_level(&LogLevel::Error), "ERROR");
621        assert_eq!(format_log_level(&LogLevel::Fatal), "FATAL");
622    }
623
624    #[test]
625    fn test_batch_size_limits() {
626        let metrics_batch_size = 20;
627        let logs_batch_size = 100;
628
629        // CloudWatch limits
630        assert!(metrics_batch_size <= 1000);
631        assert!(logs_batch_size <= 10000);
632    }
633
634    #[tokio::test]
635    async fn test_metric_batching() {
636        let metrics = vec![
637            Metric {
638                name: "test1".to_string(),
639                value: 1.0,
640                timestamp: 1000,
641                dimensions: HashMap::new(),
642                unit: Some("Count".to_string()),
643            },
644            Metric {
645                name: "test2".to_string(),
646                value: 2.0,
647                timestamp: 2000,
648                dimensions: HashMap::new(),
649                unit: Some("Count".to_string()),
650            },
651        ];
652
653        // Test chunking logic
654        let batch_size = 1;
655        let chunks: Vec<_> = metrics.chunks(batch_size).collect();
656
657        assert_eq!(chunks.len(), 2);
658        assert_eq!(chunks[0].len(), 1);
659        assert_eq!(chunks[1].len(), 1);
660    }
661
662    #[tokio::test]
663    async fn test_log_entry_sorting() {
664        let mut entries = vec![
665            LogEntry {
666                timestamp: std::time::UNIX_EPOCH + std::time::Duration::from_secs(2000),
667                level: LogLevel::Info,
668                message: "second".to_string(),
669                labels: HashMap::new(),
670                trace_id: None,
671                span_id: None,
672            },
673            LogEntry {
674                timestamp: std::time::UNIX_EPOCH + std::time::Duration::from_secs(1000),
675                level: LogLevel::Info,
676                message: "first".to_string(),
677                labels: HashMap::new(),
678                trace_id: None,
679                span_id: None,
680            },
681        ];
682
683        entries.sort_by_key(|e| {
684            e.timestamp
685                .duration_since(std::time::UNIX_EPOCH)
686                .unwrap()
687                .as_millis()
688        });
689
690        assert_eq!(entries[0].message, "first");
691        assert_eq!(entries[1].message, "second");
692    }
693}