Skip to main content

rs_adk/text/
map_over.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use super::TextAgent;
6use crate::error::AgentError;
7use crate::state::State;
8
9/// Iterates a single agent over each item in a state list.
10/// Reads `state[list_key]`, runs agent per item (setting `state[item_key]`),
11/// collects results into `state[output_key]`.
12pub struct MapOverTextAgent {
13    name: String,
14    agent: Arc<dyn TextAgent>,
15    list_key: String,
16    item_key: String,
17    output_key: String,
18}
19
20impl MapOverTextAgent {
21    /// Create a new map-over agent that iterates over a list in state.
22    pub fn new(
23        name: impl Into<String>,
24        agent: Arc<dyn TextAgent>,
25        list_key: impl Into<String>,
26    ) -> Self {
27        Self {
28            name: name.into(),
29            agent,
30            list_key: list_key.into(),
31            item_key: "_item".into(),
32            output_key: "_results".into(),
33        }
34    }
35
36    /// Set the state key for the current item (default: "_item").
37    pub fn item_key(mut self, key: impl Into<String>) -> Self {
38        self.item_key = key.into();
39        self
40    }
41
42    /// Set the state key for the output list (default: "_results").
43    pub fn output_key(mut self, key: impl Into<String>) -> Self {
44        self.output_key = key.into();
45        self
46    }
47}
48
49#[async_trait]
50impl TextAgent for MapOverTextAgent {
51    fn name(&self) -> &str {
52        &self.name
53    }
54
55    async fn run(&self, state: &State) -> Result<String, AgentError> {
56        let items: Vec<serde_json::Value> = state.get(&self.list_key).unwrap_or_default();
57
58        let mut results = Vec::with_capacity(items.len());
59
60        for item in &items {
61            state.set(&self.item_key, item);
62            state.set("input", item.to_string());
63            let result = self.agent.run(state).await?;
64            results.push(result);
65        }
66
67        state.set(&self.output_key, &results);
68        Ok(results.join("\n"))
69    }
70}