use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use serde_json::Value as JsonValue;
use langgraph_checkpoint::config::RunnableConfig;
use super::base::{Runnable, RunnableError};
use super::callable::RunnableCallable;
pub type NodeFnFuture = Pin<Box<dyn Future<Output = Result<JsonValue, RunnableError>> + Send>>;
pub trait IntoNodeFunction {
fn into_runnable(self, name: &str) -> Arc<dyn Runnable>;
}
impl<F, Fut> IntoNodeFunction for F
where
F: Fn(JsonValue, RunnableConfig) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<JsonValue, RunnableError>> + Send + 'static,
{
fn into_runnable(self, name: &str) -> Arc<dyn Runnable> {
Arc::new(RunnableCallable::new(name, self))
}
}
pub struct SyncNodeFn<F>(pub F);
impl<F> IntoNodeFunction for SyncNodeFn<F>
where
F: Fn(&JsonValue, &RunnableConfig) -> Result<JsonValue, RunnableError>
+ Send + Sync + 'static,
{
fn into_runnable(self, name: &str) -> Arc<dyn Runnable> {
Arc::new(RunnableCallable::new_sync(name, self.0))
}
}
impl IntoNodeFunction for Arc<dyn Runnable> {
fn into_runnable(self, _name: &str) -> Arc<dyn Runnable> {
self
}
}
pub struct NodeFn1<F>(pub F);
impl<F> IntoNodeFunction for NodeFn1<F>
where
F: Fn(&JsonValue) -> Result<JsonValue, RunnableError> + Send + Sync + 'static,
{
fn into_runnable(self, name: &str) -> Arc<dyn Runnable> {
let f = self.0;
Arc::new(RunnableCallable::new_sync(name, move |input: &JsonValue, _config: &RunnableConfig| {
f(input)
}))
}
}
pub struct RoutingFn<F>(pub F);
impl<F> IntoNodeFunction for RoutingFn<F>
where
F: Fn(&JsonValue) -> String + Send + Sync + 'static,
{
fn into_runnable(self, name: &str) -> Arc<dyn Runnable> {
let f = self.0;
Arc::new(RunnableCallable::new_sync(name, move |input: &JsonValue, _config: &RunnableConfig| {
let route = f(input);
Ok(JsonValue::String(route))
}))
}
}
#[macro_export]
macro_rules! node_fn {
($f:expr) => {
$crate::runnable::SyncNodeFn($f)
};
}
#[macro_export]
macro_rules! routing {
($f:expr) => {
$crate::runnable::RoutingFn($f)
};
}
#[macro_export]
macro_rules! conditional_edges {
($graph:expr, $source:expr, $route_fn:expr, $($key:expr => $val:expr),+ $(,)?) => {
$graph.add_conditional_edges(
$source,
$crate::runnable::RoutingFn($route_fn),
Some({
let mut map = std::collections::HashMap::new();
$(map.insert($key.to_string(), $val.to_string());)+
map
}),
)
};
}