Skip to main content

scouter_drift/genai/
poller.rs

1// Module for polling GenAI drift records that are "pending" and need to be processed
2use crate::error::DriftError;
3use chrono::Duration;
4use scouter_dataframe::parquet::tracing::service::get_trace_span_service;
5use scouter_evaluate::evaluate::GenAIEvaluator;
6use scouter_sql::sql::aggregator::get_trace_summary_service;
7use scouter_sql::sql::traits::{GenAIDriftSqlLogic, ProfileSqlLogic};
8use scouter_sql::PostgresClient;
9use scouter_types::genai::{EvalSet, GenAIEvalProfile};
10use scouter_types::sql::{TraceFilters, TraceSpan};
11use scouter_types::{EvalRecord, Status, TraceId};
12use sqlx::{Pool, Postgres};
13use std::sync::Arc;
14use tokio::time::sleep;
15use tracing::{debug, error, instrument};
16
17enum TraceSpanResult {
18    Ready(Arc<Vec<TraceSpan>>),
19    Reschedule,
20    Failed,
21}
22
23#[instrument(skip_all)]
24/// Helper function to wait for trace spans associated with a task UID.
25/// Queries Delta Lake: first finds the trace summary by queue_uid, then fetches
26/// full spans from the trace_spans table.
27async fn wait_for_trace_spans(
28    task_uid: &str,
29    max_wait: Duration,
30    initial_backoff: Duration,
31) -> Result<Arc<Vec<TraceSpan>>, DriftError> {
32    let start = chrono::Utc::now();
33    let mut backoff = initial_backoff;
34
35    let summary_service = get_trace_summary_service().ok_or_else(|| {
36        DriftError::GenAIEvaluatorError("TraceSummaryService not initialized".to_string())
37    })?;
38
39    let span_service = get_trace_span_service().ok_or_else(|| {
40        DriftError::GenAIEvaluatorError("TraceSpanService not initialized".to_string())
41    })?;
42
43    loop {
44        // Query summaries by queue_uid to find the trace_id
45        let filters = TraceFilters {
46            queue_uid: Some(task_uid.to_string()),
47            limit: Some(1),
48            ..Default::default()
49        };
50
51        match summary_service
52            .query_service
53            .get_paginated_traces(&filters)
54            .await
55        {
56            Ok(response) if !response.items.is_empty() => {
57                let trace_id_hex = &response.items[0].trace_id;
58                debug!(
59                    "Found trace summary for task {}, trace_id={}",
60                    task_uid, trace_id_hex
61                );
62
63                // Fetch full spans from Delta Lake
64                let trace_id_bytes = TraceId::hex_to_bytes(trace_id_hex).map_err(|e| {
65                    DriftError::GenAIEvaluatorError(format!("Invalid trace_id hex: {}", e))
66                })?;
67
68                match span_service
69                    .query_service
70                    .get_trace_spans(Some(trace_id_bytes.as_slice()), None, None, None, None)
71                    .await
72                {
73                    Ok(spans) if !spans.is_empty() => {
74                        debug!("Found {} spans for task {}", spans.len(), task_uid);
75                        return Ok(Arc::new(spans));
76                    }
77                    Ok(_) => {
78                        debug!(
79                            "Trace summary found but spans not yet available for {}",
80                            task_uid
81                        );
82                    }
83                    Err(e) => {
84                        error!("Error fetching spans from Delta Lake: {:?}", e);
85                    }
86                }
87            }
88            Ok(_) => {
89                // No summary found yet
90            }
91            Err(e) => {
92                error!("Error querying trace summaries: {:?}", e);
93            }
94        }
95
96        if (chrono::Utc::now() - start) >= max_wait {
97            error!(
98                "Timeout waiting for trace spans after {:?} for task {}",
99                max_wait, task_uid
100            );
101            return Err(DriftError::TraceSpansNotAvailable(task_uid.to_string()));
102        }
103
104        debug!(
105            "No spans found yet for {}, waiting {:?} before retry",
106            task_uid, backoff
107        );
108        sleep(backoff.to_std().unwrap()).await;
109        backoff = std::cmp::min(backoff * 2, Duration::seconds(5));
110    }
111}
112
113#[instrument(skip_all)]
114async fn wait_for_trace_spans_with_reschedule(
115    pool: &Pool<Postgres>,
116    task: &EvalRecord,
117    max_retries: &i32,
118    trace_wait_timeout: Duration,
119    trace_backoff: Duration,
120    trace_reschedule_delay: Duration,
121) -> Result<TraceSpanResult, DriftError> {
122    let retry_count = task.retry_count;
123
124    if retry_count >= *max_retries {
125        return Ok(TraceSpanResult::Failed);
126    }
127
128    match wait_for_trace_spans(&task.uid, trace_wait_timeout, trace_backoff).await {
129        Ok(spans) => Ok(TraceSpanResult::Ready(spans)),
130        Err(DriftError::TraceSpansNotAvailable(_)) => {
131            PostgresClient::reschedule_genai_eval_record(pool, &task.uid, trace_reschedule_delay)
132                .await?;
133            Ok(TraceSpanResult::Reschedule)
134        }
135        Err(e) => Err(e),
136    }
137}
138
139/// Poller struct for processing GenAI drift records
140/// A few different things going on here:
141/// 1. Poll the database for "pending" GenAI drift records
142/// 2. For each record, check if trace spans are needed and available
143/// 3. If spans are needed but not available, reschedule the record for later processing
144/// 4. If spans are available or not needed, process the record using GenAIEvaluator
145/// 5. Update the record status to "processed" or "failed" based on the outcome
146pub struct GenAIPoller {
147    db_pool: Pool<Postgres>,
148    max_retries: i32,
149    trace_wait_timeout: Duration,
150    trace_backoff: Duration,
151    trace_reschedule_delay: Duration,
152}
153
154impl GenAIPoller {
155    pub fn new(
156        db_pool: &Pool<Postgres>,
157        max_retries: i32,
158        trace_wait_timeout: Duration,
159        trace_backoff: Duration,
160        trace_reschedule_delay: Duration,
161    ) -> Self {
162        GenAIPoller {
163            db_pool: db_pool.clone(),
164            max_retries,
165            trace_wait_timeout,
166            trace_backoff,
167            trace_reschedule_delay,
168        }
169    }
170
171    #[instrument(skip_all)]
172    pub async fn process_event_record(
173        &mut self,
174        record: &EvalRecord,
175        profile: &GenAIEvalProfile,
176        spans: Arc<Vec<TraceSpan>>,
177    ) -> Result<EvalSet, DriftError> {
178        debug!("Processing workflow");
179
180        // create arc mutex for profile
181        let profile = Arc::new(profile.clone());
182
183        match GenAIEvaluator::process_event_record(record, profile, spans).await {
184            Ok(result_set) => {
185                // insert task results first
186                PostgresClient::insert_eval_task_results_batch(
187                    &self.db_pool,
188                    &result_set.records,
189                    &record.entity_id,
190                )
191                .await
192                .inspect_err(|e| {
193                    error!("Failed to insert LLM task results: {:?}", e);
194                })?;
195
196                // insert workflow record
197                PostgresClient::insert_genai_eval_workflow_record(
198                    &self.db_pool,
199                    &result_set.inner,
200                    &record.entity_id,
201                )
202                .await
203                .inspect_err(|e| {
204                    error!("Failed to insert GenAI workflow record: {:?}", e);
205                })?;
206
207                return Ok(result_set);
208            }
209            Err(e) => {
210                error!("Failed to process drift record: {:?}", e);
211                return Err(DriftError::GenAIEvaluatorError(e.to_string()));
212            }
213        };
214    }
215
216    #[instrument(skip_all)]
217    pub async fn do_poll(&mut self) -> Result<bool, DriftError> {
218        let task = PostgresClient::get_pending_genai_eval_record(&self.db_pool).await?;
219
220        let Some(task) = task else {
221            return Ok(false);
222        };
223
224        debug!("Processing genai drift record for profile: {}", task.uid);
225
226        let mut genai_profile = if let Some(profile) =
227            PostgresClient::get_drift_profile(&self.db_pool, &task.entity_id).await?
228        {
229            let genai_profile: GenAIEvalProfile =
230                serde_json::from_value(profile).inspect_err(|e| {
231                    error!("Failed to deserialize GenAI drift profile: {:?}", e);
232                })?;
233            genai_profile
234        } else {
235            error!("No GenAI drift profile found for {}", task.uid);
236            return Ok(false);
237        };
238
239        let mut retry_count = 0;
240        if let Some(workflow) = &mut genai_profile.workflow {
241            workflow.reset_agents().await.inspect_err(|e| {
242                error!("Failed to reset agents: {:?}", e);
243            })?;
244        }
245
246        let spans = if genai_profile.has_trace_assertions() {
247            match wait_for_trace_spans_with_reschedule(
248                &self.db_pool,
249                &task,
250                &self.max_retries,
251                self.trace_wait_timeout,
252                self.trace_backoff,
253                self.trace_reschedule_delay,
254            )
255            .await?
256            {
257                TraceSpanResult::Ready(spans) => spans,
258                TraceSpanResult::Reschedule => {
259                    debug!(
260                        "Traces not yet available for task {}, rescheduled",
261                        task.uid
262                    );
263                    return Ok(true);
264                }
265                TraceSpanResult::Failed => {
266                    error!("Max retries exceeded for task {}", task.uid);
267                    PostgresClient::update_genai_eval_record_status(
268                        &self.db_pool,
269                        &task,
270                        Status::Failed,
271                        &0,
272                    )
273                    .await?;
274                    return Err(DriftError::TraceSpansNotAvailable(task.uid.clone()));
275                }
276            }
277        } else {
278            Arc::new(vec![])
279        };
280
281        loop {
282            match self
283                .process_event_record(&task, &genai_profile, spans.clone())
284                .await
285            {
286                Ok(result_set) => {
287                    PostgresClient::update_genai_eval_record_status(
288                        &self.db_pool,
289                        &task,
290                        Status::Processed,
291                        &result_set.inner.duration_ms,
292                    )
293                    .await?;
294                    break;
295                }
296                Err(e) => {
297                    error!(
298                        "Failed to process drift record (attempt {}): {:?}",
299                        retry_count + 1,
300                        e
301                    );
302
303                    retry_count += 1;
304                    if retry_count >= self.max_retries {
305                        // Update the record status to error
306                        PostgresClient::update_genai_eval_record_status(
307                            &self.db_pool,
308                            &task,
309                            Status::Failed,
310                            &0,
311                        )
312                        .await?;
313                        return Err(DriftError::GenAIEvaluatorError(e.to_string()));
314                    } else {
315                        // Exponential backoff before retrying
316                        let val = 100 * 2_i64.pow(retry_count as u32);
317                        sleep(Duration::milliseconds(val).to_std()?).await;
318                    }
319                }
320            }
321        }
322
323        Ok(true)
324    }
325
326    #[instrument(skip_all)]
327    pub async fn poll_for_tasks(&mut self) -> Result<(), DriftError> {
328        let result = self.do_poll().await;
329
330        // silent error handling
331        match result {
332            Ok(true) => {
333                debug!("Successfully processed drift record");
334                Ok(())
335            }
336            Ok(false) => {
337                sleep(Duration::seconds(1).to_std()?).await;
338                Ok(())
339            }
340            Err(e) => {
341                error!("Error processing drift record: {:?}", e);
342                Ok(())
343            }
344        }
345    }
346}