Skip to main content

objectiveai_sdk/swarm/
swarm.rs

1//! Core Swarm types and validation logic.
2
3use crate::agent;
4use crate::weights::{Weights, WeightsEntry};
5use indexmap::IndexMap;
6use rust_decimal::Decimal;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use twox_hash::XxHash3_128;
10use schemars::JsonSchema;
11
12// ── Pre-validation types (no computed ID) ──────────────────────────
13
14/// An inline swarm base definition (without computed ID or metadata).
15///
16/// Contains a list of agent configurations that will be validated, deduplicated,
17/// and sorted when converting to an [`InlineSwarm`].
18#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
19#[schemars(rename = "swarm.InlineSwarmBase")]
20pub struct InlineSwarmBase {
21    /// The LLMs in this swarm, with optional counts and fallbacks.
22    pub agents: Vec<agent::InlineAgentBaseWithFallbacksOrRemoteWithCount>,
23    /// Optional weights for each agent. If `None`, uniform weights are used.
24    #[serde(skip_serializing_if = "Option::is_none")]
25    #[schemars(extend("omitempty" = true))]
26    pub weights: Option<Weights>,
27}
28
29impl InlineSwarmBase {
30    /// Validates and converts to an [`InlineSwarm`] with computed ID.
31    ///
32    /// If `weights` is `None`, uniform weights (`Decimal::ONE` per agent) are used.
33    /// Remote agent references are resolved from the provided hashmap.
34    pub fn convert(
35        self,
36        remote_agents: Option<&HashMap<String, agent::RemoteAgentBaseWithFallbacks>>,
37    ) -> Result<InlineSwarm, String> {
38        convert_base(self.agents, self.weights, remote_agents)
39    }
40}
41
42/// A remote swarm base definition with metadata (without computed ID).
43///
44/// Like [`InlineSwarmBase`] but includes a description for remote storage.
45#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
46#[schemars(rename = "swarm.RemoteSwarmBase")]
47pub struct RemoteSwarmBase {
48    /// Human-readable description of what this swarm does.
49    pub description: String,
50    #[serde(flatten)]
51    #[schemars(schema_with = "crate::flatten_schema::<InlineSwarmBase>")]
52    pub inner: InlineSwarmBase,
53}
54
55impl RemoteSwarmBase {
56    /// Validates and converts to a [`RemoteSwarm`] with computed ID.
57    pub fn convert(
58        self,
59        remote_agents: Option<&HashMap<String, agent::RemoteAgentBaseWithFallbacks>>,
60    ) -> Result<RemoteSwarm, String> {
61        Ok(RemoteSwarm {
62            description: self.description,
63            inner: self.inner.convert(remote_agents)?,
64        })
65    }
66}
67
68/// A swarm base definition, either remote (with metadata) or inline.
69#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
70#[serde(untagged)]
71#[schemars(rename = "swarm.SwarmBase")]
72pub enum SwarmBase {
73    #[schemars(title = "Remote")]
74    Remote(RemoteSwarmBase),
75    #[schemars(title = "Inline")]
76    Inline(InlineSwarmBase),
77}
78
79impl SwarmBase {
80    /// Validates and converts to a [`Swarm`] with computed ID.
81    pub fn convert(
82        self,
83        remote_agents: Option<&HashMap<String, agent::RemoteAgentBaseWithFallbacks>>,
84    ) -> Result<Swarm, String> {
85        match self {
86            SwarmBase::Remote(r) => Ok(Swarm::Remote(r.convert(remote_agents)?)),
87            SwarmBase::Inline(i) => Ok(Swarm::Inline(i.convert(remote_agents)?)),
88        }
89    }
90}
91
92// ── Post-validation types (with computed ID) ───────────────────────
93
94/// A validated inline Swarm with its computed content-addressed ID.
95///
96/// Created by converting from [`InlineSwarmBase`] via [`InlineSwarmBase::convert`].
97/// The conversion:
98/// 1. Validates and normalizes each agent
99/// 2. Merges duplicate LLMs (by full_id) and sums their counts
100/// 3. Sorts LLMs by full_id for deterministic ordering
101/// 4. Computes the swarm ID from the sorted (full_id, count) pairs
102/// 5. Aligns weights (merging duplicates by weighted average)
103///
104/// # Constraints
105///
106/// - Individual LLMs with `count: 0` are skipped
107/// - Total agent count (sum of all counts) must be between 1 and 128
108#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
109#[schemars(rename = "swarm.InlineSwarm")]
110pub struct InlineSwarm {
111    /// The deterministic content-addressed ID (22-character base62 string).
112    pub id: String,
113    /// The validated and deduplicated LLMs, sorted by full_id.
114    pub agents: Vec<agent::AgentWithFallbacksWithCount>,
115    /// The aligned weights for each agent.
116    pub weights: Weights,
117}
118
119/// A validated remote Swarm with metadata and computed content-addressed ID.
120#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
121#[schemars(rename = "swarm.RemoteSwarm")]
122pub struct RemoteSwarm {
123    pub description: String,
124    #[serde(flatten)]
125    #[schemars(schema_with = "crate::flatten_schema::<InlineSwarm>")]
126    pub inner: InlineSwarm,
127}
128
129/// A validated Swarm, either remote (with metadata) or inline.
130#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
131#[serde(untagged)]
132#[schemars(rename = "swarm.Swarm")]
133pub enum Swarm {
134    #[schemars(title = "Remote")]
135    Remote(RemoteSwarm),
136    #[schemars(title = "Inline")]
137    Inline(InlineSwarm),
138}
139
140impl InlineSwarm {
141    /// Converts back to an `InlineSwarmBase`, dropping the computed ID.
142    pub fn into_base(self) -> InlineSwarmBase {
143        InlineSwarmBase {
144            agents: self.agents.into_iter().map(|a| {
145                agent::InlineAgentBaseWithFallbacksOrRemoteWithCount {
146                    count: a.count,
147                    inner: agent::InlineAgentBaseWithFallbacksOrRemote::AgentBase(
148                        match a.inner {
149                            agent::AgentWithFallbacks::Inline(i) => agent::InlineAgentBaseWithFallbacks {
150                                inner: i.inner.into_base(),
151                                fallbacks: i.fallbacks.map(|fbs| fbs.into_iter().map(|fb| fb.into_base()).collect()),
152                            },
153                            agent::AgentWithFallbacks::Remote(r) => agent::InlineAgentBaseWithFallbacks {
154                                inner: r.inner.inner.into_base(),
155                                fallbacks: r.inner.fallbacks.map(|fbs| fbs.into_iter().map(|fb| fb.into_base()).collect()),
156                            },
157                        },
158                    ),
159                }
160            }).collect(),
161            weights: Some(self.weights),
162        }
163    }
164}
165
166impl Swarm {
167    /// Returns the inner `InlineSwarm` regardless of variant.
168    pub fn inline(&self) -> &InlineSwarm {
169        match self {
170            Swarm::Remote(r) => &r.inner,
171            Swarm::Inline(i) => i,
172        }
173    }
174
175    /// Consumes self and returns the inner `InlineSwarm`.
176    pub fn into_inline(self) -> InlineSwarm {
177        match self {
178            Swarm::Remote(r) => r.inner,
179            Swarm::Inline(i) => i,
180        }
181    }
182}
183
184// ── InlineSwarmBaseOrRemote ────────────────────────────────────────
185
186/// A swarm specification that is either an inline swarm base
187/// or a remote path reference.
188#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
189#[serde(untagged)]
190#[schemars(rename = "swarm.InlineSwarmBaseOrRemote")]
191pub enum InlineSwarmBaseOrRemote {
192    #[schemars(title = "SwarmBase")]
193    SwarmBase(InlineSwarmBase),
194    #[schemars(title = "Remote")]
195    Remote(crate::RemotePath),
196}
197
198/// Like [`InlineSwarmBaseOrRemote`] but with optional commit.
199/// Used in request types where commit resolution happens server-side.
200#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
201#[serde(untagged)]
202#[schemars(rename = "swarm.InlineSwarmBaseOrRemoteCommitOptional")]
203pub enum InlineSwarmBaseOrRemoteCommitOptional {
204    #[schemars(title = "SwarmBase")]
205    SwarmBase(InlineSwarmBase),
206    #[schemars(title = "Remote")]
207    Remote(crate::RemotePathCommitOptional),
208}
209
210// ── Private helpers ────────────────────────────────────────────────
211
212/// Validates agent fallbacks for duplicate IDs.
213fn validate_agent_fallbacks(agent: &agent::AgentWithFallbacks) -> Result<(), String> {
214    let inline = match agent {
215        agent::AgentWithFallbacks::Remote(a) => &a.inner,
216        agent::AgentWithFallbacks::Inline(a) => a,
217    };
218    if let Some(fallbacks) = &inline.fallbacks {
219        if fallbacks.iter().any(|fb| fb.id() == inline.inner.id()) {
220            return Err(format!(
221                "Agent cannot have identical primary and fallback IDs: {}",
222                inline.inner.id()
223            ));
224        }
225        for i in 0..fallbacks.len() {
226            for j in (i + 1)..fallbacks.len() {
227                if fallbacks[i].id() == fallbacks[j].id() {
228                    return Err(format!(
229                        "Agent cannot have duplicate fallback IDs: {}",
230                        fallbacks[i].id()
231                    ));
232                }
233            }
234        }
235    }
236    Ok(())
237}
238
239/// Converts an agent slot (inline or remote reference) to a validated agent.
240fn convert_agent_slot(
241    slot: agent::InlineAgentBaseWithFallbacksOrRemote,
242    remote_agents: Option<&HashMap<String, agent::RemoteAgentBaseWithFallbacks>>,
243) -> Result<agent::AgentWithFallbacks, String> {
244    match slot {
245        agent::InlineAgentBaseWithFallbacksOrRemote::AgentBase(base_with_fallbacks) => {
246            Ok(agent::AgentWithFallbacks::Inline(base_with_fallbacks.convert()?))
247        }
248        agent::InlineAgentBaseWithFallbacksOrRemote::Remote(path) => {
249            let key = path.key();
250            let remote_agents = remote_agents.ok_or_else(|| {
251                format!(
252                    "remote agent reference '{}' but no agents hashmap provided",
253                    key
254                )
255            })?;
256            let agent_base = remote_agents.get(&key).ok_or_else(|| {
257                format!(
258                    "remote agent '{}' not found in agents hashmap",
259                    key
260                )
261            })?;
262            Ok(agent::AgentWithFallbacks::Remote(agent_base.clone().convert()?))
263        }
264    }
265}
266
267/// Core conversion: validates agents, deduplicates, sorts, computes ID, aligns weights.
268///
269/// If `weights` is `None`, uniform weights (`Decimal::ONE` per agent) are used.
270fn convert_base(
271    agents: Vec<agent::InlineAgentBaseWithFallbacksOrRemoteWithCount>,
272    weights: Option<Weights>,
273    remote_agents: Option<&HashMap<String, agent::RemoteAgentBaseWithFallbacks>>,
274) -> Result<InlineSwarm, String> {
275    // Resolve weights: use provided or default to uniform
276    let weight_pairs: Vec<(Decimal, bool)> = match &weights {
277        Some(w) => {
278            if w.len() != agents.len() {
279                return Err(format!(
280                    "weights length ({}) does not match agents length ({})",
281                    w.len(),
282                    agents.len()
283                ));
284            }
285            w.to_weights_and_invert()
286        }
287        None => vec![(Decimal::ONE, false); agents.len()],
288    };
289
290    // Validate weights are in [0, 1] and at least one is positive.
291    let mut has_positive = false;
292    for (i, (weight, _)) in weight_pairs.iter().enumerate() {
293        if *weight < Decimal::ZERO || *weight > Decimal::ONE {
294            return Err(format!(
295                "weight at index {} must be between 0 and 1, got {}",
296                i, weight
297            ));
298        }
299        if *weight > Decimal::ZERO {
300            has_positive = true;
301        }
302    }
303    if !has_positive {
304        return Err(
305            "weights must have at least one positive value".to_string(),
306        );
307    }
308
309    let mut agents_with_full_id: IndexMap<
310        String,
311        (
312            agent::AgentWithFallbacksWithCount,
313            Decimal, // weighted sum
314            u64,     // total count
315            bool,    // invert
316        ),
317    > = IndexMap::with_capacity(agents.len());
318    let mut count = 0u64;
319
320    for (base_agent, (weight, invert)) in
321        agents.into_iter().zip(weight_pairs.into_iter())
322    {
323        match base_agent.count {
324            0 => continue,
325            n => count += n,
326        }
327        let converted = convert_agent_slot(base_agent.inner, remote_agents)?;
328        validate_agent_fallbacks(&converted)?;
329        let full_id = converted.full_id();
330        let agent_with_count = agent::AgentWithFallbacksWithCount {
331            count: base_agent.count,
332            inner: converted,
333        };
334        match agents_with_full_id.get_mut(&full_id) {
335            Some((
336                existing,
337                weighted_sum,
338                total_count,
339                existing_invert,
340            )) => {
341                if *existing_invert != invert {
342                    return Err(format!(
343                        "conflicting invert flags for merged agent with full_id: {}",
344                        full_id
345                    ));
346                }
347                *weighted_sum += weight * Decimal::from(agent_with_count.count);
348                *total_count += agent_with_count.count;
349                existing.count += agent_with_count.count;
350            }
351            None => {
352                let weighted_sum = weight * Decimal::from(agent_with_count.count);
353                let total_count = agent_with_count.count;
354                agents_with_full_id.insert(
355                    full_id,
356                    (agent_with_count, weighted_sum, total_count, invert),
357                );
358            }
359        }
360    }
361
362    if count == 0 || count > 128 {
363        return Err(
364            "`swarm.agents` must contain between 1 and 128 total LLMs"
365                .to_string(),
366        );
367    }
368
369    agents_with_full_id.sort_unstable_keys();
370
371    let mut hasher = XxHash3_128::with_seed(0);
372    for (full_id, (agent, _, _, _)) in &agents_with_full_id {
373        hasher.write(full_id.as_bytes());
374        let count_bytes = agent.count.to_le_bytes();
375        hasher.write(&count_bytes);
376    }
377    let id = format!("{:0>22}", base62::encode(hasher.finish_128()));
378
379    let mut result_agents = Vec::with_capacity(agents_with_full_id.len());
380    let mut entries = Vec::with_capacity(agents_with_full_id.len());
381    for (_, (agent, weighted_sum, total_count, invert)) in
382        agents_with_full_id
383    {
384        result_agents.push(agent);
385        let merged_weight = weighted_sum / Decimal::from(total_count);
386        entries.push(WeightsEntry {
387            weight: merged_weight,
388            invert: if invert { Some(true) } else { None },
389        });
390    }
391
392    Ok(InlineSwarm {
393        id,
394        agents: result_agents,
395        weights: Weights::Entries(entries),
396    })
397}
398
399/// Merge a validated agent into the dedup map.
400fn merge_agent(
401    agents_with_full_id: &mut IndexMap<String, agent::AgentWithFallbacksWithCount>,
402    agent_with_count: agent::AgentWithFallbacksWithCount,
403) {
404    let full_id = agent_with_count.inner.full_id();
405    match agents_with_full_id.get_mut(&full_id) {
406        Some(existing) => existing.count += agent_with_count.count,
407        None => {
408            agents_with_full_id.insert(full_id, agent_with_count);
409        }
410    }
411}