use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::any::Any;
use std::collections::HashMap;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct SourceName(pub String);
impl SourceName {
pub fn new(name: impl Into<String>) -> Self {
Self(name.into())
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl fmt::Display for SourceName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<&str> for SourceName {
fn from(s: &str) -> Self {
Self(s.to_string())
}
}
impl From<String> for SourceName {
fn from(s: String) -> Self {
Self(s)
}
}
pub fn serialize<T: Serialize>(value: &T) -> Result<Vec<u8>, GraphError> {
#[cfg(debug_assertions)]
{
serde_json::to_vec(value).map_err(|e| GraphError::Serialization(e.to_string()))
}
#[cfg(not(debug_assertions))]
{
bincode::serialize(value).map_err(|e| GraphError::Serialization(e.to_string()))
}
}
pub fn deserialize<T: DeserializeOwned>(bytes: &[u8]) -> Result<T, GraphError> {
#[cfg(debug_assertions)]
{
serde_json::from_slice(bytes).map_err(|e| GraphError::Deserialization(e.to_string()))
}
#[cfg(not(debug_assertions))]
{
bincode::deserialize(bytes).map_err(|e| GraphError::Deserialization(e.to_string()))
}
}
#[derive(Debug, Clone)]
pub struct InputCache {
entries: HashMap<SourceName, Vec<u8>>,
}
impl InputCache {
pub fn new() -> Self {
Self {
entries: HashMap::new(),
}
}
pub fn update(&mut self, source: SourceName, bytes: Vec<u8>) {
self.entries.insert(source, bytes);
}
pub fn get<T: DeserializeOwned>(&self, name: &str) -> Option<Result<T, GraphError>> {
let bytes = self.entries.get(&SourceName::new(name))?;
Some(deserialize::<T>(bytes))
}
pub fn has(&self, name: &str) -> bool {
self.entries.contains_key(&SourceName::new(name))
}
pub fn get_raw(&self, name: &str) -> Option<&[u8]> {
self.entries
.get(&SourceName::new(name))
.map(|v| v.as_slice())
}
pub fn snapshot(&self) -> InputCache {
self.clone()
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn replace_all(&mut self, other: InputCache) {
self.entries = other.entries;
}
pub fn sources(&self) -> Vec<&SourceName> {
self.entries.keys().collect()
}
pub fn entries_raw(&self) -> &HashMap<SourceName, Vec<u8>> {
&self.entries
}
pub fn entries_as_json(&self) -> HashMap<String, String> {
self.entries
.iter()
.map(|(name, bytes)| {
let value = if cfg!(debug_assertions) {
serde_json::from_slice::<serde_json::Value>(bytes)
.map(|v| v.to_string())
.unwrap_or_else(|_| hex_encode(bytes))
} else {
hex_encode(bytes)
};
(name.as_str().to_string(), value)
})
.collect()
}
}
impl Default for InputCache {
fn default() -> Self {
Self::new()
}
}
fn hex_encode(bytes: &[u8]) -> String {
bytes.iter().map(|b| format!("{:02x}", b)).collect()
}
#[derive(Debug)]
pub enum GraphResult {
Completed { outputs: Vec<Box<dyn Any + Send>> },
Error(GraphError),
}
impl GraphResult {
pub fn completed(outputs: Vec<Box<dyn Any + Send>>) -> Self {
Self::Completed { outputs }
}
pub fn completed_empty() -> Self {
Self::Completed {
outputs: Vec::new(),
}
}
pub fn error(err: GraphError) -> Self {
Self::Error(err)
}
pub fn is_completed(&self) -> bool {
matches!(self, Self::Completed { .. })
}
pub fn is_error(&self) -> bool {
matches!(self, Self::Error(_))
}
}
#[derive(Debug, thiserror::Error)]
pub enum GraphError {
#[error("serialization failed: {0}")]
Serialization(String),
#[error("deserialization failed: {0}")]
Deserialization(String),
#[error("missing input: source '{0}' not found in cache")]
MissingInput(String),
#[error("node execution failed: {0}")]
NodeExecution(String),
#[error("graph execution failed: {0}")]
Execution(String),
}
pub type CompiledGraphFn =
Arc<dyn Fn(InputCache) -> Pin<Box<dyn Future<Output = GraphResult> + Send>> + Send + Sync>;
pub struct ComputationGraphRegistration {
pub graph_fn: CompiledGraphFn,
pub accumulator_names: Vec<String>,
pub reaction_mode: String,
}
pub type ComputationGraphConstructor = Box<dyn Fn() -> ComputationGraphRegistration + Send + Sync>;
pub type GlobalComputationGraphRegistry =
Arc<parking_lot::RwLock<HashMap<String, ComputationGraphConstructor>>>;
static GLOBAL_COMPUTATION_GRAPH_REGISTRY: once_cell::sync::Lazy<GlobalComputationGraphRegistry> =
once_cell::sync::Lazy::new(|| Arc::new(parking_lot::RwLock::new(HashMap::new())));
pub fn register_computation_graph_constructor<F>(graph_name: String, constructor: F)
where
F: Fn() -> ComputationGraphRegistration + Send + Sync + 'static,
{
let mut registry = GLOBAL_COMPUTATION_GRAPH_REGISTRY.write();
registry.insert(graph_name.clone(), Box::new(constructor));
tracing::debug!("Registered computation graph constructor: {}", graph_name);
}
pub fn global_computation_graph_registry() -> GlobalComputationGraphRegistry {
GLOBAL_COMPUTATION_GRAPH_REGISTRY.clone()
}
pub fn list_registered_graphs() -> Vec<String> {
let registry = GLOBAL_COMPUTATION_GRAPH_REGISTRY.read();
registry.keys().cloned().collect()
}
pub fn deregister_computation_graph(graph_name: &str) {
let mut registry = GLOBAL_COMPUTATION_GRAPH_REGISTRY.write();
registry.remove(graph_name);
tracing::debug!("Deregistered computation graph constructor: {}", graph_name);
}
pub mod types {
pub use crate::{deserialize, serialize, GraphError, GraphResult, InputCache, SourceName};
}