scouter_drift/
drifter.rs

1#[cfg(feature = "sql")]
2pub mod drift_executor {
3
4    use crate::error::DriftError;
5    use crate::{custom::CustomDrifter, llm::LLMDrifter, psi::PsiDrifter, spc::SpcDrifter};
6    use chrono::{DateTime, Utc};
7
8    use scouter_sql::sql::traits::{AlertSqlLogic, ProfileSqlLogic};
9    use scouter_sql::{sql::schema::TaskRequest, PostgresClient};
10    use scouter_types::{DriftProfile, DriftTaskInfo, DriftType};
11    use sqlx::{Pool, Postgres};
12    use std::collections::BTreeMap;
13    use std::result::Result;
14    use std::result::Result::Ok;
15    use std::str::FromStr;
16    use tracing::{debug, error, info, instrument, span, Instrument, Level};
17
18    #[allow(clippy::enum_variant_names)]
19    pub enum Drifter {
20        SpcDrifter(SpcDrifter),
21        PsiDrifter(PsiDrifter),
22        CustomDrifter(CustomDrifter),
23        LLMDrifter(LLMDrifter),
24    }
25
26    impl Drifter {
27        pub async fn check_for_alerts(
28            &self,
29            db_pool: &Pool<Postgres>,
30            previous_run: DateTime<Utc>,
31        ) -> Result<Option<Vec<BTreeMap<String, String>>>, DriftError> {
32            match self {
33                Drifter::SpcDrifter(drifter) => {
34                    drifter.check_for_alerts(db_pool, previous_run).await
35                }
36                Drifter::PsiDrifter(drifter) => {
37                    drifter.check_for_alerts(db_pool, previous_run).await
38                }
39                Drifter::CustomDrifter(drifter) => {
40                    drifter.check_for_alerts(db_pool, previous_run).await
41                }
42                Drifter::LLMDrifter(drifter) => {
43                    drifter.check_for_alerts(db_pool, previous_run).await
44                }
45            }
46        }
47    }
48
49    pub trait GetDrifter {
50        fn get_drifter(&self) -> Drifter;
51    }
52
53    impl GetDrifter for DriftProfile {
54        /// Get a Drifter for processing drift profile tasks
55        ///
56        /// # Arguments
57        ///
58        /// * `name` - Name of the drift profile
59        /// * `space` - Space of the drift profile
60        /// * `version` - Version of the drift profile
61        ///
62        /// # Returns
63        ///
64        /// * `Drifter` - Drifter enum
65        fn get_drifter(&self) -> Drifter {
66            match self {
67                DriftProfile::Spc(profile) => Drifter::SpcDrifter(SpcDrifter::new(profile.clone())),
68                DriftProfile::Psi(profile) => Drifter::PsiDrifter(PsiDrifter::new(profile.clone())),
69                DriftProfile::Custom(profile) => {
70                    Drifter::CustomDrifter(CustomDrifter::new(profile.clone()))
71                }
72                DriftProfile::LLM(profile) => Drifter::LLMDrifter(LLMDrifter::new(profile.clone())),
73            }
74        }
75    }
76
77    pub struct DriftExecutor {
78        db_pool: Pool<Postgres>,
79    }
80
81    impl DriftExecutor {
82        pub fn new(db_pool: &Pool<Postgres>) -> Self {
83            Self {
84                db_pool: db_pool.clone(),
85            }
86        }
87
88        /// Process a single drift computation task
89        ///
90        /// # Arguments
91        ///
92        /// * `drift_profile` - Drift profile to compute drift for
93        /// * `previous_run` - Previous run timestamp
94        /// * `schedule` - Schedule for drift computation
95        /// * `transaction` - Postgres transaction
96        ///
97        /// # Returns
98        ///
99        pub async fn _process_task(
100            &mut self,
101            profile: DriftProfile,
102            previous_run: DateTime<Utc>,
103        ) -> Result<Option<Vec<BTreeMap<String, String>>>, DriftError> {
104            // match Drifter enum
105
106            profile
107                .get_drifter()
108                .check_for_alerts(&self.db_pool, previous_run)
109                .await
110        }
111
112        async fn do_poll(&mut self) -> Result<Option<TaskRequest>, DriftError> {
113            debug!("Polling for drift tasks");
114
115            // Get task from the database (query uses skip lock to pull task and update to processing)
116            let task = PostgresClient::get_drift_profile_task(&self.db_pool).await?;
117
118            let Some(task) = task else {
119                return Ok(None);
120            };
121
122            let task_info = DriftTaskInfo {
123                space: task.space.clone(),
124                name: task.name.clone(),
125                version: task.version.clone(),
126                uid: task.uid.clone(),
127                drift_type: DriftType::from_str(&task.drift_type).unwrap(),
128            };
129
130            info!(
131                "Processing drift task for profile: {}/{}/{} and type {}",
132                task.space, task.name, task.version, task.drift_type
133            );
134
135            self.process_task(&task, &task_info).await?;
136
137            // Update the run dates while still holding the lock
138            PostgresClient::update_drift_profile_run_dates(
139                &self.db_pool,
140                &task_info,
141                &task.schedule,
142            )
143            .instrument(span!(Level::INFO, "Update Run Dates"))
144            .await?;
145
146            Ok(Some(task))
147        }
148
149        #[instrument(skip_all)]
150        async fn process_task(
151            &mut self,
152            task: &TaskRequest,
153            task_info: &DriftTaskInfo,
154        ) -> Result<(), DriftError> {
155            // get the drift type
156            let drift_type = DriftType::from_str(&task.drift_type).inspect_err(|e| {
157                error!("Error converting drift type: {:?}", e);
158            })?;
159
160            // get the drift profile
161            let profile = DriftProfile::from_str(drift_type.clone(), task.profile.clone())
162                .inspect_err(|e| {
163                    error!(
164                        "Error converting drift profile for type {:?}: {:?}",
165                        drift_type, e
166                    );
167                })?;
168
169            // check for alerts
170            match self._process_task(profile, task.previous_run).await {
171                Ok(Some(alerts)) => {
172                    info!("Drift task processed successfully with alerts");
173
174                    // Insert alerts atomically within the same transaction
175                    for alert in alerts {
176                        PostgresClient::insert_drift_alert(
177                            &self.db_pool,
178                            task_info,
179                            alert.get("entity_name").unwrap_or(&"NA".to_string()),
180                            &alert,
181                            &drift_type,
182                        )
183                        .await
184                        .map_err(|e| {
185                            error!("Error inserting drift alert: {:?}", e);
186                            DriftError::SqlError(e)
187                        })?;
188                    }
189                    Ok(())
190                }
191                Ok(None) => {
192                    info!("Drift task processed successfully with no alerts");
193                    Ok(())
194                }
195                Err(e) => {
196                    error!("Error processing drift task: {:?}", e);
197                    Err(DriftError::AlertProcessingError(e.to_string()))
198                }
199            }
200        }
201
202        /// Execute single drift computation and alerting
203        ///
204        /// # Returns
205        ///
206        /// * `Result<()>` - Result of drift computation and alerting
207        #[instrument(skip_all)]
208        pub async fn poll_for_tasks(&mut self) -> Result<(), DriftError> {
209            match self.do_poll().await? {
210                Some(_) => {
211                    info!("Successfully processed drift task");
212                    Ok(())
213                }
214                None => {
215                    info!("No triggered schedules found in db. Sleeping for 10 seconds");
216                    tokio::time::sleep(tokio::time::Duration::from_secs(10)).await;
217                    Ok(())
218                }
219            }
220        }
221    }
222
223    #[cfg(test)]
224    mod tests {
225        use super::*;
226        use rusty_logging::logger::{LogLevel, LoggingConfig, RustyLogger};
227        use scouter_settings::DatabaseSettings;
228        use scouter_sql::PostgresClient;
229        use scouter_types::DriftAlertRequest;
230        use sqlx::{postgres::Postgres, Pool};
231
232        pub async fn cleanup(pool: &Pool<Postgres>) {
233            sqlx::raw_sql(
234                r#"
235                DELETE 
236                FROM scouter.spc_drift;
237
238                DELETE 
239                FROM scouter.observability_metric;
240
241                DELETE
242                FROM scouter.custom_drift;
243
244                DELETE
245                FROM scouter.drift_alert;
246
247                DELETE
248                FROM scouter.drift_profile;
249
250                DELETE
251                FROM scouter.psi_drift;
252                "#,
253            )
254            .fetch_all(pool)
255            .await
256            .unwrap();
257
258            RustyLogger::setup_logging(Some(LoggingConfig::new(
259                None,
260                Some(LogLevel::Info),
261                None,
262                None,
263            )))
264            .unwrap();
265        }
266
267        #[tokio::test]
268        async fn test_drift_executor_spc() {
269            let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
270                .await
271                .unwrap();
272
273            cleanup(&db_pool).await;
274
275            let mut populate_path =
276                std::env::current_dir().expect("Failed to get current directory");
277            populate_path.push("src/scripts/populate_spc.sql");
278
279            let script = std::fs::read_to_string(populate_path).unwrap();
280            sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
281
282            let mut drift_executor = DriftExecutor::new(&db_pool);
283
284            drift_executor.poll_for_tasks().await.unwrap();
285
286            // get alerts from db
287            let request = DriftAlertRequest {
288                space: "statworld".to_string(),
289                name: "test_app".to_string(),
290                version: "0.1.0".to_string(),
291                limit_datetime: None,
292                active: None,
293                limit: None,
294            };
295            let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
296                .await
297                .unwrap();
298            assert!(!alerts.is_empty());
299        }
300
301        #[tokio::test]
302        async fn test_drift_executor_spc_missing_feature_data() {
303            // this tests the scenario where only 1 of 2 features has data in the db when polling
304            // for tasks. Need to ensure this does not fail and the present feature and data are
305            // still processed
306            let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
307                .await
308                .unwrap();
309            cleanup(&db_pool).await;
310
311            let mut populate_path =
312                std::env::current_dir().expect("Failed to get current directory");
313            populate_path.push("src/scripts/populate_spc_alert.sql");
314
315            let script = std::fs::read_to_string(populate_path).unwrap();
316            sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
317
318            let mut drift_executor = DriftExecutor::new(&db_pool);
319
320            drift_executor.poll_for_tasks().await.unwrap();
321
322            // get alerts from db
323            let request = DriftAlertRequest {
324                space: "statworld".to_string(),
325                name: "test_app".to_string(),
326                version: "0.1.0".to_string(),
327                limit_datetime: None,
328                active: None,
329                limit: None,
330            };
331            let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
332                .await
333                .unwrap();
334
335            assert!(!alerts.is_empty());
336        }
337
338        #[tokio::test]
339        async fn test_drift_executor_psi() {
340            let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
341                .await
342                .unwrap();
343
344            cleanup(&db_pool).await;
345
346            let mut populate_path =
347                std::env::current_dir().expect("Failed to get current directory");
348            populate_path.push("src/scripts/populate_psi.sql");
349
350            let mut script = std::fs::read_to_string(populate_path).unwrap();
351            let bin_count = 1000;
352            let skew_feature = "feature_1";
353            let skew_factor = 10;
354            let apply_skew = true;
355            script = script.replace("{{bin_count}}", &bin_count.to_string());
356            script = script.replace("{{skew_feature}}", skew_feature);
357            script = script.replace("{{skew_factor}}", &skew_factor.to_string());
358            script = script.replace("{{apply_skew}}", &apply_skew.to_string());
359            sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
360
361            let mut drift_executor = DriftExecutor::new(&db_pool);
362
363            drift_executor.poll_for_tasks().await.unwrap();
364
365            // get alerts from db
366            let request = DriftAlertRequest {
367                space: "scouter".to_string(),
368                name: "model".to_string(),
369                version: "0.1.0".to_string(),
370                limit_datetime: None,
371                active: None,
372                limit: None,
373            };
374            let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
375                .await
376                .unwrap();
377
378            assert_eq!(alerts.len(), 1);
379        }
380
381        /// This test verifies that the PSI drift executor does **not** generate any drift alerts
382        /// when there are **not enough target samples** to meet the minimum threshold required
383        /// for PSI calculation.
384        ///
385        /// This arg determines how many bin counts to simulate for a production environment.
386        /// In the script there are 3 features, each with 10 bins.
387        /// `bin_count = 2` means we simulate 2 observations per bin.
388        /// So for each feature: 10 bins * 2 samples = 20 samples inserted PER insert.
389        /// Since the script inserts each feature's data 3 times (simulating 3 production batches),
390        /// each feature ends up with: 20 samples * 3 = 60 samples total.
391        /// This is below the required threshold of >100 samples per feature for PSI calculation,
392        /// so no drift alert should be generated.
393        #[tokio::test]
394        async fn test_drift_executor_psi_not_enough_target_samples() {
395            let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
396                .await
397                .unwrap();
398
399            cleanup(&db_pool).await;
400
401            let mut populate_path =
402                std::env::current_dir().expect("Failed to get current directory");
403            populate_path.push("src/scripts/populate_psi.sql");
404
405            let mut script = std::fs::read_to_string(populate_path).unwrap();
406            let bin_count = 2;
407            let skew_feature = "feature_1";
408            let skew_factor = 1;
409            let apply_skew = false;
410            script = script.replace("{{bin_count}}", &bin_count.to_string());
411            script = script.replace("{{skew_feature}}", skew_feature);
412            script = script.replace("{{skew_factor}}", &skew_factor.to_string());
413            script = script.replace("{{apply_skew}}", &apply_skew.to_string());
414            sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
415
416            let mut drift_executor = DriftExecutor::new(&db_pool);
417
418            drift_executor.poll_for_tasks().await.unwrap();
419
420            // get alerts from db
421            let request = DriftAlertRequest {
422                space: "scouter".to_string(),
423                name: "model".to_string(),
424                version: "0.1.0".to_string(),
425                limit_datetime: None,
426                active: None,
427                limit: None,
428            };
429            let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
430                .await
431                .unwrap();
432
433            assert!(alerts.is_empty());
434        }
435
436        #[tokio::test]
437        async fn test_drift_executor_custom() {
438            let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
439                .await
440                .unwrap();
441
442            cleanup(&db_pool).await;
443
444            let mut populate_path =
445                std::env::current_dir().expect("Failed to get current directory");
446            populate_path.push("src/scripts/populate_custom.sql");
447
448            let script = std::fs::read_to_string(populate_path).unwrap();
449            sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
450
451            let mut drift_executor = DriftExecutor::new(&db_pool);
452
453            drift_executor.poll_for_tasks().await.unwrap();
454
455            // get alerts from db
456            let request = DriftAlertRequest {
457                space: "scouter".to_string(),
458                name: "model".to_string(),
459                version: "0.1.0".to_string(),
460                limit_datetime: None,
461                active: None,
462                limit: None,
463            };
464            let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
465                .await
466                .unwrap();
467
468            assert_eq!(alerts.len(), 1);
469        }
470    }
471}