Skip to main content

scouter_drift/custom/
drift.rs

1use crate::error::DriftError;
2use chrono::{DateTime, Utc};
3use scouter_dispatch::AlertDispatcher;
4use scouter_sql::sql::traits::CustomMetricSqlLogic;
5use scouter_sql::{sql::cache::entity_cache, PostgresClient};
6use scouter_types::custom::{ComparisonMetricAlert, CustomDriftProfile};
7use scouter_types::{AlertMap, ProfileBaseArgs};
8use sqlx::{Pool, Postgres};
9use std::collections::HashMap;
10use tracing::error;
11use tracing::info;
12pub struct CustomDrifter {
13    profile: CustomDriftProfile,
14}
15
16impl CustomDrifter {
17    pub fn new(profile: CustomDriftProfile) -> Self {
18        Self { profile }
19    }
20
21    /// Helper method to format profile identifier for logging
22    fn profile_id(&self) -> String {
23        format!(
24            "{}/{}/{}",
25            self.profile.space(),
26            self.profile.name(),
27            self.profile.version()
28        )
29    }
30
31    /// Fetches observed custom metric values from database for the given time limit
32    pub async fn get_observed_custom_metric_values(
33        &self,
34        limit_datetime: &DateTime<Utc>,
35        db_pool: &Pool<Postgres>,
36    ) -> Result<HashMap<String, f64>, DriftError> {
37        let metrics: Vec<String> = self.profile.metrics.keys().cloned().collect();
38        let entity_id = entity_cache()
39            .get_entity_id_from_uid(db_pool, &self.profile.config.uid)
40            .await?;
41
42        PostgresClient::get_custom_metric_values(db_pool, limit_datetime, &metrics, &entity_id)
43            .await
44            .inspect_err(|e| {
45                error!(
46                    "Unable to obtain custom metric data from DB for {}: {}",
47                    self.profile_id(),
48                    e
49                );
50            })
51            .map_err(Into::into)
52    }
53
54    /// Retrieves metric map and logs if no data found
55    pub async fn get_metric_map(
56        &self,
57        limit_datetime: &DateTime<Utc>,
58        db_pool: &Pool<Postgres>,
59    ) -> Result<Option<HashMap<String, f64>>, DriftError> {
60        let metric_map = self
61            .get_observed_custom_metric_values(limit_datetime, db_pool)
62            .await?;
63
64        if metric_map.is_empty() {
65            info!(
66                "No custom metric data found for {}. Skipping alert processing.",
67                self.profile_id()
68            );
69            return Ok(None);
70        }
71
72        Ok(Some(metric_map))
73    }
74
75    /// Generates alerts for metrics that exceed their thresholds and dispatches them
76    pub async fn generate_alerts(
77        &self,
78        metric_map: &HashMap<String, f64>,
79    ) -> Result<Option<Vec<AlertMap>>, DriftError> {
80        // Early return if no alert conditions configured
81        let Some(alert_conditions) = &self.profile.config.alert_config.alert_conditions else {
82            info!(
83                "No alert conditions configured for {}. Skipping alert processing.",
84                self.profile_id()
85            );
86            return Ok(None);
87        };
88
89        // Collect metrics that should trigger alerts
90        let metric_alerts: Vec<ComparisonMetricAlert> = metric_map
91            .iter()
92            .filter_map(|(name, observed_value)| {
93                let alert_condition = alert_conditions.get(name)?;
94
95                if alert_condition.should_alert(*observed_value) {
96                    Some(ComparisonMetricAlert {
97                        metric_name: name.clone(),
98                        baseline_value: alert_condition.baseline_value,
99                        observed_value: *observed_value,
100                        delta: alert_condition.delta,
101                        alert_threshold: alert_condition.alert_threshold.clone(),
102                    })
103                } else {
104                    None
105                }
106            })
107            .collect();
108
109        // Early return if no alerts to process
110        if metric_alerts.is_empty() {
111            info!(
112                "No alerts to process for {} (checked {} metrics)",
113                self.profile_id(),
114                metric_map.len()
115            );
116            return Ok(None);
117        }
118
119        // Dispatch alerts
120        let alert_dispatcher = AlertDispatcher::new(&self.profile.config).inspect_err(|e| {
121            error!(
122                "Error creating alert dispatcher for {}: {}",
123                self.profile_id(),
124                e
125            );
126        })?;
127
128        for alert in &metric_alerts {
129            alert_dispatcher
130                .process_alerts(alert)
131                .await
132                .inspect_err(|e| {
133                    error!(
134                        "Error processing alert for metric '{}' in {}: {}",
135                        alert.metric_name,
136                        self.profile_id(),
137                        e
138                    );
139                })?;
140        }
141
142        // Convert to owned maps before returning
143        Ok(Some(metric_alerts.into_iter().map(|a| a.into()).collect()))
144    }
145
146    /// Checks for alerts based on metric values since previous run
147    pub async fn check_for_alerts(
148        &self,
149        db_pool: &Pool<Postgres>,
150        previous_run: &DateTime<Utc>,
151    ) -> Result<Option<Vec<AlertMap>>, DriftError> {
152        let Some(metric_map) = self.get_metric_map(previous_run, db_pool).await? else {
153            return Ok(None);
154        };
155
156        self.generate_alerts(&metric_map).await.inspect_err(|e| {
157            error!("Error generating alerts for {}: {}", self.profile_id(), e);
158        })
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use scouter_types::custom::{CustomMetric, CustomMetricAlertConfig, CustomMetricDriftConfig};
166    use scouter_types::AlertThreshold;
167
168    fn get_test_drifter() -> CustomDrifter {
169        let custom_metrics = vec![
170            CustomMetric::new("mse", 12.02, AlertThreshold::Above, Some(1.0)).unwrap(),
171            CustomMetric::new("accuracy", 0.75, AlertThreshold::Below, None).unwrap(),
172        ];
173
174        let drift_config = CustomMetricDriftConfig::new(
175            "scouter",
176            "model",
177            "0.1.0",
178            25,
179            CustomMetricAlertConfig::default(),
180            None,
181        )
182        .unwrap();
183
184        let profile = CustomDriftProfile::new(drift_config, custom_metrics).unwrap();
185
186        CustomDrifter::new(profile)
187    }
188
189    #[tokio::test]
190    async fn test_generate_alerts_triggers_for_out_of_bounds_metrics() {
191        let drifter = get_test_drifter();
192
193        let mut metric_map = HashMap::new();
194        // mse baseline: 12.02, threshold: Above, delta: 1.0
195        // This should trigger (14.0 > 12.02 + 1.0)
196        metric_map.insert("mse".to_string(), 14.0);
197        // accuracy baseline: 0.75, threshold: Below, delta: None
198        // This should trigger (0.65 < 0.75)
199        metric_map.insert("accuracy".to_string(), 0.65);
200
201        let alerts = drifter
202            .generate_alerts(&metric_map)
203            .await
204            .expect("Should generate alerts successfully")
205            .expect("Should have alerts");
206
207        assert_eq!(alerts.len(), 2, "Should generate 2 alerts");
208
209        // Verify alert contents
210        let has_mse_alert = alerts
211            .iter()
212            .any(|a| matches!(a, AlertMap::Custom(alert) if alert.metric_name == "mse"));
213        let has_accuracy_alert = alerts
214            .iter()
215            .any(|a| matches!(a, AlertMap::Custom(alert) if alert.metric_name == "accuracy"));
216        assert!(has_mse_alert, "Should have MSE alert");
217        assert!(has_accuracy_alert, "Should have accuracy alert");
218    }
219
220    #[tokio::test]
221    async fn test_generate_alerts_no_trigger_within_threshold() {
222        let drifter = get_test_drifter();
223
224        let mut metric_map = HashMap::new();
225        // mse baseline: 12.02, threshold: Above, delta: 1.0
226        // This should NOT trigger (12.5 < 12.02 + 1.0)
227        metric_map.insert("mse".to_string(), 12.5);
228        // accuracy baseline: 0.75, threshold: Below
229        // This should NOT trigger (0.76 > 0.75)
230        metric_map.insert("accuracy".to_string(), 0.76);
231
232        let alerts = drifter
233            .generate_alerts(&metric_map)
234            .await
235            .expect("Should handle no-alert case successfully");
236
237        assert!(
238            alerts.is_none(),
239            "Should not generate alerts for values within threshold"
240        );
241    }
242
243    #[tokio::test]
244    async fn test_generate_alerts_partial_triggers() {
245        let drifter = get_test_drifter();
246
247        let mut metric_map = HashMap::new();
248        // Only MSE should trigger
249        metric_map.insert("mse".to_string(), 14.0);
250        // Accuracy within bounds
251        metric_map.insert("accuracy".to_string(), 0.76);
252
253        let alerts = drifter
254            .generate_alerts(&metric_map)
255            .await
256            .expect("Should generate alerts successfully")
257            .expect("Should have alerts");
258
259        assert_eq!(alerts.len(), 1, "Should generate 1 alert");
260        assert!(
261            matches!(&alerts[0], AlertMap::Custom(alert) if alert.metric_name == "mse"),
262            "Alert should be for MSE metric"
263        );
264    }
265
266    #[test]
267    fn test_profile_id_formatting() {
268        let drifter = get_test_drifter();
269        let profile_id = drifter.profile_id();
270
271        assert_eq!(profile_id, "scouter/model/0.1.0");
272    }
273}