bullpen/
api.rs

1use std::pin::Pin;
2
3use futures::stream::StreamExt;
4use futures::{Stream, TryStreamExt};
5use reqwest::multipart::Form;
6use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
7use serde::de::DeserializeOwned;
8use serde::Serialize;
9use tokio_util::bytes::Bytes;
10
11use crate::error::APIError;
12
13const PPLX_API_ENDPOINT: &str = "https://api.perplexity.ai";
14const MODELFARM_API_ENDPOINT: &str = "https://proxy-modelfarm.replit.app";
15const OLLAMA_API_ENDPOINT: &str = "http://localhost:11434";
16
17pub struct Pplx {
18    pub http_client: reqwest::Client,
19    pub base_url: String,
20    pub api_key: String,
21}
22
23impl Pplx {
24    pub fn new(api_key: String) -> Self {
25        Self {
26            http_client: reqwest::Client::new(),
27            base_url: PPLX_API_ENDPOINT.to_string(),
28            api_key,
29        }
30    }
31
32    pub async fn get(&self, path: &str) -> Result<String, APIError> {
33        let url = format!("{}{}", &self.base_url, path);
34
35        let response = self
36            .http_client
37            .get(url)
38            .header(reqwest::header::CONTENT_TYPE, "application/json")
39            .bearer_auth(&self.api_key)
40            .send()
41            .await
42            .unwrap();
43
44        if response.status().is_server_error() {
45            return Err(APIError::EndpointError(response.text().await.unwrap()));
46        }
47
48        let response_text = response.text().await.unwrap();
49
50        #[cfg(feature = "log")]
51        log::trace!("{}", response_text);
52
53        Ok(response_text)
54    }
55
56    pub async fn post<T: Serialize>(&self, path: &str, parameters: &T) -> Result<String, APIError> {
57        let url = format!("{}{}", &self.base_url, path);
58
59        let response =
60            self.http_client
61                .post(url)
62                .header(reqwest::header::CONTENT_TYPE, "application/json")
63                .bearer_auth(&self.api_key)
64                .json(&parameters)
65                .send()
66                .await
67                .unwrap();
68
69        if !response.status().is_success() {
70            return Err(APIError::EndpointError(response.text().await.unwrap()));
71        }
72
73        Ok(response.text().await.unwrap())
74    }
75
76    pub async fn delete(&self, path: &str) -> Result<String, APIError> {
77        let url = format!("{}{}", &self.base_url, path);
78
79        let response = self
80            .http_client
81            .delete(url)
82            .header(reqwest::header::CONTENT_TYPE, "application/json")
83            .bearer_auth(&self.api_key)
84            .send()
85            .await
86            .unwrap();
87
88        if response.status().is_server_error() {
89            return Err(APIError::EndpointError(response.text().await.unwrap()));
90        }
91
92        Ok(response.text().await.unwrap())
93    }
94
95    pub async fn post_with_form(&self, path: &str, form: Form) -> Result<String, APIError> {
96        let url = format!("{}{}", &self.base_url, path);
97
98        let response = self
99            .http_client
100            .post(url)
101            // .header(reqwest::header::CONTENT_TYPE, "multipart/form-data")
102            .bearer_auth(&self.api_key)
103            .multipart(form)
104            .send()
105            .await
106            .unwrap();
107
108        if !response.status().is_success() {
109            return Err(APIError::EndpointError(response.text().await.unwrap()));
110        }
111
112        Ok(response.text().await.unwrap())
113    }
114
115    pub async fn post_stream<I, O>(
116        &self,
117        path: &str,
118        parameters: &I,
119    ) -> Pin<Box<dyn Stream<Item = Result<O, APIError>> + Send>>
120    where
121        I: Serialize,
122        O: DeserializeOwned + std::marker::Send + 'static,
123    {
124        let url = format!("{}{}", &self.base_url, path);
125
126        let event_source = self
127            .http_client
128            .post(url)
129            .header(reqwest::header::CONTENT_TYPE, "application/json")
130            .bearer_auth(&self.api_key)
131            .json(&parameters)
132            .eventsource()
133            .unwrap();
134
135        Pplx::process_stream::<O>(event_source).await
136    }
137
138    pub async fn process_stream<O>(
139        mut event_source: EventSource,
140    ) -> Pin<Box<dyn Stream<Item = Result<O, APIError>> + Send>>
141    where
142        O: DeserializeOwned + Send + 'static,
143    {
144        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
145
146        #[derive(serde::Deserialize)]
147        struct StreamErrorWrapper {
148            error: StreamError,
149        }
150
151        #[derive(serde::Deserialize)]
152        struct StreamError {
153            message: String,
154            #[serde(rename = "type")]
155            error_type: String,
156            _param: Option<serde_json::Value>,
157            _code: Option<u8>,
158        }
159
160        tokio::spawn(async move {
161            while let Some(event_result) = event_source.next().await {
162                match event_result {
163                    Ok(event) => match event {
164                        Event::Open => continue,
165                        Event::Message(message) => {
166                            if message.data == "[DONE]" {
167                                break;
168                            }
169
170                            let response = match serde_json::from_str::<O>(&message.data) {
171                                Ok(result) => Ok(result),
172                                Err(error) => {
173                                    // Try to parse an error message from the stream
174                                    match serde_json::from_str::<StreamErrorWrapper>(&message.data) {
175                                        Ok(error_wrapper) => Err(APIError::StreamError(format!("OpenAI {}: {}",  error_wrapper.error.error_type, error_wrapper.error.message))),
176                                        Err(_) => Err(APIError::StreamError(format!("OpenAI error parsing event stream: {}\nstream data: {}", error.to_string(), message.data))),
177                                    }
178                                }
179                            };
180
181                            if let Err(_error) = tx.send(response) {
182                                break;
183                            }
184                        }
185                    },
186                    Err(_error) => {
187                        // if let Err(_error) =
188                        // tx.send(Err(APIError::StreamError(error.
189                        // to_string()))) {
190                        //     break;
191                        // }
192                    }
193                }
194            }
195
196            event_source.close();
197        });
198
199        Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
200    }
201}
202
203pub struct Modelfarm {
204    pub http_client: reqwest::Client,
205    pub base_url: String,
206}
207
208impl Modelfarm {
209    pub fn new() -> Self {
210        Self {
211            http_client: reqwest::Client::new(),
212            base_url: MODELFARM_API_ENDPOINT.to_string(),
213        }
214    }
215
216    pub async fn get(&self, path: &str) -> Result<String, APIError> {
217        let url = format!("{}{}", &self.base_url, path);
218
219        let response = self
220            .http_client
221            .get(url)
222            .header(reqwest::header::CONTENT_TYPE, "application/json")
223            .send()
224            .await
225            .unwrap();
226
227        if response.status().is_server_error() {
228            return Err(APIError::EndpointError(response.text().await.unwrap()));
229        }
230
231        Ok(response.text().await.unwrap())
232    }
233
234    pub async fn post<T: Serialize>(&self, path: &str, parameters: &T) -> Result<String, APIError> {
235        let url = format!("{}{}", &self.base_url, path);
236
237        let response = self
238            .http_client
239            .post(url)
240            .header(reqwest::header::CONTENT_TYPE, "application/json")
241            .json(&parameters)
242            .send()
243            .await
244            .unwrap();
245
246        if !response.status().is_success() {
247            return Err(APIError::EndpointError(response.text().await.unwrap()));
248        }
249
250        Ok(response.text().await.unwrap())
251    }
252
253    pub async fn post_stream<I>(
254        &self,
255        path: &str,
256        parameters: &I,
257    ) -> Pin<Box<dyn Stream<Item = anyhow::Result<Bytes>> + Send>>
258    where
259        I: Serialize,
260    {
261        let url = format!("{}{}", &self.base_url, path);
262
263        let request = self
264            .http_client
265            .post(url)
266            .header(reqwest::header::CONTENT_TYPE, "application/json")
267            .json(&parameters)
268            .build()
269            .unwrap();
270
271        Box::pin(
272            self.http_client
273                .execute(request)
274                .await
275                .map_err(anyhow::Error::from)
276                .unwrap()
277                .bytes_stream()
278                .map_err(anyhow::Error::from),
279        )
280    }
281}
282
283pub struct Ollama {
284    pub http_client: reqwest::Client,
285    pub base_url: String,
286}
287
288impl Ollama {
289    pub fn new() -> Self {
290        Self {
291            http_client: reqwest::Client::new(),
292            base_url: OLLAMA_API_ENDPOINT.to_string(),
293        }
294    }
295
296    pub async fn get(&self, path: &str) -> Result<String, APIError> {
297        let url = format!("{}{}", &self.base_url, path);
298
299        let response = self
300            .http_client
301            .get(url)
302            .header(reqwest::header::CONTENT_TYPE, "application/json")
303            .send()
304            .await
305            .unwrap();
306
307        if response.status().is_server_error() {
308            return Err(APIError::EndpointError(response.text().await.unwrap()));
309        }
310
311        let response_text = response.text().await.unwrap();
312
313        Ok(response_text)
314    }
315
316    pub async fn post<T: Serialize>(&self, path: &str, parameters: &T) -> Result<String, APIError> {
317        let url = format!("{}{}", &self.base_url, path);
318
319        let response = self
320            .http_client
321            .post(url)
322            .header(reqwest::header::CONTENT_TYPE, "application/json")
323            .json(&parameters)
324            .send()
325            .await
326            .unwrap();
327
328        if !response.status().is_success() {
329            return Err(APIError::EndpointError(response.text().await.unwrap()));
330        }
331
332        Ok(response.text().await.unwrap())
333    }
334
335    pub async fn post_stream<I>(
336        &self,
337        path: &str,
338        parameters: &I,
339    ) -> Pin<Box<dyn Stream<Item = anyhow::Result<Bytes>> + Send>>
340    where
341        I: Serialize,
342    {
343        let url = format!("{}{}", &self.base_url, path);
344
345        let request = self
346            .http_client
347            .post(url)
348            .header(reqwest::header::CONTENT_TYPE, "application/json")
349            .json(&parameters)
350            .build()
351            .unwrap();
352
353        Box::pin(
354            self.http_client
355                .execute(request)
356                .await
357                .map_err(anyhow::Error::from)
358                .unwrap()
359                .bytes_stream()
360                .map_err(anyhow::Error::from),
361        )
362    }
363}