use crate::error::DataFrameError;
use arrow::array::AsArray;
use arrow::array::{BooleanBuilder, StringArray};
use arrow::datatypes::DataType;
use arrow::datatypes::UInt32Type;
use arrow_array::types::Float64Type;
use arrow_array::types::TimestampNanosecondType;
use arrow_array::Array;
use arrow_array::RecordBatch;
use arrow_array::StringViewArray;
use chrono::{DateTime, TimeZone, Utc};
use datafusion::error::{DataFusionError, Result};
use datafusion::logical_expr::ScalarFunctionArgs;
use datafusion::logical_expr::{
ColumnarValue, Expr, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature, Volatility,
};
use datafusion::prelude::DataFrame;
use datafusion::scalar::ScalarValue;
use deltalake::logstore::{
default_logstore, logstore_factories, LogStore, LogStoreFactory, ObjectStoreRef, StorageConfig,
};
use deltalake::DeltaResult;
use scouter_types::{BinnedMetric, BinnedMetricStats, BinnedMetrics};
use std::sync::Arc;
use tracing::{debug, error, instrument};
use url::Url;
pub struct ParquetHelper {}
impl ParquetHelper {
#[instrument(skip_all)]
pub fn extract_feature_array(batch: &RecordBatch) -> Result<&StringViewArray, DataFrameError> {
let feature_array = batch
.column_by_name("feature")
.ok_or_else(|| {
error!("Missing 'feature' field in RecordBatch");
DataFrameError::MissingFieldError("feature")
})?
.as_string_view_opt()
.ok_or_else(|| {
error!("Failed to downcast 'feature' field to StringViewArray");
DataFrameError::DowncastError("StringViewArray")
})?;
Ok(feature_array)
}
#[instrument(skip_all)]
pub fn extract_created_at(batch: &RecordBatch) -> Result<Vec<DateTime<Utc>>, DataFrameError> {
let created_at_list = batch
.column_by_name("created_at")
.ok_or_else(|| {
error!("Missing 'created_at' field in RecordBatch");
DataFrameError::MissingFieldError("created_at")
})?
.as_list_opt::<i32>()
.ok_or_else(|| {
error!("Failed to downcast 'created_at' field to ListArray");
DataFrameError::DowncastError("ListArray")
})?;
let created_at_array = created_at_list.value(0);
Ok(created_at_array
.as_primitive::<TimestampNanosecondType>()
.iter()
.filter_map(|ts| ts.map(|t| Utc.timestamp_nanos(t)))
.collect())
}
}
pub struct BinnedMetricsExtractor {}
impl BinnedMetricsExtractor {
#[instrument(skip_all)]
fn extract_stats(batch: &RecordBatch) -> Result<Vec<BinnedMetricStats>, DataFrameError> {
let stats_list = batch
.column_by_name("stats")
.ok_or_else(|| {
error!("Missing 'stats' field in RecordBatch");
DataFrameError::MissingFieldError("stats")
})?
.as_list_opt::<i32>()
.ok_or_else(|| {
error!("Failed to downcast 'stats' field to ListArray");
DataFrameError::DowncastError("ListArray")
})?
.value(0);
let stats_structs = stats_list.as_struct_opt().ok_or_else(|| {
error!("Failed to downcast 'stats' field to StructArray");
DataFrameError::DowncastError("StructArray")
})?;
let avg_array = stats_structs
.column_by_name("avg")
.ok_or_else(|| DataFrameError::MissingFieldError("avg"))
.inspect_err(|e| error!("Failed to get 'avg' field from stats: {:?}", e))?
.as_primitive_opt::<Float64Type>()
.ok_or_else(|| DataFrameError::DowncastError("Float64Array"))?;
let lower_bound_array = stats_structs
.column_by_name("lower_bound")
.ok_or_else(|| DataFrameError::MissingFieldError("lower_bound"))
.inspect_err(|e| error!("Failed to get 'lower_bound' field from stats: {:?}", e))?
.as_primitive_opt::<Float64Type>()
.ok_or_else(|| DataFrameError::DowncastError("Float64Array"))?;
let upper_bound_array = stats_structs
.column_by_name("upper_bound")
.ok_or_else(|| DataFrameError::MissingFieldError("upper_bound"))
.inspect_err(|e| error!("Failed to get 'upper_bound' field from stats: {:?}", e))?
.as_primitive_opt::<Float64Type>()
.ok_or_else(|| DataFrameError::DowncastError("Float64Array"))?;
Ok((0..stats_structs.len())
.map(|i| BinnedMetricStats {
avg: avg_array.value(i),
lower_bound: lower_bound_array.value(i),
upper_bound: upper_bound_array.value(i),
})
.collect())
}
#[instrument(skip_all)]
fn process_metric_record_batch(batch: &RecordBatch) -> Result<BinnedMetric, DataFrameError> {
debug!("Processing metric record batch");
let metric_column = batch.column_by_name("metric").ok_or_else(|| {
error!("Missing 'metric' field in RecordBatch");
DataFrameError::MissingFieldError("metric")
})?;
let metric_name = if let Some(dict_array) = metric_column.as_dictionary_opt::<UInt32Type>()
{
let values = dict_array.values();
let string_values = values.as_string_opt::<i32>().ok_or_else(|| {
error!("Failed to downcast dictionary values to StringArray");
DataFrameError::DowncastError("StringArray")
})?;
let key = dict_array.key(0).ok_or_else(|| {
error!("Failed to get key from dictionary array");
DataFrameError::MissingFieldError("dictionary key")
})?;
string_values.value(key).to_string()
} else if let Some(string_view_array) = metric_column.as_string_view_opt() {
string_view_array.value(0).to_string()
} else if let Some(string_array) = metric_column.as_string_opt::<i32>() {
string_array.value(0).to_string()
} else {
error!("Failed to downcast 'metric' field to any supported string type");
return Err(DataFrameError::DowncastError("String type"));
};
let created_at_list = ParquetHelper::extract_created_at(batch)?;
let stats = Self::extract_stats(batch)?;
Ok(BinnedMetric {
metric: metric_name,
created_at: created_at_list,
stats,
})
}
#[instrument(skip_all)]
pub async fn dataframe_to_binned_metrics(
df: DataFrame,
) -> Result<BinnedMetrics, DataFrameError> {
debug!("Converting DataFrame to binned metrics");
let batches = df.collect().await?;
let metrics: Vec<BinnedMetric> = batches
.iter()
.map(Self::process_metric_record_batch)
.collect::<Result<Vec<_>, _>>()
.inspect_err(|e| {
error!("Failed to process metric record batch: {:?}", e);
})?;
Ok(BinnedMetrics::from_vec(metrics))
}
}
pub(crate) struct PassthroughLogStoreFactory;
impl LogStoreFactory for PassthroughLogStoreFactory {
fn with_options(
&self,
prefixed_store: ObjectStoreRef,
root_store: ObjectStoreRef,
location: &Url,
options: &StorageConfig,
) -> DeltaResult<Arc<dyn LogStore>> {
let store = if location.scheme() == "az" {
let subpath = location.path().trim_start_matches('/');
if subpath.is_empty() {
prefixed_store
} else {
let prefix = object_store::path::Path::from(subpath);
Arc::new(object_store::prefix::PrefixStore::new(
root_store.clone(),
prefix,
)) as ObjectStoreRef
}
} else {
prefixed_store
};
Ok(default_logstore(store, root_store, location, options))
}
}
pub(crate) fn register_cloud_logstore_factories() {
let factories = logstore_factories();
let factory = Arc::new(PassthroughLogStoreFactory) as Arc<dyn LogStoreFactory>;
for scheme in ["gs", "s3", "s3a", "az", "abfs", "abfss"] {
let key = Url::parse(&format!("{}://", scheme)).expect("scheme is a valid URL prefix");
if !factories.contains_key(&key) {
factories.insert(key, factory.clone());
}
}
}
#[derive(Debug)]
struct AttrMatchUdf {
signature: Signature,
}
impl PartialEq for AttrMatchUdf {
fn eq(&self, _other: &Self) -> bool {
true }
}
impl Eq for AttrMatchUdf {}
impl std::hash::Hash for AttrMatchUdf {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name().hash(state);
}
}
impl AttrMatchUdf {
fn new() -> Self {
Self {
signature: Signature::one_of(
vec![
TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8]),
TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
],
Volatility::Immutable,
),
}
}
}
impl ScalarUDFImpl for AttrMatchUdf {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn name(&self) -> &str {
"match_attr"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Boolean)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let args_slice = args.args;
let batch_size = args.number_rows;
let pattern_str = match &args_slice[1] {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(p)))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(p))) => p.clone(),
_ => {
return Err(DataFusionError::Execution(
"match_attr: second arg must be a non-null Utf8 scalar literal".into(),
))
}
};
let inner = pattern_str.trim_matches('%');
match &args_slice[0] {
ColumnarValue::Scalar(s) => {
let matched = match s {
ScalarValue::Utf8(Some(v))
| ScalarValue::LargeUtf8(Some(v))
| ScalarValue::Utf8View(Some(v)) => v.contains(inner),
_ => false,
};
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(matched))))
}
ColumnarValue::Array(arr) => {
let mut builder = BooleanBuilder::with_capacity(batch_size);
if arr.data_type() == &DataType::Utf8View {
let view_arr = arr
.as_any()
.downcast_ref::<arrow_array::StringViewArray>()
.ok_or_else(|| {
DataFusionError::Execution(
"match_attr: expected StringViewArray for search_blob".into(),
)
})?;
for i in 0..arr.len() {
if view_arr.is_null(i) {
builder.append_null();
} else {
builder.append_value(view_arr.value(i).contains(inner));
}
}
} else {
let cast_arr =
arrow::compute::cast(arr.as_ref(), &DataType::Utf8).map_err(|e| {
DataFusionError::Execution(format!(
"match_attr: cast to Utf8 failed: {e}"
))
})?;
let str_arr =
cast_arr
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| {
DataFusionError::Execution(
"match_attr: downcast to StringArray failed".into(),
)
})?;
for i in 0..arr.len() {
if str_arr.is_null(i) {
builder.append_null();
} else {
builder.append_value(str_arr.value(i).contains(inner));
}
}
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
}
}
}
pub fn create_attr_match_udf() -> ScalarUDF {
ScalarUDF::from(AttrMatchUdf::new())
}
pub fn match_attr_expr(search_blob: Expr, pattern: Expr) -> Expr {
create_attr_match_udf().call(vec![search_blob, pattern])
}