1use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10
11use crate::checkpoint::CheckpointSaver;
12use crate::interrupt::ResumeValue;
13use crate::observability::{
14 CachePolicy as LlmCachePolicy, GraphLifecycleCallback, MetricsCollector,
15};
16use crate::pregel::{BudgetConfig, BudgetTracker, Durability};
17use crate::runtime::Heartbeat;
18use crate::store::Store;
19
20#[derive(Clone, Default)]
22pub struct RunnableConfig {
23 pub thread_id: Option<String>,
25
26 pub checkpoint_id: Option<String>,
28
29 pub recursion_limit: usize,
31
32 pub max_parallel_tasks: usize,
34
35 pub run_name: Option<String>,
37
38 pub graph_name: Option<String>,
40
41 pub run_id: Option<String>,
48
49 pub checkpoint_ns: Option<crate::checkpoint::CheckpointNamespace>,
51
52 pub cache: Option<CacheConfig>,
54
55 pub tags: Vec<String>,
57
58 pub metadata: HashMap<String, serde_json::Value>,
60
61 pub cancellation_token: Option<tokio_util::sync::CancellationToken>,
63
64 pub budget: Option<BudgetConfig>,
66
67 pub durability: Option<Durability>,
69
70 #[allow(
72 clippy::type_complexity,
73 reason = "trait object callback requires full signature"
74 )]
75 pub node_finished_callback: Option<Arc<dyn Fn(&str) + Send + Sync>>,
76
77 pub resume_value: Option<ResumeValue>,
82
83 pub interrupt_before: Option<Vec<String>>,
85
86 pub interrupt_after: Option<Vec<String>>,
88
89 pub metrics_collector: Option<Arc<dyn MetricsCollector>>,
91
92 pub callback_handler: Option<Arc<dyn GraphLifecycleCallback>>,
98
99 pub llm_cache_policy: Option<LlmCachePolicy>,
101
102 pub heartbeat: Option<Heartbeat>,
108
109 pub budget_tracker: Option<Arc<BudgetTracker>>,
115
116 pub resource_limits: Option<ResourceLimits>,
118}
119
120impl std::fmt::Debug for RunnableConfig {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 f.debug_struct("RunnableConfig")
123 .field("thread_id", &self.thread_id)
124 .field("checkpoint_id", &self.checkpoint_id)
125 .field("recursion_limit", &self.recursion_limit)
126 .field("max_parallel_tasks", &self.max_parallel_tasks)
127 .field("run_name", &self.run_name)
128 .field("graph_name", &self.graph_name)
129 .field("run_id", &self.run_id)
130 .field("checkpoint_ns", &self.checkpoint_ns)
131 .field("cache", &self.cache)
132 .field("tags", &self.tags)
133 .field("metadata", &self.metadata)
134 .field(
135 "cancellation_token",
136 &self
137 .cancellation_token
138 .as_ref()
139 .map(|_| "CancellationToken"),
140 )
141 .field("budget", &self.budget)
142 .field("durability", &self.durability)
143 .field(
144 "node_finished_callback",
145 &self.node_finished_callback.as_ref().map(|_| "<fn>"),
146 )
147 .field("resume_value", &self.resume_value)
148 .field("interrupt_before", &self.interrupt_before)
149 .field("interrupt_after", &self.interrupt_after)
150 .field(
151 "metrics_collector",
152 &self
153 .metrics_collector
154 .as_ref()
155 .map(|_| "<MetricsCollector>"),
156 )
157 .field(
158 "callback_handler",
159 &self
160 .callback_handler
161 .as_ref()
162 .map(|_| "<GraphLifecycleCallback>"),
163 )
164 .field(
165 "llm_cache_policy",
166 &self.llm_cache_policy.as_ref().map(|_| "<CachePolicy>"),
167 )
168 .field("heartbeat", &self.heartbeat.as_ref().map(|_| "<Heartbeat>"))
169 .field(
170 "budget_tracker",
171 &self.budget_tracker.as_ref().map(|_| "<BudgetTracker>"),
172 )
173 .field("resource_limits", &self.resource_limits)
174 .finish()
175 }
176}
177
178impl RunnableConfig {
179 #[must_use]
181 pub fn new() -> Self {
182 Self {
183 recursion_limit: 25,
184 max_parallel_tasks: 100,
185 heartbeat: None,
186 ..Default::default()
187 }
188 }
189
190 #[must_use]
192 pub fn with_thread_id(mut self, id: impl Into<String>) -> Self {
193 self.thread_id = Some(id.into());
194 self
195 }
196
197 #[must_use]
199 pub fn with_checkpoint_id(mut self, id: impl Into<String>) -> Self {
200 self.checkpoint_id = Some(id.into());
201 self
202 }
203
204 #[must_use]
206 pub fn with_run_id(mut self, id: impl Into<String>) -> Self {
207 self.run_id = Some(id.into());
208 self
209 }
210
211 #[must_use]
213 pub const fn with_recursion_limit(mut self, limit: usize) -> Self {
214 self.recursion_limit = limit;
215 self
216 }
217
218 #[must_use]
220 pub const fn with_max_parallel_tasks(mut self, max: usize) -> Self {
221 self.max_parallel_tasks = max;
222 self
223 }
224
225 #[must_use]
227 pub fn with_run_name(mut self, name: impl Into<String>) -> Self {
228 self.run_name = Some(name.into());
229 self
230 }
231
232 #[must_use]
234 pub fn with_graph_name(mut self, name: impl Into<String>) -> Self {
235 self.graph_name = Some(name.into());
236 self
237 }
238
239 #[must_use]
241 pub fn with_checkpoint_ns(mut self, ns: crate::checkpoint::CheckpointNamespace) -> Self {
242 self.checkpoint_ns = Some(ns);
243 self
244 }
245
246 #[must_use]
248 pub fn with_cache(mut self, cache: CacheConfig) -> Self {
249 self.cache = Some(cache);
250 self
251 }
252
253 #[must_use]
255 pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
256 self.tags.push(tag.into());
257 self
258 }
259
260 #[must_use]
262 pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
263 self.metadata.insert(key.into(), value);
264 self
265 }
266
267 #[must_use]
269 pub fn with_cancellation_token(mut self, token: tokio_util::sync::CancellationToken) -> Self {
270 self.cancellation_token = Some(token);
271 self
272 }
273
274 #[must_use]
276 pub fn with_budget(mut self, budget: BudgetConfig) -> Self {
277 self.budget = Some(budget);
278 self
279 }
280
281 #[must_use]
292 pub fn with_interrupt_before(mut self, nodes: Vec<String>) -> Self {
293 self.interrupt_before = Some(nodes);
294 self
295 }
296
297 #[must_use]
308 pub fn with_interrupt_after(mut self, nodes: Vec<String>) -> Self {
309 self.interrupt_after = Some(nodes);
310 self
311 }
312
313 #[must_use]
327 pub fn with_metrics_collector(mut self, collector: Arc<dyn MetricsCollector>) -> Self {
328 self.metrics_collector = Some(collector);
329 self
330 }
331
332 #[must_use]
346 pub fn with_callback_handler(mut self, handler: Arc<dyn GraphLifecycleCallback>) -> Self {
347 self.callback_handler = Some(handler);
348 self
349 }
350
351 #[must_use]
353 pub fn with_llm_cache_policy(mut self, policy: LlmCachePolicy) -> Self {
354 self.llm_cache_policy = Some(policy);
355 self
356 }
357
358 #[must_use]
363 pub const fn budget_tracker(&self) -> Option<&Arc<BudgetTracker>> {
364 self.budget_tracker.as_ref()
365 }
366
367 #[must_use]
380 pub const fn with_resource_limits(mut self, limits: ResourceLimits) -> Self {
381 self.resource_limits = Some(limits);
382 self
383 }
384}
385
386#[derive(Clone, Debug)]
388pub struct CacheConfig {
389 pub policy: CachePolicy,
391}
392
393#[derive(Clone)]
398pub struct CachePolicy {
399 #[allow(
405 clippy::type_complexity,
406 reason = "trait object requires full signature"
407 )]
408 pub key_func: Option<Arc<dyn Fn(&serde_json::Value, &RunnableConfig) -> String + Send + Sync>>,
409
410 pub ttl: Option<Duration>,
412
413 pub max_entries: Option<usize>,
415}
416
417impl std::fmt::Debug for CachePolicy {
418 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
419 f.debug_struct("CachePolicy")
420 .field("key_func", &self.key_func.as_ref().map(|_| "<fn>"))
421 .field("ttl", &self.ttl)
422 .field("max_entries", &self.max_entries)
423 .finish()
424 }
425}
426
427impl Default for CachePolicy {
428 fn default() -> Self {
429 Self::default_policy()
430 }
431}
432
433impl CachePolicy {
434 #[must_use]
436 pub fn default_policy() -> Self {
437 Self {
438 key_func: None,
439 ttl: None,
440 max_entries: None,
441 }
442 }
443
444 #[must_use]
448 pub fn ttl(duration: Duration) -> Self {
449 Self {
450 key_func: None,
451 ttl: Some(duration),
452 max_entries: None,
453 }
454 }
455
456 #[must_use]
461 pub fn custom_key(
462 key_func: impl Fn(&serde_json::Value, &RunnableConfig) -> String + Send + Sync + 'static,
463 ) -> Self {
464 Self {
465 key_func: Some(Arc::new(key_func)),
466 ttl: None,
467 max_entries: None,
468 }
469 }
470}
471
472#[derive(Clone, Debug, Default)]
478pub struct TaskConfig {
479 pub retry_policy: Option<crate::graph::RetryPolicy>,
481
482 pub cache_policy: Option<CachePolicy>,
484
485 pub timeout: Option<Duration>,
487
488 pub name: Option<String>,
490}
491
492#[derive(Clone, Default)]
498pub struct EntrypointConfig {
499 pub checkpointer: Option<Arc<dyn CheckpointSaver>>,
501
502 pub store: Option<Arc<dyn Store>>,
504}
505
506impl std::fmt::Debug for EntrypointConfig {
507 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
508 f.debug_struct("EntrypointConfig")
509 .field(
510 "checkpointer",
511 &self.checkpointer.as_ref().map(|_| "<CheckpointSaver>"),
512 )
513 .field("store", &self.store.as_ref().map(|_| "<Store>"))
514 .finish()
515 }
516}
517
518#[derive(Clone, Default)]
532pub struct ResourceLimits {
533 pub max_state_size_bytes: Option<usize>,
535}
536
537impl std::fmt::Debug for ResourceLimits {
538 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
539 f.debug_struct("ResourceLimits")
540 .field("max_state_size_bytes", &self.max_state_size_bytes)
541 .finish()
542 }
543}
544
545impl ResourceLimits {
546 #[must_use]
548 pub fn new() -> Self {
549 Self::default()
550 }
551
552 #[must_use]
554 pub const fn with_max_state_size_bytes(mut self, max: usize) -> Self {
555 self.max_state_size_bytes = Some(max);
556 self
557 }
558}
559
560#[cfg(test)]
561mod tests {
562 use super::*;
563
564 #[test]
565 fn test_runnable_config_new() {
566 let config = RunnableConfig::new();
567 assert_eq!(config.recursion_limit, 25);
568 assert_eq!(config.max_parallel_tasks, 100);
569 assert!(config.thread_id.is_none());
570 assert!(config.checkpoint_id.is_none());
571 assert!(config.cancellation_token.is_none());
572 assert!(config.budget.is_none());
573 assert!(config.durability.is_none());
574 assert!(config.resume_value.is_none());
575 assert!(config.heartbeat.is_none());
576 }
577
578 #[test]
579 fn test_runnable_config_with_cancellation_token() {
580 let token = tokio_util::sync::CancellationToken::new();
581 let config = RunnableConfig::new().with_cancellation_token(token);
582 assert!(config.cancellation_token.is_some());
583 }
584
585 #[test]
586 fn test_runnable_config_with_budget() {
587 let budget = BudgetConfig::new().with_max_tokens(1000);
588 let config = RunnableConfig::new().with_budget(budget);
589 assert!(config.budget.is_some());
590 assert_eq!(config.budget.as_ref().unwrap().max_tokens, Some(1000));
591 }
592
593 #[test]
594 fn test_cache_policy_default() {
595 let policy = CachePolicy::default_policy();
596 assert!(policy.key_func.is_none());
597 assert!(policy.ttl.is_none());
598 assert!(policy.max_entries.is_none());
599 }
600
601 #[test]
602 fn test_cache_policy_ttl() {
603 let policy = CachePolicy::ttl(Duration::from_secs(60));
604 assert!(policy.key_func.is_none());
605 assert_eq!(policy.ttl, Some(Duration::from_secs(60)));
606 assert!(policy.max_entries.is_none());
607 }
608
609 #[test]
610 fn test_cache_policy_custom_key() {
611 let policy =
612 CachePolicy::custom_key(|val, _cfg| format!("key-{}", val.as_str().unwrap_or("")));
613 assert!(policy.key_func.is_some());
614 assert!(policy.ttl.is_none());
615 assert!(policy.max_entries.is_none());
616
617 let config = RunnableConfig::new();
619 let key = (policy.key_func.as_ref().unwrap())(&serde_json::json!("test"), &config);
620 assert_eq!(key, "key-test");
621 }
622
623 #[test]
624 fn test_cache_policy_default_trait() {
625 let policy = CachePolicy::default();
626 assert!(policy.key_func.is_none());
627 assert!(policy.ttl.is_none());
628 assert!(policy.max_entries.is_none());
629 }
630
631 #[test]
632 fn test_cache_policy_debug() {
633 let policy = CachePolicy::ttl(Duration::from_secs(30));
634 let debug_str = format!("{policy:?}");
635 assert!(debug_str.contains("ttl"));
636 assert!(debug_str.contains("30s"));
637 }
638
639 #[test]
640 fn test_task_config_default() {
641 let config = TaskConfig::default();
642 assert!(config.retry_policy.is_none());
643 assert!(config.cache_policy.is_none());
644 assert!(config.timeout.is_none());
645 assert!(config.name.is_none());
646 }
647
648 #[test]
649 fn test_entrypoint_config_default() {
650 let config = EntrypointConfig::default();
651 assert!(config.checkpointer.is_none());
652 assert!(config.store.is_none());
653 }
654
655 #[test]
656 fn test_runnable_config_debug_format() {
657 let config = RunnableConfig::new()
658 .with_thread_id("t1")
659 .with_run_name("test-run");
660 let debug = format!("{config:?}");
661 assert!(debug.contains("t1"));
662 assert!(debug.contains("test-run"));
663 }
664}
665
666