use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use async_trait::async_trait;
use serde_json::Value as JsonValue;
use langgraph_checkpoint::config::RunnableConfig;
use super::base::{Runnable, RunnableError};
pub type BoxedFn = Arc<
dyn Fn(
JsonValue,
RunnableConfig,
) -> Pin<Box<dyn Future<Output = Result<JsonValue, RunnableError>> + Send>>
+ Send
+ Sync,
>;
pub struct RunnableCallable {
name: String,
func: BoxedFn,
}
impl RunnableCallable {
pub fn new<F, Fut>(name: impl Into<String>, f: F) -> Self
where
F: Fn(JsonValue, RunnableConfig) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<JsonValue, RunnableError>> + Send + 'static,
{
Self {
name: name.into(),
func: Arc::new(move |input, config| Box::pin(f(input, config))),
}
}
pub fn new_sync<F>(name: impl Into<String>, f: F) -> Self
where
F: Fn(&JsonValue, &RunnableConfig) -> Result<JsonValue, RunnableError> + Send + Sync + 'static,
{
let f = Arc::new(f);
Self {
name: name.into(),
func: Arc::new(move |input, config| {
let f = f.clone();
Box::pin(async move { f(&input, &config) })
}),
}
}
}
#[async_trait]
impl Runnable for RunnableCallable {
fn invoke(&self, input: &JsonValue, config: &RunnableConfig) -> Result<JsonValue, RunnableError> {
let func = self.func.clone();
let input = input.clone();
let config = config.clone();
match tokio::runtime::Handle::try_current() {
Ok(handle) => handle.block_on(crate::config::with_config(config.clone(), func(input, config))),
Err(_) => tokio::runtime::Runtime::new()
.unwrap()
.block_on(func(input, config)),
}
}
async fn ainvoke(&self, input: &JsonValue, config: &RunnableConfig) -> Result<JsonValue, RunnableError> {
let func = self.func.clone();
let input = input.clone();
let config_inner = config.clone();
crate::config::with_config(config.clone(), func(input, config_inner)).await
}
fn name(&self) -> &str {
&self.name
}
}