1use crate::agent;
4use crate::weights::{Weights, WeightsEntry};
5use indexmap::IndexMap;
6use rust_decimal::Decimal;
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use twox_hash::XxHash3_128;
11
12#[derive(
19 Clone,
20 Debug,
21 PartialEq,
22 Serialize,
23 Deserialize,
24 JsonSchema,
25 arbitrary::Arbitrary,
26)]
27#[schemars(rename = "swarm.InlineSwarmBase")]
28pub struct InlineSwarmBase {
29 pub agents: Vec<agent::InlineAgentBaseWithFallbacksOrRemoteWithCount>,
31 #[serde(skip_serializing_if = "Option::is_none")]
33 #[schemars(extend("omitempty" = true))]
34 pub weights: Option<Weights>,
35}
36
37impl InlineSwarmBase {
38 pub fn convert(
43 self,
44 remote_agents: Option<
45 &HashMap<String, (agent::RemoteAgentBaseWithFallbacks, crate::RemotePath)>,
46 >,
47 ) -> Result<InlineSwarm, String> {
48 convert_base(self.agents, self.weights, remote_agents)
49 }
50}
51
52#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
56#[schemars(rename = "swarm.RemoteSwarmBase")]
57pub struct RemoteSwarmBase {
58 pub description: String,
60 #[serde(flatten)]
61 #[schemars(schema_with = "crate::flatten_schema::<InlineSwarmBase>")]
62 pub inner: InlineSwarmBase,
63}
64
65impl RemoteSwarmBase {
66 pub fn convert(
68 self,
69 remote_agents: Option<
70 &HashMap<String, (agent::RemoteAgentBaseWithFallbacks, crate::RemotePath)>,
71 >,
72 ) -> Result<RemoteSwarm, String> {
73 Ok(RemoteSwarm {
74 description: self.description,
75 inner: self.inner.convert(remote_agents)?,
76 })
77 }
78}
79
80#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
82#[serde(untagged)]
83#[schemars(rename = "swarm.SwarmBase")]
84pub enum SwarmBase {
85 #[schemars(title = "Remote")]
86 Remote(RemoteSwarmBase),
87 #[schemars(title = "Inline")]
88 Inline(InlineSwarmBase),
89}
90
91impl SwarmBase {
92 pub fn convert(
94 self,
95 remote_agents: Option<
96 &HashMap<String, (agent::RemoteAgentBaseWithFallbacks, crate::RemotePath)>,
97 >,
98 ) -> Result<Swarm, String> {
99 match self {
100 SwarmBase::Remote(r) => {
101 Ok(Swarm::Remote(r.convert(remote_agents)?))
102 }
103 SwarmBase::Inline(i) => {
104 Ok(Swarm::Inline(i.convert(remote_agents)?))
105 }
106 }
107 }
108}
109
110#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
127#[schemars(rename = "swarm.InlineSwarm")]
128pub struct InlineSwarm {
129 pub id: String,
131 pub agents: Vec<agent::AgentWithFallbacksWithCount>,
133 pub weights: Weights,
135}
136
137#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
139#[schemars(rename = "swarm.RemoteSwarm")]
140pub struct RemoteSwarm {
141 pub description: String,
142 #[serde(flatten)]
143 #[schemars(schema_with = "crate::flatten_schema::<InlineSwarm>")]
144 pub inner: InlineSwarm,
145}
146
147#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
149#[serde(untagged)]
150#[schemars(rename = "swarm.Swarm")]
151pub enum Swarm {
152 #[schemars(title = "Remote")]
153 Remote(RemoteSwarm),
154 #[schemars(title = "Inline")]
155 Inline(InlineSwarm),
156}
157
158impl InlineSwarm {
159 pub fn into_base(self) -> InlineSwarmBase {
161 InlineSwarmBase {
162 agents: self
163 .agents
164 .into_iter()
165 .map(|a| agent::InlineAgentBaseWithFallbacksOrRemoteWithCount {
166 count: a.count,
167 inner:
168 agent::InlineAgentBaseWithFallbacksOrRemote::AgentBase(
169 match a.inner {
170 agent::AgentWithFallbacks::Inline(i) => {
171 agent::InlineAgentBaseWithFallbacks {
172 inner: i.inner.into_base(),
173 fallbacks: i.fallbacks.map(|fbs| {
174 fbs.into_iter()
175 .map(|fb| fb.into_base())
176 .collect()
177 }),
178 }
179 }
180 agent::AgentWithFallbacks::Remote(r) => {
181 agent::InlineAgentBaseWithFallbacks {
182 inner: r.inner.inner.into_base(),
183 fallbacks: r.inner.fallbacks.map(
184 |fbs| {
185 fbs.into_iter()
186 .map(|fb| fb.into_base())
187 .collect()
188 },
189 ),
190 }
191 }
192 },
193 ),
194 })
195 .collect(),
196 weights: Some(self.weights),
197 }
198 }
199}
200
201impl Swarm {
202 pub fn inline(&self) -> &InlineSwarm {
204 match self {
205 Swarm::Remote(r) => &r.inner,
206 Swarm::Inline(i) => i,
207 }
208 }
209
210 pub fn into_inline(self) -> InlineSwarm {
212 match self {
213 Swarm::Remote(r) => r.inner,
214 Swarm::Inline(i) => i,
215 }
216 }
217}
218
219#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
224#[serde(untagged)]
225#[schemars(rename = "swarm.InlineSwarmBaseOrRemote")]
226pub enum InlineSwarmBaseOrRemote {
227 #[schemars(title = "SwarmBase")]
228 SwarmBase(InlineSwarmBase),
229 #[schemars(title = "Remote")]
230 Remote(crate::RemotePath),
231}
232
233#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
236#[serde(untagged)]
237#[schemars(rename = "swarm.InlineSwarmBaseOrRemoteCommitOptional")]
238pub enum InlineSwarmBaseOrRemoteCommitOptional {
239 #[schemars(title = "SwarmBase")]
240 SwarmBase(InlineSwarmBase),
241 #[schemars(title = "Remote")]
242 Remote(crate::RemotePathCommitOptional),
243}
244
245fn validate_agent_fallbacks(
249 agent: &agent::AgentWithFallbacks,
250) -> Result<(), String> {
251 let inline = match agent {
252 agent::AgentWithFallbacks::Remote(a) => &a.inner,
253 agent::AgentWithFallbacks::Inline(a) => a,
254 };
255 if let Some(fallbacks) = &inline.fallbacks {
256 if fallbacks.iter().any(|fb| fb.id() == inline.inner.id()) {
257 return Err(format!(
258 "Agent cannot have identical primary and fallback IDs: {}",
259 inline.inner.id()
260 ));
261 }
262 for i in 0..fallbacks.len() {
263 for j in (i + 1)..fallbacks.len() {
264 if fallbacks[i].id() == fallbacks[j].id() {
265 return Err(format!(
266 "Agent cannot have duplicate fallback IDs: {}",
267 fallbacks[i].id()
268 ));
269 }
270 }
271 }
272 }
273 Ok(())
274}
275
276fn convert_agent_slot(
278 slot: agent::InlineAgentBaseWithFallbacksOrRemote,
279 remote_agents: Option<
280 &HashMap<String, (agent::RemoteAgentBaseWithFallbacks, crate::RemotePath)>,
281 >,
282) -> Result<agent::AgentWithFallbacks, String> {
283 match slot {
284 agent::InlineAgentBaseWithFallbacksOrRemote::AgentBase(
285 base_with_fallbacks,
286 ) => Ok(agent::AgentWithFallbacks::Inline(
287 base_with_fallbacks.convert()?,
288 )),
289 agent::InlineAgentBaseWithFallbacksOrRemote::Remote(path) => {
290 let key = path.key();
291 let remote_agents = remote_agents.ok_or_else(|| {
292 format!(
293 "remote agent reference '{}' but no agents hashmap provided",
294 key
295 )
296 })?;
297 let agent_base = remote_agents.get(&key).ok_or_else(|| {
298 format!("remote agent '{}' not found in agents hashmap", key)
299 })?;
300 Ok(agent::AgentWithFallbacks::Remote(
304 agent_base.0.clone().convert()?,
305 ))
306 }
307 }
308}
309
310fn convert_base(
314 agents: Vec<agent::InlineAgentBaseWithFallbacksOrRemoteWithCount>,
315 weights: Option<Weights>,
316 remote_agents: Option<
317 &HashMap<String, (agent::RemoteAgentBaseWithFallbacks, crate::RemotePath)>,
318 >,
319) -> Result<InlineSwarm, String> {
320 let weight_pairs: Vec<(Decimal, bool)> = match &weights {
322 Some(w) => {
323 if w.len() != agents.len() {
324 return Err(format!(
325 "weights length ({}) does not match agents length ({})",
326 w.len(),
327 agents.len()
328 ));
329 }
330 w.to_weights_and_invert()
331 }
332 None => vec![(Decimal::ONE, false); agents.len()],
333 };
334
335 let mut has_positive = false;
337 for (i, (weight, _)) in weight_pairs.iter().enumerate() {
338 if *weight < Decimal::ZERO || *weight > Decimal::ONE {
339 return Err(format!(
340 "weight at index {} must be between 0 and 1, got {}",
341 i, weight
342 ));
343 }
344 if *weight > Decimal::ZERO {
345 has_positive = true;
346 }
347 }
348 if !has_positive {
349 return Err("weights must have at least one positive value".to_string());
350 }
351
352 let mut agents_with_full_id: IndexMap<
353 String,
354 (
355 agent::AgentWithFallbacksWithCount,
356 Decimal, u64, bool, ),
360 > = IndexMap::with_capacity(agents.len());
361 let mut count = 0u64;
362
363 for (base_agent, (weight, invert)) in
364 agents.into_iter().zip(weight_pairs.into_iter())
365 {
366 match base_agent.count {
367 0 => continue,
368 n => count += n,
369 }
370 let converted = convert_agent_slot(base_agent.inner, remote_agents)?;
371 validate_agent_fallbacks(&converted)?;
372 let full_id = converted.full_id();
373 let agent_with_count = agent::AgentWithFallbacksWithCount {
374 count: base_agent.count,
375 inner: converted,
376 };
377 match agents_with_full_id.get_mut(&full_id) {
378 Some((existing, weighted_sum, total_count, existing_invert)) => {
379 if *existing_invert != invert {
380 return Err(format!(
381 "conflicting invert flags for merged agent with full_id: {}",
382 full_id
383 ));
384 }
385 *weighted_sum += weight * Decimal::from(agent_with_count.count);
386 *total_count += agent_with_count.count;
387 existing.count += agent_with_count.count;
388 }
389 None => {
390 let weighted_sum =
391 weight * Decimal::from(agent_with_count.count);
392 let total_count = agent_with_count.count;
393 agents_with_full_id.insert(
394 full_id,
395 (agent_with_count, weighted_sum, total_count, invert),
396 );
397 }
398 }
399 }
400
401 if count == 0 || count > 128 {
402 return Err("`swarm.agents` must contain between 1 and 128 total LLMs"
403 .to_string());
404 }
405
406 agents_with_full_id.sort_unstable_keys();
407
408 let mut hasher = XxHash3_128::with_seed(0);
409 for (full_id, (agent, _, _, _)) in &agents_with_full_id {
410 hasher.write(full_id.as_bytes());
411 let count_bytes = agent.count.to_le_bytes();
412 hasher.write(&count_bytes);
413 }
414 let id = format!("{:0>22}", base62::encode(hasher.finish_128()));
415
416 let mut result_agents = Vec::with_capacity(agents_with_full_id.len());
417 let mut entries = Vec::with_capacity(agents_with_full_id.len());
418 for (_, (agent, weighted_sum, total_count, invert)) in agents_with_full_id {
419 result_agents.push(agent);
420 let merged_weight = weighted_sum / Decimal::from(total_count);
421 entries.push(WeightsEntry {
422 weight: merged_weight,
423 invert: if invert { Some(true) } else { None },
424 });
425 }
426
427 Ok(InlineSwarm {
428 id,
429 agents: result_agents,
430 weights: Weights::Entries(entries),
431 })
432}
433
434fn merge_agent(
436 agents_with_full_id: &mut IndexMap<
437 String,
438 agent::AgentWithFallbacksWithCount,
439 >,
440 agent_with_count: agent::AgentWithFallbacksWithCount,
441) {
442 let full_id = agent_with_count.inner.full_id();
443 match agents_with_full_id.get_mut(&full_id) {
444 Some(existing) => existing.count += agent_with_count.count,
445 None => {
446 agents_with_full_id.insert(full_id, agent_with_count);
447 }
448 }
449}