1use async_trait::async_trait;
2use futures::StreamExt;
3use log::{debug, trace};
4use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
5use reqwest::{Client as HttpClient, Url};
6use std::str::FromStr;
7
8use ag_ui_core::event::Event;
9use ag_ui_core::types::input::RunAgentInput;
10use ag_ui_core::{AgentState, FwdProps};
11
12use crate::Agent;
13use crate::agent::AgentError;
14use crate::agent::AgentError::SerializationError;
15use crate::sse::SseResponseExt;
16use crate::stream::EventStream;
17
18pub struct HttpAgent {
19 http_client: HttpClient,
20 base_url: Url,
21 header_map: HeaderMap,
22}
23
24impl HttpAgent {
25 pub fn new(base_url: Url, header_map: HeaderMap) -> Self {
26 let http_client = HttpClient::new();
27 let mut header_map: HeaderMap = header_map;
28
29 header_map.insert("Content-Type", HeaderValue::from_static("application/json"));
30 Self {
31 http_client,
32 base_url,
33 header_map,
34 }
35 }
36
37 pub fn builder() -> HttpAgentBuilder {
38 HttpAgentBuilder::new()
39 }
40}
41
42pub struct HttpAgentBuilder {
43 base_url: Option<Url>,
44 header_map: HeaderMap,
45 http_client: Option<HttpClient>,
46}
47
48impl HttpAgentBuilder {
49 pub fn new() -> Self {
50 Self {
51 base_url: None,
52 header_map: HeaderMap::new(),
53 http_client: None,
54 }
55 }
56
57 pub fn with_url(mut self, base_url: Url) -> Self {
59 self.base_url = Some(base_url);
60 self
61 }
62
63 pub fn with_url_str(mut self, url: &str) -> Result<Self, AgentError> {
65 let parsed_url = Url::parse(url).map_err(|e| AgentError::ConfigError {
66 message: format!("Invalid URL '{url}': {e}"),
67 })?;
68 self.base_url = Some(parsed_url);
69 Ok(self)
70 }
71
72 pub fn with_headers(mut self, header_map: HeaderMap) -> Self {
74 self.header_map = header_map;
75 self
76 }
77
78 pub fn with_header(mut self, name: &str, value: &str) -> Result<Self, AgentError> {
80 let header_name = HeaderName::from_str(name).map_err(|e| AgentError::ConfigError {
81 message: format!("Invalid header name '{value}': {e}"),
82 })?;
83 let header_value = HeaderValue::from_str(value).map_err(|e| AgentError::ConfigError {
84 message: format!("Invalid header value '{value}': {e}"),
85 })?;
86 self.header_map.insert(header_name, header_value);
87 Ok(self)
88 }
89
90 pub fn with_header_typed(mut self, name: HeaderName, value: HeaderValue) -> Self {
92 self.header_map.insert(name, value);
93 self
94 }
95
96 pub fn with_bearer_token(self, token: &str) -> Result<Self, AgentError> {
98 let auth_value = format!("Bearer {token}");
99 self.with_header("Authorization", &auth_value)
100 }
101
102 pub fn with_http_client(mut self, client: HttpClient) -> Self {
104 self.http_client = Some(client);
105 self
106 }
107
108 pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
110 let client = HttpClient::builder()
111 .timeout(std::time::Duration::from_secs(timeout_secs))
112 .build()
113 .unwrap_or_else(|_| HttpClient::new());
114 self.http_client = Some(client);
115 self
116 }
117
118 pub fn build(self) -> Result<HttpAgent, AgentError> {
119 let base_url = self.base_url.ok_or(AgentError::ConfigError {
120 message: "Base URL is required".to_string(),
121 })?;
122
123 if !["http", "https"].contains(&base_url.scheme()) {
125 return Err(AgentError::ConfigError {
126 message: format!("Unsupported URL scheme: {}", base_url.scheme()),
127 });
128 }
129
130 let http_client = self.http_client.unwrap_or_default();
131
132 Ok(HttpAgent {
133 http_client,
134 base_url,
135 header_map: self.header_map,
136 })
137 }
138}
139
140impl Default for HttpAgentBuilder {
141 fn default() -> Self {
142 Self::new()
143 }
144}
145
146impl From<reqwest::Error> for AgentError {
147 fn from(err: reqwest::Error) -> Self {
148 AgentError::ExecutionError {
149 message: err.to_string(),
150 }
151 }
152}
153
154#[async_trait]
155impl<StateT: AgentState, FwdPropsT: FwdProps> Agent<StateT, FwdPropsT> for HttpAgent {
156 async fn run(
157 &self,
158 input: &RunAgentInput<StateT, FwdPropsT>,
159 ) -> Result<EventStream<'async_trait, StateT>, AgentError> {
160 let response = self
162 .http_client
163 .post(self.base_url.clone())
164 .json(input)
165 .headers(self.header_map.clone())
166 .send()
167 .await?;
168
169 let stream = response
171 .event_source()
172 .await
173 .map(|result| match result {
174 Ok(event) => {
175 trace!("Received event: {event:?}");
176
177 let event_data: Event<StateT> = serde_json::from_str(&event.data)
178 .map_err(|err| SerializationError { source: err })?;
179 debug!("Deserialized event: {event_data:?}");
180
181 Ok(event_data)
182 }
183 Err(err) => Err(AgentError::ExecutionError {
184 message: err.to_string(),
185 }),
186 })
187 .boxed();
188 Ok(stream)
189 }
190}