async_llm/http/
simple.rs

1use std::pin::Pin;
2
3use futures::Stream;
4use reqwest_eventsource::RequestBuilderExt;
5use serde::{de::DeserializeOwned, Serialize};
6
7use crate::{error::Error, providers::Config};
8
9use super::{stream::stream, HttpClient};
10
11#[derive(Debug, Clone)]
12pub struct SimpleHttpClient<C: Config> {
13    pub(crate) client: reqwest::Client,
14    pub(crate) config: C,
15}
16
17#[async_trait::async_trait]
18impl<C: Config> HttpClient for SimpleHttpClient<C> {
19    async fn post<I: Serialize + Send, O: DeserializeOwned>(
20        &self,
21        path: &str,
22        request: I,
23    ) -> Result<O, Error> {
24        let url = self.config.url(path);
25        let headers = self.config.headers()?;
26        let query = self.config.query();
27        let resp = self
28            .client
29            .post(&url)
30            .headers(headers)
31            .query(&query)
32            .json(&request)
33            .send()
34            .await
35            .map_err(|e| {
36                Error::HttpClient(format!(
37                    "Failed to send HTTP request. Error = {}, url = {url:?}",
38                    e
39                ))
40            })?;
41        let status_code = resp.status();
42        if status_code.is_success() {
43            let value: serde_json::Value = resp.json().await.map_err(|e| {
44                Error::HttpClient(format!(
45                    "Failed to read JSON from HTTP request. Error = {}, url = {url}",
46                    e
47                ))
48            })?;
49            if value.get("choices").is_some() {
50                return Ok(serde_json::from_value(value.clone())?);
51            }
52            if let Some(error) = value.get("error") {
53                return Err(Error::HttpClient(format!(
54                    "Failed to process HTTP request. Error = {error:?}, url = {url}",
55                )));
56            } else {
57                return Err(Error::HttpClient(format!(
58                    "Failed to process HTTP request. url = {url:?}",
59                )));
60            }
61        } else {
62            let body = resp.text().await.map_err(|e| {
63                Error::HttpClient(format!(
64                    "Failed to read text from HTTP request. Error = {}, url = {url}",
65                    e
66                ))
67            })?;
68            Err(Error::HttpClient(format!(
69            "Failed to process HTTP request. Status Code = {status_code:?}, url = {url} - body = {body}",
70        )))
71        }
72    }
73
74    async fn post_stream<I: Serialize + Send, O: DeserializeOwned + Send + 'static>(
75        &self,
76        path: &str,
77        request: I,
78    ) -> Result<Pin<Box<dyn Stream<Item = Result<O, Error>> + Send>>, Error> {
79        let url = self.config.url(path);
80        let headers = self.config.headers()?;
81        let query = self.config.query();
82        let event_source = self
83            .client
84            .post(url)
85            .headers(headers)
86            .query(&query)
87            .json(&request)
88            .eventsource()
89            .map_err(|e| {
90                Error::HttpClient(format!("Failed to send HTTP request. Error = {}", e))
91            })?;
92        stream(event_source, self.config.stream_done_message()).await
93    }
94}
95
96impl<C: Config> SimpleHttpClient<C> {
97    pub fn new(config: C) -> Self {
98        Self {
99            client: reqwest::Client::new(),
100            config,
101        }
102    }
103}