use std::collections::HashMap;
use std::sync::Arc;
use petgraph::algo::toposort;
use petgraph::graph::{DiGraph, NodeIndex};
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;
use super::error::{GraphError, StygianError};
use crate::ports::{ScrapingService, ServiceInput, ServiceOutput};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Node {
pub id: String,
pub service: String,
pub config: serde_json::Value,
#[serde(default)]
pub metadata: serde_json::Value,
}
impl Node {
pub fn new(
id: impl Into<String>,
service: impl Into<String>,
config: serde_json::Value,
) -> Self {
Self {
id: id.into(),
service: service.into(),
config,
metadata: serde_json::Value::Null,
}
}
pub fn with_metadata(
id: impl Into<String>,
service: impl Into<String>,
config: serde_json::Value,
metadata: serde_json::Value,
) -> Self {
Self {
id: id.into(),
service: service.into(),
config,
metadata,
}
}
pub fn validate(&self) -> Result<(), StygianError> {
if self.id.is_empty() {
return Err(GraphError::InvalidEdge("Node ID cannot be empty".into()).into());
}
if self.service.is_empty() {
return Err(GraphError::InvalidEdge("Node service type cannot be empty".into()).into());
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Edge {
pub from: String,
pub to: String,
#[serde(default)]
pub config: serde_json::Value,
}
impl Edge {
pub fn new(from: impl Into<String>, to: impl Into<String>) -> Self {
Self {
from: from.into(),
to: to.into(),
config: serde_json::Value::Null,
}
}
pub fn with_config(
from: impl Into<String>,
to: impl Into<String>,
config: serde_json::Value,
) -> Self {
Self {
from: from.into(),
to: to.into(),
config,
}
}
pub fn validate(&self) -> Result<(), StygianError> {
if self.from.is_empty() || self.to.is_empty() {
return Err(GraphError::InvalidEdge("Edge endpoints cannot be empty".into()).into());
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Pipeline {
pub name: String,
pub nodes: Vec<Node>,
pub edges: Vec<Edge>,
#[serde(default)]
pub metadata: serde_json::Value,
}
impl Pipeline {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
nodes: Vec::new(),
edges: Vec::new(),
metadata: serde_json::Value::Null,
}
}
pub fn add_node(&mut self, node: Node) {
self.nodes.push(node);
}
pub fn add_edge(&mut self, edge: Edge) {
self.edges.push(edge);
}
pub fn validate(&self) -> Result<(), StygianError> {
for node in &self.nodes {
node.validate()?;
}
for edge in &self.edges {
edge.validate()?;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct NodeResult {
pub node_id: String,
pub output: ServiceOutput,
}
pub struct DagExecutor {
graph: DiGraph<Node, ()>,
_node_indices: HashMap<String, NodeIndex>,
}
impl DagExecutor {
pub fn new() -> Self {
Self {
graph: DiGraph::new(),
_node_indices: HashMap::new(),
}
}
pub fn from_pipeline(pipeline: &Pipeline) -> Result<Self, StygianError> {
pipeline.validate()?;
let mut graph = DiGraph::new();
let mut node_indices = HashMap::new();
for node in &pipeline.nodes {
let idx = graph.add_node(node.clone());
node_indices.insert(node.id.clone(), idx);
}
for edge in &pipeline.edges {
let from_idx = node_indices
.get(&edge.from)
.ok_or_else(|| GraphError::NodeNotFound(edge.from.clone()))?;
let to_idx = node_indices
.get(&edge.to)
.ok_or_else(|| GraphError::NodeNotFound(edge.to.clone()))?;
graph.add_edge(*from_idx, *to_idx, ());
}
if petgraph::algo::is_cyclic_directed(&graph) {
return Err(GraphError::CycleDetected.into());
}
Ok(Self {
graph,
_node_indices: node_indices,
})
}
pub async fn execute(
&self,
services: &HashMap<String, Arc<dyn ScrapingService>>,
) -> Result<Vec<NodeResult>, StygianError> {
let topo_order = toposort(&self.graph, None).map_err(|_| GraphError::CycleDetected)?;
let waves = self.build_execution_waves(&topo_order);
let results: Arc<Mutex<HashMap<String, ServiceOutput>>> =
Arc::new(Mutex::new(HashMap::new()));
for wave in waves {
let mut handles = Vec::new();
for node_idx in wave {
let node = self.graph[node_idx].clone();
let service = services.get(&node.service).cloned().ok_or_else(|| {
GraphError::InvalidPipeline(format!(
"No service registered for type '{}'",
node.service
))
})?;
let upstream_data = {
let store = results.lock().await;
let mut data = serde_json::Map::new();
for pred_idx in self
.graph
.neighbors_directed(node_idx, petgraph::Direction::Incoming)
{
let pred_id = &self.graph[pred_idx].id;
if let Some(out) = store.get(pred_id) {
data.insert(
pred_id.clone(),
serde_json::Value::String(out.data.clone()),
);
}
}
serde_json::Value::Object(data)
};
let input = ServiceInput {
url: node
.config
.get("url")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
params: upstream_data,
};
let results_clone = Arc::clone(&results);
let node_id = node.id.clone();
handles.push(tokio::spawn(async move {
let output = service.execute(input).await?;
results_clone
.lock()
.await
.insert(node_id.clone(), output.clone());
Ok::<NodeResult, StygianError>(NodeResult { node_id, output })
}));
}
for handle in handles {
handle
.await
.map_err(|e| GraphError::ExecutionFailed(format!("Task join error: {e}")))??;
}
}
let store = results.lock().await;
let final_results = topo_order
.iter()
.filter_map(|idx| {
let node_id = &self.graph[*idx].id;
store.get(node_id).map(|output| NodeResult {
node_id: node_id.clone(),
output: output.clone(),
})
})
.collect();
Ok(final_results)
}
fn build_execution_waves(&self, topo_order: &[NodeIndex]) -> Vec<Vec<NodeIndex>> {
let mut level: HashMap<NodeIndex, usize> = HashMap::new();
for &idx in topo_order {
let max_pred_level = self
.graph
.neighbors_directed(idx, petgraph::Direction::Incoming)
.map(|pred| level.get(&pred).copied().unwrap_or(0) + 1)
.max()
.unwrap_or(0);
level.insert(idx, max_pred_level);
}
let max_level = level.values().copied().max().unwrap_or(0);
let mut waves: Vec<Vec<NodeIndex>> = vec![Vec::new(); max_level + 1];
for (idx, lvl) in level {
if let Some(wave) = waves.get_mut(lvl) {
wave.push(idx);
}
}
waves
}
}
impl Default for DagExecutor {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::domain::error::Result;
#[test]
fn test_node_creation() {
let node = Node::new(
"test",
"http",
serde_json::json!({"url": "https://example.com"}),
);
assert_eq!(node.id, "test");
assert_eq!(node.service, "http");
}
#[test]
fn test_edge_creation() {
let edge = Edge::new("a", "b");
assert_eq!(edge.from, "a");
assert_eq!(edge.to, "b");
}
#[test]
fn test_pipeline_validation() {
let mut pipeline = Pipeline::new("test");
pipeline.add_node(Node::new("fetch", "http", serde_json::json!({})));
pipeline.add_node(Node::new("extract", "ai", serde_json::json!({})));
pipeline.add_edge(Edge::new("fetch", "extract"));
assert!(pipeline.validate().is_ok());
}
#[test]
fn test_cycle_detection() {
let mut pipeline = Pipeline::new("cyclic");
pipeline.add_node(Node::new("a", "http", serde_json::json!({})));
pipeline.add_node(Node::new("b", "http", serde_json::json!({})));
pipeline.add_edge(Edge::new("a", "b"));
pipeline.add_edge(Edge::new("b", "a"));
let result = DagExecutor::from_pipeline(&pipeline);
assert!(matches!(
result,
Err(StygianError::Graph(GraphError::CycleDetected))
));
}
#[tokio::test]
async fn test_diamond_concurrent_execution() -> Result<()> {
use crate::adapters::noop::NoopService;
let mut pipeline = Pipeline::new("diamond");
pipeline.add_node(Node::new("A", "noop", serde_json::json!({"url": ""})));
pipeline.add_node(Node::new("B", "noop", serde_json::json!({"url": ""})));
pipeline.add_node(Node::new("C", "noop", serde_json::json!({"url": ""})));
pipeline.add_node(Node::new("D", "noop", serde_json::json!({"url": ""})));
pipeline.add_edge(Edge::new("A", "B"));
pipeline.add_edge(Edge::new("A", "C"));
pipeline.add_edge(Edge::new("B", "D"));
pipeline.add_edge(Edge::new("C", "D"));
let executor = DagExecutor::from_pipeline(&pipeline)?;
let mut services: HashMap<String, std::sync::Arc<dyn crate::ports::ScrapingService>> =
HashMap::new();
services.insert("noop".to_string(), std::sync::Arc::new(NoopService));
let results = executor.execute(&services).await?;
assert_eq!(results.len(), 4);
let ids: Vec<&str> = results.iter().map(|r| r.node_id.as_str()).collect();
assert!(ids.contains(&"A"));
assert!(ids.contains(&"B"));
assert!(ids.contains(&"C"));
assert!(ids.contains(&"D"));
Ok(())
}
#[tokio::test]
async fn test_missing_service_returns_error() -> Result<()> {
let mut pipeline = Pipeline::new("test");
pipeline.add_node(Node::new("fetch", "http", serde_json::json!({})));
let executor = DagExecutor::from_pipeline(&pipeline)?;
let services: HashMap<String, std::sync::Arc<dyn crate::ports::ScrapingService>> =
HashMap::new();
let result = executor.execute(&services).await;
assert!(result.is_err());
Ok(())
}
}