Skip to main content

awaken_contract/
agent_spec_patch.rs

1//! Field-level override for [`AgentSpec`].
2//!
3//! Stored as JSON inside [`RecordMeta::user_overrides`] for built-in agents.
4//! Missing fields inherit from the base spec. JSON `null` clears fields whose
5//! base `AgentSpec` representation is optional.
6//! Merge happens at read time via [`merge_agent_spec`].
7
8use std::collections::{HashMap, HashSet};
9
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12
13use crate::contract::inference::{ContextWindowPolicy, ReasoningEffort};
14use crate::registry_spec::{AgentSpec, RemoteEndpoint};
15
16/// Patch value for `AgentSpec` fields that are optional in the base spec.
17///
18/// - `None` = field is missing from the patch, inherit the base value.
19/// - `Some(None)` = field is present as JSON `null`, clear the base value.
20/// - `Some(Some(value))` = field is present as a JSON value, override.
21pub type NullablePatch<T> = Option<Option<T>>;
22
23/// Patch for built-in agent customization.
24///
25/// Override support covers runtime-safe AgentSpec fields. Adding more fields
26/// later is purely additive because missing fields decode as "inherit".
27///
28/// `#[serde(deny_unknown_fields)]` rejects payloads containing field names
29/// that don't exist on this struct, preventing silent drift when callers
30/// misspell or target deprecated fields.
31#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize, schemars::JsonSchema)]
32#[serde(default, deny_unknown_fields)]
33pub struct AgentSpecPatch {
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub model_id: Option<String>,
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub system_prompt: Option<String>,
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub max_rounds: Option<usize>,
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub max_continuation_retries: Option<usize>,
42    #[serde(
43        default,
44        deserialize_with = "nullable_patch::deserialize",
45        serialize_with = "nullable_patch::serialize",
46        skip_serializing_if = "nullable_patch::is_missing"
47    )]
48    pub context_policy: NullablePatch<ContextWindowPolicy>,
49    #[serde(skip_serializing_if = "Option::is_none")]
50    pub plugin_ids: Option<Vec<String>>,
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub active_hook_filter: Option<HashSet<String>>,
53    /// Per-key shallow merge: patch keys override base keys; un-patched
54    /// keys preserved from base. To delete a base key, set its value to
55    /// JSON `null` in this map (handled at merge time).
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub sections: Option<HashMap<String, Value>>,
58    /// Whitelist of tool IDs. `Some([..])` overrides; `None` keeps base.
59    /// JSON `null` clears to "all tools"; missing inherits base.
60    #[serde(
61        default,
62        deserialize_with = "nullable_patch::deserialize",
63        serialize_with = "nullable_patch::serialize",
64        skip_serializing_if = "nullable_patch::is_missing"
65    )]
66    pub allowed_tools: NullablePatch<Vec<String>>,
67    /// Blacklist of tool IDs. Same semantics as `allowed_tools`.
68    #[serde(
69        default,
70        deserialize_with = "nullable_patch::deserialize",
71        serialize_with = "nullable_patch::serialize",
72        skip_serializing_if = "nullable_patch::is_missing"
73    )]
74    pub excluded_tools: NullablePatch<Vec<String>>,
75    /// Sub-agent IDs this agent can delegate to. `Some([..])` overrides.
76    #[serde(skip_serializing_if = "Option::is_none")]
77    pub delegates: Option<Vec<String>>,
78    /// Reasoning effort override. JSON `null` clears the base value.
79    #[serde(
80        default,
81        deserialize_with = "nullable_patch::deserialize",
82        serialize_with = "nullable_patch::serialize",
83        skip_serializing_if = "nullable_patch::is_missing"
84    )]
85    pub reasoning_effort: NullablePatch<ReasoningEffort>,
86    /// Remote endpoint override. JSON `null` clears the base value.
87    #[serde(
88        default,
89        deserialize_with = "nullable_patch::deserialize",
90        serialize_with = "nullable_patch::serialize",
91        skip_serializing_if = "nullable_patch::is_missing"
92    )]
93    pub endpoint: NullablePatch<RemoteEndpoint>,
94}
95
96impl AgentSpecPatch {
97    /// True when no field is set — equivalent to "no override".
98    pub fn is_empty(&self) -> bool {
99        self.model_id.is_none()
100            && self.system_prompt.is_none()
101            && self.max_rounds.is_none()
102            && self.max_continuation_retries.is_none()
103            && self.context_policy.is_none()
104            && self.plugin_ids.is_none()
105            && self.active_hook_filter.is_none()
106            && self.sections.is_none()
107            && self.allowed_tools.is_none()
108            && self.excluded_tools.is_none()
109            && self.delegates.is_none()
110            && self.reasoning_effort.is_none()
111            && self.endpoint.is_none()
112    }
113}
114
115/// Apply a [`AgentSpecPatch`] on top of a base [`AgentSpec`], producing the
116/// effective spec passed to the resolver.
117///
118/// Semantics:
119/// - Scalar fields (`model_id`, `system_prompt`, `max_rounds`,
120///   `max_continuation_retries`): patch's value if `Some`, else base.
121/// - `plugin_ids`: replace whole list when patch is `Some`.
122/// - `sections`: per-key shallow merge. Patch keys override base keys.
123///   A patch value of JSON `null` deletes the corresponding base key.
124/// - Patch-supported option fields (`allowed_tools`, `excluded_tools`,
125///   `reasoning_effort`, `context_policy`, `endpoint`) are tri-state:
126///   missing inherits, JSON `null` clears, and a JSON value overrides.
127/// - Metadata fields pass through from `base` unchanged (id, registry).
128pub fn merge_agent_spec(base: AgentSpec, patch: AgentSpecPatch) -> AgentSpec {
129    AgentSpec {
130        id: base.id,
131        model_id: patch.model_id.unwrap_or(base.model_id),
132        system_prompt: patch.system_prompt.unwrap_or(base.system_prompt),
133        max_rounds: patch.max_rounds.unwrap_or(base.max_rounds),
134        max_continuation_retries: patch
135            .max_continuation_retries
136            .unwrap_or(base.max_continuation_retries),
137        context_policy: merge_nullable(base.context_policy, patch.context_policy),
138        plugin_ids: patch.plugin_ids.unwrap_or(base.plugin_ids),
139        active_hook_filter: patch.active_hook_filter.unwrap_or(base.active_hook_filter),
140        sections: merge_sections(base.sections, patch.sections),
141        allowed_tools: merge_nullable(base.allowed_tools, patch.allowed_tools),
142        excluded_tools: merge_nullable(base.excluded_tools, patch.excluded_tools),
143        delegates: patch.delegates.unwrap_or(base.delegates),
144        reasoning_effort: merge_nullable(base.reasoning_effort, patch.reasoning_effort),
145        endpoint: merge_nullable(base.endpoint, patch.endpoint),
146        // Pass-through metadata:
147        registry: base.registry,
148    }
149}
150
151fn merge_nullable<T>(base: Option<T>, patch: NullablePatch<T>) -> Option<T> {
152    patch.unwrap_or(base)
153}
154
155fn merge_sections(
156    mut base: HashMap<String, Value>,
157    patch: Option<HashMap<String, Value>>,
158) -> HashMap<String, Value> {
159    let Some(patch) = patch else { return base };
160    for (key, value) in patch {
161        if value.is_null() {
162            base.remove(&key);
163        } else {
164            base.insert(key, value);
165        }
166    }
167    base
168}
169
170mod nullable_patch {
171    use serde::{Deserialize, Deserializer, Serialize, Serializer};
172
173    pub fn serialize<S, T>(value: &Option<Option<T>>, serializer: S) -> Result<S::Ok, S::Error>
174    where
175        S: Serializer,
176        T: Serialize,
177    {
178        match value {
179            None => serializer.serialize_none(),
180            Some(inner) => inner.serialize(serializer),
181        }
182    }
183
184    pub fn deserialize<'de, D, T>(deserializer: D) -> Result<Option<Option<T>>, D::Error>
185    where
186        D: Deserializer<'de>,
187        T: Deserialize<'de>,
188    {
189        Option::<T>::deserialize(deserializer).map(Some)
190    }
191
192    pub fn is_missing<T>(value: &Option<Option<T>>) -> bool {
193        value.is_none()
194    }
195}