langgraph_prebuilt/node_helpers.rs
1//! Helper functions for building graph nodes with minimal boilerplate.
2//!
3//! These utilities eliminate the manual JSON ↔ typed conversion that makes
4//! Rust examples verbose compared to Python's langchain-core.
5
6use std::io::Write;
7use serde_json::Value as JsonValue;
8use tokio_stream::StreamExt;
9use langgraph_checkpoint::config::RunnableConfig;
10use langgraph::config::get_stream_writer;
11use langgraph::runnable::RunnableError;
12use langgraph::stream::StreamPart;
13use langgraph::types::StreamMode;
14
15use crate::traits::BaseChatModel;
16use crate::types::Message;
17
18/// Extract typed messages from a graph state JSON, with an optional system prompt prepended.
19///
20/// This replaces the common 8-line pattern:
21/// ```ignore
22/// let messages_json = input.get("messages")
23/// .and_then(|m| m.as_array()).cloned().unwrap_or_default();
24/// let mut typed_messages = vec![Message::system("...")];
25/// for msg in &messages_json {
26/// if let Ok(m) = serde_json::from_value::<Message>(msg.clone()) {
27/// typed_messages.push(m);
28/// }
29/// }
30/// ```
31///
32/// With:
33/// ```ignore
34/// let messages = extract_messages(&input, Some("You are a helpful assistant."));
35/// ```
36pub fn extract_messages(input: &JsonValue, system_prompt: Option<&str>) -> Vec<Message> {
37 let messages_json = input
38 .get("messages")
39 .and_then(|m| m.as_array())
40 .cloned()
41 .unwrap_or_default();
42
43 let mut messages = Vec::with_capacity(messages_json.len() + 1);
44
45 if let Some(prompt) = system_prompt {
46 messages.push(Message::system(prompt));
47 }
48
49 for msg in &messages_json {
50 if let Ok(m) = serde_json::from_value::<Message>(msg.clone()) {
51 messages.push(m);
52 }
53 }
54
55 messages
56}
57
58/// Convert a model response into a state update JSON.
59///
60/// Wraps the response message in `{"messages": [response]}` format.
61pub fn llm_response_to_json(response: Message) -> Result<JsonValue, RunnableError> {
62 let response_json = serde_json::to_value(response)
63 .map_err(|e| RunnableError::Node(e.to_string()))?;
64 Ok(serde_json::json!({ "messages": [response_json] }))
65}
66
67/// Invoke an LLM and return a state update.
68///
69/// This is the complete LLM node logic in one call:
70/// 1. Extracts messages from input state
71/// 2. Prepends system prompt
72/// 3. Calls the model
73/// 4. Wraps response in state update format
74///
75/// # Example
76/// ```ignore
77/// let model_clone = model.clone();
78/// graph.add_node("chatbot", move |input: JsonValue, _config: RunnableConfig| {
79/// let model = model_clone.clone();
80/// async move { invoke_llm(model.as_ref(), &input, "You are a helpful assistant.") }
81/// })?;
82/// ```
83pub fn invoke_llm(
84 model: &dyn BaseChatModel,
85 input: &JsonValue,
86 system_prompt: &str,
87) -> Result<JsonValue, RunnableError> {
88 let messages = extract_messages(input, Some(system_prompt));
89 let response = model.invoke(&messages, &RunnableConfig::new())
90 .map_err(|e| RunnableError::Node(e.to_string()))?;
91 llm_response_to_json(response)
92}
93
94/// Invoke an LLM with a custom config and return a state update.
95///
96/// Same as [`invoke_llm`] but allows passing a custom config (e.g., for streaming).
97pub fn invoke_llm_with_config(
98 model: &dyn BaseChatModel,
99 input: &JsonValue,
100 system_prompt: &str,
101 config: &RunnableConfig,
102) -> Result<JsonValue, RunnableError> {
103 let messages = extract_messages(input, Some(system_prompt));
104 let response = model.invoke(messages.as_slice(), config)
105 .map_err(|e| RunnableError::Node(e.to_string()))?;
106 llm_response_to_json(response)
107}
108
109/// Stream LLM tokens via StreamWriter and return the final state update.
110///
111/// Calls `model.astream()` for token-by-token streaming. Each partial message
112/// is forwarded through the stream writer (if active) as a JSON payload:
113/// ```json
114/// {"type": "token", "content": "Hello"}
115/// ```
116///
117/// The final complete message is returned as a state update in
118/// `{"messages": [response]}` format.
119///
120/// # Example
121/// ```ignore
122/// let model_clone = model.clone();
123/// graph.add_node("chatbot", move |input: JsonValue, _config: RunnableConfig| {
124/// let model = model_clone.clone();
125/// async move { stream_llm(model.as_ref(), &input, "You are a helpful assistant.").await }
126/// })?;
127/// ```
128pub async fn stream_llm(
129 model: &(dyn BaseChatModel + Send + Sync),
130 input: &JsonValue,
131 system_prompt: &str,
132) -> Result<JsonValue, RunnableError> {
133 let messages = extract_messages(input, Some(system_prompt));
134 let writer = get_stream_writer();
135
136 let config = RunnableConfig::new();
137 let mut stream = model.astream(&messages, &config);
138 let mut accumulated_thinking = String::new();
139 let mut accumulated_content = String::new();
140 let mut tool_calls_message = None;
141
142 // Standard incremental streaming (same as LangChain / OpenAI SDK):
143 // - Each chunk yielded by the provider contains ONLY new delta tokens.
144 // - If tool calls are present, the provider yields ONE final signal chunk
145 // at the very end with has_tool_calls()=true and empty content/thinking.
146 // This lets us detect tool calls without re-printing any content.
147 // - We forward every content/thinking delta directly to the stream writer,
148 // and accumulate them ourselves for the final return value.
149 while let Some(result) = stream.next().await {
150 let chunk = result.map_err(|e| RunnableError::Node(e.to_string()))?;
151
152 if chunk.has_tool_calls() {
153 // Tool-calls signal chunk — no content to print, just capture it.
154 tool_calls_message = Some(chunk);
155 } else {
156 // Pure delta chunk — forward to stream writer and accumulate.
157 if let Some(ref w) = writer {
158 if let Some(thinking) = chunk.thinking() {
159 if !thinking.is_empty() {
160 let _ = w.try_send(serde_json::json!({
161 "type": "thinking",
162 "content": thinking,
163 }));
164 }
165 }
166 if let Some(content) = chunk.text() {
167 if !content.is_empty() {
168 let _ = w.try_send(serde_json::json!({
169 "type": "token",
170 "content": content,
171 }));
172 }
173 }
174 }
175 if let Some(thinking) = chunk.thinking() {
176 accumulated_thinking.push_str(thinking);
177 }
178 if let Some(content) = chunk.text() {
179 accumulated_content.push_str(content);
180 }
181 }
182 }
183
184 // Build the final Message from accumulated content + tool calls (if any).
185 let mut final_message = match tool_calls_message {
186 Some(tc_msg) => {
187 // Reconstruct with full accumulated content + the assembled tool calls.
188 let tool_calls = match tc_msg {
189 Message::Ai { tool_calls, .. } => tool_calls,
190 _ => vec![],
191 };
192 Message::ai_with_tool_calls(accumulated_content, tool_calls)
193 }
194 None => Message::ai(accumulated_content),
195 };
196
197 if !accumulated_thinking.is_empty() {
198 if let Message::Ai { thinking: ref mut th, .. } = final_message {
199 *th = Some(accumulated_thinking);
200 }
201 }
202
203 llm_response_to_json(final_message)
204}
205
206/// Get a field from state as i64, defaulting to 0.
207pub fn get_i64(input: &JsonValue, key: &str) -> i64 {
208 input.get(key).and_then(|v| v.as_i64()).unwrap_or(0)
209}
210
211/// Get a field from state as a string, defaulting to "".
212pub fn get_str<'a>(input: &'a JsonValue, key: &str) -> &'a str {
213 input.get(key).and_then(|v| v.as_str()).unwrap_or("")
214}
215
216/// Extract the assistant's text reply from an `invoke_llm` / `stream_llm` result.
217///
218/// Both helpers return `{"messages": [response]}`. This function digs out the
219/// `content` field of the last message so callers don't repeat the same
220/// `.get("messages") … .last() … .get("content")` chain every time.
221///
222/// # Example
223/// ```ignore
224/// let result = stream_llm(model, &input, "You are a planner.").await?;
225/// let text = response_text(&result);
226/// println!("LLM said: {}", text);
227/// ```
228pub fn response_text(result: &JsonValue) -> &str {
229 result
230 .get("messages")
231 .and_then(|m| m.as_array())
232 .and_then(|msgs| msgs.last())
233 .and_then(|m| m.get("content"))
234 .and_then(|c| c.as_str())
235 .unwrap_or("")
236}
237
238/// Print the last AI message from an `invoke` / `ainvoke` result.
239///
240/// Mirrors [`print_stream`] for non-streaming scenarios. Finds the last
241/// AI message in the `{"messages": [...]}` state, prints its thinking (if any)
242/// in dim gray followed by the content in normal color.
243///
244/// # Example
245/// ```ignore
246/// let result = agent.ainvoke(&input, &RunnableConfig::new()).await?;
247/// print_result(&result);
248/// ```
249pub fn print_result(result: &JsonValue) {
250 print_result_with_options(result, true);
251}
252
253/// Like [`print_result`] but with explicit control over thinking display.
254///
255/// When `show_thinking` is `false` the thinking block is omitted, matching the
256/// behaviour of [`print_stream_with_options`] with `show_thinking = false`.
257pub fn print_result_with_options(result: &JsonValue, show_thinking: bool) {
258 let messages = match result.get("messages").and_then(|m| m.as_array()) {
259 Some(m) => m,
260 None => return,
261 };
262
263 // Walk backwards to find the last AI message that has non-empty content.
264 for msg in messages.iter().rev() {
265 if msg.get("type").and_then(|t| t.as_str()) != Some("ai") {
266 continue;
267 }
268
269 // Print thinking in dim gray (same ANSI codes as stream_and_print).
270 if show_thinking {
271 if let Some(thinking) = msg.get("thinking").and_then(|t| t.as_str()) {
272 if !thinking.is_empty() {
273 println!("\x1b[2;90m[Thinking] {}\x1b[0m", thinking);
274 }
275 }
276 }
277
278 // Print the answer content.
279 if let Some(content) = msg.get("content").and_then(|c| c.as_str()) {
280 if !content.is_empty() {
281 println!("{}", content);
282 return;
283 }
284 }
285
286 // Fallback: mention tool calls if the last AI turn was a tool-call step.
287 if let Some(tool_calls) = msg.get("tool_calls").and_then(|tc| tc.as_array()) {
288 if !tool_calls.is_empty() {
289 println!("[Called {} tool(s)]", tool_calls.len());
290 return;
291 }
292 }
293 }
294}
295
296/// Strip markdown code fences (` ```json … ``` `) and parse the inner JSON.
297///
298/// If the text is plain JSON (no fences), it is parsed directly.
299/// Returns `None` when the text is not valid JSON after stripping.
300///
301/// # Example
302/// ```ignore
303/// let text = r#"```json\n{"title": "Plan"}\n```"#;
304/// let value = parse_json_response(text).unwrap();
305/// assert_eq!(value["title"], "Plan");
306/// ```
307pub fn parse_json_response(text: &str) -> Option<JsonValue> {
308 let trimmed = text.trim();
309 let json_str = if trimmed.starts_with("```") {
310 let start = trimmed.find('\n').map(|i| i + 1).unwrap_or(3);
311 let end = trimmed.rfind("```").unwrap_or(trimmed.len());
312 &trimmed[start..end]
313 } else {
314 trimmed
315 };
316 serde_json::from_str(json_str.trim()).ok()
317}
318
319/// Ask the LLM a single prompt and get back a parsed JSON value.
320///
321/// This combines three steps that are repeated in every "structured output" node:
322/// 1. Call `stream_llm` with a raw prompt (no state extraction)
323/// 2. Extract the response text
324/// 3. Parse JSON (stripping markdown fences if present)
325///
326/// Returns `None` when the response is not valid JSON.
327///
328/// # Example
329/// ```ignore
330/// let plan = ask_json(model, "Create a plan in JSON format", "").await;
331/// ```
332pub async fn ask_json(
333 model: &(dyn BaseChatModel + Send + Sync),
334 prompt: &str,
335 system_prompt: &str,
336) -> Result<Option<JsonValue>, RunnableError> {
337 let input = serde_json::json!({"messages": [{"type": "human", "content": prompt}]});
338 let result = stream_llm(model, &input, system_prompt).await?;
339 let text = response_text(&result);
340 Ok(parse_json_response(text))
341}
342
343/// Stream graph execution and print tokens to stdout in real-time.
344///
345/// Handles the common streaming boilerplate in examples. Tokens from
346/// `StreamMode::Custom` are printed inline (typewriter style). Node
347/// completion updates from `StreamMode::Updates` are printed as `[update]` lines.
348/// Thinking content is printed in dim gray with a `[Thinking]` prefix.
349///
350/// Returns the collected token text.
351///
352/// # Example
353/// ```ignore
354/// use langgraph::prelude::*;
355/// use langgraph_prebuilt::print_stream;
356///
357/// let mut stream = app.astream(&input, &RunnableConfig::new(), vec![StreamMode::Custom, StreamMode::Updates]);
358/// let text = print_stream(&mut stream, true).await;
359/// println!("Final: {}", text);
360/// ```
361pub async fn print_stream(
362 stream: &mut (impl StreamExt<Item = StreamPart> + Unpin),
363 print_updates: bool,
364) -> String {
365 print_stream_with_options(stream, print_updates, true).await
366}
367
368/// Like [`print_stream`] but with explicit control over thinking display.
369///
370/// When `show_thinking` is `false`, thinking/reasoning content is suppressed.
371pub async fn print_stream_with_options(
372 stream: &mut (impl StreamExt<Item = StreamPart> + Unpin),
373 print_updates: bool,
374 show_thinking: bool,
375) -> String {
376 let mut collected = String::new();
377 let mut in_thinking = false;
378
379 while let Some(part) = stream.next().await {
380 match part.mode {
381 StreamMode::Custom => {
382 if let Some(token_type) = part.data.get("type").and_then(|t| t.as_str()) {
383 match token_type {
384 "thinking" if show_thinking => {
385 if let Some(content) = part.data.get("content").and_then(|c| c.as_str()) {
386 if !in_thinking {
387 // ANSI dark gray: ESC[2;90m — dim + bright black
388 // Resets at the end of each thinking block.
389 print!("\x1b[2;90m[Thinking] ");
390 in_thinking = true;
391 }
392 print!("{}", content);
393 let _ = std::io::stdout().flush();
394 }
395 }
396 "token" => {
397 if in_thinking {
398 // End of thinking block — reset color, then new line before answer
399 print!("\x1b[0m");
400 println!();
401 in_thinking = false;
402 }
403 if let Some(content) = part.data.get("content").and_then(|c| c.as_str()) {
404 print!("{}", content);
405 let _ = std::io::stdout().flush();
406 collected.push_str(content);
407 }
408 }
409 _ => {}
410 }
411 }
412 }
413 StreamMode::Updates if print_updates => {
414 if in_thinking {
415 print!("\x1b[0m");
416 println!();
417 in_thinking = false;
418 }
419 if let Some(obj) = part.data.as_object() {
420 for (node_name, _) in obj {
421 println!("\n[update] Node '{}' completed", node_name);
422 }
423 }
424 }
425 _ => {}
426 }
427 }
428
429 if in_thinking {
430 print!("\x1b[0m");
431 println!();
432 }
433
434 collected
435}