1use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10
11use crate::checkpoint::CheckpointSaver;
12use crate::interrupt::ResumeValue;
13use crate::observability::{
14 CachePolicy as LlmCachePolicy, GraphLifecycleCallback, MetricsCollector,
15};
16use crate::pregel::{BudgetConfig, BudgetTracker, Durability};
17use crate::runtime::Heartbeat;
18use crate::store::Store;
19
20#[derive(Clone, Default)]
22pub struct RunnableConfig {
23 pub thread_id: Option<String>,
25
26 pub checkpoint_id: Option<String>,
28
29 pub recursion_limit: usize,
31
32 pub max_parallel_tasks: usize,
34
35 pub run_name: Option<String>,
37
38 pub graph_name: Option<String>,
40
41 pub run_id: Option<String>,
48
49 pub checkpoint_ns: Option<crate::checkpoint::CheckpointNamespace>,
51
52 pub cache: Option<CacheConfig>,
54
55 pub tags: Vec<String>,
57
58 pub metadata: HashMap<String, serde_json::Value>,
60
61 pub cancellation_token: Option<tokio_util::sync::CancellationToken>,
63
64 pub budget: Option<BudgetConfig>,
66
67 pub durability: Option<Durability>,
69
70 #[allow(
72 clippy::type_complexity,
73 reason = "trait object callback requires full signature"
74 )]
75 pub node_finished_callback: Option<Arc<dyn Fn(&str) + Send + Sync>>,
76
77 pub resume_value: Option<ResumeValue>,
82
83 pub interrupt_before: Option<Vec<String>>,
85
86 pub interrupt_after: Option<Vec<String>>,
88
89 pub metrics_collector: Option<Arc<dyn MetricsCollector>>,
91
92 pub callback_handler: Option<Arc<dyn GraphLifecycleCallback>>,
98
99 pub llm_cache_policy: Option<LlmCachePolicy>,
101
102 pub heartbeat: Option<Heartbeat>,
108
109 pub budget_tracker: Option<Arc<BudgetTracker>>,
115}
116
117impl std::fmt::Debug for RunnableConfig {
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 f.debug_struct("RunnableConfig")
120 .field("thread_id", &self.thread_id)
121 .field("checkpoint_id", &self.checkpoint_id)
122 .field("recursion_limit", &self.recursion_limit)
123 .field("max_parallel_tasks", &self.max_parallel_tasks)
124 .field("run_name", &self.run_name)
125 .field("graph_name", &self.graph_name)
126 .field("run_id", &self.run_id)
127 .field("checkpoint_ns", &self.checkpoint_ns)
128 .field("cache", &self.cache)
129 .field("tags", &self.tags)
130 .field("metadata", &self.metadata)
131 .field(
132 "cancellation_token",
133 &self
134 .cancellation_token
135 .as_ref()
136 .map(|_| "CancellationToken"),
137 )
138 .field("budget", &self.budget)
139 .field("durability", &self.durability)
140 .field(
141 "node_finished_callback",
142 &self.node_finished_callback.as_ref().map(|_| "<fn>"),
143 )
144 .field("resume_value", &self.resume_value)
145 .field("interrupt_before", &self.interrupt_before)
146 .field("interrupt_after", &self.interrupt_after)
147 .field(
148 "metrics_collector",
149 &self
150 .metrics_collector
151 .as_ref()
152 .map(|_| "<MetricsCollector>"),
153 )
154 .field(
155 "callback_handler",
156 &self
157 .callback_handler
158 .as_ref()
159 .map(|_| "<GraphLifecycleCallback>"),
160 )
161 .field(
162 "llm_cache_policy",
163 &self.llm_cache_policy.as_ref().map(|_| "<CachePolicy>"),
164 )
165 .field("heartbeat", &self.heartbeat.as_ref().map(|_| "<Heartbeat>"))
166 .field(
167 "budget_tracker",
168 &self.budget_tracker.as_ref().map(|_| "<BudgetTracker>"),
169 )
170 .finish()
171 }
172}
173
174impl RunnableConfig {
175 #[must_use]
177 pub fn new() -> Self {
178 Self {
179 recursion_limit: 25,
180 max_parallel_tasks: 100,
181 heartbeat: None,
182 ..Default::default()
183 }
184 }
185
186 #[must_use]
188 pub fn with_thread_id(mut self, id: impl Into<String>) -> Self {
189 self.thread_id = Some(id.into());
190 self
191 }
192
193 #[must_use]
195 pub fn with_checkpoint_id(mut self, id: impl Into<String>) -> Self {
196 self.checkpoint_id = Some(id.into());
197 self
198 }
199
200 #[must_use]
202 pub fn with_run_id(mut self, id: impl Into<String>) -> Self {
203 self.run_id = Some(id.into());
204 self
205 }
206
207 #[must_use]
209 pub const fn with_recursion_limit(mut self, limit: usize) -> Self {
210 self.recursion_limit = limit;
211 self
212 }
213
214 #[must_use]
216 pub const fn with_max_parallel_tasks(mut self, max: usize) -> Self {
217 self.max_parallel_tasks = max;
218 self
219 }
220
221 #[must_use]
223 pub fn with_run_name(mut self, name: impl Into<String>) -> Self {
224 self.run_name = Some(name.into());
225 self
226 }
227
228 #[must_use]
230 pub fn with_graph_name(mut self, name: impl Into<String>) -> Self {
231 self.graph_name = Some(name.into());
232 self
233 }
234
235 #[must_use]
237 pub fn with_checkpoint_ns(mut self, ns: crate::checkpoint::CheckpointNamespace) -> Self {
238 self.checkpoint_ns = Some(ns);
239 self
240 }
241
242 #[must_use]
244 pub fn with_cache(mut self, cache: CacheConfig) -> Self {
245 self.cache = Some(cache);
246 self
247 }
248
249 #[must_use]
251 pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
252 self.tags.push(tag.into());
253 self
254 }
255
256 #[must_use]
258 pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
259 self.metadata.insert(key.into(), value);
260 self
261 }
262
263 #[must_use]
265 pub fn with_cancellation_token(mut self, token: tokio_util::sync::CancellationToken) -> Self {
266 self.cancellation_token = Some(token);
267 self
268 }
269
270 #[must_use]
272 pub fn with_budget(mut self, budget: BudgetConfig) -> Self {
273 self.budget = Some(budget);
274 self
275 }
276
277 #[must_use]
288 pub fn with_interrupt_before(mut self, nodes: Vec<String>) -> Self {
289 self.interrupt_before = Some(nodes);
290 self
291 }
292
293 #[must_use]
304 pub fn with_interrupt_after(mut self, nodes: Vec<String>) -> Self {
305 self.interrupt_after = Some(nodes);
306 self
307 }
308
309 #[must_use]
323 pub fn with_metrics_collector(mut self, collector: Arc<dyn MetricsCollector>) -> Self {
324 self.metrics_collector = Some(collector);
325 self
326 }
327
328 #[must_use]
342 pub fn with_callback_handler(mut self, handler: Arc<dyn GraphLifecycleCallback>) -> Self {
343 self.callback_handler = Some(handler);
344 self
345 }
346
347 #[must_use]
349 pub fn with_llm_cache_policy(mut self, policy: LlmCachePolicy) -> Self {
350 self.llm_cache_policy = Some(policy);
351 self
352 }
353
354 #[must_use]
359 pub const fn budget_tracker(&self) -> Option<&Arc<BudgetTracker>> {
360 self.budget_tracker.as_ref()
361 }
362}
363
364#[derive(Clone, Debug)]
366pub struct CacheConfig {
367 pub policy: CachePolicy,
369}
370
371#[derive(Clone)]
376pub struct CachePolicy {
377 #[allow(
383 clippy::type_complexity,
384 reason = "trait object requires full signature"
385 )]
386 pub key_func: Option<Arc<dyn Fn(&serde_json::Value, &RunnableConfig) -> String + Send + Sync>>,
387
388 pub ttl: Option<Duration>,
390
391 pub max_entries: Option<usize>,
393}
394
395impl std::fmt::Debug for CachePolicy {
396 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
397 f.debug_struct("CachePolicy")
398 .field("key_func", &self.key_func.as_ref().map(|_| "<fn>"))
399 .field("ttl", &self.ttl)
400 .field("max_entries", &self.max_entries)
401 .finish()
402 }
403}
404
405impl Default for CachePolicy {
406 fn default() -> Self {
407 Self::default_policy()
408 }
409}
410
411impl CachePolicy {
412 #[must_use]
414 pub fn default_policy() -> Self {
415 Self {
416 key_func: None,
417 ttl: None,
418 max_entries: None,
419 }
420 }
421
422 #[must_use]
426 pub fn ttl(duration: Duration) -> Self {
427 Self {
428 key_func: None,
429 ttl: Some(duration),
430 max_entries: None,
431 }
432 }
433
434 #[must_use]
439 pub fn custom_key(
440 key_func: impl Fn(&serde_json::Value, &RunnableConfig) -> String + Send + Sync + 'static,
441 ) -> Self {
442 Self {
443 key_func: Some(Arc::new(key_func)),
444 ttl: None,
445 max_entries: None,
446 }
447 }
448}
449
450#[derive(Clone, Debug, Default)]
456pub struct TaskConfig {
457 pub retry_policy: Option<crate::graph::RetryPolicy>,
459
460 pub cache_policy: Option<CachePolicy>,
462
463 pub timeout: Option<Duration>,
465
466 pub name: Option<String>,
468}
469
470#[derive(Clone, Default)]
476pub struct EntrypointConfig {
477 pub checkpointer: Option<Arc<dyn CheckpointSaver>>,
479
480 pub store: Option<Arc<dyn Store>>,
482}
483
484impl std::fmt::Debug for EntrypointConfig {
485 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
486 f.debug_struct("EntrypointConfig")
487 .field(
488 "checkpointer",
489 &self.checkpointer.as_ref().map(|_| "<CheckpointSaver>"),
490 )
491 .field("store", &self.store.as_ref().map(|_| "<Store>"))
492 .finish()
493 }
494}
495
496#[cfg(test)]
497mod tests {
498 use super::*;
499
500 #[test]
501 fn test_runnable_config_new() {
502 let config = RunnableConfig::new();
503 assert_eq!(config.recursion_limit, 25);
504 assert_eq!(config.max_parallel_tasks, 100);
505 assert!(config.thread_id.is_none());
506 assert!(config.checkpoint_id.is_none());
507 assert!(config.cancellation_token.is_none());
508 assert!(config.budget.is_none());
509 assert!(config.durability.is_none());
510 assert!(config.resume_value.is_none());
511 assert!(config.heartbeat.is_none());
512 }
513
514 #[test]
515 fn test_runnable_config_with_cancellation_token() {
516 let token = tokio_util::sync::CancellationToken::new();
517 let config = RunnableConfig::new().with_cancellation_token(token);
518 assert!(config.cancellation_token.is_some());
519 }
520
521 #[test]
522 fn test_runnable_config_with_budget() {
523 let budget = BudgetConfig::new().with_max_tokens(1000);
524 let config = RunnableConfig::new().with_budget(budget);
525 assert!(config.budget.is_some());
526 assert_eq!(config.budget.as_ref().unwrap().max_tokens, Some(1000));
527 }
528
529 #[test]
530 fn test_cache_policy_default() {
531 let policy = CachePolicy::default_policy();
532 assert!(policy.key_func.is_none());
533 assert!(policy.ttl.is_none());
534 assert!(policy.max_entries.is_none());
535 }
536
537 #[test]
538 fn test_cache_policy_ttl() {
539 let policy = CachePolicy::ttl(Duration::from_secs(60));
540 assert!(policy.key_func.is_none());
541 assert_eq!(policy.ttl, Some(Duration::from_secs(60)));
542 assert!(policy.max_entries.is_none());
543 }
544
545 #[test]
546 fn test_cache_policy_custom_key() {
547 let policy =
548 CachePolicy::custom_key(|val, _cfg| format!("key-{}", val.as_str().unwrap_or("")));
549 assert!(policy.key_func.is_some());
550 assert!(policy.ttl.is_none());
551 assert!(policy.max_entries.is_none());
552
553 let config = RunnableConfig::new();
555 let key = (policy.key_func.as_ref().unwrap())(&serde_json::json!("test"), &config);
556 assert_eq!(key, "key-test");
557 }
558
559 #[test]
560 fn test_cache_policy_default_trait() {
561 let policy = CachePolicy::default();
562 assert!(policy.key_func.is_none());
563 assert!(policy.ttl.is_none());
564 assert!(policy.max_entries.is_none());
565 }
566
567 #[test]
568 fn test_cache_policy_debug() {
569 let policy = CachePolicy::ttl(Duration::from_secs(30));
570 let debug_str = format!("{policy:?}");
571 assert!(debug_str.contains("ttl"));
572 assert!(debug_str.contains("30s"));
573 }
574
575 #[test]
576 fn test_task_config_default() {
577 let config = TaskConfig::default();
578 assert!(config.retry_policy.is_none());
579 assert!(config.cache_policy.is_none());
580 assert!(config.timeout.is_none());
581 assert!(config.name.is_none());
582 }
583
584 #[test]
585 fn test_entrypoint_config_default() {
586 let config = EntrypointConfig::default();
587 assert!(config.checkpointer.is_none());
588 assert!(config.store.is_none());
589 }
590
591 #[test]
592 fn test_runnable_config_debug_format() {
593 let config = RunnableConfig::new()
594 .with_thread_id("t1")
595 .with_run_name("test-run");
596 let debug = format!("{config:?}");
597 assert!(debug.contains("t1"));
598 assert!(debug.contains("test-run"));
599 }
600}
601
602