use crate::api::{ContentBlock, CreateMessageRequest, LlmClient};
use crate::config::SofosConfig;
use crate::error::{Result, SofosError};
use crate::repl::SteerQueue;
use crate::repl::conversation::ConversationHistory;
use crate::repl::request_builder::RequestBuilder;
use crate::session::DisplayMessage;
use crate::tools::ToolExecutor;
use crate::ui::UI;
use colored::Colorize;
use std::io::Write;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::time::{Duration, sleep};
pub struct ResponseHandler {
client: LlmClient,
tool_executor: ToolExecutor,
conversation: ConversationHistory,
ui: UI,
model: String,
max_tokens: u32,
enable_thinking: bool,
thinking_budget: u32,
config: SofosConfig,
available_tools: Vec<crate::api::Tool>,
use_streaming: bool,
interrupt_flag: Arc<AtomicBool>,
steer_queue: SteerQueue,
}
impl ResponseHandler {
#[allow(clippy::too_many_arguments)]
pub fn new(
client: LlmClient,
tool_executor: ToolExecutor,
conversation: ConversationHistory,
model: String,
max_tokens: u32,
enable_thinking: bool,
thinking_budget: u32,
available_tools: Vec<crate::api::Tool>,
use_streaming: bool,
interrupt_flag: Arc<AtomicBool>,
steer_queue: SteerQueue,
) -> Self {
Self {
client,
tool_executor,
conversation,
ui: UI::new(),
model,
max_tokens,
enable_thinking,
thinking_budget,
config: SofosConfig::default(),
available_tools,
use_streaming,
interrupt_flag,
steer_queue,
}
}
fn drain_steer_messages(&self) -> Option<String> {
let mut queue = self
.steer_queue
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
if queue.is_empty() {
return None;
}
let messages: Vec<String> = std::mem::take(&mut *queue);
Some(messages.join("\n\n"))
}
pub async fn handle_response(
&mut self,
mut content_blocks: Vec<ContentBlock>,
display_messages: &mut Vec<DisplayMessage>,
total_input_tokens: &mut u32,
total_output_tokens: &mut u32,
) -> Result<()> {
let mut iteration = 0;
loop {
iteration += 1;
if std::env::var("SOFOS_DEBUG").is_ok() {
eprintln!(
"\n=== handle_response: iteration={}, blocks={} ===",
iteration,
content_blocks.len()
);
}
if iteration > self.config.max_tool_iterations {
self.handle_max_iterations(
display_messages,
total_input_tokens,
total_output_tokens,
)
.await?;
return Ok(());
}
let (text_output, tool_uses, had_reasoning) =
self.process_content_blocks(&content_blocks);
if !text_output.is_empty() {
if !self.use_streaming {
if iteration > 1 {
println!();
}
println!("{}", "Assistant:".bright_blue().bold());
for text in &text_output {
self.ui.print_assistant_text(text)?;
}
}
let combined_text = text_output.join("\n");
display_messages.push(DisplayMessage::AssistantMessage {
content: combined_text,
});
}
if !content_blocks.is_empty() {
let message_blocks: Vec<crate::api::MessageContentBlock> = content_blocks
.iter()
.filter_map(crate::api::MessageContentBlock::from_content_block_for_api)
.collect();
if !message_blocks.is_empty() {
self.conversation.add_assistant_with_blocks(message_blocks);
}
}
if tool_uses.is_empty()
&& text_output.is_empty()
&& had_reasoning
&& matches!(self.client, LlmClient::OpenAI(_))
{
let response = self.get_next_response(&[], display_messages).await?;
*total_input_tokens += response.usage.input_tokens;
*total_output_tokens += response.usage.output_tokens;
if response.content.is_empty() {
println!(
"{}",
"Assistant returned reasoning but no visible response.".dimmed()
);
println!();
break;
}
content_blocks = response.content;
continue;
}
if tool_uses.is_empty() {
if text_output.is_empty() && !had_reasoning {
println!("{}", "Assistant returned an empty response.".dimmed());
println!();
}
break;
}
let (tool_results, user_cancelled) =
self.execute_tools(&tool_uses, display_messages).await?;
if !tool_results.is_empty() {
if std::env::var("SOFOS_DEBUG").is_ok() {
eprintln!(
"=== Adding {} tool results to conversation ===",
tool_results.len()
);
}
if let Some(steer_text) = self.drain_steer_messages() {
println!(
"{} {}",
"↑".bright_magenta().bold(),
"mid-turn message delivered to the model".bright_magenta()
);
let mut blocks = tool_results;
blocks.push(crate::api::MessageContentBlock::Text {
text: format!(
"[User sent this message while you were working on the current task. \
Take it into account and adjust your plan if needed]:\n{}",
steer_text
),
cache_control: None,
});
self.conversation.add_user_with_blocks(blocks);
} else {
self.conversation.add_tool_results(tool_results);
}
}
if user_cancelled {
if std::env::var("SOFOS_DEBUG").is_ok() {
eprintln!("=== Returning early due to user cancellation ===");
}
return Ok(());
}
let response = self.get_next_response(&tool_uses, display_messages).await?;
*total_input_tokens += response.usage.input_tokens;
*total_output_tokens += response.usage.output_tokens;
if std::env::var("SOFOS_DEBUG").is_ok() {
eprintln!(
"\n=== Response received: stop_reason={:?}, content_blocks={} ===",
response.stop_reason,
response.content.len()
);
}
if let Some(ref stop_reason) = response.stop_reason {
if stop_reason == "max_tokens" {
UI::print_warning("Response was cut off due to token limit.");
eprintln!(
"Consider using --max-tokens with a higher value (current: {})",
self.max_tokens
);
}
}
if response.content.is_empty() {
println!("{}", "Assistant:".bright_blue().bold());
println!("{}", "I've completed the tool operations but didn't generate a response. Please let me know if you need any clarification.".dimmed());
println!();
return Ok(());
}
content_blocks = response.content;
}
Ok(())
}
fn process_content_blocks(
&self,
content_blocks: &[ContentBlock],
) -> (Vec<String>, Vec<(String, String, serde_json::Value)>, bool) {
let mut text_output = Vec::new();
let mut tool_uses = Vec::new();
let mut had_reasoning = false;
for block in content_blocks {
match block {
ContentBlock::Text { text } => {
if !text.trim().is_empty() {
text_output.push(text.clone());
}
}
ContentBlock::Thinking { thinking, .. } => {
if !self.use_streaming {
self.ui.print_thinking(thinking);
}
had_reasoning = true;
}
ContentBlock::Summary { summary } => {
if !self.use_streaming {
self.ui.print_thinking(summary);
}
had_reasoning = true;
}
ContentBlock::ToolUse { id, name, input } => {
tool_uses.push((id.clone(), name.clone(), input.clone()));
}
ContentBlock::ServerToolUse { name, input, .. } => {
if std::env::var("SOFOS_DEBUG").is_ok() {
eprintln!("Server tool use: {} with input: {:?}", name, input);
}
}
ContentBlock::WebSearchToolResult { content, .. } => {
if !content.is_empty() {
text_output
.push(format!("\n[Web search returned {} results]", content.len()));
}
}
}
}
(text_output, tool_uses, had_reasoning)
}
async fn execute_tools(
&self,
tool_uses: &[(String, String, serde_json::Value)],
display_messages: &mut Vec<DisplayMessage>,
) -> Result<(Vec<crate::api::MessageContentBlock>, bool)> {
let mut tool_results = Vec::new();
let mut user_cancelled = false;
if std::env::var("SOFOS_DEBUG").is_ok() {
eprintln!("\n=== Executing {} tools ===", tool_uses.len());
}
for (i, (tool_id, tool_name, tool_input)) in tool_uses.iter().enumerate() {
if std::env::var("SOFOS_DEBUG").is_ok() {
eprintln!(
"=== Tool {}/{}: {} (id: {}) ===",
i + 1,
tool_uses.len(),
tool_name,
&tool_id[..20.min(tool_id.len())]
);
}
let command = if tool_name == "execute_bash" {
tool_input.get("command").and_then(|v| v.as_str())
} else {
None
};
self.ui.print_tool_header(tool_name, command);
if tool_name == "execute_bash" {
print!("\x1B[?25l");
let _ = std::io::stdout().flush();
}
let result = self.tool_executor.execute(tool_name, tool_input).await;
if tool_name == "execute_bash" {
print!("\x1B[?25h");
println!();
}
match result {
Ok(output) => {
if std::env::var("SOFOS_DEBUG").is_ok() {
eprintln!(
"=== Tool {} succeeded, output length: {} ===",
i + 1,
output.text().len()
);
}
let display_output =
UI::create_tool_display_message(tool_name, tool_input, output.text());
if !display_output.is_empty() {
let ui = UI::new();
ui.print_tool_output(&display_output);
}
display_messages.push(DisplayMessage::ToolExecution {
tool_name: tool_name.clone(),
tool_input: tool_input.clone(),
tool_output: display_output.clone(),
});
tool_results.push(crate::api::MessageContentBlock::ToolResult {
tool_use_id: tool_id.clone(),
content: output.text().to_string(),
cache_control: None,
});
for image in output.images() {
tool_results.push(crate::api::MessageContentBlock::Image {
source: crate::api::ImageSource::Base64 {
media_type: image.mime_type.clone(),
data: image.base64_data.clone(),
},
cache_control: None,
});
}
if output.text().starts_with("File deletion cancelled by user")
|| output
.text()
.starts_with("Directory deletion cancelled by user")
{
user_cancelled = true;
break;
}
}
Err(e) => {
if std::env::var("SOFOS_DEBUG").is_ok() {
eprintln!("=== Tool {} failed: {} ===", i + 1, e);
}
let error_msg = format!("{}", e);
if e.is_blocked() {
UI::print_blocked_with_hint(&e);
} else {
UI::print_error_with_hint(&e);
}
println!();
display_messages.push(DisplayMessage::ToolExecution {
tool_name: tool_name.clone(),
tool_input: tool_input.clone(),
tool_output: error_msg.clone(),
});
tool_results.push(crate::api::MessageContentBlock::ToolResult {
tool_use_id: tool_id.clone(),
content: error_msg,
cache_control: None,
});
}
}
}
Ok((tool_results, user_cancelled))
}
async fn get_next_response(
&mut self,
tool_uses: &[(String, String, serde_json::Value)],
display_messages: &mut Vec<DisplayMessage>,
) -> Result<crate::api::CreateMessageResponse> {
if std::env::var("SOFOS_DEBUG").is_ok() {
eprintln!("=== About to generate response ===");
eprintln!("\n=== DEBUG: Conversation before API call ===");
for (i, msg) in self.conversation.messages().iter().enumerate() {
let content_desc = match &msg.content {
crate::api::MessageContent::Text { content } => {
format!("text({})", content.len())
}
crate::api::MessageContent::Blocks { content } => {
format!("blocks({})", content.len())
}
};
eprintln!("Message {}: role={}, content={}", i, msg.role, content_desc);
}
eprintln!("===========================================\n");
}
if self.use_streaming {
let printer = Arc::new(crate::ui::StreamPrinter::new());
let p_text = printer.clone();
let p_think = printer.clone();
let interrupt = Arc::clone(&self.interrupt_flag);
let request = self.build_request();
let response_result = self
.client
.create_message_streaming(
request,
move |t| p_text.on_text_delta(t),
move |t| p_think.on_thinking_delta(t),
interrupt,
)
.await;
printer.finish();
return response_result;
}
let request = self.build_request();
let client = self.client.clone();
let response_result = self
.run_interruptible(async move { client.create_message(request).await })
.await;
if self.interrupt_flag.load(Ordering::SeqCst) {
println!(
"\n{}",
"Processing interrupted by user. You can now provide additional guidance."
.bright_yellow()
);
println!();
let tools_executed: Vec<String> =
tool_uses.iter().map(|(_, name, _)| name.clone()).collect();
let interrupt_msg = format!(
"INTERRUPT: The user pressed ESC while waiting for your response after tool execution. \
Tools that were executed: {}. The user wants to provide additional guidance before you continue. \
Wait for their next message.",
tools_executed.join(", ")
);
if !self
.conversation
.append_text_to_last_user_blocks(interrupt_msg.clone())
{
self.conversation.add_user_message(interrupt_msg);
}
display_messages.push(DisplayMessage::UserMessage {
content: format!(
"[Interrupted after executing: {}]",
tools_executed.join(", ")
),
});
return Err(SofosError::Interrupted);
}
response_result
}
async fn wait_for_interrupt(flag: Arc<AtomicBool>) {
while !flag.load(Ordering::Relaxed) {
sleep(Duration::from_millis(50)).await;
}
}
async fn run_interruptible<T>(
&self,
fut: impl std::future::Future<Output = Result<T>> + Send + 'static,
) -> Result<T>
where
T: Send + 'static,
{
let mut handle = tokio::spawn(fut);
let interrupt_flag = Arc::clone(&self.interrupt_flag);
tokio::select! {
res = &mut handle => match res {
Ok(inner) => inner,
Err(e) => Err(SofosError::Join(format!("{}", e))),
},
_ = Self::wait_for_interrupt(interrupt_flag) => {
handle.abort();
Err(SofosError::Interrupted)
}
}
}
async fn handle_max_iterations(
&mut self,
display_messages: &mut Vec<DisplayMessage>,
total_input_tokens: &mut u32,
total_output_tokens: &mut u32,
) -> Result<()> {
UI::print_warning("Maximum tool iterations reached. Stopping to prevent infinite loop.");
let interruption_msg = format!(
"SYSTEM INTERRUPTION: You have reached the maximum number of tool iterations ({}). \
This limit prevents infinite loops. Please provide a summary of what you've accomplished \
so far and suggest how the user should proceed. Consider breaking down the task into \
smaller steps or asking the user for clarification.",
self.config.max_tool_iterations
);
self.conversation.add_user_message(interruption_msg.clone());
display_messages.push(DisplayMessage::UserMessage {
content: "[System: Maximum tool iterations reached]".to_string(),
});
let request = self.build_request();
let client = self.client.clone();
let response_result = self
.run_interruptible(async move { client.create_message(request).await })
.await;
match response_result {
Ok(response) => {
*total_input_tokens += response.usage.input_tokens;
*total_output_tokens += response.usage.output_tokens;
for block in &response.content {
if let ContentBlock::Text { text } = block {
if !text.trim().is_empty() {
println!("{}", "Assistant:".bright_blue().bold());
self.ui.print_assistant_text(text)?;
display_messages.push(DisplayMessage::AssistantMessage {
content: text.clone(),
});
}
}
}
let message_blocks: Vec<crate::api::MessageContentBlock> = response
.content
.iter()
.filter_map(crate::api::MessageContentBlock::from_content_block_for_api)
.collect();
if !message_blocks.is_empty() {
self.conversation.add_assistant_with_blocks(message_blocks);
}
}
Err(e) => {
UI::print_error(&format!("Failed to get summary after interruption: {}", e));
return Err(e);
}
}
Ok(())
}
fn get_available_tools(&self) -> Vec<crate::api::Tool> {
self.available_tools.clone()
}
fn build_request(&self) -> CreateMessageRequest {
RequestBuilder::new(
&self.client,
&self.model,
self.max_tokens,
&self.conversation,
self.get_available_tools(),
self.enable_thinking,
self.thinking_budget,
)
.build()
}
pub fn conversation(&self) -> &ConversationHistory {
&self.conversation
}
}