lattice_sdk/core/
http_client.rs

1use crate::{join_url, ApiError, ClientConfig, RequestOptions};
2use futures::{Stream, StreamExt};
3use reqwest::{
4    header::{HeaderName, HeaderValue},
5    Client, Method, Request, Response,
6};
7use serde::de::DeserializeOwned;
8use std::{
9    pin::Pin,
10    str::FromStr,
11    task::{Context, Poll},
12};
13
14/// A streaming byte stream for downloading files efficiently
15pub struct ByteStream {
16    content_length: Option<u64>,
17    inner: Pin<Box<dyn Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send>>,
18}
19
20impl ByteStream {
21    /// Create a new ByteStream from a Response
22    pub(crate) fn new(response: Response) -> Self {
23        let content_length = response.content_length();
24        let stream = response.bytes_stream();
25
26        Self {
27            content_length,
28            inner: Box::pin(stream),
29        }
30    }
31
32    /// Collect the entire stream into a `Vec<u8>`
33    ///
34    /// This consumes the stream and buffers all data into memory.
35    /// For large files, prefer using `try_next()` to process chunks incrementally.
36    ///
37    /// # Example
38    /// ```no_run
39    /// let stream = client.download_file().await?;
40    /// let bytes = stream.collect().await?;
41    /// ```
42    pub async fn collect(mut self) -> Result<Vec<u8>, ApiError> {
43        let mut result = Vec::new();
44        while let Some(chunk) = self.inner.next().await {
45            result.extend_from_slice(&chunk.map_err(ApiError::Network)?);
46        }
47        Ok(result)
48    }
49
50    /// Get the next chunk from the stream
51    ///
52    /// Returns `Ok(Some(bytes))` if a chunk is available,
53    /// `Ok(None)` if the stream is finished, or an error.
54    ///
55    /// # Example
56    /// ```no_run
57    /// let mut stream = client.download_file().await?;
58    /// while let Some(chunk) = stream.try_next().await? {
59    ///     process_chunk(&chunk);
60    /// }
61    /// ```
62    pub async fn try_next(&mut self) -> Result<Option<bytes::Bytes>, ApiError> {
63        match self.inner.next().await {
64            Some(Ok(bytes)) => Ok(Some(bytes)),
65            Some(Err(e)) => Err(ApiError::Network(e)),
66            None => Ok(None),
67        }
68    }
69
70    /// Get the content length from response headers if available
71    pub fn content_length(&self) -> Option<u64> {
72        self.content_length
73    }
74}
75
76impl Stream for ByteStream {
77    type Item = Result<bytes::Bytes, ApiError>;
78
79    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
80        match self.inner.as_mut().poll_next(cx) {
81            Poll::Ready(Some(Ok(bytes))) => Poll::Ready(Some(Ok(bytes))),
82            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(ApiError::Network(e)))),
83            Poll::Ready(None) => Poll::Ready(None),
84            Poll::Pending => Poll::Pending,
85        }
86    }
87}
88
89/// Internal HTTP client that handles requests with authentication and retries
90#[derive(Clone)]
91pub struct HttpClient {
92    client: Client,
93    config: ClientConfig,
94}
95
96impl HttpClient {
97    pub fn new(config: ClientConfig) -> Result<Self, ApiError> {
98        let client = Client::builder()
99            .timeout(config.timeout)
100            .user_agent(&config.user_agent)
101            .build()
102            .map_err(ApiError::Network)?;
103
104        Ok(Self { client, config })
105    }
106
107    /// Execute a request with the given method, path, and options
108    pub async fn execute_request<T>(
109        &self,
110        method: Method,
111        path: &str,
112        body: Option<serde_json::Value>,
113        query_params: Option<Vec<(String, String)>>,
114        options: Option<RequestOptions>,
115    ) -> Result<T, ApiError>
116    where
117        T: DeserializeOwned, // Generic T: DeserializeOwned means the response will be automatically deserialized into whatever type you specify:
118    {
119        let url = join_url(&self.config.base_url, path);
120        let mut request = self.client.request(method, &url);
121
122        // Apply query parameters if provided
123        if let Some(params) = query_params {
124            request = request.query(&params);
125        }
126
127        // Apply additional query parameters from options
128        if let Some(opts) = &options {
129            if !opts.additional_query_params.is_empty() {
130                request = request.query(&opts.additional_query_params);
131            }
132        }
133
134        // Apply body if provided
135        if let Some(body) = body {
136            request = request.json(&body);
137        }
138
139        // Build the request
140        let mut req = request.build().map_err(|e| ApiError::Network(e))?;
141
142        // Apply authentication and headers
143        self.apply_auth_headers(&mut req, &options)?;
144        self.apply_custom_headers(&mut req, &options)?;
145
146        // Execute with retries
147        let response = self.execute_with_retries(req, &options).await?;
148        self.parse_response(response).await
149    }
150
151    fn apply_auth_headers(
152        &self,
153        request: &mut Request,
154        options: &Option<RequestOptions>,
155    ) -> Result<(), ApiError> {
156        let headers = request.headers_mut();
157
158        // Apply API key (request options override config)
159        let api_key = options
160            .as_ref()
161            .and_then(|opts| opts.api_key.as_ref())
162            .or(self.config.api_key.as_ref());
163
164        if let Some(key) = api_key {
165            headers.insert("api_key", key.parse().map_err(|_| ApiError::InvalidHeader)?);
166        }
167
168        // Apply bearer token (request options override config)
169        let token = options
170            .as_ref()
171            .and_then(|opts| opts.token.as_ref())
172            .or(self.config.token.as_ref());
173
174        if let Some(token) = token {
175            let auth_value = format!("Bearer {}", token);
176            headers.insert(
177                "Authorization",
178                auth_value.parse().map_err(|_| ApiError::InvalidHeader)?,
179            );
180        }
181
182        Ok(())
183    }
184
185    fn apply_custom_headers(
186        &self,
187        request: &mut Request,
188        options: &Option<RequestOptions>,
189    ) -> Result<(), ApiError> {
190        let headers = request.headers_mut();
191
192        // Apply config-level custom headers
193        for (key, value) in &self.config.custom_headers {
194            headers.insert(
195                HeaderName::from_str(key).map_err(|_| ApiError::InvalidHeader)?,
196                HeaderValue::from_str(value).map_err(|_| ApiError::InvalidHeader)?,
197            );
198        }
199
200        // Apply request-level custom headers (override config)
201        if let Some(options) = options {
202            for (key, value) in &options.additional_headers {
203                headers.insert(
204                    HeaderName::from_str(key).map_err(|_| ApiError::InvalidHeader)?,
205                    HeaderValue::from_str(value).map_err(|_| ApiError::InvalidHeader)?,
206                );
207            }
208        }
209
210        Ok(())
211    }
212
213    async fn execute_with_retries(
214        &self,
215        request: Request,
216        options: &Option<RequestOptions>,
217    ) -> Result<Response, ApiError> {
218        let max_retries = options
219            .as_ref()
220            .and_then(|opts| opts.max_retries)
221            .unwrap_or(self.config.max_retries);
222
223        let mut last_error = None;
224
225        for attempt in 0..=max_retries {
226            let cloned_request = request.try_clone().ok_or(ApiError::RequestClone)?;
227
228            match self.client.execute(cloned_request).await {
229                Ok(response) if response.status().is_success() => return Ok(response),
230                Ok(response) => {
231                    let status_code = response.status().as_u16();
232                    let body = response.text().await.ok();
233                    return Err(ApiError::from_response(status_code, body.as_deref()));
234                }
235                Err(e) if attempt < max_retries => {
236                    last_error = Some(e);
237                    // Exponential backoff
238                    let delay = std::time::Duration::from_millis(100 * 2_u64.pow(attempt));
239                    tokio::time::sleep(delay).await;
240                }
241                Err(e) => return Err(ApiError::Network(e)),
242            }
243        }
244
245        Err(ApiError::Network(last_error.unwrap()))
246    }
247
248    async fn parse_response<T>(&self, response: Response) -> Result<T, ApiError>
249    where
250        T: DeserializeOwned,
251    {
252        let text = response.text().await.map_err(ApiError::Network)?;
253        serde_json::from_str(&text).map_err(ApiError::Serialization)
254    }
255
256    /// Execute a request and return a streaming response (for large file downloads)
257    ///
258    /// This method returns a `ByteStream` that can be used to download large files
259    /// efficiently without loading the entire content into memory. The stream can be
260    /// consumed chunk by chunk, written directly to disk, or collected into bytes.
261    ///
262    /// # Examples
263    ///
264    /// **Option 1: Collect all bytes into memory**
265    /// ```no_run
266    /// let stream = client.execute_stream_request(
267    ///     Method::GET,
268    ///     "/file",
269    ///     None,
270    ///     None,
271    ///     None,
272    /// ).await?;
273    ///
274    /// let bytes = stream.collect().await?;
275    /// ```
276    ///
277    /// **Option 2: Process chunks with try_next()**
278    /// ```no_run
279    /// let mut stream = client.execute_stream_request(
280    ///     Method::GET,
281    ///     "/large-file",
282    ///     None,
283    ///     None,
284    ///     None,
285    /// ).await?;
286    ///
287    /// while let Some(chunk) = stream.try_next().await? {
288    ///     process_chunk(&chunk);
289    /// }
290    /// ```
291    ///
292    /// **Option 3: Stream with futures::Stream trait**
293    /// ```no_run
294    /// use futures::StreamExt;
295    ///
296    /// let stream = client.execute_stream_request(
297    ///     Method::GET,
298    ///     "/large-file",
299    ///     None,
300    ///     None,
301    ///     None,
302    /// ).await?;
303    ///
304    /// let mut file = tokio::fs::File::create("output.mp4").await?;
305    /// let mut stream = std::pin::pin!(stream);
306    /// while let Some(chunk) = stream.next().await {
307    ///     let chunk = chunk?;
308    ///     tokio::io::AsyncWriteExt::write_all(&mut file, &chunk).await?;
309    /// }
310    /// ```
311    pub async fn execute_stream_request(
312        &self,
313        method: Method,
314        path: &str,
315        body: Option<serde_json::Value>,
316        query_params: Option<Vec<(String, String)>>,
317        options: Option<RequestOptions>,
318    ) -> Result<ByteStream, ApiError> {
319        let url = join_url(&self.config.base_url, path);
320        let mut request = self.client.request(method, &url);
321
322        // Apply query parameters if provided
323        if let Some(params) = query_params {
324            request = request.query(&params);
325        }
326
327        // Apply additional query parameters from options
328        if let Some(opts) = &options {
329            if !opts.additional_query_params.is_empty() {
330                request = request.query(&opts.additional_query_params);
331            }
332        }
333
334        // Apply body if provided
335        if let Some(body) = body {
336            request = request.json(&body);
337        }
338
339        // Build the request
340        let mut req = request.build().map_err(|e| ApiError::Network(e))?;
341
342        // Apply authentication and headers
343        self.apply_auth_headers(&mut req, &options)?;
344        self.apply_custom_headers(&mut req, &options)?;
345
346        // Execute with retries
347        let response = self.execute_with_retries(req, &options).await?;
348
349        // Return streaming response
350        Ok(ByteStream::new(response))
351    }
352
353    /// Execute a request and return an SSE stream
354    ///
355    /// This method returns an `SseStream<T>` that automatically parses
356    /// Server-Sent Events and deserializes the JSON data in each event.
357    ///
358    /// # SSE-Specific Headers
359    ///
360    /// This method automatically sets the following headers **after** applying custom headers,
361    /// which means these headers will override any user-supplied values:
362    /// - `Accept: text/event-stream` - Required for SSE protocol
363    /// - `Cache-Control: no-store` - Prevents caching of streaming responses
364    ///
365    /// This ensures proper SSE behavior even if custom headers are provided.
366    ///
367    /// # Example
368    /// ```no_run
369    /// use futures::StreamExt;
370    ///
371    /// let stream = client.execute_sse_request::<CompletionChunk>(
372    ///     Method::POST,
373    ///     "/stream",
374    ///     Some(serde_json::json!({"query": "Hello"})),
375    ///     None,
376    ///     None,
377    ///     Some("[[DONE]]".to_string()),
378    /// ).await?;
379    ///
380    /// let mut stream = std::pin::pin!(stream);
381    /// while let Some(chunk) = stream.next().await {
382    ///     let chunk = chunk?;
383    ///     println!("Received: {:?}", chunk);
384    /// }
385    /// ```
386    pub async fn execute_sse_request<T>(
387        &self,
388        method: Method,
389        path: &str,
390        body: Option<serde_json::Value>,
391        query_params: Option<Vec<(String, String)>>,
392        options: Option<RequestOptions>,
393        terminator: Option<String>,
394    ) -> Result<crate::SseStream<T>, ApiError>
395    where
396        T: DeserializeOwned + Send + 'static,
397    {
398        let url = join_url(&self.config.base_url, path);
399        let mut request = self.client.request(method, &url);
400
401        // Apply query parameters if provided
402        if let Some(params) = query_params {
403            request = request.query(&params);
404        }
405
406        // Apply additional query parameters from options
407        if let Some(opts) = &options {
408            if !opts.additional_query_params.is_empty() {
409                request = request.query(&opts.additional_query_params);
410            }
411        }
412
413        // Apply body if provided
414        if let Some(body) = body {
415            request = request.json(&body);
416        }
417
418        // Build the request
419        let mut req = request.build().map_err(|e| ApiError::Network(e))?;
420
421        // Apply authentication and headers
422        self.apply_auth_headers(&mut req, &options)?;
423        self.apply_custom_headers(&mut req, &options)?;
424
425        // SSE-specific headers
426        req.headers_mut().insert(
427            "Accept",
428            "text/event-stream"
429                .parse()
430                .map_err(|_| ApiError::InvalidHeader)?,
431        );
432        req.headers_mut().insert(
433            "Cache-Control",
434            "no-store".parse().map_err(|_| ApiError::InvalidHeader)?,
435        );
436
437        // Execute with retries
438        let response = self.execute_with_retries(req, &options).await?;
439
440        // Return SSE stream
441        crate::SseStream::new(response, terminator).await
442    }
443}