use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use tokio::sync::Mutex;
use uuid::Uuid;
use chrono::{DateTime, Utc};
#[async_trait]
pub trait GraphPersistence: Send + Sync {
async fn save(&self, id: &str, definition: &GraphDefinition) -> Result<(), PersistenceError>;
async fn load(&self, id: &str) -> Result<GraphDefinition, PersistenceError>;
async fn delete(&self, id: &str) -> Result<(), PersistenceError>;
async fn exists(&self, id: &str) -> Result<bool, PersistenceError>;
async fn list(&self) -> Result<Vec<String>, PersistenceError>;
}
#[derive(Debug, thiserror::Error)]
pub enum PersistenceError {
#[error("Graph '{0}' not found")]
NotFound(String),
#[error("Serialization error: {0}")]
SerializationError(String),
#[error("Deserialization error: {0}")]
DeserializationError(String),
#[error("IO error: {0}")]
IoError(String),
#[error("Invalid graph definition: {0}")]
InvalidDefinition(String),
#[error("MongoDB error: {0}")]
MongoError(String),
#[error("Connection error: {0}")]
ConnectionError(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphDefinition {
pub id: String,
pub name: Option<String>,
pub entry_point: String,
pub nodes: Vec<NodeDefinition>,
pub edges: Vec<EdgeDefinition>,
pub routers: Vec<RouterDefinition>,
pub recursion_limit: usize,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
pub metadata: HashMap<String, serde_json::Value>,
}
impl GraphDefinition {
pub fn new(entry_point: String) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4().to_string(),
name: None,
entry_point,
nodes: Vec::new(),
edges: Vec::new(),
routers: Vec::new(),
recursion_limit: 25,
created_at: now,
updated_at: now,
metadata: HashMap::new(),
}
}
pub fn with_id(mut self, id: String) -> Self {
self.id = id;
self
}
pub fn with_name(mut self, name: String) -> Self {
self.name = Some(name);
self
}
pub fn with_recursion_limit(mut self, limit: usize) -> Self {
self.recursion_limit = limit;
self
}
pub fn add_node(&mut self, node: NodeDefinition) {
self.nodes.push(node);
self.updated_at = Utc::now();
}
pub fn add_edge(&mut self, edge: EdgeDefinition) {
self.edges.push(edge);
self.updated_at = Utc::now();
}
pub fn add_router(&mut self, router: RouterDefinition) {
self.routers.push(router);
self.updated_at = Utc::now();
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeDefinition {
pub name: String,
pub node_type: NodeType,
pub config: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum NodeType {
Sync,
Async,
Subgraph,
Custom,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EdgeDefinition {
pub edge_type: EdgeType,
pub source: String,
pub target: Option<String>,
pub targets: Option<Vec<String>>,
pub router_name: Option<String>,
pub conditional_targets: Option<HashMap<String, String>>,
pub default_target: Option<String>,
pub sources: Option<Vec<String>>,
}
impl EdgeDefinition {
pub fn fixed(source: String, target: String) -> Self {
Self {
edge_type: EdgeType::Fixed,
source,
target: Some(target),
targets: None,
router_name: None,
conditional_targets: None,
default_target: None,
sources: None,
}
}
pub fn conditional(
source: String,
router_name: String,
targets: HashMap<String, String>,
default_target: Option<String>,
) -> Self {
Self {
edge_type: EdgeType::Conditional,
source,
target: None,
targets: None,
router_name: Some(router_name),
conditional_targets: Some(targets),
default_target,
sources: None,
}
}
pub fn fan_out(source: String, targets: Vec<String>) -> Self {
Self {
edge_type: EdgeType::FanOut,
source,
target: None,
targets: Some(targets),
router_name: None,
conditional_targets: None,
default_target: None,
sources: None,
}
}
pub fn fan_in(sources: Vec<String>, target: String) -> Self {
Self {
edge_type: EdgeType::FanIn,
source: "__fan_in__".to_string(),
target: Some(target),
targets: None,
router_name: None,
conditional_targets: None,
default_target: None,
sources: Some(sources),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum EdgeType {
Fixed,
Conditional,
FanOut,
FanIn,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RouterDefinition {
pub name: String,
pub router_type: String,
pub routes: Vec<String>,
pub config: serde_json::Value,
}
pub struct MemoryPersistence {
graphs: Mutex<HashMap<String, GraphDefinition>>,
}
impl MemoryPersistence {
pub fn new() -> Self {
Self {
graphs: Mutex::new(HashMap::new()),
}
}
}
impl Default for MemoryPersistence {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl GraphPersistence for MemoryPersistence {
async fn save(&self, id: &str, definition: &GraphDefinition) -> Result<(), PersistenceError> {
let mut graphs = self.graphs.lock().await;
graphs.insert(id.to_string(), definition.clone());
Ok(())
}
async fn load(&self, id: &str) -> Result<GraphDefinition, PersistenceError> {
let graphs = self.graphs.lock().await;
graphs.get(id)
.cloned()
.ok_or_else(|| PersistenceError::NotFound(id.to_string()))
}
async fn delete(&self, id: &str) -> Result<(), PersistenceError> {
let mut graphs = self.graphs.lock().await;
graphs.remove(id)
.map(|_| ())
.ok_or_else(|| PersistenceError::NotFound(id.to_string()))
}
async fn exists(&self, id: &str) -> Result<bool, PersistenceError> {
let graphs = self.graphs.lock().await;
Ok(graphs.contains_key(id))
}
async fn list(&self) -> Result<Vec<String>, PersistenceError> {
let graphs = self.graphs.lock().await;
Ok(graphs.keys().cloned().collect())
}
}
pub struct FilePersistence {
directory: PathBuf,
}
impl FilePersistence {
pub fn new(directory: impl Into<PathBuf>) -> Self {
let dir = directory.into();
if !dir.exists() {
std::fs::create_dir_all(&dir).ok();
}
Self { directory: dir }
}
fn graph_path(&self, id: &str) -> PathBuf {
self.directory.join(format!("{}.json", id))
}
}
impl Default for FilePersistence {
fn default() -> Self {
Self::new(".graph_definitions")
}
}
#[async_trait]
impl GraphPersistence for FilePersistence {
async fn save(&self, id: &str, definition: &GraphDefinition) -> Result<(), PersistenceError> {
let path = self.graph_path(id);
let json = serde_json::to_string_pretty(definition)
.map_err(|e| PersistenceError::SerializationError(e.to_string()))?;
std::fs::write(&path, json)
.map_err(|e| PersistenceError::IoError(e.to_string()))?;
Ok(())
}
async fn load(&self, id: &str) -> Result<GraphDefinition, PersistenceError> {
let path = self.graph_path(id);
if !path.exists() {
return Err(PersistenceError::NotFound(id.to_string()));
}
let json = std::fs::read_to_string(&path)
.map_err(|e| PersistenceError::IoError(e.to_string()))?;
let definition: GraphDefinition = serde_json::from_str(&json)
.map_err(|e| PersistenceError::DeserializationError(e.to_string()))?;
Ok(definition)
}
async fn delete(&self, id: &str) -> Result<(), PersistenceError> {
let path = self.graph_path(id);
if path.exists() {
std::fs::remove_file(&path)
.map_err(|e| PersistenceError::IoError(e.to_string()))?;
}
Ok(())
}
async fn exists(&self, id: &str) -> Result<bool, PersistenceError> {
let path = self.graph_path(id);
Ok(path.exists())
}
async fn list(&self) -> Result<Vec<String>, PersistenceError> {
let mut ids = Vec::new();
let entries = std::fs::read_dir(&self.directory)
.map_err(|e| PersistenceError::IoError(e.to_string()))?;
for entry in entries {
if let Ok(entry) = entry {
let path = entry.path();
if path.extension().map_or(false, |ext| ext == "json") {
if let Some(id) = path.file_stem().and_then(|s| s.to_str()) {
ids.push(id.to_string());
}
}
}
}
Ok(ids)
}
}
#[cfg(feature = "mongodb-persistence")]
mod mongo_impl {
use super::*;
use mongodb::{
Client, Collection,
bson::{doc, Document, from_document, to_document},
options::{ClientOptions, FindOptions},
};
pub struct MongoConfig {
pub uri: String,
pub database: String,
pub collection: String,
}
impl MongoConfig {
pub fn new(uri: String, database: String, collection: String) -> Self {
Self { uri, database, collection }
}
pub fn from_env() -> Self {
Self {
uri: std::env::var("MONGO_URI")
.expect("MONGO_URI environment variable not set"),
database: std::env::var("MONGO_DATABASE")
.unwrap_or_else(|_| "langgraph".to_string()),
collection: std::env::var("MONGO_COLLECTION")
.unwrap_or_else(|_| "graph_definitions".to_string()),
}
}
}
pub struct MongoPersistence {
client: Client,
collection: Collection<Document>,
database_name: String,
collection_name: String,
}
impl MongoPersistence {
pub async fn new(config: MongoConfig) -> Result<Self, PersistenceError> {
let client_options = ClientOptions::parse(&config.uri)
.await
.map_err(|e| PersistenceError::ConnectionError(e.to_string()))?;
let client = Client::with_options(client_options)
.map_err(|e| PersistenceError::ConnectionError(e.to_string()))?;
let database = client.database(&config.database);
let collection = database.collection(&config.collection);
Ok(Self {
client,
collection,
database_name: config.database,
collection_name: config.collection,
})
}
pub async fn from_env() -> Result<Self, PersistenceError> {
let config = MongoConfig::from_env();
Self::new(config).await
}
pub fn client(&self) -> &Client {
&self.client
}
pub fn collection_name(&self) -> &str {
&self.collection_name
}
pub fn database_name(&self) -> &str {
&self.database_name
}
}
#[async_trait]
impl GraphPersistence for MongoPersistence {
async fn save(&self, id: &str, definition: &GraphDefinition) -> Result<(), PersistenceError> {
let doc = to_document(definition)
.map_err(|e| PersistenceError::SerializationError(e.to_string()))?;
self.collection
.update_one(
doc! { "_id": id },
doc! { "$set": doc },
mongodb::options::UpdateOptions::builder()
.upsert(true)
.build(),
)
.await
.map_err(|e| PersistenceError::MongoError(e.to_string()))?;
Ok(())
}
async fn load(&self, id: &str) -> Result<GraphDefinition, PersistenceError> {
let filter = doc! { "_id": id };
let result = self.collection
.find_one(filter, None)
.await
.map_err(|e| PersistenceError::MongoError(e.to_string()))?;
match result {
Some(doc) => {
let definition: GraphDefinition = from_document(doc)
.map_err(|e| PersistenceError::DeserializationError(e.to_string()))?;
Ok(definition)
}
None => Err(PersistenceError::NotFound(id.to_string())),
}
}
async fn delete(&self, id: &str) -> Result<(), PersistenceError> {
let filter = doc! { "_id": id };
let result = self.collection
.delete_one(filter, None)
.await
.map_err(|e| PersistenceError::MongoError(e.to_string()))?;
if result.deleted_count == 0 {
Err(PersistenceError::NotFound(id.to_string()))
} else {
Ok(())
}
}
async fn exists(&self, id: &str) -> Result<bool, PersistenceError> {
let filter = doc! { "_id": id };
let count = self.collection
.count_documents(filter, None)
.await
.map_err(|e| PersistenceError::MongoError(e.to_string()))?;
Ok(count > 0)
}
async fn list(&self) -> Result<Vec<String>, PersistenceError> {
let filter = doc! {};
let options = FindOptions::builder()
.projection(doc! { "_id": 1 })
.build();
let mut cursor = self.collection
.find(filter, options)
.await
.map_err(|e| PersistenceError::MongoError(e.to_string()))?;
let mut ids = Vec::new();
while cursor.advance().await
.map_err(|e| PersistenceError::MongoError(e.to_string()))?
{
let doc = cursor.deserialize_current()
.map_err(|e| PersistenceError::DeserializationError(e.to_string()))?;
if let Some(id) = doc.get_str("_id").ok() {
ids.push(id.to_string());
}
}
Ok(ids)
}
}
}
#[cfg(feature = "mongodb-persistence")]
pub use mongo_impl::{MongoPersistence, MongoConfig};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_node_type_serialization() {
let types = vec![NodeType::Sync, NodeType::Async, NodeType::Subgraph, NodeType::Custom];
for t in types {
let json = serde_json::to_string(&t).unwrap();
let parsed: NodeType = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, t);
}
}
#[test]
fn test_edge_type_serialization() {
let types = vec![EdgeType::Fixed, EdgeType::Conditional, EdgeType::FanOut, EdgeType::FanIn];
for t in types {
let json = serde_json::to_string(&t).unwrap();
let parsed: EdgeType = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, t);
}
}
#[test]
fn test_graph_definition_builder() {
let def = GraphDefinition::new("entry".to_string())
.with_id("test-id".to_string())
.with_name("Test Graph".to_string())
.with_recursion_limit(50);
assert_eq!(def.id, "test-id");
assert_eq!(def.name, Some("Test Graph".to_string()));
assert_eq!(def.entry_point, "entry");
assert_eq!(def.recursion_limit, 50);
}
#[tokio::test]
async fn test_memory_persistence() {
let persistence = MemoryPersistence::new();
let def = GraphDefinition::new("entry".to_string())
.with_id("test-001".to_string());
persistence.save("test-001", &def).await.unwrap();
assert!(persistence.exists("test-001").await.unwrap());
let loaded = persistence.load("test-001").await.unwrap();
assert_eq!(loaded.id, "test-001");
persistence.delete("test-001").await.unwrap();
assert!(!persistence.exists("test-001").await.unwrap());
}
}