Skip to main content

scouter_drift/genai/
drift.rs

1use crate::error::DriftError;
2use chrono::{DateTime, Utc};
3use scouter_dispatch::AlertDispatcher;
4use scouter_sql::sql::traits::GenAIDriftSqlLogic;
5use scouter_sql::{sql::cache::entity_cache, PostgresClient};
6use scouter_types::{custom::ComparisonMetricAlert, genai::GenAIEvalProfile};
7use scouter_types::{AlertMap, ProfileBaseArgs};
8use sqlx::{Pool, Postgres};
9use tracing::error;
10use tracing::info;
11
12pub struct GenAIDrifter {
13    profile: GenAIEvalProfile,
14}
15
16impl GenAIDrifter {
17    pub fn new(profile: GenAIEvalProfile) -> Self {
18        Self { profile }
19    }
20
21    /// Helper method to format profile identifier for logging
22    fn profile_id(&self) -> String {
23        format!(
24            "{}/{}/{}",
25            self.profile.space(),
26            self.profile.name(),
27            self.profile.version()
28        )
29    }
30
31    /// Fetches workflow value from database for the given time limit
32    pub async fn get_workflow_value(
33        &self,
34        limit_datetime: &DateTime<Utc>,
35        db_pool: &Pool<Postgres>,
36    ) -> Result<Option<f64>, DriftError> {
37        let entity_id = entity_cache()
38            .get_entity_id_from_uid(db_pool, &self.profile.config.uid)
39            .await?;
40
41        PostgresClient::get_genai_workflow_value(db_pool, limit_datetime, &entity_id)
42            .await
43            .inspect_err(|e| {
44                error!(
45                    "Unable to obtain genai metric data from DB for {}: {}",
46                    self.profile_id(),
47                    e
48                );
49            })
50            .map_err(Into::into)
51    }
52
53    /// Retrieves metric value and logs if not found
54    pub async fn get_metric_value(
55        &self,
56        limit_datetime: &DateTime<Utc>,
57        db_pool: &Pool<Postgres>,
58    ) -> Result<Option<f64>, DriftError> {
59        let value = self.get_workflow_value(limit_datetime, db_pool).await?;
60
61        if value.is_none() {
62            info!(
63                "No genai metric data found for {}. Skipping alert processing.",
64                self.profile_id()
65            );
66        }
67
68        Ok(value)
69    }
70
71    /// Generates alerts if conditions are met and dispatches them
72    pub async fn generate_alerts(
73        &self,
74        observed_value: f64,
75    ) -> Result<Option<Vec<AlertMap>>, DriftError> {
76        // Early return if no alert condition configured
77        let Some(alert_condition) = &self.profile.config.alert_config.alert_condition else {
78            info!(
79                "No alert condition configured for {}. Skipping alert processing.",
80                self.profile_id()
81            );
82            return Ok(None);
83        };
84
85        // Check if alert should be triggered
86        if !alert_condition.should_alert(observed_value) {
87            info!(
88                "No alerts to process for {} (observed: {}, baseline: {})",
89                self.profile_id(),
90                observed_value,
91                alert_condition.baseline_value
92            );
93            return Ok(None);
94        }
95
96        // Build comparison alert with owned data
97        let metric_name = "genai_workflow_metric".to_string();
98        let comparison_alert = ComparisonMetricAlert {
99            metric_name: metric_name.clone(),
100            baseline_value: alert_condition.baseline_value,
101            observed_value,
102            delta: alert_condition.delta,
103            alert_threshold: alert_condition.alert_threshold.clone(),
104        };
105
106        // Dispatch alert
107        let alert_dispatcher = AlertDispatcher::new(&self.profile.config).inspect_err(|e| {
108            error!(
109                "Error creating alert dispatcher for {}: {}",
110                self.profile_id(),
111                e
112            );
113        })?;
114
115        alert_dispatcher
116            .process_alerts(&comparison_alert)
117            .await
118            .inspect_err(|e| {
119                error!("Error processing alerts for {}: {}", self.profile_id(), e);
120            })?;
121
122        // Convert to owned map before returning
123        Ok(Some(vec![AlertMap::GenAI(comparison_alert)]))
124    }
125
126    /// Checks for alerts based on metric value since previous run
127    pub async fn check_for_alerts(
128        &self,
129        db_pool: &Pool<Postgres>,
130        previous_run: &DateTime<Utc>,
131    ) -> Result<Option<Vec<AlertMap>>, DriftError> {
132        let Some(metric_value) = self.get_metric_value(previous_run, db_pool).await? else {
133            return Ok(None);
134        };
135
136        self.generate_alerts(metric_value).await.inspect_err(|e| {
137            error!("Error generating alerts for {}: {}", self.profile_id(), e);
138        })
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145    use potato_head::mock::create_score_prompt;
146    use scouter_types::genai::{ComparisonOperator, EvaluationTasks};
147    use scouter_types::genai::{GenAIAlertConfig, GenAIEvalConfig, GenAIEvalProfile, LLMJudgeTask};
148    use scouter_types::{
149        AlertCondition, AlertDispatchConfig, AlertThreshold, ConsoleDispatchConfig,
150    };
151    use serde_json::Value;
152
153    async fn get_test_drifter() -> GenAIDrifter {
154        let prompt = create_score_prompt(Some(vec!["input".to_string()]));
155
156        let task1 = LLMJudgeTask::new_rs(
157            "metric1",
158            prompt.clone(),
159            Value::Number(4.into()),
160            None,
161            ComparisonOperator::GreaterThanOrEqual,
162            None,
163            None,
164            None,
165            None,
166        );
167
168        let task2 = LLMJudgeTask::new_rs(
169            "metric2",
170            prompt.clone(),
171            Value::Number(2.into()),
172            None,
173            ComparisonOperator::LessThanOrEqual,
174            None,
175            None,
176            None,
177            None,
178        );
179
180        let tasks = EvaluationTasks::new()
181            .add_task(task1)
182            .add_task(task2)
183            .build();
184
185        let alert_condition = AlertCondition {
186            baseline_value: 5.0,
187            alert_threshold: AlertThreshold::Below,
188            delta: Some(1.0),
189        };
190        let alert_config = GenAIAlertConfig {
191            schedule: "0 0 * * * *".to_string(),
192            dispatch_config: AlertDispatchConfig::Console(ConsoleDispatchConfig { enabled: true }),
193            alert_condition: Some(alert_condition),
194        };
195
196        let drift_config =
197            GenAIEvalConfig::new("scouter", "ML", "0.1.0", 1.0, alert_config, None).unwrap();
198
199        let profile = GenAIEvalProfile::new(drift_config, tasks).await.unwrap();
200
201        GenAIDrifter::new(profile)
202    }
203
204    #[tokio::test]
205    async fn test_generate_alerts_triggers_when_threshold_exceeded() {
206        let drifter = get_test_drifter().await;
207
208        let observed_value = 3.0;
209        let alerts = drifter
210            .generate_alerts(observed_value)
211            .await
212            .expect("Should generate alerts successfully");
213
214        assert!(
215            alerts.is_some(),
216            "Should generate alerts for out-of-bounds value"
217        );
218
219        let alert_map = &alerts.unwrap()[0];
220        match alert_map {
221            AlertMap::GenAI(alert) => {
222                assert_eq!(alert.metric_name, "genai_workflow_metric");
223                assert_eq!(alert.observed_value, observed_value);
224            }
225            _ => panic!("Expected GenAI alert map"),
226        }
227    }
228
229    #[tokio::test]
230    async fn test_generate_alerts_no_trigger_within_threshold() {
231        let drifter = get_test_drifter().await;
232
233        // Workflow metric value within acceptable range
234        let observed_value = 5.0;
235        let alerts = drifter
236            .generate_alerts(observed_value)
237            .await
238            .expect("Should generate alerts successfully");
239
240        assert!(
241            alerts.is_none(),
242            "Should not generate alerts for value within threshold"
243        );
244    }
245}