use crate::error::DriftError;
use chrono::Duration;
use scouter_dataframe::parquet::tracing::service::get_trace_span_service;
use scouter_evaluate::evaluate::GenAIEvaluator;
use scouter_sql::sql::aggregator::get_trace_summary_service;
use scouter_sql::sql::traits::{GenAIDriftSqlLogic, ProfileSqlLogic};
use scouter_sql::PostgresClient;
use scouter_types::genai::{EvalSet, GenAIEvalProfile};
use scouter_types::sql::{TraceFilters, TraceSpan};
use scouter_types::{EvalRecord, Status, TraceId};
use sqlx::{Pool, Postgres};
use std::sync::Arc;
use tokio::time::sleep;
use tracing::{debug, error, instrument};
enum TraceSpanResult {
Ready(Arc<Vec<TraceSpan>>),
Reschedule,
Failed,
}
#[instrument(skip_all)]
async fn wait_for_trace_spans(
task_uid: &str,
max_wait: Duration,
initial_backoff: Duration,
) -> Result<Arc<Vec<TraceSpan>>, DriftError> {
let start = chrono::Utc::now();
let mut backoff = initial_backoff;
let summary_service = get_trace_summary_service().ok_or_else(|| {
DriftError::GenAIEvaluatorError("TraceSummaryService not initialized".to_string())
})?;
let span_service = get_trace_span_service().ok_or_else(|| {
DriftError::GenAIEvaluatorError("TraceSpanService not initialized".to_string())
})?;
loop {
let filters = TraceFilters {
queue_uid: Some(task_uid.to_string()),
limit: Some(1),
..Default::default()
};
match summary_service
.query_service
.get_paginated_traces(&filters)
.await
{
Ok(response) if !response.items.is_empty() => {
let trace_id_hex = &response.items[0].trace_id;
debug!(
"Found trace summary for task {}, trace_id={}",
task_uid, trace_id_hex
);
let trace_id_bytes = TraceId::hex_to_bytes(trace_id_hex).map_err(|e| {
DriftError::GenAIEvaluatorError(format!("Invalid trace_id hex: {}", e))
})?;
match span_service
.query_service
.get_trace_spans(Some(trace_id_bytes.as_slice()), None, None, None, None)
.await
{
Ok(spans) if !spans.is_empty() => {
debug!("Found {} spans for task {}", spans.len(), task_uid);
return Ok(Arc::new(spans));
}
Ok(_) => {
debug!(
"Trace summary found but spans not yet available for {}",
task_uid
);
}
Err(e) => {
error!("Error fetching spans from Delta Lake: {:?}", e);
}
}
}
Ok(_) => {
}
Err(e) => {
error!("Error querying trace summaries: {:?}", e);
}
}
if (chrono::Utc::now() - start) >= max_wait {
error!(
"Timeout waiting for trace spans after {:?} for task {}",
max_wait, task_uid
);
return Err(DriftError::TraceSpansNotAvailable(task_uid.to_string()));
}
debug!(
"No spans found yet for {}, waiting {:?} before retry",
task_uid, backoff
);
sleep(backoff.to_std().unwrap()).await;
backoff = std::cmp::min(backoff * 2, Duration::seconds(5));
}
}
#[instrument(skip_all)]
async fn wait_for_trace_spans_with_reschedule(
pool: &Pool<Postgres>,
task: &EvalRecord,
max_retries: &i32,
trace_wait_timeout: Duration,
trace_backoff: Duration,
trace_reschedule_delay: Duration,
) -> Result<TraceSpanResult, DriftError> {
let retry_count = task.retry_count;
if retry_count >= *max_retries {
return Ok(TraceSpanResult::Failed);
}
match wait_for_trace_spans(&task.uid, trace_wait_timeout, trace_backoff).await {
Ok(spans) => Ok(TraceSpanResult::Ready(spans)),
Err(DriftError::TraceSpansNotAvailable(_)) => {
PostgresClient::reschedule_genai_eval_record(pool, &task.uid, trace_reschedule_delay)
.await?;
Ok(TraceSpanResult::Reschedule)
}
Err(e) => Err(e),
}
}
pub struct GenAIPoller {
db_pool: Pool<Postgres>,
max_retries: i32,
trace_wait_timeout: Duration,
trace_backoff: Duration,
trace_reschedule_delay: Duration,
}
impl GenAIPoller {
pub fn new(
db_pool: &Pool<Postgres>,
max_retries: i32,
trace_wait_timeout: Duration,
trace_backoff: Duration,
trace_reschedule_delay: Duration,
) -> Self {
GenAIPoller {
db_pool: db_pool.clone(),
max_retries,
trace_wait_timeout,
trace_backoff,
trace_reschedule_delay,
}
}
#[instrument(skip_all)]
pub async fn process_event_record(
&mut self,
record: &EvalRecord,
profile: &GenAIEvalProfile,
spans: Arc<Vec<TraceSpan>>,
) -> Result<EvalSet, DriftError> {
debug!("Processing workflow");
let profile = Arc::new(profile.clone());
match GenAIEvaluator::process_event_record(record, profile, spans).await {
Ok(result_set) => {
PostgresClient::insert_eval_task_results_batch(
&self.db_pool,
&result_set.records,
&record.entity_id,
)
.await
.inspect_err(|e| {
error!("Failed to insert LLM task results: {:?}", e);
})?;
PostgresClient::insert_genai_eval_workflow_record(
&self.db_pool,
&result_set.inner,
&record.entity_id,
)
.await
.inspect_err(|e| {
error!("Failed to insert GenAI workflow record: {:?}", e);
})?;
return Ok(result_set);
}
Err(e) => {
error!("Failed to process drift record: {:?}", e);
return Err(DriftError::GenAIEvaluatorError(e.to_string()));
}
};
}
#[instrument(skip_all)]
pub async fn do_poll(&mut self) -> Result<bool, DriftError> {
let task = PostgresClient::get_pending_genai_eval_record(&self.db_pool).await?;
let Some(task) = task else {
return Ok(false);
};
debug!("Processing genai drift record for profile: {}", task.uid);
let mut genai_profile = if let Some(profile) =
PostgresClient::get_drift_profile(&self.db_pool, &task.entity_id).await?
{
let genai_profile: GenAIEvalProfile =
serde_json::from_value(profile).inspect_err(|e| {
error!("Failed to deserialize GenAI drift profile: {:?}", e);
})?;
genai_profile
} else {
error!("No GenAI drift profile found for {}", task.uid);
return Ok(false);
};
let mut retry_count = 0;
if let Some(workflow) = &mut genai_profile.workflow {
workflow.reset_agents().await.inspect_err(|e| {
error!("Failed to reset agents: {:?}", e);
})?;
}
let spans = if genai_profile.has_trace_assertions() {
match wait_for_trace_spans_with_reschedule(
&self.db_pool,
&task,
&self.max_retries,
self.trace_wait_timeout,
self.trace_backoff,
self.trace_reschedule_delay,
)
.await?
{
TraceSpanResult::Ready(spans) => spans,
TraceSpanResult::Reschedule => {
debug!(
"Traces not yet available for task {}, rescheduled",
task.uid
);
return Ok(true);
}
TraceSpanResult::Failed => {
error!("Max retries exceeded for task {}", task.uid);
PostgresClient::update_genai_eval_record_status(
&self.db_pool,
&task,
Status::Failed,
&0,
)
.await?;
return Err(DriftError::TraceSpansNotAvailable(task.uid.clone()));
}
}
} else {
Arc::new(vec![])
};
loop {
match self
.process_event_record(&task, &genai_profile, spans.clone())
.await
{
Ok(result_set) => {
PostgresClient::update_genai_eval_record_status(
&self.db_pool,
&task,
Status::Processed,
&result_set.inner.duration_ms,
)
.await?;
break;
}
Err(e) => {
error!(
"Failed to process drift record (attempt {}): {:?}",
retry_count + 1,
e
);
retry_count += 1;
if retry_count >= self.max_retries {
PostgresClient::update_genai_eval_record_status(
&self.db_pool,
&task,
Status::Failed,
&0,
)
.await?;
return Err(DriftError::GenAIEvaluatorError(e.to_string()));
} else {
let val = 100 * 2_i64.pow(retry_count as u32);
sleep(Duration::milliseconds(val).to_std()?).await;
}
}
}
}
Ok(true)
}
#[instrument(skip_all)]
pub async fn poll_for_tasks(&mut self) -> Result<(), DriftError> {
let result = self.do_poll().await;
match result {
Ok(true) => {
debug!("Successfully processed drift record");
Ok(())
}
Ok(false) => {
sleep(Duration::seconds(1).to_std()?).await;
Ok(())
}
Err(e) => {
error!("Error processing drift record: {:?}", e);
Ok(())
}
}
}
}