1use crate::error::DriftError;
2use crate::{custom::CustomDrifter, genai::GenAIDrifter, psi::PsiDrifter, spc::SpcDrifter};
3use chrono::{DateTime, Utc};
4use scouter_sql::sql::traits::{AlertSqlLogic, ProfileSqlLogic};
5use scouter_sql::{sql::schema::TaskRequest, PostgresClient};
6use scouter_types::{AlertMap, DriftProfile};
7use sqlx::{Pool, Postgres};
8use std::result::Result;
9use std::result::Result::Ok;
10
11use tracing::{debug, error, info, instrument, span, Instrument, Level};
12
13#[allow(clippy::enum_variant_names)]
14pub enum Drifter {
15 SpcDrifter(SpcDrifter),
16 PsiDrifter(PsiDrifter),
17 CustomDrifter(CustomDrifter),
18 GenAIDrifter(GenAIDrifter),
19}
20
21impl Drifter {
22 pub async fn check_for_alerts(
23 &self,
24 db_pool: &Pool<Postgres>,
25 previous_run: &DateTime<Utc>,
26 ) -> Result<Option<Vec<AlertMap>>, DriftError> {
27 match self {
28 Drifter::SpcDrifter(drifter) => drifter.check_for_alerts(db_pool, previous_run).await,
29 Drifter::PsiDrifter(drifter) => drifter.check_for_alerts(db_pool, previous_run).await,
30 Drifter::CustomDrifter(drifter) => {
31 drifter.check_for_alerts(db_pool, previous_run).await
32 }
33 Drifter::GenAIDrifter(drifter) => drifter.check_for_alerts(db_pool, previous_run).await,
34 }
35 }
36}
37
38pub trait GetDrifter {
39 fn get_drifter(&self) -> Drifter;
40}
41
42impl GetDrifter for DriftProfile {
43 fn get_drifter(&self) -> Drifter {
55 match self {
56 DriftProfile::Spc(profile) => Drifter::SpcDrifter(SpcDrifter::new(profile.clone())),
57 DriftProfile::Psi(profile) => Drifter::PsiDrifter(PsiDrifter::new(profile.clone())),
58 DriftProfile::Custom(profile) => {
59 Drifter::CustomDrifter(CustomDrifter::new(profile.clone()))
60 }
61 DriftProfile::GenAI(profile) => {
62 Drifter::GenAIDrifter(GenAIDrifter::new(profile.clone()))
63 }
64 }
65 }
66}
67
68pub struct DriftExecutor {
69 db_pool: Pool<Postgres>,
70}
71
72impl DriftExecutor {
73 pub fn new(db_pool: &Pool<Postgres>) -> Self {
74 Self {
75 db_pool: db_pool.clone(),
76 }
77 }
78
79 pub async fn _process_task(
91 &mut self,
92 profile: DriftProfile,
93 previous_run: &DateTime<Utc>,
94 ) -> Result<Option<Vec<AlertMap>>, DriftError> {
95 profile
98 .get_drifter()
99 .check_for_alerts(&self.db_pool, previous_run)
100 .await
101 }
102
103 async fn do_poll(&mut self) -> bool {
104 debug!("Polling for drift tasks");
105
106 let task = match PostgresClient::get_drift_profile_task(&self.db_pool).await {
108 Ok(task) => task,
109 Err(e) => {
110 error!("Error fetching drift task: {:?}", e);
111 return false;
112 }
113 };
114
115 let Some(task) = task else {
116 return false;
117 };
118
119 info!(
120 "Processing drift task for profile: {} and type {}",
121 task.uid, task.drift_type
122 );
123
124 match self.process_task(&task).await {
126 Ok(_) => info!(
127 "Successfully processed drift task for profile: {}",
128 task.uid
129 ),
130 Err(e) => error!(
131 "Error processing drift task for profile {}: {:?}",
132 task.uid, e
133 ),
134 }
135
136 match PostgresClient::update_drift_profile_run_dates(
137 &self.db_pool,
138 &task.entity_id,
139 &task.schedule,
140 &task.previous_run,
141 )
142 .instrument(span!(Level::INFO, "Update Run Dates"))
143 .await
144 {
145 Ok(_) => info!("Updated run dates for drift profile task: {}", task.uid),
146 Err(e) => error!(
147 "CRITICAL: Failed to reschedule task Error updating run dates for drift profile task {}: {:?}",
148 task.uid, e
149 ),
150 }
151
152 true
153 }
154
155 #[instrument(skip_all)]
156 async fn process_task(
157 &mut self,
158 task: &TaskRequest,
159 ) -> Result<(), DriftError> {
162 let profile = DriftProfile::from_str(&task.drift_type, &task.profile).inspect_err(|e| {
164 error!(
165 "Error converting drift profile for type {:?}: {:?}",
166 &task.drift_type, e
167 );
168 })?;
169
170 match self._process_task(profile, &task.previous_run).await {
171 Ok(Some(alerts)) => {
172 info!("Drift task processed successfully with alerts");
173
174 for alert in alerts {
176 PostgresClient::insert_drift_alert(&self.db_pool, &task.entity_id, &alert)
177 .await
178 .map_err(|e| {
179 error!("Error inserting drift alert: {:?}", e);
180 DriftError::SqlError(e)
181 })?;
182 }
183 Ok(())
184 }
185 Ok(None) => {
186 info!("Drift task processed successfully with no alerts");
187 Ok(())
188 }
189 Err(e) => {
190 error!("Error processing drift task: {:?}", e);
191 Err(DriftError::AlertProcessingError(e.to_string()))
192 }
193 }
194 }
195
196 #[instrument(skip_all)]
202 pub async fn poll_for_tasks(&mut self) -> Result<(), DriftError> {
203 match self.do_poll().await {
204 true => {
205 info!("Successfully processed drift task");
206 Ok(())
207 }
208 false => {
209 info!("No triggered schedules found in db. Sleeping for 10 seconds");
210 tokio::time::sleep(tokio::time::Duration::from_secs(10)).await;
211 Ok(())
212 }
213 }
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use crate::GenAIPoller;
220
221 use super::*;
222 use chrono::Duration;
223 use rusty_logging::logger::{LogLevel, LoggingConfig, RustyLogger};
224 use scouter_settings::DatabaseSettings;
225 use scouter_sql::sql::traits::{EntitySqlLogic, GenAIDriftSqlLogic, SpcSqlLogic};
226 use scouter_sql::PostgresClient;
227 use scouter_types::spc::SpcFeatureDriftProfile;
228 use scouter_types::{
229 spc::{SpcAlertConfig, SpcAlertRule, SpcDriftConfig, SpcDriftProfile},
230 AlertDispatchConfig, DriftAlertPaginationRequest,
231 };
232 use scouter_types::{BoxedGenAIEvalRecord, DriftType, ProfileArgs, SpcRecord};
233 use semver::Version;
234 use sqlx::{postgres::Postgres, Pool};
235 use std::collections::HashMap;
236
237 use potato_head::mock::{create_score_prompt, LLMTestServer};
238 use scouter_types::genai::{
239 AssertionTask, ComparisonOperator, EvaluationTaskType, EvaluationTasks, GenAIAlertConfig,
240 GenAIEvalConfig, GenAIEvalProfile, LLMJudgeTask,
241 };
242 use scouter_types::{AlertCondition, AlertThreshold, GenAIEvalRecord};
243 use serde_json::Value;
244
245 pub async fn cleanup(pool: &Pool<Postgres>) {
246 sqlx::raw_sql(
247 r#"
248 DELETE
249 FROM scouter.spc_drift;
250
251 DELETE
252 FROM scouter.drift_entities;
253
254 DELETE
255 FROM scouter.observability_metric;
256
257 DELETE
258 FROM scouter.custom_drift;
259
260 DELETE
261 FROM scouter.drift_alert;
262
263 DELETE
264 FROM scouter.drift_profile;
265
266 DELETE
267 FROM scouter.psi_drift;
268
269 DELETE
270 FROM scouter.genai_eval_workflow;
271
272 DELETE
273 FROM scouter.genai_eval_task;
274
275 DELETE
276 FROM scouter.genai_eval_record;
277 "#,
278 )
279 .fetch_all(pool)
280 .await
281 .unwrap();
282
283 RustyLogger::setup_logging(Some(LoggingConfig::new(
284 None,
285 Some(LogLevel::Info),
286 None,
287 None,
288 )))
289 .unwrap();
290 }
291
292 #[tokio::test]
293 async fn test_drift_executor_spc() {
294 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
295 .await
296 .unwrap();
297
298 cleanup(&db_pool).await;
299
300 let alert_config = SpcAlertConfig {
301 rule: SpcAlertRule::default(),
302 schedule: "* * * * * * *".to_string(),
304 features_to_monitor: vec!["col_1".to_string(), "col_3".to_string()],
305 dispatch_config: AlertDispatchConfig::default(),
306 };
307
308 let config = SpcDriftConfig::new(
309 "statworld",
310 "test_app",
311 "0.1.0",
312 Some(true),
313 Some(25),
314 Some(alert_config),
315 None,
316 )
317 .unwrap();
318
319 let col1_profile = SpcFeatureDriftProfile {
320 id: "col_1".to_string(),
321 center: -3.997113080300062,
322 one_ucl: -1.9742384896265417,
323 one_lcl: -6.019987670973582,
324 two_ucl: 0.048636101046978464,
325 two_lcl: -8.042862261647102,
326 three_ucl: 2.071510691720498,
327 three_lcl: -10.065736852320622,
328 timestamp: Utc::now(),
329 };
330
331 let col3_profile = SpcFeatureDriftProfile {
332 id: "col_3".to_string(),
333 center: -3.937652409303277,
334 one_ucl: -2.0275656995100224,
335 one_lcl: -5.8477391190965315,
336 two_ucl: -0.1174789897167674,
337 two_lcl: -7.757825828889787,
338 three_ucl: 1.7926077200764872,
339 three_lcl: -9.66791253868304,
340 timestamp: Utc::now(),
341 };
342
343 let drift_profile = DriftProfile::Spc(SpcDriftProfile {
344 config,
345 features: HashMap::from([
346 (col1_profile.id.clone(), col1_profile),
347 (col3_profile.id.clone(), col3_profile),
348 ]),
349 scouter_version: "0.1.0".to_string(),
350 });
351
352 let profile_args = ProfileArgs {
353 space: "statworld".to_string(),
354 name: "test_app".to_string(),
355 version: Some("0.1.0".to_string()),
356 schedule: "* * * * * *".to_string(),
357 scouter_version: "0.1.0".to_string(),
358 drift_type: DriftType::Spc,
359 };
360
361 let version = Version::new(0, 1, 0);
362 let uid = PostgresClient::insert_drift_profile(
363 &db_pool,
364 &drift_profile,
365 &profile_args,
366 &version,
367 &true,
368 &true,
369 )
370 .await
371 .unwrap();
372 let entity_id = PostgresClient::get_entity_id_from_uid(&db_pool, &uid)
373 .await
374 .unwrap();
375
376 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
377
378 let mut records = vec![]; for i in 0..100 {
380 let record = SpcRecord {
381 created_at: Utc::now() + chrono::Duration::seconds(i),
383 uid: uid.clone(),
384 feature: "col_1".to_string(),
385 value: 10.0 + i as f64,
386 entity_id,
387 };
388 records.push(record);
389 }
390
391 PostgresClient::insert_spc_drift_records_batch(&db_pool, &records, &entity_id)
392 .await
393 .unwrap();
394
395 let mut drift_executor = DriftExecutor::new(&db_pool);
396 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
397
398 drift_executor.poll_for_tasks().await.unwrap();
399
400 let request = DriftAlertPaginationRequest {
402 active: None,
403 limit: None,
404 uid: uid.clone(),
405 ..Default::default()
406 };
407
408 let entity_id = PostgresClient::get_entity_id_from_uid(&db_pool, &uid)
409 .await
410 .unwrap();
411
412 let alerts = PostgresClient::get_paginated_drift_alerts(&db_pool, &request, &entity_id)
413 .await
414 .unwrap();
415 assert!(!alerts.items.is_empty());
416 }
417
418 #[tokio::test]
419 async fn test_drift_executor_psi() {
420 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
421 .await
422 .unwrap();
423
424 cleanup(&db_pool).await;
425
426 let mut populate_path = std::env::current_dir().expect("Failed to get current directory");
427 populate_path.push("src/scripts/populate_psi.sql");
428
429 let mut script = std::fs::read_to_string(populate_path).unwrap();
430 let bin_count = 1000;
431 let skew_feature = "feature_1";
432 let skew_factor = 10;
433 let apply_skew = true;
434 script = script.replace("{{bin_count}}", &bin_count.to_string());
435 script = script.replace("{{skew_feature}}", skew_feature);
436 script = script.replace("{{skew_factor}}", &skew_factor.to_string());
437 script = script.replace("{{apply_skew}}", &apply_skew.to_string());
438 sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
439 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
440
441 let mut drift_executor = DriftExecutor::new(&db_pool);
442
443 drift_executor.poll_for_tasks().await.unwrap();
444
445 let request = DriftAlertPaginationRequest {
447 uid: "019ae1b4-3003-77c1-8eed-2ec005e85963".to_string(),
448 active: None,
449 limit: None,
450 ..Default::default()
451 };
452
453 let entity_id = PostgresClient::get_entity_id_from_space_name_version_drift_type(
454 &db_pool,
455 "scouter",
456 "model",
457 "0.1.0",
458 DriftType::Psi.to_string(),
459 )
460 .await
461 .unwrap();
462
463 let alerts = PostgresClient::get_paginated_drift_alerts(&db_pool, &request, &entity_id)
464 .await
465 .unwrap();
466
467 assert_eq!(alerts.items.len(), 1);
468 }
469
470 #[tokio::test]
483 async fn test_drift_executor_psi_not_enough_target_samples() {
484 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
485 .await
486 .unwrap();
487
488 cleanup(&db_pool).await;
489
490 let mut populate_path = std::env::current_dir().expect("Failed to get current directory");
491 populate_path.push("src/scripts/populate_psi.sql");
492
493 let mut script = std::fs::read_to_string(populate_path).unwrap();
494 let bin_count = 2;
495 let skew_feature = "feature_1";
496 let skew_factor = 1;
497 let apply_skew = false;
498 script = script.replace("{{bin_count}}", &bin_count.to_string());
499 script = script.replace("{{skew_feature}}", skew_feature);
500 script = script.replace("{{skew_factor}}", &skew_factor.to_string());
501 script = script.replace("{{apply_skew}}", &apply_skew.to_string());
502 sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
503 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
504
505 let mut drift_executor = DriftExecutor::new(&db_pool);
506
507 drift_executor.poll_for_tasks().await.unwrap();
508
509 let request = DriftAlertPaginationRequest {
511 uid: "019ae1b4-3003-77c1-8eed-2ec005e85963".to_string(),
512 active: None,
513 limit: None,
514 ..Default::default()
515 };
516
517 let entity_id = PostgresClient::get_entity_id_from_space_name_version_drift_type(
518 &db_pool,
519 "scouter",
520 "model",
521 "0.1.0",
522 DriftType::Psi.to_string(),
523 )
524 .await
525 .unwrap();
526
527 let alerts = PostgresClient::get_paginated_drift_alerts(&db_pool, &request, &entity_id)
528 .await
529 .unwrap();
530
531 assert!(alerts.items.is_empty());
532 }
533
534 #[tokio::test]
535 async fn test_drift_executor_custom() {
536 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
537 .await
538 .unwrap();
539
540 cleanup(&db_pool).await;
541
542 let mut populate_path = std::env::current_dir().expect("Failed to get current directory");
543 populate_path.push("src/scripts/populate_custom.sql");
544
545 let script = std::fs::read_to_string(populate_path).unwrap();
546 sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
547 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
548
549 let mut drift_executor = DriftExecutor::new(&db_pool);
550
551 drift_executor.poll_for_tasks().await.unwrap();
552
553 let request = DriftAlertPaginationRequest {
555 uid: "scouter|model|0.1.0|custom".to_string(),
556 ..Default::default()
557 };
558
559 let entity_id = PostgresClient::get_entity_id_from_space_name_version_drift_type(
560 &db_pool,
561 "scouter",
562 "model",
563 "0.1.0",
564 DriftType::Custom.to_string(),
565 )
566 .await
567 .unwrap();
568
569 let alerts = PostgresClient::get_paginated_drift_alerts(&db_pool, &request, &entity_id)
570 .await
571 .unwrap();
572
573 assert_eq!(alerts.items.len(), 2);
574 }
575
576 #[test]
577 fn test_drift_executor_genai() {
578 let mut mock = LLMTestServer::new();
580 mock.start_server().unwrap();
581 let runtime = tokio::runtime::Runtime::new().unwrap();
582
583 let db_pool = runtime.block_on(async {
584 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
586 .await
587 .unwrap();
588
589 cleanup(&db_pool).await;
590 db_pool
591 });
592
593 let prompt = create_score_prompt(Some(vec!["input".to_string()]));
595
596 let assertion_level_1 = AssertionTask {
597 id: "input_check".to_string(),
598 context_path: Some("input.foo".to_string()),
599 operator: ComparisonOperator::Equals,
600 expected_value: Value::String("bar".to_string()),
601 description: Some("Check if input.foo is bar".to_string()),
602 task_type: EvaluationTaskType::Assertion,
603 depends_on: vec![],
604 result: None,
605 condition: false,
606 item_context_path: None,
607 };
608
609 let judge_task = LLMJudgeTask::new_rs(
610 "query_relevance",
611 prompt.clone(),
612 Value::Number(1.into()),
613 Some("score".to_string()),
614 ComparisonOperator::GreaterThanOrEqual,
615 None,
616 None,
617 None,
618 None,
619 );
620
621 let assert_query_score = AssertionTask {
622 id: "assert_score".to_string(),
623 context_path: Some("query_relevance.score".to_string()),
624 operator: ComparisonOperator::IsNumeric,
625 expected_value: Value::Bool(true),
626 depends_on: vec!["query_relevance".to_string()],
627 task_type: EvaluationTaskType::Assertion,
628 description: Some("Check that score is numeric".to_string()),
629 result: None,
630 condition: false,
631 item_context_path: None,
632 };
633
634 let tasks = EvaluationTasks::new()
635 .add_task(assertion_level_1)
636 .add_task(judge_task)
637 .add_task(assert_query_score)
638 .build();
639
640 let alert_condition = AlertCondition {
642 baseline_value: 0.8, alert_threshold: AlertThreshold::Below,
644 delta: Some(0.01), };
646
647 let alert_config = GenAIAlertConfig {
648 schedule: "* * * * * *".to_string(), dispatch_config: AlertDispatchConfig::default(),
650 alert_condition: Some(alert_condition),
651 };
652
653 let drift_config =
654 GenAIEvalConfig::new("scouter", "genai_test", "0.1.0", 1.0, alert_config, None)
655 .unwrap();
656
657 let profile = runtime
658 .block_on(async { GenAIEvalProfile::new(drift_config, tasks).await })
659 .unwrap();
660 let drift_profile = DriftProfile::GenAI(profile.clone());
661
662 let profile_args = ProfileArgs {
664 space: "scouter".to_string(),
665 name: "genai_test".to_string(),
666 version: Some("0.1.0".to_string()),
667 schedule: "* * * * * *".to_string(),
668 scouter_version: "0.1.0".to_string(),
669 drift_type: DriftType::GenAI,
670 };
671
672 let version = Version::new(0, 1, 0);
673
674 let uid = runtime.block_on(async {
675 PostgresClient::insert_drift_profile(
676 &db_pool,
677 &drift_profile,
678 &profile_args,
679 &version,
680 &true,
681 &true,
682 )
683 .await
684 .unwrap()
685 });
686
687 let entity_id = runtime.block_on(async {
688 PostgresClient::get_entity_id_from_uid(&db_pool, &uid)
689 .await
690 .unwrap()
691 });
692
693 std::thread::sleep(std::time::Duration::from_secs(1));
695
696 let mut records = vec![];
698 for i in 0..50 {
699 let context = serde_json::json!({
701 "input": {
702 "foo": if i % 4 == 0 { "bar" } else { "wrong" } }
704 });
705
706 let record = GenAIEvalRecord::new_rs(
707 context,
708 Utc::now() + chrono::Duration::seconds(i),
709 format!("UID{}", i),
710 uid.clone(),
711 None,
712 None,
713 );
714
715 records.push(BoxedGenAIEvalRecord::new(record));
716 }
717
718 let mut poller = GenAIPoller::new(
720 &db_pool,
721 3,
722 Duration::seconds(10),
723 Duration::milliseconds(100),
724 Duration::seconds(30),
725 );
726 for record in records {
727 runtime.block_on(async {
730 PostgresClient::insert_genai_eval_record(&db_pool, record, &entity_id)
731 .await
732 .unwrap();
733
734 poller.do_poll().await.unwrap();
735 });
736 }
737
738 let mut drift_executor = DriftExecutor::new(&db_pool);
740
741 runtime.block_on(async {
742 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
743 drift_executor.poll_for_tasks().await.unwrap();
744 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
745 });
746
747 let request = DriftAlertPaginationRequest {
749 uid: uid.clone(),
750 active: None,
751 limit: None,
752 ..Default::default()
753 };
754
755 let alerts = runtime.block_on(async {
756 PostgresClient::get_paginated_drift_alerts(&db_pool, &request, &entity_id)
757 .await
758 .unwrap()
759 });
760
761 assert!(
762 !alerts.items.is_empty(),
763 "Expected drift alerts to be generated for low pass rate"
764 );
765
766 let alert = &alerts.items[0];
768
769 assert_eq!(alert.alert.entity_name(), "genai_workflow_metric");
770
771 let observed_value: f64 = match &alert.alert {
773 AlertMap::GenAI(genai_alert) => genai_alert.observed_value,
774 _ => panic!("Expected GenAI alert map"),
775 };
776
777 assert!(
778 observed_value < 0.8, "Expected low pass rate to trigger alert"
780 );
781
782 mock.stop_server().unwrap();
784 }
785}