use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::any::Any;
use std::collections::HashMap;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct SourceName(pub String);
impl SourceName {
pub fn new(name: impl Into<String>) -> Self {
Self(name.into())
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl fmt::Display for SourceName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<&str> for SourceName {
fn from(s: &str) -> Self {
Self(s.to_string())
}
}
impl From<String> for SourceName {
fn from(s: String) -> Self {
Self(s)
}
}
pub fn serialize<T: Serialize>(value: &T) -> Result<Vec<u8>, GraphError> {
bincode::serialize(value).map_err(|e| GraphError::Serialization(e.to_string()))
}
pub fn deserialize<T: DeserializeOwned>(bytes: &[u8]) -> Result<T, GraphError> {
bincode::deserialize(bytes).map_err(|e| GraphError::Deserialization(e.to_string()))
}
#[derive(Debug, Clone)]
pub struct InputCache {
entries: HashMap<SourceName, Vec<u8>>,
}
impl InputCache {
pub fn new() -> Self {
Self {
entries: HashMap::new(),
}
}
pub fn update(&mut self, source: SourceName, bytes: Vec<u8>) {
self.entries.insert(source, bytes);
}
pub fn get<T: DeserializeOwned>(&self, name: &str) -> Option<Result<T, GraphError>> {
let bytes = self.entries.get(&SourceName::new(name))?;
Some(deserialize::<T>(bytes))
}
pub fn has(&self, name: &str) -> bool {
self.entries.contains_key(&SourceName::new(name))
}
pub fn get_raw(&self, name: &str) -> Option<&[u8]> {
self.entries
.get(&SourceName::new(name))
.map(|v| v.as_slice())
}
pub fn snapshot(&self) -> InputCache {
self.clone()
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn replace_all(&mut self, other: InputCache) {
self.entries = other.entries;
}
pub fn sources(&self) -> Vec<&SourceName> {
self.entries.keys().collect()
}
pub fn entries_raw(&self) -> &HashMap<SourceName, Vec<u8>> {
&self.entries
}
pub fn entries_as_json(&self) -> HashMap<String, String> {
self.entries
.iter()
.map(|(name, bytes)| {
let value = if cfg!(debug_assertions) {
serde_json::from_slice::<serde_json::Value>(bytes)
.map(|v| v.to_string())
.unwrap_or_else(|_| hex_encode(bytes))
} else {
hex_encode(bytes)
};
(name.as_str().to_string(), value)
})
.collect()
}
}
impl Default for InputCache {
fn default() -> Self {
Self::new()
}
}
fn hex_encode(bytes: &[u8]) -> String {
bytes.iter().map(|b| format!("{:02x}", b)).collect()
}
#[derive(Debug)]
pub enum GraphResult {
Completed { outputs: Vec<Box<dyn Any + Send>> },
Error(GraphError),
}
impl GraphResult {
pub fn completed(outputs: Vec<Box<dyn Any + Send>>) -> Self {
Self::Completed { outputs }
}
pub fn error(err: GraphError) -> Self {
Self::Error(err)
}
pub fn is_completed(&self) -> bool {
matches!(self, Self::Completed { .. })
}
pub fn is_error(&self) -> bool {
matches!(self, Self::Error(_))
}
}
#[derive(Debug, thiserror::Error)]
pub enum GraphError {
#[error("serialization failed: {0}")]
Serialization(String),
#[error("deserialization failed: {0}")]
Deserialization(String),
#[error("missing input: source '{0}' not found in cache")]
MissingInput(String),
#[error("node execution failed: {0}")]
NodeExecution(String),
#[error("graph execution failed: {0}")]
Execution(String),
}
pub type CompiledGraphFn =
Arc<dyn Fn(InputCache) -> Pin<Box<dyn Future<Output = GraphResult> + Send>> + Send + Sync>;
pub struct ComputationGraphRegistration {
pub graph_fn: CompiledGraphFn,
pub trigger_reactor: Option<String>,
pub accumulator_names: Vec<String>,
pub reaction_mode: String,
}
pub type ComputationGraphConstructor = Box<dyn Fn() -> ComputationGraphRegistration + Send + Sync>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReactionMode {
WhenAny,
WhenAll,
}
impl ReactionMode {
pub const fn as_str(&self) -> &'static str {
match self {
ReactionMode::WhenAny => "when_any",
ReactionMode::WhenAll => "when_all",
}
}
}
impl fmt::Display for ReactionMode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
pub trait Reactor {
const NAME: &'static str;
const ACCUMULATORS: &'static [&'static str];
const REACTION_MODE: ReactionMode;
}
#[derive(Debug, Clone)]
pub struct ReactorRegistration {
pub name: String,
pub accumulator_names: Vec<String>,
pub reaction_mode: ReactionMode,
}
pub type ReactorConstructor = Box<dyn Fn() -> ReactorRegistration + Send + Sync>;
pub trait Graph {
const NAME: &'static str;
const IS_TRIGGERLESS: bool;
}
pub mod types {
pub use crate::{deserialize, serialize, GraphError, GraphResult, InputCache, SourceName};
}