Skip to main content

horizons_ai/
client.rs

1use crate::error::{HorizonsError, HorizonsErrorKind};
2use crate::types::Health;
3use async_stream::try_stream;
4use bytes::Bytes;
5use futures_util::Stream;
6use futures_util::StreamExt;
7use reqwest::header::{HeaderMap, HeaderName, HeaderValue, ACCEPT, AUTHORIZATION};
8use reqwest::{Method, Response};
9use serde::de::DeserializeOwned;
10use serde::Serialize;
11use serde_json::Value;
12use uuid::Uuid;
13
14#[derive(Debug, Default, Clone)]
15pub struct ClientOptions {
16    pub project_id: Option<Uuid>,
17    pub user_id: Option<Uuid>,
18    pub user_email: Option<String>,
19    pub agent_id: Option<String>,
20    pub api_key: Option<String>,
21}
22
23#[derive(Debug, Clone)]
24pub struct HorizonsClient {
25    base_url: String,
26    org_id: Uuid,
27    opts: ClientOptions,
28    http: reqwest::Client,
29}
30
31impl HorizonsClient {
32    pub fn new(base_url: impl Into<String>, org_id: Uuid) -> Self {
33        let base_url = base_url.into();
34        let base_url = base_url.trim_end_matches('/').to_string();
35        Self {
36            base_url,
37            org_id,
38            opts: ClientOptions::default(),
39            http: reqwest::Client::new(),
40        }
41    }
42
43    pub fn with_project_id(mut self, project_id: Uuid) -> Self {
44        self.opts.project_id = Some(project_id);
45        self
46    }
47
48    pub fn with_user(mut self, user_id: Uuid, user_email: Option<impl Into<String>>) -> Self {
49        self.opts.user_id = Some(user_id);
50        self.opts.user_email = user_email.map(|s| s.into());
51        self
52    }
53
54    pub fn with_agent_id(mut self, agent_id: impl Into<String>) -> Self {
55        self.opts.agent_id = Some(agent_id.into());
56        self
57    }
58
59    pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
60        self.opts.api_key = Some(api_key.into());
61        self
62    }
63
64    pub fn org_id(&self) -> Uuid {
65        self.org_id
66    }
67
68    pub fn base_url(&self) -> &str {
69        &self.base_url
70    }
71
72    fn headers(&self, extra: Option<HeaderMap>) -> Result<HeaderMap, HorizonsError> {
73        fn insert(
74            headers: &mut HeaderMap,
75            name: &'static str,
76            value: impl AsRef<str>,
77        ) -> Result<(), HorizonsError> {
78            let name = HeaderName::from_bytes(name.as_bytes()).map_err(|e| {
79                HorizonsError::new(HorizonsErrorKind::Serialization, None, e.to_string())
80            })?;
81            let value = HeaderValue::from_str(value.as_ref()).map_err(|e| {
82                HorizonsError::new(HorizonsErrorKind::Serialization, None, e.to_string())
83            })?;
84            headers.insert(name, value);
85            Ok(())
86        }
87
88        let mut headers = HeaderMap::new();
89        insert(&mut headers, "x-org-id", self.org_id.to_string())?;
90        if let Some(project_id) = self.opts.project_id {
91            insert(&mut headers, "x-project-id", project_id.to_string())?;
92        }
93        if let Some(user_id) = self.opts.user_id {
94            insert(&mut headers, "x-user-id", user_id.to_string())?;
95            if let Some(email) = self.opts.user_email.as_deref() {
96                insert(&mut headers, "x-user-email", email)?;
97            }
98        }
99        if let Some(agent_id) = self.opts.agent_id.as_deref() {
100            insert(&mut headers, "x-agent-id", agent_id)?;
101        }
102        if let Some(api_key) = self.opts.api_key.as_deref() {
103            let v = format!("Bearer {api_key}");
104            headers.insert(
105                AUTHORIZATION,
106                HeaderValue::from_str(&v).map_err(|e| {
107                    HorizonsError::new(HorizonsErrorKind::Serialization, None, e.to_string())
108                })?,
109            );
110        }
111        if let Some(extra) = extra {
112            headers.extend(extra);
113        }
114        Ok(headers)
115    }
116
117    fn url(&self, path: &str) -> String {
118        if path.starts_with('/') {
119            format!("{}{}", self.base_url, path)
120        } else {
121            format!("{}/{}", self.base_url, path)
122        }
123    }
124
125    pub(crate) async fn send(
126        &self,
127        method: Method,
128        path: &str,
129        query: Option<&impl Serialize>,
130        body: Option<&impl Serialize>,
131        extra_headers: Option<HeaderMap>,
132    ) -> Result<Response, HorizonsError> {
133        let mut req = self.http.request(method, self.url(path));
134        if let Some(q) = query {
135            req = req.query(q);
136        }
137        if let Some(b) = body {
138            req = req.json(b);
139        }
140        req = req.headers(self.headers(extra_headers)?);
141        Ok(req.send().await?)
142    }
143
144    pub(crate) async fn map_error(&self, resp: Response) -> Result<HorizonsError, HorizonsError> {
145        let status = resp.status();
146        let code = status.as_u16();
147        let text = resp.text().await.unwrap_or_default();
148        let kind = if code == 404 {
149            HorizonsErrorKind::NotFound
150        } else if code == 401 || code == 403 {
151            HorizonsErrorKind::Auth
152        } else if (400..500).contains(&code) {
153            HorizonsErrorKind::Validation
154        } else {
155            HorizonsErrorKind::Server
156        };
157        Ok(HorizonsError::new(
158            kind,
159            Some(code),
160            if text.is_empty() {
161                status.to_string()
162            } else {
163                text
164            },
165        ))
166    }
167
168    pub async fn request_json<T: DeserializeOwned>(
169        &self,
170        method: Method,
171        path: &str,
172        query: Option<&impl Serialize>,
173        body: Option<&impl Serialize>,
174    ) -> Result<T, HorizonsError> {
175        let resp = self.send(method, path, query, body, None).await?;
176        if resp.status().is_success() {
177            if resp.status().as_u16() == 204 {
178                // Can't deserialize an empty body. Match the Python SDK behavior.
179                return Ok(serde_json::from_value(Value::Null)?);
180            }
181            return Ok(resp.json::<T>().await?);
182        }
183        Err(self.map_error(resp).await?)
184    }
185
186    pub async fn request_value(
187        &self,
188        method: Method,
189        path: &str,
190        query: Option<&impl Serialize>,
191        body: Option<&impl Serialize>,
192    ) -> Result<Value, HorizonsError> {
193        self.request_json::<Value>(method, path, query, body).await
194    }
195
196    pub fn sse_post(
197        &self,
198        path: &str,
199        body: impl Serialize + Send + Sync + 'static,
200    ) -> impl Stream<Item = Result<Value, HorizonsError>> + Send + 'static {
201        let client = self.clone();
202        let path = path.to_string();
203        try_stream! {
204            let mut extra = HeaderMap::new();
205            extra.insert(ACCEPT, HeaderValue::from_static("text/event-stream"));
206            let resp = client
207                .send(Method::POST, &path, None::<&()>, Some(&body), Some(extra))
208                .await?;
209
210            if resp.status().is_success() {
211                let mut stream = resp.bytes_stream();
212                let mut buf: Vec<u8> = Vec::new();
213                let mut cur_event: Option<String> = None;
214
215                while let Some(chunk) = stream.next().await {
216                    let chunk: Bytes = chunk?;
217                    buf.extend_from_slice(&chunk);
218
219                    while let Some(pos) = buf.iter().position(|b| *b == b'\n') {
220                        let mut line = buf.drain(..=pos).collect::<Vec<u8>>();
221                        if line.ends_with(b"\n") { line.pop(); }
222                        if line.ends_with(b"\r") { line.pop(); }
223                        let line = String::from_utf8_lossy(&line).to_string();
224
225                        if line.is_empty() {
226                            cur_event = None;
227                            continue;
228                        }
229                        if line.starts_with(':') {
230                            continue;
231                        }
232                        if let Some(rest) = line.strip_prefix("event:") {
233                            cur_event = Some(rest.trim().to_string());
234                            continue;
235                        }
236                        if let Some(rest) = line.strip_prefix("data:") {
237                            let data_str = rest.trim();
238                            if data_str.is_empty() {
239                                continue;
240                            }
241
242                            // Some streams may send control events.
243                            if let Some(ev) = cur_event.as_deref() {
244                                if ev == "done" {
245                                    return;
246                                }
247                                if ev == "error" {
248                                    Err(HorizonsError::new(
249                                        HorizonsErrorKind::Stream,
250                                        None,
251                                        data_str.to_string(),
252                                    ))?;
253                                }
254                            }
255
256                            // Best-effort JSON decode; skip non-JSON lines.
257                            if let Ok(v) = serde_json::from_str::<Value>(data_str) {
258                                yield v;
259                            }
260                        }
261                    }
262                }
263            } else {
264                Err(client.map_error(resp).await?)?;
265            }
266        }
267    }
268
269    pub async fn health(&self) -> Result<Health, HorizonsError> {
270        self.request_json(Method::GET, "/api/v1/health", None::<&()>, None::<&()>)
271            .await
272    }
273
274    pub fn onboard(&self) -> crate::apis::OnboardApi {
275        crate::apis::OnboardApi::new(self.clone())
276    }
277
278    pub fn events(&self) -> crate::apis::EventsApi {
279        crate::apis::EventsApi::new(self.clone())
280    }
281
282    pub fn agents(&self) -> crate::apis::AgentsApi {
283        crate::apis::AgentsApi::new(self.clone())
284    }
285
286    pub fn actions(&self) -> crate::apis::ActionsApi {
287        crate::apis::ActionsApi::new(self.clone())
288    }
289
290    pub fn memory(&self) -> crate::apis::MemoryApi {
291        crate::apis::MemoryApi::new(self.clone())
292    }
293
294    pub fn context_refresh(&self) -> crate::apis::ContextRefreshApi {
295        crate::apis::ContextRefreshApi::new(self.clone())
296    }
297
298    pub fn optimization(&self) -> crate::apis::OptimizationApi {
299        crate::apis::OptimizationApi::new(self.clone())
300    }
301
302    pub fn evaluation(&self) -> crate::apis::EvaluationApi {
303        crate::apis::EvaluationApi::new(self.clone())
304    }
305
306    pub fn engine(&self) -> crate::apis::EngineApi {
307        crate::apis::EngineApi::new(self.clone())
308    }
309
310    pub fn pipelines(&self) -> crate::apis::PipelinesApi {
311        crate::apis::PipelinesApi::new(self.clone())
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    #[test]
320    fn trims_trailing_slash() {
321        let c = HorizonsClient::new("http://localhost:8000/", Uuid::nil());
322        assert_eq!(c.base_url(), "http://localhost:8000");
323    }
324}