use crate::error::DriftError;
use chrono::{DateTime, Utc};
use scouter_dispatch::AlertDispatcher;
use scouter_sql::sql::traits::GenAIDriftSqlLogic;
use scouter_sql::{sql::cache::entity_cache, PostgresClient};
use scouter_types::{custom::ComparisonMetricAlert, genai::GenAIEvalProfile};
use scouter_types::{AlertMap, ProfileBaseArgs};
use sqlx::{Pool, Postgres};
use tracing::error;
use tracing::info;
pub struct GenAIDrifter {
profile: GenAIEvalProfile,
}
impl GenAIDrifter {
pub fn new(profile: GenAIEvalProfile) -> Self {
Self { profile }
}
fn profile_id(&self) -> String {
format!(
"{}/{}/{}",
self.profile.space(),
self.profile.name(),
self.profile.version()
)
}
pub async fn get_workflow_value(
&self,
limit_datetime: &DateTime<Utc>,
db_pool: &Pool<Postgres>,
) -> Result<Option<f64>, DriftError> {
let entity_id = entity_cache()
.get_entity_id_from_uid(db_pool, &self.profile.config.uid)
.await?;
PostgresClient::get_genai_workflow_value(db_pool, limit_datetime, &entity_id)
.await
.inspect_err(|e| {
error!(
"Unable to obtain genai metric data from DB for {}: {}",
self.profile_id(),
e
);
})
.map_err(Into::into)
}
pub async fn get_metric_value(
&self,
limit_datetime: &DateTime<Utc>,
db_pool: &Pool<Postgres>,
) -> Result<Option<f64>, DriftError> {
let value = self.get_workflow_value(limit_datetime, db_pool).await?;
if value.is_none() {
info!(
"No genai metric data found for {}. Skipping alert processing.",
self.profile_id()
);
}
Ok(value)
}
pub async fn generate_alerts(
&self,
observed_value: f64,
) -> Result<Option<Vec<AlertMap>>, DriftError> {
let Some(alert_condition) = &self.profile.config.alert_config.alert_condition else {
info!(
"No alert condition configured for {}. Skipping alert processing.",
self.profile_id()
);
return Ok(None);
};
if !alert_condition.should_alert(observed_value) {
info!(
"No alerts to process for {} (observed: {}, baseline: {})",
self.profile_id(),
observed_value,
alert_condition.baseline_value
);
return Ok(None);
}
let metric_name = "genai_workflow_metric".to_string();
let comparison_alert = ComparisonMetricAlert {
metric_name: metric_name.clone(),
baseline_value: alert_condition.baseline_value,
observed_value,
delta: alert_condition.delta,
alert_threshold: alert_condition.alert_threshold.clone(),
};
let alert_dispatcher = AlertDispatcher::new(&self.profile.config).inspect_err(|e| {
error!(
"Error creating alert dispatcher for {}: {}",
self.profile_id(),
e
);
})?;
alert_dispatcher
.process_alerts(&comparison_alert)
.await
.inspect_err(|e| {
error!("Error processing alerts for {}: {}", self.profile_id(), e);
})?;
Ok(Some(vec![AlertMap::GenAI(comparison_alert)]))
}
pub async fn check_for_alerts(
&self,
db_pool: &Pool<Postgres>,
previous_run: &DateTime<Utc>,
) -> Result<Option<Vec<AlertMap>>, DriftError> {
let Some(metric_value) = self.get_metric_value(previous_run, db_pool).await? else {
return Ok(None);
};
self.generate_alerts(metric_value).await.inspect_err(|e| {
error!("Error generating alerts for {}: {}", self.profile_id(), e);
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use potato_head::mock::create_score_prompt;
use scouter_types::genai::{ComparisonOperator, EvaluationTasks};
use scouter_types::genai::{GenAIAlertConfig, GenAIEvalConfig, GenAIEvalProfile, LLMJudgeTask};
use scouter_types::{
AlertCondition, AlertDispatchConfig, AlertThreshold, ConsoleDispatchConfig,
};
use serde_json::Value;
async fn get_test_drifter() -> GenAIDrifter {
let prompt = create_score_prompt(Some(vec!["input".to_string()]));
let task1 = LLMJudgeTask::new_rs(
"metric1",
prompt.clone(),
Value::Number(4.into()),
None,
ComparisonOperator::GreaterThanOrEqual,
None,
None,
None,
None,
);
let task2 = LLMJudgeTask::new_rs(
"metric2",
prompt.clone(),
Value::Number(2.into()),
None,
ComparisonOperator::LessThanOrEqual,
None,
None,
None,
None,
);
let tasks = EvaluationTasks::new()
.add_task(task1)
.add_task(task2)
.build();
let alert_condition = AlertCondition {
baseline_value: 5.0,
alert_threshold: AlertThreshold::Below,
delta: Some(1.0),
};
let alert_config = GenAIAlertConfig {
schedule: "0 0 * * * *".to_string(),
dispatch_config: AlertDispatchConfig::Console(ConsoleDispatchConfig { enabled: true }),
alert_condition: Some(alert_condition),
};
let drift_config =
GenAIEvalConfig::new("scouter", "ML", "0.1.0", 1.0, alert_config, None).unwrap();
let profile = GenAIEvalProfile::new(drift_config, tasks).await.unwrap();
GenAIDrifter::new(profile)
}
#[tokio::test]
async fn test_generate_alerts_triggers_when_threshold_exceeded() {
let drifter = get_test_drifter().await;
let observed_value = 3.0;
let alerts = drifter
.generate_alerts(observed_value)
.await
.expect("Should generate alerts successfully");
assert!(
alerts.is_some(),
"Should generate alerts for out-of-bounds value"
);
let alert_map = &alerts.unwrap()[0];
match alert_map {
AlertMap::GenAI(alert) => {
assert_eq!(alert.metric_name, "genai_workflow_metric");
assert_eq!(alert.observed_value, observed_value);
}
_ => panic!("Expected GenAI alert map"),
}
}
#[tokio::test]
async fn test_generate_alerts_no_trigger_within_threshold() {
let drifter = get_test_drifter().await;
let observed_value = 5.0;
let alerts = drifter
.generate_alerts(observed_value)
.await
.expect("Should generate alerts successfully");
assert!(
alerts.is_none(),
"Should not generate alerts for value within threshold"
);
}
}