use crate::error::{MiniLLMError, Result};
use crate::generator::{GeneratorInfo, NodeCompletionParameters};
use crate::message::{merge_contiguous_messages, ContentPart, Message, MessageContent, Role};
use crate::provider::{global_client, LLMClient, StreamingCompletion};
use std::sync::{Arc, RwLock, Weak};
use std::time::Duration;
use uuid::Uuid;
pub struct ChatNode {
pub id: String,
pub message: Message,
children: RwLock<Vec<Arc<ChatNode>>>,
parent: RwLock<Option<Weak<ChatNode>>>,
pub metadata: RwLock<serde_json::Value>,
format_kwargs: RwLock<std::collections::HashMap<String, String>>,
}
impl ChatNode {
pub fn root(system_prompt: impl Into<String>) -> Arc<Self> {
let prompt: String = system_prompt.into();
Arc::new(Self {
id: Uuid::new_v4().to_string(),
message: Message::system(prompt),
children: RwLock::new(Vec::new()),
parent: RwLock::new(None),
metadata: RwLock::new(serde_json::json!({})),
format_kwargs: RwLock::new(std::collections::HashMap::new()),
})
}
pub fn new(message: Message) -> Arc<Self> {
Arc::new(Self {
id: Uuid::new_v4().to_string(),
message,
children: RwLock::new(Vec::new()),
parent: RwLock::new(None),
metadata: RwLock::new(serde_json::json!({})),
format_kwargs: RwLock::new(std::collections::HashMap::new()),
})
}
pub fn user(content: impl Into<MessageContent>) -> Arc<Self> {
Self::new(Message::user(content))
}
pub fn assistant(content: impl Into<MessageContent>) -> Arc<Self> {
Self::new(Message::assistant(content))
}
pub fn add_child(self: &Arc<Self>, child: Arc<ChatNode>) -> Arc<ChatNode> {
{
let mut parent_lock = child.parent.write().unwrap();
*parent_lock = Some(Arc::downgrade(self));
}
{
let mut children_lock = self.children.write().unwrap();
children_lock.push(child.clone());
}
child
}
pub fn add_user(self: &Arc<Self>, content: impl Into<MessageContent>) -> Arc<ChatNode> {
self.add_child(Self::user(content))
}
pub fn add_assistant(self: &Arc<Self>, content: impl Into<MessageContent>) -> Arc<ChatNode> {
self.add_child(Self::assistant(content))
}
pub fn parent(&self) -> Option<Arc<ChatNode>> {
self.parent
.read()
.unwrap()
.as_ref()
.and_then(|w| w.upgrade())
}
pub fn children(&self) -> Vec<Arc<ChatNode>> {
self.children.read().unwrap().clone()
}
pub fn child_count(&self) -> usize {
self.children.read().unwrap().len()
}
pub fn is_root(&self) -> bool {
self.parent.read().unwrap().is_none()
}
pub fn get_root(self: &Arc<Self>) -> Arc<ChatNode> {
match self.parent() {
Some(parent) => parent.get_root(),
None => self.clone(),
}
}
pub fn is_leaf(&self) -> bool {
self.children.read().unwrap().is_empty()
}
pub fn thread(&self) -> Vec<Message> {
let mut messages = Vec::new();
self.collect_thread(&mut messages);
messages.reverse();
messages
}
fn collect_thread(&self, messages: &mut Vec<Message>) {
messages.push(self.message.clone());
if let Some(parent) = self.parent() {
parent.collect_thread(messages);
}
}
pub fn merged_thread(&self) -> Vec<Message> {
merge_contiguous_messages(self.thread())
}
pub fn depth(&self) -> usize {
match self.parent() {
Some(parent) => 1 + parent.depth(),
None => 0,
}
}
pub fn find_by_id(self: &Arc<Self>, id: &str) -> Option<Arc<ChatNode>> {
if self.id == id {
return Some(self.clone());
}
for child in self.children() {
if let Some(found) = child.find_by_id(id) {
return Some(found);
}
}
None
}
pub fn last_child(&self) -> Option<Arc<ChatNode>> {
self.children.read().unwrap().last().cloned()
}
pub fn get_leaf(self: &Arc<Self>) -> Arc<ChatNode> {
match self.last_child() {
Some(child) => child.get_leaf(),
None => self.clone(),
}
}
pub fn detach(self: &Arc<Self>) -> Arc<ChatNode> {
if let Some(parent) = self.parent() {
let mut children = parent.children.write().unwrap();
children.retain(|c| c.id != self.id);
}
{
let mut parent_lock = self.parent.write().unwrap();
*parent_lock = None;
}
self.clone()
}
pub fn merge(self: &Arc<Self>, other: &Arc<ChatNode>) -> Arc<ChatNode> {
let other_root = other.get_root();
self.add_child(other_root);
other.get_leaf()
}
pub fn iter_depth_first(self: &Arc<Self>) -> Vec<Arc<ChatNode>> {
let mut result = vec![self.clone()];
for child in self.children() {
result.extend(child.iter_depth_first());
}
result
}
pub fn iter_breadth_first(self: &Arc<Self>) -> Vec<Arc<ChatNode>> {
let mut result = Vec::new();
let mut queue = std::collections::VecDeque::new();
queue.push_back(self.clone());
while let Some(node) = queue.pop_front() {
result.push(node.clone());
for child in node.children() {
queue.push_back(child);
}
}
result
}
pub fn iter_leaves(self: &Arc<Self>) -> Vec<Arc<ChatNode>> {
if self.is_leaf() {
return vec![self.clone()];
}
let mut result = Vec::new();
for child in self.children() {
result.extend(child.iter_leaves());
}
result
}
pub fn node_count(self: &Arc<Self>) -> usize {
1 + self
.children()
.iter()
.map(|c| c.node_count())
.sum::<usize>()
}
pub fn set_metadata(&self, key: &str, value: serde_json::Value) {
let mut metadata = self.metadata.write().unwrap();
metadata[key] = value;
}
pub fn get_metadata(&self, key: &str) -> Option<serde_json::Value> {
let metadata = self.metadata.read().unwrap();
metadata.get(key).cloned()
}
pub fn set_format_kwarg(&self, key: &str, value: &str) {
let mut kwargs = self.format_kwargs.write().unwrap();
kwargs.insert(key.to_string(), value.to_string());
}
pub fn set_format_kwargs(&self, kwargs: &std::collections::HashMap<String, String>) {
let mut current = self.format_kwargs.write().unwrap();
for (k, v) in kwargs {
current.insert(k.clone(), v.clone());
}
}
pub fn get_format_kwarg(&self, key: &str) -> Option<String> {
let kwargs = self.format_kwargs.read().unwrap();
kwargs.get(key).cloned()
}
pub fn get_format_kwargs(&self) -> std::collections::HashMap<String, String> {
self.format_kwargs.read().unwrap().clone()
}
pub fn update_format_kwargs(
self: &Arc<Self>,
kwargs: &std::collections::HashMap<String, String>,
propagate: bool,
) {
{
let mut current = self.format_kwargs.write().unwrap();
for (k, v) in kwargs {
current.insert(k.clone(), v.clone());
}
}
if propagate {
if let Some(parent) = self.parent() {
parent.update_format_kwargs(kwargs, true);
}
}
}
pub fn formatted_text(&self) -> Option<String> {
let text = self.message.content.get_text()?;
Some(self.format_string(text))
}
pub fn format_string(&self, template: &str) -> String {
let kwargs = self.format_kwargs.read().unwrap();
let mut result = template.to_string();
for (key, value) in kwargs.iter() {
let placeholder = format!("{{{}}}", key);
result = result.replace(&placeholder, value);
}
result
}
pub fn formatted_thread(&self) -> Vec<Message> {
let mut all_kwargs = std::collections::HashMap::new();
self.collect_format_kwargs(&mut all_kwargs);
self.thread()
.into_iter()
.map(|mut msg| {
match &msg.content {
MessageContent::Text(text) => {
let mut formatted = text.clone();
for (key, value) in &all_kwargs {
let placeholder = format!("{{{}}}", key);
formatted = formatted.replace(&placeholder, value);
}
msg.content = MessageContent::Text(formatted);
}
MessageContent::Parts(parts) => {
let formatted_parts: Vec<_> = parts
.iter()
.map(|part| {
if let Some(text) = part.as_text() {
let mut formatted = text.to_string();
for (key, value) in &all_kwargs {
let placeholder = format!("{{{}}}", key);
formatted = formatted.replace(&placeholder, value);
}
ContentPart::text(formatted)
} else {
part.clone()
}
})
.collect();
msg.content = MessageContent::Parts(formatted_parts);
}
}
msg
})
.collect()
}
fn collect_format_kwargs(&self, kwargs: &mut std::collections::HashMap<String, String>) {
if let Some(parent) = self.parent() {
parent.collect_format_kwargs(kwargs);
}
let my_kwargs = self.format_kwargs.read().unwrap();
for (k, v) in my_kwargs.iter() {
kwargs.insert(k.clone(), v.clone());
}
}
pub async fn complete(
self: &Arc<Self>,
generator: &GeneratorInfo,
params: Option<&NodeCompletionParameters>,
) -> Result<Arc<ChatNode>> {
let client = global_client();
self.complete_with_client(client, generator, params).await
}
pub async fn complete_with_client(
self: &Arc<Self>,
client: &LLMClient,
generator: &GeneratorInfo,
params: Option<&NodeCompletionParameters>,
) -> Result<Arc<ChatNode>> {
let mut messages = merge_contiguous_messages(self.formatted_thread());
if let Some(p) = params {
if let Some(system) = &p.system_prompt {
if messages.first().map(|m| m.role) != Some(Role::System) {
messages.insert(0, Message::system(system.clone()));
}
}
}
if let Some(p) = params {
if let Some(prepend) = &p.force_prepend {
messages.push(Message::assistant(prepend.clone()));
}
}
let completion_params = params
.and_then(|p| p.params.as_ref())
.map(|p| generator.default_params.merge(p))
.unwrap_or_else(|| generator.default_params.clone());
let max_retries = params.map(|p| p.retry).unwrap_or(4);
let exp_back_off = params.map(|p| p.exp_back_off).unwrap_or(false);
let back_off_time = params.map(|p| p.back_off_time).unwrap_or(1.0);
let max_back_off = params.map(|p| p.max_back_off).unwrap_or(15.0);
let parse_json = params.map(|p| p.parse_json).unwrap_or(false);
let crash_on_refusal = params.map(|p| p.crash_on_refusal).unwrap_or(false);
let crash_on_empty = params.map(|p| p.crash_on_empty_response).unwrap_or(false);
let force_prepend = params.and_then(|p| p.force_prepend.clone());
use crate::provider::CostTrackingType;
let cost_tracking = params
.map(|p| p.cost_tracking)
.unwrap_or(CostTrackingType::None);
let cost_callback = params.and_then(|p| p.cost_callback.clone());
let include_usage = cost_tracking == CostTrackingType::OpenRouter;
let mut last_error: Option<MiniLLMError> = None;
let mut current_back_off = back_off_time;
for attempt in 0..=max_retries {
if attempt > 0 {
let sleep_time = if exp_back_off {
current_back_off.min(max_back_off)
} else {
back_off_time
};
tokio::time::sleep(Duration::from_secs_f64(sleep_time)).await;
if exp_back_off {
current_back_off *= 2.0;
}
tracing::debug!(attempt = attempt, "Retrying completion request");
}
let response = match client
.complete_with_usage_tracking(
generator,
&messages,
&completion_params,
include_usage,
)
.await
{
Ok(r) => r,
Err(e) => {
last_error = Some(e);
continue;
}
};
let mut content = response.content.clone();
if let Some(ref prepend) = force_prepend {
if !content.starts_with(prepend) {
content = format!("{}{}", prepend, content);
}
}
if crash_on_empty && content.trim().is_empty() {
last_error = Some(MiniLLMError::Other("Empty response from model".to_string()));
continue;
}
if parse_json {
match self.process_json_response(&content, crash_on_refusal) {
Ok(parsed) => content = parsed,
Err(e) => {
last_error = Some(e);
continue;
}
}
}
let assistant_node = Self::new(Message::assistant(content));
assistant_node.set_metadata("response_id", serde_json::json!(response.id));
assistant_node.set_metadata("model", serde_json::json!(response.model));
if let Some(usage) = &response.usage {
assistant_node.set_metadata("usage", serde_json::json!(usage));
}
if let Some(finish_reason) = &response.finish_reason {
assistant_node.set_metadata("finish_reason", serde_json::json!(finish_reason));
}
if let Some(ref callback) = cost_callback {
if let Some(usage) = &response.usage {
use crate::provider::CostInfo;
let cost_info = CostInfo {
cost: usage.cost.unwrap_or(0.0),
prompt_tokens: usage.prompt_tokens,
completion_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
cached_tokens: usage.cached_tokens,
reasoning_tokens: usage.reasoning_tokens,
model: response.model.clone(),
response_id: response.id.clone(),
};
callback(cost_info);
}
}
return Ok(self.add_child(assistant_node));
}
Err(last_error.unwrap_or_else(|| MiniLLMError::Other("Max retries exceeded".to_string())))
}
fn process_json_response(&self, content: &str, crash_on_refusal: bool) -> Result<String> {
use crate::json_repair::{repair_json, RepairOptions};
if crash_on_refusal && !content.contains('{') && !content.contains('[') {
return Err(MiniLLMError::Other(format!(
"No JSON found in response: {}",
content
)));
}
let repaired = repair_json(content, &RepairOptions::default())?;
if crash_on_refusal && (repaired == "\"\"" || repaired == "{}" || repaired.is_empty()) {
return Err(MiniLLMError::Other(format!(
"Empty JSON in response: {}",
content
)));
}
Ok(repaired)
}
pub async fn complete_streaming(
self: &Arc<Self>,
generator: &GeneratorInfo,
params: Option<&NodeCompletionParameters>,
) -> Result<StreamingCompletion> {
let client = global_client();
self.complete_streaming_with_client(client, generator, params)
.await
}
pub async fn complete_streaming_with_client(
self: &Arc<Self>,
client: &LLMClient,
generator: &GeneratorInfo,
params: Option<&NodeCompletionParameters>,
) -> Result<StreamingCompletion> {
let mut messages = merge_contiguous_messages(self.formatted_thread());
if let Some(p) = params {
if let Some(system) = &p.system_prompt {
if messages.first().map(|m| m.role) != Some(Role::System) {
messages.insert(0, Message::system(system.clone()));
}
}
}
let completion_params = params
.and_then(|p| p.params.as_ref())
.map(|p| generator.default_params.merge(p))
.unwrap_or_else(|| generator.default_params.clone());
use crate::provider::CostTrackingType;
let cost_tracking = params
.map(|p| p.cost_tracking)
.unwrap_or(CostTrackingType::None);
let include_usage = cost_tracking == CostTrackingType::OpenRouter;
client
.complete_streaming_with_usage(generator, &messages, &completion_params, include_usage)
.await
}
pub async fn complete_streaming_collect(
self: &Arc<Self>,
generator: &GeneratorInfo,
params: Option<&NodeCompletionParameters>,
) -> Result<Arc<ChatNode>> {
let stream = self.complete_streaming(generator, params).await?;
let response = stream.collect().await?;
let parse_json = params.map(|p| p.parse_json).unwrap_or(false);
let crash_on_refusal = params.map(|p| p.crash_on_refusal).unwrap_or(false);
let force_prepend = params.and_then(|p| p.force_prepend.clone());
let mut content = response.content;
if let Some(ref prepend) = force_prepend {
if !content.starts_with(prepend) {
content = format!("{}{}", prepend, content);
}
}
if parse_json {
content = self.process_json_response(&content, crash_on_refusal)?;
}
let assistant_node = Self::new(Message::assistant(content));
assistant_node.set_metadata("response_id", serde_json::json!(response.id));
assistant_node.set_metadata("model", serde_json::json!(response.model));
if let Some(usage) = &response.usage {
assistant_node.set_metadata("usage", serde_json::json!(usage));
}
if let Some(p) = params {
if let Some(ref callback) = p.cost_callback {
if let Some(usage) = &response.usage {
use crate::provider::CostInfo;
let cost_info = CostInfo {
cost: usage.cost.unwrap_or(0.0),
prompt_tokens: usage.prompt_tokens,
completion_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
cached_tokens: usage.cached_tokens,
reasoning_tokens: usage.reasoning_tokens,
model: response.model.clone(),
response_id: response.id.clone(),
};
callback(cost_info);
}
}
}
Ok(self.add_child(assistant_node))
}
pub async fn chat(
self: &Arc<Self>,
user_message: impl Into<MessageContent>,
generator: &GeneratorInfo,
) -> Result<Arc<ChatNode>> {
let user_node = self.add_user(user_message);
user_node.complete(generator, None).await
}
pub async fn chat_streaming(
self: &Arc<Self>,
user_message: impl Into<MessageContent>,
generator: &GeneratorInfo,
) -> Result<(Arc<ChatNode>, StreamingCompletion)> {
let user_node = self.add_user(user_message);
let stream = user_node.complete_streaming(generator, None).await?;
Ok((user_node, stream))
}
pub fn text(&self) -> Option<&str> {
self.message.text()
}
pub fn role(&self) -> Role {
self.message.role
}
}
impl std::fmt::Debug for ChatNode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ChatNode")
.field("id", &self.id)
.field("role", &self.message.role)
.field("children_count", &self.child_count())
.finish()
}
}
#[derive(Debug, Clone)]
pub struct PrettyPrintConfig {
pub system_prefix: String,
pub user_prefix: String,
pub assistant_prefix: String,
pub separator: String,
}
impl Default for PrettyPrintConfig {
fn default() -> Self {
Self {
system_prefix: "SYSTEM: ".to_string(),
user_prefix: "\n\nUSER: ".to_string(),
assistant_prefix: "\n\nASSISTANT: ".to_string(),
separator: "".to_string(),
}
}
}
impl PrettyPrintConfig {
pub fn new(system: &str, user: &str, assistant: &str) -> Self {
Self {
system_prefix: system.to_string(),
user_prefix: user.to_string(),
assistant_prefix: assistant.to_string(),
separator: "".to_string(),
}
}
pub fn with_separator(mut self, sep: &str) -> Self {
self.separator = sep.to_string();
self
}
}
pub fn pretty_messages(node: &Arc<ChatNode>, config: Option<&PrettyPrintConfig>) -> String {
let default_config = PrettyPrintConfig::default();
let config = config.unwrap_or(&default_config);
let messages = node.formatted_thread();
let mut result = String::new();
for (i, msg) in messages.iter().enumerate() {
if i > 0 && !config.separator.is_empty() {
result.push_str(&config.separator);
}
let prefix = match msg.role {
Role::System => &config.system_prefix,
Role::User => &config.user_prefix,
Role::Assistant => &config.assistant_prefix,
Role::Tool => "\n\nTOOL: ",
};
result.push_str(prefix);
if let Some(text) = msg.content.get_text() {
result.push_str(text);
} else {
result.push_str("[multimodal content]");
}
}
result
}
pub fn format_conversation(node: &Arc<ChatNode>) -> String {
pretty_messages(node, None)
}
pub struct ConversationBuilder {
root: Arc<ChatNode>,
current: Arc<ChatNode>,
}
impl ConversationBuilder {
pub fn new(system_prompt: impl Into<String>) -> Self {
let root = ChatNode::root(system_prompt);
Self {
current: root.clone(),
root,
}
}
pub fn user(mut self, content: impl Into<MessageContent>) -> Self {
self.current = self.current.add_user(content);
self
}
pub fn assistant(mut self, content: impl Into<MessageContent>) -> Self {
self.current = self.current.add_assistant(content);
self
}
pub fn root(&self) -> Arc<ChatNode> {
self.root.clone()
}
pub fn current(&self) -> Arc<ChatNode> {
self.current.clone()
}
pub fn build(self) -> Arc<ChatNode> {
self.current
}
}
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ThreadData {
pub prompts: Vec<ThreadMessage>,
#[serde(default)]
pub required_kwargs: std::collections::HashMap<String, Option<String>>,
}
impl ThreadData {
pub fn get_kwargs(&self) -> std::collections::HashMap<String, String> {
self.required_kwargs
.iter()
.filter_map(|(k, v)| v.as_ref().map(|val| (k.clone(), val.clone())))
.collect()
}
pub fn get_placeholder_keys(&self) -> Vec<String> {
self.required_kwargs
.iter()
.filter_map(|(k, v)| if v.is_none() { Some(k.clone()) } else { None })
.collect()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ThreadMessage {
pub role: String,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub image_data: Option<ThreadImageData>,
#[serde(skip_serializing_if = "Option::is_none")]
pub audio_data: Option<ThreadAudioData>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ThreadImageData {
pub images: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ThreadAudioData {
#[serde(default)]
pub audio_paths: Vec<String>,
#[serde(default)]
pub audio_ids: std::collections::HashMap<String, String>,
}
impl ChatNode {
pub fn save_thread(&self, path: &str) -> Result<()> {
let thread_data = self.to_thread_data();
let json = serde_json::to_string_pretty(&thread_data)
.map_err(|e| MiniLLMError::Other(format!("Failed to serialize thread: {}", e)))?;
std::fs::write(path, json)
.map_err(|e| MiniLLMError::Other(format!("Failed to write file: {}", e)))?;
Ok(())
}
pub fn to_thread_data(&self) -> ThreadData {
let messages = self.thread();
let prompts: Vec<ThreadMessage> = messages
.iter()
.map(|msg| ThreadMessage {
role: msg.role.as_str().to_string(),
content: msg.content.get_text().unwrap_or("").to_string(),
image_data: None, audio_data: None, })
.collect();
let mut all_kwargs = std::collections::HashMap::new();
self.collect_format_kwargs(&mut all_kwargs);
let required_kwargs = all_kwargs.into_iter().map(|(k, v)| (k, Some(v))).collect();
ThreadData {
prompts,
required_kwargs,
}
}
pub fn from_thread_file(path: &str) -> Result<(Arc<ChatNode>, Arc<ChatNode>)> {
let json = std::fs::read_to_string(path)
.map_err(|e| MiniLLMError::Other(format!("Failed to read file: {}", e)))?;
Self::from_thread_json(&json)
}
pub fn from_thread_json(json: &str) -> Result<(Arc<ChatNode>, Arc<ChatNode>)> {
let thread_data: ThreadData = serde_json::from_str(json)
.map_err(|e| MiniLLMError::Other(format!("Failed to parse thread JSON: {}", e)))?;
Self::from_thread_data(&thread_data)
}
pub fn from_thread_data(data: &ThreadData) -> Result<(Arc<ChatNode>, Arc<ChatNode>)> {
if data.prompts.is_empty() {
return Err(MiniLLMError::Other("Thread has no messages".to_string()));
}
let mut root: Option<Arc<ChatNode>> = None;
let mut current: Option<Arc<ChatNode>> = None;
for msg in &data.prompts {
let role = match msg.role.as_str() {
"system" => Role::System,
"user" => Role::User,
"assistant" => Role::Assistant,
"tool" => Role::Tool,
_ => return Err(MiniLLMError::Other(format!("Unknown role: {}", msg.role))),
};
let message = Message {
role,
content: MessageContent::text(&msg.content),
name: None,
tool_call_id: None,
tool_calls: None,
};
let node = ChatNode::new(message);
current = Some(match current {
Some(parent) => parent.add_child(node),
None => {
root = Some(node.clone());
node
}
});
}
if let Some(ref root_node) = root {
root_node.set_format_kwargs(&data.get_kwargs());
}
Ok((root.unwrap(), current.unwrap()))
}
pub fn from_messages(messages: &[Message]) -> Result<(Arc<ChatNode>, Arc<ChatNode>)> {
if messages.is_empty() {
return Err(MiniLLMError::Other("No messages provided".to_string()));
}
let mut root: Option<Arc<ChatNode>> = None;
let mut current: Option<Arc<ChatNode>> = None;
for msg in messages {
let node = ChatNode::new(msg.clone());
current = Some(match current {
Some(parent) => parent.add_child(node),
None => {
root = Some(node.clone());
node
}
});
}
Ok((root.unwrap(), current.unwrap()))
}
}