use arrow_array::ArrayRef;
use arrow_schema::DataType;
use datafusion::execution::SendableRecordBatchStream;
use smol_str::SmolStr;
use crate::errors::FnError;
#[derive(Clone, Debug)]
pub struct AlgorithmSignature {
pub output_fields: Vec<arrow_schema::Field>,
pub docs: String,
}
#[non_exhaustive]
pub struct AlgorithmContext<'a> {
pub config_json: &'a str,
pub host: Option<&'a dyn AlgorithmHost>,
}
impl std::fmt::Debug for AlgorithmContext<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AlgorithmContext")
.field("config_json", &self.config_json)
.field("host_bound", &self.host.is_some())
.finish()
}
}
impl<'a> AlgorithmContext<'a> {
#[must_use]
pub fn new(config_json: &'a str) -> Self {
Self {
config_json,
host: None,
}
}
#[must_use]
pub fn with_host(mut self, host: &'a dyn AlgorithmHost) -> Self {
self.host = Some(host);
self
}
}
pub trait AlgorithmHost: Send + Sync {
fn as_any(&self) -> &dyn std::any::Any;
}
pub trait AlgorithmProvider: Send + Sync {
fn signature(&self) -> &AlgorithmSignature;
fn run(&self, ctx: AlgorithmContext<'_>) -> Result<SendableRecordBatchStream, FnError>;
}
#[derive(Clone, Debug)]
pub struct PregelSignature {
pub state_type: DataType,
pub message_type: DataType,
pub aggregation_mode: AggregationMode,
pub max_supersteps: Option<u64>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum AggregationMode {
Bsp,
AsyncShared,
AsyncMessaging,
}
#[derive(Debug)]
pub struct ComputeOutcome {
pub halt: bool,
pub outgoing: Vec<(SmolStr, ArrayRef)>,
}
#[derive(Clone, Copy, Debug, Default)]
pub struct PregelStats {
pub active_vertices: u64,
pub messages_sent: u64,
pub last_superstep_ms: u64,
}
pub trait PregelProgramProvider: Send + Sync {
fn signature(&self) -> &PregelSignature;
fn halt(&self, _superstep: u64, _stats: &PregelStats) -> bool {
false
}
}