use crate::error::{MiniLLMError, Result};
use crate::generator::{GeneratorInfo, NodeCompletionParameters};
use crate::json_repair::{loads, repair_json, JsonValue, RepairOptions};
use crate::message::{merge_contiguous_messages, ContentPart, Message, MessageContent, Role};
use crate::provider::{global_client, CompletionResponse, StreamingCompletion};
use std::sync::{Arc, RwLock};
use std::time::Duration;
use uuid::Uuid;
struct CompletionSettings {
params: Option<crate::generator::CompletionParameters>,
system_prompt: Option<String>,
force_prepend: Option<String>,
format_kwargs: std::collections::HashMap<String, String>,
add_child: bool,
parse_json: bool,
use_cache: bool,
crash_on_refusal: bool,
crash_on_empty: bool,
retry: u32,
exp_back_off: bool,
back_off_time: f64,
max_back_off: f64,
timeout: Option<Duration>,
track_cost: bool,
token_price: Option<crate::provider::TokenPrice>,
cost_callback: Option<crate::provider::CostCallback>,
}
impl CompletionSettings {
fn from_params(params: Option<&NodeCompletionParameters>) -> Self {
let defaults = NodeCompletionParameters::default();
let p = params.unwrap_or(&defaults);
Self {
params: p.params.clone(),
system_prompt: p.system_prompt.clone(),
force_prepend: p.force_prepend.clone(),
format_kwargs: p.format_kwargs.clone(),
add_child: p.add_child,
parse_json: p.parse_json,
use_cache: p.use_cache,
crash_on_refusal: p.crash_on_refusal,
crash_on_empty: p.crash_on_empty_response,
retry: p.retry,
exp_back_off: p.exp_back_off,
back_off_time: p.back_off_time,
max_back_off: p.max_back_off,
timeout: p.timeout_secs.filter(|&s| s > 0).map(Duration::from_secs),
track_cost: p.track_cost,
token_price: p.token_price.clone(),
cost_callback: p.cost_callback.clone(),
}
}
fn merged_params(&self, generator: &GeneratorInfo) -> crate::generator::CompletionParameters {
match &self.params {
Some(p) => generator.default_params.merge(p),
None => generator.default_params.clone(),
}
}
fn price<'a>(
&'a self,
generator: &'a GeneratorInfo,
) -> Option<&'a crate::provider::TokenPrice> {
self.token_price.as_ref().or(generator.token_price.as_ref())
}
fn fire_cost_callback(&self, generator: &GeneratorInfo, response: &CompletionResponse) {
if let (Some(callback), Some(usage)) = (&self.cost_callback, &response.usage) {
let outcome = generator
.provider
.cost_of(usage.clone(), self.price(generator));
callback(outcome.into_cost_info(response.model.clone(), response.id.clone()));
}
}
}
#[derive(Debug, Clone, Copy)]
enum ResponseMode {
NonStreaming,
Streaming,
}
fn is_retryable(error: &MiniLLMError) -> bool {
match error {
MiniLLMError::Api { status, .. } => *status == 408 || *status == 429 || *status >= 500,
MiniLLMError::Timeout | MiniLLMError::Stream(_) => true,
MiniLLMError::Http(e) => e.is_connect() || e.is_request() || e.is_timeout(),
_ => false,
}
}
fn repair_and_validate_json(content: &str, crash_on_refusal: bool) -> Result<String> {
let opts = RepairOptions::default();
if crash_on_refusal {
let value = loads(content, &opts)?;
if json_value_is_empty(&value) {
return Err(MiniLLMError::NoJsonFound(content.to_string()));
}
}
Ok(repair_json(content, &opts)?)
}
fn json_value_is_empty(value: &JsonValue) -> bool {
match value {
JsonValue::Null => true,
JsonValue::String(s) => s.is_empty(),
JsonValue::Array(a) => a.is_empty(),
JsonValue::Object(o) => o.is_empty(),
_ => false,
}
}
fn tracked_params(params: Option<&NodeCompletionParameters>) -> NodeCompletionParameters {
let mut p = params.cloned().unwrap_or_default();
p.track_cost = true;
p.cost_callback = None;
p
}
fn tracked_streaming_params(params: Option<&NodeCompletionParameters>) -> NodeCompletionParameters {
const DEFAULT_TRACKED_IDLE_TIMEOUT_SECS: u64 = 120;
let mut p = tracked_params(params);
if p.timeout_secs.is_none_or(|s| s == 0) {
p.timeout_secs = Some(DEFAULT_TRACKED_IDLE_TIMEOUT_SECS);
}
p
}
fn apply_kwargs(template: &str, kwargs: &std::collections::HashMap<String, String>) -> String {
let mut result = String::with_capacity(template.len());
let mut rest = template;
while let Some(open) = rest.find('{') {
result.push_str(&rest[..open]);
let after_open = &rest[open + 1..];
match after_open.find('}') {
Some(close) => {
let key = &after_open[..close];
match kwargs.get(key) {
Some(value) => result.push_str(value),
None => {
result.push('{');
result.push_str(key);
result.push('}');
}
}
rest = &after_open[close + 1..];
}
None => {
result.push_str(&rest[open..]);
rest = "";
}
}
}
result.push_str(rest);
result
}
struct NodeData {
message: Message,
metadata: serde_json::Value,
format_kwargs: std::collections::HashMap<String, String>,
cache_breakpoint: bool,
parent: Option<String>,
children: Vec<String>,
phantom_child_count: usize,
refcount: usize,
}
struct Tree {
nodes: RwLock<std::collections::HashMap<String, NodeData>>,
}
impl Tree {
fn with_root(message: Message) -> (Arc<Tree>, String) {
let id = Uuid::new_v4().to_string();
let mut nodes = std::collections::HashMap::new();
nodes.insert(id.clone(), NodeData::new(message, None));
(
Arc::new(Tree {
nodes: RwLock::new(nodes),
}),
id,
)
}
fn release(&self, id: &str) {
let mut nodes = self.nodes.write().unwrap();
if let Some(node) = nodes.get_mut(id) {
node.refcount = node.refcount.saturating_sub(1);
}
Self::reclaim_dead(&mut nodes, id.to_string());
}
fn reclaim_dead(nodes: &mut NodeMap, start: String) {
let mut current = Some(start);
while let Some(cur) = current {
let Some(node) = nodes.get(&cur) else {
break;
};
if node.refcount > 0 || !node.children.is_empty() || node.phantom_child_count > 0 {
break;
}
let parent = node.parent.clone();
nodes.remove(&cur);
if let Some(pid) = &parent {
if let Some(p) = nodes.get_mut(pid) {
let before = p.children.len();
p.children.retain(|c| c != &cur);
if p.children.len() == before {
p.phantom_child_count = p.phantom_child_count.checked_sub(1).expect(
"phantom_child_count underflow: registered/phantom invariant broken",
);
}
}
}
current = parent; }
}
fn unlink_from_parent(nodes: &mut NodeMap, id: &str) {
let parent_id = nodes.get(id).and_then(|n| n.parent.clone());
if let Some(pid) = &parent_id {
if let Some(p) = nodes.get_mut(pid) {
let before = p.children.len();
p.children.retain(|c| c != id);
if p.children.len() == before {
p.phantom_child_count = p.phantom_child_count.checked_sub(1).expect(
"phantom_child_count underflow: registered/phantom invariant broken",
);
}
}
}
if let Some(n) = nodes.get_mut(id) {
n.parent = None;
}
if let Some(pid) = parent_id {
Self::reclaim_dead(nodes, pid);
}
}
fn is_self_or_ancestor_locked(nodes: &NodeMap, start: &str, target: &str) -> bool {
let mut cur = Some(start.to_string());
while let Some(c) = cur {
if c == target {
return true;
}
cur = nodes.get(&c).and_then(|n| n.parent.clone());
}
false
}
}
type NodeMap = std::collections::HashMap<String, NodeData>;
fn lock_two_write<'a>(
first: &'a Arc<Tree>,
second: &'a Arc<Tree>,
) -> (
std::sync::RwLockWriteGuard<'a, NodeMap>,
std::sync::RwLockWriteGuard<'a, NodeMap>,
) {
assert!(
!Arc::ptr_eq(first, second),
"lock_two_write requires two distinct trees"
);
if Arc::as_ptr(first) < Arc::as_ptr(second) {
let f = first.nodes.write().unwrap();
let s = second.nodes.write().unwrap();
(f, s)
} else {
let s = second.nodes.write().unwrap();
let f = first.nodes.write().unwrap();
(f, s)
}
}
impl NodeData {
fn new(message: Message, parent: Option<String>) -> Self {
NodeData {
message,
metadata: serde_json::json!({}),
format_kwargs: std::collections::HashMap::new(),
cache_breakpoint: false,
parent,
children: Vec::new(),
phantom_child_count: 0,
refcount: 1,
}
}
}
pub struct ChatNode {
pub id: String,
pub message: Message,
tree: Arc<Tree>,
}
impl Clone for ChatNode {
fn clone(&self) -> Self {
self.tree
.nodes
.write()
.unwrap()
.get_mut(&self.id)
.expect("node id present in its own tree")
.refcount += 1;
Self {
id: self.id.clone(),
message: self.message.clone(),
tree: self.tree.clone(),
}
}
}
impl Drop for ChatNode {
fn drop(&mut self) {
self.tree.release(&self.id);
}
}
impl ChatNode {
fn with_node<R>(&self, id: &str, f: impl FnOnce(&NodeData) -> R) -> R {
let nodes = self.tree.nodes.read().unwrap();
f(nodes.get(id).expect("node id present in its own tree"))
}
fn with_node_mut<R>(&self, id: &str, f: impl FnOnce(&mut NodeData) -> R) -> R {
let mut nodes = self.tree.nodes.write().unwrap();
f(nodes.get_mut(id).expect("node id present in its own tree"))
}
fn handle(&self, id: String) -> ChatNode {
let message = {
let mut nodes = self.tree.nodes.write().unwrap();
let n = nodes.get_mut(&id).expect("node id present in its own tree");
n.refcount += 1;
n.message.clone()
};
ChatNode {
id,
message,
tree: self.tree.clone(),
}
}
fn handle_owned(&self, id: String, message: Message) -> ChatNode {
ChatNode {
id,
message,
tree: self.tree.clone(),
}
}
#[cfg(test)]
fn arena_len(&self) -> usize {
self.tree.nodes.read().unwrap().len()
}
fn insert_child(&self, parent_id: &str, message: Message) -> ChatNode {
let id = Uuid::new_v4().to_string();
{
let mut nodes = self.tree.nodes.write().unwrap();
nodes.insert(
id.clone(),
NodeData::new(message.clone(), Some(parent_id.to_string())),
);
nodes
.get_mut(parent_id)
.expect("parent id present")
.children
.push(id.clone());
}
self.handle_owned(id, message)
}
fn copy_subtree_under(&self, parent_id: &str, src: &ChatNode) -> String {
debug_assert!(
!Arc::ptr_eq(&self.tree, &src.tree),
"copy_subtree_under must be cross-tree (same-tree would deadlock)"
);
let new_root_id = Uuid::new_v4().to_string();
let mut stack = vec![(
src.id.clone(),
Some(parent_id.to_string()),
new_root_id.clone(),
)];
let (src_nodes, mut nodes) = lock_two_write(&src.tree, &self.tree);
while let Some((src_id, new_parent, new_id)) = stack.pop() {
let src_data = src_nodes.get(&src_id).expect("src node present");
let child_pairs: Vec<(String, String)> = src_data
.children
.iter()
.map(|c| (c.clone(), Uuid::new_v4().to_string()))
.collect();
nodes.insert(
new_id.clone(),
NodeData {
message: src_data.message.clone(),
metadata: src_data.metadata.clone(),
format_kwargs: src_data.format_kwargs.clone(),
cache_breakpoint: src_data.cache_breakpoint,
parent: new_parent.clone(),
children: child_pairs.iter().map(|(_, n)| n.clone()).collect(),
phantom_child_count: 0,
refcount: 0,
},
);
if let Some(p) = &new_parent {
if p == parent_id {
nodes
.get_mut(parent_id)
.expect("parent id present")
.children
.push(new_id.clone());
}
}
for (src_child, new_child) in child_pairs {
stack.push((src_child, Some(new_id.clone()), new_child));
}
}
new_root_id
}
pub fn root(system_prompt: impl Into<String>) -> Self {
Self::new(Message::system(system_prompt.into()))
}
pub fn new(message: Message) -> Self {
let (tree, id) = Tree::with_root(message.clone());
Self { id, message, tree }
}
pub fn user(content: impl Into<MessageContent>) -> Self {
Self::new(Message::user(content))
}
pub fn assistant(content: impl Into<MessageContent>) -> Self {
Self::new(Message::assistant(content))
}
pub fn add_child(&self, child: ChatNode) -> Result<ChatNode> {
if Arc::ptr_eq(&self.tree, &child.tree) {
let mut nodes = self.tree.nodes.write().unwrap();
if Tree::is_self_or_ancestor_locked(&nodes, &self.id, &child.id) {
return Err(MiniLLMError::InvalidParameter(
"add_child would create a cycle: the child is this node or one of its ancestors"
.to_string(),
));
}
Tree::unlink_from_parent(&mut nodes, &child.id);
nodes.get_mut(&child.id).expect("child present").parent = Some(self.id.clone());
nodes
.get_mut(&self.id)
.expect("self present")
.children
.push(child.id.clone());
drop(nodes);
Ok(child)
} else {
let new_id = self.copy_subtree_under(&self.id, &child);
Ok(self.handle(new_id))
}
}
pub fn add_user(&self, content: impl Into<MessageContent>) -> ChatNode {
self.insert_child(&self.id, Message::user(content))
}
pub fn add_assistant(&self, content: impl Into<MessageContent>) -> ChatNode {
self.insert_child(&self.id, Message::assistant(content))
}
pub fn parent(&self) -> Option<ChatNode> {
let parent_id = self.with_node(&self.id, |n| n.parent.clone())?;
Some(self.handle(parent_id))
}
pub fn children(&self) -> Vec<ChatNode> {
let ids = self.with_node(&self.id, |n| n.children.clone());
ids.into_iter().map(|id| self.handle(id)).collect()
}
pub fn child_count(&self) -> usize {
self.with_node(&self.id, |n| n.children.len())
}
pub fn is_root(&self) -> bool {
self.with_node(&self.id, |n| n.parent.is_none())
}
pub fn get_root(&self) -> ChatNode {
let mut id = self.id.clone();
while let Some(parent) = self.with_node(&id, |n| n.parent.clone()) {
id = parent;
}
self.handle(id)
}
pub fn is_leaf(&self) -> bool {
self.with_node(&self.id, |n| n.children.is_empty())
}
pub fn thread(&self) -> Vec<Message> {
self.node_path().iter().map(|n| n.message.clone()).collect()
}
pub fn merged_thread(&self) -> Vec<Message> {
merge_contiguous_messages(self.thread())
}
pub fn depth(&self) -> usize {
let mut depth = 0;
let mut node = self.parent();
while let Some(n) = node {
depth += 1;
node = n.parent();
}
depth
}
pub fn find_by_id(&self, id: &str) -> Option<ChatNode> {
self.iter_depth_first().into_iter().find(|n| n.id == id)
}
pub fn last_child(&self) -> Option<ChatNode> {
self.children().pop()
}
pub fn get_leaf(&self) -> ChatNode {
let mut node = self.clone();
while let Some(child) = node.last_child() {
node = child;
}
node
}
pub fn detach(&self) -> ChatNode {
let mut nodes = self.tree.nodes.write().unwrap();
Tree::unlink_from_parent(&mut nodes, &self.id);
drop(nodes);
self.clone()
}
pub fn merge(&self, other: &ChatNode) -> Result<ChatNode> {
let other_root = other.get_root();
let copied_root = self.add_child(other_root)?;
Ok(copied_root.get_leaf())
}
pub fn clone_tree(&self) -> ChatNode {
let src_nodes = self.tree.nodes.read().unwrap();
let id_map: std::collections::HashMap<String, String> = src_nodes
.keys()
.map(|id| (id.clone(), Uuid::new_v4().to_string()))
.collect();
let new_id = id_map[&self.id].clone();
let new_nodes: std::collections::HashMap<String, NodeData> = src_nodes
.iter()
.map(|(id, data)| {
let new_id_for_node = id_map[id].clone();
let refcount = usize::from(new_id_for_node == new_id);
(
new_id_for_node,
NodeData {
message: data.message.clone(),
metadata: data.metadata.clone(),
format_kwargs: data.format_kwargs.clone(),
cache_breakpoint: data.cache_breakpoint,
parent: data.parent.as_ref().map(|p| id_map[p].clone()),
children: data.children.iter().map(|c| id_map[c].clone()).collect(),
phantom_child_count: data.phantom_child_count,
refcount,
},
)
})
.collect();
let new_tree = Arc::new(Tree {
nodes: RwLock::new(new_nodes),
});
ChatNode {
id: new_id,
message: self.message.clone(),
tree: new_tree,
}
}
pub fn iter_depth_first(&self) -> Vec<ChatNode> {
let mut result = Vec::new();
let mut stack = vec![self.clone()];
while let Some(node) = stack.pop() {
let children = node.children();
result.push(node);
stack.extend(children.into_iter().rev());
}
result
}
pub fn iter_breadth_first(&self) -> Vec<ChatNode> {
let mut result = Vec::new();
let mut queue = std::collections::VecDeque::new();
queue.push_back(self.clone());
while let Some(node) = queue.pop_front() {
result.push(node.clone());
for child in node.children() {
queue.push_back(child);
}
}
result
}
pub fn iter_leaves(&self) -> Vec<ChatNode> {
self.iter_depth_first()
.into_iter()
.filter(|n| n.is_leaf())
.collect()
}
pub fn node_count(&self) -> usize {
self.iter_depth_first().len()
}
pub fn set_metadata(&self, key: &str, value: serde_json::Value) {
self.with_node_mut(&self.id, |n| n.metadata[key] = value);
}
pub fn get_metadata(&self, key: &str) -> Option<serde_json::Value> {
self.with_node(&self.id, |n| n.metadata.get(key).cloned())
}
pub fn cache_breakpoint(&self) -> ChatNode {
self.with_node_mut(&self.id, |n| n.cache_breakpoint = true);
self.clone()
}
pub fn is_cache_breakpoint(&self) -> bool {
self.with_node(&self.id, |n| n.cache_breakpoint)
}
pub fn clear_cache_breakpoint(&self) {
self.with_node_mut(&self.id, |n| n.cache_breakpoint = false);
}
pub fn clear_all_cache_breakpoints(&self) {
let mut nodes = self.tree.nodes.write().unwrap();
for node in nodes.values_mut() {
node.cache_breakpoint = false;
}
}
pub fn set_format_kwarg(&self, key: &str, value: &str) {
self.with_node_mut(&self.id, |n| {
n.format_kwargs.insert(key.to_string(), value.to_string());
});
}
pub fn set_format_kwargs(&self, kwargs: &std::collections::HashMap<String, String>) {
self.with_node_mut(&self.id, |n| {
for (k, v) in kwargs {
n.format_kwargs.insert(k.clone(), v.clone());
}
});
}
pub fn get_format_kwarg(&self, key: &str) -> Option<String> {
self.with_node(&self.id, |n| n.format_kwargs.get(key).cloned())
}
pub fn get_format_kwargs(&self) -> std::collections::HashMap<String, String> {
self.with_node(&self.id, |n| n.format_kwargs.clone())
}
pub fn formatted_text(&self) -> Option<String> {
let text = self.message.content.get_text()?;
Some(apply_kwargs(text, &self.get_format_kwargs()))
}
pub fn format_string(&self, template: &str) -> String {
apply_kwargs(template, &self.get_format_kwargs())
}
pub fn formatted_thread(&self) -> Vec<Message> {
self.formatted_thread_with_base(&std::collections::HashMap::new())
}
pub fn formatted_thread_with_base(
&self,
base: &std::collections::HashMap<String, String>,
) -> Vec<Message> {
self.node_path()
.into_iter()
.map(|node| {
let mut kwargs = base.clone();
for (k, v) in node.get_format_kwargs() {
kwargs.insert(k, v);
}
let mut msg = node.message.clone();
msg.cache_breakpoint = node.with_node(&node.id, |n| n.cache_breakpoint);
msg.content = match msg.content {
MessageContent::Text(text) => {
MessageContent::Text(apply_kwargs(&text, &kwargs))
}
MessageContent::Parts(parts) => MessageContent::Parts(
parts
.into_iter()
.map(|part| match part.as_text() {
Some(text) => ContentPart::text(apply_kwargs(text, &kwargs)),
None => part,
})
.collect(),
),
};
msg
})
.collect()
}
fn node_path(&self) -> Vec<ChatNode> {
let mut path = vec![self.clone()];
let mut node = self.parent();
while let Some(n) = node {
node = n.parent();
path.push(n);
}
path.reverse();
path
}
pub async fn complete(
&self,
generator: &GeneratorInfo,
params: Option<&NodeCompletionParameters>,
) -> Result<ChatNode> {
Ok(self.complete_collect(generator, params).await?.0)
}
pub async fn complete_collect(
&self,
generator: &GeneratorInfo,
params: Option<&NodeCompletionParameters>,
) -> Result<(ChatNode, CompletionResponse)> {
let settings = CompletionSettings::from_params(params);
self.run_with_retry(generator, &settings, ResponseMode::NonStreaming)
.await
}
pub async fn ensure_cached(
&self,
generator: &GeneratorInfo,
params: Option<&NodeCompletionParameters>,
) -> Result<crate::provider::CostInfo> {
let mut warm = params.cloned().unwrap_or_default();
warm.use_cache = true;
let mut completion_params = warm.params.unwrap_or_default();
completion_params.max_tokens = Some(0);
warm.params = Some(completion_params);
let settings = CompletionSettings::from_params(Some(&warm));
let messages = self.prepare_messages(&settings);
let completion_params = settings.merged_params(generator);
let response = global_client()
.complete_with_usage_tracking(
generator,
&messages,
&completion_params,
true,
settings.timeout,
)
.await?;
let price = settings.price(generator);
let outcome = match &response.usage {
Some(usage) => generator.provider.cost_of(usage.clone(), price),
None => crate::provider::CostOutcome::unknown(),
};
let info = outcome.into_cost_info(response.model.clone(), response.id.clone());
if let Some(cb) = &settings.cost_callback {
cb(info.clone());
}
Ok(info)
}
async fn fetch_response(
&self,
generator: &GeneratorInfo,
settings: &CompletionSettings,
mode: ResponseMode,
) -> Result<CompletionResponse> {
match mode {
ResponseMode::NonStreaming => {
let messages = self.prepare_messages(settings);
let completion_params = settings.merged_params(generator);
global_client()
.complete_with_usage_tracking(
generator,
&messages,
&completion_params,
settings.track_cost,
settings.timeout,
)
.await
}
ResponseMode::Streaming => {
self.start_streaming(generator, settings)
.await?
.collect()
.await
}
}
}
async fn run_with_retry(
&self,
generator: &GeneratorInfo,
settings: &CompletionSettings,
mode: ResponseMode,
) -> Result<(ChatNode, CompletionResponse)> {
let mut last_error: Option<MiniLLMError> = None;
let mut current_back_off = settings.back_off_time;
for attempt in 0..=settings.retry {
if attempt > 0 {
let sleep_time = if settings.exp_back_off {
current_back_off.min(settings.max_back_off)
} else {
settings.back_off_time
};
tokio::time::sleep(Duration::from_secs_f64(sleep_time)).await;
if settings.exp_back_off {
current_back_off *= 2.0;
}
tracing::debug!(attempt = attempt, "Retrying completion request");
}
let response = match self.fetch_response(generator, settings, mode).await {
Ok(r) => r,
Err(e) => {
if !is_retryable(&e) {
return Err(e);
}
last_error = Some(e);
continue;
}
};
let content = match self.postprocess_content(&response.content, settings) {
Ok(c) => c,
Err(e) => {
last_error = Some(e);
continue;
}
};
let node = self.build_assistant_node(content, &response, settings.add_child);
settings.fire_cost_callback(generator, &response);
return Ok((node, response));
}
Err(MiniLLMError::MaxRetriesExceeded(Box::new(
last_error.unwrap_or(MiniLLMError::EmptyResponse),
)))
}
fn prepare_messages(&self, settings: &CompletionSettings) -> Vec<Message> {
let mut messages =
merge_contiguous_messages(self.formatted_thread_with_base(&settings.format_kwargs));
if let Some(system) = &settings.system_prompt {
if messages.first().map(|m| m.role) != Some(Role::System) {
messages.insert(0, Message::system(system.clone()));
}
}
if let Some(prepend) = &settings.force_prepend {
messages.push(Message::assistant(prepend.clone()));
}
if settings.use_cache {
if let Some(last) = messages.last_mut() {
last.cache_breakpoint = true;
}
}
messages
}
fn postprocess_content(&self, raw: &str, settings: &CompletionSettings) -> Result<String> {
let mut content = raw.to_string();
if let Some(prepend) = &settings.force_prepend {
if !content.starts_with(prepend) {
content = format!("{}{}", prepend, content);
}
}
if settings.crash_on_empty && content.trim().is_empty() {
return Err(MiniLLMError::EmptyResponse);
}
if settings.parse_json {
content = repair_and_validate_json(&content, settings.crash_on_refusal)?;
}
Ok(content)
}
fn build_assistant_node(
&self,
content: String,
response: &CompletionResponse,
add_child: bool,
) -> ChatNode {
let mut message = Message::assistant(content);
message.tool_calls = response.tool_calls.clone();
let node = if add_child {
self.insert_child(&self.id, message)
} else {
self.insert_phantom_child(&self.id, message)
};
node.set_metadata("response_id", serde_json::json!(response.id));
node.set_metadata("model", serde_json::json!(response.model));
if let Some(usage) = &response.usage {
node.set_metadata("usage", serde_json::json!(usage));
}
if let Some(finish_reason) = &response.finish_reason {
node.set_metadata("finish_reason", serde_json::json!(finish_reason));
}
node
}
fn insert_phantom_child(&self, parent_id: &str, message: Message) -> ChatNode {
let id = Uuid::new_v4().to_string();
{
let mut nodes = self.tree.nodes.write().unwrap();
nodes.insert(
id.clone(),
NodeData::new(message.clone(), Some(parent_id.to_string())),
);
nodes
.get_mut(parent_id)
.expect("parent id present")
.phantom_child_count += 1;
}
self.handle_owned(id, message)
}
pub async fn complete_streaming(
&self,
generator: &GeneratorInfo,
params: Option<&NodeCompletionParameters>,
) -> Result<StreamingCompletion> {
let settings = CompletionSettings::from_params(params);
self.start_streaming(generator, &settings).await
}
async fn start_streaming(
&self,
generator: &GeneratorInfo,
settings: &CompletionSettings,
) -> Result<StreamingCompletion> {
let messages = self.prepare_messages(settings);
let completion_params = settings.merged_params(generator);
global_client()
.complete_streaming_with_usage(
generator,
&messages,
&completion_params,
settings.track_cost,
settings.timeout,
)
.await
}
pub async fn complete_streaming_collect(
&self,
generator: &GeneratorInfo,
params: Option<&NodeCompletionParameters>,
) -> Result<ChatNode> {
let settings = CompletionSettings::from_params(params);
let (node, _response) = self
.run_with_retry(generator, &settings, ResponseMode::Streaming)
.await?;
Ok(node)
}
pub async fn complete_tracked(
&self,
ctx: &crate::tracking::CompletionContext,
params: Option<&NodeCompletionParameters>,
) -> Result<ChatNode> {
let tracked_params = tracked_params(params);
let (node, response) = self
.complete_collect(&ctx.generator, Some(&tracked_params))
.await?;
ctx.report_cost(ctx.cost_for_response(&response).await)
.await;
Ok(node)
}
pub async fn complete_streaming_tracked(
&self,
ctx: &crate::tracking::CompletionContext,
params: Option<&NodeCompletionParameters>,
) -> Result<crate::tracking::TrackedStream> {
let tracked_params = tracked_streaming_params(params);
let stream = self
.complete_streaming(&ctx.generator, Some(&tracked_params))
.await?;
Ok(crate::tracking::TrackedStream::new(stream, ctx))
}
pub async fn complete_streaming_collect_tracked(
&self,
ctx: &crate::tracking::CompletionContext,
params: Option<&NodeCompletionParameters>,
) -> Result<ChatNode> {
let tracked_params = tracked_streaming_params(params);
let settings = CompletionSettings::from_params(Some(&tracked_params));
let (node, response) = self
.run_with_retry(&ctx.generator, &settings, ResponseMode::Streaming)
.await?;
ctx.report_cost(ctx.cost_for_response(&response).await)
.await;
Ok(node)
}
pub async fn chat(
&self,
user_message: impl Into<MessageContent>,
generator: &GeneratorInfo,
) -> Result<ChatNode> {
let user_node = self.add_user(user_message);
user_node.complete(generator, None).await
}
pub async fn chat_streaming(
&self,
user_message: impl Into<MessageContent>,
generator: &GeneratorInfo,
) -> Result<(ChatNode, StreamingCompletion)> {
let user_node = self.add_user(user_message);
let stream = user_node.complete_streaming(generator, None).await?;
Ok((user_node, stream))
}
pub fn text(&self) -> Option<&str> {
self.message.text()
}
pub fn role(&self) -> Role {
self.message.role
}
}
impl std::fmt::Debug for ChatNode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ChatNode")
.field("id", &self.id)
.field("role", &self.message.role)
.field("children_count", &self.child_count())
.finish()
}
}
#[derive(Debug, Clone)]
pub struct PrettyPrintConfig {
pub system_prefix: String,
pub user_prefix: String,
pub assistant_prefix: String,
pub separator: String,
}
impl Default for PrettyPrintConfig {
fn default() -> Self {
Self {
system_prefix: "SYSTEM: ".to_string(),
user_prefix: "\n\nUSER: ".to_string(),
assistant_prefix: "\n\nASSISTANT: ".to_string(),
separator: "".to_string(),
}
}
}
impl PrettyPrintConfig {
pub fn new(system: &str, user: &str, assistant: &str) -> Self {
Self {
system_prefix: system.to_string(),
user_prefix: user.to_string(),
assistant_prefix: assistant.to_string(),
separator: "".to_string(),
}
}
pub fn with_separator(mut self, sep: &str) -> Self {
self.separator = sep.to_string();
self
}
}
pub fn pretty_messages(node: &ChatNode, config: Option<&PrettyPrintConfig>) -> String {
let default_config = PrettyPrintConfig::default();
let config = config.unwrap_or(&default_config);
let messages = node.formatted_thread();
let mut result = String::new();
for (i, msg) in messages.iter().enumerate() {
if i > 0 && !config.separator.is_empty() {
result.push_str(&config.separator);
}
let prefix = match msg.role {
Role::System => &config.system_prefix,
Role::User => &config.user_prefix,
Role::Assistant => &config.assistant_prefix,
Role::Tool => "\n\nTOOL: ",
};
result.push_str(prefix);
let text = msg.content.all_text();
if text.is_empty() && msg.content.has_multimodal() {
result.push_str("[multimodal content]");
} else {
result.push_str(&text);
}
}
result
}
pub fn format_conversation(node: &ChatNode) -> String {
pretty_messages(node, None)
}
pub struct ConversationBuilder {
root: ChatNode,
current: ChatNode,
}
impl ConversationBuilder {
pub fn new(system_prompt: impl Into<String>) -> Self {
let root = ChatNode::root(system_prompt);
Self {
current: root.clone(),
root,
}
}
pub fn user(mut self, content: impl Into<MessageContent>) -> Self {
self.current = self.current.add_user(content);
self
}
pub fn assistant(mut self, content: impl Into<MessageContent>) -> Self {
self.current = self.current.add_assistant(content);
self
}
pub fn root(&self) -> ChatNode {
self.root.clone()
}
pub fn current(&self) -> ChatNode {
self.current.clone()
}
pub fn build(self) -> ChatNode {
self.current
}
}
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ThreadNode {
pub message: Message,
#[serde(default, skip_serializing_if = "std::collections::HashMap::is_empty")]
pub format_kwargs: std::collections::HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ThreadData {
pub prompts: Vec<ThreadNode>,
}
impl ChatNode {
pub fn save_thread(&self, path: &str) -> Result<()> {
let json = serde_json::to_string_pretty(&self.to_thread_data())?;
std::fs::write(path, json)?;
Ok(())
}
pub fn to_thread_data(&self) -> ThreadData {
ThreadData {
prompts: self
.node_path()
.into_iter()
.map(|node| ThreadNode {
message: node.message.clone(),
format_kwargs: node.get_format_kwargs(),
})
.collect(),
}
}
pub fn from_thread_file(path: &str) -> Result<(ChatNode, ChatNode)> {
Self::from_thread_json(&std::fs::read_to_string(path)?)
}
pub fn from_thread_json(json: &str) -> Result<(ChatNode, ChatNode)> {
Self::from_thread_data(&serde_json::from_str(json)?)
}
pub fn from_thread_data(data: &ThreadData) -> Result<(ChatNode, ChatNode)> {
Self::build_chain(
data.prompts
.iter()
.map(|e| (e.message.clone(), e.format_kwargs.clone())),
)
}
pub fn from_messages(messages: &[Message]) -> Result<(ChatNode, ChatNode)> {
Self::build_chain(
messages
.iter()
.map(|m| (m.clone(), std::collections::HashMap::new())),
)
}
fn build_chain(
nodes: impl IntoIterator<Item = (Message, std::collections::HashMap<String, String>)>,
) -> Result<(ChatNode, ChatNode)> {
let mut iter = nodes.into_iter();
let (first_msg, first_kwargs) = iter.next().ok_or(MiniLLMError::EmptyThread)?;
let root = ChatNode::new(first_msg);
root.set_format_kwargs(&first_kwargs);
let mut current = root.clone();
for (msg, kwargs) in iter {
current = current.insert_child(¤t.id, msg);
current.set_format_kwargs(&kwargs);
}
Ok((root, current))
}
}
#[cfg(test)]
mod completion_pipeline_tests {
use super::*;
use crate::provider::CompletionResponse;
fn response_with(content: &str) -> CompletionResponse {
CompletionResponse::new("gen-1", "test-model", content)
}
#[test]
fn retryable_classification() {
assert!(is_retryable(&MiniLLMError::Api {
status: 429,
message: "rate".into()
}));
assert!(is_retryable(&MiniLLMError::Api {
status: 503,
message: "down".into()
}));
assert!(is_retryable(&MiniLLMError::Timeout));
assert!(is_retryable(&MiniLLMError::Stream("boom".into())));
assert!(!is_retryable(&MiniLLMError::Api {
status: 401,
message: "bad key".into()
}));
assert!(!is_retryable(&MiniLLMError::Api {
status: 400,
message: "bad req".into()
}));
assert!(!is_retryable(&MiniLLMError::EmptyResponse));
}
#[test]
fn json_value_emptiness() {
use crate::json_repair::{loads, RepairOptions};
let opts = RepairOptions::default();
assert!(json_value_is_empty(&loads("null", &opts).unwrap()));
assert!(json_value_is_empty(&loads("\"\"", &opts).unwrap()));
assert!(json_value_is_empty(&loads("{}", &opts).unwrap()));
assert!(json_value_is_empty(&loads("[]", &opts).unwrap()));
assert!(!json_value_is_empty(&loads("{\"a\":1}", &opts).unwrap()));
assert!(!json_value_is_empty(&loads("[1]", &opts).unwrap()));
}
#[test]
fn repair_validate_crash_on_refusal_rejects_empty_value() {
let err = repair_and_validate_json("I can't help. {}", true).unwrap_err();
assert!(matches!(err, MiniLLMError::NoJsonFound(_)));
}
#[test]
fn repair_validate_accepts_real_json_and_repairs() {
let out = repair_and_validate_json("{'a': 1,}", true).unwrap();
assert_eq!(out, r#"{"a": 1}"#);
}
#[test]
fn repair_validate_no_crash_when_flag_off() {
assert!(repair_and_validate_json("no json here", false).is_ok());
}
#[test]
fn postprocess_reattaches_force_prepend_once() {
let node = ChatNode::root("sys");
let params = NodeCompletionParameters::new().with_force_prepend("Score: ");
let settings = CompletionSettings::from_params(Some(¶ms));
assert_eq!(
node.postprocess_content("8/10", &settings).unwrap(),
"Score: 8/10"
);
assert_eq!(
node.postprocess_content("Score: 8/10", &settings).unwrap(),
"Score: 8/10"
);
}
#[test]
fn postprocess_crash_on_empty() {
let node = ChatNode::root("sys");
let params = NodeCompletionParameters::new().with_crash_on_empty(true);
let settings = CompletionSettings::from_params(Some(¶ms));
assert!(matches!(
node.postprocess_content(" \n ", &settings).unwrap_err(),
MiniLLMError::EmptyResponse
));
assert_eq!(node.postprocess_content("hi", &settings).unwrap(), "hi");
}
#[test]
fn postprocess_parse_json_repairs() {
let node = ChatNode::root("sys");
let params = NodeCompletionParameters::new().with_parse_json(true);
let settings = CompletionSettings::from_params(Some(¶ms));
assert_eq!(
node.postprocess_content("{'a': 1,}", &settings).unwrap(),
r#"{"a": 1}"#
);
}
#[test]
fn postprocess_default_is_passthrough() {
let node = ChatNode::root("sys");
let settings = CompletionSettings::from_params(None);
assert_eq!(
node.postprocess_content("plain text", &settings).unwrap(),
"plain text"
);
}
#[test]
fn build_node_as_real_child_threads_tool_calls_and_metadata() {
let root = ChatNode::root("sys");
let user = root.add_user("hi");
let mut response = response_with("answer");
response.finish_reason = Some("tool_calls".into());
response.tool_calls = Some(vec![serde_json::json!({"id": "c1"})]);
let node = user.build_assistant_node("answer".into(), &response, true);
assert_eq!(user.child_count(), 1);
assert_eq!(node.parent().unwrap().id, user.id);
assert!(node.message.tool_calls.is_some());
assert_eq!(
node.get_metadata("finish_reason"),
Some(serde_json::json!("tool_calls"))
);
assert_eq!(
node.get_metadata("model"),
Some(serde_json::json!("test-model"))
);
}
#[test]
fn build_node_as_phantom_does_not_register_in_parent() {
let root = ChatNode::root("sys");
let user = root.add_user("hi");
let response = response_with("answer");
let phantom = user.build_assistant_node("answer".into(), &response, false);
assert_eq!(user.child_count(), 0);
assert_eq!(phantom.parent().unwrap().id, user.id);
let thread = phantom.thread();
assert_eq!(thread.len(), 3); assert_eq!(thread[2].text(), Some("answer"));
assert_eq!(phantom.get_root().id, root.id);
}
#[test]
fn clone_tree_copies_the_whole_tree_isolated() {
let root = ChatNode::root("sys");
let a = root.add_user("a");
a.set_format_kwarg("k", "v");
let _b = root.add_user("b");
let cloned_a = a.clone_tree();
assert_ne!(cloned_a.id, a.id);
assert_eq!(cloned_a.text(), Some("a"));
assert_eq!(cloned_a.get_format_kwarg("k"), Some("v".to_string()));
let cloned_root = cloned_a.get_root();
assert_ne!(cloned_root.id, root.id);
assert_eq!(cloned_root.text(), Some("sys"));
assert_eq!(
cloned_root.node_count(),
3,
"clone has root + both branches"
);
assert_eq!(cloned_a.thread().len(), 2);
cloned_a.set_format_kwarg("k", "changed");
assert_eq!(a.get_format_kwarg("k"), Some("v".to_string()));
cloned_a.add_assistant("reply");
assert!(
a.is_leaf(),
"extending the clone must not touch the original"
);
}
#[test]
fn clone_tree_works_on_a_phantom_node() {
let root = ChatNode::root("sys");
let user = root.add_user("u");
let phantom = user.build_assistant_node(
"answer".into(),
&crate::provider::CompletionResponse::new("g", "m", "answer"),
false, );
assert_eq!(
user.child_count(),
0,
"precondition: phantom not registered"
);
let cloned = phantom.clone_tree();
assert_eq!(cloned.text(), Some("answer"));
assert_eq!(cloned.thread().len(), 3); assert_ne!(cloned.id, phantom.id);
}
#[test]
fn holding_a_node_keeps_its_full_ancestor_history_alive() {
let leaf = {
let root = ChatNode::root("sys");
let user = root.add_user("hi");
user.add_assistant("there")
};
assert_eq!(leaf.thread().len(), 3);
assert_eq!(leaf.get_root().text(), Some("sys"));
assert_eq!(leaf.thread()[0].text(), Some("sys"));
}
#[test]
fn held_node_keeps_its_ancestor_chain_but_unheld_branches_are_reclaimed() {
let root = ChatNode::root("sys");
let kept = root.add_user("kept");
root.add_user("dropped"); root.add_user("alsodropped");
assert_eq!(
root.arena_len(),
2,
"unheld sibling branches were reclaimed"
);
assert_eq!(root.child_count(), 1);
drop(root);
let root = kept.get_root();
assert_eq!(root.text(), Some("sys"));
assert_eq!(root.node_count(), 2); assert_eq!(kept.arena_len(), 2);
}
#[test]
fn dropping_every_handle_frees_the_tree() {
let leaf = {
let root = ChatNode::root("sys");
let u = root.add_user("u");
u.add_assistant("a")
};
assert_eq!(leaf.get_root().node_count(), 3);
let tree = leaf.tree.clone();
assert_eq!(Arc::strong_count(&tree), 2); drop(leaf);
assert_eq!(Arc::strong_count(&tree), 1);
}
#[test]
fn phantom_node_is_reclaimed_when_its_handle_drops() {
let root = ChatNode::root("sys");
let user = root.add_user("hi");
let resp = response_with("speculative");
{
let phantom = user.build_assistant_node("speculative".into(), &resp, false);
assert_eq!(
phantom.thread().len(),
3,
"phantom reads its ancestor chain"
);
assert_eq!(root.arena_len(), 3, "phantom present while held"); }
assert_eq!(root.arena_len(), 2, "phantom reclaimed on drop");
assert!(user.is_leaf(), "phantom never registered as a child");
}
#[test]
fn held_phantom_keeps_its_parent_chain_alive_when_ancestors_are_dropped() {
let phantom = {
let root = ChatNode::root("sys");
let user = root.add_user("hi");
let resp = response_with("speculative");
user.build_assistant_node("speculative".into(), &resp, false)
};
assert_eq!(phantom.thread().len(), 3, "sys, hi, phantom");
assert_eq!(phantom.get_root().text(), Some("sys"));
assert_eq!(phantom.parent().unwrap().text(), Some("hi"));
let cloned = phantom.clone_tree();
assert_eq!(cloned.thread().len(), 3);
}
#[test]
fn reparenting_a_sibling_out_from_under_a_phantoms_ancestor_keeps_the_phantom_alive() {
let root = ChatNode::root("sys");
let a = root.add_user("a");
let reg = a.add_assistant("reg"); let resp = response_with("ph");
let phantom = a.build_assistant_node("ph".into(), &resp, false); drop(a);
root.add_child(reg.clone()).unwrap();
assert_eq!(phantom.thread().len(), 3, "sys, a, phantom");
assert_eq!(phantom.get_root().text(), Some("sys"));
assert_eq!(phantom.parent().unwrap().text(), Some("a"));
drop(phantom);
assert_eq!(reg.get_root().text(), Some("sys"));
assert!(
reg.thread().iter().all(|m| m.text() != Some("a")),
"a gone from reg's thread"
);
}
#[test]
fn detaching_a_phantom_decrements_its_old_parents_phantom_count() {
let root = ChatNode::root("sys");
let user = root.add_user("hi"); let resp = response_with("ph");
let phantom = user.build_assistant_node("ph".into(), &resp, false);
assert_eq!(root.arena_len(), 3, "sys, hi, phantom");
phantom.detach();
assert!(phantom.is_root(), "detached phantom is a root");
drop(user);
assert_eq!(
root.arena_len(),
2,
"old parent reclaimed (no leak): sys + detached phantom"
);
assert_eq!(phantom.text(), Some("ph"));
}
#[test]
fn re_parenting_a_phantom_transfers_it_without_leaking_the_old_parent() {
let root = ChatNode::root("sys");
let a = root.add_user("a");
let b = root.add_user("b");
let resp = response_with("ph");
let phantom = a.build_assistant_node("ph".into(), &resp, false); assert_eq!(root.arena_len(), 4, "sys, a, b, phantom");
b.add_child(phantom.clone()).unwrap();
assert_eq!(phantom.parent().unwrap().id, b.id);
assert_eq!(b.child_count(), 1, "phantom is now a registered child of b");
drop(a);
assert_eq!(
root.arena_len(),
3,
"old parent `a` reclaimed (no leak): sys, b, phantom-under-b"
);
}
#[test]
fn detached_then_dropped_subtree_is_reclaimed() {
let root = ChatNode::root("sys");
let a = root.add_user("a");
let _a_child = a.add_assistant("a-reply");
root.add_user("b"); assert_eq!(root.arena_len(), 3);
a.detach(); assert_eq!(root.child_count(), 0, "root lost its only child");
assert_eq!(root.arena_len(), 3, "a's subtree still held, still present");
drop(_a_child);
drop(a); assert_eq!(
root.arena_len(),
1,
"detached subtree reclaimed; only sys remains"
);
assert_eq!(root.node_count(), 1);
}
#[test]
fn concurrent_reverse_cross_tree_merges_do_not_deadlock() {
use std::sync::Barrier;
for _ in 0..200 {
let a = ChatNode::root("A");
let b = ChatNode::root("B");
let a2 = a.clone();
let b2 = b.clone();
let barrier = std::sync::Arc::new(Barrier::new(2));
let (ba, bb) = (barrier.clone(), barrier.clone());
let t1 = std::thread::spawn(move || {
ba.wait();
let _ = a.merge(&b); });
let t2 = std::thread::spawn(move || {
bb.wait();
let _ = b2.merge(&a2); });
t1.join().unwrap();
t2.join().unwrap();
}
}
#[test]
fn deep_tree_builds_traverses_and_drops_without_overflow() {
let leaf = {
let root = ChatNode::root("sys");
let mut cur = root.clone();
for i in 0..50_000 {
cur = cur.add_user(format!("m{i}"));
}
cur
};
assert_eq!(leaf.thread().len(), 50_001);
drop(leaf); }
#[test]
fn merge_rejects_same_tree_cycle() {
let root = ChatNode::root("sys");
let user = root.add_user("u");
assert!(user.merge(&root).is_err());
let other = ChatNode::root("other");
assert!(user.merge(&other).is_ok());
}
#[test]
fn add_child_rejects_ancestor_and_self_cycle() {
let root = ChatNode::root("sys");
let user = root.add_user("u");
assert!(user.add_child(root.clone()).is_err(), "ancestor → cycle");
assert!(user.add_child(user.clone()).is_err(), "self → cycle");
let other = ChatNode::root("other");
assert!(root.add_child(other).is_ok());
}
#[test]
fn same_tree_reparent_reclaims_the_orphaned_old_parent() {
let root = ChatNode::root("sys");
let a = root.add_user("a"); let b = a.add_assistant("b"); drop(a); assert_eq!(root.arena_len(), 3);
root.add_child(b.clone()).unwrap();
assert_eq!(
root.arena_len(),
2,
"orphaned old parent `a` must be reclaimed, not leaked"
);
assert_eq!(b.parent().unwrap().id, root.id);
assert_eq!(root.child_count(), 1);
}
#[test]
fn concurrent_same_tree_reparents_cannot_form_a_cycle() {
use std::sync::Barrier;
for _ in 0..200 {
let root = ChatNode::root("root");
let x = root.add_user("x");
let y = root.add_user("y"); let (xa, ya) = (x.clone(), y.clone());
let (x_mover, y_target) = (x.clone(), y.clone());
let (y_mover, x_target) = (y.clone(), x.clone());
drop((x, y));
let barrier = std::sync::Arc::new(Barrier::new(2));
let (b1, b2) = (barrier.clone(), barrier.clone());
let t1 = std::thread::spawn(move || {
b1.wait();
let _ = x_mover.add_child(y_target); });
let t2 = std::thread::spawn(move || {
b2.wait();
let _ = y_mover.add_child(x_target); });
t1.join().unwrap();
t2.join().unwrap();
for node in [&xa, &ya] {
let mut steps = 0;
let mut cur = Some(node.clone());
while let Some(n) = cur {
steps += 1;
assert!(
steps < 100,
"parent chain did not terminate: a cycle was committed"
);
cur = n.parent();
}
}
}
}
#[test]
fn node_path_is_root_to_self() {
let root = ChatNode::root("sys");
let u = root.add_user("u");
let a = u.add_assistant("a");
let path = a.node_path();
let ids: Vec<_> = path.iter().map(|n| n.id.clone()).collect();
assert_eq!(ids, vec![root.id.clone(), u.id.clone(), a.id.clone()]);
}
#[test]
fn prepare_messages_applies_system_prompt_and_force_prepend() {
let root = ChatNode::root("base sys");
let user = root.add_user("hello");
let p1 = NodeCompletionParameters::new();
let s1 = CompletionSettings::from_params(Some(&p1));
let m1 = user.prepare_messages(&s1);
assert_eq!(m1.first().unwrap().role, Role::System);
assert_eq!(m1.first().unwrap().text(), Some("base sys"));
let p2 = NodeCompletionParameters::new().with_force_prepend("Answer: ");
let s2 = CompletionSettings::from_params(Some(&p2));
let m2 = user.prepare_messages(&s2);
let last = m2.last().unwrap();
assert_eq!(last.role, Role::Assistant);
assert_eq!(last.text(), Some("Answer: "));
}
#[test]
fn prepare_messages_applies_completion_kwargs_base() {
let root = ChatNode::root("I am {bot}");
let user = root.add_user("hi");
let params = NodeCompletionParameters::new().with_format_kwarg("bot", "Claude");
let settings = CompletionSettings::from_params(Some(¶ms));
let msgs = user.prepare_messages(&settings);
assert_eq!(msgs.first().unwrap().text(), Some("I am Claude"));
}
#[test]
fn marking_a_node_propagates_breakpoint_into_prepared_messages() {
let root = ChatNode::root("sys");
let user = root.add_user("hi");
root.cache_breakpoint();
assert!(root.is_cache_breakpoint());
let settings = CompletionSettings::from_params(None);
let msgs = user.prepare_messages(&settings);
assert!(msgs[0].cache_breakpoint, "system marked");
assert!(!msgs[1].cache_breakpoint, "user not marked");
}
#[test]
fn use_cache_flag_marks_the_whole_prefix() {
let root = ChatNode::root("sys");
let user = root.add_user("hi");
let params = NodeCompletionParameters::new().with_cache(true);
let settings = CompletionSettings::from_params(Some(¶ms));
let msgs = user.prepare_messages(&settings);
assert!(msgs.last().unwrap().cache_breakpoint);
}
#[test]
fn clear_cache_breakpoint_and_clear_all() {
let root = ChatNode::root("sys");
let a = root.add_user("a");
let b = a.add_assistant("b");
root.cache_breakpoint();
b.cache_breakpoint();
assert!(root.is_cache_breakpoint() && b.is_cache_breakpoint());
b.clear_cache_breakpoint();
assert!(!b.is_cache_breakpoint());
assert!(root.is_cache_breakpoint());
root.cache_breakpoint();
b.cache_breakpoint();
a.clear_all_cache_breakpoints();
assert!(!root.is_cache_breakpoint());
assert!(!b.is_cache_breakpoint());
}
#[test]
fn cache_breakpoint_survives_clone_tree_and_thread_serialization() {
let root = ChatNode::root("sys");
let user = root.add_user("hi");
root.cache_breakpoint();
let cloned = user.clone_tree();
assert!(cloned.get_root().is_cache_breakpoint());
let msgs = user.prepare_messages(&CompletionSettings::from_params(None));
let json = serde_json::to_value(&msgs[0]).unwrap();
assert_eq!(json["cache_breakpoint"], true);
}
}