use super::state::SharedState;
use echo_core::agent::Agent;
use echo_core::error::Result;
use futures::future::BoxFuture;
use std::sync::Arc;
use tokio::sync::Mutex;
pub(crate) enum NodeAction {
Agent {
agent: Arc<Mutex<Box<dyn Agent>>>,
input_key: String,
output_key: String,
use_execute: bool,
},
Function(Box<dyn NodeFn>),
Passthrough,
}
pub(crate) trait NodeFn: Send + Sync {
fn call<'a>(&'a self, state: &'a SharedState) -> BoxFuture<'a, Result<()>>;
}
struct FnWrapper<F>(F);
impl<F> NodeFn for FnWrapper<F>
where
F: for<'a> Fn(&'a SharedState) -> BoxFuture<'a, Result<()>> + Send + Sync,
{
fn call<'a>(&'a self, state: &'a SharedState) -> BoxFuture<'a, Result<()>> {
(self.0)(state)
}
}
#[allow(dead_code)]
pub(crate) struct Node {
pub name: String,
pub action: NodeAction,
}
impl Node {
pub fn agent(
name: impl Into<String>,
agent: impl Agent + 'static,
input_key: impl Into<String>,
output_key: impl Into<String>,
) -> Self {
Self {
name: name.into(),
action: NodeAction::Agent {
agent: Arc::new(Mutex::new(Box::new(agent))),
input_key: input_key.into(),
output_key: output_key.into(),
use_execute: true,
},
}
}
pub fn agent_with_mode(
name: impl Into<String>,
agent: impl Agent + 'static,
input_key: impl Into<String>,
output_key: impl Into<String>,
use_execute: bool,
) -> Self {
Self {
name: name.into(),
action: NodeAction::Agent {
agent: Arc::new(Mutex::new(Box::new(agent))),
input_key: input_key.into(),
output_key: output_key.into(),
use_execute,
},
}
}
pub fn agent_shared(
name: impl Into<String>,
agent: Arc<Mutex<Box<dyn Agent>>>,
input_key: impl Into<String>,
output_key: impl Into<String>,
) -> Self {
Self {
name: name.into(),
action: NodeAction::Agent {
agent,
input_key: input_key.into(),
output_key: output_key.into(),
use_execute: true,
},
}
}
pub fn agent_shared_with_mode(
name: impl Into<String>,
agent: Arc<Mutex<Box<dyn Agent>>>,
input_key: impl Into<String>,
output_key: impl Into<String>,
use_execute: bool,
) -> Self {
Self {
name: name.into(),
action: NodeAction::Agent {
agent,
input_key: input_key.into(),
output_key: output_key.into(),
use_execute,
},
}
}
pub fn function<F>(name: impl Into<String>, f: F) -> Self
where
F: for<'a> Fn(&'a SharedState) -> BoxFuture<'a, Result<()>> + Send + Sync + 'static,
{
Self {
name: name.into(),
action: NodeAction::Function(Box::new(FnWrapper(f))),
}
}
pub fn passthrough(name: impl Into<String>) -> Self {
Self {
name: name.into(),
action: NodeAction::Passthrough,
}
}
pub async fn execute(&self, state: &SharedState) -> Result<()> {
match &self.action {
NodeAction::Agent {
agent,
input_key,
output_key,
use_execute,
} => {
let input = state.get::<String>(input_key).unwrap_or_default();
let agent = agent.lock().await;
let output = if *use_execute {
agent.execute(&input).await?
} else {
agent.chat(&input).await?
};
state.merge_overwrite(&SharedState::from_values(
[(
output_key.to_string(),
serde_json::Value::String(output.clone()),
)]
.into_iter()
.collect(),
))?;
state.push_message(echo_core::llm::types::Message::assistant(output))?;
Ok(())
}
NodeAction::Function(f) => f.call(state).await,
NodeAction::Passthrough => Ok(()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_function_node() {
let node = Node::function("double", |state: &SharedState| {
Box::pin(async move {
let x: i64 = state.get("input").unwrap_or(0);
let _ = state.set("output", x * 2);
Ok(())
})
});
let state = SharedState::new();
let _ = state.set("input", 21i64);
node.execute(&state).await.unwrap();
assert_eq!(state.get::<i64>("output"), Some(42));
}
#[tokio::test]
async fn test_passthrough_node() {
let node = Node::passthrough("noop");
let state = SharedState::new();
let _ = state.set("x", 1);
node.execute(&state).await.unwrap();
assert_eq!(state.get::<i64>("x"), Some(1)); }
}