scouter_drift/
drifter.rs

1#[cfg(feature = "sql")]
2pub mod drift_executor {
3
4    use crate::error::DriftError;
5    use crate::{custom::CustomDrifter, llm::LLMDrifter, psi::PsiDrifter, spc::SpcDrifter};
6    use chrono::{DateTime, Utc};
7
8    use scouter_sql::sql::traits::{AlertSqlLogic, ProfileSqlLogic};
9    use scouter_sql::{sql::schema::TaskRequest, PostgresClient};
10    use scouter_types::{DriftProfile, DriftTaskInfo, DriftType};
11    use sqlx::{Pool, Postgres};
12    use std::collections::BTreeMap;
13    use std::result::Result;
14    use std::result::Result::Ok;
15    use std::str::FromStr;
16    use tracing::{debug, error, info, instrument, span, Instrument, Level};
17
18    #[allow(clippy::enum_variant_names)]
19    pub enum Drifter {
20        SpcDrifter(SpcDrifter),
21        PsiDrifter(PsiDrifter),
22        CustomDrifter(CustomDrifter),
23        LLMDrifter(LLMDrifter),
24    }
25
26    impl Drifter {
27        pub async fn check_for_alerts(
28            &self,
29            db_pool: &Pool<Postgres>,
30            previous_run: DateTime<Utc>,
31        ) -> Result<Option<Vec<BTreeMap<String, String>>>, DriftError> {
32            match self {
33                Drifter::SpcDrifter(drifter) => {
34                    drifter.check_for_alerts(db_pool, previous_run).await
35                }
36                Drifter::PsiDrifter(drifter) => {
37                    drifter.check_for_alerts(db_pool, previous_run).await
38                }
39                Drifter::CustomDrifter(drifter) => {
40                    drifter.check_for_alerts(db_pool, previous_run).await
41                }
42                Drifter::LLMDrifter(drifter) => {
43                    drifter.check_for_alerts(db_pool, previous_run).await
44                }
45            }
46        }
47    }
48
49    pub trait GetDrifter {
50        fn get_drifter(&self) -> Drifter;
51    }
52
53    impl GetDrifter for DriftProfile {
54        /// Get a Drifter for processing drift profile tasks
55        ///
56        /// # Arguments
57        ///
58        /// * `name` - Name of the drift profile
59        /// * `space` - Space of the drift profile
60        /// * `version` - Version of the drift profile
61        ///
62        /// # Returns
63        ///
64        /// * `Drifter` - Drifter enum
65        fn get_drifter(&self) -> Drifter {
66            match self {
67                DriftProfile::Spc(profile) => Drifter::SpcDrifter(SpcDrifter::new(profile.clone())),
68                DriftProfile::Psi(profile) => Drifter::PsiDrifter(PsiDrifter::new(profile.clone())),
69                DriftProfile::Custom(profile) => {
70                    Drifter::CustomDrifter(CustomDrifter::new(profile.clone()))
71                }
72                DriftProfile::LLM(profile) => Drifter::LLMDrifter(LLMDrifter::new(profile.clone())),
73            }
74        }
75    }
76
77    pub struct DriftExecutor {
78        db_pool: Pool<Postgres>,
79    }
80
81    impl DriftExecutor {
82        pub fn new(db_pool: &Pool<Postgres>) -> Self {
83            Self {
84                db_pool: db_pool.clone(),
85            }
86        }
87
88        /// Process a single drift computation task
89        ///
90        /// # Arguments
91        ///
92        /// * `drift_profile` - Drift profile to compute drift for
93        /// * `previous_run` - Previous run timestamp
94        /// * `schedule` - Schedule for drift computation
95        /// * `transaction` - Postgres transaction
96        ///
97        /// # Returns
98        ///
99        pub async fn _process_task(
100            &mut self,
101            profile: DriftProfile,
102            previous_run: DateTime<Utc>,
103        ) -> Result<Option<Vec<BTreeMap<String, String>>>, DriftError> {
104            // match Drifter enum
105
106            profile
107                .get_drifter()
108                .check_for_alerts(&self.db_pool, previous_run)
109                .await
110        }
111
112        async fn do_poll(&mut self) -> Result<Option<TaskRequest>, DriftError> {
113            debug!("Polling for drift tasks");
114
115            // Get task from the database (query uses skip lock to pull task and update to processing)
116            let task = PostgresClient::get_drift_profile_task(&self.db_pool).await?;
117
118            let Some(task) = task else {
119                return Ok(None);
120            };
121
122            let task_info = DriftTaskInfo {
123                space: task.space.clone(),
124                name: task.name.clone(),
125                version: task.version.clone(),
126                uid: task.uid.clone(),
127                drift_type: DriftType::from_str(&task.drift_type).unwrap(),
128            };
129
130            info!(
131                "Processing drift task for profile: {}/{}/{} and type {}",
132                task.space, task.name, task.version, task.drift_type
133            );
134
135            self.process_task(&task, &task_info).await?;
136
137            // Update the run dates while still holding the lock
138            PostgresClient::update_drift_profile_run_dates(
139                &self.db_pool,
140                &task_info,
141                &task.schedule,
142            )
143            .instrument(span!(Level::INFO, "Update Run Dates"))
144            .await?;
145
146            Ok(Some(task))
147        }
148
149        #[instrument(skip_all)]
150        async fn process_task(
151            &mut self,
152            task: &TaskRequest,
153            task_info: &DriftTaskInfo,
154        ) -> Result<(), DriftError> {
155            // get the drift type
156            let drift_type = DriftType::from_str(&task.drift_type).inspect_err(|e| {
157                error!("Error converting drift type: {:?}", e);
158            })?;
159
160            // get the drift profile
161            let profile = DriftProfile::from_str(drift_type.clone(), task.profile.clone())
162                .inspect_err(|e| {
163                    error!(
164                        "Error converting drift profile for type {:?}: {:?}",
165                        drift_type, e
166                    );
167                })?;
168
169            // check for alerts
170            match self._process_task(profile, task.previous_run).await {
171                Ok(Some(alerts)) => {
172                    info!("Drift task processed successfully with alerts");
173
174                    // Insert alerts atomically within the same transaction
175                    for alert in alerts {
176                        PostgresClient::insert_drift_alert(
177                            &self.db_pool,
178                            task_info,
179                            alert.get("entity_name").unwrap_or(&"NA".to_string()),
180                            &alert,
181                            &drift_type,
182                        )
183                        .await
184                        .map_err(|e| {
185                            error!("Error inserting drift alert: {:?}", e);
186                            DriftError::SqlError(e)
187                        })?;
188                    }
189                    Ok(())
190                }
191                Ok(None) => {
192                    info!("Drift task processed successfully with no alerts");
193                    Ok(())
194                }
195                Err(e) => {
196                    error!("Error processing drift task: {:?}", e);
197                    Err(DriftError::AlertProcessingError(e.to_string()))
198                }
199            }
200        }
201
202        /// Execute single drift computation and alerting
203        ///
204        /// # Returns
205        ///
206        /// * `Result<()>` - Result of drift computation and alerting
207        #[instrument(skip_all)]
208        pub async fn poll_for_tasks(&mut self) -> Result<(), DriftError> {
209            match self.do_poll().await? {
210                Some(_) => {
211                    info!("Successfully processed drift task");
212                    Ok(())
213                }
214                None => {
215                    info!("No triggered schedules found in db. Sleeping for 10 seconds");
216                    tokio::time::sleep(tokio::time::Duration::from_secs(10)).await;
217                    Ok(())
218                }
219            }
220        }
221    }
222
223    #[cfg(test)]
224    mod tests {
225        use super::*;
226        use rusty_logging::logger::{LogLevel, LoggingConfig, RustyLogger};
227        use scouter_settings::DatabaseSettings;
228        use scouter_sql::PostgresClient;
229        use scouter_types::spc::SpcFeatureDriftProfile;
230        use scouter_types::{
231            spc::{SpcAlertConfig, SpcAlertRule, SpcDriftConfig, SpcDriftProfile},
232            AlertDispatchConfig, DriftAlertRequest,
233        };
234        use scouter_types::{CommonCrons, ProfileArgs};
235        use semver::Version;
236        use sqlx::{postgres::Postgres, Pool};
237        use std::collections::HashMap;
238
239        pub async fn cleanup(pool: &Pool<Postgres>) {
240            sqlx::raw_sql(
241                r#"
242                DELETE
243                FROM scouter.spc_drift;
244
245                DELETE
246                FROM scouter.observability_metric;
247
248                DELETE
249                FROM scouter.custom_drift;
250
251                DELETE
252                FROM scouter.drift_alert;
253
254                DELETE
255                FROM scouter.drift_profile;
256
257                DELETE
258                FROM scouter.psi_drift;
259                "#,
260            )
261            .fetch_all(pool)
262            .await
263            .unwrap();
264
265            RustyLogger::setup_logging(Some(LoggingConfig::new(
266                None,
267                Some(LogLevel::Info),
268                None,
269                None,
270            )))
271            .unwrap();
272        }
273
274        #[tokio::test]
275        async fn test_drift_executor_spc() {
276            let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
277                .await
278                .unwrap();
279
280            cleanup(&db_pool).await;
281
282            let alert_config = SpcAlertConfig {
283                rule: SpcAlertRule::default(),
284                schedule: CommonCrons::EveryDay.cron().to_string(),
285                features_to_monitor: vec!["col_1".to_string(), "col_3".to_string()],
286                dispatch_config: AlertDispatchConfig::default(),
287            };
288
289            let config = SpcDriftConfig::new(
290                Some("statworld".to_string()),
291                Some("test_app".to_string()),
292                Some("0.1.0".to_string()),
293                Some(true),
294                Some(25),
295                Some(alert_config),
296                None,
297            )
298            .unwrap();
299
300            let col1_profile = SpcFeatureDriftProfile {
301                id: "col_1".to_string(),
302                center: -3.997113080300062,
303                one_ucl: -1.9742384896265417,
304                one_lcl: -6.019987670973582,
305                two_ucl: 0.048636101046978464,
306                two_lcl: -8.042862261647102,
307                three_ucl: 2.071510691720498,
308                three_lcl: -10.065736852320622,
309                timestamp: Utc::now(),
310            };
311
312            let col3_profile = SpcFeatureDriftProfile {
313                id: "col_3".to_string(),
314                center: -3.937652409303277,
315                one_ucl: -2.0275656995100224,
316                one_lcl: -5.8477391190965315,
317                two_ucl: -0.1174789897167674,
318                two_lcl: -7.757825828889787,
319                three_ucl: 1.7926077200764872,
320                three_lcl: -9.66791253868304,
321                timestamp: Utc::now(),
322            };
323
324            let drift_profile = DriftProfile::Spc(SpcDriftProfile {
325                config,
326                features: HashMap::from([
327                    (col1_profile.id.clone(), col1_profile),
328                    (col3_profile.id.clone(), col3_profile),
329                ]),
330                scouter_version: "0.1.0".to_string(),
331            });
332
333            let profile_args = ProfileArgs {
334                space: "statworld".to_string(),
335                name: "test_app".to_string(),
336                version: Some("0.1.0".to_string()),
337                schedule: "* * * * * *".to_string(),
338                scouter_version: "0.1.0".to_string(),
339                drift_type: DriftType::Spc,
340            };
341
342            let version = Version::new(0, 1, 0);
343            PostgresClient::insert_drift_profile(
344                &db_pool,
345                &drift_profile,
346                &profile_args,
347                &version,
348                &true,
349                &true,
350            )
351            .await
352            .unwrap();
353
354            tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
355
356            let mut populate_path =
357                std::env::current_dir().expect("Failed to get current directory");
358            populate_path.push("src/scripts/populate_spc.sql");
359            let script = std::fs::read_to_string(populate_path).unwrap();
360
361            sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
362            let mut drift_executor = DriftExecutor::new(&db_pool);
363            tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
364
365            drift_executor.poll_for_tasks().await.unwrap();
366
367            // get alerts from db
368            let request = DriftAlertRequest {
369                space: "statworld".to_string(),
370                name: "test_app".to_string(),
371                version: "0.1.0".to_string(),
372                limit_datetime: None,
373                active: None,
374                limit: None,
375            };
376            let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
377                .await
378                .unwrap();
379            assert!(!alerts.is_empty());
380        }
381
382        #[tokio::test]
383        async fn test_drift_executor_spc_missing_feature_data() {
384            // this tests the scenario where only 1 of 2 features has data in the db when polling
385            // for tasks. Need to ensure this does not fail and the present feature and data are
386            // still processed
387            let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
388                .await
389                .unwrap();
390            cleanup(&db_pool).await;
391
392            let mut populate_path =
393                std::env::current_dir().expect("Failed to get current directory");
394            populate_path.push("src/scripts/populate_spc_alert.sql");
395
396            let script = std::fs::read_to_string(populate_path).unwrap();
397            sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
398            tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
399
400            let mut drift_executor = DriftExecutor::new(&db_pool);
401
402            drift_executor.poll_for_tasks().await.unwrap();
403
404            // get alerts from db
405            let request = DriftAlertRequest {
406                space: "statworld".to_string(),
407                name: "test_app".to_string(),
408                version: "0.1.0".to_string(),
409                limit_datetime: None,
410                active: None,
411                limit: None,
412            };
413            let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
414                .await
415                .unwrap();
416
417            assert!(!alerts.is_empty());
418        }
419
420        #[tokio::test]
421        async fn test_drift_executor_psi() {
422            let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
423                .await
424                .unwrap();
425
426            cleanup(&db_pool).await;
427
428            let mut populate_path =
429                std::env::current_dir().expect("Failed to get current directory");
430            populate_path.push("src/scripts/populate_psi.sql");
431
432            let mut script = std::fs::read_to_string(populate_path).unwrap();
433            let bin_count = 1000;
434            let skew_feature = "feature_1";
435            let skew_factor = 10;
436            let apply_skew = true;
437            script = script.replace("{{bin_count}}", &bin_count.to_string());
438            script = script.replace("{{skew_feature}}", skew_feature);
439            script = script.replace("{{skew_factor}}", &skew_factor.to_string());
440            script = script.replace("{{apply_skew}}", &apply_skew.to_string());
441            sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
442            tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
443
444            let mut drift_executor = DriftExecutor::new(&db_pool);
445
446            drift_executor.poll_for_tasks().await.unwrap();
447
448            // get alerts from db
449            let request = DriftAlertRequest {
450                space: "scouter".to_string(),
451                name: "model".to_string(),
452                version: "0.1.0".to_string(),
453                limit_datetime: None,
454                active: None,
455                limit: None,
456            };
457            let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
458                .await
459                .unwrap();
460
461            assert_eq!(alerts.len(), 1);
462        }
463
464        /// This test verifies that the PSI drift executor does **not** generate any drift alerts
465        /// when there are **not enough target samples** to meet the minimum threshold required
466        /// for PSI calculation.
467        ///
468        /// This arg determines how many bin counts to simulate for a production environment.
469        /// In the script there are 3 features, each with 10 bins.
470        /// `bin_count = 2` means we simulate 2 observations per bin.
471        /// So for each feature: 10 bins * 2 samples = 20 samples inserted PER insert.
472        /// Since the script inserts each feature's data 3 times (simulating 3 production batches),
473        /// each feature ends up with: 20 samples * 3 = 60 samples total.
474        /// This is below the required threshold of >100 samples per feature for PSI calculation,
475        /// so no drift alert should be generated.
476        #[tokio::test]
477        async fn test_drift_executor_psi_not_enough_target_samples() {
478            let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
479                .await
480                .unwrap();
481
482            cleanup(&db_pool).await;
483
484            let mut populate_path =
485                std::env::current_dir().expect("Failed to get current directory");
486            populate_path.push("src/scripts/populate_psi.sql");
487
488            let mut script = std::fs::read_to_string(populate_path).unwrap();
489            let bin_count = 2;
490            let skew_feature = "feature_1";
491            let skew_factor = 1;
492            let apply_skew = false;
493            script = script.replace("{{bin_count}}", &bin_count.to_string());
494            script = script.replace("{{skew_feature}}", skew_feature);
495            script = script.replace("{{skew_factor}}", &skew_factor.to_string());
496            script = script.replace("{{apply_skew}}", &apply_skew.to_string());
497            sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
498            tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
499
500            let mut drift_executor = DriftExecutor::new(&db_pool);
501
502            drift_executor.poll_for_tasks().await.unwrap();
503
504            // get alerts from db
505            let request = DriftAlertRequest {
506                space: "scouter".to_string(),
507                name: "model".to_string(),
508                version: "0.1.0".to_string(),
509                limit_datetime: None,
510                active: None,
511                limit: None,
512            };
513            let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
514                .await
515                .unwrap();
516
517            assert!(alerts.is_empty());
518        }
519
520        #[tokio::test]
521        async fn test_drift_executor_custom() {
522            let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
523                .await
524                .unwrap();
525
526            cleanup(&db_pool).await;
527
528            let mut populate_path =
529                std::env::current_dir().expect("Failed to get current directory");
530            populate_path.push("src/scripts/populate_custom.sql");
531
532            let script = std::fs::read_to_string(populate_path).unwrap();
533            sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
534            tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
535
536            let mut drift_executor = DriftExecutor::new(&db_pool);
537
538            drift_executor.poll_for_tasks().await.unwrap();
539
540            // get alerts from db
541            let request = DriftAlertRequest {
542                space: "scouter".to_string(),
543                name: "model".to_string(),
544                version: "0.1.0".to_string(),
545                limit_datetime: None,
546                active: None,
547                limit: None,
548            };
549            let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
550                .await
551                .unwrap();
552
553            assert_eq!(alerts.len(), 1);
554        }
555    }
556}