1use std::{borrow::Cow, pin::Pin, time::Duration};
2
3use futures_util::{stream::StreamExt, Stream};
4use reqwest::{header, ClientBuilder, RequestBuilder, Response, StatusCode};
5use serde::Deserialize;
6use thiserror::Error as ThisError;
7use tokenizers::Tokenizer;
8
9use crate::{How, StreamJob};
10use async_stream::stream;
11
12pub trait Job {
20 type Output;
22
23 type ResponseBody: for<'de> Deserialize<'de>;
25
26 fn build_request(&self, client: &reqwest::Client, base: &str) -> RequestBuilder;
29
30 fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output;
32}
33
34pub trait Task {
37 type Output;
39
40 type ResponseBody: for<'de> Deserialize<'de>;
42
43 fn build_request(&self, client: &reqwest::Client, base: &str, model: &str) -> RequestBuilder;
46
47 fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output;
49
50 fn with_model<'a>(&'a self, model: &'a str) -> MethodJob<'a, Self>
52 where
53 Self: Sized,
54 {
55 MethodJob { model, task: self }
56 }
57}
58
59pub struct MethodJob<'a, T> {
62 pub model: &'a str,
64 pub task: &'a T,
66}
67
68impl<T> Job for MethodJob<'_, T>
69where
70 T: Task,
71{
72 type Output = T::Output;
73
74 type ResponseBody = T::ResponseBody;
75
76 fn build_request(&self, client: &reqwest::Client, base: &str) -> RequestBuilder {
77 self.task.build_request(client, base, self.model)
78 }
79
80 fn body_to_output(&self, response: T::ResponseBody) -> T::Output {
81 self.task.body_to_output(response)
82 }
83}
84
85pub struct HttpClient {
87 base: String,
88 http: reqwest::Client,
89 api_token: Option<String>,
90}
91
92impl HttpClient {
93 pub fn new(host: String, api_token: Option<String>) -> Result<Self, Error> {
96 let http = ClientBuilder::new().build()?;
97
98 Ok(Self {
99 base: host,
100 http,
101 api_token,
102 })
103 }
104
105 async fn response(&self, builder: RequestBuilder, how: &How) -> Result<Response, Error> {
107 let query = if how.be_nice {
108 [("nice", "true")].as_slice()
109 } else {
110 [].as_slice()
112 };
113
114 let api_token = how
115 .api_token
116 .as_ref()
117 .or(self.api_token.as_ref())
118 .expect("API token needs to be set on client construction or per request");
119 let response = builder
120 .query(query)
121 .header(header::AUTHORIZATION, Self::header_from_token(api_token))
122 .timeout(how.client_timeout)
123 .send()
124 .await
125 .map_err(|reqwest_error| {
126 if reqwest_error.is_timeout() {
127 Error::ClientTimeout(how.client_timeout)
128 } else {
129 reqwest_error.into()
130 }
131 })?;
132 translate_http_error(response).await
133 }
134
135 pub async fn output_of<T: Job>(&self, task: &T, how: &How) -> Result<T::Output, Error> {
161 let builder = task.build_request(&self.http, &self.base);
162 let response = self.response(builder, how).await?;
163 let response_body: T::ResponseBody = response.json().await?;
164 let answer = task.body_to_output(response_body);
165 Ok(answer)
166 }
167
168 pub async fn stream_output_of<T: StreamJob>(
169 &self,
170 task: &T,
171 how: &How,
172 ) -> Result<Pin<Box<dyn Stream<Item = Result<T::Output, Error>> + Send>>, Error>
173 where
174 T::Output: 'static,
175 {
176 let builder = task.build_request(&self.http, &self.base);
177 let response = self.response(builder, how).await?;
178 let mut stream = response.bytes_stream();
179
180 Ok(Box::pin(stream! {
181 while let Some(item) = stream.next().await {
182 match item {
183 Ok(bytes) => {
184 let events = Self::parse_stream_event::<T::ResponseBody>(bytes.as_ref());
185 for event in events {
186 yield event.map(|b| T::body_to_output(b));
187 }
188 }
189 Err(e) => {
190 yield Err(e.into());
191 }
192 }
193 }
194 }))
195 }
196
197 fn parse_stream_event<StreamBody>(bytes: &[u8]) -> Vec<Result<StreamBody, Error>>
200 where
201 StreamBody: for<'de> Deserialize<'de>,
202 {
203 String::from_utf8_lossy(bytes)
204 .split("data: ")
205 .skip(1)
206 .map(|s| {
207 serde_json::from_str(s).map_err(|e| Error::InvalidStream {
208 deserialization_error: e.to_string(),
209 })
210 })
211 .collect()
212 }
213
214 fn header_from_token(api_token: &str) -> header::HeaderValue {
215 let mut auth_value = header::HeaderValue::from_str(&format!("Bearer {api_token}")).unwrap();
216 auth_value.set_sensitive(true);
218 auth_value
219 }
220
221 pub async fn tokenizer_by_model(
222 &self,
223 model: &str,
224 api_token: Option<String>,
225 ) -> Result<Tokenizer, Error> {
226 let api_token = api_token
227 .as_ref()
228 .or(self.api_token.as_ref())
229 .expect("API token needs to be set on client construction or per request");
230 let response = self
231 .http
232 .get(format!("{}/models/{model}/tokenizer", self.base))
233 .header(header::AUTHORIZATION, Self::header_from_token(api_token))
234 .send()
235 .await?;
236 let response = translate_http_error(response).await?;
237 let bytes = response.bytes().await?;
238 let tokenizer = Tokenizer::from_bytes(bytes).map_err(|e| Error::InvalidTokenizer {
239 deserialization_error: e.to_string(),
240 })?;
241 Ok(tokenizer)
242 }
243}
244
245async fn translate_http_error(response: reqwest::Response) -> Result<reqwest::Response, Error> {
246 let status = response.status();
247 if !status.is_success() {
248 let body = response.text().await?;
252 let api_error: Result<ApiError, _> = serde_json::from_str(&body);
254 let translated_error = match status {
255 StatusCode::TOO_MANY_REQUESTS => Error::TooManyRequests,
256 StatusCode::SERVICE_UNAVAILABLE => {
257 if api_error.is_ok_and(|error| error.code == "QUEUE_FULL") {
260 Error::Busy
261 } else {
262 Error::Unavailable
263 }
264 }
265 _ => Error::Http {
266 status: status.as_u16(),
267 body,
268 },
269 };
270 Err(translated_error)
271 } else {
272 Ok(response)
273 }
274}
275
276#[derive(Deserialize, Debug)]
278struct ApiError<'a> {
279 code: Cow<'a, str>,
285}
286
287#[derive(ThisError, Debug)]
289pub enum Error {
290 #[error(
292 "You are trying to send too many requests to the API in to short an interval. Slow down a \
293 bit, otherwise these error will persist. Sorry for this, but we try to prevent DOS attacks."
294 )]
295 TooManyRequests,
296 #[error(
298 "Sorry the request to the Aleph Alpha API has been rejected due to the requested model \
299 being very busy at the moment. We found it unlikely that your request would finish in a \
300 reasonable timeframe, so it was rejected right away, rather than make you wait. You are \
301 welcome to retry your request any time."
302 )]
303 Busy,
304 #[error(
306 "The service is currently unavailable. This is likely due to restart. Please try again \
307 later."
308 )]
309 Unavailable,
310 #[error("No response received within given timeout: {0:?}")]
311 ClientTimeout(Duration),
312 #[error("HTTP request failed with status code {}. Body:\n{}", status, body)]
314 Http { status: u16, body: String },
315 #[error(
316 "Tokenizer could not be correctly deserialized. Caused by:\n{}",
317 deserialization_error
318 )]
319 InvalidTokenizer { deserialization_error: String },
320 #[error(
322 "Stream event could not be correctly deserialized. Caused by:\n{}.",
323 deserialization_error
324 )]
325 InvalidStream { deserialization_error: String },
326 #[error(transparent)]
328 Other(#[from] reqwest::Error),
329}
330
331#[cfg(test)]
332mod tests {
333 use crate::{chat::ChatEvent, completion::CompletionEvent};
334
335 use super::*;
336
337 #[test]
338 fn stream_chunk_event_is_parsed() {
339 let bytes = b"data: {\"type\":\"stream_chunk\",\"index\":0,\"completion\":\" The New York Times, May 15\"}\n\n";
341
342 let events = HttpClient::parse_stream_event::<CompletionEvent>(bytes);
344 let event = events.first().unwrap().as_ref().unwrap();
345
346 match event {
348 CompletionEvent::StreamChunk(chunk) => assert_eq!(chunk.index, 0),
349 _ => panic!("Expected a stream chunk"),
350 }
351 }
352
353 #[test]
354 fn completion_summary_event_is_parsed() {
355 let bytes = b"data: {\"type\":\"stream_summary\",\"index\":0,\"model_version\":\"2022-04\",\"finish_reason\":\"maximum_tokens\"}\n\ndata: {\"type\":\"completion_summary\",\"num_tokens_prompt_total\":1,\"num_tokens_generated\":7}\n\n";
357
358 let events = HttpClient::parse_stream_event::<CompletionEvent>(bytes);
360
361 let first = events.first().unwrap().as_ref().unwrap();
363 match first {
364 CompletionEvent::StreamSummary(summary) => {
365 assert_eq!(summary.finish_reason, "maximum_tokens")
366 }
367 _ => panic!("Expected a completion summary"),
368 }
369 let second = events.last().unwrap().as_ref().unwrap();
370 match second {
371 CompletionEvent::CompletionSummary(summary) => {
372 assert_eq!(summary.num_tokens_generated, 7)
373 }
374 _ => panic!("Expected a completion summary"),
375 }
376 }
377
378 #[test]
379 fn chat_stream_chunk_event_is_parsed() {
380 let bytes = b"data: {\"id\":\"831e41b4-2382-4b08-990e-0a3859967f43\",\"choices\":[{\"finish_reason\":null,\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"\"},\"logprobs\":null}],\"created\":1729782822,\"model\":\"pharia-1-llm-7b-control\",\"system_fingerprint\":null,\"object\":\"chat.completion.chunk\",\"usage\":null}\n\n";
382
383 let events = HttpClient::parse_stream_event::<ChatEvent>(bytes);
385 let event = events.first().unwrap().as_ref().unwrap();
386
387 assert_eq!(event.choices[0].delta.role.as_ref().unwrap(), "assistant");
389 }
390
391 #[test]
392 fn chat_stream_chunk_without_role_is_parsed() {
393 let bytes = b"data: {\"id\":\"a3ceca7f-32b2-4a6c-89e7-bc8eb5327f76\",\"choices\":[{\"finish_reason\":null,\"index\":0,\"delta\":{\"content\":\"Hello! How can I help you today? If you have any questions or need assistance, feel free to ask.\"},\"logprobs\":null}],\"created\":1729784197,\"model\":\"pharia-1-llm-7b-control\",\"system_fingerprint\":null,\"object\":\"chat.completion.chunk\",\"usage\":null}\n\n";
395
396 let events = HttpClient::parse_stream_event::<ChatEvent>(bytes);
398 let event = events.first().unwrap().as_ref().unwrap();
399
400 assert_eq!(event.choices[0].delta.content, "Hello! How can I help you today? If you have any questions or need assistance, feel free to ask.");
402 }
403}