1#[cfg(feature = "sql")]
2pub mod drift_executor {
3
4 use crate::error::DriftError;
5 use crate::{custom::CustomDrifter, psi::PsiDrifter, spc::SpcDrifter};
6 use chrono::{DateTime, Utc};
7
8 use scouter_sql::sql::traits::{AlertSqlLogic, ProfileSqlLogic};
9 use scouter_sql::{sql::schema::TaskRequest, PostgresClient};
10 use scouter_types::{DriftProfile, DriftTaskInfo, DriftType};
11 use sqlx::{Pool, Postgres};
12 use std::collections::BTreeMap;
13 use std::result::Result;
14 use std::result::Result::Ok;
15 use std::str::FromStr;
16 use tracing::{debug, error, info, span, Instrument, Level};
17
18 #[allow(clippy::enum_variant_names)]
19 pub enum Drifter {
20 SpcDrifter(SpcDrifter),
21 PsiDrifter(PsiDrifter),
22 CustomDrifter(CustomDrifter),
23 }
24
25 impl Drifter {
26 pub async fn check_for_alerts(
27 &self,
28 db_pool: &Pool<Postgres>,
29 previous_run: DateTime<Utc>,
30 ) -> Result<Option<Vec<BTreeMap<String, String>>>, DriftError> {
31 match self {
32 Drifter::SpcDrifter(drifter) => {
33 drifter.check_for_alerts(db_pool, previous_run).await
34 }
35 Drifter::PsiDrifter(drifter) => {
36 drifter.check_for_alerts(db_pool, previous_run).await
37 }
38 Drifter::CustomDrifter(drifter) => {
39 drifter.check_for_alerts(db_pool, previous_run).await
40 }
41 }
42 }
43 }
44
45 pub trait GetDrifter {
46 fn get_drifter(&self) -> Drifter;
47 }
48
49 impl GetDrifter for DriftProfile {
50 fn get_drifter(&self) -> Drifter {
62 match self {
63 DriftProfile::Spc(profile) => Drifter::SpcDrifter(SpcDrifter::new(profile.clone())),
64 DriftProfile::Psi(profile) => Drifter::PsiDrifter(PsiDrifter::new(profile.clone())),
65 DriftProfile::Custom(profile) => {
66 Drifter::CustomDrifter(CustomDrifter::new(profile.clone()))
67 }
68 }
69 }
70 }
71
72 pub struct DriftExecutor {
73 db_pool: Pool<Postgres>,
74 }
75
76 impl DriftExecutor {
77 pub fn new(db_pool: &Pool<Postgres>) -> Self {
78 Self {
79 db_pool: db_pool.clone(),
80 }
81 }
82
83 pub async fn _process_task(
95 &mut self,
96 profile: DriftProfile,
97 previous_run: DateTime<Utc>,
98 ) -> Result<Option<Vec<BTreeMap<String, String>>>, DriftError> {
99 profile
102 .get_drifter()
103 .check_for_alerts(&self.db_pool, previous_run)
104 .await
105 }
106
107 async fn do_poll(&mut self) -> Result<Option<TaskRequest>, DriftError> {
108 debug!("Polling for drift tasks");
109
110 let task = PostgresClient::get_drift_profile_task(&self.db_pool).await?;
112
113 let Some(task) = task else {
114 return Ok(None);
115 };
116
117 let task_info = DriftTaskInfo {
118 space: task.space.clone(),
119 name: task.name.clone(),
120 version: task.version.clone(),
121 uid: task.uid.clone(),
122 drift_type: DriftType::from_str(&task.drift_type).unwrap(),
123 };
124
125 info!(
126 "Processing drift task for profile: {}/{}/{} and type {}",
127 task.space, task.name, task.version, task.drift_type
128 );
129
130 self.process_task(&task, &task_info).await?;
131
132 PostgresClient::update_drift_profile_run_dates(
134 &self.db_pool,
135 &task_info,
136 &task.schedule,
137 )
138 .instrument(span!(Level::INFO, "Update Run Dates"))
139 .await?;
140
141 Ok(Some(task))
142 }
143
144 async fn process_task(
145 &mut self,
146 task: &TaskRequest,
147 task_info: &DriftTaskInfo,
148 ) -> Result<(), DriftError> {
149 let drift_type = DriftType::from_str(&task.drift_type).inspect_err(|e| {
151 error!("Error converting drift type: {:?}", e);
152 })?;
153
154 let profile = DriftProfile::from_str(drift_type.clone(), task.profile.clone())
156 .inspect_err(|e| {
157 error!("Error converting drift profile: {:?}", e);
158 })?;
159
160 match self._process_task(profile, task.previous_run).await {
162 Ok(Some(alerts)) => {
163 info!("Drift task processed successfully with alerts");
164
165 for alert in alerts {
167 PostgresClient::insert_drift_alert(
168 &self.db_pool,
169 task_info,
170 alert.get("entity_name").unwrap_or(&"NA".to_string()),
171 &alert,
172 &drift_type,
173 )
174 .await
175 .map_err(|e| {
176 error!("Error inserting drift alert: {:?}", e);
177 DriftError::SqlError(e)
178 })?;
179 }
180 Ok(())
181 }
182 Ok(None) => {
183 info!("Drift task processed successfully with no alerts");
184 Ok(())
185 }
186 Err(e) => {
187 error!("Error processing drift task: {:?}", e);
188 Err(DriftError::AlertProcessingError(e.to_string()))
189 }
190 }
191 }
192
193 pub async fn poll_for_tasks(&mut self) -> Result<(), DriftError> {
199 match self.do_poll().await? {
200 Some(_) => {
201 info!("Successfully processed drift task");
202 Ok(())
203 }
204 None => {
205 info!("No triggered schedules found in db. Sleeping for 10 seconds");
206 tokio::time::sleep(tokio::time::Duration::from_secs(10)).await;
207 Ok(())
208 }
209 }
210 }
211 }
212
213 #[cfg(test)]
214 mod tests {
215 use super::*;
216 use rusty_logging::logger::{LogLevel, LoggingConfig, RustyLogger};
217 use scouter_settings::DatabaseSettings;
218 use scouter_sql::PostgresClient;
219 use scouter_types::DriftAlertRequest;
220 use sqlx::{postgres::Postgres, Pool};
221
222 pub async fn cleanup(pool: &Pool<Postgres>) {
223 sqlx::raw_sql(
224 r#"
225 DELETE
226 FROM scouter.spc_drift;
227
228 DELETE
229 FROM scouter.observability_metric;
230
231 DELETE
232 FROM scouter.custom_drift;
233
234 DELETE
235 FROM scouter.drift_alert;
236
237 DELETE
238 FROM scouter.drift_profile;
239
240 DELETE
241 FROM scouter.psi_drift;
242 "#,
243 )
244 .fetch_all(pool)
245 .await
246 .unwrap();
247
248 RustyLogger::setup_logging(Some(LoggingConfig::new(
249 None,
250 Some(LogLevel::Info),
251 None,
252 None,
253 )))
254 .unwrap();
255 }
256
257 #[tokio::test]
258 async fn test_drift_executor_spc() {
259 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
260 .await
261 .unwrap();
262
263 cleanup(&db_pool).await;
264
265 let mut populate_path =
266 std::env::current_dir().expect("Failed to get current directory");
267 populate_path.push("src/scripts/populate_spc.sql");
268
269 let script = std::fs::read_to_string(populate_path).unwrap();
270 sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
271
272 let mut drift_executor = DriftExecutor::new(&db_pool);
273
274 drift_executor.poll_for_tasks().await.unwrap();
275
276 let request = DriftAlertRequest {
278 space: "statworld".to_string(),
279 name: "test_app".to_string(),
280 version: "0.1.0".to_string(),
281 limit_datetime: None,
282 active: None,
283 limit: None,
284 };
285 let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
286 .await
287 .unwrap();
288 assert!(!alerts.is_empty());
289 }
290
291 #[tokio::test]
292 async fn test_drift_executor_spc_missing_feature_data() {
293 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
297 .await
298 .unwrap();
299 cleanup(&db_pool).await;
300
301 let mut populate_path =
302 std::env::current_dir().expect("Failed to get current directory");
303 populate_path.push("src/scripts/populate_spc_alert.sql");
304
305 let script = std::fs::read_to_string(populate_path).unwrap();
306 sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
307
308 let mut drift_executor = DriftExecutor::new(&db_pool);
309
310 drift_executor.poll_for_tasks().await.unwrap();
311
312 let request = DriftAlertRequest {
314 space: "statworld".to_string(),
315 name: "test_app".to_string(),
316 version: "0.1.0".to_string(),
317 limit_datetime: None,
318 active: None,
319 limit: None,
320 };
321 let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
322 .await
323 .unwrap();
324
325 assert!(!alerts.is_empty());
326 }
327
328 #[tokio::test]
329 async fn test_drift_executor_psi() {
330 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
331 .await
332 .unwrap();
333
334 cleanup(&db_pool).await;
335
336 let mut populate_path =
337 std::env::current_dir().expect("Failed to get current directory");
338 populate_path.push("src/scripts/populate_psi.sql");
339
340 let mut script = std::fs::read_to_string(populate_path).unwrap();
341 let bin_count = 1000;
342 let skew_feature = "feature_1";
343 let skew_factor = 10;
344 let apply_skew = true;
345 script = script.replace("{{bin_count}}", &bin_count.to_string());
346 script = script.replace("{{skew_feature}}", skew_feature);
347 script = script.replace("{{skew_factor}}", &skew_factor.to_string());
348 script = script.replace("{{apply_skew}}", &apply_skew.to_string());
349 sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
350
351 let mut drift_executor = DriftExecutor::new(&db_pool);
352
353 drift_executor.poll_for_tasks().await.unwrap();
354
355 let request = DriftAlertRequest {
357 space: "scouter".to_string(),
358 name: "model".to_string(),
359 version: "0.1.0".to_string(),
360 limit_datetime: None,
361 active: None,
362 limit: None,
363 };
364 let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
365 .await
366 .unwrap();
367
368 assert_eq!(alerts.len(), 1);
369 }
370
371 #[tokio::test]
384 async fn test_drift_executor_psi_not_enough_target_samples() {
385 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
386 .await
387 .unwrap();
388
389 cleanup(&db_pool).await;
390
391 let mut populate_path =
392 std::env::current_dir().expect("Failed to get current directory");
393 populate_path.push("src/scripts/populate_psi.sql");
394
395 let mut script = std::fs::read_to_string(populate_path).unwrap();
396 let bin_count = 2;
397 let skew_feature = "feature_1";
398 let skew_factor = 1;
399 let apply_skew = false;
400 script = script.replace("{{bin_count}}", &bin_count.to_string());
401 script = script.replace("{{skew_feature}}", skew_feature);
402 script = script.replace("{{skew_factor}}", &skew_factor.to_string());
403 script = script.replace("{{apply_skew}}", &apply_skew.to_string());
404 sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
405
406 let mut drift_executor = DriftExecutor::new(&db_pool);
407
408 drift_executor.poll_for_tasks().await.unwrap();
409
410 let request = DriftAlertRequest {
412 space: "scouter".to_string(),
413 name: "model".to_string(),
414 version: "0.1.0".to_string(),
415 limit_datetime: None,
416 active: None,
417 limit: None,
418 };
419 let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
420 .await
421 .unwrap();
422
423 assert!(alerts.is_empty());
424 }
425
426 #[tokio::test]
427 async fn test_drift_executor_custom() {
428 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
429 .await
430 .unwrap();
431
432 cleanup(&db_pool).await;
433
434 let mut populate_path =
435 std::env::current_dir().expect("Failed to get current directory");
436 populate_path.push("src/scripts/populate_custom.sql");
437
438 let script = std::fs::read_to_string(populate_path).unwrap();
439 sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
440
441 let mut drift_executor = DriftExecutor::new(&db_pool);
442
443 drift_executor.poll_for_tasks().await.unwrap();
444
445 let request = DriftAlertRequest {
447 space: "scouter".to_string(),
448 name: "model".to_string(),
449 version: "0.1.0".to_string(),
450 limit_datetime: None,
451 active: None,
452 limit: None,
453 };
454 let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
455 .await
456 .unwrap();
457
458 assert_eq!(alerts.len(), 1);
459 }
460 }
461}