use std::io::Write;
use serde_json::Value as JsonValue;
use tokio_stream::StreamExt;
use langgraph_checkpoint::config::RunnableConfig;
use langgraph::config::get_stream_writer;
use langgraph::runnable::RunnableError;
use langgraph::stream::StreamPart;
use langgraph::types::StreamMode;
use crate::traits::BaseChatModel;
use crate::types::Message;
pub fn extract_messages(input: &JsonValue, system_prompt: Option<&str>) -> Vec<Message> {
let messages_json = input
.get("messages")
.and_then(|m| m.as_array())
.cloned()
.unwrap_or_default();
let mut messages = Vec::with_capacity(messages_json.len() + 1);
if let Some(prompt) = system_prompt {
messages.push(Message::system(prompt));
}
for msg in &messages_json {
if let Ok(m) = serde_json::from_value::<Message>(msg.clone()) {
messages.push(m);
}
}
messages
}
pub fn llm_response_to_json(response: Message) -> Result<JsonValue, RunnableError> {
let response_json = serde_json::to_value(response)
.map_err(|e| RunnableError::Node(e.to_string()))?;
Ok(serde_json::json!({ "messages": [response_json] }))
}
pub fn invoke_llm(
model: &dyn BaseChatModel,
input: &JsonValue,
system_prompt: &str,
) -> Result<JsonValue, RunnableError> {
let messages = extract_messages(input, Some(system_prompt));
let response = model.invoke(&messages, &RunnableConfig::new())
.map_err(|e| RunnableError::Node(e.to_string()))?;
llm_response_to_json(response)
}
pub fn invoke_llm_with_config(
model: &dyn BaseChatModel,
input: &JsonValue,
system_prompt: &str,
config: &RunnableConfig,
) -> Result<JsonValue, RunnableError> {
let messages = extract_messages(input, Some(system_prompt));
let response = model.invoke(messages.as_slice(), config)
.map_err(|e| RunnableError::Node(e.to_string()))?;
llm_response_to_json(response)
}
pub async fn stream_llm(
model: &(dyn BaseChatModel + Send + Sync),
input: &JsonValue,
system_prompt: &str,
) -> Result<JsonValue, RunnableError> {
let messages = extract_messages(input, Some(system_prompt));
let writer = get_stream_writer();
let config = RunnableConfig::new();
let mut stream = model.astream(&messages, &config);
let mut accumulated_thinking = String::new();
let mut accumulated_content = String::new();
let mut tool_calls_message = None;
while let Some(result) = stream.next().await {
let chunk = result.map_err(|e| RunnableError::Node(e.to_string()))?;
if chunk.has_tool_calls() {
tool_calls_message = Some(chunk);
} else {
if let Some(ref w) = writer {
if let Some(thinking) = chunk.thinking() {
if !thinking.is_empty() {
let _ = w.try_send(serde_json::json!({
"type": "thinking",
"content": thinking,
}));
}
}
if let Some(content) = chunk.text() {
if !content.is_empty() {
let _ = w.try_send(serde_json::json!({
"type": "token",
"content": content,
}));
}
}
}
if let Some(thinking) = chunk.thinking() {
accumulated_thinking.push_str(thinking);
}
if let Some(content) = chunk.text() {
accumulated_content.push_str(content);
}
}
}
let mut final_message = match tool_calls_message {
Some(tc_msg) => {
let tool_calls = match tc_msg {
Message::Ai { tool_calls, .. } => tool_calls,
_ => vec![],
};
Message::ai_with_tool_calls(accumulated_content, tool_calls)
}
None => Message::ai(accumulated_content),
};
if !accumulated_thinking.is_empty() {
if let Message::Ai { thinking: ref mut th, .. } = final_message {
*th = Some(accumulated_thinking);
}
}
llm_response_to_json(final_message)
}
pub fn get_i64(input: &JsonValue, key: &str) -> i64 {
input.get(key).and_then(|v| v.as_i64()).unwrap_or(0)
}
pub fn get_str<'a>(input: &'a JsonValue, key: &str) -> &'a str {
input.get(key).and_then(|v| v.as_str()).unwrap_or("")
}
pub fn response_text(result: &JsonValue) -> &str {
result
.get("messages")
.and_then(|m| m.as_array())
.and_then(|msgs| msgs.last())
.and_then(|m| m.get("content"))
.and_then(|c| c.as_str())
.unwrap_or("")
}
pub fn print_result(result: &JsonValue) {
print_result_with_options(result, true);
}
pub fn print_result_with_options(result: &JsonValue, show_thinking: bool) {
let messages = match result.get("messages").and_then(|m| m.as_array()) {
Some(m) => m,
None => return,
};
for msg in messages.iter().rev() {
if msg.get("type").and_then(|t| t.as_str()) != Some("ai") {
continue;
}
if show_thinking {
if let Some(thinking) = msg.get("thinking").and_then(|t| t.as_str()) {
if !thinking.is_empty() {
println!("\x1b[2;90m[Thinking] {}\x1b[0m", thinking);
}
}
}
if let Some(content) = msg.get("content").and_then(|c| c.as_str()) {
if !content.is_empty() {
println!("{}", content);
return;
}
}
if let Some(tool_calls) = msg.get("tool_calls").and_then(|tc| tc.as_array()) {
if !tool_calls.is_empty() {
println!("[Called {} tool(s)]", tool_calls.len());
return;
}
}
}
}
pub fn parse_json_response(text: &str) -> Option<JsonValue> {
let trimmed = text.trim();
let json_str = if trimmed.starts_with("```") {
let start = trimmed.find('\n').map(|i| i + 1).unwrap_or(3);
let end = trimmed.rfind("```").unwrap_or(trimmed.len());
&trimmed[start..end]
} else {
trimmed
};
serde_json::from_str(json_str.trim()).ok()
}
pub async fn ask_json(
model: &(dyn BaseChatModel + Send + Sync),
prompt: &str,
system_prompt: &str,
) -> Result<Option<JsonValue>, RunnableError> {
let input = serde_json::json!({"messages": [{"type": "human", "content": prompt}]});
let result = stream_llm(model, &input, system_prompt).await?;
let text = response_text(&result);
Ok(parse_json_response(text))
}
pub async fn print_stream(
stream: &mut (impl StreamExt<Item = StreamPart> + Unpin),
print_updates: bool,
) -> String {
print_stream_with_options(stream, print_updates, true).await
}
pub async fn print_stream_with_options(
stream: &mut (impl StreamExt<Item = StreamPart> + Unpin),
print_updates: bool,
show_thinking: bool,
) -> String {
let mut collected = String::new();
let mut in_thinking = false;
while let Some(part) = stream.next().await {
match part.mode {
StreamMode::Custom => {
if let Some(token_type) = part.data.get("type").and_then(|t| t.as_str()) {
match token_type {
"thinking" if show_thinking => {
if let Some(content) = part.data.get("content").and_then(|c| c.as_str()) {
if !in_thinking {
print!("\x1b[2;90m[Thinking] ");
in_thinking = true;
}
print!("{}", content);
let _ = std::io::stdout().flush();
}
}
"token" => {
if in_thinking {
print!("\x1b[0m");
println!();
in_thinking = false;
}
if let Some(content) = part.data.get("content").and_then(|c| c.as_str()) {
print!("{}", content);
let _ = std::io::stdout().flush();
collected.push_str(content);
}
}
_ => {}
}
}
}
StreamMode::Updates if print_updates => {
if in_thinking {
print!("\x1b[0m");
println!();
in_thinking = false;
}
if let Some(obj) = part.data.as_object() {
for (node_name, _) in obj {
println!("\n[update] Node '{}' completed", node_name);
}
}
}
_ => {}
}
}
if in_thinking {
print!("\x1b[0m");
println!();
}
collected
}