use langchainrust::{
AgentState, StateUpdate, GraphBuilder, START, END,
langgraph::{
GraphPersistence, MemoryPersistence, FilePersistence,
GraphDefinition, NodeDefinition, EdgeDefinition, NodeType, EdgeType,
},
};
#[cfg(feature = "mongodb-persistence")]
use langchainrust::langgraph::{MongoPersistence, MongoConfig};
use tempfile::TempDir;
use std::sync::Arc;
#[path = "../common/mod.rs"]
mod common;
#[cfg(feature = "mongodb-persistence")]
use common::MongoTestConfig;
#[tokio::test]
async fn test_memory_persistence_save_load() {
let persistence = MemoryPersistence::new();
let definition = GraphDefinition::new("entry".to_string())
.with_id("test-001".to_string())
.with_name("Test Workflow".to_string());
persistence.save("test-001", &definition).await.unwrap();
assert!(persistence.exists("test-001").await.unwrap());
let loaded = persistence.load("test-001").await.unwrap();
assert_eq!(loaded.id, "test-001");
assert_eq!(loaded.name, Some("Test Workflow".to_string()));
assert_eq!(loaded.entry_point, "entry");
}
#[tokio::test]
async fn test_memory_persistence_delete() {
let persistence = MemoryPersistence::new();
let definition = GraphDefinition::new("entry".to_string()).with_id("test-002".to_string());
persistence.save("test-002", &definition).await.unwrap();
assert!(persistence.exists("test-002").await.unwrap());
persistence.delete("test-002").await.unwrap();
assert!(!persistence.exists("test-002").await.unwrap());
}
#[tokio::test]
async fn test_memory_persistence_list() {
let persistence = MemoryPersistence::new();
for i in 1..=3 {
let def = GraphDefinition::new("entry".to_string()).with_id(format!("graph-{}", i));
persistence.save(&format!("graph-{}", i), &def).await.unwrap();
}
let list = persistence.list().await.unwrap();
assert_eq!(list.len(), 3);
assert!(list.contains(&"graph-1".to_string()));
assert!(list.contains(&"graph-2".to_string()));
assert!(list.contains(&"graph-3".to_string()));
}
#[tokio::test]
async fn test_memory_persistence_not_found() {
let persistence = MemoryPersistence::new();
let result = persistence.load("nonexistent").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_file_persistence_save_load() {
let temp_dir = TempDir::new().unwrap();
let path = temp_dir.path().to_str().unwrap();
let persistence = FilePersistence::new(path);
let definition = GraphDefinition::new("process".to_string())
.with_id("file-001".to_string())
.with_name("File Test".to_string());
persistence.save("file-001", &definition).await.unwrap();
assert!(persistence.exists("file-001").await.unwrap());
let loaded = persistence.load("file-001").await.unwrap();
assert_eq!(loaded.id, "file-001");
assert_eq!(loaded.name, Some("File Test".to_string()));
}
#[tokio::test]
async fn test_file_persistence_json_format() {
let temp_dir = TempDir::new().unwrap();
let path = temp_dir.path().to_str().unwrap();
let persistence = FilePersistence::new(path);
let definition = GraphDefinition::new("start".to_string())
.with_id("json-test".to_string())
.with_recursion_limit(100);
persistence.save("json-test", &definition).await.unwrap();
let file_path = format!("{}/json-test.json", path);
let content = std::fs::read_to_string(&file_path).unwrap();
assert!(content.contains("\"id\": \"json-test\""));
assert!(content.contains("\"recursion_limit\": 100"));
assert!(content.contains("\"entry_point\": \"start\""));
}
#[tokio::test]
async fn test_file_persistence_delete() {
let temp_dir = TempDir::new().unwrap();
let path = temp_dir.path().to_str().unwrap();
let persistence = FilePersistence::new(path);
let definition = GraphDefinition::new("entry".to_string()).with_id("del-test".to_string());
persistence.save("del-test", &definition).await.unwrap();
let file_path = format!("{}/del-test.json", path);
assert!(std::path::Path::new(&file_path).exists());
persistence.delete("del-test").await.unwrap();
assert!(!std::path::Path::new(&file_path).exists());
}
#[tokio::test]
async fn test_file_persistence_list() {
let temp_dir = TempDir::new().unwrap();
let path = temp_dir.path().to_str().unwrap();
let persistence = FilePersistence::new(path);
for i in 1..=5 {
let def = GraphDefinition::new("entry".to_string()).with_id(format!("wf-{}", i));
persistence.save(&format!("wf-{}", i), &def).await.unwrap();
}
let list = persistence.list().await.unwrap();
assert_eq!(list.len(), 5);
}
#[test]
fn test_graph_definition_builder() {
let def = GraphDefinition::new("entry_point".to_string())
.with_id("custom-id".to_string())
.with_name("My Workflow".to_string())
.with_recursion_limit(50);
assert_eq!(def.id, "custom-id");
assert_eq!(def.name, Some("My Workflow".to_string()));
assert_eq!(def.entry_point, "entry_point");
assert_eq!(def.recursion_limit, 50);
}
#[test]
fn test_graph_definition_auto_id() {
let def = GraphDefinition::new("entry".to_string());
assert!(!def.id.is_empty());
assert!(uuid::Uuid::parse_str(&def.id).is_ok());
}
#[test]
fn test_node_definition() {
let node = NodeDefinition {
name: "process".to_string(),
node_type: NodeType::Sync,
config: serde_json::json!({"key": "value"}),
};
let json = serde_json::to_string(&node).unwrap();
let parsed: NodeDefinition = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.name, "process");
assert_eq!(parsed.node_type, NodeType::Sync);
}
#[test]
fn test_node_type_serialization() {
let types = vec![NodeType::Sync, NodeType::Async, NodeType::Subgraph, NodeType::Custom];
for t in types {
let json = serde_json::to_string(&t).unwrap();
let parsed: NodeType = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, t);
}
}
#[test]
fn test_edge_definition_fixed() {
let edge = EdgeDefinition::fixed("source".to_string(), "target".to_string());
assert_eq!(edge.edge_type, EdgeType::Fixed);
assert_eq!(edge.source, "source");
assert_eq!(edge.target, Some("target".to_string()));
assert!(edge.targets.is_none());
}
#[test]
fn test_edge_definition_conditional() {
let mut targets = std::collections::HashMap::new();
targets.insert("a".to_string(), "node_a".to_string());
targets.insert("b".to_string(), "node_b".to_string());
let edge = EdgeDefinition::conditional(
"decision".to_string(),
"router".to_string(),
targets.clone(),
Some("default".to_string()),
);
assert_eq!(edge.edge_type, EdgeType::Conditional);
assert_eq!(edge.router_name, Some("router".to_string()));
assert_eq!(edge.conditional_targets, Some(targets));
assert_eq!(edge.default_target, Some("default".to_string()));
}
#[test]
fn test_edge_definition_fan_out() {
let edge = EdgeDefinition::fan_out(
"source".to_string(),
vec!["a".to_string(), "b".to_string(), "c".to_string()],
);
assert_eq!(edge.edge_type, EdgeType::FanOut);
assert_eq!(edge.targets, Some(vec!["a", "b", "c"].into_iter().map(String::from).collect()));
}
#[test]
fn test_edge_definition_fan_in() {
let edge = EdgeDefinition::fan_in(
vec!["a".to_string(), "b".to_string()],
"merge".to_string(),
);
assert_eq!(edge.edge_type, EdgeType::FanIn);
assert_eq!(edge.target, Some("merge".to_string()));
}
#[test]
fn test_edge_type_serialization() {
let types = vec![EdgeType::Fixed, EdgeType::Conditional, EdgeType::FanOut, EdgeType::FanIn];
for t in types {
let json = serde_json::to_string(&t).unwrap();
let parsed: EdgeType = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, t);
}
}
#[tokio::test]
async fn test_compiled_graph_to_definition() {
let graph = GraphBuilder::<AgentState>::new()
.add_node_fn("step1", |s: &AgentState| Ok(StateUpdate::full(s.clone())))
.add_node_fn("step2", |s: &AgentState| Ok(StateUpdate::full(s.clone())))
.add_node_fn("step3", |s: &AgentState| Ok(StateUpdate::full(s.clone())))
.add_edge(START, "step1")
.add_edge("step1", "step2")
.add_edge("step2", "step3")
.add_edge("step3", END)
.compile()
.unwrap();
let definition = graph.to_definition();
assert_eq!(definition.nodes.len(), 3);
assert_eq!(definition.edges.len(), 4);
assert_eq!(definition.entry_point, "step1");
assert_eq!(definition.recursion_limit, 25);
}
#[tokio::test]
async fn test_persistence_roundtrip() {
let graph = GraphBuilder::<AgentState>::new()
.add_node_fn("process", |s: &AgentState| {
let mut state = s.clone();
state.set_output("done".to_string());
Ok(StateUpdate::full(state))
})
.add_edge(START, "process")
.add_edge("process", END)
.compile()
.unwrap();
let definition = graph.to_definition();
let persistence = MemoryPersistence::new();
persistence.save("roundtrip", &definition).await.unwrap();
let loaded = persistence.load("roundtrip").await.unwrap();
assert_eq!(loaded.nodes.len(), definition.nodes.len());
assert_eq!(loaded.edges.len(), definition.edges.len());
assert_eq!(loaded.entry_point, definition.entry_point);
}
#[test]
fn test_definition_timestamps() {
use chrono::Utc;
let before = Utc::now();
let def = GraphDefinition::new("entry".to_string());
let after = Utc::now();
assert!(def.created_at >= before);
assert!(def.created_at <= after);
assert_eq!(def.updated_at, def.created_at);
}
#[test]
fn test_definition_metadata() {
let mut def = GraphDefinition::new("entry".to_string());
def.metadata.insert("version".to_string(), serde_json::json!("1.0"));
def.metadata.insert("author".to_string(), serde_json::json!("test"));
assert_eq!(def.metadata.get("version").unwrap(), "1.0");
}
#[tokio::test]
async fn test_concurrent_saves() {
let persistence = Arc::new(MemoryPersistence::new());
let mut handles = vec![];
for i in 0..10 {
let p = persistence.clone();
handles.push(tokio::spawn(async move {
let def = GraphDefinition::new("entry".to_string()).with_id(format!("concurrent-{}", i));
p.save(&format!("concurrent-{}", i), &def).await.unwrap();
}));
}
for handle in handles {
handle.await.unwrap();
}
let list = persistence.list().await.unwrap();
assert_eq!(list.len(), 10);
}
#[cfg(feature = "mongodb-persistence")]
mod mongo_tests {
use super::*;
use langchainrust::langgraph::{MongoPersistence, MongoConfig};
#[tokio::test]
async fn test_mongo_persistence_save_load() {
let config = MongoTestConfig::get().to_mongo_config();
let persistence = MongoPersistence::new(config).await.unwrap();
let definition = GraphDefinition::new("entry".to_string())
.with_id("mongo-test-001".to_string())
.with_name("MongoDB Test Workflow".to_string());
persistence.save("mongo-test-001", &definition).await.unwrap();
assert!(persistence.exists("mongo-test-001").await.unwrap());
let loaded = persistence.load("mongo-test-001").await.unwrap();
assert_eq!(loaded.id, "mongo-test-001");
assert_eq!(loaded.name, Some("MongoDB Test Workflow".to_string()));
assert_eq!(loaded.entry_point, "entry");
persistence.delete("mongo-test-001").await.unwrap();
}
#[tokio::test]
async fn test_mongo_persistence_delete() {
let config = MongoTestConfig::get().to_mongo_config();
let persistence = MongoPersistence::new(config).await.unwrap();
let definition = GraphDefinition::new("entry".to_string())
.with_id("mongo-test-del".to_string());
persistence.save("mongo-test-del", &definition).await.unwrap();
assert!(persistence.exists("mongo-test-del").await.unwrap());
persistence.delete("mongo-test-del").await.unwrap();
assert!(!persistence.exists("mongo-test-del").await.unwrap());
}
#[tokio::test]
async fn test_mongo_persistence_list() {
let config = MongoTestConfig::get().to_mongo_config();
let persistence = MongoPersistence::new(config).await.unwrap();
for i in 1..=3 {
let _ = persistence.delete(&format!("mongo-list-{}", i)).await;
}
for i in 1..=3 {
let def = GraphDefinition::new("entry".to_string())
.with_id(format!("mongo-list-{}", i));
persistence.save(&format!("mongo-list-{}", i), &def).await.unwrap();
}
let list = persistence.list().await.unwrap();
let mongo_items: Vec<_> = list.iter()
.filter(|id| id.starts_with("mongo-list-"))
.collect();
assert_eq!(mongo_items.len(), 3);
for i in 1..=3 {
let _ = persistence.delete(&format!("mongo-list-{}", i)).await;
}
}
#[tokio::test]
async fn test_mongo_persistence_not_found() {
let config = MongoTestConfig::get().to_mongo_config();
let persistence = MongoPersistence::new(config).await.unwrap();
let result = persistence.load("mongo-nonexistent").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_mongo_persistence_upsert() {
let config = MongoTestConfig::get().to_mongo_config();
let persistence = MongoPersistence::new(config).await.unwrap();
let def1 = GraphDefinition::new("entry".to_string())
.with_id("mongo-upsert".to_string())
.with_name("Original Name".to_string())
.with_recursion_limit(25);
persistence.save("mongo-upsert", &def1).await.unwrap();
let def2 = GraphDefinition::new("entry".to_string())
.with_id("mongo-upsert".to_string())
.with_name("Updated Name".to_string())
.with_recursion_limit(50);
persistence.save("mongo-upsert", &def2).await.unwrap();
let loaded = persistence.load("mongo-upsert").await.unwrap();
assert_eq!(loaded.name, Some("Updated Name".to_string()));
assert_eq!(loaded.recursion_limit, 50);
persistence.delete("mongo-upsert").await.unwrap();
}
#[tokio::test]
async fn test_mongo_concurrent_saves() {
let config = MongoTestConfig::get().to_mongo_config();
let persistence = Arc::new(MongoPersistence::new(config).await.unwrap());
let mut handles = vec![];
for i in 0..10 {
let p = persistence.clone();
handles.push(tokio::spawn(async move {
let def = GraphDefinition::new("entry".to_string())
.with_id(format!("mongo-concurrent-{}", i));
p.save(&format!("mongo-concurrent-{}", i), &def).await.unwrap();
}));
}
for handle in handles {
handle.await.unwrap();
}
let list = persistence.list().await.unwrap();
let concurrent_items: Vec<_> = list.iter()
.filter(|id| id.starts_with("mongo-concurrent-"))
.collect();
assert_eq!(concurrent_items.len(), 10);
for i in 0..10 {
let _ = persistence.delete(&format!("mongo-concurrent-{}", i)).await;
}
}
#[tokio::test]
async fn test_mongo_custom_config() {
let config = MongoTestConfig::get().to_mongo_config();
let persistence = MongoPersistence::new(config).await.unwrap();
assert!(!persistence.database_name().is_empty());
assert!(!persistence.collection_name().is_empty());
let def = GraphDefinition::new("test".to_string())
.with_id("custom-config-test".to_string());
persistence.save("custom-config-test", &def).await.unwrap();
assert!(persistence.exists("custom-config-test").await.unwrap());
persistence.delete("custom-config-test").await.unwrap();
}
#[tokio::test]
async fn test_mongo_persistence_roundtrip() {
let config = MongoTestConfig::get().to_mongo_config();
let persistence = MongoPersistence::new(config).await.unwrap();
let graph = GraphBuilder::<AgentState>::new()
.add_node_fn("process", |s: &AgentState| {
let mut state = s.clone();
state.set_output("done".to_string());
Ok(StateUpdate::full(state))
})
.add_edge(START, "process")
.add_edge("process", END)
.compile()
.unwrap();
let definition = graph.to_definition();
persistence.save("mongo-roundtrip", &definition).await.unwrap();
let loaded = persistence.load("mongo-roundtrip").await.unwrap();
assert_eq!(loaded.nodes.len(), definition.nodes.len());
assert_eq!(loaded.edges.len(), definition.edges.len());
assert_eq!(loaded.entry_point, definition.entry_point);
persistence.delete("mongo-roundtrip").await.unwrap();
}
}