aleph_alpha_client/
http.rs

1use std::{borrow::Cow, pin::Pin, time::Duration};
2
3use futures_util::{stream::StreamExt, Stream};
4use reqwest::{header, ClientBuilder, RequestBuilder, Response, StatusCode};
5use serde::Deserialize;
6use thiserror::Error as ThisError;
7use tokenizers::Tokenizer;
8
9use crate::{How, StreamJob};
10use async_stream::stream;
11
12/// A job send to the Aleph Alpha Api using the http client. A job wraps all the knowledge required
13/// for the Aleph Alpha API to specify its result. Notably it includes the model(s) the job is
14/// executed on. This allows this trait to hold in the presence of services, which use more than one
15/// model and task type to achieve their result. On the other hand a bare [`crate::TaskCompletion`]
16/// can not implement this trait directly, since its result would depend on what model is chosen to
17/// execute it. You can remedy this by turning completion task into a job, calling
18/// [`Task::with_model`].
19pub trait Job {
20    /// Output returned by [`crate::Client::output_of`]
21    type Output;
22
23    /// Expected answer of the Aleph Alpha API
24    type ResponseBody: for<'de> Deserialize<'de>;
25
26    /// Prepare the request for the Aleph Alpha API. Authentication headers can be assumed to be
27    /// already set.
28    fn build_request(&self, client: &reqwest::Client, base: &str) -> RequestBuilder;
29
30    /// Parses the response of the server into higher level structs for the user.
31    fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output;
32}
33
34/// A task send to the Aleph Alpha Api using the http client. Requires to specify a model before it
35/// can be executed.
36pub trait Task {
37    /// Output returned by [`crate::Client::output_of`]
38    type Output;
39
40    /// Expected answer of the Aleph Alpha API
41    type ResponseBody: for<'de> Deserialize<'de>;
42
43    /// Prepare the request for the Aleph Alpha API. Authentication headers can be assumed to be
44    /// already set.
45    fn build_request(&self, client: &reqwest::Client, base: &str, model: &str) -> RequestBuilder;
46
47    /// Parses the response of the server into higher level structs for the user.
48    fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output;
49
50    /// Turn your task into [`Job`] by annotating it with a model name.
51    fn with_model<'a>(&'a self, model: &'a str) -> MethodJob<'a, Self>
52    where
53        Self: Sized,
54    {
55        MethodJob { model, task: self }
56    }
57}
58
59/// Enriches the `Task` to a `Job` by appending the model it should be executed with. Use this as
60/// input for [`Client::output_of`].
61pub struct MethodJob<'a, T> {
62    /// Name of the Aleph Alpha Model. E.g. "luminous-base".
63    pub model: &'a str,
64    /// Task to be executed against the model.
65    pub task: &'a T,
66}
67
68impl<T> Job for MethodJob<'_, T>
69where
70    T: Task,
71{
72    type Output = T::Output;
73
74    type ResponseBody = T::ResponseBody;
75
76    fn build_request(&self, client: &reqwest::Client, base: &str) -> RequestBuilder {
77        self.task.build_request(client, base, self.model)
78    }
79
80    fn body_to_output(&self, response: T::ResponseBody) -> T::Output {
81        self.task.body_to_output(response)
82    }
83}
84
85/// Sends HTTP request to the Aleph Alpha API
86pub struct HttpClient {
87    base: String,
88    http: reqwest::Client,
89    api_token: Option<String>,
90}
91
92impl HttpClient {
93    /// In production you typically would want set this to <https://inference-api.pharia.your-company.com>.
94    /// Yet you may want to use a different instance for testing.
95    pub fn new(host: String, api_token: Option<String>) -> Result<Self, Error> {
96        let http = ClientBuilder::new().build()?;
97
98        Ok(Self {
99            base: host,
100            http,
101            api_token,
102        })
103    }
104
105    /// Construct and execute a request building on top of a `RequestBuilder`
106    async fn response(&self, builder: RequestBuilder, how: &How) -> Result<Response, Error> {
107        let query = if how.be_nice {
108            [("nice", "true")].as_slice()
109        } else {
110            // nice=false is default, so we just omit it.
111            [].as_slice()
112        };
113
114        let api_token = how
115            .api_token
116            .as_ref()
117            .or(self.api_token.as_ref())
118            .expect("API token needs to be set on client construction or per request");
119        let response = builder
120            .query(query)
121            .header(header::AUTHORIZATION, Self::header_from_token(api_token))
122            .timeout(how.client_timeout)
123            .send()
124            .await
125            .map_err(|reqwest_error| {
126                if reqwest_error.is_timeout() {
127                    Error::ClientTimeout(how.client_timeout)
128                } else {
129                    reqwest_error.into()
130                }
131            })?;
132        translate_http_error(response).await
133    }
134
135    /// Execute a task with the aleph alpha API and fetch its result.
136    ///
137    /// ```no_run
138    /// use aleph_alpha_client::{Client, How, TaskCompletion, Task, Error};
139    ///
140    /// async fn print_completion() -> Result<(), Error> {
141    ///     // Authenticate against API. Fetches token.
142    ///     let client = Client::from_env()?;
143    ///
144    ///     // Name of the model we we want to use. Large models give usually better answer, but are
145    ///     // also slower and more costly.
146    ///     let model = "luminous-base";
147    ///
148    ///     // The task we want to perform. Here we want to continue the sentence: "An apple a day
149    ///     // ..."
150    ///     let task = TaskCompletion::from_text("An apple a day");
151    ///
152    ///     // Retrieve answer from API
153    ///     let response = client.output_of(&task.with_model(model), &How::default()).await?;
154    ///
155    ///     // Print entire sentence with completion
156    ///     println!("An apple a day{}", response.completion);
157    ///     Ok(())
158    /// }
159    /// ```
160    pub async fn output_of<T: Job>(&self, task: &T, how: &How) -> Result<T::Output, Error> {
161        let builder = task.build_request(&self.http, &self.base);
162        let response = self.response(builder, how).await?;
163        let response_body: T::ResponseBody = response.json().await?;
164        let answer = task.body_to_output(response_body);
165        Ok(answer)
166    }
167
168    pub async fn stream_output_of<'task, T: StreamJob + Send + Sync + 'task>(
169        &self,
170        task: T,
171        how: &How,
172    ) -> Result<Pin<Box<dyn Stream<Item = Result<T::Output, Error>> + Send + 'task>>, Error>
173    where
174        T::Output: 'static,
175    {
176        let builder = task.build_request(&self.http, &self.base);
177        let response = self.response(builder, how).await?;
178        let mut stream = response.bytes_stream();
179
180        Ok(Box::pin(stream! {
181            while let Some(item) = stream.next().await {
182                match item {
183                    Ok(bytes) => {
184                        let events = Self::parse_stream_event::<T::ResponseBody>(bytes.as_ref());
185                        for event in events {
186                            match event {
187                                // Check if the output should be yielded or skipped
188                                Ok(b) => if let Some(output) = task.body_to_output(b) {
189                                    yield Ok(output);
190                                }
191                                Err(e) => {
192                                    yield Err(e);
193                                }
194                            }
195
196                        }
197                    }
198                    Err(e) => {
199                        yield Err(e.into());
200                    }
201                }
202            }
203        }))
204    }
205
206    /// Take a byte slice (of a SSE) and parse it into a provided response body.
207    /// Each SSE event is expected to contain one or multiple JSON bodies prefixed by `data: `.
208    fn parse_stream_event<StreamBody>(bytes: &[u8]) -> Vec<Result<StreamBody, Error>>
209    where
210        StreamBody: for<'de> Deserialize<'de>,
211    {
212        String::from_utf8_lossy(bytes)
213            .split("data: ")
214            .skip(1)
215            // The last stream event for the chat endpoint (not for the completion endpoint) always is "[DONE]"
216            // While we could model this as a variant of the `ChatStreamChunk` enum, the value of this is
217            // unclear, so we ignore it here.
218            .filter(|s| s.trim() != "[DONE]")
219            .map(|s| {
220                serde_json::from_str(s).map_err(|e| Error::InvalidStream {
221                    deserialization_error: e.to_string(),
222                })
223            })
224            .collect()
225    }
226
227    fn header_from_token(api_token: &str) -> header::HeaderValue {
228        let mut auth_value = header::HeaderValue::from_str(&format!("Bearer {api_token}")).unwrap();
229        // Consider marking security-sensitive headers with `set_sensitive`.
230        auth_value.set_sensitive(true);
231        auth_value
232    }
233
234    pub async fn tokenizer_by_model(
235        &self,
236        model: &str,
237        api_token: Option<String>,
238    ) -> Result<Tokenizer, Error> {
239        let api_token = api_token
240            .as_ref()
241            .or(self.api_token.as_ref())
242            .expect("API token needs to be set on client construction or per request");
243        let response = self
244            .http
245            .get(format!("{}/models/{model}/tokenizer", self.base))
246            .header(header::AUTHORIZATION, Self::header_from_token(api_token))
247            .send()
248            .await?;
249        let response = translate_http_error(response).await?;
250        let bytes = response.bytes().await?;
251        let tokenizer = Tokenizer::from_bytes(bytes).map_err(|e| Error::InvalidTokenizer {
252            deserialization_error: e.to_string(),
253        })?;
254        Ok(tokenizer)
255    }
256}
257
258async fn translate_http_error(response: reqwest::Response) -> Result<reqwest::Response, Error> {
259    let status = response.status();
260    if !status.is_success() {
261        // Store body in a variable, so we can use it, even if it is not an Error emitted by
262        // the API, but an intermediate Proxy like NGinx, so we can still forward the error
263        // message.
264        let body = response.text().await?;
265        // If the response is an error emitted by the API, this deserialization should succeed.
266        let api_error: Result<ApiError, _> = serde_json::from_str(&body);
267        let translated_error = match status {
268            StatusCode::TOO_MANY_REQUESTS => Error::TooManyRequests,
269            StatusCode::SERVICE_UNAVAILABLE => {
270                // Presence of `api_error` implies the error originated from the API itself (rather
271                // than the intermediate proxy) and so we can decode it as such.
272                if api_error.is_ok_and(|error| error.code == "QUEUE_FULL") {
273                    Error::Busy
274                } else {
275                    Error::Unavailable
276                }
277            }
278            _ => Error::Http {
279                status: status.as_u16(),
280                body,
281            },
282        };
283        Err(translated_error)
284    } else {
285        Ok(response)
286    }
287}
288
289/// We are only interested in the status codes of the API.
290#[derive(Deserialize, Debug)]
291struct ApiError<'a> {
292    /// Unique string in capital letters emitted by the API to signal different kinds of errors in a
293    /// finer granularity then the HTTP status codes alone would allow for.
294    ///
295    /// E.g. Differentiating between request rate limiting and parallel tasks limiting which both
296    /// are 429 (the former is emitted by NGinx though).
297    code: Cow<'a, str>,
298}
299
300/// Errors returned by the Aleph Alpha Client
301#[derive(ThisError, Debug)]
302pub enum Error {
303    /// User exceeds his current Task Quota.
304    #[error(
305        "You are trying to send too many requests to the API in to short an interval. Slow down a \
306        bit, otherwise these error will persist. Sorry for this, but we try to prevent DOS attacks."
307    )]
308    TooManyRequests,
309    /// Model is busy. Most likely due to many other users requesting its services right now.
310    #[error(
311        "Sorry the request to the Aleph Alpha API has been rejected due to the requested model \
312        being very busy at the moment. We found it unlikely that your request would finish in a \
313        reasonable timeframe, so it was rejected right away, rather than make you wait. You are \
314        welcome to retry your request any time."
315    )]
316    Busy,
317    /// The API itself is unavailable, most likely due to restart.
318    #[error(
319        "The service is currently unavailable. This is likely due to restart. Please try again \
320        later."
321    )]
322    Unavailable,
323    #[error("No response received within given timeout: {0:?}")]
324    ClientTimeout(Duration),
325    /// An error on the Http Protocol level.
326    #[error("HTTP request failed with status code {}. Body:\n{}", status, body)]
327    Http { status: u16, body: String },
328    #[error(
329        "Tokenizer could not be correctly deserialized. Caused by:\n{}",
330        deserialization_error
331    )]
332    InvalidTokenizer { deserialization_error: String },
333    /// Deserialization error of the stream event.
334    #[error(
335        "Stream event could not be correctly deserialized. Caused by:\n{}.",
336        deserialization_error
337    )]
338    InvalidStream { deserialization_error: String },
339    /// Most likely either TLS errors creating the Client, or IO errors.
340    #[error(transparent)]
341    Other(#[from] reqwest::Error),
342}
343
344#[cfg(test)]
345mod tests {
346    use crate::{
347        chat::{DeserializedChatChunk, StreamChatResponse, StreamMessage},
348        completion::DeserializedCompletionEvent,
349    };
350
351    use super::*;
352
353    #[test]
354    fn stream_chunk_event_is_parsed() {
355        // Given some bytes
356        let bytes = b"data: {\"type\":\"stream_chunk\",\"index\":0,\"completion\":\" The New York Times, May 15\"}\n\n";
357
358        // When they are parsed
359        let events = HttpClient::parse_stream_event::<DeserializedCompletionEvent>(bytes);
360        let event = events.first().unwrap().as_ref().unwrap();
361
362        // Then the event is a stream chunk
363        match event {
364            DeserializedCompletionEvent::StreamChunk { completion, .. } => {
365                assert_eq!(completion, " The New York Times, May 15")
366            }
367            _ => panic!("Expected a stream chunk"),
368        }
369    }
370
371    #[test]
372    fn completion_summary_event_is_parsed() {
373        // Given some bytes with a stream summary and a completion summary
374        let bytes = b"data: {\"type\":\"stream_summary\",\"index\":0,\"model_version\":\"2022-04\",\"finish_reason\":\"maximum_tokens\"}\n\ndata: {\"type\":\"completion_summary\",\"num_tokens_prompt_total\":1,\"num_tokens_generated\":7}\n\n";
375
376        // When they are parsed
377        let events = HttpClient::parse_stream_event::<DeserializedCompletionEvent>(bytes);
378
379        // Then the first event is a stream summary and the last event is a completion summary
380        let first = events.first().unwrap().as_ref().unwrap();
381        match first {
382            DeserializedCompletionEvent::StreamSummary { finish_reason } => {
383                assert_eq!(finish_reason, "maximum_tokens")
384            }
385            _ => panic!("Expected a completion summary"),
386        }
387        let second = events.last().unwrap().as_ref().unwrap();
388        match second {
389            DeserializedCompletionEvent::CompletionSummary {
390                num_tokens_generated,
391                ..
392            } => {
393                assert_eq!(*num_tokens_generated, 7)
394            }
395            _ => panic!("Expected a completion summary"),
396        }
397    }
398
399    #[test]
400    fn chat_usage_event_is_parsed() {
401        // Given some bytes
402        let bytes = b"data: {\"id\": \"67c5b5f2-6672-4b0b-82b1-cc844127b214\",\"choices\": [],\"created\": 1739539146,\"model\": \"pharia-1-llm-7b-control\",\"system_fingerprint\": \".unknown.\",\"object\": \"chat.completion.chunk\",\"usage\": {\"prompt_tokens\": 20,\"completion_tokens\": 10,\"total_tokens\": 30}}";
403
404        // When they are parsed
405        let events = HttpClient::parse_stream_event::<StreamChatResponse>(bytes);
406        let event = events.first().unwrap().as_ref().unwrap();
407
408        // Then the event has a usage
409        assert_eq!(event.usage.as_ref().unwrap().prompt_tokens, 20);
410        assert_eq!(event.usage.as_ref().unwrap().completion_tokens, 10);
411    }
412
413    #[test]
414    fn chat_stream_chunk_event_is_parsed() {
415        // Given some bytes
416        let bytes = b"data: {\"id\":\"831e41b4-2382-4b08-990e-0a3859967f43\",\"choices\":[{\"finish_reason\":null,\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"\"},\"logprobs\":null}],\"created\":1729782822,\"model\":\"pharia-1-llm-7b-control\",\"system_fingerprint\":null,\"object\":\"chat.completion.chunk\",\"usage\":null}\n\n";
417
418        // When they are parsed
419        let events = HttpClient::parse_stream_event::<StreamChatResponse>(bytes);
420        let event = events.first().unwrap().as_ref().unwrap();
421
422        // Then the event is a chat stream chunk
423        assert_eq!(event.choices.len(), 1);
424        assert!(
425            matches!(&event.choices[0], DeserializedChatChunk::Delta { delta: StreamMessage { role: Some(role), .. }, .. } if role == "assistant")
426        );
427    }
428
429    #[test]
430    fn chat_stream_chunk_without_role_is_parsed() {
431        // Given some bytes without a role
432        let bytes = b"data: {\"id\":\"a3ceca7f-32b2-4a6c-89e7-bc8eb5327f76\",\"choices\":[{\"finish_reason\":null,\"index\":0,\"delta\":{\"content\":\"Hello! How can I help you today? If you have any questions or need assistance, feel free to ask.\"},\"logprobs\":null}],\"created\":1729784197,\"model\":\"pharia-1-llm-7b-control\",\"system_fingerprint\":null,\"object\":\"chat.completion.chunk\",\"usage\":null}\n\n";
433
434        // When they are parsed
435        let events = HttpClient::parse_stream_event::<StreamChatResponse>(bytes);
436        let event = events.first().unwrap().as_ref().unwrap();
437
438        // Then the event is a chat stream chunk
439        assert_eq!(event.choices.len(), 1);
440        assert!(
441            matches!(&event.choices[0], DeserializedChatChunk::Delta { delta: StreamMessage { content, .. }, .. } if content == "Hello! How can I help you today? If you have any questions or need assistance, feel free to ask.")
442        );
443    }
444
445    #[test]
446    fn chat_stream_chunk_without_content_but_with_finish_reason_is_parsed() {
447        // Given some bytes without a role or content but with a finish reason
448        let bytes = b"data: {\"id\":\"a3ceca7f-32b2-4a6c-89e7-bc8eb5327f76\",\"choices\":[{\"finish_reason\":\"stop\",\"index\":0,\"delta\":{},\"logprobs\":null}],\"created\":1729784197,\"model\":\"pharia-1-llm-7b-control\",\"system_fingerprint\":null,\"object\":\"chat.completion.chunk\",\"usage\":null}\n\n";
449
450        // When they are parsed
451        let events = HttpClient::parse_stream_event::<StreamChatResponse>(bytes);
452        let event = events.first().unwrap().as_ref().unwrap();
453
454        // Then the event is a chat stream chunk with a done event
455        assert!(
456            matches!(&event.choices[0], DeserializedChatChunk::Finished { finish_reason } if  finish_reason == "stop")
457        );
458    }
459}