use std::collections::{HashMap, HashSet};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::contract::inference::{ContextWindowPolicy, ReasoningEffort};
use crate::contract::lifecycle::StopConditionSpec;
use crate::registry_spec::{AgentBackendSpec, AgentSpec, BackendConfigError, RemoteEndpoint};
pub type NullablePatch<T> = Option<Option<T>>;
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize, schemars::JsonSchema)]
#[serde(default, deny_unknown_fields)]
pub struct AgentSpecPatch {
#[serde(
default,
deserialize_with = "nullable_patch::deserialize",
serialize_with = "nullable_patch::serialize",
skip_serializing_if = "nullable_patch::is_missing"
)]
pub description: NullablePatch<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub backend: Option<AgentBackendSpec>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_prompt: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_rounds: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_continuation_retries: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_conditions: Option<Vec<StopConditionSpec>>,
#[serde(
default,
deserialize_with = "nullable_patch::deserialize",
serialize_with = "nullable_patch::serialize",
skip_serializing_if = "nullable_patch::is_missing"
)]
pub context_policy: NullablePatch<ContextWindowPolicy>,
#[serde(skip_serializing_if = "Option::is_none")]
pub plugin_ids: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub active_hook_filter: Option<HashSet<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sections: Option<HashMap<String, Value>>,
#[serde(
default,
deserialize_with = "nullable_patch::deserialize",
serialize_with = "nullable_patch::serialize",
skip_serializing_if = "nullable_patch::is_missing"
)]
pub allowed_tools: NullablePatch<Vec<String>>,
#[serde(
default,
deserialize_with = "nullable_patch::deserialize",
serialize_with = "nullable_patch::serialize",
skip_serializing_if = "nullable_patch::is_missing"
)]
pub allowed_tool_patterns: NullablePatch<Vec<String>>,
#[serde(
default,
deserialize_with = "nullable_patch::deserialize",
serialize_with = "nullable_patch::serialize",
skip_serializing_if = "nullable_patch::is_missing"
)]
pub excluded_tools: NullablePatch<Vec<String>>,
#[serde(
default,
deserialize_with = "nullable_patch::deserialize",
serialize_with = "nullable_patch::serialize",
skip_serializing_if = "nullable_patch::is_missing"
)]
pub excluded_tool_patterns: NullablePatch<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub delegates: Option<Vec<String>>,
#[serde(
default,
deserialize_with = "nullable_patch::deserialize",
serialize_with = "nullable_patch::serialize",
skip_serializing_if = "nullable_patch::is_missing"
)]
pub reasoning_effort: NullablePatch<ReasoningEffort>,
#[serde(
default,
deserialize_with = "nullable_patch::deserialize",
serialize_with = "nullable_patch::serialize",
skip_serializing_if = "nullable_patch::is_missing"
)]
pub endpoint: NullablePatch<RemoteEndpoint>,
}
impl AgentSpecPatch {
pub fn is_empty(&self) -> bool {
self.model_id.is_none()
&& self.description.is_none()
&& self.backend.is_none()
&& self.system_prompt.is_none()
&& self.max_rounds.is_none()
&& self.max_continuation_retries.is_none()
&& self.stop_conditions.is_none()
&& self.context_policy.is_none()
&& self.plugin_ids.is_none()
&& self.active_hook_filter.is_none()
&& self.sections.is_none()
&& self.allowed_tools.is_none()
&& self.allowed_tool_patterns.is_none()
&& self.excluded_tools.is_none()
&& self.excluded_tool_patterns.is_none()
&& self.delegates.is_none()
&& self.reasoning_effort.is_none()
&& self.endpoint.is_none()
}
}
pub fn merge_agent_spec(
base: AgentSpec,
patch: AgentSpecPatch,
) -> Result<AgentSpec, BackendConfigError> {
let endpoint_patch = patch.endpoint.clone();
let backend_patched = patch.backend.is_some();
let awaken_fields_patched =
patch.model_id.is_some() || patch.system_prompt.is_some() || patch.max_rounds.is_some();
let mut merged = AgentSpec {
id: base.id,
description: merge_nullable(base.description, patch.description),
backend: patch.backend.unwrap_or(base.backend),
model_id: patch.model_id.unwrap_or(base.model_id),
system_prompt: patch.system_prompt.unwrap_or(base.system_prompt),
max_rounds: patch.max_rounds.unwrap_or(base.max_rounds),
max_continuation_retries: patch
.max_continuation_retries
.unwrap_or(base.max_continuation_retries),
stop_conditions: patch.stop_conditions.unwrap_or(base.stop_conditions),
context_policy: merge_nullable(base.context_policy, patch.context_policy),
plugin_ids: patch.plugin_ids.unwrap_or(base.plugin_ids),
active_hook_filter: patch.active_hook_filter.unwrap_or(base.active_hook_filter),
sections: merge_sections(base.sections, patch.sections),
allowed_tools: merge_nullable(base.allowed_tools, patch.allowed_tools),
allowed_tool_patterns: merge_nullable(
base.allowed_tool_patterns,
patch.allowed_tool_patterns,
),
excluded_tools: merge_nullable(base.excluded_tools, patch.excluded_tools),
excluded_tool_patterns: merge_nullable(
base.excluded_tool_patterns,
patch.excluded_tool_patterns,
),
delegates: patch.delegates.unwrap_or(base.delegates),
reasoning_effort: merge_nullable(base.reasoning_effort, patch.reasoning_effort),
endpoint: merge_nullable(base.endpoint, patch.endpoint),
registry: base.registry,
};
if backend_patched {
if merged.backend.is_awaken() {
merged.endpoint = None;
if let Some(model_id) = merged.backend.awaken_model_id() {
merged.model_id = model_id;
}
if let Some(system_prompt) = merged.backend.awaken_system_prompt() {
merged.system_prompt = system_prompt;
}
} else {
merged.endpoint = merged.backend.remote_endpoint()?;
}
} else {
match endpoint_patch {
Some(Some(ref endpoint)) => {
merged.backend = AgentBackendSpec::from_remote_endpoint(endpoint);
}
Some(None) => {
merged.backend = AgentBackendSpec::awaken_from_fields(
&merged.model_id,
&merged.system_prompt,
merged.max_rounds,
);
}
None if awaken_fields_patched && merged.backend.is_awaken() => {
merged.backend = AgentBackendSpec::awaken_from_fields(
&merged.model_id,
&merged.system_prompt,
merged.max_rounds,
);
}
None => {}
}
}
if merged.allowed_tools.is_none() && merged.allowed_tool_patterns.is_none() {
merged.allowed_tools = Some(Vec::new());
merged.allowed_tool_patterns = Some(Vec::new());
}
Ok(merged)
}
fn merge_nullable<T>(base: Option<T>, patch: NullablePatch<T>) -> Option<T> {
patch.unwrap_or(base)
}
fn merge_sections(
mut base: HashMap<String, Value>,
patch: Option<HashMap<String, Value>>,
) -> HashMap<String, Value> {
let Some(patch) = patch else { return base };
for (key, value) in patch {
if value.is_null() {
base.remove(&key);
} else {
base.insert(key, value);
}
}
base
}
mod nullable_patch {
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub fn serialize<S, T>(value: &Option<Option<T>>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
T: Serialize,
{
match value {
None => serializer.serialize_none(),
Some(inner) => inner.serialize(serializer),
}
}
pub fn deserialize<'de, D, T>(deserializer: D) -> Result<Option<Option<T>>, D::Error>
where
D: Deserializer<'de>,
T: Deserialize<'de>,
{
Option::<T>::deserialize(deserializer).map(Some)
}
pub fn is_missing<T>(value: &Option<Option<T>>) -> bool {
value.is_none()
}
}