mod types;
use std::future::Future;
use std::pin::Pin;
use crate::Result;
use crate::graph::builder::NodeContext;
use crate::graph::command::NodeResult;
use crate::graph::compiled::{CompiledGraph, GraphExecution};
use crate::graph::recursion::ChildRun;
type Handler<S, U> = Box<
dyn Fn(S, NodeContext) -> Pin<Box<dyn Future<Output = Result<NodeResult<U>>> + Send>>
+ Send
+ Sync,
>;
pub fn shared_subgraph_node<State>(child: CompiledGraph<State, State>) -> Handler<State, State>
where
State: Clone + Send + Sync + 'static,
{
Box::new(move |state: State, ctx: NodeContext| {
let child = child_for(&child, &ctx);
let recorder = ChildRunRecorder::new(&ctx);
Box::pin(async move {
let execution = child.run(state).await?;
recorder.record(&execution);
Ok(NodeResult::Update(execution.state))
})
})
}
pub fn adapter_subgraph_node<P, PU, C, CU, ToChild, FromChild>(
child: CompiledGraph<C, CU>,
to_child: ToChild,
from_child: FromChild,
) -> Handler<P, PU>
where
P: Clone + Send + Sync + 'static,
PU: Send + 'static,
C: Clone + Send + Sync + 'static,
CU: Send + 'static,
ToChild: Fn(&P) -> C + Send + Sync + Clone + 'static,
FromChild: Fn(&P, C) -> PU + Send + Sync + Clone + 'static,
{
Box::new(move |state: P, ctx: NodeContext| {
let child = child_for(&child, &ctx);
let recorder = ChildRunRecorder::new(&ctx);
let to_child = to_child.clone();
let from_child = from_child.clone();
Box::pin(async move {
let child_input = to_child(&state);
let execution = child.run(child_input).await?;
recorder.record(&execution);
let update = from_child(&state, execution.state);
Ok(NodeResult::Update(update))
})
})
}
fn namespaced<S, U>(child: &CompiledGraph<S, U>, ctx: &NodeContext) -> CompiledGraph<S, U> {
let mut namespace = child.namespace().to_vec();
namespace.push(ctx.node_id.to_string());
child.clone().with_namespace(namespace)
}
fn child_for<S, U>(child: &CompiledGraph<S, U>, ctx: &NodeContext) -> CompiledGraph<S, U> {
namespaced(child, ctx)
.with_recursion_frames(ctx.recursion_frames.clone())
.with_recursion_node(ctx.node_id.clone())
}
struct ChildRunRecorder {
node: crate::harness::ids::NodeId,
sink: Option<crate::graph::recursion::ChildRunSink>,
}
impl ChildRunRecorder {
fn new(ctx: &NodeContext) -> Self {
Self {
node: ctx.node_id.clone(),
sink: ctx.child_runs.clone(),
}
}
fn record<S>(&self, execution: &GraphExecution<S>) {
if let Some(sink) = &self.sink {
sink.record(ChildRun {
node: self.node.clone(),
graph_id: execution.graph_id.clone(),
run_id: execution.run_id.clone(),
root_run_id: execution.root_run_id.clone(),
usage: crate::harness::usage::UsageTotals::default(),
});
}
}
}
#[cfg(test)]
mod test;