scouter_drift/llm/
poller.rs1use crate::error::DriftError;
3use crate::llm::evaluator::LLMEvaluator;
4use potato_head::Score;
5use scouter_sql::sql::traits::{LLMDriftSqlLogic, ProfileSqlLogic};
6use scouter_sql::PostgresClient;
7use scouter_types::llm::LLMDriftProfile;
8use scouter_types::{DriftType, GetProfileRequest, LLMRecord, Status};
9use sqlx::{Pool, Postgres};
10use std::collections::HashMap;
11use std::time::Duration;
12use tokio::time::sleep;
13use tracing::{debug, error, info, instrument};
14
15pub struct LLMPoller {
16 db_pool: Pool<Postgres>,
17 max_retries: usize,
18}
19
20impl LLMPoller {
21 pub fn new(db_pool: &Pool<Postgres>, max_retries: usize) -> Self {
22 LLMPoller {
23 db_pool: db_pool.clone(),
24 max_retries,
25 }
26 }
27
28 #[instrument(skip_all)]
29 pub async fn process_drift_record(
30 &mut self,
31 record: &LLMRecord,
32 profile: &LLMDriftProfile,
33 ) -> Result<(HashMap<String, Score>, Option<i32>), DriftError> {
34 debug!("Processing workflow");
35
36 match LLMEvaluator::process_drift_record(record, profile).await {
37 Ok((metrics, score_map, workflow_duration)) => {
38 PostgresClient::insert_llm_metric_values_batch(&self.db_pool, &metrics)
39 .await
40 .inspect_err(|e| {
41 error!("Failed to insert LLM metric values: {:?}", e);
42 })?;
43
44 return Ok((score_map, workflow_duration));
45 }
46 Err(e) => {
47 error!("Failed to process drift record: {:?}", e);
48 return Err(DriftError::LLMEvaluatorError(e.to_string()));
49 }
50 };
51 }
52
53 #[instrument(skip_all)]
54 pub async fn do_poll(&mut self) -> Result<bool, DriftError> {
55 let task = PostgresClient::get_pending_llm_drift_record(&self.db_pool).await?;
57
58 let Some(mut task) = task else {
59 return Ok(false);
60 };
61
62 info!(
63 "Processing llm drift record for profile: {}/{}/{}",
64 task.space, task.name, task.version
65 );
66
67 let request = GetProfileRequest {
69 space: task.space.clone(),
70 name: task.name.clone(),
71 version: task.version.clone(),
72 drift_type: DriftType::LLM,
73 };
74
75 let mut llm_profile = if let Some(profile) =
76 PostgresClient::get_drift_profile(&self.db_pool, &request).await?
77 {
78 let llm_profile: LLMDriftProfile =
79 serde_json::from_value(profile).inspect_err(|e| {
80 error!("Failed to deserialize LLM drift profile: {:?}", e);
81 })?;
82 llm_profile
83 } else {
84 error!(
85 "No LLM drift profile found for {}/{}/{}",
86 task.space, task.name, task.version
87 );
88 return Ok(false);
89 };
90 let mut retry_count = 0;
91
92 llm_profile.workflow.reset_agents().await.inspect_err(|e| {
93 error!("Failed to reset agents: {:?}", e);
94 })?;
95
96 loop {
97 match self.process_drift_record(&task, &llm_profile).await {
98 Ok((result, workflow_duration)) => {
99 task.score = serde_json::to_value(result).inspect_err(|e| {
100 error!("Failed to serialize score map: {:?}", e);
101 })?;
102
103 PostgresClient::update_llm_drift_record_status(
104 &self.db_pool,
105 &task,
106 Status::Processed,
107 workflow_duration,
108 )
109 .await?;
110 break;
111 }
112 Err(e) => {
113 error!(
114 "Failed to process drift record (attempt {}): {:?}",
115 retry_count + 1,
116 e
117 );
118
119 retry_count += 1;
120 if retry_count >= self.max_retries {
121 PostgresClient::update_llm_drift_record_status(
123 &self.db_pool,
124 &task,
125 Status::Failed,
126 None,
127 )
128 .await?;
129 return Err(DriftError::LLMEvaluatorError(e.to_string()));
130 } else {
131 sleep(Duration::from_millis(100 * 2_u64.pow(retry_count as u32))).await;
133 }
134 }
135 }
136 }
137
138 Ok(true)
139 }
140
141 #[instrument(skip_all)]
142 pub async fn poll_for_tasks(&mut self) -> Result<(), DriftError> {
143 let result = self.do_poll().await;
144
145 match result {
147 Ok(true) => {
148 debug!("Successfully processed drift record");
149 Ok(())
150 }
151 Ok(false) => Ok(()),
152 Err(e) => {
153 error!("Error processing drift record: {:?}", e);
154 Ok(())
155 }
156 }
157 }
158}