use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use crate::checkpoint::CheckpointSaver;
use crate::interrupt::ResumeValue;
use crate::observability::{
CachePolicy as LlmCachePolicy, GraphLifecycleCallback, MetricsCollector,
};
use crate::pregel::{BudgetConfig, BudgetTracker, Durability};
use crate::runtime::Heartbeat;
use crate::store::Store;
#[derive(Clone, Default)]
pub struct RunnableConfig {
pub thread_id: Option<String>,
pub checkpoint_id: Option<String>,
pub recursion_limit: usize,
pub max_parallel_tasks: usize,
pub run_name: Option<String>,
pub graph_name: Option<String>,
pub run_id: Option<String>,
pub checkpoint_ns: Option<crate::checkpoint::CheckpointNamespace>,
pub cache: Option<CacheConfig>,
pub tags: Vec<String>,
pub metadata: HashMap<String, serde_json::Value>,
pub cancellation_token: Option<tokio_util::sync::CancellationToken>,
pub budget: Option<BudgetConfig>,
pub durability: Option<Durability>,
#[allow(
clippy::type_complexity,
reason = "trait object callback requires full signature"
)]
pub node_finished_callback: Option<Arc<dyn Fn(&str) + Send + Sync>>,
pub resume_value: Option<ResumeValue>,
pub interrupt_before: Option<Vec<String>>,
pub interrupt_after: Option<Vec<String>>,
pub metrics_collector: Option<Arc<dyn MetricsCollector>>,
pub callback_handler: Option<Arc<dyn GraphLifecycleCallback>>,
pub llm_cache_policy: Option<LlmCachePolicy>,
pub heartbeat: Option<Heartbeat>,
pub budget_tracker: Option<Arc<BudgetTracker>>,
pub resource_limits: Option<ResourceLimits>,
}
impl std::fmt::Debug for RunnableConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RunnableConfig")
.field("thread_id", &self.thread_id)
.field("checkpoint_id", &self.checkpoint_id)
.field("recursion_limit", &self.recursion_limit)
.field("max_parallel_tasks", &self.max_parallel_tasks)
.field("run_name", &self.run_name)
.field("graph_name", &self.graph_name)
.field("run_id", &self.run_id)
.field("checkpoint_ns", &self.checkpoint_ns)
.field("cache", &self.cache)
.field("tags", &self.tags)
.field("metadata", &self.metadata)
.field(
"cancellation_token",
&self
.cancellation_token
.as_ref()
.map(|_| "CancellationToken"),
)
.field("budget", &self.budget)
.field("durability", &self.durability)
.field(
"node_finished_callback",
&self.node_finished_callback.as_ref().map(|_| "<fn>"),
)
.field("resume_value", &self.resume_value)
.field("interrupt_before", &self.interrupt_before)
.field("interrupt_after", &self.interrupt_after)
.field(
"metrics_collector",
&self
.metrics_collector
.as_ref()
.map(|_| "<MetricsCollector>"),
)
.field(
"callback_handler",
&self
.callback_handler
.as_ref()
.map(|_| "<GraphLifecycleCallback>"),
)
.field(
"llm_cache_policy",
&self.llm_cache_policy.as_ref().map(|_| "<CachePolicy>"),
)
.field("heartbeat", &self.heartbeat.as_ref().map(|_| "<Heartbeat>"))
.field(
"budget_tracker",
&self.budget_tracker.as_ref().map(|_| "<BudgetTracker>"),
)
.field("resource_limits", &self.resource_limits)
.finish()
}
}
impl RunnableConfig {
#[must_use]
pub fn new() -> Self {
Self {
recursion_limit: 25,
max_parallel_tasks: 100,
heartbeat: None,
..Default::default()
}
}
#[must_use]
pub fn with_thread_id(mut self, id: impl Into<String>) -> Self {
self.thread_id = Some(id.into());
self
}
#[must_use]
pub fn with_checkpoint_id(mut self, id: impl Into<String>) -> Self {
self.checkpoint_id = Some(id.into());
self
}
#[must_use]
pub fn with_run_id(mut self, id: impl Into<String>) -> Self {
self.run_id = Some(id.into());
self
}
#[must_use]
pub const fn with_recursion_limit(mut self, limit: usize) -> Self {
self.recursion_limit = limit;
self
}
#[must_use]
pub const fn with_max_parallel_tasks(mut self, max: usize) -> Self {
self.max_parallel_tasks = max;
self
}
#[must_use]
pub fn with_run_name(mut self, name: impl Into<String>) -> Self {
self.run_name = Some(name.into());
self
}
#[must_use]
pub fn with_graph_name(mut self, name: impl Into<String>) -> Self {
self.graph_name = Some(name.into());
self
}
#[must_use]
pub fn with_checkpoint_ns(mut self, ns: crate::checkpoint::CheckpointNamespace) -> Self {
self.checkpoint_ns = Some(ns);
self
}
#[must_use]
pub fn with_cache(mut self, cache: CacheConfig) -> Self {
self.cache = Some(cache);
self
}
#[must_use]
pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
self.tags.push(tag.into());
self
}
#[must_use]
pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
#[must_use]
pub fn with_cancellation_token(mut self, token: tokio_util::sync::CancellationToken) -> Self {
self.cancellation_token = Some(token);
self
}
#[must_use]
pub fn with_budget(mut self, budget: BudgetConfig) -> Self {
self.budget = Some(budget);
self
}
#[must_use]
pub fn with_interrupt_before(mut self, nodes: Vec<String>) -> Self {
self.interrupt_before = Some(nodes);
self
}
#[must_use]
pub fn with_interrupt_after(mut self, nodes: Vec<String>) -> Self {
self.interrupt_after = Some(nodes);
self
}
#[must_use]
pub fn with_metrics_collector(mut self, collector: Arc<dyn MetricsCollector>) -> Self {
self.metrics_collector = Some(collector);
self
}
#[must_use]
pub fn with_callback_handler(mut self, handler: Arc<dyn GraphLifecycleCallback>) -> Self {
self.callback_handler = Some(handler);
self
}
#[must_use]
pub fn with_llm_cache_policy(mut self, policy: LlmCachePolicy) -> Self {
self.llm_cache_policy = Some(policy);
self
}
#[must_use]
pub const fn budget_tracker(&self) -> Option<&Arc<BudgetTracker>> {
self.budget_tracker.as_ref()
}
#[must_use]
pub const fn with_resource_limits(mut self, limits: ResourceLimits) -> Self {
self.resource_limits = Some(limits);
self
}
}
#[derive(Clone, Debug)]
pub struct CacheConfig {
pub policy: CachePolicy,
}
#[derive(Clone)]
pub struct CachePolicy {
#[allow(
clippy::type_complexity,
reason = "trait object requires full signature"
)]
pub key_func: Option<Arc<dyn Fn(&serde_json::Value, &RunnableConfig) -> String + Send + Sync>>,
pub ttl: Option<Duration>,
pub max_entries: Option<usize>,
}
impl std::fmt::Debug for CachePolicy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CachePolicy")
.field("key_func", &self.key_func.as_ref().map(|_| "<fn>"))
.field("ttl", &self.ttl)
.field("max_entries", &self.max_entries)
.finish()
}
}
impl Default for CachePolicy {
fn default() -> Self {
Self::default_policy()
}
}
impl CachePolicy {
#[must_use]
pub fn default_policy() -> Self {
Self {
key_func: None,
ttl: None,
max_entries: None,
}
}
#[must_use]
pub fn ttl(duration: Duration) -> Self {
Self {
key_func: None,
ttl: Some(duration),
max_entries: None,
}
}
#[must_use]
pub fn custom_key(
key_func: impl Fn(&serde_json::Value, &RunnableConfig) -> String + Send + Sync + 'static,
) -> Self {
Self {
key_func: Some(Arc::new(key_func)),
ttl: None,
max_entries: None,
}
}
}
#[derive(Clone, Debug, Default)]
pub struct TaskConfig {
pub retry_policy: Option<crate::graph::RetryPolicy>,
pub cache_policy: Option<CachePolicy>,
pub timeout: Option<Duration>,
pub name: Option<String>,
}
#[derive(Clone, Default)]
pub struct EntrypointConfig {
pub checkpointer: Option<Arc<dyn CheckpointSaver>>,
pub store: Option<Arc<dyn Store>>,
}
impl std::fmt::Debug for EntrypointConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EntrypointConfig")
.field(
"checkpointer",
&self.checkpointer.as_ref().map(|_| "<CheckpointSaver>"),
)
.field("store", &self.store.as_ref().map(|_| "<Store>"))
.finish()
}
}
#[derive(Clone, Default)]
pub struct ResourceLimits {
pub max_state_size_bytes: Option<usize>,
}
impl std::fmt::Debug for ResourceLimits {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ResourceLimits")
.field("max_state_size_bytes", &self.max_state_size_bytes)
.finish()
}
}
impl ResourceLimits {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub const fn with_max_state_size_bytes(mut self, max: usize) -> Self {
self.max_state_size_bytes = Some(max);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_runnable_config_new() {
let config = RunnableConfig::new();
assert_eq!(config.recursion_limit, 25);
assert_eq!(config.max_parallel_tasks, 100);
assert!(config.thread_id.is_none());
assert!(config.checkpoint_id.is_none());
assert!(config.cancellation_token.is_none());
assert!(config.budget.is_none());
assert!(config.durability.is_none());
assert!(config.resume_value.is_none());
assert!(config.heartbeat.is_none());
}
#[test]
fn test_runnable_config_with_cancellation_token() {
let token = tokio_util::sync::CancellationToken::new();
let config = RunnableConfig::new().with_cancellation_token(token);
assert!(config.cancellation_token.is_some());
}
#[test]
fn test_runnable_config_with_budget() {
let budget = BudgetConfig::new().with_max_tokens(1000);
let config = RunnableConfig::new().with_budget(budget);
assert!(config.budget.is_some());
assert_eq!(config.budget.as_ref().unwrap().max_tokens, Some(1000));
}
#[test]
fn test_cache_policy_default() {
let policy = CachePolicy::default_policy();
assert!(policy.key_func.is_none());
assert!(policy.ttl.is_none());
assert!(policy.max_entries.is_none());
}
#[test]
fn test_cache_policy_ttl() {
let policy = CachePolicy::ttl(Duration::from_secs(60));
assert!(policy.key_func.is_none());
assert_eq!(policy.ttl, Some(Duration::from_secs(60)));
assert!(policy.max_entries.is_none());
}
#[test]
fn test_cache_policy_custom_key() {
let policy =
CachePolicy::custom_key(|val, _cfg| format!("key-{}", val.as_str().unwrap_or("")));
assert!(policy.key_func.is_some());
assert!(policy.ttl.is_none());
assert!(policy.max_entries.is_none());
let config = RunnableConfig::new();
let key = (policy.key_func.as_ref().unwrap())(&serde_json::json!("test"), &config);
assert_eq!(key, "key-test");
}
#[test]
fn test_cache_policy_default_trait() {
let policy = CachePolicy::default();
assert!(policy.key_func.is_none());
assert!(policy.ttl.is_none());
assert!(policy.max_entries.is_none());
}
#[test]
fn test_cache_policy_debug() {
let policy = CachePolicy::ttl(Duration::from_secs(30));
let debug_str = format!("{policy:?}");
assert!(debug_str.contains("ttl"));
assert!(debug_str.contains("30s"));
}
#[test]
fn test_task_config_default() {
let config = TaskConfig::default();
assert!(config.retry_policy.is_none());
assert!(config.cache_policy.is_none());
assert!(config.timeout.is_none());
assert!(config.name.is_none());
}
#[test]
fn test_entrypoint_config_default() {
let config = EntrypointConfig::default();
assert!(config.checkpointer.is_none());
assert!(config.store.is_none());
}
#[test]
fn test_runnable_config_debug_format() {
let config = RunnableConfig::new()
.with_thread_id("t1")
.with_run_name("test-run");
let debug = format!("{config:?}");
assert!(debug.contains("t1"));
assert!(debug.contains("test-run"));
}
}