1use std::{borrow::Cow, pin::Pin, time::Duration};
2
3use futures_util::{stream::StreamExt, Stream};
4use reqwest::{header, ClientBuilder, RequestBuilder, Response, StatusCode};
5use serde::Deserialize;
6use thiserror::Error as ThisError;
7use tokenizers::Tokenizer;
8
9use crate::{How, StreamJob};
10use async_stream::stream;
11
12pub trait Job {
20 type Output;
22
23 type ResponseBody: for<'de> Deserialize<'de>;
25
26 fn build_request(&self, client: &reqwest::Client, base: &str) -> RequestBuilder;
29
30 fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output;
32}
33
34pub trait Task {
37 type Output;
39
40 type ResponseBody: for<'de> Deserialize<'de>;
42
43 fn build_request(&self, client: &reqwest::Client, base: &str, model: &str) -> RequestBuilder;
46
47 fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output;
49
50 fn with_model<'a>(&'a self, model: &'a str) -> MethodJob<'a, Self>
52 where
53 Self: Sized,
54 {
55 MethodJob { model, task: self }
56 }
57}
58
59pub struct MethodJob<'a, T> {
62 pub model: &'a str,
64 pub task: &'a T,
66}
67
68impl<T> Job for MethodJob<'_, T>
69where
70 T: Task,
71{
72 type Output = T::Output;
73
74 type ResponseBody = T::ResponseBody;
75
76 fn build_request(&self, client: &reqwest::Client, base: &str) -> RequestBuilder {
77 self.task.build_request(client, base, self.model)
78 }
79
80 fn body_to_output(&self, response: T::ResponseBody) -> T::Output {
81 self.task.body_to_output(response)
82 }
83}
84
85pub struct HttpClient {
87 base: String,
88 http: reqwest::Client,
89 api_token: Option<String>,
90}
91
92impl HttpClient {
93 pub fn new(host: String, api_token: Option<String>) -> Result<Self, Error> {
96 let http = ClientBuilder::new().build()?;
97
98 Ok(Self {
99 base: host,
100 http,
101 api_token,
102 })
103 }
104
105 async fn response(&self, builder: RequestBuilder, how: &How) -> Result<Response, Error> {
107 let query = if how.be_nice {
108 [("nice", "true")].as_slice()
109 } else {
110 [].as_slice()
112 };
113
114 let api_token = how
115 .api_token
116 .as_ref()
117 .or(self.api_token.as_ref())
118 .expect("API token needs to be set on client construction or per request");
119 let response = builder
120 .query(query)
121 .header(header::AUTHORIZATION, Self::header_from_token(api_token))
122 .timeout(how.client_timeout)
123 .send()
124 .await
125 .map_err(|reqwest_error| {
126 if reqwest_error.is_timeout() {
127 Error::ClientTimeout(how.client_timeout)
128 } else {
129 reqwest_error.into()
130 }
131 })?;
132 translate_http_error(response).await
133 }
134
135 pub async fn output_of<T: Job>(&self, task: &T, how: &How) -> Result<T::Output, Error> {
161 let builder = task.build_request(&self.http, &self.base);
162 let response = self.response(builder, how).await?;
163 let response_body: T::ResponseBody = response.json().await?;
164 let answer = task.body_to_output(response_body);
165 Ok(answer)
166 }
167
168 pub async fn stream_output_of<'task, T: StreamJob + Send + Sync + 'task>(
169 &self,
170 task: T,
171 how: &How,
172 ) -> Result<Pin<Box<dyn Stream<Item = Result<T::Output, Error>> + Send + 'task>>, Error>
173 where
174 T::Output: 'static,
175 {
176 let builder = task.build_request(&self.http, &self.base);
177 let response = self.response(builder, how).await?;
178 let mut stream = response.bytes_stream();
179
180 Ok(Box::pin(stream! {
181 while let Some(item) = stream.next().await {
182 match item {
183 Ok(bytes) => {
184 let events = Self::parse_stream_event::<T::ResponseBody>(bytes.as_ref());
185 for event in events {
186 match event {
187 Ok(b) => if let Some(output) = task.body_to_output(b) {
189 yield Ok(output);
190 }
191 Err(e) => {
192 yield Err(e);
193 }
194 }
195
196 }
197 }
198 Err(e) => {
199 yield Err(e.into());
200 }
201 }
202 }
203 }))
204 }
205
206 fn parse_stream_event<StreamBody>(bytes: &[u8]) -> Vec<Result<StreamBody, Error>>
209 where
210 StreamBody: for<'de> Deserialize<'de>,
211 {
212 String::from_utf8_lossy(bytes)
213 .split("data: ")
214 .skip(1)
215 .filter(|s| s.trim() != "[DONE]")
219 .map(|s| {
220 serde_json::from_str(s).map_err(|e| Error::InvalidStream {
221 deserialization_error: e.to_string(),
222 })
223 })
224 .collect()
225 }
226
227 fn header_from_token(api_token: &str) -> header::HeaderValue {
228 let mut auth_value = header::HeaderValue::from_str(&format!("Bearer {api_token}")).unwrap();
229 auth_value.set_sensitive(true);
231 auth_value
232 }
233
234 pub async fn tokenizer_by_model(
235 &self,
236 model: &str,
237 api_token: Option<String>,
238 ) -> Result<Tokenizer, Error> {
239 let api_token = api_token
240 .as_ref()
241 .or(self.api_token.as_ref())
242 .expect("API token needs to be set on client construction or per request");
243 let response = self
244 .http
245 .get(format!("{}/models/{model}/tokenizer", self.base))
246 .header(header::AUTHORIZATION, Self::header_from_token(api_token))
247 .send()
248 .await?;
249 let response = translate_http_error(response).await?;
250 let bytes = response.bytes().await?;
251 let tokenizer = Tokenizer::from_bytes(bytes).map_err(|e| Error::InvalidTokenizer {
252 deserialization_error: e.to_string(),
253 })?;
254 Ok(tokenizer)
255 }
256}
257
258async fn translate_http_error(response: reqwest::Response) -> Result<reqwest::Response, Error> {
259 let status = response.status();
260 if !status.is_success() {
261 let body = response.text().await?;
265 let api_error: Result<ApiError, _> = serde_json::from_str(&body);
267 let translated_error = match status {
268 StatusCode::TOO_MANY_REQUESTS => Error::TooManyRequests,
269 StatusCode::SERVICE_UNAVAILABLE => {
270 if api_error.is_ok_and(|error| error.code == "QUEUE_FULL") {
273 Error::Busy
274 } else {
275 Error::Unavailable
276 }
277 }
278 _ => Error::Http {
279 status: status.as_u16(),
280 body,
281 },
282 };
283 Err(translated_error)
284 } else {
285 Ok(response)
286 }
287}
288
289#[derive(Deserialize, Debug)]
291struct ApiError<'a> {
292 code: Cow<'a, str>,
298}
299
300#[derive(ThisError, Debug)]
302pub enum Error {
303 #[error(
305 "You are trying to send too many requests to the API in to short an interval. Slow down a \
306 bit, otherwise these error will persist. Sorry for this, but we try to prevent DOS attacks."
307 )]
308 TooManyRequests,
309 #[error(
311 "Sorry the request to the Aleph Alpha API has been rejected due to the requested model \
312 being very busy at the moment. We found it unlikely that your request would finish in a \
313 reasonable timeframe, so it was rejected right away, rather than make you wait. You are \
314 welcome to retry your request any time."
315 )]
316 Busy,
317 #[error(
319 "The service is currently unavailable. This is likely due to restart. Please try again \
320 later."
321 )]
322 Unavailable,
323 #[error("No response received within given timeout: {0:?}")]
324 ClientTimeout(Duration),
325 #[error("HTTP request failed with status code {}. Body:\n{}", status, body)]
327 Http { status: u16, body: String },
328 #[error(
329 "Tokenizer could not be correctly deserialized. Caused by:\n{}",
330 deserialization_error
331 )]
332 InvalidTokenizer { deserialization_error: String },
333 #[error(
335 "Stream event could not be correctly deserialized. Caused by:\n{}.",
336 deserialization_error
337 )]
338 InvalidStream { deserialization_error: String },
339 #[error(transparent)]
341 Other(#[from] reqwest::Error),
342}
343
344#[cfg(test)]
345mod tests {
346 use crate::{
347 chat::{DeserializedChatChunk, StreamChatResponse, StreamMessage},
348 completion::DeserializedCompletionEvent,
349 };
350
351 use super::*;
352
353 #[test]
354 fn stream_chunk_event_is_parsed() {
355 let bytes = b"data: {\"type\":\"stream_chunk\",\"index\":0,\"completion\":\" The New York Times, May 15\"}\n\n";
357
358 let events = HttpClient::parse_stream_event::<DeserializedCompletionEvent>(bytes);
360 let event = events.first().unwrap().as_ref().unwrap();
361
362 match event {
364 DeserializedCompletionEvent::StreamChunk { completion, .. } => {
365 assert_eq!(completion, " The New York Times, May 15")
366 }
367 _ => panic!("Expected a stream chunk"),
368 }
369 }
370
371 #[test]
372 fn completion_summary_event_is_parsed() {
373 let bytes = b"data: {\"type\":\"stream_summary\",\"index\":0,\"model_version\":\"2022-04\",\"finish_reason\":\"maximum_tokens\"}\n\ndata: {\"type\":\"completion_summary\",\"num_tokens_prompt_total\":1,\"num_tokens_generated\":7}\n\n";
375
376 let events = HttpClient::parse_stream_event::<DeserializedCompletionEvent>(bytes);
378
379 let first = events.first().unwrap().as_ref().unwrap();
381 match first {
382 DeserializedCompletionEvent::StreamSummary { finish_reason } => {
383 assert_eq!(finish_reason, "maximum_tokens")
384 }
385 _ => panic!("Expected a completion summary"),
386 }
387 let second = events.last().unwrap().as_ref().unwrap();
388 match second {
389 DeserializedCompletionEvent::CompletionSummary {
390 num_tokens_generated,
391 ..
392 } => {
393 assert_eq!(*num_tokens_generated, 7)
394 }
395 _ => panic!("Expected a completion summary"),
396 }
397 }
398
399 #[test]
400 fn chat_usage_event_is_parsed() {
401 let bytes = b"data: {\"id\": \"67c5b5f2-6672-4b0b-82b1-cc844127b214\",\"choices\": [],\"created\": 1739539146,\"model\": \"pharia-1-llm-7b-control\",\"system_fingerprint\": \".unknown.\",\"object\": \"chat.completion.chunk\",\"usage\": {\"prompt_tokens\": 20,\"completion_tokens\": 10,\"total_tokens\": 30}}";
403
404 let events = HttpClient::parse_stream_event::<StreamChatResponse>(bytes);
406 let event = events.first().unwrap().as_ref().unwrap();
407
408 assert_eq!(event.usage.as_ref().unwrap().prompt_tokens, 20);
410 assert_eq!(event.usage.as_ref().unwrap().completion_tokens, 10);
411 }
412
413 #[test]
414 fn chat_stream_chunk_event_is_parsed() {
415 let bytes = b"data: {\"id\":\"831e41b4-2382-4b08-990e-0a3859967f43\",\"choices\":[{\"finish_reason\":null,\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"\"},\"logprobs\":null}],\"created\":1729782822,\"model\":\"pharia-1-llm-7b-control\",\"system_fingerprint\":null,\"object\":\"chat.completion.chunk\",\"usage\":null}\n\n";
417
418 let events = HttpClient::parse_stream_event::<StreamChatResponse>(bytes);
420 let event = events.first().unwrap().as_ref().unwrap();
421
422 assert_eq!(event.choices.len(), 1);
424 assert!(
425 matches!(&event.choices[0], DeserializedChatChunk::Delta { delta: StreamMessage { role: Some(role), .. }, .. } if role == "assistant")
426 );
427 }
428
429 #[test]
430 fn chat_stream_chunk_without_role_is_parsed() {
431 let bytes = b"data: {\"id\":\"a3ceca7f-32b2-4a6c-89e7-bc8eb5327f76\",\"choices\":[{\"finish_reason\":null,\"index\":0,\"delta\":{\"content\":\"Hello! How can I help you today? If you have any questions or need assistance, feel free to ask.\"},\"logprobs\":null}],\"created\":1729784197,\"model\":\"pharia-1-llm-7b-control\",\"system_fingerprint\":null,\"object\":\"chat.completion.chunk\",\"usage\":null}\n\n";
433
434 let events = HttpClient::parse_stream_event::<StreamChatResponse>(bytes);
436 let event = events.first().unwrap().as_ref().unwrap();
437
438 assert_eq!(event.choices.len(), 1);
440 assert!(
441 matches!(&event.choices[0], DeserializedChatChunk::Delta { delta: StreamMessage { content, .. }, .. } if content == "Hello! How can I help you today? If you have any questions or need assistance, feel free to ask.")
442 );
443 }
444
445 #[test]
446 fn chat_stream_chunk_without_content_but_with_finish_reason_is_parsed() {
447 let bytes = b"data: {\"id\":\"a3ceca7f-32b2-4a6c-89e7-bc8eb5327f76\",\"choices\":[{\"finish_reason\":\"stop\",\"index\":0,\"delta\":{},\"logprobs\":null}],\"created\":1729784197,\"model\":\"pharia-1-llm-7b-control\",\"system_fingerprint\":null,\"object\":\"chat.completion.chunk\",\"usage\":null}\n\n";
449
450 let events = HttpClient::parse_stream_event::<StreamChatResponse>(bytes);
452 let event = events.first().unwrap().as_ref().unwrap();
453
454 assert!(
456 matches!(&event.choices[0], DeserializedChatChunk::Finished { finish_reason } if finish_reason == "stop")
457 );
458 }
459}