use super::layer_trait::{Layer, LayerConfig, LayerResult};
use crate::config::Config;
use crate::session::{Message, Session};
use anyhow::Result;
use async_trait::async_trait;
use colored::Colorize;
pub struct LayerProcessor {
pub config: LayerConfig,
}
impl LayerProcessor {
pub fn new(config: LayerConfig) -> Self {
Self { config }
}
pub fn create_messages(&self, input: &str, session: &Session) -> Vec<Message> {
let mut messages = Vec::new();
let effective_model = self.config.get_effective_model(&session.info.model);
let system_prompt = self.config.get_effective_system_prompt();
let should_cache = crate::session::model_utils::model_supports_caching(&effective_model);
messages.push(Message {
role: "system".to_string(),
content: system_prompt,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
cached: should_cache, ..Default::default()
});
let processed_input = self.prepare_input(input, session);
messages.push(Message {
role: "user".to_string(),
content: processed_input,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
cached: false,
..Default::default()
});
messages
}
}
#[async_trait]
impl Layer for LayerProcessor {
fn name(&self) -> &str {
&self.config.name
}
fn config(&self) -> &LayerConfig {
&self.config
}
async fn process(
&self,
input: &str,
session: &Session,
config: &Config,
operation_cancelled: tokio::sync::watch::Receiver<bool>,
) -> Result<LayerResult> {
if *operation_cancelled.borrow() {
return Err(anyhow::anyhow!("Operation cancelled"));
}
let total_start = std::time::Instant::now();
let mut api_time_ms = 0u64;
let mut tool_time_ms = 0u64;
let effective_model = self.config.get_effective_model(&session.info.model);
let messages = self.create_messages(input, session);
let api_start = std::time::Instant::now();
let response = crate::session::chat_completion_with_provider(
crate::session::ChatCompletionProviderParams {
messages: &messages,
model: &effective_model,
temperature: self.config.temperature,
top_p: self.config.top_p,
top_k: self.config.top_k,
max_tokens: self.config.max_tokens,
config,
max_retries: config.max_retries,
cancellation_token: None,
schema: None,
},
)
.await?;
api_time_ms += api_start.elapsed().as_millis() as u64;
let (output, exchange, direct_tool_calls, _finish_reason) = (
response.content,
response.exchange,
response.tool_calls,
response.finish_reason,
);
if !self.config.mcp.server_refs.is_empty() {
let tool_calls = if let Some(ref calls) = direct_tool_calls {
calls
} else {
&crate::mcp::parse_tool_calls(&output)
};
if !tool_calls.is_empty() {
let output_clone = output.clone();
let mut tool_results = Vec::new();
for tool_call in tool_calls {
println!("{} {}", "Tool call:".yellow(), tool_call.tool_name);
let server_name =
crate::mcp::tool_map::get_tool_server_name(&tool_call.tool_name)
.unwrap_or_else(|| "unknown".to_string());
if !self
.config
.mcp
.is_tool_allowed(&tool_call.tool_name, &server_name)
{
println!(
"{} {} {}",
"Tool".red(),
tool_call.tool_name,
"not allowed for this layer".red()
);
continue;
}
let layer_config = self.config.get_merged_config_for_layer(config);
let result = match crate::mcp::execute_layer_tool_call(
tool_call,
&layer_config,
&self.config,
Some(operation_cancelled.clone()), )
.await
{
Ok((res, single_tool_time_ms)) => {
tool_time_ms += single_tool_time_ms;
res
}
Err(e) => {
crate::log_error!("{} {}", "Tool execution error:", e);
continue;
}
};
tool_results.push(result);
}
if !tool_results.is_empty() {
println!("{}", "Processing tool results...".cyan());
let mut layer_session = messages.clone();
layer_session.push(crate::session::Message {
role: "assistant".to_string(),
content: output_clone,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
cached: false,
..Default::default()
});
for tool_result in &tool_results {
layer_session.push(crate::session::Message {
role: "tool".to_string(),
content: serde_json::to_string(&tool_result.result).unwrap_or_default(),
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
cached: false,
tool_call_id: Some(tool_result.tool_id.clone()), name: Some(tool_result.tool_name.clone()), ..Default::default()
});
}
let api_start_tool_processing = std::time::Instant::now();
match crate::session::chat_completion_with_provider(
crate::session::ChatCompletionProviderParams {
messages: &layer_session,
model: &effective_model,
temperature: self.config.temperature,
top_p: self.config.top_p,
top_k: self.config.top_k,
max_tokens: self.config.max_tokens,
config,
max_retries: config.max_retries,
cancellation_token: None,
schema: None,
},
)
.await
{
Ok(response) => {
api_time_ms += api_start_tool_processing.elapsed().as_millis() as u64;
let token_usage = response.exchange.usage.clone();
let total_time_ms = total_start.elapsed().as_millis() as u64;
return Ok(LayerResult {
outputs: vec![response.content],
exchange: response.exchange,
token_usage,
tool_calls: response.tool_calls,
api_time_ms,
tool_time_ms,
total_time_ms,
});
}
Err(e) => {
crate::log_error!("{} {}", "Error processing tool results:", e);
}
}
}
}
}
let token_usage = exchange.usage.clone();
let total_time_ms = total_start.elapsed().as_millis() as u64;
Ok(LayerResult {
outputs: vec![output],
exchange,
token_usage,
tool_calls: direct_tool_calls,
api_time_ms,
tool_time_ms,
total_time_ms,
})
}
}