ag_ui_client/
agent.rs

1use futures::stream::StreamExt;
2use std::collections::HashSet;
3use thiserror::Error;
4
5use ag_ui_core::types::context::Context;
6use ag_ui_core::types::ids::{AgentId, MessageId, RunId, ThreadId};
7use ag_ui_core::types::input::RunAgentInput;
8use ag_ui_core::types::message::Message;
9use ag_ui_core::types::tool::Tool;
10use ag_ui_core::{AgentState, FwdProps, JsonValue};
11
12use crate::event_handler::EventHandler;
13use crate::stream::EventStream;
14use crate::subscriber::IntoSubscribers;
15
16#[derive(Debug, Clone)]
17pub struct AgentConfig<StateT = JsonValue> {
18    pub agent_id: Option<AgentId>,
19    pub description: Option<String>,
20    pub thread_id: Option<ThreadId>,
21    pub initial_messages: Option<Vec<Message>>,
22    pub initial_state: Option<StateT>,
23    pub debug: Option<bool>,
24}
25
26impl<S> Default for AgentConfig<S>
27where
28    S: Default,
29{
30    fn default() -> Self {
31        Self {
32            agent_id: None,
33            description: None,
34            thread_id: None,
35            initial_messages: None,
36            initial_state: None,
37            debug: None,
38        }
39    }
40}
41
42/// Parameters for running an agent.
43#[derive(Debug, Clone, Default)]
44pub struct RunAgentParams<StateT: AgentState = JsonValue, FwdPropsT: FwdProps = JsonValue> {
45    pub run_id: Option<RunId>,
46    pub tools: Option<Vec<Tool>>,
47    pub context: Option<Vec<Context>>,
48    pub forwarded_props: Option<FwdPropsT>,
49    pub messages: Vec<Message>,
50    pub state: StateT,
51}
52
53#[derive(Debug, Clone)]
54pub struct RunAgentResult<StateT: AgentState> {
55    pub result: JsonValue,
56    pub new_messages: Vec<Message>,
57    pub new_state: StateT,
58}
59
60pub type AgentRunState<StateT, FwdPropsT> = RunAgentInput<StateT, FwdPropsT>;
61
62#[derive(Debug, Clone)]
63pub struct AgentStateMutation<StateT = JsonValue> {
64    pub messages: Option<Vec<Message>>,
65    pub state: Option<StateT>,
66    pub stop_propagation: bool,
67}
68
69impl<StateT> Default for AgentStateMutation<StateT> {
70    fn default() -> Self {
71        Self {
72            messages: None,
73            state: None,
74            stop_propagation: false,
75        }
76    }
77}
78
79// Error types
80#[derive(Error, Debug)]
81pub enum AgentError {
82    #[error("Agent execution failed: {message}")]
83    ExecutionError { message: String },
84    #[error("Invalid configuration: {message}")]
85    ConfigError { message: String },
86    #[error("Serialization error: {source}")]
87    SerializationError {
88        #[from]
89        source: serde_json::Error,
90    },
91}
92
93// TODO: Expand documentation
94/// Agent trait
95#[async_trait::async_trait]
96pub trait Agent<StateT = JsonValue, FwdPropsT = JsonValue>: Send + Sync
97where
98    StateT: AgentState,
99    FwdPropsT: FwdProps,
100{
101    async fn run(
102        &self,
103        input: &RunAgentInput<StateT, FwdPropsT>,
104    ) -> Result<EventStream<'async_trait, StateT>, AgentError>;
105
106    // TODO: Expand documentation
107    /// The main execution method, containing the full pipeline logic.
108    async fn run_agent(
109        &self,
110        params: &RunAgentParams<StateT, FwdPropsT>,
111        subscribers: impl IntoSubscribers<StateT, FwdPropsT>,
112    ) -> Result<RunAgentResult<StateT>, AgentError> {
113        // TODO: Use Agent ID?
114        let _agent_id = AgentId::random();
115
116        let input = RunAgentInput {
117            thread_id: ThreadId::random(),
118            run_id: params.run_id.clone().unwrap_or_else(RunId::random),
119            state: params.state.clone(),
120            messages: params.messages.clone(),
121            tools: params.tools.clone().unwrap_or_default(),
122            context: params.context.clone().unwrap_or_default(),
123            // TODO: Find suitable default value
124            forwarded_props: params.forwarded_props.clone().unwrap(),
125        };
126        let current_message_ids: HashSet<&MessageId> =
127            params.messages.iter().map(|m| m.id()).collect();
128
129        // Initialize event handler with the current state
130        let subscribers = subscribers.into_subscribers();
131        let mut event_handler = EventHandler::new(
132            params.messages.clone(),
133            params.state.clone(),
134            &input,
135            subscribers,
136        );
137
138        let mut stream = self.run(&input).await?.fuse();
139
140        while let Some(event_result) = stream.next().await {
141            match event_result {
142                Ok(event) => {
143                    let mutation = event_handler.handle_event(&event).await?;
144                    event_handler.apply_mutation(mutation).await?;
145                }
146                Err(e) => {
147                    event_handler.on_error(&e).await?;
148                    return Err(e);
149                }
150            }
151        }
152
153        // Finalize the run
154        event_handler.on_finalize().await?;
155
156        // Collect new messages
157        let new_messages = event_handler
158            .messages
159            .iter()
160            .filter(|m| !current_message_ids.contains(&m.id()))
161            .cloned()
162            .collect();
163
164        Ok(RunAgentResult {
165            result: event_handler.result,
166            new_messages,
167            new_state: event_handler.state,
168        })
169    }
170}