aleph_alpha_client/
http.rs

1use std::{borrow::Cow, pin::Pin, time::Duration};
2
3use bytes::Bytes;
4use eventsource_stream::Eventsource;
5use futures_util::{stream::StreamExt, Stream};
6use reqwest::{header, ClientBuilder, RequestBuilder, Response, StatusCode};
7use serde::Deserialize;
8use thiserror::Error as ThisError;
9use tokenizers::Tokenizer;
10
11use crate::{How, StreamJob, TraceContext};
12use async_stream::stream;
13
14/// A job send to the Aleph Alpha Api using the http client. A job wraps all the knowledge required
15/// for the Aleph Alpha API to specify its result. Notably it includes the model(s) the job is
16/// executed on. This allows this trait to hold in the presence of services, which use more than one
17/// model and task type to achieve their result. On the other hand a bare [`crate::TaskCompletion`]
18/// can not implement this trait directly, since its result would depend on what model is chosen to
19/// execute it. You can remedy this by turning completion task into a job, calling
20/// [`Task::with_model`].
21pub trait Job {
22    /// Output returned by [`crate::Client::output_of`]
23    type Output;
24
25    /// Expected answer of the Aleph Alpha API
26    type ResponseBody: for<'de> Deserialize<'de>;
27
28    /// Prepare the request for the Aleph Alpha API. Authentication headers can be assumed to be
29    /// already set.
30    fn build_request(&self, client: &reqwest::Client, base: &str) -> RequestBuilder;
31
32    /// Parses the response of the server into higher level structs for the user.
33    fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output;
34}
35
36/// A task send to the Aleph Alpha Api using the http client. Requires to specify a model before it
37/// can be executed.
38pub trait Task {
39    /// Output returned by [`crate::Client::output_of`]
40    type Output;
41
42    /// Expected answer of the Aleph Alpha API
43    type ResponseBody: for<'de> Deserialize<'de>;
44
45    /// Prepare the request for the Aleph Alpha API. Authentication headers can be assumed to be
46    /// already set.
47    fn build_request(&self, client: &reqwest::Client, base: &str, model: &str) -> RequestBuilder;
48
49    /// Parses the response of the server into higher level structs for the user.
50    fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output;
51
52    /// Turn your task into [`Job`] by annotating it with a model name.
53    fn with_model<'a>(&'a self, model: &'a str) -> MethodJob<'a, Self>
54    where
55        Self: Sized,
56    {
57        MethodJob { model, task: self }
58    }
59}
60
61/// Enriches the `Task` to a `Job` by appending the model it should be executed with. Use this as
62/// input for [`Client::output_of`].
63pub struct MethodJob<'a, T> {
64    /// Name of the Aleph Alpha Model. E.g. "luminous-base".
65    pub model: &'a str,
66    /// Task to be executed against the model.
67    pub task: &'a T,
68}
69
70impl<T> Job for MethodJob<'_, T>
71where
72    T: Task,
73{
74    type Output = T::Output;
75
76    type ResponseBody = T::ResponseBody;
77
78    fn build_request(&self, client: &reqwest::Client, base: &str) -> RequestBuilder {
79        self.task.build_request(client, base, self.model)
80    }
81
82    fn body_to_output(&self, response: T::ResponseBody) -> T::Output {
83        self.task.body_to_output(response)
84    }
85}
86
87/// Sends HTTP request to the Aleph Alpha API
88pub struct HttpClient {
89    base: String,
90    http: reqwest::Client,
91    api_token: Option<String>,
92}
93
94impl HttpClient {
95    /// In production you typically would want set this to <https://inference-api.pharia.your-company.com>.
96    /// Yet you may want to use a different instance for testing.
97    pub fn new(host: String, api_token: Option<String>) -> Result<Self, Error> {
98        let http = ClientBuilder::new().build()?;
99
100        Ok(Self {
101            base: host,
102            http,
103            api_token,
104        })
105    }
106
107    /// Construct and execute a request building on top of a `RequestBuilder`
108    async fn response(&self, builder: RequestBuilder, how: &How) -> Result<Response, Error> {
109        let query = if how.be_nice {
110            [("nice", "true")].as_slice()
111        } else {
112            // nice=false is default, so we just omit it.
113            [].as_slice()
114        };
115
116        let api_token = how
117            .api_token
118            .as_ref()
119            .or(self.api_token.as_ref())
120            .expect("API token needs to be set on client construction or per request");
121        let mut builder = builder
122            .query(query)
123            .header(header::AUTHORIZATION, Self::header_from_token(api_token))
124            .timeout(how.client_timeout);
125
126        if let Some(trace_context) = &how.trace_context {
127            for (key, value) in trace_context.as_w3c_headers() {
128                builder = builder.header(key, value);
129            }
130        }
131
132        let response = builder.send().await.map_err(|reqwest_error| {
133            if reqwest_error.is_timeout() {
134                Error::ClientTimeout(how.client_timeout)
135            } else {
136                reqwest_error.into()
137            }
138        })?;
139        translate_http_error(response).await
140    }
141
142    /// Execute a task with the aleph alpha API and fetch its result.
143    ///
144    /// ```no_run
145    /// use aleph_alpha_client::{Client, How, TaskCompletion, Task, Error};
146    ///
147    /// async fn print_completion() -> Result<(), Error> {
148    ///     // Authenticate against API. Fetches token.
149    ///     let client = Client::from_env()?;
150    ///
151    ///     // Name of the model we we want to use. Large models give usually better answer, but are
152    ///     // also slower and more costly.
153    ///     let model = "luminous-base";
154    ///
155    ///     // The task we want to perform. Here we want to continue the sentence: "An apple a day
156    ///     // ..."
157    ///     let task = TaskCompletion::from_text("An apple a day");
158    ///
159    ///     // Retrieve answer from API
160    ///     let response = client.output_of(&task.with_model(model), &How::default()).await?;
161    ///
162    ///     // Print entire sentence with completion
163    ///     println!("An apple a day{}", response.completion);
164    ///     Ok(())
165    /// }
166    /// ```
167    pub async fn output_of<T: Job>(&self, task: &T, how: &How) -> Result<T::Output, Error> {
168        let builder = task.build_request(&self.http, &self.base);
169        let response = self.response(builder, how).await?;
170        let response_body: T::ResponseBody = response.json().await?;
171        let answer = task.body_to_output(response_body);
172        Ok(answer)
173    }
174
175    /// Execute a stream task with the aleph alpha API and stream its result.
176    pub async fn stream_output_of<'task, T: StreamJob + Send + Sync + 'task>(
177        &self,
178        task: T,
179        how: &How,
180    ) -> Result<Pin<Box<dyn Stream<Item = Result<T::Output, Error>> + Send + 'task>>, Error>
181    where
182        T::Output: 'static,
183    {
184        let builder = task.build_request(&self.http, &self.base);
185        let response = self.response(builder, how).await?;
186        let stream = Box::pin(response.bytes_stream());
187        Self::parse_stream_output(stream, task).await
188    }
189
190    /// Parse a stream of bytes into a stream of [`crate::StreamTask::Output`] objects.
191    ///
192    /// The [`crate::StreamTask::body_to_output`] allows each implementation to decide how to handle
193    /// the response events.
194    pub async fn parse_stream_output<'task, T: StreamJob + Send + Sync + 'task>(
195        stream: Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>> + Send>>,
196        task: T,
197    ) -> Result<Pin<Box<dyn Stream<Item = Result<T::Output, Error>> + Send + 'task>>, Error>
198    where
199        T::Output: 'static,
200    {
201        let mut stream = stream.eventsource();
202
203        Ok(Box::pin(stream! {
204            while let Some(item) = stream.next().await {
205                match item {
206                    Ok(event) => {
207                        // The last stream event for the chat endpoint always is "[DONE]". We assume
208                        // that the consumer of this library is not interested in this event.
209                        if event.data.trim() == "[DONE]" {
210                            break;
211                        }
212                        // Each task defines its response body as an associated type. This allows
213                        // us to define generic parsing logic for multiple streaming tasks. In
214                        // addition, tasks define an output type, which is a higher level
215                        // abstraction over the response body. With the `body_to_output` method,
216                        // tasks define logic to parse a response body into an output. This
217                        // decouples the parsing logic from the data handed to users.
218                        match serde_json::from_str::<T::ResponseBody>(&event.data) {
219                            Ok(b) => yield Ok(task.body_to_output(b)),
220                            Err(e) => {
221                                yield Err(Error::InvalidStream {
222                                    deserialization_error: e.to_string(),
223                                });
224                            }
225                        }
226                    }
227                    Err(e) => {
228                        yield Err(Error::InvalidStream {
229                            deserialization_error: e.to_string(),
230                        });
231                    }
232                }
233            }
234        }))
235    }
236
237    fn header_from_token(api_token: &str) -> header::HeaderValue {
238        let mut auth_value = header::HeaderValue::from_str(&format!("Bearer {api_token}")).unwrap();
239        // Consider marking security-sensitive headers with `set_sensitive`.
240        auth_value.set_sensitive(true);
241        auth_value
242    }
243
244    pub async fn tokenizer_by_model(
245        &self,
246        model: &str,
247        api_token: Option<String>,
248        context: Option<TraceContext>,
249    ) -> Result<Tokenizer, Error> {
250        let api_token = api_token
251            .as_ref()
252            .or(self.api_token.as_ref())
253            .expect("API token needs to be set on client construction or per request");
254        let mut builder = self
255            .http
256            .get(format!("{}/models/{model}/tokenizer", self.base))
257            .header(header::AUTHORIZATION, Self::header_from_token(api_token));
258
259        if let Some(trace_context) = &context {
260            for (key, value) in trace_context.as_w3c_headers() {
261                builder = builder.header(key, value);
262            }
263        }
264
265        let response = builder.send().await?;
266        let response = translate_http_error(response).await?;
267        let bytes = response.bytes().await?;
268        let tokenizer = Tokenizer::from_bytes(bytes).map_err(|e| Error::InvalidTokenizer {
269            deserialization_error: e.to_string(),
270        })?;
271        Ok(tokenizer)
272    }
273}
274
275async fn translate_http_error(response: reqwest::Response) -> Result<reqwest::Response, Error> {
276    let status = response.status();
277    if !status.is_success() {
278        // Store body in a variable, so we can use it, even if it is not an Error emitted by
279        // the API, but an intermediate Proxy like NGinx, so we can still forward the error
280        // message.
281        let body = response.text().await?;
282        // If the response is an error emitted by the API, this deserialization should succeed.
283        let api_error: Result<ApiError, _> = serde_json::from_str(&body);
284        let translated_error = match status {
285            StatusCode::TOO_MANY_REQUESTS => Error::TooManyRequests,
286            StatusCode::NOT_FOUND => {
287                if api_error.is_ok_and(|error| error.code == "UNKNOWN_MODEL") {
288                    Error::ModelNotFound
289                } else {
290                    Error::Http {
291                        status: status.as_u16(),
292                        body,
293                    }
294                }
295            }
296            StatusCode::SERVICE_UNAVAILABLE => {
297                // Presence of `api_error` implies the error originated from the API itself (rather
298                // than the intermediate proxy) and so we can decode it as such.
299                if api_error.is_ok_and(|error| error.code == "QUEUE_FULL") {
300                    Error::Busy
301                } else {
302                    Error::Unavailable
303                }
304            }
305            _ => Error::Http {
306                status: status.as_u16(),
307                body,
308            },
309        };
310        Err(translated_error)
311    } else {
312        Ok(response)
313    }
314}
315
316/// We are only interested in the status codes of the API.
317#[derive(Deserialize, Debug)]
318struct ApiError<'a> {
319    /// Unique string in capital letters emitted by the API to signal different kinds of errors in a
320    /// finer granularity then the HTTP status codes alone would allow for.
321    ///
322    /// E.g. Differentiating between request rate limiting and parallel tasks limiting which both
323    /// are 429 (the former is emitted by NGinx though).
324    code: Cow<'a, str>,
325}
326
327/// Errors returned by the Aleph Alpha Client
328#[derive(ThisError, Debug)]
329pub enum Error {
330    #[error(
331        "The model was not found. Please check the provided model name. You can query the list \
332        of available models at the `models` endpoint. If you believe the model should be
333        available, contact the operator of your inference server."
334    )]
335    ModelNotFound,
336    /// User exceeds his current Task Quota.
337    #[error(
338        "You are trying to send too many requests to the API in to short an interval. Slow down a \
339        bit, otherwise these error will persist. Sorry for this, but we try to prevent DOS attacks."
340    )]
341    TooManyRequests,
342    /// Model is busy. Most likely due to many other users requesting its services right now.
343    #[error(
344        "Sorry the request to the Aleph Alpha API has been rejected due to the requested model \
345        being very busy at the moment. We found it unlikely that your request would finish in a \
346        reasonable timeframe, so it was rejected right away, rather than make you wait. You are \
347        welcome to retry your request any time."
348    )]
349    Busy,
350    /// The API itself is unavailable, most likely due to restart.
351    #[error(
352        "The service is currently unavailable. This is likely due to restart. Please try again \
353        later."
354    )]
355    Unavailable,
356    #[error("No response received within given timeout: {0:?}")]
357    ClientTimeout(Duration),
358    /// An error on the Http Protocol level.
359    #[error("HTTP request failed with status code {}. Body:\n{}", status, body)]
360    Http { status: u16, body: String },
361    #[error(
362        "Tokenizer could not be correctly deserialized. Caused by:\n{}",
363        deserialization_error
364    )]
365    InvalidTokenizer { deserialization_error: String },
366    /// Deserialization error of the stream event.
367    #[error(
368        "Stream event could not be correctly deserialized. Caused by:\n{}.",
369        deserialization_error
370    )]
371    InvalidStream { deserialization_error: String },
372    /// Most likely either TLS errors creating the Client, or IO errors.
373    #[error(transparent)]
374    Other(#[from] reqwest::Error),
375}
376
377#[cfg(test)]
378mod tests {
379    use crate::{ChatEvent, CompletionEvent, Message, TaskChat, TaskCompletion};
380
381    use super::*;
382
383    #[tokio::test]
384    async fn stream_chunk_event_is_parsed() {
385        // Given a completion task and part of its response stream that includes a stream chunk
386        let task = TaskCompletion::from_text("An apple a day");
387        let job = task.with_model("pharia-1-llm-7b-control");
388        let bytes = "data: {\"type\":\"stream_chunk\",\"index\":0,\"completion\":\" The New York Times, May 15\"}\n\ndata: [DONE]";
389        let stream = Box::pin(futures_util::stream::once(
390            async move { Ok(Bytes::from(bytes)) },
391        ));
392
393        // When converting it to a stream of events
394        let stream = HttpClient::parse_stream_output(stream, job).await.unwrap();
395        let mut events = stream.collect::<Vec<_>>().await;
396
397        // Then a completion event is yielded
398        assert_eq!(events.len(), 1);
399        assert!(
400            matches!(events.remove(0).unwrap(), CompletionEvent::Delta { completion, .. } if completion == " The New York Times, May 15")
401        );
402    }
403
404    #[tokio::test]
405    async fn completion_summary_event_is_parsed() {
406        // Given a completion task and part of its response stream that includes a finish reason and a summary
407        let task = TaskCompletion::from_text("An apple a day");
408        let job = task.with_model("pharia-1-llm-7b-control");
409        let bytes = "data: {\"type\":\"stream_summary\",\"index\":0,\"model_version\":\"2022-04\",\"finish_reason\":\"maximum_tokens\"}\n\ndata: {\"type\":\"completion_summary\",\"num_tokens_prompt_total\":1,\"num_tokens_generated\":7}\n\n";
410        let stream = Box::pin(futures_util::stream::once(
411            async move { Ok(Bytes::from(bytes)) },
412        ));
413
414        // When converting it to a stream of events
415        let stream = HttpClient::parse_stream_output(stream, job).await.unwrap();
416        let mut events = stream.collect::<Vec<_>>().await;
417
418        // Then a finish reason event and a summary event are yielded
419        assert_eq!(events.len(), 2);
420        assert!(
421            matches!(events.remove(0).unwrap(), CompletionEvent::Finished { reason } if reason == "maximum_tokens")
422        );
423        assert!(
424            matches!(events.remove(0).unwrap(), CompletionEvent::Summary { usage, .. } if usage.prompt_tokens == 1 && usage.completion_tokens == 7)
425        );
426    }
427
428    #[tokio::test]
429    async fn chat_usage_event_is_parsed() {
430        // Given a chat task and part of its response stream that includes a usage event
431        let task = TaskChat::with_messages(vec![Message::user("An apple a day")]);
432        let job = task.with_model("pharia-1-llm-7b-control");
433        let bytes = "data: {\"id\": \"67c5b5f2-6672-4b0b-82b1-cc844127b214\",\"choices\": [],\"created\": 1739539146,\"model\": \"pharia-1-llm-7b-control\",\"system_fingerprint\": \".unknown.\",\"object\": \"chat.completion.chunk\",\"usage\": {\"prompt_tokens\": 20,\"completion_tokens\": 10,\"total_tokens\": 30}}\n\n";
434        let stream = Box::pin(futures_util::stream::once(
435            async move { Ok(Bytes::from(bytes)) },
436        ));
437
438        // When converting it to a stream of events
439        let stream = HttpClient::parse_stream_output(stream, job).await.unwrap();
440        let mut events = stream.collect::<Vec<_>>().await;
441
442        // Then a summary event is yielded
443        assert_eq!(events.len(), 1);
444        assert!(
445            matches!(events.remove(0).unwrap(), ChatEvent::Summary { usage } if usage.prompt_tokens == 20 && usage.completion_tokens == 10)
446        );
447    }
448
449    #[tokio::test]
450    async fn chat_stream_chunk_with_role_is_parsed() {
451        // Given a chat task and part of its response stream that includes a stream chunk with a role
452        let task = TaskChat::with_messages(vec![Message::user("An apple a day")]);
453        let job = task.with_model("pharia-1-llm-7b-control");
454        let bytes = "data: {\"id\":\"831e41b4-2382-4b08-990e-0a3859967f43\",\"choices\":[{\"finish_reason\":null,\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"\"},\"logprobs\":null}],\"created\":1729782822,\"model\":\"pharia-1-llm-7b-control\",\"system_fingerprint\":null,\"object\":\"chat.completion.chunk\",\"usage\":null}\n\n";
455        let stream = Box::pin(futures_util::stream::once(
456            async move { Ok(Bytes::from(bytes)) },
457        ));
458
459        // When converting it to a stream of events
460        let stream = HttpClient::parse_stream_output(stream, job).await.unwrap();
461        let mut events = stream.collect::<Vec<_>>().await;
462
463        // Then a message start event with a role is yielded
464        assert_eq!(events.len(), 1);
465        assert!(
466            matches!(events.remove(0).unwrap(), ChatEvent::MessageStart { role } if role == "assistant")
467        );
468    }
469
470    #[tokio::test]
471    async fn chat_stream_chunk_without_role_is_parsed() {
472        // Given a chat task and part of its response stream that includes a pure content stream chunk
473        let task = TaskChat::with_messages(vec![Message::user("An apple a day")]);
474        let job = task.with_model("pharia-1-llm-7b-control");
475        let bytes = "data: {\"id\":\"a3ceca7f-32b2-4a6c-89e7-bc8eb5327f76\",\"choices\":[{\"finish_reason\":null,\"index\":0,\"delta\":{\"content\":\"Hello! How can I help you today? If you have any questions or need assistance, feel free to ask.\"},\"logprobs\":null}],\"created\":1729784197,\"model\":\"pharia-1-llm-7b-control\",\"system_fingerprint\":null,\"object\":\"chat.completion.chunk\",\"usage\":null}\n\n";
476        let stream = Box::pin(futures_util::stream::once(
477            async move { Ok(Bytes::from(bytes)) },
478        ));
479
480        // When converting it to a stream of events
481        let stream = HttpClient::parse_stream_output(stream, job).await.unwrap();
482        let mut events = stream.collect::<Vec<_>>().await;
483
484        // Then a message delta event with content is yielded
485        assert_eq!(events.len(), 1);
486        assert!(
487            matches!(events.remove(0).unwrap(), ChatEvent::MessageDelta { content, logprobs, .. } if content == Some("Hello! How can I help you today? If you have any questions or need assistance, feel free to ask.".to_owned()) && logprobs.is_empty())
488        );
489    }
490
491    #[tokio::test]
492    async fn chat_stream_chunk_without_content_but_with_finish_reason_is_parsed() {
493        // Given a chat task and part of its response stream that includes a stream chunk with a finish reason
494        let task = TaskChat::with_messages(vec![Message::user("An apple a day")]);
495        let job = task.with_model("pharia-1-llm-7b-control");
496        let bytes = "data: {\"id\":\"a3ceca7f-32b2-4a6c-89e7-bc8eb5327f76\",\"choices\":[{\"finish_reason\":\"stop\",\"index\":0,\"delta\":{},\"logprobs\":null}],\"created\":1729784197,\"model\":\"pharia-1-llm-7b-control\",\"system_fingerprint\":null,\"object\":\"chat.completion.chunk\",\"usage\":null}\n\n";
497        let stream = Box::pin(futures_util::stream::once(
498            async move { Ok(Bytes::from(bytes)) },
499        ));
500
501        // When converting it to a stream of events
502        let stream = HttpClient::parse_stream_output(stream, job).await.unwrap();
503        let mut events = stream.collect::<Vec<_>>().await;
504
505        // Then a message end event with a stop reason is yielded
506        assert_eq!(events.len(), 1);
507        assert!(
508            matches!(events.remove(0).unwrap(), ChatEvent::MessageDelta { finish_reason, .. } if finish_reason == Some("stop".to_owned()))
509        );
510    }
511
512    #[tokio::test]
513    async fn sse_event_split_over_multiple_chunks() {
514        // Given a completion task and an SSE event split across multiple chunks
515        let task = TaskCompletion::from_text("An apple a day");
516        let job = task.with_model("pharia-1-llm-7b-control");
517        let chunks = vec![
518            "data: {\"type\":\"stream_chunk\",\"index\":0,\"completion\":\"",
519            " Hello world\"}\n\n",
520        ];
521        let stream = Box::pin(futures_util::stream::iter(
522            chunks.into_iter().map(|chunk| Ok(Bytes::from(chunk))),
523        ));
524
525        // When converting it to a stream of events
526        let stream = HttpClient::parse_stream_output(stream, job).await.unwrap();
527        let mut events = stream.collect::<Vec<_>>().await;
528
529        // Then a single completion event is yielded with the complete content
530        assert_eq!(events.len(), 1);
531        assert!(
532            matches!(events.remove(0).unwrap(), CompletionEvent::Delta { completion, .. } if completion == " Hello world")
533        );
534    }
535
536    #[tokio::test]
537    async fn two_sse_events_in_one_chunk() {
538        // Given a completion task and two SSE events in a single chunk
539        let task = TaskCompletion::from_text("An apple a day");
540        let job = task.with_model("pharia-1-llm-7b-control");
541        let bytes = "data: {\"type\":\"stream_chunk\",\"index\":0,\"completion\":\" First\"}\n\ndata: {\"type\":\"stream_chunk\",\"index\":0,\"completion\":\" Second\"}\n\n";
542        let stream = Box::pin(futures_util::stream::once(
543            async move { Ok(Bytes::from(bytes)) },
544        ));
545
546        // When converting it to a stream of events
547        let stream = HttpClient::parse_stream_output(stream, job).await.unwrap();
548        let mut events = stream.collect::<Vec<_>>().await;
549
550        // Then two completion events are yielded
551        assert_eq!(events.len(), 2);
552        assert!(
553            matches!(events.remove(0).unwrap(), CompletionEvent::Delta { completion, .. } if completion == " First")
554        );
555        assert!(
556            matches!(events.remove(0).unwrap(), CompletionEvent::Delta { completion, .. } if completion == " Second")
557        );
558    }
559}