1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use super::TextAgent;
6use crate::error::AgentError;
7use crate::state::State;
8
9pub struct MapOverTextAgent {
13 name: String,
14 agent: Arc<dyn TextAgent>,
15 list_key: String,
16 item_key: String,
17 output_key: String,
18}
19
20impl MapOverTextAgent {
21 pub fn new(
23 name: impl Into<String>,
24 agent: Arc<dyn TextAgent>,
25 list_key: impl Into<String>,
26 ) -> Self {
27 Self {
28 name: name.into(),
29 agent,
30 list_key: list_key.into(),
31 item_key: "_item".into(),
32 output_key: "_results".into(),
33 }
34 }
35
36 pub fn item_key(mut self, key: impl Into<String>) -> Self {
38 self.item_key = key.into();
39 self
40 }
41
42 pub fn output_key(mut self, key: impl Into<String>) -> Self {
44 self.output_key = key.into();
45 self
46 }
47}
48
49#[async_trait]
50impl TextAgent for MapOverTextAgent {
51 fn name(&self) -> &str {
52 &self.name
53 }
54
55 async fn run(&self, state: &State) -> Result<String, AgentError> {
56 let items: Vec<serde_json::Value> = state.get(&self.list_key).unwrap_or_default();
57
58 let mut results = Vec::with_capacity(items.len());
59
60 for item in &items {
61 state.set(&self.item_key, item);
62 state.set("input", item.to_string());
63 let result = self.agent.run(state).await?;
64 results.push(result);
65 }
66
67 state.set(&self.output_key, &results);
68 Ok(results.join("\n"))
69 }
70}