ag_ui_client/
http.rs

1use async_trait::async_trait;
2use futures::StreamExt;
3use log::{debug, trace};
4use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
5use reqwest::{Client as HttpClient, Url};
6use std::str::FromStr;
7
8use ag_ui_core::event::Event;
9use ag_ui_core::types::input::RunAgentInput;
10use ag_ui_core::{AgentState, FwdProps};
11
12use crate::Agent;
13use crate::agent::AgentError;
14use crate::agent::AgentError::SerializationError;
15use crate::sse::SseResponseExt;
16use crate::stream::EventStream;
17
18pub struct HttpAgent {
19    http_client: HttpClient,
20    base_url: Url,
21    header_map: HeaderMap,
22}
23
24impl HttpAgent {
25    pub fn new(base_url: Url, header_map: HeaderMap) -> Self {
26        let http_client = HttpClient::new();
27        let mut header_map: HeaderMap = header_map;
28
29        header_map.insert("Content-Type", HeaderValue::from_static("application/json"));
30        Self {
31            http_client,
32            base_url,
33            header_map,
34        }
35    }
36
37    pub fn builder() -> HttpAgentBuilder {
38        HttpAgentBuilder::new()
39    }
40}
41
42pub struct HttpAgentBuilder {
43    base_url: Option<Url>,
44    header_map: HeaderMap,
45    http_client: Option<HttpClient>,
46}
47
48impl HttpAgentBuilder {
49    pub fn new() -> Self {
50        Self {
51            base_url: None,
52            header_map: HeaderMap::new(),
53            http_client: None,
54        }
55    }
56
57    /// Set the base URL from a Url instance
58    pub fn with_url(mut self, base_url: Url) -> Self {
59        self.base_url = Some(base_url);
60        self
61    }
62
63    /// Set the base URL from a string, returning Result for validation
64    pub fn with_url_str(mut self, url: &str) -> Result<Self, AgentError> {
65        let parsed_url = Url::parse(url).map_err(|e| AgentError::ConfigError {
66            message: format!("Invalid URL '{url}': {e}"),
67        })?;
68        self.base_url = Some(parsed_url);
69        Ok(self)
70    }
71
72    /// Replace all headers with the provided HeaderMap
73    pub fn with_headers(mut self, header_map: HeaderMap) -> Self {
74        self.header_map = header_map;
75        self
76    }
77
78    /// Add a single header by name and value strings
79    pub fn with_header(mut self, name: &str, value: &str) -> Result<Self, AgentError> {
80        let header_name = HeaderName::from_str(name).map_err(|e| AgentError::ConfigError {
81            message: format!("Invalid header name '{value}': {e}"),
82        })?;
83        let header_value = HeaderValue::from_str(value).map_err(|e| AgentError::ConfigError {
84            message: format!("Invalid header value '{value}': {e}"),
85        })?;
86        self.header_map.insert(header_name, header_value);
87        Ok(self)
88    }
89
90    /// Add a header using HeaderName and HeaderValue directly
91    pub fn with_header_typed(mut self, name: HeaderName, value: HeaderValue) -> Self {
92        self.header_map.insert(name, value);
93        self
94    }
95
96    /// Add an authorization bearer token
97    pub fn with_bearer_token(self, token: &str) -> Result<Self, AgentError> {
98        let auth_value = format!("Bearer {token}");
99        self.with_header("Authorization", &auth_value)
100    }
101
102    /// Set a custom HTTP client
103    pub fn with_http_client(mut self, client: HttpClient) -> Self {
104        self.http_client = Some(client);
105        self
106    }
107
108    /// Set request timeout in seconds
109    pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
110        let client = HttpClient::builder()
111            .timeout(std::time::Duration::from_secs(timeout_secs))
112            .build()
113            .unwrap_or_else(|_| HttpClient::new());
114        self.http_client = Some(client);
115        self
116    }
117
118    pub fn build(self) -> Result<HttpAgent, AgentError> {
119        let base_url = self.base_url.ok_or(AgentError::ConfigError {
120            message: "Base URL is required".to_string(),
121        })?;
122
123        // Validate URL scheme
124        if !["http", "https"].contains(&base_url.scheme()) {
125            return Err(AgentError::ConfigError {
126                message: format!("Unsupported URL scheme: {}", base_url.scheme()),
127            });
128        }
129
130        let http_client = self.http_client.unwrap_or_default();
131
132        Ok(HttpAgent {
133            http_client,
134            base_url,
135            header_map: self.header_map,
136        })
137    }
138}
139
140impl Default for HttpAgentBuilder {
141    fn default() -> Self {
142        Self::new()
143    }
144}
145
146impl From<reqwest::Error> for AgentError {
147    fn from(err: reqwest::Error) -> Self {
148        AgentError::ExecutionError {
149            message: err.to_string(),
150        }
151    }
152}
153
154#[async_trait]
155impl<StateT: AgentState, FwdPropsT: FwdProps> Agent<StateT, FwdPropsT> for HttpAgent {
156    async fn run(
157        &self,
158        input: &RunAgentInput<StateT, FwdPropsT>,
159    ) -> Result<EventStream<'async_trait, StateT>, AgentError> {
160        // Send the request and get the response
161        let response = self
162            .http_client
163            .post(self.base_url.clone())
164            .json(input)
165            .headers(self.header_map.clone())
166            .send()
167            .await?;
168
169        // Convert the response to an SSE event stream
170        let stream = response
171            .event_source()
172            .await
173            .map(|result| match result {
174                Ok(event) => {
175                    trace!("Received event: {event:?}");
176
177                    let event_data: Event<StateT> = serde_json::from_str(&event.data)
178                        .map_err(|err| SerializationError { source: err })?;
179                    debug!("Deserialized event: {event_data:?}");
180
181                    Ok(event_data)
182                }
183                Err(err) => Err(AgentError::ExecutionError {
184                    message: err.to_string(),
185                }),
186            })
187            .boxed();
188        Ok(stream)
189    }
190}