1use futures::stream::StreamExt;
2use std::collections::HashSet;
3use thiserror::Error;
4
5use ag_ui_core::types::context::Context;
6use ag_ui_core::types::ids::{AgentId, MessageId, RunId, ThreadId};
7use ag_ui_core::types::input::RunAgentInput;
8use ag_ui_core::types::message::Message;
9use ag_ui_core::types::tool::Tool;
10use ag_ui_core::{AgentState, FwdProps, JsonValue};
11
12use crate::event_handler::EventHandler;
13use crate::stream::EventStream;
14use crate::subscriber::IntoSubscribers;
15
16#[derive(Debug, Clone)]
17pub struct AgentConfig<StateT = JsonValue> {
18 pub agent_id: Option<AgentId>,
19 pub description: Option<String>,
20 pub thread_id: Option<ThreadId>,
21 pub initial_messages: Option<Vec<Message>>,
22 pub initial_state: Option<StateT>,
23 pub debug: Option<bool>,
24}
25
26impl<S> Default for AgentConfig<S>
27where
28 S: Default,
29{
30 fn default() -> Self {
31 Self {
32 agent_id: None,
33 description: None,
34 thread_id: None,
35 initial_messages: None,
36 initial_state: None,
37 debug: None,
38 }
39 }
40}
41
42#[derive(Debug, Clone, Default)]
44pub struct RunAgentParams<StateT: AgentState = JsonValue, FwdPropsT: FwdProps = JsonValue> {
45 pub run_id: Option<RunId>,
46 pub tools: Option<Vec<Tool>>,
47 pub context: Option<Vec<Context>>,
48 pub forwarded_props: Option<FwdPropsT>,
49 pub messages: Vec<Message>,
50 pub state: StateT,
51}
52
53#[derive(Debug, Clone)]
54pub struct RunAgentResult<StateT: AgentState> {
55 pub result: JsonValue,
56 pub new_messages: Vec<Message>,
57 pub new_state: StateT,
58}
59
60pub type AgentRunState<StateT, FwdPropsT> = RunAgentInput<StateT, FwdPropsT>;
61
62#[derive(Debug, Clone)]
63pub struct AgentStateMutation<StateT = JsonValue> {
64 pub messages: Option<Vec<Message>>,
65 pub state: Option<StateT>,
66 pub stop_propagation: bool,
67}
68
69impl<StateT> Default for AgentStateMutation<StateT> {
70 fn default() -> Self {
71 Self {
72 messages: None,
73 state: None,
74 stop_propagation: false,
75 }
76 }
77}
78
79#[derive(Error, Debug)]
81pub enum AgentError {
82 #[error("Agent execution failed: {message}")]
83 ExecutionError { message: String },
84 #[error("Invalid configuration: {message}")]
85 ConfigError { message: String },
86 #[error("Serialization error: {source}")]
87 SerializationError {
88 #[from]
89 source: serde_json::Error,
90 },
91}
92
93#[async_trait::async_trait]
96pub trait Agent<StateT = JsonValue, FwdPropsT = JsonValue>: Send + Sync
97where
98 StateT: AgentState,
99 FwdPropsT: FwdProps,
100{
101 async fn run(
102 &self,
103 input: &RunAgentInput<StateT, FwdPropsT>,
104 ) -> Result<EventStream<'async_trait, StateT>, AgentError>;
105
106 async fn run_agent(
109 &self,
110 params: &RunAgentParams<StateT, FwdPropsT>,
111 subscribers: impl IntoSubscribers<StateT, FwdPropsT>,
112 ) -> Result<RunAgentResult<StateT>, AgentError> {
113 let _agent_id = AgentId::random();
115
116 let input = RunAgentInput {
117 thread_id: ThreadId::random(),
118 run_id: params.run_id.clone().unwrap_or_else(RunId::random),
119 state: params.state.clone(),
120 messages: params.messages.clone(),
121 tools: params.tools.clone().unwrap_or_default(),
122 context: params.context.clone().unwrap_or_default(),
123 forwarded_props: params.forwarded_props.clone().unwrap(),
125 };
126 let current_message_ids: HashSet<&MessageId> =
127 params.messages.iter().map(|m| m.id()).collect();
128
129 let subscribers = subscribers.into_subscribers();
131 let mut event_handler = EventHandler::new(
132 params.messages.clone(),
133 params.state.clone(),
134 &input,
135 subscribers,
136 );
137
138 let mut stream = self.run(&input).await?.fuse();
139
140 while let Some(event_result) = stream.next().await {
141 match event_result {
142 Ok(event) => {
143 let mutation = event_handler.handle_event(&event).await?;
144 event_handler.apply_mutation(mutation).await?;
145 }
146 Err(e) => {
147 event_handler.on_error(&e).await?;
148 return Err(e);
149 }
150 }
151 }
152
153 event_handler.on_finalize().await?;
155
156 let new_messages = event_handler
158 .messages
159 .iter()
160 .filter(|m| !current_message_ids.contains(&m.id()))
161 .cloned()
162 .collect();
163
164 Ok(RunAgentResult {
165 result: event_handler.result,
166 new_messages,
167 new_state: event_handler.state,
168 })
169 }
170}