use super::Callable;
use crate::graph::CompiledGraph;
use async_trait::async_trait;
use std::sync::Arc;
pub struct GraphCallable {
name: String,
graph: Arc<CompiledGraph>,
}
impl GraphCallable {
pub fn new(name: impl Into<String>, graph: Arc<CompiledGraph>) -> Self {
Self {
name: name.into(),
graph,
}
}
pub fn graph(&self) -> &CompiledGraph {
&self.graph
}
}
#[async_trait]
impl Callable for GraphCallable {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> Option<&str> {
None }
async fn run(&self, input: &str) -> anyhow::Result<String> {
let state = self.graph.run(input).await?;
if let Some(s) = state.as_str() {
Ok(s.to_string())
} else {
Ok(serde_json::to_string(&state.data)?)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::CompiledGraph;
use std::collections::HashMap;
#[tokio::test]
async fn test_graph_callable_name() {
let graph = Arc::new(CompiledGraph {
nodes: HashMap::new(),
edges: vec![],
conditional_edges: vec![],
entry_point: "start".to_string(),
});
let callable = GraphCallable::new("test-graph", graph);
assert_eq!(callable.name(), "test-graph");
}
#[tokio::test]
async fn test_graph_callable_run() {
let graph = Arc::new(CompiledGraph {
nodes: HashMap::new(),
edges: vec![],
conditional_edges: vec![],
entry_point: "start".to_string(),
});
let callable = GraphCallable::new("test-graph", graph);
let result = callable.run("test input").await;
assert!(result.is_err());
}
}