1use crate::error::DriftError;
2use chrono::{DateTime, Utc};
3use scouter_dispatch::AlertDispatcher;
4use scouter_sql::sql::traits::LLMDriftSqlLogic;
5use scouter_sql::PostgresClient;
6use scouter_types::contracts::ServiceInfo;
7use scouter_types::{custom::ComparisonMetricAlert, llm::LLMDriftProfile, AlertThreshold};
8use sqlx::{Pool, Postgres};
9use std::collections::{BTreeMap, HashMap};
10use tracing::error;
11use tracing::info;
12
13pub struct LLMDrifter {
14 service_info: ServiceInfo,
15 profile: LLMDriftProfile,
16}
17
18impl LLMDrifter {
19 pub fn new(profile: LLMDriftProfile) -> Self {
20 Self {
21 service_info: ServiceInfo {
22 name: profile.config.name.clone(),
23 space: profile.config.space.clone(),
24 version: profile.config.version.clone(),
25 },
26 profile,
27 }
28 }
29
30 pub async fn get_observed_llm_metric_values(
31 &self,
32 limit_datetime: &DateTime<Utc>,
33 db_pool: &Pool<Postgres>,
34 ) -> Result<HashMap<String, f64>, DriftError> {
35 let metrics: Vec<String> = self
36 .profile
37 .metrics
38 .iter()
39 .map(|metric| metric.name.clone())
40 .collect();
41
42 Ok(PostgresClient::get_llm_metric_values(
43 db_pool,
44 &self.service_info,
45 limit_datetime,
46 &metrics,
47 )
48 .await
49 .inspect_err(|e| {
50 let msg = format!(
51 "Error: Unable to obtain llm metric data from DB for {}/{}/{}: {}",
52 self.service_info.space, self.service_info.name, self.service_info.version, e
53 );
54 error!(msg);
55 })?)
56 }
57
58 pub async fn get_metric_map(
59 &self,
60 limit_datetime: &DateTime<Utc>,
61 db_pool: &Pool<Postgres>,
62 ) -> Result<Option<HashMap<String, f64>>, DriftError> {
63 let metric_map = self
64 .get_observed_llm_metric_values(limit_datetime, db_pool)
65 .await?;
66
67 if metric_map.is_empty() {
68 info!(
69 "No llm metric data was found for {}/{}/{}. Skipping alert processing.",
70 self.service_info.space, self.service_info.name, self.service_info.version,
71 );
72 return Ok(None);
73 }
74
75 Ok(Some(metric_map))
76 }
77
78 fn is_out_of_bounds(
79 training_value: f64,
80 observed_value: f64,
81 alert_condition: &AlertThreshold,
82 alert_boundary: Option<f64>,
83 ) -> bool {
84 if observed_value == training_value {
85 return false;
86 }
87
88 let below_threshold = |boundary: Option<f64>| match boundary {
89 Some(b) => observed_value < training_value - b,
90 None => observed_value < training_value,
91 };
92
93 let above_threshold = |boundary: Option<f64>| match boundary {
94 Some(b) => observed_value > training_value + b,
95 None => observed_value > training_value,
96 };
97
98 match alert_condition {
99 AlertThreshold::Below => below_threshold(alert_boundary),
100 AlertThreshold::Above => above_threshold(alert_boundary),
101 AlertThreshold::Outside => {
102 below_threshold(alert_boundary) || above_threshold(alert_boundary)
103 } }
105 }
106
107 pub async fn generate_alerts(
108 &self,
109 metric_map: &HashMap<String, f64>,
110 ) -> Result<Option<Vec<ComparisonMetricAlert>>, DriftError> {
111 let metric_alerts: Vec<ComparisonMetricAlert> = metric_map
112 .iter()
113 .filter_map(|(name, observed_value)| {
114 let training_value = self
115 .profile
116 .get_metric_value(name)
117 .inspect_err(|e| {
118 let msg = format!("Error getting training value for metric {name}: {e}");
119 error!(msg);
120 })
121 .ok()?;
122 let alert_condition = &self
123 .profile
124 .config
125 .alert_config
126 .alert_conditions
127 .as_ref()
128 .unwrap()[name];
129 if Self::is_out_of_bounds(
130 training_value,
131 *observed_value,
132 &alert_condition.alert_threshold,
133 alert_condition.alert_threshold_value,
134 ) {
135 Some(ComparisonMetricAlert {
136 metric_name: name.clone(),
137 training_metric_value: training_value,
138 observed_metric_value: *observed_value,
139 alert_threshold_value: alert_condition.alert_threshold_value,
140 alert_threshold: alert_condition.alert_threshold.clone(),
141 })
142 } else {
143 None
144 }
145 })
146 .collect();
147
148 if metric_alerts.is_empty() {
149 info!(
150 "No alerts to process for {}/{}/{}",
151 self.service_info.space, self.service_info.name, self.service_info.version
152 );
153 return Ok(None);
154 }
155
156 let alert_dispatcher = AlertDispatcher::new(&self.profile.config).inspect_err(|e| {
157 let msg = format!(
158 "Error creating alert dispatcher for {}/{}/{}: {}",
159 self.service_info.space, self.service_info.name, self.service_info.version, e
160 );
161 error!(msg);
162 })?;
163
164 for alert in &metric_alerts {
165 alert_dispatcher
166 .process_alerts(alert)
167 .await
168 .inspect_err(|e| {
169 let msg = format!(
170 "Error processing alerts for {}/{}/{}: {}",
171 self.service_info.space,
172 self.service_info.name,
173 self.service_info.version,
174 e
175 );
176 error!(msg);
177 })?;
178 }
179
180 Ok(Some(metric_alerts))
181 }
182
183 fn organize_alerts(mut alerts: Vec<ComparisonMetricAlert>) -> Vec<BTreeMap<String, String>> {
184 let mut alert_vec = Vec::new();
185 alerts.iter_mut().for_each(|alert| {
186 let mut alert_map = BTreeMap::new();
187 alert_map.insert("entity_name".to_string(), alert.metric_name.clone());
188 alert_map.insert(
189 "training_metric_value".to_string(),
190 alert.training_metric_value.to_string(),
191 );
192 alert_map.insert(
193 "observed_metric_value".to_string(),
194 alert.observed_metric_value.to_string(),
195 );
196 let alert_threshold_value_str = match alert.alert_threshold_value {
197 Some(value) => value.to_string(),
198 None => "None".to_string(),
199 };
200 alert_map.insert(
201 "alert_threshold_value".to_string(),
202 alert_threshold_value_str,
203 );
204 alert_map.insert(
205 "alert_threshold".to_string(),
206 alert.alert_threshold.to_string(),
207 );
208 alert_vec.push(alert_map);
209 });
210
211 alert_vec
212 }
213
214 pub async fn check_for_alerts(
215 &self,
216 db_pool: &Pool<Postgres>,
217 previous_run: DateTime<Utc>,
218 ) -> Result<Option<Vec<BTreeMap<String, String>>>, DriftError> {
219 let metric_map = self.get_metric_map(&previous_run, db_pool).await?;
220
221 match metric_map {
222 Some(metric_map) => {
223 let alerts = self.generate_alerts(&metric_map).await.inspect_err(|e| {
224 let msg = format!(
225 "Error generating alerts for {}/{}/{}: {}",
226 self.service_info.space,
227 self.service_info.name,
228 self.service_info.version,
229 e
230 );
231 error!(msg);
232 })?;
233 match alerts {
234 Some(alerts) => Ok(Some(Self::organize_alerts(alerts))),
235 None => Ok(None),
236 }
237 }
238 None => Ok(None),
239 }
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246 use potato_head::{create_score_prompt, LLMTestServer};
247 use scouter_types::llm::{LLMAlertConfig, LLMDriftConfig, LLMDriftMetric, LLMDriftProfile};
248
249 async fn get_test_drifter() -> LLMDrifter {
250 let prompt = create_score_prompt(Some(vec!["input".to_string()]));
251 let metric1 = LLMDriftMetric::new(
252 "coherence",
253 5.0,
254 AlertThreshold::Below,
255 Some(0.5),
256 Some(prompt.clone()),
257 )
258 .unwrap();
259
260 let metric2 = LLMDriftMetric::new(
261 "relevancy",
262 5.0,
263 AlertThreshold::Below,
264 None,
265 Some(prompt.clone()),
266 )
267 .unwrap();
268
269 let alert_config = LLMAlertConfig::default();
270 let drift_config =
271 LLMDriftConfig::new("scouter", "ML", "0.1.0", 25, alert_config, None).unwrap();
272
273 let profile = LLMDriftProfile::from_metrics(drift_config, vec![metric1, metric2])
274 .await
275 .unwrap();
276
277 LLMDrifter::new(profile)
278 }
279
280 #[test]
281 fn test_is_out_of_bounds() {
282 let relevancy_training_value = 5.0;
284
285 let relevancy_observed_value = 4.0;
287
288 let relevancy_alert_condition = AlertThreshold::Below;
290
291 let relevancy_alert_boundary = Some(0.5);
294
295 let relevancy_is_out_of_bounds = LLMDrifter::is_out_of_bounds(
296 relevancy_training_value,
297 relevancy_observed_value,
298 &relevancy_alert_condition,
299 relevancy_alert_boundary,
300 );
301 assert!(relevancy_is_out_of_bounds);
302
303 let coherence_training_value = 0.76;
307
308 let coherence_observed_value = 0.67;
310
311 let coherence_alert_condition = AlertThreshold::Below;
313
314 let coherence_alert_boundary = None;
316
317 let coherence_is_out_of_bounds = LLMDrifter::is_out_of_bounds(
318 coherence_training_value,
319 coherence_observed_value,
320 &coherence_alert_condition,
321 coherence_alert_boundary,
322 );
323 assert!(coherence_is_out_of_bounds);
324 }
325
326 #[test]
327 fn test_generate_llm_alerts() {
328 let mut mock = LLMTestServer::new();
329 mock.start_server().unwrap();
330 let runtime = tokio::runtime::Runtime::new().unwrap();
331
332 let mut metric_map = HashMap::new();
333 metric_map.insert("coherence".to_string(), 4.0);
335 metric_map.insert("relevancy".to_string(), 4.5);
337
338 let alerts = runtime.block_on(async {
339 let drifter = get_test_drifter().await;
340 drifter.generate_alerts(&metric_map).await.unwrap().unwrap()
341 });
342
343 assert_eq!(alerts.len(), 2);
344 mock.stop_server().unwrap();
345 }
346}