1use std::{borrow::Cow, pin::Pin, time::Duration};
2
3use bytes::Bytes;
4use eventsource_stream::Eventsource;
5use futures_util::{stream::StreamExt, Stream};
6use reqwest::{header, ClientBuilder, RequestBuilder, Response, StatusCode};
7use serde::Deserialize;
8use thiserror::Error as ThisError;
9use tokenizers::Tokenizer;
10
11use crate::{How, StreamJob, TraceContext};
12use async_stream::stream;
13
14pub trait Job {
22 type Output;
24
25 type ResponseBody: for<'de> Deserialize<'de>;
27
28 fn build_request(&self, client: &reqwest::Client, base: &str) -> RequestBuilder;
31
32 fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output;
34}
35
36pub trait Task {
39 type Output;
41
42 type ResponseBody: for<'de> Deserialize<'de>;
44
45 fn build_request(&self, client: &reqwest::Client, base: &str, model: &str) -> RequestBuilder;
48
49 fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output;
51
52 fn with_model<'a>(&'a self, model: &'a str) -> MethodJob<'a, Self>
54 where
55 Self: Sized,
56 {
57 MethodJob { model, task: self }
58 }
59}
60
61pub struct MethodJob<'a, T> {
64 pub model: &'a str,
66 pub task: &'a T,
68}
69
70impl<T> Job for MethodJob<'_, T>
71where
72 T: Task,
73{
74 type Output = T::Output;
75
76 type ResponseBody = T::ResponseBody;
77
78 fn build_request(&self, client: &reqwest::Client, base: &str) -> RequestBuilder {
79 self.task.build_request(client, base, self.model)
80 }
81
82 fn body_to_output(&self, response: T::ResponseBody) -> T::Output {
83 self.task.body_to_output(response)
84 }
85}
86
87pub struct HttpClient {
89 base: String,
90 http: reqwest::Client,
91 api_token: Option<String>,
92}
93
94impl HttpClient {
95 pub fn new(host: String, api_token: Option<String>) -> Result<Self, Error> {
98 let http = ClientBuilder::new().build()?;
99
100 Ok(Self {
101 base: host,
102 http,
103 api_token,
104 })
105 }
106
107 async fn response(&self, builder: RequestBuilder, how: &How) -> Result<Response, Error> {
109 let query = if how.be_nice {
110 [("nice", "true")].as_slice()
111 } else {
112 [].as_slice()
114 };
115
116 let api_token = how
117 .api_token
118 .as_ref()
119 .or(self.api_token.as_ref())
120 .expect("API token needs to be set on client construction or per request");
121 let mut builder = builder
122 .query(query)
123 .header(header::AUTHORIZATION, Self::header_from_token(api_token))
124 .timeout(how.client_timeout);
125
126 if let Some(trace_context) = &how.trace_context {
127 for (key, value) in trace_context.as_w3c_headers() {
128 builder = builder.header(key, value);
129 }
130 }
131
132 let response = builder.send().await.map_err(|reqwest_error| {
133 if reqwest_error.is_timeout() {
134 Error::ClientTimeout(how.client_timeout)
135 } else {
136 reqwest_error.into()
137 }
138 })?;
139 translate_http_error(response).await
140 }
141
142 pub async fn output_of<T: Job>(&self, task: &T, how: &How) -> Result<T::Output, Error> {
168 let builder = task.build_request(&self.http, &self.base);
169 let response = self.response(builder, how).await?;
170 let response_body: T::ResponseBody = response.json().await?;
171 let answer = task.body_to_output(response_body);
172 Ok(answer)
173 }
174
175 pub async fn stream_output_of<'task, T: StreamJob + Send + Sync + 'task>(
177 &self,
178 task: T,
179 how: &How,
180 ) -> Result<Pin<Box<dyn Stream<Item = Result<T::Output, Error>> + Send + 'task>>, Error>
181 where
182 T::Output: 'static,
183 {
184 let builder = task.build_request(&self.http, &self.base);
185 let response = self.response(builder, how).await?;
186 let stream = Box::pin(response.bytes_stream());
187 Self::parse_stream_output(stream, task).await
188 }
189
190 pub async fn parse_stream_output<'task, T: StreamJob + Send + Sync + 'task>(
195 stream: Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>> + Send>>,
196 task: T,
197 ) -> Result<Pin<Box<dyn Stream<Item = Result<T::Output, Error>> + Send + 'task>>, Error>
198 where
199 T::Output: 'static,
200 {
201 let mut stream = stream.eventsource();
202
203 Ok(Box::pin(stream! {
204 while let Some(item) = stream.next().await {
205 match item {
206 Ok(event) => {
207 if event.data.trim() == "[DONE]" {
210 break;
211 }
212 match serde_json::from_str::<T::ResponseBody>(&event.data) {
219 Ok(b) => yield Ok(task.body_to_output(b)),
220 Err(e) => {
221 yield Err(Error::InvalidStream {
222 deserialization_error: e.to_string(),
223 });
224 }
225 }
226 }
227 Err(e) => {
228 yield Err(Error::InvalidStream {
229 deserialization_error: e.to_string(),
230 });
231 }
232 }
233 }
234 }))
235 }
236
237 fn header_from_token(api_token: &str) -> header::HeaderValue {
238 let mut auth_value = header::HeaderValue::from_str(&format!("Bearer {api_token}")).unwrap();
239 auth_value.set_sensitive(true);
241 auth_value
242 }
243
244 pub async fn tokenizer_by_model(
245 &self,
246 model: &str,
247 api_token: Option<String>,
248 context: Option<TraceContext>,
249 ) -> Result<Tokenizer, Error> {
250 let api_token = api_token
251 .as_ref()
252 .or(self.api_token.as_ref())
253 .expect("API token needs to be set on client construction or per request");
254 let mut builder = self
255 .http
256 .get(format!("{}/models/{model}/tokenizer", self.base))
257 .header(header::AUTHORIZATION, Self::header_from_token(api_token));
258
259 if let Some(trace_context) = &context {
260 for (key, value) in trace_context.as_w3c_headers() {
261 builder = builder.header(key, value);
262 }
263 }
264
265 let response = builder.send().await?;
266 let response = translate_http_error(response).await?;
267 let bytes = response.bytes().await?;
268 let tokenizer = Tokenizer::from_bytes(bytes).map_err(|e| Error::InvalidTokenizer {
269 deserialization_error: e.to_string(),
270 })?;
271 Ok(tokenizer)
272 }
273}
274
275async fn translate_http_error(response: reqwest::Response) -> Result<reqwest::Response, Error> {
276 let status = response.status();
277 if !status.is_success() {
278 let body = response.text().await?;
282 let api_error: Result<ApiError, _> = serde_json::from_str(&body);
284 let translated_error = match status {
285 StatusCode::TOO_MANY_REQUESTS => Error::TooManyRequests,
286 StatusCode::NOT_FOUND => {
287 if api_error.is_ok_and(|error| error.code == "UNKNOWN_MODEL") {
288 Error::ModelNotFound
289 } else {
290 Error::Http {
291 status: status.as_u16(),
292 body,
293 }
294 }
295 }
296 StatusCode::SERVICE_UNAVAILABLE => {
297 if api_error.is_ok_and(|error| error.code == "QUEUE_FULL") {
300 Error::Busy
301 } else {
302 Error::Unavailable
303 }
304 }
305 _ => Error::Http {
306 status: status.as_u16(),
307 body,
308 },
309 };
310 Err(translated_error)
311 } else {
312 Ok(response)
313 }
314}
315
316#[derive(Deserialize, Debug)]
318struct ApiError<'a> {
319 code: Cow<'a, str>,
325}
326
327#[derive(ThisError, Debug)]
329pub enum Error {
330 #[error(
331 "The model was not found. Please check the provided model name. You can query the list \
332 of available models at the `models` endpoint. If you believe the model should be
333 available, contact the operator of your inference server."
334 )]
335 ModelNotFound,
336 #[error(
338 "You are trying to send too many requests to the API in to short an interval. Slow down a \
339 bit, otherwise these error will persist. Sorry for this, but we try to prevent DOS attacks."
340 )]
341 TooManyRequests,
342 #[error(
344 "Sorry the request to the Aleph Alpha API has been rejected due to the requested model \
345 being very busy at the moment. We found it unlikely that your request would finish in a \
346 reasonable timeframe, so it was rejected right away, rather than make you wait. You are \
347 welcome to retry your request any time."
348 )]
349 Busy,
350 #[error(
352 "The service is currently unavailable. This is likely due to restart. Please try again \
353 later."
354 )]
355 Unavailable,
356 #[error("No response received within given timeout: {0:?}")]
357 ClientTimeout(Duration),
358 #[error("HTTP request failed with status code {}. Body:\n{}", status, body)]
360 Http { status: u16, body: String },
361 #[error(
362 "Tokenizer could not be correctly deserialized. Caused by:\n{}",
363 deserialization_error
364 )]
365 InvalidTokenizer { deserialization_error: String },
366 #[error(
368 "Stream event could not be correctly deserialized. Caused by:\n{}.",
369 deserialization_error
370 )]
371 InvalidStream { deserialization_error: String },
372 #[error(transparent)]
374 Other(#[from] reqwest::Error),
375}
376
377#[cfg(test)]
378mod tests {
379 use crate::{ChatEvent, CompletionEvent, Message, TaskChat, TaskCompletion};
380
381 use super::*;
382
383 #[tokio::test]
384 async fn stream_chunk_event_is_parsed() {
385 let task = TaskCompletion::from_text("An apple a day");
387 let job = task.with_model("pharia-1-llm-7b-control");
388 let bytes = "data: {\"type\":\"stream_chunk\",\"index\":0,\"completion\":\" The New York Times, May 15\"}\n\ndata: [DONE]";
389 let stream = Box::pin(futures_util::stream::once(
390 async move { Ok(Bytes::from(bytes)) },
391 ));
392
393 let stream = HttpClient::parse_stream_output(stream, job).await.unwrap();
395 let mut events = stream.collect::<Vec<_>>().await;
396
397 assert_eq!(events.len(), 1);
399 assert!(
400 matches!(events.remove(0).unwrap(), CompletionEvent::Delta { completion, .. } if completion == " The New York Times, May 15")
401 );
402 }
403
404 #[tokio::test]
405 async fn completion_summary_event_is_parsed() {
406 let task = TaskCompletion::from_text("An apple a day");
408 let job = task.with_model("pharia-1-llm-7b-control");
409 let bytes = "data: {\"type\":\"stream_summary\",\"index\":0,\"model_version\":\"2022-04\",\"finish_reason\":\"maximum_tokens\"}\n\ndata: {\"type\":\"completion_summary\",\"num_tokens_prompt_total\":1,\"num_tokens_generated\":7}\n\n";
410 let stream = Box::pin(futures_util::stream::once(
411 async move { Ok(Bytes::from(bytes)) },
412 ));
413
414 let stream = HttpClient::parse_stream_output(stream, job).await.unwrap();
416 let mut events = stream.collect::<Vec<_>>().await;
417
418 assert_eq!(events.len(), 2);
420 assert!(
421 matches!(events.remove(0).unwrap(), CompletionEvent::Finished { reason } if reason == "maximum_tokens")
422 );
423 assert!(
424 matches!(events.remove(0).unwrap(), CompletionEvent::Summary { usage, .. } if usage.prompt_tokens == 1 && usage.completion_tokens == 7)
425 );
426 }
427
428 #[tokio::test]
429 async fn chat_usage_event_is_parsed() {
430 let task = TaskChat::with_messages(vec![Message::user("An apple a day")]);
432 let job = task.with_model("pharia-1-llm-7b-control");
433 let bytes = "data: {\"id\": \"67c5b5f2-6672-4b0b-82b1-cc844127b214\",\"choices\": [],\"created\": 1739539146,\"model\": \"pharia-1-llm-7b-control\",\"system_fingerprint\": \".unknown.\",\"object\": \"chat.completion.chunk\",\"usage\": {\"prompt_tokens\": 20,\"completion_tokens\": 10,\"total_tokens\": 30}}\n\n";
434 let stream = Box::pin(futures_util::stream::once(
435 async move { Ok(Bytes::from(bytes)) },
436 ));
437
438 let stream = HttpClient::parse_stream_output(stream, job).await.unwrap();
440 let mut events = stream.collect::<Vec<_>>().await;
441
442 assert_eq!(events.len(), 1);
444 assert!(
445 matches!(events.remove(0).unwrap(), ChatEvent::Summary { usage } if usage.prompt_tokens == 20 && usage.completion_tokens == 10)
446 );
447 }
448
449 #[tokio::test]
450 async fn chat_stream_chunk_with_role_is_parsed() {
451 let task = TaskChat::with_messages(vec![Message::user("An apple a day")]);
453 let job = task.with_model("pharia-1-llm-7b-control");
454 let bytes = "data: {\"id\":\"831e41b4-2382-4b08-990e-0a3859967f43\",\"choices\":[{\"finish_reason\":null,\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"\"},\"logprobs\":null}],\"created\":1729782822,\"model\":\"pharia-1-llm-7b-control\",\"system_fingerprint\":null,\"object\":\"chat.completion.chunk\",\"usage\":null}\n\n";
455 let stream = Box::pin(futures_util::stream::once(
456 async move { Ok(Bytes::from(bytes)) },
457 ));
458
459 let stream = HttpClient::parse_stream_output(stream, job).await.unwrap();
461 let mut events = stream.collect::<Vec<_>>().await;
462
463 assert_eq!(events.len(), 1);
465 assert!(
466 matches!(events.remove(0).unwrap(), ChatEvent::MessageStart { role } if role == "assistant")
467 );
468 }
469
470 #[tokio::test]
471 async fn chat_stream_chunk_without_role_is_parsed() {
472 let task = TaskChat::with_messages(vec![Message::user("An apple a day")]);
474 let job = task.with_model("pharia-1-llm-7b-control");
475 let bytes = "data: {\"id\":\"a3ceca7f-32b2-4a6c-89e7-bc8eb5327f76\",\"choices\":[{\"finish_reason\":null,\"index\":0,\"delta\":{\"content\":\"Hello! How can I help you today? If you have any questions or need assistance, feel free to ask.\"},\"logprobs\":null}],\"created\":1729784197,\"model\":\"pharia-1-llm-7b-control\",\"system_fingerprint\":null,\"object\":\"chat.completion.chunk\",\"usage\":null}\n\n";
476 let stream = Box::pin(futures_util::stream::once(
477 async move { Ok(Bytes::from(bytes)) },
478 ));
479
480 let stream = HttpClient::parse_stream_output(stream, job).await.unwrap();
482 let mut events = stream.collect::<Vec<_>>().await;
483
484 assert_eq!(events.len(), 1);
486 assert!(
487 matches!(events.remove(0).unwrap(), ChatEvent::MessageDelta { content, logprobs, .. } if content == Some("Hello! How can I help you today? If you have any questions or need assistance, feel free to ask.".to_owned()) && logprobs.is_empty())
488 );
489 }
490
491 #[tokio::test]
492 async fn chat_stream_chunk_without_content_but_with_finish_reason_is_parsed() {
493 let task = TaskChat::with_messages(vec![Message::user("An apple a day")]);
495 let job = task.with_model("pharia-1-llm-7b-control");
496 let bytes = "data: {\"id\":\"a3ceca7f-32b2-4a6c-89e7-bc8eb5327f76\",\"choices\":[{\"finish_reason\":\"stop\",\"index\":0,\"delta\":{},\"logprobs\":null}],\"created\":1729784197,\"model\":\"pharia-1-llm-7b-control\",\"system_fingerprint\":null,\"object\":\"chat.completion.chunk\",\"usage\":null}\n\n";
497 let stream = Box::pin(futures_util::stream::once(
498 async move { Ok(Bytes::from(bytes)) },
499 ));
500
501 let stream = HttpClient::parse_stream_output(stream, job).await.unwrap();
503 let mut events = stream.collect::<Vec<_>>().await;
504
505 assert_eq!(events.len(), 1);
507 assert!(
508 matches!(events.remove(0).unwrap(), ChatEvent::MessageDelta { finish_reason, .. } if finish_reason == Some("stop".to_owned()))
509 );
510 }
511
512 #[tokio::test]
513 async fn sse_event_split_over_multiple_chunks() {
514 let task = TaskCompletion::from_text("An apple a day");
516 let job = task.with_model("pharia-1-llm-7b-control");
517 let chunks = vec![
518 "data: {\"type\":\"stream_chunk\",\"index\":0,\"completion\":\"",
519 " Hello world\"}\n\n",
520 ];
521 let stream = Box::pin(futures_util::stream::iter(
522 chunks.into_iter().map(|chunk| Ok(Bytes::from(chunk))),
523 ));
524
525 let stream = HttpClient::parse_stream_output(stream, job).await.unwrap();
527 let mut events = stream.collect::<Vec<_>>().await;
528
529 assert_eq!(events.len(), 1);
531 assert!(
532 matches!(events.remove(0).unwrap(), CompletionEvent::Delta { completion, .. } if completion == " Hello world")
533 );
534 }
535
536 #[tokio::test]
537 async fn two_sse_events_in_one_chunk() {
538 let task = TaskCompletion::from_text("An apple a day");
540 let job = task.with_model("pharia-1-llm-7b-control");
541 let bytes = "data: {\"type\":\"stream_chunk\",\"index\":0,\"completion\":\" First\"}\n\ndata: {\"type\":\"stream_chunk\",\"index\":0,\"completion\":\" Second\"}\n\n";
542 let stream = Box::pin(futures_util::stream::once(
543 async move { Ok(Bytes::from(bytes)) },
544 ));
545
546 let stream = HttpClient::parse_stream_output(stream, job).await.unwrap();
548 let mut events = stream.collect::<Vec<_>>().await;
549
550 assert_eq!(events.len(), 2);
552 assert!(
553 matches!(events.remove(0).unwrap(), CompletionEvent::Delta { completion, .. } if completion == " First")
554 );
555 assert!(
556 matches!(events.remove(0).unwrap(), CompletionEvent::Delta { completion, .. } if completion == " Second")
557 );
558 }
559}