1use crate::agent;
4use crate::weights::{Weights, WeightsEntry};
5use indexmap::IndexMap;
6use rust_decimal::Decimal;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use twox_hash::XxHash3_128;
10use schemars::JsonSchema;
11
12#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
19#[schemars(rename = "swarm.InlineSwarmBase")]
20pub struct InlineSwarmBase {
21 pub agents: Vec<agent::InlineAgentBaseWithFallbacksOrRemoteWithCount>,
23 #[serde(skip_serializing_if = "Option::is_none")]
25 #[schemars(extend("omitempty" = true))]
26 pub weights: Option<Weights>,
27}
28
29impl InlineSwarmBase {
30 pub fn convert(
35 self,
36 remote_agents: Option<&HashMap<String, agent::RemoteAgentBaseWithFallbacks>>,
37 ) -> Result<InlineSwarm, String> {
38 convert_base(self.agents, self.weights, remote_agents)
39 }
40}
41
42#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
46#[schemars(rename = "swarm.RemoteSwarmBase")]
47pub struct RemoteSwarmBase {
48 pub description: String,
50 #[serde(flatten)]
51 #[schemars(schema_with = "crate::flatten_schema::<InlineSwarmBase>")]
52 pub inner: InlineSwarmBase,
53}
54
55impl RemoteSwarmBase {
56 pub fn convert(
58 self,
59 remote_agents: Option<&HashMap<String, agent::RemoteAgentBaseWithFallbacks>>,
60 ) -> Result<RemoteSwarm, String> {
61 Ok(RemoteSwarm {
62 description: self.description,
63 inner: self.inner.convert(remote_agents)?,
64 })
65 }
66}
67
68#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
70#[serde(untagged)]
71#[schemars(rename = "swarm.SwarmBase")]
72pub enum SwarmBase {
73 #[schemars(title = "Remote")]
74 Remote(RemoteSwarmBase),
75 #[schemars(title = "Inline")]
76 Inline(InlineSwarmBase),
77}
78
79impl SwarmBase {
80 pub fn convert(
82 self,
83 remote_agents: Option<&HashMap<String, agent::RemoteAgentBaseWithFallbacks>>,
84 ) -> Result<Swarm, String> {
85 match self {
86 SwarmBase::Remote(r) => Ok(Swarm::Remote(r.convert(remote_agents)?)),
87 SwarmBase::Inline(i) => Ok(Swarm::Inline(i.convert(remote_agents)?)),
88 }
89 }
90}
91
92#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
109#[schemars(rename = "swarm.InlineSwarm")]
110pub struct InlineSwarm {
111 pub id: String,
113 pub agents: Vec<agent::AgentWithFallbacksWithCount>,
115 pub weights: Weights,
117}
118
119#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
121#[schemars(rename = "swarm.RemoteSwarm")]
122pub struct RemoteSwarm {
123 pub description: String,
124 #[serde(flatten)]
125 #[schemars(schema_with = "crate::flatten_schema::<InlineSwarm>")]
126 pub inner: InlineSwarm,
127}
128
129#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
131#[serde(untagged)]
132#[schemars(rename = "swarm.Swarm")]
133pub enum Swarm {
134 #[schemars(title = "Remote")]
135 Remote(RemoteSwarm),
136 #[schemars(title = "Inline")]
137 Inline(InlineSwarm),
138}
139
140impl InlineSwarm {
141 pub fn into_base(self) -> InlineSwarmBase {
143 InlineSwarmBase {
144 agents: self.agents.into_iter().map(|a| {
145 agent::InlineAgentBaseWithFallbacksOrRemoteWithCount {
146 count: a.count,
147 inner: agent::InlineAgentBaseWithFallbacksOrRemote::AgentBase(
148 match a.inner {
149 agent::AgentWithFallbacks::Inline(i) => agent::InlineAgentBaseWithFallbacks {
150 inner: i.inner.into_base(),
151 fallbacks: i.fallbacks.map(|fbs| fbs.into_iter().map(|fb| fb.into_base()).collect()),
152 },
153 agent::AgentWithFallbacks::Remote(r) => agent::InlineAgentBaseWithFallbacks {
154 inner: r.inner.inner.into_base(),
155 fallbacks: r.inner.fallbacks.map(|fbs| fbs.into_iter().map(|fb| fb.into_base()).collect()),
156 },
157 },
158 ),
159 }
160 }).collect(),
161 weights: Some(self.weights),
162 }
163 }
164}
165
166impl Swarm {
167 pub fn inline(&self) -> &InlineSwarm {
169 match self {
170 Swarm::Remote(r) => &r.inner,
171 Swarm::Inline(i) => i,
172 }
173 }
174
175 pub fn into_inline(self) -> InlineSwarm {
177 match self {
178 Swarm::Remote(r) => r.inner,
179 Swarm::Inline(i) => i,
180 }
181 }
182}
183
184#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
189#[serde(untagged)]
190#[schemars(rename = "swarm.InlineSwarmBaseOrRemote")]
191pub enum InlineSwarmBaseOrRemote {
192 #[schemars(title = "SwarmBase")]
193 SwarmBase(InlineSwarmBase),
194 #[schemars(title = "Remote")]
195 Remote(crate::RemotePath),
196}
197
198#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
201#[serde(untagged)]
202#[schemars(rename = "swarm.InlineSwarmBaseOrRemoteCommitOptional")]
203pub enum InlineSwarmBaseOrRemoteCommitOptional {
204 #[schemars(title = "SwarmBase")]
205 SwarmBase(InlineSwarmBase),
206 #[schemars(title = "Remote")]
207 Remote(crate::RemotePathCommitOptional),
208}
209
210fn validate_agent_fallbacks(agent: &agent::AgentWithFallbacks) -> Result<(), String> {
214 let inline = match agent {
215 agent::AgentWithFallbacks::Remote(a) => &a.inner,
216 agent::AgentWithFallbacks::Inline(a) => a,
217 };
218 if let Some(fallbacks) = &inline.fallbacks {
219 if fallbacks.iter().any(|fb| fb.id() == inline.inner.id()) {
220 return Err(format!(
221 "Agent cannot have identical primary and fallback IDs: {}",
222 inline.inner.id()
223 ));
224 }
225 for i in 0..fallbacks.len() {
226 for j in (i + 1)..fallbacks.len() {
227 if fallbacks[i].id() == fallbacks[j].id() {
228 return Err(format!(
229 "Agent cannot have duplicate fallback IDs: {}",
230 fallbacks[i].id()
231 ));
232 }
233 }
234 }
235 }
236 Ok(())
237}
238
239fn convert_agent_slot(
241 slot: agent::InlineAgentBaseWithFallbacksOrRemote,
242 remote_agents: Option<&HashMap<String, agent::RemoteAgentBaseWithFallbacks>>,
243) -> Result<agent::AgentWithFallbacks, String> {
244 match slot {
245 agent::InlineAgentBaseWithFallbacksOrRemote::AgentBase(base_with_fallbacks) => {
246 Ok(agent::AgentWithFallbacks::Inline(base_with_fallbacks.convert()?))
247 }
248 agent::InlineAgentBaseWithFallbacksOrRemote::Remote(path) => {
249 let key = path.key();
250 let remote_agents = remote_agents.ok_or_else(|| {
251 format!(
252 "remote agent reference '{}' but no agents hashmap provided",
253 key
254 )
255 })?;
256 let agent_base = remote_agents.get(&key).ok_or_else(|| {
257 format!(
258 "remote agent '{}' not found in agents hashmap",
259 key
260 )
261 })?;
262 Ok(agent::AgentWithFallbacks::Remote(agent_base.clone().convert()?))
263 }
264 }
265}
266
267fn convert_base(
271 agents: Vec<agent::InlineAgentBaseWithFallbacksOrRemoteWithCount>,
272 weights: Option<Weights>,
273 remote_agents: Option<&HashMap<String, agent::RemoteAgentBaseWithFallbacks>>,
274) -> Result<InlineSwarm, String> {
275 let weight_pairs: Vec<(Decimal, bool)> = match &weights {
277 Some(w) => {
278 if w.len() != agents.len() {
279 return Err(format!(
280 "weights length ({}) does not match agents length ({})",
281 w.len(),
282 agents.len()
283 ));
284 }
285 w.to_weights_and_invert()
286 }
287 None => vec![(Decimal::ONE, false); agents.len()],
288 };
289
290 let mut has_positive = false;
292 for (i, (weight, _)) in weight_pairs.iter().enumerate() {
293 if *weight < Decimal::ZERO || *weight > Decimal::ONE {
294 return Err(format!(
295 "weight at index {} must be between 0 and 1, got {}",
296 i, weight
297 ));
298 }
299 if *weight > Decimal::ZERO {
300 has_positive = true;
301 }
302 }
303 if !has_positive {
304 return Err(
305 "weights must have at least one positive value".to_string(),
306 );
307 }
308
309 let mut agents_with_full_id: IndexMap<
310 String,
311 (
312 agent::AgentWithFallbacksWithCount,
313 Decimal, u64, bool, ),
317 > = IndexMap::with_capacity(agents.len());
318 let mut count = 0u64;
319
320 for (base_agent, (weight, invert)) in
321 agents.into_iter().zip(weight_pairs.into_iter())
322 {
323 match base_agent.count {
324 0 => continue,
325 n => count += n,
326 }
327 let converted = convert_agent_slot(base_agent.inner, remote_agents)?;
328 validate_agent_fallbacks(&converted)?;
329 let full_id = converted.full_id();
330 let agent_with_count = agent::AgentWithFallbacksWithCount {
331 count: base_agent.count,
332 inner: converted,
333 };
334 match agents_with_full_id.get_mut(&full_id) {
335 Some((
336 existing,
337 weighted_sum,
338 total_count,
339 existing_invert,
340 )) => {
341 if *existing_invert != invert {
342 return Err(format!(
343 "conflicting invert flags for merged agent with full_id: {}",
344 full_id
345 ));
346 }
347 *weighted_sum += weight * Decimal::from(agent_with_count.count);
348 *total_count += agent_with_count.count;
349 existing.count += agent_with_count.count;
350 }
351 None => {
352 let weighted_sum = weight * Decimal::from(agent_with_count.count);
353 let total_count = agent_with_count.count;
354 agents_with_full_id.insert(
355 full_id,
356 (agent_with_count, weighted_sum, total_count, invert),
357 );
358 }
359 }
360 }
361
362 if count == 0 || count > 128 {
363 return Err(
364 "`swarm.agents` must contain between 1 and 128 total LLMs"
365 .to_string(),
366 );
367 }
368
369 agents_with_full_id.sort_unstable_keys();
370
371 let mut hasher = XxHash3_128::with_seed(0);
372 for (full_id, (agent, _, _, _)) in &agents_with_full_id {
373 hasher.write(full_id.as_bytes());
374 let count_bytes = agent.count.to_le_bytes();
375 hasher.write(&count_bytes);
376 }
377 let id = format!("{:0>22}", base62::encode(hasher.finish_128()));
378
379 let mut result_agents = Vec::with_capacity(agents_with_full_id.len());
380 let mut entries = Vec::with_capacity(agents_with_full_id.len());
381 for (_, (agent, weighted_sum, total_count, invert)) in
382 agents_with_full_id
383 {
384 result_agents.push(agent);
385 let merged_weight = weighted_sum / Decimal::from(total_count);
386 entries.push(WeightsEntry {
387 weight: merged_weight,
388 invert: if invert { Some(true) } else { None },
389 });
390 }
391
392 Ok(InlineSwarm {
393 id,
394 agents: result_agents,
395 weights: Weights::Entries(entries),
396 })
397}
398
399fn merge_agent(
401 agents_with_full_id: &mut IndexMap<String, agent::AgentWithFallbacksWithCount>,
402 agent_with_count: agent::AgentWithFallbacksWithCount,
403) {
404 let full_id = agent_with_count.inner.full_id();
405 match agents_with_full_id.get_mut(&full_id) {
406 Some(existing) => existing.count += agent_with_count.count,
407 None => {
408 agents_with_full_id.insert(full_id, agent_with_count);
409 }
410 }
411}