awaken_contract/
agent_spec_patch.rs1use std::collections::{HashMap, HashSet};
9
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12
13use crate::contract::inference::{ContextWindowPolicy, ReasoningEffort};
14use crate::registry_spec::{AgentSpec, RemoteEndpoint};
15
16pub type NullablePatch<T> = Option<Option<T>>;
22
23#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize, schemars::JsonSchema)]
32#[serde(default, deny_unknown_fields)]
33pub struct AgentSpecPatch {
34 #[serde(skip_serializing_if = "Option::is_none")]
35 pub model_id: Option<String>,
36 #[serde(skip_serializing_if = "Option::is_none")]
37 pub system_prompt: Option<String>,
38 #[serde(skip_serializing_if = "Option::is_none")]
39 pub max_rounds: Option<usize>,
40 #[serde(skip_serializing_if = "Option::is_none")]
41 pub max_continuation_retries: Option<usize>,
42 #[serde(
43 default,
44 deserialize_with = "nullable_patch::deserialize",
45 serialize_with = "nullable_patch::serialize",
46 skip_serializing_if = "nullable_patch::is_missing"
47 )]
48 pub context_policy: NullablePatch<ContextWindowPolicy>,
49 #[serde(skip_serializing_if = "Option::is_none")]
50 pub plugin_ids: Option<Vec<String>>,
51 #[serde(skip_serializing_if = "Option::is_none")]
52 pub active_hook_filter: Option<HashSet<String>>,
53 #[serde(skip_serializing_if = "Option::is_none")]
57 pub sections: Option<HashMap<String, Value>>,
58 #[serde(
61 default,
62 deserialize_with = "nullable_patch::deserialize",
63 serialize_with = "nullable_patch::serialize",
64 skip_serializing_if = "nullable_patch::is_missing"
65 )]
66 pub allowed_tools: NullablePatch<Vec<String>>,
67 #[serde(
69 default,
70 deserialize_with = "nullable_patch::deserialize",
71 serialize_with = "nullable_patch::serialize",
72 skip_serializing_if = "nullable_patch::is_missing"
73 )]
74 pub excluded_tools: NullablePatch<Vec<String>>,
75 #[serde(skip_serializing_if = "Option::is_none")]
77 pub delegates: Option<Vec<String>>,
78 #[serde(
80 default,
81 deserialize_with = "nullable_patch::deserialize",
82 serialize_with = "nullable_patch::serialize",
83 skip_serializing_if = "nullable_patch::is_missing"
84 )]
85 pub reasoning_effort: NullablePatch<ReasoningEffort>,
86 #[serde(
88 default,
89 deserialize_with = "nullable_patch::deserialize",
90 serialize_with = "nullable_patch::serialize",
91 skip_serializing_if = "nullable_patch::is_missing"
92 )]
93 pub endpoint: NullablePatch<RemoteEndpoint>,
94}
95
96impl AgentSpecPatch {
97 pub fn is_empty(&self) -> bool {
99 self.model_id.is_none()
100 && self.system_prompt.is_none()
101 && self.max_rounds.is_none()
102 && self.max_continuation_retries.is_none()
103 && self.context_policy.is_none()
104 && self.plugin_ids.is_none()
105 && self.active_hook_filter.is_none()
106 && self.sections.is_none()
107 && self.allowed_tools.is_none()
108 && self.excluded_tools.is_none()
109 && self.delegates.is_none()
110 && self.reasoning_effort.is_none()
111 && self.endpoint.is_none()
112 }
113}
114
115pub fn merge_agent_spec(base: AgentSpec, patch: AgentSpecPatch) -> AgentSpec {
129 AgentSpec {
130 id: base.id,
131 model_id: patch.model_id.unwrap_or(base.model_id),
132 system_prompt: patch.system_prompt.unwrap_or(base.system_prompt),
133 max_rounds: patch.max_rounds.unwrap_or(base.max_rounds),
134 max_continuation_retries: patch
135 .max_continuation_retries
136 .unwrap_or(base.max_continuation_retries),
137 context_policy: merge_nullable(base.context_policy, patch.context_policy),
138 plugin_ids: patch.plugin_ids.unwrap_or(base.plugin_ids),
139 active_hook_filter: patch.active_hook_filter.unwrap_or(base.active_hook_filter),
140 sections: merge_sections(base.sections, patch.sections),
141 allowed_tools: merge_nullable(base.allowed_tools, patch.allowed_tools),
142 excluded_tools: merge_nullable(base.excluded_tools, patch.excluded_tools),
143 delegates: patch.delegates.unwrap_or(base.delegates),
144 reasoning_effort: merge_nullable(base.reasoning_effort, patch.reasoning_effort),
145 endpoint: merge_nullable(base.endpoint, patch.endpoint),
146 registry: base.registry,
148 }
149}
150
151fn merge_nullable<T>(base: Option<T>, patch: NullablePatch<T>) -> Option<T> {
152 patch.unwrap_or(base)
153}
154
155fn merge_sections(
156 mut base: HashMap<String, Value>,
157 patch: Option<HashMap<String, Value>>,
158) -> HashMap<String, Value> {
159 let Some(patch) = patch else { return base };
160 for (key, value) in patch {
161 if value.is_null() {
162 base.remove(&key);
163 } else {
164 base.insert(key, value);
165 }
166 }
167 base
168}
169
170mod nullable_patch {
171 use serde::{Deserialize, Deserializer, Serialize, Serializer};
172
173 pub fn serialize<S, T>(value: &Option<Option<T>>, serializer: S) -> Result<S::Ok, S::Error>
174 where
175 S: Serializer,
176 T: Serialize,
177 {
178 match value {
179 None => serializer.serialize_none(),
180 Some(inner) => inner.serialize(serializer),
181 }
182 }
183
184 pub fn deserialize<'de, D, T>(deserializer: D) -> Result<Option<Option<T>>, D::Error>
185 where
186 D: Deserializer<'de>,
187 T: Deserialize<'de>,
188 {
189 Option::<T>::deserialize(deserializer).map(Some)
190 }
191
192 pub fn is_missing<T>(value: &Option<Option<T>>) -> bool {
193 value.is_none()
194 }
195}