1use std::pin::Pin;
2
3use futures::stream::StreamExt;
4use futures::{Stream, TryStreamExt};
5use reqwest::multipart::Form;
6use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
7use serde::de::DeserializeOwned;
8use serde::Serialize;
9use tokio_util::bytes::Bytes;
10
11use crate::error::APIError;
12
13const PPLX_API_ENDPOINT: &str = "https://api.perplexity.ai";
14const MODELFARM_API_ENDPOINT: &str = "https://proxy-modelfarm.replit.app";
15const OLLAMA_API_ENDPOINT: &str = "http://localhost:11434";
16
17pub struct Pplx {
18 pub http_client: reqwest::Client,
19 pub base_url: String,
20 pub api_key: String,
21}
22
23impl Pplx {
24 pub fn new(api_key: String) -> Self {
25 Self {
26 http_client: reqwest::Client::new(),
27 base_url: PPLX_API_ENDPOINT.to_string(),
28 api_key,
29 }
30 }
31
32 pub async fn get(&self, path: &str) -> Result<String, APIError> {
33 let url = format!("{}{}", &self.base_url, path);
34
35 let response = self
36 .http_client
37 .get(url)
38 .header(reqwest::header::CONTENT_TYPE, "application/json")
39 .bearer_auth(&self.api_key)
40 .send()
41 .await
42 .unwrap();
43
44 if response.status().is_server_error() {
45 return Err(APIError::EndpointError(response.text().await.unwrap()));
46 }
47
48 let response_text = response.text().await.unwrap();
49
50 #[cfg(feature = "log")]
51 log::trace!("{}", response_text);
52
53 Ok(response_text)
54 }
55
56 pub async fn post<T: Serialize>(&self, path: &str, parameters: &T) -> Result<String, APIError> {
57 let url = format!("{}{}", &self.base_url, path);
58
59 let response =
60 self.http_client
61 .post(url)
62 .header(reqwest::header::CONTENT_TYPE, "application/json")
63 .bearer_auth(&self.api_key)
64 .json(¶meters)
65 .send()
66 .await
67 .unwrap();
68
69 if !response.status().is_success() {
70 return Err(APIError::EndpointError(response.text().await.unwrap()));
71 }
72
73 Ok(response.text().await.unwrap())
74 }
75
76 pub async fn delete(&self, path: &str) -> Result<String, APIError> {
77 let url = format!("{}{}", &self.base_url, path);
78
79 let response = self
80 .http_client
81 .delete(url)
82 .header(reqwest::header::CONTENT_TYPE, "application/json")
83 .bearer_auth(&self.api_key)
84 .send()
85 .await
86 .unwrap();
87
88 if response.status().is_server_error() {
89 return Err(APIError::EndpointError(response.text().await.unwrap()));
90 }
91
92 Ok(response.text().await.unwrap())
93 }
94
95 pub async fn post_with_form(&self, path: &str, form: Form) -> Result<String, APIError> {
96 let url = format!("{}{}", &self.base_url, path);
97
98 let response = self
99 .http_client
100 .post(url)
101 .bearer_auth(&self.api_key)
103 .multipart(form)
104 .send()
105 .await
106 .unwrap();
107
108 if !response.status().is_success() {
109 return Err(APIError::EndpointError(response.text().await.unwrap()));
110 }
111
112 Ok(response.text().await.unwrap())
113 }
114
115 pub async fn post_stream<I, O>(
116 &self,
117 path: &str,
118 parameters: &I,
119 ) -> Pin<Box<dyn Stream<Item = Result<O, APIError>> + Send>>
120 where
121 I: Serialize,
122 O: DeserializeOwned + std::marker::Send + 'static,
123 {
124 let url = format!("{}{}", &self.base_url, path);
125
126 let event_source = self
127 .http_client
128 .post(url)
129 .header(reqwest::header::CONTENT_TYPE, "application/json")
130 .bearer_auth(&self.api_key)
131 .json(¶meters)
132 .eventsource()
133 .unwrap();
134
135 Pplx::process_stream::<O>(event_source).await
136 }
137
138 pub async fn process_stream<O>(
139 mut event_source: EventSource,
140 ) -> Pin<Box<dyn Stream<Item = Result<O, APIError>> + Send>>
141 where
142 O: DeserializeOwned + Send + 'static,
143 {
144 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
145
146 #[derive(serde::Deserialize)]
147 struct StreamErrorWrapper {
148 error: StreamError,
149 }
150
151 #[derive(serde::Deserialize)]
152 struct StreamError {
153 message: String,
154 #[serde(rename = "type")]
155 error_type: String,
156 _param: Option<serde_json::Value>,
157 _code: Option<u8>,
158 }
159
160 tokio::spawn(async move {
161 while let Some(event_result) = event_source.next().await {
162 match event_result {
163 Ok(event) => match event {
164 Event::Open => continue,
165 Event::Message(message) => {
166 if message.data == "[DONE]" {
167 break;
168 }
169
170 let response = match serde_json::from_str::<O>(&message.data) {
171 Ok(result) => Ok(result),
172 Err(error) => {
173 match serde_json::from_str::<StreamErrorWrapper>(&message.data) {
175 Ok(error_wrapper) => Err(APIError::StreamError(format!("OpenAI {}: {}", error_wrapper.error.error_type, error_wrapper.error.message))),
176 Err(_) => Err(APIError::StreamError(format!("OpenAI error parsing event stream: {}\nstream data: {}", error.to_string(), message.data))),
177 }
178 }
179 };
180
181 if let Err(_error) = tx.send(response) {
182 break;
183 }
184 }
185 },
186 Err(_error) => {
187 }
193 }
194 }
195
196 event_source.close();
197 });
198
199 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
200 }
201}
202
203pub struct Modelfarm {
204 pub http_client: reqwest::Client,
205 pub base_url: String,
206}
207
208impl Modelfarm {
209 pub fn new() -> Self {
210 Self {
211 http_client: reqwest::Client::new(),
212 base_url: MODELFARM_API_ENDPOINT.to_string(),
213 }
214 }
215
216 pub async fn get(&self, path: &str) -> Result<String, APIError> {
217 let url = format!("{}{}", &self.base_url, path);
218
219 let response = self
220 .http_client
221 .get(url)
222 .header(reqwest::header::CONTENT_TYPE, "application/json")
223 .send()
224 .await
225 .unwrap();
226
227 if response.status().is_server_error() {
228 return Err(APIError::EndpointError(response.text().await.unwrap()));
229 }
230
231 Ok(response.text().await.unwrap())
232 }
233
234 pub async fn post<T: Serialize>(&self, path: &str, parameters: &T) -> Result<String, APIError> {
235 let url = format!("{}{}", &self.base_url, path);
236
237 let response = self
238 .http_client
239 .post(url)
240 .header(reqwest::header::CONTENT_TYPE, "application/json")
241 .json(¶meters)
242 .send()
243 .await
244 .unwrap();
245
246 if !response.status().is_success() {
247 return Err(APIError::EndpointError(response.text().await.unwrap()));
248 }
249
250 Ok(response.text().await.unwrap())
251 }
252
253 pub async fn post_stream<I>(
254 &self,
255 path: &str,
256 parameters: &I,
257 ) -> Pin<Box<dyn Stream<Item = anyhow::Result<Bytes>> + Send>>
258 where
259 I: Serialize,
260 {
261 let url = format!("{}{}", &self.base_url, path);
262
263 let request = self
264 .http_client
265 .post(url)
266 .header(reqwest::header::CONTENT_TYPE, "application/json")
267 .json(¶meters)
268 .build()
269 .unwrap();
270
271 Box::pin(
272 self.http_client
273 .execute(request)
274 .await
275 .map_err(anyhow::Error::from)
276 .unwrap()
277 .bytes_stream()
278 .map_err(anyhow::Error::from),
279 )
280 }
281}
282
283pub struct Ollama {
284 pub http_client: reqwest::Client,
285 pub base_url: String,
286}
287
288impl Ollama {
289 pub fn new() -> Self {
290 Self {
291 http_client: reqwest::Client::new(),
292 base_url: OLLAMA_API_ENDPOINT.to_string(),
293 }
294 }
295
296 pub async fn get(&self, path: &str) -> Result<String, APIError> {
297 let url = format!("{}{}", &self.base_url, path);
298
299 let response = self
300 .http_client
301 .get(url)
302 .header(reqwest::header::CONTENT_TYPE, "application/json")
303 .send()
304 .await
305 .unwrap();
306
307 if response.status().is_server_error() {
308 return Err(APIError::EndpointError(response.text().await.unwrap()));
309 }
310
311 let response_text = response.text().await.unwrap();
312
313 Ok(response_text)
314 }
315
316 pub async fn post<T: Serialize>(&self, path: &str, parameters: &T) -> Result<String, APIError> {
317 let url = format!("{}{}", &self.base_url, path);
318
319 let response = self
320 .http_client
321 .post(url)
322 .header(reqwest::header::CONTENT_TYPE, "application/json")
323 .json(¶meters)
324 .send()
325 .await
326 .unwrap();
327
328 if !response.status().is_success() {
329 return Err(APIError::EndpointError(response.text().await.unwrap()));
330 }
331
332 Ok(response.text().await.unwrap())
333 }
334
335 pub async fn post_stream<I>(
336 &self,
337 path: &str,
338 parameters: &I,
339 ) -> Pin<Box<dyn Stream<Item = anyhow::Result<Bytes>> + Send>>
340 where
341 I: Serialize,
342 {
343 let url = format!("{}{}", &self.base_url, path);
344
345 let request = self
346 .http_client
347 .post(url)
348 .header(reqwest::header::CONTENT_TYPE, "application/json")
349 .json(¶meters)
350 .build()
351 .unwrap();
352
353 Box::pin(
354 self.http_client
355 .execute(request)
356 .await
357 .map_err(anyhow::Error::from)
358 .unwrap()
359 .bytes_stream()
360 .map_err(anyhow::Error::from),
361 )
362 }
363}