use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
use super::compiled_subgraph::{CompiledSubgraph, StateProjector};
use crate::error::GraphError;
use crate::graph::Graph;
use crate::state::state_lens::StateLens;
use crate::state::workflow_state::{MergeStrategy, WorkflowState};
use crate::stream_emitter::StreamSink;
use tokio_util::sync::CancellationToken;
pub struct SubgraphSpec<
Outer: WorkflowState,
Inner: WorkflowState,
M: MergeStrategy<Inner>,
L: StateLens<Outer, Inner>,
> {
pub graph: Arc<Graph<Inner, M>>,
pub lens: L,
pub max_steps: usize,
_phantom: PhantomData<Outer>,
}
impl<
Outer: WorkflowState,
Inner: WorkflowState,
M: MergeStrategy<Inner>,
L: StateLens<Outer, Inner>,
> SubgraphSpec<Outer, Inner, M, L>
where
Outer: 'static,
Inner: 'static,
M: 'static,
L: 'static,
{
pub fn new(graph: Arc<Graph<Inner, M>>, lens: L) -> Self {
Self {
graph,
lens,
max_steps: 1000, _phantom: PhantomData,
}
}
pub fn max_steps(mut self, max: usize) -> Self {
self.max_steps = max;
self
}
pub fn project<'a>(&self, outer: &'a mut Outer) -> &'a mut Inner {
self.lens.get(outer)
}
pub fn compile(self) -> CompiledSubgraph<Outer> {
let max_steps = self.max_steps;
CompiledSubgraph::new(Arc::new(self), max_steps)
}
}
impl<
Outer: WorkflowState,
Inner: WorkflowState,
M: MergeStrategy<Inner>,
L: StateLens<Outer, Inner>,
> StateProjector<Outer> for SubgraphSpec<Outer, Inner, M, L>
where
Inner: 'static,
M: 'static,
L: 'static,
{
fn execute<'a>(
&'a self,
outer: &'a mut Outer,
stream: Option<Arc<dyn StreamSink>>,
cancel: CancellationToken,
) -> Pin<Box<dyn Future<Output = Result<(), GraphError>> + Send + 'a>> {
Box::pin(async move {
let inner_ref = self.lens.get(outer);
let mut inner_engine =
crate::ExecutionEngine::new(inner_ref, stream, cancel, None, None);
let mut cb = crate::graph::NoopStepCallback;
self.graph
.run_inline(&mut inner_engine, self.max_steps, &mut cb)
.await?;
Ok(())
})
}
fn graph_name(&self) -> &str {
self.graph.name()
}
fn node_count(&self) -> usize {
self.graph.node_names().len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::state::State;
#[derive(Debug, PartialEq)]
struct OuterState {
inner: InnerState,
}
#[derive(Debug, PartialEq)]
struct InnerState {
value: i32,
}
struct TestLens;
impl StateLens<OuterState, InnerState> for TestLens {
fn get<'a>(&self, outer: &'a mut OuterState) -> &'a mut InnerState {
&mut outer.inner
}
}
#[test]
fn test_subgraph_spec_projection() {
let mut outer = OuterState {
inner: InnerState { value: 42 },
};
let lens = TestLens;
let inner = lens.get(&mut outer);
inner.value = 100;
assert_eq!(outer.inner.value, 100);
}
}