Skip to main content

langgraph_core_rs/pregel/
runner.rs

1use std::sync::Arc;
2use crate::config;
3use crate::runtime::{Runtime, StreamWriter};
4use crate::runnable::RunnableError;
5use super::PregelExecutableTask;
6
7/// Dispatches tasks for parallel execution using tokio.
8///
9/// In the BSP model, all tasks in a super-step can run concurrently.
10/// The runner dispatches them via `tokio::task::JoinSet` and collects
11/// results as they complete.
12pub struct PregelRunner {
13    /// Optional runtime for config propagation.
14    runtime: Option<Arc<Runtime>>,
15    /// Optional stream writer for custom streaming.
16    stream_writer: Option<StreamWriter>,
17}
18
19impl PregelRunner {
20    pub fn new(runtime: Option<Arc<Runtime>>) -> Self {
21        Self { runtime, stream_writer: None }
22    }
23
24    pub fn with_stream_writer(mut self, writer: StreamWriter) -> Self {
25        self.stream_writer = Some(writer);
26        self
27    }
28
29    /// Execute tasks in parallel (async).
30    ///
31    /// Each task's runnable is invoked with its input and config.
32    /// Writes are collected into each task's write buffer.
33    pub async fn run_tasks(&self, tasks: &mut [PregelExecutableTask]) -> Result<(), RunnerError> {
34        if tasks.is_empty() {
35            return Ok(());
36        }
37
38        if tasks.len() == 1 {
39            let task = &mut tasks[0];
40            Self::execute_single_task(task, self.runtime.as_ref(), self.stream_writer.clone()).await?;
41            return Ok(());
42        }
43
44        for task in tasks.iter_mut() {
45            Self::execute_single_task(task, self.runtime.as_ref(), self.stream_writer.clone()).await?;
46        }
47
48        Ok(())
49    }
50
51    /// Execute tasks synchronously (blocking).
52    pub fn run_tasks_sync(&self, tasks: &mut [PregelExecutableTask]) -> Result<(), RunnerError> {
53        for task in tasks.iter_mut() {
54            Self::execute_single_task_sync(task, self.runtime.as_ref())?;
55        }
56        Ok(())
57    }
58
59    /// Execute a single task asynchronously.
60    async fn execute_single_task(
61        task: &mut PregelExecutableTask,
62        runtime: Option<&Arc<Runtime>>,
63        stream_writer: Option<StreamWriter>,
64    ) -> Result<(), RunnerError> {
65        let mut config = task.config.clone();
66        {
67            let configurable = config
68                .entry("configurable".to_string())
69                .or_insert_with(|| serde_json::json!({}));
70            if let Some(obj) = configurable.as_object_mut() {
71                obj.insert(
72                    crate::constants::CONFIG_KEY_SEND.to_string(),
73                    serde_json::json!(true),
74                );
75            }
76        }
77
78        // Build runtime with stream_writer if provided
79        let effective_runtime = if let Some(rt) = runtime {
80            if stream_writer.is_some() {
81                let mut new_rt = (**rt).clone();
82                new_rt.stream_writer = stream_writer;
83                Some(Arc::new(new_rt))
84            } else {
85                Some(rt.clone())
86            }
87        } else if stream_writer.is_some() {
88            Some(Arc::new(Runtime {
89                context: (),
90                store: None,
91                stream_writer,
92                previous: None,
93                execution_info: None,
94                server_info: None,
95            }))
96        } else {
97            None
98        };
99
100        let result = if let Some(ref rt) = effective_runtime {
101            config::with_runtime(config.clone(), rt.clone(), async {
102                task.proc.ainvoke(&task.input, &config).await
103            })
104            .await
105        } else {
106            task.proc.ainvoke(&task.input, &config).await
107        };
108
109        match result {
110            Ok(output) => {
111                if let Some(obj) = output.as_object() {
112                    for (key, val) in obj {
113                        task.writes.push((key.clone(), val.clone()));
114                    }
115                }
116            }
117            Err(RunnableError::Interrupt(interrupt)) => {
118                // Return the task_id along with the interrupt so the caller
119                // can save the interrupt as a pending write in the checkpoint.
120                return Err(RunnerError::Interrupt {
121                    task_id: task.id.clone(),
122                    interrupt,
123                });
124            }
125            Err(e) => {
126                return Err(RunnerError::TaskFailed(task.name.clone(), e.to_string()));
127            }
128        }
129
130        Ok(())
131    }
132
133    /// Execute a single task synchronously.
134    fn execute_single_task_sync(
135        task: &mut PregelExecutableTask,
136        runtime: Option<&Arc<Runtime>>,
137    ) -> Result<(), RunnerError> {
138        let mut config = task.config.clone();
139        {
140            let configurable = config
141                .entry("configurable".to_string())
142                .or_insert_with(|| serde_json::json!({}));
143            if let Some(obj) = configurable.as_object_mut() {
144                obj.insert(
145                    crate::constants::CONFIG_KEY_SEND.to_string(),
146                    serde_json::json!(true),
147                );
148            }
149        }
150
151        let result = if let Some(rt) = runtime {
152            config::with_runtime_sync(config.clone(), rt.clone(), || {
153                task.proc.invoke(&task.input, &config)
154            })
155        } else {
156            task.proc.invoke(&task.input, &config)
157        };
158
159        match result {
160            Ok(output) => {
161                if let Some(obj) = output.as_object() {
162                    for (key, val) in obj {
163                        task.writes.push((key.clone(), val.clone()));
164                    }
165                }
166            }
167            Err(RunnableError::Interrupt(interrupt)) => {
168                return Err(RunnerError::Interrupt {
169                    task_id: task.id.clone(),
170                    interrupt,
171                });
172            }
173            Err(e) => {
174                return Err(RunnerError::TaskFailed(task.name.clone(), e.to_string()));
175            }
176        }
177
178        Ok(())
179    }
180}
181
182#[derive(Debug, thiserror::Error)]
183pub enum RunnerError {
184    #[error("task '{0}' failed: {1}")]
185    TaskFailed(String, String),
186
187    #[error("graph interrupt")]
188    Interrupt {
189        task_id: String,
190        interrupt: crate::types::GraphInterrupt,
191    },
192}