scouter-drift 0.25.0

Drift logic for Scouter
Documentation
use crate::error::DriftError;
use chrono::{DateTime, Utc};
use scouter_dispatch::AlertDispatcher;
use scouter_sql::sql::traits::GenAIDriftSqlLogic;
use scouter_sql::{sql::cache::entity_cache, PostgresClient};
use scouter_types::{custom::ComparisonMetricAlert, genai::GenAIEvalProfile};
use scouter_types::{AlertMap, ProfileBaseArgs};
use sqlx::{Pool, Postgres};
use tracing::error;
use tracing::info;

pub struct GenAIDrifter {
    profile: GenAIEvalProfile,
}

impl GenAIDrifter {
    pub fn new(profile: GenAIEvalProfile) -> Self {
        Self { profile }
    }

    /// Helper method to format profile identifier for logging
    fn profile_id(&self) -> String {
        format!(
            "{}/{}/{}",
            self.profile.space(),
            self.profile.name(),
            self.profile.version()
        )
    }

    /// Fetches workflow value from database for the given time limit
    pub async fn get_workflow_value(
        &self,
        limit_datetime: &DateTime<Utc>,
        db_pool: &Pool<Postgres>,
    ) -> Result<Option<f64>, DriftError> {
        let entity_id = entity_cache()
            .get_entity_id_from_uid(db_pool, &self.profile.config.uid)
            .await?;

        PostgresClient::get_genai_workflow_value(db_pool, limit_datetime, &entity_id)
            .await
            .inspect_err(|e| {
                error!(
                    "Unable to obtain genai metric data from DB for {}: {}",
                    self.profile_id(),
                    e
                );
            })
            .map_err(Into::into)
    }

    /// Retrieves metric value and logs if not found
    pub async fn get_metric_value(
        &self,
        limit_datetime: &DateTime<Utc>,
        db_pool: &Pool<Postgres>,
    ) -> Result<Option<f64>, DriftError> {
        let value = self.get_workflow_value(limit_datetime, db_pool).await?;

        if value.is_none() {
            info!(
                "No genai metric data found for {}. Skipping alert processing.",
                self.profile_id()
            );
        }

        Ok(value)
    }

    /// Generates alerts if conditions are met and dispatches them
    pub async fn generate_alerts(
        &self,
        observed_value: f64,
    ) -> Result<Option<Vec<AlertMap>>, DriftError> {
        // Early return if no alert condition configured
        let Some(alert_condition) = &self.profile.config.alert_config.alert_condition else {
            info!(
                "No alert condition configured for {}. Skipping alert processing.",
                self.profile_id()
            );
            return Ok(None);
        };

        // Check if alert should be triggered
        if !alert_condition.should_alert(observed_value) {
            info!(
                "No alerts to process for {} (observed: {}, baseline: {})",
                self.profile_id(),
                observed_value,
                alert_condition.baseline_value
            );
            return Ok(None);
        }

        // Build comparison alert with owned data
        let metric_name = "genai_workflow_metric".to_string();
        let comparison_alert = ComparisonMetricAlert {
            metric_name: metric_name.clone(),
            baseline_value: alert_condition.baseline_value,
            observed_value,
            delta: alert_condition.delta,
            alert_threshold: alert_condition.alert_threshold.clone(),
        };

        // Dispatch alert
        let alert_dispatcher = AlertDispatcher::new(&self.profile.config).inspect_err(|e| {
            error!(
                "Error creating alert dispatcher for {}: {}",
                self.profile_id(),
                e
            );
        })?;

        alert_dispatcher
            .process_alerts(&comparison_alert)
            .await
            .inspect_err(|e| {
                error!("Error processing alerts for {}: {}", self.profile_id(), e);
            })?;

        // Convert to owned map before returning
        Ok(Some(vec![AlertMap::GenAI(comparison_alert)]))
    }

    /// Checks for alerts based on metric value since previous run
    pub async fn check_for_alerts(
        &self,
        db_pool: &Pool<Postgres>,
        previous_run: &DateTime<Utc>,
    ) -> Result<Option<Vec<AlertMap>>, DriftError> {
        let Some(metric_value) = self.get_metric_value(previous_run, db_pool).await? else {
            return Ok(None);
        };

        self.generate_alerts(metric_value).await.inspect_err(|e| {
            error!("Error generating alerts for {}: {}", self.profile_id(), e);
        })
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use potato_head::mock::create_score_prompt;
    use scouter_types::genai::{ComparisonOperator, EvaluationTasks};
    use scouter_types::genai::{GenAIAlertConfig, GenAIEvalConfig, GenAIEvalProfile, LLMJudgeTask};
    use scouter_types::{
        AlertCondition, AlertDispatchConfig, AlertThreshold, ConsoleDispatchConfig,
    };
    use serde_json::Value;

    async fn get_test_drifter() -> GenAIDrifter {
        let prompt = create_score_prompt(Some(vec!["input".to_string()]));

        let task1 = LLMJudgeTask::new_rs(
            "metric1",
            prompt.clone(),
            Value::Number(4.into()),
            None,
            ComparisonOperator::GreaterThanOrEqual,
            None,
            None,
            None,
            None,
        );

        let task2 = LLMJudgeTask::new_rs(
            "metric2",
            prompt.clone(),
            Value::Number(2.into()),
            None,
            ComparisonOperator::LessThanOrEqual,
            None,
            None,
            None,
            None,
        );

        let tasks = EvaluationTasks::new()
            .add_task(task1)
            .add_task(task2)
            .build();

        let alert_condition = AlertCondition {
            baseline_value: 5.0,
            alert_threshold: AlertThreshold::Below,
            delta: Some(1.0),
        };
        let alert_config = GenAIAlertConfig {
            schedule: "0 0 * * * *".to_string(),
            dispatch_config: AlertDispatchConfig::Console(ConsoleDispatchConfig { enabled: true }),
            alert_condition: Some(alert_condition),
        };

        let drift_config =
            GenAIEvalConfig::new("scouter", "ML", "0.1.0", 1.0, alert_config, None).unwrap();

        let profile = GenAIEvalProfile::new(drift_config, tasks).await.unwrap();

        GenAIDrifter::new(profile)
    }

    #[tokio::test]
    async fn test_generate_alerts_triggers_when_threshold_exceeded() {
        let drifter = get_test_drifter().await;

        let observed_value = 3.0;
        let alerts = drifter
            .generate_alerts(observed_value)
            .await
            .expect("Should generate alerts successfully");

        assert!(
            alerts.is_some(),
            "Should generate alerts for out-of-bounds value"
        );

        let alert_map = &alerts.unwrap()[0];
        match alert_map {
            AlertMap::GenAI(alert) => {
                assert_eq!(alert.metric_name, "genai_workflow_metric");
                assert_eq!(alert.observed_value, observed_value);
            }
            _ => panic!("Expected GenAI alert map"),
        }
    }

    #[tokio::test]
    async fn test_generate_alerts_no_trigger_within_threshold() {
        let drifter = get_test_drifter().await;

        // Workflow metric value within acceptable range
        let observed_value = 5.0;
        let alerts = drifter
            .generate_alerts(observed_value)
            .await
            .expect("Should generate alerts successfully");

        assert!(
            alerts.is_none(),
            "Should not generate alerts for value within threshold"
        );
    }
}