Skip to main content

rs_adk/text/
race.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use super::TextAgent;
6use crate::error::AgentError;
7use crate::state::State;
8
9/// Runs agents concurrently, returns the first to complete. Cancels the rest.
10pub struct RaceTextAgent {
11    name: String,
12    agents: Vec<Arc<dyn TextAgent>>,
13}
14
15impl RaceTextAgent {
16    /// Create a new race agent that runs agents concurrently and returns the first result.
17    pub fn new(name: impl Into<String>, agents: Vec<Arc<dyn TextAgent>>) -> Self {
18        Self {
19            name: name.into(),
20            agents,
21        }
22    }
23}
24
25#[async_trait]
26impl TextAgent for RaceTextAgent {
27    fn name(&self) -> &str {
28        &self.name
29    }
30
31    async fn run(&self, state: &State) -> Result<String, AgentError> {
32        if self.agents.is_empty() {
33            return Err(AgentError::Other("No agents in race".into()));
34        }
35
36        let (tx, mut rx) = tokio::sync::mpsc::channel::<Result<String, AgentError>>(1);
37        let cancel = tokio_util::sync::CancellationToken::new();
38
39        let mut handles = Vec::with_capacity(self.agents.len());
40        for agent in &self.agents {
41            let agent = agent.clone();
42            let state = state.clone();
43            let tx = tx.clone();
44            let cancel = cancel.clone();
45
46            handles.push(tokio::spawn(async move {
47                tokio::select! {
48                    result = agent.run(&state) => {
49                        let _ = tx.send(result).await;
50                    }
51                    _ = cancel.cancelled() => {}
52                }
53            }));
54        }
55        drop(tx); // Close our sender so rx completes when all are done.
56
57        let result = rx
58            .recv()
59            .await
60            .unwrap_or(Err(AgentError::Other("All race agents failed".into())));
61
62        // Cancel remaining agents.
63        cancel.cancel();
64        for handle in handles {
65            handle.abort();
66        }
67
68        result
69    }
70}