use std::collections::HashSet;
use std::sync::Arc;
use std::time::Instant;
use anyhow::Result;
use chrono::Utc;
use futures::StreamExt;
use log::{debug, info};
use neuromance_client::{ClientError, LLMClient};
use neuromance_common::chat::{Message, MessageRole};
use neuromance_common::client::{ChatRequest, ChatResponse, ToolChoice};
use neuromance_common::tools::{ToolApproval, ToolCall};
use neuromance_tools::ToolExecutor;
use crate::error::CoreError;
pub type ToolApprovalCallback = Box<dyn Fn(&ToolCall) -> ToolApproval + Send + Sync>;
pub type StreamingCallback = Box<dyn Fn(&str) + Send + Sync>;
pub struct Core<C: LLMClient> {
pub client: C,
pub max_turns: Option<u32>,
pub auto_approve_tools: bool,
pub tool_choice: ToolChoice,
pub streaming: bool,
pub tool_executor: ToolExecutor,
pub tool_approval_callback: Option<ToolApprovalCallback>,
pub streaming_callback: Option<StreamingCallback>,
}
impl<C: LLMClient> Core<C> {
pub fn new(client: C) -> Self {
Self {
client,
max_turns: None,
auto_approve_tools: false,
tool_choice: ToolChoice::Auto,
streaming: false,
tool_executor: ToolExecutor::new(),
tool_approval_callback: None,
streaming_callback: None,
}
}
pub fn with_tool_approval_callback<F>(mut self, callback: F) -> Self
where
F: Fn(&ToolCall) -> ToolApproval + Send + Sync + 'static,
{
self.tool_approval_callback = Some(Box::new(callback));
self
}
pub fn with_streaming<F>(mut self, callback: F) -> Self
where
F: Fn(&str) + Send + Sync + 'static,
{
self.streaming = true;
self.streaming_callback = Some(Box::new(callback));
self
}
async fn chat_with_retry(&self, request: &ChatRequest) -> Result<ChatResponse> {
let mut last_error = None;
let config = self.client.config();
for attempt in 0..=config.retry_config.max_retries {
match self.client.chat(request).await {
Ok(response) => return Ok(response),
Err(e) => {
let is_retryable = e
.downcast_ref::<ClientError>()
.map(|client_err| client_err.is_retryable())
.unwrap_or(false);
if attempt < config.retry_config.max_retries && is_retryable {
debug!(
"Request failed (attempt {}), retrying in {:?}: {}",
attempt + 1,
config.retry_config.initial_delay,
e
);
last_error = Some(e);
tokio::time::sleep(config.retry_config.initial_delay).await;
continue;
}
last_error = Some(e);
break;
}
}
}
Err(last_error.unwrap())
}
async fn chat_stream_accumulated(&self, request: &ChatRequest) -> Result<ChatResponse> {
let mut stream = self.client.chat_stream(request).await?;
let mut accumulated_content = String::new();
let mut response_metadata = None;
let mut role = None;
let mut tool_calls: Vec<ToolCall> = Vec::new();
let mut finish_reason = None;
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result?;
if let Some(ref content) = chunk.delta_content {
accumulated_content.push_str(content);
if let Some(ref callback) = self.streaming_callback {
callback(content);
}
}
if role.is_none() {
role = chunk.delta_role;
}
if let Some(ref delta_tool_calls) = chunk.delta_tool_calls {
debug!("Received {} tool call delta(s)", delta_tool_calls.len());
tool_calls = ToolCall::merge_deltas(tool_calls, delta_tool_calls);
}
if chunk.finish_reason.is_some() {
finish_reason = chunk.finish_reason;
}
response_metadata = Some(chunk);
}
let conversation_id = request
.messages
.first()
.ok_or_else(|| anyhow::anyhow!("Request must contain at least one message"))?
.conversation_id;
let last_chunk =
response_metadata.ok_or_else(|| anyhow::anyhow!("Stream ended without any chunks"))?;
let message = Message {
id: uuid::Uuid::new_v4(),
conversation_id,
role: role.unwrap_or(MessageRole::Assistant),
content: accumulated_content,
tool_calls: tool_calls.into_iter().collect(),
tool_call_id: None,
name: None,
timestamp: Utc::now(),
metadata: last_chunk.metadata,
};
Ok(ChatResponse {
message,
model: last_chunk.model,
usage: last_chunk.usage,
finish_reason,
created_at: last_chunk.created_at,
response_id: last_chunk.response_id,
metadata: std::collections::HashMap::new(),
})
}
pub async fn chat_with_tool_loop(&self, mut messages: Vec<Message>) -> Result<Vec<Message>> {
let mut turn_count = 0;
let mut pending_tool_calls: HashSet<String> = HashSet::new();
let start_time = Instant::now();
let mut messages_arc: Arc<[Message]> = messages.clone().into();
loop {
let request = ChatRequest::from((self.client.config(), messages_arc.clone()))
.with_tools(self.tool_executor.get_all_tools())
.with_tool_choice(self.tool_choice.clone());
info!(
"Executing chat turn ({}/{})",
turn_count + 1,
self.max_turns
.map_or("unlimited".to_string(), |max| max.to_string()),
);
debug!(
"Chat request:\n {}",
serde_json::to_string_pretty(&request)?
);
let response = if self.streaming {
self.chat_stream_accumulated(&request).await?
} else {
self.chat_with_retry(&request).await?
};
debug!("Received response from LLM");
debug!(
"Assistant Response:\n {}",
serde_json::to_string_pretty(&response)?
);
let conversation_id = response.message.conversation_id;
let tool_calls = response.message.tool_calls.clone();
let tool_calls_count = tool_calls.len();
messages.push(response.message);
if tool_calls.is_empty() {
let duration = start_time.elapsed();
debug!(
"No tool calls in response, chat completed in {} turns ({:.2?})",
turn_count + 1,
duration
);
return Ok(messages);
}
for tool_call in &tool_calls {
let tool_name = &tool_call.function.name;
let call_id = &tool_call.id;
pending_tool_calls.insert(tool_call.id.clone());
debug!("Tool Name: {} (id: {})", tool_name, call_id);
debug!("Tool Arguments: {:?}", tool_call.function.arguments);
let is_auto_approved =
self.auto_approve_tools || self.tool_executor.is_tool_auto_approved(tool_name);
debug!("Tool auto-approved: {}", is_auto_approved);
let approval = if is_auto_approved {
ToolApproval::Approved
} else if let Some(ref callback) = self.tool_approval_callback {
callback(tool_call)
} else {
ToolApproval::Denied("No approval mechanism configured".to_string())
};
debug!("Tool Approval Status: {:?}", approval);
match approval {
ToolApproval::Approved => {
debug!("Executing tool: {}", tool_name);
match self.tool_executor.execute_tool(tool_call).await {
Ok(result) => {
debug!("Tool {} executed successfully", tool_name);
debug!("Tool result: {}", result);
let tool_message = Message::tool(
conversation_id,
result,
tool_call.id.clone(),
tool_call.function.name.clone(),
)?;
messages.push(tool_message);
pending_tool_calls.remove(&tool_call.id);
}
Err(e) => {
debug!("Tool {} execution failed: {}", tool_name, e);
let error_message = Message::tool(
conversation_id,
format!("Tool execution failed: {}", e),
tool_call.id.clone(),
tool_call.function.name.clone(),
)?;
messages.push(error_message);
pending_tool_calls.remove(&tool_call.id);
}
}
}
ToolApproval::Denied(reason) => {
debug!("Tool {} denied: {}", tool_name, reason);
let denial_message = Message::tool(
conversation_id,
format!("Tool execution denied: {}", reason),
tool_call.id.clone(),
tool_call.function.name.clone(),
)?;
messages.push(denial_message);
pending_tool_calls.remove(&tool_call.id);
}
ToolApproval::Quit => {
debug!("User quit during tool approval");
return Err(CoreError::Other(anyhow::anyhow!(
"User quit during tool approval"
))
.into());
}
}
}
debug!(
"Completed processing {} tool calls, continuing conversation",
tool_calls_count
);
messages_arc = messages.clone().into();
if !pending_tool_calls.is_empty() {
debug!(
"Warning: {} tool calls still pending",
pending_tool_calls.len()
);
}
turn_count += 1;
if let Some(max) = self.max_turns
&& turn_count >= max
{
return Err(CoreError::MaxTurnsExceeded(format!(
"Exceeded maximum turns: {} (configured max: {})",
turn_count, max
))
.into());
}
}
}
}