aleph_alpha_client/
http.rs

1use std::{borrow::Cow, pin::Pin, time::Duration};
2
3use futures_util::{stream::StreamExt, Stream};
4use reqwest::{header, ClientBuilder, RequestBuilder, Response, StatusCode};
5use serde::Deserialize;
6use thiserror::Error as ThisError;
7use tokenizers::Tokenizer;
8
9use crate::{How, StreamJob};
10use async_stream::stream;
11
12/// A job send to the Aleph Alpha Api using the http client. A job wraps all the knowledge required
13/// for the Aleph Alpha API to specify its result. Notably it includes the model(s) the job is
14/// executed on. This allows this trait to hold in the presence of services, which use more than one
15/// model and task type to achieve their result. On the other hand a bare [`crate::TaskCompletion`]
16/// can not implement this trait directly, since its result would depend on what model is chosen to
17/// execute it. You can remedy this by turning completion task into a job, calling
18/// [`Task::with_model`].
19pub trait Job {
20    /// Output returned by [`crate::Client::output_of`]
21    type Output;
22
23    /// Expected answer of the Aleph Alpha API
24    type ResponseBody: for<'de> Deserialize<'de>;
25
26    /// Prepare the request for the Aleph Alpha API. Authentication headers can be assumed to be
27    /// already set.
28    fn build_request(&self, client: &reqwest::Client, base: &str) -> RequestBuilder;
29
30    /// Parses the response of the server into higher level structs for the user.
31    fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output;
32}
33
34/// A task send to the Aleph Alpha Api using the http client. Requires to specify a model before it
35/// can be executed.
36pub trait Task {
37    /// Output returned by [`crate::Client::output_of`]
38    type Output;
39
40    /// Expected answer of the Aleph Alpha API
41    type ResponseBody: for<'de> Deserialize<'de>;
42
43    /// Prepare the request for the Aleph Alpha API. Authentication headers can be assumed to be
44    /// already set.
45    fn build_request(&self, client: &reqwest::Client, base: &str, model: &str) -> RequestBuilder;
46
47    /// Parses the response of the server into higher level structs for the user.
48    fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output;
49
50    /// Turn your task into [`Job`] by annotating it with a model name.
51    fn with_model<'a>(&'a self, model: &'a str) -> MethodJob<'a, Self>
52    where
53        Self: Sized,
54    {
55        MethodJob { model, task: self }
56    }
57}
58
59/// Enriches the `Task` to a `Job` by appending the model it should be executed with. Use this as
60/// input for [`Client::output_of`].
61pub struct MethodJob<'a, T> {
62    /// Name of the Aleph Alpha Model. E.g. "luminous-base".
63    pub model: &'a str,
64    /// Task to be executed against the model.
65    pub task: &'a T,
66}
67
68impl<T> Job for MethodJob<'_, T>
69where
70    T: Task,
71{
72    type Output = T::Output;
73
74    type ResponseBody = T::ResponseBody;
75
76    fn build_request(&self, client: &reqwest::Client, base: &str) -> RequestBuilder {
77        self.task.build_request(client, base, self.model)
78    }
79
80    fn body_to_output(&self, response: T::ResponseBody) -> T::Output {
81        self.task.body_to_output(response)
82    }
83}
84
85/// Sends HTTP request to the Aleph Alpha API
86pub struct HttpClient {
87    base: String,
88    http: reqwest::Client,
89    api_token: Option<String>,
90}
91
92impl HttpClient {
93    /// In production you typically would want set this to <https://inference-api.pharia.your-company.com>.
94    /// Yet you may want to use a different instance for testing.
95    pub fn new(host: String, api_token: Option<String>) -> Result<Self, Error> {
96        let http = ClientBuilder::new().build()?;
97
98        Ok(Self {
99            base: host,
100            http,
101            api_token,
102        })
103    }
104
105    /// Construct and execute a request building on top of a `RequestBuilder`
106    async fn response(&self, builder: RequestBuilder, how: &How) -> Result<Response, Error> {
107        let query = if how.be_nice {
108            [("nice", "true")].as_slice()
109        } else {
110            // nice=false is default, so we just omit it.
111            [].as_slice()
112        };
113
114        let api_token = how
115            .api_token
116            .as_ref()
117            .or(self.api_token.as_ref())
118            .expect("API token needs to be set on client construction or per request");
119        let response = builder
120            .query(query)
121            .header(header::AUTHORIZATION, Self::header_from_token(api_token))
122            .timeout(how.client_timeout)
123            .send()
124            .await
125            .map_err(|reqwest_error| {
126                if reqwest_error.is_timeout() {
127                    Error::ClientTimeout(how.client_timeout)
128                } else {
129                    reqwest_error.into()
130                }
131            })?;
132        translate_http_error(response).await
133    }
134
135    /// Execute a task with the aleph alpha API and fetch its result.
136    ///
137    /// ```no_run
138    /// use aleph_alpha_client::{Client, How, TaskCompletion, Task, Error};
139    ///
140    /// async fn print_completion() -> Result<(), Error> {
141    ///     // Authenticate against API. Fetches token.
142    ///     let client = Client::from_env()?;
143    ///
144    ///     // Name of the model we we want to use. Large models give usually better answer, but are
145    ///     // also slower and more costly.
146    ///     let model = "luminous-base";
147    ///
148    ///     // The task we want to perform. Here we want to continue the sentence: "An apple a day
149    ///     // ..."
150    ///     let task = TaskCompletion::from_text("An apple a day");
151    ///
152    ///     // Retrieve answer from API
153    ///     let response = client.output_of(&task.with_model(model), &How::default()).await?;
154    ///
155    ///     // Print entire sentence with completion
156    ///     println!("An apple a day{}", response.completion);
157    ///     Ok(())
158    /// }
159    /// ```
160    pub async fn output_of<T: Job>(&self, task: &T, how: &How) -> Result<T::Output, Error> {
161        let builder = task.build_request(&self.http, &self.base);
162        let response = self.response(builder, how).await?;
163        let response_body: T::ResponseBody = response.json().await?;
164        let answer = task.body_to_output(response_body);
165        Ok(answer)
166    }
167
168    pub async fn stream_output_of<T: StreamJob>(
169        &self,
170        task: &T,
171        how: &How,
172    ) -> Result<Pin<Box<dyn Stream<Item = Result<T::Output, Error>> + Send>>, Error>
173    where
174        T::Output: 'static,
175    {
176        let builder = task.build_request(&self.http, &self.base);
177        let response = self.response(builder, how).await?;
178        let mut stream = response.bytes_stream();
179
180        Ok(Box::pin(stream! {
181            while let Some(item) = stream.next().await {
182                match item {
183                    Ok(bytes) => {
184                        let events = Self::parse_stream_event::<T::ResponseBody>(bytes.as_ref());
185                        for event in events {
186                            yield event.map(|b| T::body_to_output(b));
187                        }
188                    }
189                    Err(e) => {
190                        yield Err(e.into());
191                    }
192                }
193            }
194        }))
195    }
196
197    /// Take a byte slice (of a SSE) and parse it into a provided response body.
198    /// Each SSE event is expected to contain one or multiple JSON bodies prefixed by `data: `.
199    fn parse_stream_event<StreamBody>(bytes: &[u8]) -> Vec<Result<StreamBody, Error>>
200    where
201        StreamBody: for<'de> Deserialize<'de>,
202    {
203        String::from_utf8_lossy(bytes)
204            .split("data: ")
205            .skip(1)
206            .map(|s| {
207                serde_json::from_str(s).map_err(|e| Error::InvalidStream {
208                    deserialization_error: e.to_string(),
209                })
210            })
211            .collect()
212    }
213
214    fn header_from_token(api_token: &str) -> header::HeaderValue {
215        let mut auth_value = header::HeaderValue::from_str(&format!("Bearer {api_token}")).unwrap();
216        // Consider marking security-sensitive headers with `set_sensitive`.
217        auth_value.set_sensitive(true);
218        auth_value
219    }
220
221    pub async fn tokenizer_by_model(
222        &self,
223        model: &str,
224        api_token: Option<String>,
225    ) -> Result<Tokenizer, Error> {
226        let api_token = api_token
227            .as_ref()
228            .or(self.api_token.as_ref())
229            .expect("API token needs to be set on client construction or per request");
230        let response = self
231            .http
232            .get(format!("{}/models/{model}/tokenizer", self.base))
233            .header(header::AUTHORIZATION, Self::header_from_token(api_token))
234            .send()
235            .await?;
236        let response = translate_http_error(response).await?;
237        let bytes = response.bytes().await?;
238        let tokenizer = Tokenizer::from_bytes(bytes).map_err(|e| Error::InvalidTokenizer {
239            deserialization_error: e.to_string(),
240        })?;
241        Ok(tokenizer)
242    }
243}
244
245async fn translate_http_error(response: reqwest::Response) -> Result<reqwest::Response, Error> {
246    let status = response.status();
247    if !status.is_success() {
248        // Store body in a variable, so we can use it, even if it is not an Error emitted by
249        // the API, but an intermediate Proxy like NGinx, so we can still forward the error
250        // message.
251        let body = response.text().await?;
252        // If the response is an error emitted by the API, this deserialization should succeed.
253        let api_error: Result<ApiError, _> = serde_json::from_str(&body);
254        let translated_error = match status {
255            StatusCode::TOO_MANY_REQUESTS => Error::TooManyRequests,
256            StatusCode::SERVICE_UNAVAILABLE => {
257                // Presence of `api_error` implies the error originated from the API itself (rather
258                // than the intermediate proxy) and so we can decode it as such.
259                if api_error.is_ok_and(|error| error.code == "QUEUE_FULL") {
260                    Error::Busy
261                } else {
262                    Error::Unavailable
263                }
264            }
265            _ => Error::Http {
266                status: status.as_u16(),
267                body,
268            },
269        };
270        Err(translated_error)
271    } else {
272        Ok(response)
273    }
274}
275
276/// We are only interested in the status codes of the API.
277#[derive(Deserialize, Debug)]
278struct ApiError<'a> {
279    /// Unique string in capital letters emitted by the API to signal different kinds of errors in a
280    /// finer granularity then the HTTP status codes alone would allow for.
281    ///
282    /// E.g. Differentiating between request rate limiting and parallel tasks limiting which both
283    /// are 429 (the former is emitted by NGinx though).
284    code: Cow<'a, str>,
285}
286
287/// Errors returned by the Aleph Alpha Client
288#[derive(ThisError, Debug)]
289pub enum Error {
290    /// User exceeds his current Task Quota.
291    #[error(
292        "You are trying to send too many requests to the API in to short an interval. Slow down a \
293        bit, otherwise these error will persist. Sorry for this, but we try to prevent DOS attacks."
294    )]
295    TooManyRequests,
296    /// Model is busy. Most likely due to many other users requesting its services right now.
297    #[error(
298        "Sorry the request to the Aleph Alpha API has been rejected due to the requested model \
299        being very busy at the moment. We found it unlikely that your request would finish in a \
300        reasonable timeframe, so it was rejected right away, rather than make you wait. You are \
301        welcome to retry your request any time."
302    )]
303    Busy,
304    /// The API itself is unavailable, most likely due to restart.
305    #[error(
306        "The service is currently unavailable. This is likely due to restart. Please try again \
307        later."
308    )]
309    Unavailable,
310    #[error("No response received within given timeout: {0:?}")]
311    ClientTimeout(Duration),
312    /// An error on the Http Protocol level.
313    #[error("HTTP request failed with status code {}. Body:\n{}", status, body)]
314    Http { status: u16, body: String },
315    #[error(
316        "Tokenizer could not be correctly deserialized. Caused by:\n{}",
317        deserialization_error
318    )]
319    InvalidTokenizer { deserialization_error: String },
320    /// Deserialization error of the stream event.
321    #[error(
322        "Stream event could not be correctly deserialized. Caused by:\n{}.",
323        deserialization_error
324    )]
325    InvalidStream { deserialization_error: String },
326    /// Most likely either TLS errors creating the Client, or IO errors.
327    #[error(transparent)]
328    Other(#[from] reqwest::Error),
329}
330
331#[cfg(test)]
332mod tests {
333    use crate::{chat::ChatEvent, completion::CompletionEvent};
334
335    use super::*;
336
337    #[test]
338    fn stream_chunk_event_is_parsed() {
339        // Given some bytes
340        let bytes = b"data: {\"type\":\"stream_chunk\",\"index\":0,\"completion\":\" The New York Times, May 15\"}\n\n";
341
342        // When they are parsed
343        let events = HttpClient::parse_stream_event::<CompletionEvent>(bytes);
344        let event = events.first().unwrap().as_ref().unwrap();
345
346        // Then the event is a stream chunk
347        match event {
348            CompletionEvent::StreamChunk(chunk) => assert_eq!(chunk.index, 0),
349            _ => panic!("Expected a stream chunk"),
350        }
351    }
352
353    #[test]
354    fn completion_summary_event_is_parsed() {
355        // Given some bytes with a stream summary and a completion summary
356        let bytes = b"data: {\"type\":\"stream_summary\",\"index\":0,\"model_version\":\"2022-04\",\"finish_reason\":\"maximum_tokens\"}\n\ndata: {\"type\":\"completion_summary\",\"num_tokens_prompt_total\":1,\"num_tokens_generated\":7}\n\n";
357
358        // When they are parsed
359        let events = HttpClient::parse_stream_event::<CompletionEvent>(bytes);
360
361        // Then the first event is a stream summary and the last event is a completion summary
362        let first = events.first().unwrap().as_ref().unwrap();
363        match first {
364            CompletionEvent::StreamSummary(summary) => {
365                assert_eq!(summary.finish_reason, "maximum_tokens")
366            }
367            _ => panic!("Expected a completion summary"),
368        }
369        let second = events.last().unwrap().as_ref().unwrap();
370        match second {
371            CompletionEvent::CompletionSummary(summary) => {
372                assert_eq!(summary.num_tokens_generated, 7)
373            }
374            _ => panic!("Expected a completion summary"),
375        }
376    }
377
378    #[test]
379    fn chat_stream_chunk_event_is_parsed() {
380        // Given some bytes
381        let bytes = b"data: {\"id\":\"831e41b4-2382-4b08-990e-0a3859967f43\",\"choices\":[{\"finish_reason\":null,\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"\"},\"logprobs\":null}],\"created\":1729782822,\"model\":\"pharia-1-llm-7b-control\",\"system_fingerprint\":null,\"object\":\"chat.completion.chunk\",\"usage\":null}\n\n";
382
383        // When they are parsed
384        let events = HttpClient::parse_stream_event::<ChatEvent>(bytes);
385        let event = events.first().unwrap().as_ref().unwrap();
386
387        // Then the event is a chat stream chunk
388        assert_eq!(event.choices[0].delta.role.as_ref().unwrap(), "assistant");
389    }
390
391    #[test]
392    fn chat_stream_chunk_without_role_is_parsed() {
393        // Given some bytes without a role
394        let bytes = b"data: {\"id\":\"a3ceca7f-32b2-4a6c-89e7-bc8eb5327f76\",\"choices\":[{\"finish_reason\":null,\"index\":0,\"delta\":{\"content\":\"Hello! How can I help you today? If you have any questions or need assistance, feel free to ask.\"},\"logprobs\":null}],\"created\":1729784197,\"model\":\"pharia-1-llm-7b-control\",\"system_fingerprint\":null,\"object\":\"chat.completion.chunk\",\"usage\":null}\n\n";
395
396        // When they are parsed
397        let events = HttpClient::parse_stream_event::<ChatEvent>(bytes);
398        let event = events.first().unwrap().as_ref().unwrap();
399
400        // Then the event is a chat stream chunk
401        assert_eq!(event.choices[0].delta.content, "Hello! How can I help you today? If you have any questions or need assistance, feel free to ask.");
402    }
403}