use crate::clients::base::{
ChunkType, LLMCallInfo, LLMClient, LLMRequestOptions, LLMResponse, LLMResponseEnvelope,
LLMUsageDetails, StreamChunk, TokenUsage,
};
use crate::context::manager::ContextManager;
use crate::core::message::{Message, MessageMeta, MessageRole, MessageType, ToolCallInfo};
use crate::core::tool_spec::ToolSpec;
use crate::error::{ForgeError, StreamError, ToolCallError};
use crate::guardrails::{ErrorTracker, ResponseValidator};
use futures_util::StreamExt;
use serde_json::Value;
use std::collections::HashSet;
const TOOL_CALL_ID_PREFIX: &str = "call_";
const TOOL_CALL_ID_WIDTH: usize = 9;
#[derive(Debug, Clone)]
pub struct InferenceResult {
pub response: LLMResponse,
pub usage: Option<TokenUsage>,
pub usage_details: Option<LLMUsageDetails>,
pub call_info: Option<LLMCallInfo>,
pub provider_response: Option<Value>,
pub provider_events: Option<Vec<Value>>,
pub new_messages: Vec<Message>,
pub tool_call_counter: i64,
pub attempts: i32,
}
pub fn format_tool_call_id(counter: i64) -> String {
format!(
"{}{:0>width$}",
TOOL_CALL_ID_PREFIX,
counter,
width = TOOL_CALL_ID_WIDTH
)
}
mod context;
mod fold;
#[cfg(test)]
mod tests;
use context::ContextAccess;
pub use fold::fold_and_serialize;
pub type OnChunkFn = Box<dyn Fn(&StreamChunk) + Send + Sync>;
#[allow(clippy::too_many_arguments)]
pub async fn run_inference<C: LLMClient>(
messages: &mut Vec<Message>,
client: &C,
context_manager: &mut ContextManager,
validator: &ResponseValidator,
error_tracker: &mut ErrorTracker,
tool_specs: &[ToolSpec],
tool_call_counter: &mut i64,
step_index: i64,
step_hint: &str,
max_attempts: Option<i32>,
stream: bool,
on_chunk: Option<&OnChunkFn>,
sampling: Option<&serde_json::Map<String, Value>>,
) -> Result<Option<InferenceResult>, ForgeError> {
let options = LLMRequestOptions::from_sampling(sampling.cloned());
run_inference_with_options(
messages,
client,
context_manager,
validator,
error_tracker,
tool_specs,
tool_call_counter,
step_index,
step_hint,
max_attempts,
stream,
on_chunk,
options,
)
.await
}
#[allow(clippy::too_many_arguments)]
pub async fn run_inference_with_options<C: LLMClient>(
messages: &mut Vec<Message>,
client: &C,
context_manager: &mut ContextManager,
validator: &ResponseValidator,
error_tracker: &mut ErrorTracker,
tool_specs: &[ToolSpec],
tool_call_counter: &mut i64,
step_index: i64,
step_hint: &str,
max_attempts: Option<i32>,
stream: bool,
on_chunk: Option<&OnChunkFn>,
options: LLMRequestOptions,
) -> Result<Option<InferenceResult>, ForgeError> {
run_inference_with_options_inner(
messages,
client,
ContextAccess::Direct(context_manager),
validator,
error_tracker,
tool_specs,
tool_call_counter,
step_index,
step_hint,
max_attempts,
stream,
on_chunk,
options,
)
.await
}
#[allow(clippy::too_many_arguments)]
pub async fn run_inference_shared_context<C: LLMClient>(
messages: &mut Vec<Message>,
client: &C,
context_manager: &tokio::sync::Mutex<ContextManager>,
validator: &ResponseValidator,
error_tracker: &mut ErrorTracker,
tool_specs: &[ToolSpec],
tool_call_counter: &mut i64,
step_index: i64,
step_hint: &str,
max_attempts: Option<i32>,
stream: bool,
on_chunk: Option<&OnChunkFn>,
sampling: Option<&serde_json::Map<String, Value>>,
) -> Result<Option<InferenceResult>, ForgeError> {
let options = LLMRequestOptions::from_sampling(sampling.cloned());
run_inference_with_options_shared_context(
messages,
client,
context_manager,
validator,
error_tracker,
tool_specs,
tool_call_counter,
step_index,
step_hint,
max_attempts,
stream,
on_chunk,
options,
)
.await
}
#[allow(clippy::too_many_arguments)]
pub async fn run_inference_with_options_shared_context<C: LLMClient>(
messages: &mut Vec<Message>,
client: &C,
context_manager: &tokio::sync::Mutex<ContextManager>,
validator: &ResponseValidator,
error_tracker: &mut ErrorTracker,
tool_specs: &[ToolSpec],
tool_call_counter: &mut i64,
step_index: i64,
step_hint: &str,
max_attempts: Option<i32>,
stream: bool,
on_chunk: Option<&OnChunkFn>,
options: LLMRequestOptions,
) -> Result<Option<InferenceResult>, ForgeError> {
run_inference_with_options_inner(
messages,
client,
ContextAccess::Shared(context_manager),
validator,
error_tracker,
tool_specs,
tool_call_counter,
step_index,
step_hint,
max_attempts,
stream,
on_chunk,
options,
)
.await
}
#[allow(clippy::too_many_arguments)]
async fn run_inference_with_options_inner<C: LLMClient>(
messages: &mut Vec<Message>,
client: &C,
mut context_manager: ContextAccess<'_>,
validator: &ResponseValidator,
error_tracker: &mut ErrorTracker,
tool_specs: &[ToolSpec],
tool_call_counter: &mut i64,
step_index: i64,
step_hint: &str,
max_attempts: Option<i32>,
stream: bool,
on_chunk: Option<&OnChunkFn>,
options: LLMRequestOptions,
) -> Result<Option<InferenceResult>, ForgeError> {
let mut new_messages: Vec<Message> = Vec::new();
let mut attempts = 0;
let retry_limit = error_tracker.max_retries().saturating_add(1);
let max = std::cmp::min(retry_limit, max_attempts.unwrap_or(i32::MAX));
let api_format = client.api_format().as_str();
let tools_opt = if tool_specs.is_empty() {
None
} else {
Some(tool_specs.to_vec())
};
let mut next_options = options;
while attempts < max {
attempts += 1;
let mut request_options = next_options.clone();
let compacted = context_manager
.maybe_compact(messages, step_index, Some(step_hint))
.await;
if let Some(new_msgs) = compacted {
messages.clear();
messages.extend(new_msgs);
request_options.inbound_anthropic_body = None;
request_options.initial_openai_messages = None;
}
let transient_warning = context_manager.check_thresholds(messages).await;
if transient_warning.is_some() {
request_options.inbound_anthropic_body = None;
request_options.initial_openai_messages = None;
}
let mut wire = fold_and_serialize(messages, api_format);
if let Some(ref warning) = transient_warning {
let warning_msg = Message::new(
MessageRole::User,
warning.as_str(),
MessageMeta::new(MessageType::ContextWarning),
);
wire.push(warning_msg.serialize(api_format));
new_messages.push(warning_msg);
}
next_options.inbound_anthropic_body = None;
next_options.initial_openai_messages = None;
let envelope = if stream {
let mut stream = client
.send_stream_with_options(wire, tools_opt.clone(), request_options)
.await
.map_err(ForgeError::from)?;
let mut final_response: Option<LLMResponse> = None;
let mut final_usage = None;
let mut final_usage_details = None;
let mut final_call_info = None;
let mut provider_events = Vec::new();
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(ForgeError::from)?;
if chunk.chunk_type == ChunkType::ProviderEvent {
if let Some(event) = chunk.provider_event {
provider_events.push(event);
}
continue;
}
if let Some(ref cb) = on_chunk {
cb(&chunk);
}
if chunk.chunk_type == ChunkType::Final {
final_usage = chunk.usage;
final_usage_details = chunk.usage_details;
final_call_info = chunk.call_info;
final_response = chunk.response;
}
}
let response = final_response.ok_or_else(|| {
ForgeError::Stream(StreamError::new(
"Stream ended without FINAL chunk - the client adapter may be malformed or the connection was interrupted",
))
})?;
let mut envelope = LLMResponseEnvelope::from_response(response).with_metadata(
final_usage.or_else(|| client.last_usage()),
final_usage_details.or_else(|| client.last_usage_details()),
final_call_info.or_else(|| client.last_call_info()),
);
if !provider_events.is_empty() {
envelope.provider_response = Some(Value::Array(provider_events));
}
envelope
} else {
client
.send_envelope_with_options(wire, tools_opt.clone(), request_options)
.await
.map_err(ForgeError::from)?
};
let LLMResponseEnvelope {
response,
usage,
usage_details,
call_info,
provider_response,
} = envelope;
let provider_events = provider_response.as_ref().and_then(|value| match value {
Value::Array(events) => Some(events.clone()),
_ => None,
});
let provider_response = match provider_response {
Some(Value::Array(_)) => None,
other => other,
};
let observed_tokens = if let Some(usage) = usage.as_ref() {
usage.total_tokens
} else {
estimate_tokens_from_response(&response)
};
context_manager.update_token_count(observed_tokens).await;
let validation = validator.validate(&response);
let preserve_provider_response =
matches!(response, LLMResponse::ToolCalls(_)) && !validation.needs_retry;
if validation.needs_retry {
error_tracker.record_retry();
if error_tracker.retries_exhausted() {
let raw = response_to_raw_string(&response).unwrap_or_default();
return Err(ForgeError::ToolCall(
ToolCallError::new(format!(
"Retries exhausted after {} consecutive failed attempts",
error_tracker.max_retries()
))
.with_raw_response(raw),
));
}
let nudge_content = validation
.nudge
.as_ref()
.map(|n| n.content.clone())
.unwrap_or_default();
match &response {
LLMResponse::Text(text) => {
let assistant_msg = Message::new(
MessageRole::Assistant,
&text.content,
MessageMeta::new(MessageType::TextResponse).with_step_index(step_index),
);
messages.push(assistant_msg.clone());
new_messages.push(assistant_msg);
let nudge_msg = Message::new(
MessageRole::User,
&nudge_content,
MessageMeta::new(MessageType::RetryNudge).with_step_index(step_index),
);
messages.push(nudge_msg.clone());
new_messages.push(nudge_msg);
}
LLMResponse::ToolCalls(calls) => {
if calls.is_empty() {
let nudge_msg = Message::new(
MessageRole::User,
&nudge_content,
MessageMeta::new(MessageType::RetryNudge).with_step_index(step_index),
);
messages.push(nudge_msg.clone());
new_messages.push(nudge_msg);
continue;
}
let mut tool_call_infos = Vec::new();
let mut seen_call_ids = existing_tool_call_ids(messages);
for tc in calls {
if let Some(ref reasoning) = tc.reasoning {
let reasoning_msg = Message::new(
MessageRole::Assistant,
reasoning.as_str(),
MessageMeta::new(MessageType::Reasoning)
.with_step_index(step_index),
);
messages.push(reasoning_msg.clone());
new_messages.push(reasoning_msg);
}
let call_id =
next_unique_tool_call_id(tool_call_counter, &mut seen_call_ids);
let info = ToolCallInfo::new(&tc.tool, Some(tc.args.clone()), &call_id);
tool_call_infos.push(info);
}
let tool_call_msg = Message::new(
MessageRole::Assistant,
"",
MessageMeta::new(MessageType::ToolCall).with_step_index(step_index),
)
.with_tool_calls(tool_call_infos.clone());
messages.push(tool_call_msg.clone());
new_messages.push(tool_call_msg);
let error_prefix = validation
.nudge
.as_ref()
.map(|nudge| match nudge.kind.as_str() {
"unknown_tool" => "[UnknownTool]",
"invalid_arguments" => "[InvalidArguments]",
_ => "[Guardrail]",
})
.unwrap_or("[Guardrail]");
for info in &tool_call_infos {
let error_content = format!("{} {}", error_prefix, nudge_content);
let result_msg = Message::new(
MessageRole::Tool,
&error_content,
MessageMeta::new(MessageType::RetryNudge).with_step_index(step_index),
)
.with_tool_name(&info.name)
.with_tool_call_id(&info.call_id);
messages.push(result_msg.clone());
new_messages.push(result_msg);
}
}
}
continue;
}
error_tracker.reset_retries();
let mut tool_calls = validation.tool_calls.unwrap_or_default();
for call in &mut tool_calls {
call.id = None;
}
return Ok(Some(InferenceResult {
response: LLMResponse::ToolCalls(tool_calls),
usage,
usage_details,
call_info,
provider_response: if preserve_provider_response {
provider_response
} else {
None
},
provider_events: if preserve_provider_response {
provider_events
} else {
None
},
new_messages,
tool_call_counter: *tool_call_counter,
attempts,
}));
}
Ok(None)
}
fn existing_tool_call_ids(messages: &[Message]) -> HashSet<String> {
let mut seen_call_ids = HashSet::new();
for message in messages {
if let Some(calls) = message.tool_calls.as_ref() {
for call in calls {
if !call.call_id.is_empty() {
seen_call_ids.insert(call.call_id.clone());
}
}
}
if let Some(id) = message.tool_call_id.as_ref().filter(|id| !id.is_empty()) {
seen_call_ids.insert(id.clone());
}
}
seen_call_ids
}
fn next_unique_tool_call_id(
tool_call_counter: &mut i64,
seen_call_ids: &mut HashSet<String>,
) -> String {
loop {
let id = format_tool_call_id(*tool_call_counter);
*tool_call_counter += 1;
if seen_call_ids.insert(id.clone()) {
return id;
}
}
}
fn estimate_tokens_from_response(response: &LLMResponse) -> i64 {
match response {
LLMResponse::Text(t) => (t.content.len() as i64) / 4,
LLMResponse::ToolCalls(calls) => {
let total: usize = calls
.iter()
.map(|c| {
c.tool.len()
+ c.args.values().map(|v| v.to_string().len()).sum::<usize>()
+ c.reasoning.as_ref().map(|r| r.len()).unwrap_or(0)
})
.sum();
(total as i64) / 4
}
}
}
#[allow(dead_code)]
pub(crate) fn response_to_raw_string(response: &LLMResponse) -> Option<String> {
match response {
LLMResponse::Text(t) => Some(t.content.clone()),
LLMResponse::ToolCalls(calls) => {
let s: Vec<String> = calls
.iter()
.map(|c| {
format!(
"{}({})",
c.tool,
serde_json::to_string(&c.args).unwrap_or_default()
)
})
.collect();
Some(s.join(", "))
}
}
}