use crate::error::TraceEngineError;
use crate::parquet::tracing::traits::arrow_schema_to_delta;
use crate::parquet::utils::register_cloud_logstore_factories;
use crate::storage::ObjectStore;
use arrow::array::*;
use arrow::datatypes::*;
use arrow_array::RecordBatch;
use chrono::{DateTime, Duration, Utc};
use datafusion::logical_expr::{col, lit};
use datafusion::prelude::SessionContext;
use deltalake::{DeltaTable, DeltaTableBuilder, TableProperty};
use std::sync::Arc;
use tokio::sync::RwLock as AsyncRwLock;
use tracing::{debug, info, warn};
use url::Url;
const CONTROL_TABLE_NAME: &str = "_scouter_control";
const STALE_LOCK_MINUTES: i64 = 30;
mod status {
pub const IDLE: &str = "idle";
pub const PROCESSING: &str = "processing";
}
#[derive(Debug, Clone)]
pub struct TaskRecord {
pub task_name: String,
pub status: String,
pub pod_id: String,
pub claimed_at: DateTime<Utc>,
pub completed_at: Option<DateTime<Utc>>,
pub next_run_at: DateTime<Utc>,
}
fn control_schema() -> Schema {
Schema::new(vec![
Field::new("task_name", DataType::Utf8, false),
Field::new("status", DataType::Utf8, false),
Field::new("pod_id", DataType::Utf8, false),
Field::new(
"claimed_at",
DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())),
false,
),
Field::new(
"completed_at",
DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())),
true,
),
Field::new(
"next_run_at",
DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())),
false,
),
])
}
fn build_task_batch(
schema: &SchemaRef,
record: &TaskRecord,
) -> Result<RecordBatch, TraceEngineError> {
let task_name = StringArray::from(vec![record.task_name.as_str()]);
let status = StringArray::from(vec![record.status.as_str()]);
let pod_id = StringArray::from(vec![record.pod_id.as_str()]);
let claimed_at = TimestampMicrosecondArray::from(vec![record.claimed_at.timestamp_micros()])
.with_timezone("UTC");
let completed_at = if let Some(ts) = record.completed_at {
TimestampMicrosecondArray::from(vec![Some(ts.timestamp_micros())]).with_timezone("UTC")
} else {
TimestampMicrosecondArray::from(vec![None::<i64>]).with_timezone("UTC")
};
let next_run_at = TimestampMicrosecondArray::from(vec![record.next_run_at.timestamp_micros()])
.with_timezone("UTC");
RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(task_name),
Arc::new(status),
Arc::new(pod_id),
Arc::new(claimed_at),
Arc::new(completed_at),
Arc::new(next_run_at),
],
)
.map_err(Into::into)
}
pub fn get_pod_id() -> String {
std::env::var("HOSTNAME")
.or_else(|_| std::env::var("POD_NAME"))
.unwrap_or_else(|_| format!("local-{}", std::process::id()))
}
pub struct ControlTableEngine {
schema: SchemaRef,
#[allow(dead_code)] table: Arc<AsyncRwLock<DeltaTable>>,
ctx: Arc<SessionContext>,
pod_id: String,
}
impl ControlTableEngine {
pub async fn new(object_store: &ObjectStore, pod_id: String) -> Result<Self, TraceEngineError> {
let schema = Arc::new(control_schema());
let table = build_or_create_control_table(object_store, schema.clone()).await?;
let ctx = object_store.get_session()?;
if let Ok(provider) = table.table_provider().await {
ctx.register_table(CONTROL_TABLE_NAME, provider)?;
} else {
info!("Empty control table at init — deferring registration until first write");
}
Ok(Self {
schema,
table: Arc::new(AsyncRwLock::new(table)),
ctx: Arc::new(ctx),
pod_id,
})
}
pub async fn try_claim_task(&self, task_name: &str) -> Result<bool, TraceEngineError> {
let mut table_guard = self.table.write().await;
if let Err(e) = table_guard.update_incremental(None).await {
debug!("Control table update skipped (new table): {}", e);
}
let _ = self.ctx.deregister_table(CONTROL_TABLE_NAME);
if let Ok(provider) = table_guard.table_provider().await {
self.ctx.register_table(CONTROL_TABLE_NAME, provider)?;
}
let current = self
.read_task(&table_guard_to_ctx(&self.ctx), task_name)
.await?;
let now = Utc::now();
match current {
Some(record) => {
if record.status == status::PROCESSING {
let stale_threshold = now - Duration::minutes(STALE_LOCK_MINUTES);
if record.claimed_at > stale_threshold {
debug!(
"Task '{}' is being processed by pod '{}' (not stale), skipping",
task_name, record.pod_id
);
return Ok(false);
}
warn!(
"Task '{}' claimed by pod '{}' is stale (claimed_at: {}), reclaiming",
task_name, record.pod_id, record.claimed_at
);
}
if now < record.next_run_at {
debug!(
"Task '{}' not due until {}, skipping",
task_name, record.next_run_at
);
return Ok(false);
}
let claimed = TaskRecord {
task_name: task_name.to_string(),
status: status::PROCESSING.to_string(),
pod_id: self.pod_id.clone(),
claimed_at: now,
completed_at: None,
next_run_at: record.next_run_at,
};
match self.write_task_update(&mut table_guard, &claimed).await {
Ok(()) => {
info!("Successfully claimed task '{}'", task_name);
Ok(true)
}
Err(TraceEngineError::DataTableError(ref e))
if e.to_string().contains("Transaction") =>
{
info!("Lost OCC race for task '{}' to another pod", task_name);
Ok(false)
}
Err(e) => Err(e),
}
}
None => {
let claimed = TaskRecord {
task_name: task_name.to_string(),
status: status::PROCESSING.to_string(),
pod_id: self.pod_id.clone(),
claimed_at: now,
completed_at: None,
next_run_at: now, };
match self.write_task_update(&mut table_guard, &claimed).await {
Ok(()) => {
info!("Created and claimed new task '{}'", task_name);
Ok(true)
}
Err(TraceEngineError::DataTableError(ref e))
if e.to_string().contains("Transaction") =>
{
info!("Lost OCC race for new task '{}' to another pod", task_name);
Ok(false)
}
Err(e) => Err(e),
}
}
}
}
pub async fn release_task(
&self,
task_name: &str,
next_run_interval: Duration,
) -> Result<(), TraceEngineError> {
let mut table_guard = self.table.write().await;
let now = Utc::now();
let released = TaskRecord {
task_name: task_name.to_string(),
status: status::IDLE.to_string(),
pod_id: self.pod_id.clone(),
claimed_at: now,
completed_at: Some(now),
next_run_at: now + next_run_interval,
};
self.write_task_update(&mut table_guard, &released).await?;
info!(
"Released task '{}', next run at {}",
task_name, released.next_run_at
);
Ok(())
}
pub async fn release_task_on_failure(&self, task_name: &str) -> Result<(), TraceEngineError> {
let mut table_guard = self.table.write().await;
if let Err(e) = table_guard.update_incremental(None).await {
debug!("Control table update skipped: {}", e);
}
let _ = self.ctx.deregister_table(CONTROL_TABLE_NAME);
if let Ok(provider) = table_guard.table_provider().await {
self.ctx.register_table(CONTROL_TABLE_NAME, provider)?;
}
let current = self
.read_task(&table_guard_to_ctx(&self.ctx), task_name)
.await?;
let now = Utc::now();
let next_run = current.map(|r| r.next_run_at).unwrap_or(now);
let released = TaskRecord {
task_name: task_name.to_string(),
status: status::IDLE.to_string(),
pod_id: self.pod_id.clone(),
claimed_at: now,
completed_at: Some(now),
next_run_at: next_run,
};
self.write_task_update(&mut table_guard, &released).await?;
warn!(
"Released task '{}' after failure, next_run_at unchanged: {}",
task_name, next_run
);
Ok(())
}
pub async fn is_task_due(&self, task_name: &str) -> Result<bool, TraceEngineError> {
let mut table_guard = self.table.write().await;
if let Err(e) = table_guard.update_incremental(None).await {
debug!("Control table update skipped: {}", e);
}
let _ = self.ctx.deregister_table(CONTROL_TABLE_NAME);
if let Ok(provider) = table_guard.table_provider().await {
self.ctx.register_table(CONTROL_TABLE_NAME, provider)?;
}
let current = self
.read_task(&table_guard_to_ctx(&self.ctx), task_name)
.await?;
let now = Utc::now();
match current {
Some(record) => {
if record.status == status::PROCESSING {
let stale_threshold = now - Duration::minutes(STALE_LOCK_MINUTES);
Ok(record.claimed_at <= stale_threshold)
} else {
Ok(now >= record.next_run_at)
}
}
None => Ok(true),
}
}
async fn read_task(
&self,
ctx: &SessionContext,
task_name: &str,
) -> Result<Option<TaskRecord>, TraceEngineError> {
let table_exists = ctx.table_exist(CONTROL_TABLE_NAME)?;
if !table_exists {
return Ok(None);
}
let df = ctx
.table(CONTROL_TABLE_NAME)
.await
.map_err(TraceEngineError::DatafusionError)?;
let df = df
.filter(col("task_name").eq(lit(task_name)))
.map_err(TraceEngineError::DatafusionError)?;
let batches = df
.collect()
.await
.map_err(TraceEngineError::DatafusionError)?;
for batch in &batches {
if batch.num_rows() == 0 {
continue;
}
let get_string = |col_name: &'static str| -> Result<String, TraceEngineError> {
let col = batch
.column_by_name(col_name)
.ok_or(TraceEngineError::DowncastError(col_name))?;
let casted = arrow::compute::cast(col, &DataType::Utf8)
.map_err(TraceEngineError::ArrowError)?;
let arr = casted
.as_any()
.downcast_ref::<StringArray>()
.ok_or(TraceEngineError::DowncastError(col_name))?;
Ok(arr.value(0).to_string())
};
let get_timestamp =
|col_name: &'static str| -> Result<Option<DateTime<Utc>>, TraceEngineError> {
let col = batch
.column_by_name(col_name)
.ok_or(TraceEngineError::DowncastError(col_name))?;
if col.is_null(0) {
return Ok(None);
}
let arr = col
.as_any()
.downcast_ref::<TimestampMicrosecondArray>()
.ok_or(TraceEngineError::DowncastError(col_name))?;
Ok(DateTime::from_timestamp_micros(arr.value(0)))
};
let task_name_val = get_string("task_name")?;
let status_val = get_string("status")?;
let pod_id_val = get_string("pod_id")?;
let claimed_at = get_timestamp("claimed_at")?.unwrap_or_else(Utc::now);
let completed_at = get_timestamp("completed_at")?;
let next_run_at = get_timestamp("next_run_at")?.unwrap_or_else(Utc::now);
return Ok(Some(TaskRecord {
task_name: task_name_val,
status: status_val,
pod_id: pod_id_val,
claimed_at,
completed_at,
next_run_at,
}));
}
Ok(None)
}
async fn write_task_update(
&self,
table_guard: &mut DeltaTable,
record: &TaskRecord,
) -> Result<(), TraceEngineError> {
let batch = build_task_batch(&self.schema, record)?;
debug_assert!(
record
.task_name
.chars()
.all(|c| c.is_alphanumeric() || c == '_'),
"task_name must be alphanumeric + underscore, got: {}",
record.task_name
);
let predicate = format!("task_name = '{}'", record.task_name);
let delete_result = table_guard.clone().delete().with_predicate(predicate).await;
match delete_result {
Ok((updated_table, _metrics)) => {
let updated_table = updated_table
.write(vec![batch])
.with_save_mode(deltalake::protocol::SaveMode::Append)
.await?;
let _ = self.ctx.deregister_table(CONTROL_TABLE_NAME);
if let Ok(provider) = updated_table.table_provider().await {
self.ctx.register_table(CONTROL_TABLE_NAME, provider)?;
}
*table_guard = updated_table;
}
Err(e) => {
let err_msg = e.to_string();
if !err_msg.contains("No data") && !err_msg.contains("empty") {
warn!(
"Delete before write_task_update failed unexpectedly: {}",
err_msg
);
return Err(TraceEngineError::DataTableError(e));
}
let updated_table = table_guard
.clone()
.write(vec![batch])
.with_save_mode(deltalake::protocol::SaveMode::Append)
.await?;
let _ = self.ctx.deregister_table(CONTROL_TABLE_NAME);
if let Ok(provider) = updated_table.table_provider().await {
self.ctx.register_table(CONTROL_TABLE_NAME, provider)?;
}
*table_guard = updated_table;
}
}
Ok(())
}
}
fn table_guard_to_ctx(ctx: &Arc<SessionContext>) -> SessionContext {
ctx.as_ref().clone()
}
async fn build_or_create_control_table(
object_store: &ObjectStore,
schema: SchemaRef,
) -> Result<DeltaTable, TraceEngineError> {
register_cloud_logstore_factories();
let base_url = object_store.get_base_url()?;
let control_url = append_path_to_url(&base_url, CONTROL_TABLE_NAME)?;
info!(
"Loading control table [{}://.../{} ]",
control_url.scheme(),
control_url
.path_segments()
.and_then(|mut s| s.next_back())
.unwrap_or(CONTROL_TABLE_NAME)
);
let store = object_store.as_dyn_object_store();
let is_delta_table = if control_url.scheme() == "file" {
if let Ok(path) = control_url.to_file_path() {
if !path.exists() {
info!("Creating directory for control table: {:?}", path);
std::fs::create_dir_all(&path)?;
}
path.join("_delta_log").exists()
} else {
false
}
} else {
match DeltaTableBuilder::from_url(control_url.clone()) {
Ok(builder) => builder
.with_storage_backend(store.clone(), control_url.clone())
.load()
.await
.is_ok(),
Err(_) => false,
}
};
if is_delta_table {
info!(
"Loaded existing control table [{}://.../{} ]",
control_url.scheme(),
control_url
.path_segments()
.and_then(|mut s| s.next_back())
.unwrap_or(CONTROL_TABLE_NAME)
);
let table = DeltaTableBuilder::from_url(control_url.clone())?
.with_storage_backend(store, control_url)
.load()
.await?;
Ok(table)
} else {
info!("Creating new control table");
let table = DeltaTableBuilder::from_url(control_url.clone())?
.with_storage_backend(store, control_url)
.build()?;
let delta_fields = arrow_schema_to_delta(&schema);
table
.create()
.with_table_name(CONTROL_TABLE_NAME)
.with_columns(delta_fields)
.with_configuration_property(TableProperty::CheckpointInterval, Some("5"))
.await
.map_err(Into::into)
}
}
fn append_path_to_url(base: &Url, segment: &str) -> Result<Url, TraceEngineError> {
let mut url = base.clone();
if !url.path().ends_with('/') {
url.set_path(&format!("{}/", url.path()));
}
url = url.join(segment)?;
Ok(url)
}
#[cfg(test)]
mod tests {
use super::*;
use scouter_settings::ObjectStorageSettings;
fn make_test_object_store(storage_settings: &ObjectStorageSettings) -> ObjectStore {
ObjectStore::new(storage_settings).unwrap()
}
fn cleanup() {
let storage_settings = ObjectStorageSettings::default();
let current_dir = std::env::current_dir().unwrap();
let storage_path = current_dir.join(storage_settings.storage_root());
if storage_path.exists() {
let _ = std::fs::remove_dir_all(storage_path);
}
}
#[tokio::test]
async fn test_control_table_init() -> Result<(), TraceEngineError> {
cleanup();
let settings = ObjectStorageSettings::default();
let object_store = make_test_object_store(&settings);
let engine = ControlTableEngine::new(&object_store, "pod-1".to_string()).await?;
let due = engine.is_task_due("optimize").await?;
assert!(due, "New task should be due (never run before)");
cleanup();
Ok(())
}
#[tokio::test]
async fn test_claim_and_release() -> Result<(), TraceEngineError> {
cleanup();
let settings = ObjectStorageSettings::default();
let object_store = make_test_object_store(&settings);
let engine = ControlTableEngine::new(&object_store, "pod-1".to_string()).await?;
let claimed = engine.try_claim_task("optimize").await?;
assert!(claimed, "First claim should succeed");
let claimed_again = engine.try_claim_task("optimize").await?;
assert!(
!claimed_again,
"Second claim should fail (already processing)"
);
engine.release_task("optimize", Duration::hours(1)).await?;
let due = engine.is_task_due("optimize").await?;
assert!(!due, "Task should not be due yet");
cleanup();
Ok(())
}
#[tokio::test]
async fn test_claim_release_then_due() -> Result<(), TraceEngineError> {
cleanup();
let settings = ObjectStorageSettings::default();
let object_store = make_test_object_store(&settings);
let engine = ControlTableEngine::new(&object_store, "pod-1".to_string()).await?;
let claimed = engine.try_claim_task("vacuum").await?;
assert!(claimed);
engine.release_task("vacuum", Duration::seconds(0)).await?;
let due = engine.is_task_due("vacuum").await?;
assert!(due, "Task should be due after 0-second interval");
let claimed = engine.try_claim_task("vacuum").await?;
assert!(claimed, "Task should be claimable after release");
engine.release_task_on_failure("vacuum").await?;
cleanup();
Ok(())
}
#[tokio::test]
async fn test_multiple_tasks() -> Result<(), TraceEngineError> {
cleanup();
let settings = ObjectStorageSettings::default();
let object_store = make_test_object_store(&settings);
let engine = ControlTableEngine::new(&object_store, "pod-1".to_string()).await?;
let claimed_opt = engine.try_claim_task("optimize").await?;
let claimed_vac = engine.try_claim_task("vacuum").await?;
assert!(claimed_opt, "Optimize claim should succeed");
assert!(claimed_vac, "Vacuum claim should succeed");
engine.release_task("optimize", Duration::hours(24)).await?;
engine.release_task("vacuum", Duration::hours(168)).await?;
cleanup();
Ok(())
}
}