use crate::config::Config;
use crate::session::{ProviderExchange, Session, TokenUsage};
use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::str::FromStr;
pub struct LayerResult {
pub outputs: Vec<String>, pub exchange: ProviderExchange,
pub token_usage: Option<TokenUsage>,
pub tool_calls: Option<Vec<crate::mcp::McpToolCall>>,
pub api_time_ms: u64, pub tool_time_ms: u64, pub total_time_ms: u64, }
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum InputMode {
Last, All, Summary, }
impl InputMode {
pub fn as_str(&self) -> &'static str {
match self {
InputMode::Last => "last",
InputMode::All => "all",
InputMode::Summary => "summary",
}
}
}
impl FromStr for InputMode {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"last" => Ok(InputMode::Last),
"all" => Ok(InputMode::All),
"summary" => Ok(InputMode::Summary),
_ => Err(format!(
"Unknown input mode: '{}'. Valid options: last, all, summary",
s
)),
}
}
}
fn deserialize_input_mode<'de, D>(deserializer: D) -> Result<InputMode, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
let s = String::deserialize(deserializer)?;
InputMode::from_str(&s).map_err(D::Error::custom)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum OutputMode {
None, Append, Replace, Last, Restart, }
impl OutputMode {
pub fn as_str(&self) -> &'static str {
match self {
OutputMode::None => "none",
OutputMode::Append => "append",
OutputMode::Replace => "replace",
OutputMode::Last => "last",
OutputMode::Restart => "restart",
}
}
}
impl FromStr for OutputMode {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"none" => Ok(OutputMode::None),
"append" => Ok(OutputMode::Append),
"replace" => Ok(OutputMode::Replace),
"last" => Ok(OutputMode::Last),
"restart" => Ok(OutputMode::Restart),
_ => Err(format!(
"Unknown output mode: '{}'. Valid options: none, append, replace, last, restart",
s
)),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum OutputRole {
Assistant, User, }
impl OutputRole {
pub fn as_str(&self) -> &'static str {
match self {
OutputRole::Assistant => "assistant",
OutputRole::User => "user",
}
}
}
impl FromStr for OutputRole {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"assistant" => Ok(OutputRole::Assistant),
"user" => Ok(OutputRole::User),
_ => Err(format!(
"Unknown output role: '{}'. Valid options: assistant, user",
s
)),
}
}
}
fn deserialize_output_mode<'de, D>(deserializer: D) -> Result<OutputMode, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
let s = String::deserialize(deserializer)?;
OutputMode::from_str(&s).map_err(D::Error::custom)
}
fn deserialize_output_role<'de, D>(deserializer: D) -> Result<OutputRole, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
let s = String::deserialize(deserializer)?;
OutputRole::from_str(&s).map_err(D::Error::custom)
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct LayerMcpConfig {
#[serde(default)]
pub server_refs: Vec<String>,
#[serde(default)]
pub allowed_tools: Vec<String>, }
impl LayerMcpConfig {
pub fn is_tool_allowed(&self, tool_name: &str, server_name: &str) -> bool {
if self.allowed_tools.is_empty() {
return true;
}
for pattern in &self.allowed_tools {
if let Some((server_prefix, tool_pattern)) = pattern.split_once(':') {
if server_prefix == server_name {
if tool_pattern == "*" {
return true;
} else if let Some(prefix) = tool_pattern.strip_suffix('*') {
if tool_name.starts_with(prefix) {
return true;
}
} else {
if tool_name == tool_pattern {
return true;
}
}
}
} else {
if tool_name == pattern {
return true;
}
}
}
false
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct LayerConfig {
pub name: String,
pub model: Option<String>,
pub system_prompt: Option<String>,
pub description: String,
pub temperature: f32,
pub top_p: f32,
pub top_k: u32,
pub max_tokens: u32,
#[serde(deserialize_with = "deserialize_input_mode")]
pub input_mode: InputMode,
#[serde(deserialize_with = "deserialize_output_mode")]
pub output_mode: OutputMode,
#[serde(deserialize_with = "deserialize_output_role")]
pub output_role: OutputRole,
#[serde(default)]
pub mcp: LayerMcpConfig,
#[serde(default)]
pub parameters: std::collections::HashMap<String, serde_json::Value>,
#[serde(skip)]
pub processed_system_prompt: Option<String>,
}
impl LayerConfig {
pub fn get_effective_model(&self, session_model: &str) -> String {
self.model
.clone()
.unwrap_or_else(|| session_model.to_string())
}
pub fn get_merged_config_for_layer(
&self,
base_config: &crate::config::Config,
) -> crate::config::Config {
let mut merged_config = base_config.clone();
if !self.mcp.server_refs.is_empty() {
let layer_mcp_config = crate::config::RoleMcpConfig {
server_refs: self.mcp.server_refs.clone(),
allowed_tools: self.mcp.allowed_tools.clone(),
};
let enabled_servers =
layer_mcp_config.get_enabled_servers(&base_config.mcp.servers, None);
crate::log_debug!(
"Layer '{}' enabling {} servers from server_refs: {:?}",
self.name,
enabled_servers.len(),
self.mcp.server_refs
);
merged_config.mcp = crate::config::McpConfig {
servers: enabled_servers,
allowed_tools: self.mcp.allowed_tools.clone(),
};
} else {
merged_config.mcp.servers.clear();
merged_config.mcp.allowed_tools.clear();
}
merged_config
}
pub fn get_effective_system_prompt(&self) -> String {
if let Some(ref processed) = self.processed_system_prompt {
processed.clone()
} else {
if let Some(ref custom_prompt) = self.system_prompt {
custom_prompt.clone()
} else {
format!("You are a specialized AI layer named '{}'. Process the input according to your purpose.", self.name)
}
}
}
pub async fn process_and_cache_system_prompt(&mut self, project_dir: &std::path::Path) {
if let Some(ref custom_prompt) = self.system_prompt {
let processed = self
.process_prompt_placeholders_async(custom_prompt, project_dir)
.await;
self.processed_system_prompt = Some(processed);
} else {
panic!("CRITICAL CONFIG ERROR: Layer '{}' missing system_prompt. All layers must have system_prompt defined in config.", self.name);
}
}
async fn process_prompt_placeholders_async(
&self,
prompt: &str,
project_dir: &std::path::Path,
) -> String {
let mut processed = prompt.to_string();
processed =
crate::session::helper_functions::process_placeholders_async(&processed, project_dir)
.await;
for (key, value) in &self.parameters {
let replacement = match value {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Number(n) => n.to_string(),
serde_json::Value::Bool(b) => b.to_string(),
_ => serde_json::to_string(value).unwrap_or_default(),
};
processed = processed.replace(&format!("{{{{{}}}}}", key), &replacement);
}
processed
}
}
#[async_trait]
pub trait Layer {
fn name(&self) -> &str;
fn config(&self) -> &LayerConfig;
async fn process(
&self,
input: &str,
session: &Session,
config: &Config,
operation_cancelled: tokio::sync::watch::Receiver<bool>,
) -> Result<LayerResult>;
fn prepare_input(&self, input: &str, session: &Session) -> String {
match self.config().input_mode {
InputMode::Last => {
if input.trim().is_empty() {
session
.messages
.iter()
.rfind(|m| m.role == "assistant")
.map(|m| m.content.clone())
.unwrap_or_else(|| {
session
.messages
.iter()
.rfind(|m| m.role == "user")
.map(|m| m.content.clone())
.unwrap_or_else(|| "No previous messages found".to_string())
})
} else {
let last_assistant = session
.messages
.iter()
.rfind(|m| m.role == "assistant")
.map(|m| {
format!(
"Previous response:\n{}\n\nCurrent input:\n{}",
m.content, input
)
})
.unwrap_or_else(|| input.to_string());
last_assistant
}
}
InputMode::All => {
let transcript = session
.messages
.iter()
.filter(|m| m.role != "system")
.map(|m| {
let label = match m.role.as_str() {
"assistant" => "Assistant",
"user" => "User",
other => other,
};
format!("[{}]\n{}", label, m.content)
})
.collect::<Vec<_>>()
.join("\n\n");
if transcript.is_empty() {
input.to_string()
} else {
format!("{}\n\n[Current task]\n{}", transcript, input)
}
}
InputMode::Summary => {
crate::session::summarize_context(session, input)
}
}
}
}