Skip to main content

langgraph_prebuilt/
node_helpers.rs

1//! Helper functions for building graph nodes with minimal boilerplate.
2//!
3//! These utilities eliminate the manual JSON ↔ typed conversion that makes
4//! Rust examples verbose compared to Python's langchain-core.
5
6use std::io::Write;
7use serde_json::Value as JsonValue;
8use tokio_stream::StreamExt;
9use langgraph_checkpoint::config::RunnableConfig;
10use langgraph::config::get_stream_writer;
11use langgraph::runnable::RunnableError;
12use langgraph::stream::StreamPart;
13use langgraph::types::StreamMode;
14
15use crate::traits::BaseChatModel;
16use crate::types::Message;
17
18/// Extract typed messages from a graph state JSON, with an optional system prompt prepended.
19///
20/// This replaces the common 8-line pattern:
21/// ```ignore
22/// let messages_json = input.get("messages")
23///     .and_then(|m| m.as_array()).cloned().unwrap_or_default();
24/// let mut typed_messages = vec![Message::system("...")];
25/// for msg in &messages_json {
26///     if let Ok(m) = serde_json::from_value::<Message>(msg.clone()) {
27///         typed_messages.push(m);
28///     }
29/// }
30/// ```
31///
32/// With:
33/// ```ignore
34/// let messages = extract_messages(&input, Some("You are a helpful assistant."));
35/// ```
36pub fn extract_messages(input: &JsonValue, system_prompt: Option<&str>) -> Vec<Message> {
37    let messages_json = input
38        .get("messages")
39        .and_then(|m| m.as_array())
40        .cloned()
41        .unwrap_or_default();
42
43    let mut messages = Vec::with_capacity(messages_json.len() + 1);
44
45    if let Some(prompt) = system_prompt {
46        messages.push(Message::system(prompt));
47    }
48
49    for msg in &messages_json {
50        if let Ok(m) = serde_json::from_value::<Message>(msg.clone()) {
51            messages.push(m);
52        }
53    }
54
55    messages
56}
57
58/// Convert a model response into a state update JSON.
59///
60/// Wraps the response message in `{"messages": [response]}` format.
61pub fn llm_response_to_json(response: Message) -> Result<JsonValue, RunnableError> {
62    let response_json = serde_json::to_value(response)
63        .map_err(|e| RunnableError::Node(e.to_string()))?;
64    Ok(serde_json::json!({ "messages": [response_json] }))
65}
66
67/// Invoke an LLM and return a state update.
68///
69/// This is the complete LLM node logic in one call:
70/// 1. Extracts messages from input state
71/// 2. Prepends system prompt
72/// 3. Calls the model
73/// 4. Wraps response in state update format
74///
75/// # Example
76/// ```ignore
77/// let model_clone = model.clone();
78/// graph.add_node("chatbot", move |input: JsonValue, _config: RunnableConfig| {
79///     let model = model_clone.clone();
80///     async move { invoke_llm(model.as_ref(), &input, "You are a helpful assistant.") }
81/// })?;
82/// ```
83pub fn invoke_llm(
84    model: &dyn BaseChatModel,
85    input: &JsonValue,
86    system_prompt: &str,
87) -> Result<JsonValue, RunnableError> {
88    let messages = extract_messages(input, Some(system_prompt));
89    let response = model.invoke(&messages, &RunnableConfig::new())
90        .map_err(|e| RunnableError::Node(e.to_string()))?;
91    llm_response_to_json(response)
92}
93
94/// Invoke an LLM with a custom config and return a state update.
95///
96/// Same as [`invoke_llm`] but allows passing a custom config (e.g., for streaming).
97pub fn invoke_llm_with_config(
98    model: &dyn BaseChatModel,
99    input: &JsonValue,
100    system_prompt: &str,
101    config: &RunnableConfig,
102) -> Result<JsonValue, RunnableError> {
103    let messages = extract_messages(input, Some(system_prompt));
104    let response = model.invoke(messages.as_slice(), config)
105        .map_err(|e| RunnableError::Node(e.to_string()))?;
106    llm_response_to_json(response)
107}
108
109/// Stream LLM tokens via StreamWriter and return the final state update.
110///
111/// Calls `model.astream()` for token-by-token streaming. Each partial message
112/// is forwarded through the stream writer (if active) as a JSON payload:
113/// ```json
114/// {"type": "token", "content": "Hello"}
115/// ```
116///
117/// The final complete message is returned as a state update in
118/// `{"messages": [response]}` format.
119///
120/// # Example
121/// ```ignore
122/// let model_clone = model.clone();
123/// graph.add_node("chatbot", move |input: JsonValue, _config: RunnableConfig| {
124///     let model = model_clone.clone();
125///     async move { stream_llm(model.as_ref(), &input, "You are a helpful assistant.").await }
126/// })?;
127/// ```
128pub async fn stream_llm(
129    model: &(dyn BaseChatModel + Send + Sync),
130    input: &JsonValue,
131    system_prompt: &str,
132) -> Result<JsonValue, RunnableError> {
133    let messages = extract_messages(input, Some(system_prompt));
134    let writer = get_stream_writer();
135
136    let config = RunnableConfig::new();
137    let mut stream = model.astream(&messages, &config);
138    let mut accumulated_thinking = String::new();
139    let mut accumulated_content = String::new();
140    let mut tool_calls_message = None;
141
142    // Standard incremental streaming (same as LangChain / OpenAI SDK):
143    //   - Each chunk yielded by the provider contains ONLY new delta tokens.
144    //   - If tool calls are present, the provider yields ONE final signal chunk
145    //     at the very end with has_tool_calls()=true and empty content/thinking.
146    //     This lets us detect tool calls without re-printing any content.
147    //   - We forward every content/thinking delta directly to the stream writer,
148    //     and accumulate them ourselves for the final return value.
149    while let Some(result) = stream.next().await {
150        let chunk = result.map_err(|e| RunnableError::Node(e.to_string()))?;
151
152        if chunk.has_tool_calls() {
153            // Tool-calls signal chunk — no content to print, just capture it.
154            tool_calls_message = Some(chunk);
155        } else {
156            // Pure delta chunk — forward to stream writer and accumulate.
157            if let Some(ref w) = writer {
158                if let Some(thinking) = chunk.thinking() {
159                    if !thinking.is_empty() {
160                        let _ = w.try_send(serde_json::json!({
161                            "type": "thinking",
162                            "content": thinking,
163                        }));
164                    }
165                }
166                if let Some(content) = chunk.text() {
167                    if !content.is_empty() {
168                        let _ = w.try_send(serde_json::json!({
169                            "type": "token",
170                            "content": content,
171                        }));
172                    }
173                }
174            }
175            if let Some(thinking) = chunk.thinking() {
176                accumulated_thinking.push_str(thinking);
177            }
178            if let Some(content) = chunk.text() {
179                accumulated_content.push_str(content);
180            }
181        }
182    }
183
184    // Build the final Message from accumulated content + tool calls (if any).
185    let mut final_message = match tool_calls_message {
186        Some(tc_msg) => {
187            // Reconstruct with full accumulated content + the assembled tool calls.
188            let tool_calls = match tc_msg {
189                Message::Ai { tool_calls, .. } => tool_calls,
190                _ => vec![],
191            };
192            Message::ai_with_tool_calls(accumulated_content, tool_calls)
193        }
194        None => Message::ai(accumulated_content),
195    };
196
197    if !accumulated_thinking.is_empty() {
198        if let Message::Ai { thinking: ref mut th, .. } = final_message {
199            *th = Some(accumulated_thinking);
200        }
201    }
202
203    llm_response_to_json(final_message)
204}
205
206/// Get a field from state as i64, defaulting to 0.
207pub fn get_i64(input: &JsonValue, key: &str) -> i64 {
208    input.get(key).and_then(|v| v.as_i64()).unwrap_or(0)
209}
210
211/// Get a field from state as a string, defaulting to "".
212pub fn get_str<'a>(input: &'a JsonValue, key: &str) -> &'a str {
213    input.get(key).and_then(|v| v.as_str()).unwrap_or("")
214}
215
216/// Extract the assistant's text reply from an `invoke_llm` / `stream_llm` result.
217///
218/// Both helpers return `{"messages": [response]}`. This function digs out the
219/// `content` field of the last message so callers don't repeat the same
220/// `.get("messages") … .last() … .get("content")` chain every time.
221///
222/// # Example
223/// ```ignore
224/// let result = stream_llm(model, &input, "You are a planner.").await?;
225/// let text = response_text(&result);
226/// println!("LLM said: {}", text);
227/// ```
228pub fn response_text(result: &JsonValue) -> &str {
229    result
230        .get("messages")
231        .and_then(|m| m.as_array())
232        .and_then(|msgs| msgs.last())
233        .and_then(|m| m.get("content"))
234        .and_then(|c| c.as_str())
235        .unwrap_or("")
236}
237
238/// Print the last AI message from an `invoke` / `ainvoke` result.
239///
240/// Mirrors [`print_stream`] for non-streaming scenarios. Finds the last
241/// AI message in the `{"messages": [...]}` state, prints its thinking (if any)
242/// in dim gray followed by the content in normal color.
243///
244/// # Example
245/// ```ignore
246/// let result = agent.ainvoke(&input, &RunnableConfig::new()).await?;
247/// print_result(&result);
248/// ```
249pub fn print_result(result: &JsonValue) {
250    print_result_with_options(result, true);
251}
252
253/// Like [`print_result`] but with explicit control over thinking display.
254///
255/// When `show_thinking` is `false` the thinking block is omitted, matching the
256/// behaviour of [`print_stream_with_options`] with `show_thinking = false`.
257pub fn print_result_with_options(result: &JsonValue, show_thinking: bool) {
258    let messages = match result.get("messages").and_then(|m| m.as_array()) {
259        Some(m) => m,
260        None => return,
261    };
262
263    // Walk backwards to find the last AI message that has non-empty content.
264    for msg in messages.iter().rev() {
265        if msg.get("type").and_then(|t| t.as_str()) != Some("ai") {
266            continue;
267        }
268
269        // Print thinking in dim gray (same ANSI codes as stream_and_print).
270        if show_thinking {
271            if let Some(thinking) = msg.get("thinking").and_then(|t| t.as_str()) {
272                if !thinking.is_empty() {
273                    println!("\x1b[2;90m[Thinking] {}\x1b[0m", thinking);
274                }
275            }
276        }
277
278        // Print the answer content.
279        if let Some(content) = msg.get("content").and_then(|c| c.as_str()) {
280            if !content.is_empty() {
281                println!("{}", content);
282                return;
283            }
284        }
285
286        // Fallback: mention tool calls if the last AI turn was a tool-call step.
287        if let Some(tool_calls) = msg.get("tool_calls").and_then(|tc| tc.as_array()) {
288            if !tool_calls.is_empty() {
289                println!("[Called {} tool(s)]", tool_calls.len());
290                return;
291            }
292        }
293    }
294}
295
296/// Strip markdown code fences (` ```json … ``` `) and parse the inner JSON.
297///
298/// If the text is plain JSON (no fences), it is parsed directly.
299/// Returns `None` when the text is not valid JSON after stripping.
300///
301/// # Example
302/// ```ignore
303/// let text = r#"```json\n{"title": "Plan"}\n```"#;
304/// let value = parse_json_response(text).unwrap();
305/// assert_eq!(value["title"], "Plan");
306/// ```
307pub fn parse_json_response(text: &str) -> Option<JsonValue> {
308    let trimmed = text.trim();
309    let json_str = if trimmed.starts_with("```") {
310        let start = trimmed.find('\n').map(|i| i + 1).unwrap_or(3);
311        let end = trimmed.rfind("```").unwrap_or(trimmed.len());
312        &trimmed[start..end]
313    } else {
314        trimmed
315    };
316    serde_json::from_str(json_str.trim()).ok()
317}
318
319/// Ask the LLM a single prompt and get back a parsed JSON value.
320///
321/// This combines three steps that are repeated in every "structured output" node:
322/// 1. Call `stream_llm` with a raw prompt (no state extraction)
323/// 2. Extract the response text
324/// 3. Parse JSON (stripping markdown fences if present)
325///
326/// Returns `None` when the response is not valid JSON.
327///
328/// # Example
329/// ```ignore
330/// let plan = ask_json(model, "Create a plan in JSON format", "").await;
331/// ```
332pub async fn ask_json(
333    model: &(dyn BaseChatModel + Send + Sync),
334    prompt: &str,
335    system_prompt: &str,
336) -> Result<Option<JsonValue>, RunnableError> {
337    let input = serde_json::json!({"messages": [{"type": "human", "content": prompt}]});
338    let result = stream_llm(model, &input, system_prompt).await?;
339    let text = response_text(&result);
340    Ok(parse_json_response(text))
341}
342
343/// Stream graph execution and print tokens to stdout in real-time.
344///
345/// Handles the common streaming boilerplate in examples. Tokens from
346/// `StreamMode::Custom` are printed inline (typewriter style). Node
347/// completion updates from `StreamMode::Updates` are printed as `[update]` lines.
348/// Thinking content is printed in dim gray with a `[Thinking]` prefix.
349///
350/// Returns the collected token text.
351///
352/// # Example
353/// ```ignore
354/// use langgraph::prelude::*;
355/// use langgraph_prebuilt::print_stream;
356///
357/// let mut stream = app.astream(&input, &RunnableConfig::new(), vec![StreamMode::Custom, StreamMode::Updates]);
358/// let text = print_stream(&mut stream, true).await;
359/// println!("Final: {}", text);
360/// ```
361pub async fn print_stream(
362    stream: &mut (impl StreamExt<Item = StreamPart> + Unpin),
363    print_updates: bool,
364) -> String {
365    print_stream_with_options(stream, print_updates, true).await
366}
367
368/// Like [`print_stream`] but with explicit control over thinking display.
369///
370/// When `show_thinking` is `false`, thinking/reasoning content is suppressed.
371pub async fn print_stream_with_options(
372    stream: &mut (impl StreamExt<Item = StreamPart> + Unpin),
373    print_updates: bool,
374    show_thinking: bool,
375) -> String {
376    let mut collected = String::new();
377    let mut in_thinking = false;
378
379    while let Some(part) = stream.next().await {
380        match part.mode {
381            StreamMode::Custom => {
382                if let Some(token_type) = part.data.get("type").and_then(|t| t.as_str()) {
383                    match token_type {
384                        "thinking" if show_thinking => {
385                            if let Some(content) = part.data.get("content").and_then(|c| c.as_str()) {
386                                if !in_thinking {
387                                    // ANSI dark gray: ESC[2;90m — dim + bright black
388                                    // Resets at the end of each thinking block.
389                                    print!("\x1b[2;90m[Thinking] ");
390                                    in_thinking = true;
391                                }
392                                print!("{}", content);
393                                let _ = std::io::stdout().flush();
394                            }
395                        }
396                        "token" => {
397                            if in_thinking {
398                                // End of thinking block — reset color, then new line before answer
399                                print!("\x1b[0m");
400                                println!();
401                                in_thinking = false;
402                            }
403                            if let Some(content) = part.data.get("content").and_then(|c| c.as_str()) {
404                                print!("{}", content);
405                                let _ = std::io::stdout().flush();
406                                collected.push_str(content);
407                            }
408                        }
409                        _ => {}
410                    }
411                }
412            }
413            StreamMode::Updates if print_updates => {
414                if in_thinking {
415                    print!("\x1b[0m");
416                    println!();
417                    in_thinking = false;
418                }
419                if let Some(obj) = part.data.as_object() {
420                    for (node_name, _) in obj {
421                        println!("\n[update] Node '{}' completed", node_name);
422                    }
423                }
424            }
425            _ => {}
426        }
427    }
428
429    if in_thinking {
430        print!("\x1b[0m");
431        println!();
432    }
433
434    collected
435}