1use crate::error::DriftError;
3use chrono::Duration;
4use scouter_dataframe::parquet::tracing::service::get_trace_span_service;
5use scouter_evaluate::evaluate::GenAIEvaluator;
6use scouter_sql::sql::aggregator::get_trace_summary_service;
7use scouter_sql::sql::traits::{GenAIDriftSqlLogic, ProfileSqlLogic};
8use scouter_sql::PostgresClient;
9use scouter_types::genai::{EvalSet, GenAIEvalProfile};
10use scouter_types::sql::{TraceFilters, TraceSpan};
11use scouter_types::{EvalRecord, Status, TraceId};
12use sqlx::{Pool, Postgres};
13use std::sync::Arc;
14use tokio::time::sleep;
15use tracing::{debug, error, instrument};
16
17enum TraceSpanResult {
18 Ready(Arc<Vec<TraceSpan>>),
19 Reschedule,
20 Failed,
21}
22
23#[instrument(skip_all)]
24async fn wait_for_trace_spans(
28 task_uid: &str,
29 max_wait: Duration,
30 initial_backoff: Duration,
31) -> Result<Arc<Vec<TraceSpan>>, DriftError> {
32 let start = chrono::Utc::now();
33 let mut backoff = initial_backoff;
34
35 let summary_service = get_trace_summary_service().ok_or_else(|| {
36 DriftError::GenAIEvaluatorError("TraceSummaryService not initialized".to_string())
37 })?;
38
39 let span_service = get_trace_span_service().ok_or_else(|| {
40 DriftError::GenAIEvaluatorError("TraceSpanService not initialized".to_string())
41 })?;
42
43 loop {
44 let filters = TraceFilters {
46 queue_uid: Some(task_uid.to_string()),
47 limit: Some(1),
48 ..Default::default()
49 };
50
51 match summary_service
52 .query_service
53 .get_paginated_traces(&filters)
54 .await
55 {
56 Ok(response) if !response.items.is_empty() => {
57 let trace_id_hex = &response.items[0].trace_id;
58 debug!(
59 "Found trace summary for task {}, trace_id={}",
60 task_uid, trace_id_hex
61 );
62
63 let trace_id_bytes = TraceId::hex_to_bytes(trace_id_hex).map_err(|e| {
65 DriftError::GenAIEvaluatorError(format!("Invalid trace_id hex: {}", e))
66 })?;
67
68 match span_service
69 .query_service
70 .get_trace_spans(Some(trace_id_bytes.as_slice()), None, None, None, None)
71 .await
72 {
73 Ok(spans) if !spans.is_empty() => {
74 debug!("Found {} spans for task {}", spans.len(), task_uid);
75 return Ok(Arc::new(spans));
76 }
77 Ok(_) => {
78 debug!(
79 "Trace summary found but spans not yet available for {}",
80 task_uid
81 );
82 }
83 Err(e) => {
84 error!("Error fetching spans from Delta Lake: {:?}", e);
85 }
86 }
87 }
88 Ok(_) => {
89 }
91 Err(e) => {
92 error!("Error querying trace summaries: {:?}", e);
93 }
94 }
95
96 if (chrono::Utc::now() - start) >= max_wait {
97 error!(
98 "Timeout waiting for trace spans after {:?} for task {}",
99 max_wait, task_uid
100 );
101 return Err(DriftError::TraceSpansNotAvailable(task_uid.to_string()));
102 }
103
104 debug!(
105 "No spans found yet for {}, waiting {:?} before retry",
106 task_uid, backoff
107 );
108 sleep(backoff.to_std().unwrap()).await;
109 backoff = std::cmp::min(backoff * 2, Duration::seconds(5));
110 }
111}
112
113#[instrument(skip_all)]
114async fn wait_for_trace_spans_with_reschedule(
115 pool: &Pool<Postgres>,
116 task: &EvalRecord,
117 max_retries: &i32,
118 trace_wait_timeout: Duration,
119 trace_backoff: Duration,
120 trace_reschedule_delay: Duration,
121) -> Result<TraceSpanResult, DriftError> {
122 let retry_count = task.retry_count;
123
124 if retry_count >= *max_retries {
125 return Ok(TraceSpanResult::Failed);
126 }
127
128 match wait_for_trace_spans(&task.uid, trace_wait_timeout, trace_backoff).await {
129 Ok(spans) => Ok(TraceSpanResult::Ready(spans)),
130 Err(DriftError::TraceSpansNotAvailable(_)) => {
131 PostgresClient::reschedule_genai_eval_record(pool, &task.uid, trace_reschedule_delay)
132 .await?;
133 Ok(TraceSpanResult::Reschedule)
134 }
135 Err(e) => Err(e),
136 }
137}
138
139pub struct GenAIPoller {
147 db_pool: Pool<Postgres>,
148 max_retries: i32,
149 trace_wait_timeout: Duration,
150 trace_backoff: Duration,
151 trace_reschedule_delay: Duration,
152}
153
154impl GenAIPoller {
155 pub fn new(
156 db_pool: &Pool<Postgres>,
157 max_retries: i32,
158 trace_wait_timeout: Duration,
159 trace_backoff: Duration,
160 trace_reschedule_delay: Duration,
161 ) -> Self {
162 GenAIPoller {
163 db_pool: db_pool.clone(),
164 max_retries,
165 trace_wait_timeout,
166 trace_backoff,
167 trace_reschedule_delay,
168 }
169 }
170
171 #[instrument(skip_all)]
172 pub async fn process_event_record(
173 &mut self,
174 record: &EvalRecord,
175 profile: &GenAIEvalProfile,
176 spans: Arc<Vec<TraceSpan>>,
177 ) -> Result<EvalSet, DriftError> {
178 debug!("Processing workflow");
179
180 let profile = Arc::new(profile.clone());
182
183 match GenAIEvaluator::process_event_record(record, profile, spans).await {
184 Ok(result_set) => {
185 PostgresClient::insert_eval_task_results_batch(
187 &self.db_pool,
188 &result_set.records,
189 &record.entity_id,
190 )
191 .await
192 .inspect_err(|e| {
193 error!("Failed to insert LLM task results: {:?}", e);
194 })?;
195
196 PostgresClient::insert_genai_eval_workflow_record(
198 &self.db_pool,
199 &result_set.inner,
200 &record.entity_id,
201 )
202 .await
203 .inspect_err(|e| {
204 error!("Failed to insert GenAI workflow record: {:?}", e);
205 })?;
206
207 return Ok(result_set);
208 }
209 Err(e) => {
210 error!("Failed to process drift record: {:?}", e);
211 return Err(DriftError::GenAIEvaluatorError(e.to_string()));
212 }
213 };
214 }
215
216 #[instrument(skip_all)]
217 pub async fn do_poll(&mut self) -> Result<bool, DriftError> {
218 let task = PostgresClient::get_pending_genai_eval_record(&self.db_pool).await?;
219
220 let Some(task) = task else {
221 return Ok(false);
222 };
223
224 debug!("Processing genai drift record for profile: {}", task.uid);
225
226 let mut genai_profile = if let Some(profile) =
227 PostgresClient::get_drift_profile(&self.db_pool, &task.entity_id).await?
228 {
229 let genai_profile: GenAIEvalProfile =
230 serde_json::from_value(profile).inspect_err(|e| {
231 error!("Failed to deserialize GenAI drift profile: {:?}", e);
232 })?;
233 genai_profile
234 } else {
235 error!("No GenAI drift profile found for {}", task.uid);
236 return Ok(false);
237 };
238
239 let mut retry_count = 0;
240 if let Some(workflow) = &mut genai_profile.workflow {
241 workflow.reset_agents().await.inspect_err(|e| {
242 error!("Failed to reset agents: {:?}", e);
243 })?;
244 }
245
246 let spans = if genai_profile.has_trace_assertions() {
247 match wait_for_trace_spans_with_reschedule(
248 &self.db_pool,
249 &task,
250 &self.max_retries,
251 self.trace_wait_timeout,
252 self.trace_backoff,
253 self.trace_reschedule_delay,
254 )
255 .await?
256 {
257 TraceSpanResult::Ready(spans) => spans,
258 TraceSpanResult::Reschedule => {
259 debug!(
260 "Traces not yet available for task {}, rescheduled",
261 task.uid
262 );
263 return Ok(true);
264 }
265 TraceSpanResult::Failed => {
266 error!("Max retries exceeded for task {}", task.uid);
267 PostgresClient::update_genai_eval_record_status(
268 &self.db_pool,
269 &task,
270 Status::Failed,
271 &0,
272 )
273 .await?;
274 return Err(DriftError::TraceSpansNotAvailable(task.uid.clone()));
275 }
276 }
277 } else {
278 Arc::new(vec![])
279 };
280
281 loop {
282 match self
283 .process_event_record(&task, &genai_profile, spans.clone())
284 .await
285 {
286 Ok(result_set) => {
287 PostgresClient::update_genai_eval_record_status(
288 &self.db_pool,
289 &task,
290 Status::Processed,
291 &result_set.inner.duration_ms,
292 )
293 .await?;
294 break;
295 }
296 Err(e) => {
297 error!(
298 "Failed to process drift record (attempt {}): {:?}",
299 retry_count + 1,
300 e
301 );
302
303 retry_count += 1;
304 if retry_count >= self.max_retries {
305 PostgresClient::update_genai_eval_record_status(
307 &self.db_pool,
308 &task,
309 Status::Failed,
310 &0,
311 )
312 .await?;
313 return Err(DriftError::GenAIEvaluatorError(e.to_string()));
314 } else {
315 let val = 100 * 2_i64.pow(retry_count as u32);
317 sleep(Duration::milliseconds(val).to_std()?).await;
318 }
319 }
320 }
321 }
322
323 Ok(true)
324 }
325
326 #[instrument(skip_all)]
327 pub async fn poll_for_tasks(&mut self) -> Result<(), DriftError> {
328 let result = self.do_poll().await;
329
330 match result {
332 Ok(true) => {
333 debug!("Successfully processed drift record");
334 Ok(())
335 }
336 Ok(false) => {
337 sleep(Duration::seconds(1).to_std()?).await;
338 Ok(())
339 }
340 Err(e) => {
341 error!("Error processing drift record: {:?}", e);
342 Ok(())
343 }
344 }
345 }
346}