1use crate::error::DriftError;
2use chrono::{DateTime, Utc};
3use scouter_dispatch::AlertDispatcher;
4use scouter_sql::sql::traits::CustomMetricSqlLogic;
5use scouter_sql::{sql::cache::entity_cache, PostgresClient};
6use scouter_types::custom::{ComparisonMetricAlert, CustomDriftProfile};
7use scouter_types::{AlertMap, ProfileBaseArgs};
8use sqlx::{Pool, Postgres};
9use std::collections::HashMap;
10use tracing::error;
11use tracing::info;
12pub struct CustomDrifter {
13 profile: CustomDriftProfile,
14}
15
16impl CustomDrifter {
17 pub fn new(profile: CustomDriftProfile) -> Self {
18 Self { profile }
19 }
20
21 fn profile_id(&self) -> String {
23 format!(
24 "{}/{}/{}",
25 self.profile.space(),
26 self.profile.name(),
27 self.profile.version()
28 )
29 }
30
31 pub async fn get_observed_custom_metric_values(
33 &self,
34 limit_datetime: &DateTime<Utc>,
35 db_pool: &Pool<Postgres>,
36 ) -> Result<HashMap<String, f64>, DriftError> {
37 let metrics: Vec<String> = self.profile.metrics.keys().cloned().collect();
38 let entity_id = entity_cache()
39 .get_entity_id_from_uid(db_pool, &self.profile.config.uid)
40 .await?;
41
42 PostgresClient::get_custom_metric_values(db_pool, limit_datetime, &metrics, &entity_id)
43 .await
44 .inspect_err(|e| {
45 error!(
46 "Unable to obtain custom metric data from DB for {}: {}",
47 self.profile_id(),
48 e
49 );
50 })
51 .map_err(Into::into)
52 }
53
54 pub async fn get_metric_map(
56 &self,
57 limit_datetime: &DateTime<Utc>,
58 db_pool: &Pool<Postgres>,
59 ) -> Result<Option<HashMap<String, f64>>, DriftError> {
60 let metric_map = self
61 .get_observed_custom_metric_values(limit_datetime, db_pool)
62 .await?;
63
64 if metric_map.is_empty() {
65 info!(
66 "No custom metric data found for {}. Skipping alert processing.",
67 self.profile_id()
68 );
69 return Ok(None);
70 }
71
72 Ok(Some(metric_map))
73 }
74
75 pub async fn generate_alerts(
77 &self,
78 metric_map: &HashMap<String, f64>,
79 ) -> Result<Option<Vec<AlertMap>>, DriftError> {
80 let Some(alert_conditions) = &self.profile.config.alert_config.alert_conditions else {
82 info!(
83 "No alert conditions configured for {}. Skipping alert processing.",
84 self.profile_id()
85 );
86 return Ok(None);
87 };
88
89 let metric_alerts: Vec<ComparisonMetricAlert> = metric_map
91 .iter()
92 .filter_map(|(name, observed_value)| {
93 let alert_condition = alert_conditions.get(name)?;
94
95 if alert_condition.should_alert(*observed_value) {
96 Some(ComparisonMetricAlert {
97 metric_name: name.clone(),
98 baseline_value: alert_condition.baseline_value,
99 observed_value: *observed_value,
100 delta: alert_condition.delta,
101 alert_threshold: alert_condition.alert_threshold.clone(),
102 })
103 } else {
104 None
105 }
106 })
107 .collect();
108
109 if metric_alerts.is_empty() {
111 info!(
112 "No alerts to process for {} (checked {} metrics)",
113 self.profile_id(),
114 metric_map.len()
115 );
116 return Ok(None);
117 }
118
119 let alert_dispatcher = AlertDispatcher::new(&self.profile.config).inspect_err(|e| {
121 error!(
122 "Error creating alert dispatcher for {}: {}",
123 self.profile_id(),
124 e
125 );
126 })?;
127
128 for alert in &metric_alerts {
129 alert_dispatcher
130 .process_alerts(alert)
131 .await
132 .inspect_err(|e| {
133 error!(
134 "Error processing alert for metric '{}' in {}: {}",
135 alert.metric_name,
136 self.profile_id(),
137 e
138 );
139 })?;
140 }
141
142 Ok(Some(metric_alerts.into_iter().map(|a| a.into()).collect()))
144 }
145
146 pub async fn check_for_alerts(
148 &self,
149 db_pool: &Pool<Postgres>,
150 previous_run: &DateTime<Utc>,
151 ) -> Result<Option<Vec<AlertMap>>, DriftError> {
152 let Some(metric_map) = self.get_metric_map(previous_run, db_pool).await? else {
153 return Ok(None);
154 };
155
156 self.generate_alerts(&metric_map).await.inspect_err(|e| {
157 error!("Error generating alerts for {}: {}", self.profile_id(), e);
158 })
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use scouter_types::custom::{CustomMetric, CustomMetricAlertConfig, CustomMetricDriftConfig};
166 use scouter_types::AlertThreshold;
167
168 fn get_test_drifter() -> CustomDrifter {
169 let custom_metrics = vec![
170 CustomMetric::new("mse", 12.02, AlertThreshold::Above, Some(1.0)).unwrap(),
171 CustomMetric::new("accuracy", 0.75, AlertThreshold::Below, None).unwrap(),
172 ];
173
174 let drift_config = CustomMetricDriftConfig::new(
175 "scouter",
176 "model",
177 "0.1.0",
178 25,
179 CustomMetricAlertConfig::default(),
180 None,
181 )
182 .unwrap();
183
184 let profile = CustomDriftProfile::new(drift_config, custom_metrics).unwrap();
185
186 CustomDrifter::new(profile)
187 }
188
189 #[tokio::test]
190 async fn test_generate_alerts_triggers_for_out_of_bounds_metrics() {
191 let drifter = get_test_drifter();
192
193 let mut metric_map = HashMap::new();
194 metric_map.insert("mse".to_string(), 14.0);
197 metric_map.insert("accuracy".to_string(), 0.65);
200
201 let alerts = drifter
202 .generate_alerts(&metric_map)
203 .await
204 .expect("Should generate alerts successfully")
205 .expect("Should have alerts");
206
207 assert_eq!(alerts.len(), 2, "Should generate 2 alerts");
208
209 let has_mse_alert = alerts
211 .iter()
212 .any(|a| matches!(a, AlertMap::Custom(alert) if alert.metric_name == "mse"));
213 let has_accuracy_alert = alerts
214 .iter()
215 .any(|a| matches!(a, AlertMap::Custom(alert) if alert.metric_name == "accuracy"));
216 assert!(has_mse_alert, "Should have MSE alert");
217 assert!(has_accuracy_alert, "Should have accuracy alert");
218 }
219
220 #[tokio::test]
221 async fn test_generate_alerts_no_trigger_within_threshold() {
222 let drifter = get_test_drifter();
223
224 let mut metric_map = HashMap::new();
225 metric_map.insert("mse".to_string(), 12.5);
228 metric_map.insert("accuracy".to_string(), 0.76);
231
232 let alerts = drifter
233 .generate_alerts(&metric_map)
234 .await
235 .expect("Should handle no-alert case successfully");
236
237 assert!(
238 alerts.is_none(),
239 "Should not generate alerts for values within threshold"
240 );
241 }
242
243 #[tokio::test]
244 async fn test_generate_alerts_partial_triggers() {
245 let drifter = get_test_drifter();
246
247 let mut metric_map = HashMap::new();
248 metric_map.insert("mse".to_string(), 14.0);
250 metric_map.insert("accuracy".to_string(), 0.76);
252
253 let alerts = drifter
254 .generate_alerts(&metric_map)
255 .await
256 .expect("Should generate alerts successfully")
257 .expect("Should have alerts");
258
259 assert_eq!(alerts.len(), 1, "Should generate 1 alert");
260 assert!(
261 matches!(&alerts[0], AlertMap::Custom(alert) if alert.metric_name == "mse"),
262 "Alert should be for MSE metric"
263 );
264 }
265
266 #[test]
267 fn test_profile_id_formatting() {
268 let drifter = get_test_drifter();
269 let profile_id = drifter.profile_id();
270
271 assert_eq!(profile_id, "scouter/model/0.1.0");
272 }
273}