use crate::error::PeError;
use crate::lobe::LobeRuntimeServiceFactory;
use crate::phase_store::PhaseStateStore;
use crate::state::{State, StateUpdate};
use serde::{Deserialize, Serialize};
use std::any::Any;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
pub type NodeFuture<U> = Pin<Box<dyn Future<Output = NodeResult<U>> + Send>>;
pub trait NodeFn<S: State>: Send + Sync {
fn call(&self, state: &S, ctx: &NodeContext) -> NodeFuture<S::Update>;
fn name(&self) -> &str;
}
#[derive(Clone)]
pub struct NodeContext {
pub step: u32,
pub recursion_limit: u32,
pub node_name: String,
pub activation: ActivationReason,
pub metadata: HashMap<String, serde_json::Value>,
pub phase_store: PhaseStateStore,
pub stream_sender: Option<Arc<dyn Any + Send + Sync>>,
pub tool_observer: Option<Arc<dyn ToolObserver>>,
pub lobe_runtime_service_factory: Option<Arc<dyn LobeRuntimeServiceFactory>>,
}
impl std::fmt::Debug for NodeContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NodeContext")
.field("step", &self.step)
.field("recursion_limit", &self.recursion_limit)
.field("node_name", &self.node_name)
.field("activation", &self.activation)
.field("has_stream_sender", &self.stream_sender.is_some())
.field("has_tool_observer", &self.tool_observer.is_some())
.field(
"has_lobe_runtime_service_factory",
&self.lobe_runtime_service_factory.is_some(),
)
.finish()
}
}
impl NodeContext {
pub fn remaining_steps(&self) -> u32 {
self.recursion_limit.saturating_sub(self.step)
}
pub fn is_last_step(&self) -> bool {
self.step >= self.recursion_limit
}
}
pub trait NodeObserver: Send + Sync {
fn on_node_start(
&self,
node_name: &str,
step: u32,
) -> Pin<Box<dyn Future<Output = ()> + Send + '_>>;
fn on_node_complete(
&self,
node_name: &str,
step: u32,
duration: std::time::Duration,
) -> Pin<Box<dyn Future<Output = ()> + Send + '_>>;
fn on_node_error(
&self,
node_name: &str,
step: u32,
error: &str,
) -> Pin<Box<dyn Future<Output = ()> + Send + '_>>;
fn on_node_retry(
&self,
_node_name: &str,
_step: u32,
_attempt: u32,
_max_attempts: u32,
) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
Box::pin(async {})
}
}
pub trait ToolObserver: Send + Sync {
fn on_tool_start(
&self,
tool_name: &str,
input_summary: &str,
) -> Pin<Box<dyn Future<Output = ()> + Send + '_>>;
fn on_tool_complete(
&self,
tool_name: &str,
duration: std::time::Duration,
) -> Pin<Box<dyn Future<Output = ()> + Send + '_>>;
fn on_tool_error(
&self,
tool_name: &str,
error: &str,
) -> Pin<Box<dyn Future<Output = ()> + Send + '_>>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub enum ActivationReason {
EntryPoint,
Edge {
from: String,
},
ConditionalEdge {
from: String,
},
Resume,
Retry {
attempt: u32,
},
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum NodeResult<U: StateUpdate> {
Update(U),
Interrupt(InterruptRequest<U>),
Error(PeError),
Converge(ConvergenceSignal<U>),
}
impl<U: StateUpdate> NodeResult<U> {
pub fn into_standard(self) -> NodeResult<U> {
match self {
Self::Converge(signal) => NodeResult::Update(signal.partial_update),
other => other,
}
}
pub fn is_interrupt(&self) -> bool {
matches!(self, NodeResult::Interrupt(_))
}
pub fn is_error(&self) -> bool {
matches!(self, NodeResult::Error(_))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InterruptRequest<U: StateUpdate> {
pub reason: String,
#[serde(skip)]
pub partial_update: Option<U>,
pub resume_point: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConvergenceSignal<U: StateUpdate> {
pub actual_contribution: f64,
pub surprise: f64,
pub quality: f64,
#[serde(skip)]
pub partial_update: U,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HumanInput {
pub approved: bool,
#[serde(default)]
pub feedback: Option<String>,
#[serde(default)]
pub data: Option<serde_json::Value>,
}