aleph_alpha_client/
http.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
use std::{borrow::Cow, pin::Pin, time::Duration};

use futures_util::{stream::StreamExt, Stream};
use reqwest::{header, ClientBuilder, RequestBuilder, Response, StatusCode};
use serde::Deserialize;
use thiserror::Error as ThisError;
use tokenizers::Tokenizer;

use crate::{How, StreamJob};
use async_stream::stream;

/// A job send to the Aleph Alpha Api using the http client. A job wraps all the knowledge required
/// for the Aleph Alpha API to specify its result. Notably it includes the model(s) the job is
/// executed on. This allows this trait to hold in the presence of services, which use more than one
/// model and task type to achieve their result. On the other hand a bare [`crate::TaskCompletion`]
/// can not implement this trait directly, since its result would depend on what model is chosen to
/// execute it. You can remedy this by turning completion task into a job, calling
/// [`Task::with_model`].
pub trait Job {
    /// Output returned by [`crate::Client::output_of`]
    type Output;

    /// Expected answer of the Aleph Alpha API
    type ResponseBody: for<'de> Deserialize<'de>;

    /// Prepare the request for the Aleph Alpha API. Authentication headers can be assumed to be
    /// already set.
    fn build_request(&self, client: &reqwest::Client, base: &str) -> RequestBuilder;

    /// Parses the response of the server into higher level structs for the user.
    fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output;
}

/// A task send to the Aleph Alpha Api using the http client. Requires to specify a model before it
/// can be executed.
pub trait Task {
    /// Output returned by [`crate::Client::output_of`]
    type Output;

    /// Expected answer of the Aleph Alpha API
    type ResponseBody: for<'de> Deserialize<'de>;

    /// Prepare the request for the Aleph Alpha API. Authentication headers can be assumed to be
    /// already set.
    fn build_request(&self, client: &reqwest::Client, base: &str, model: &str) -> RequestBuilder;

    /// Parses the response of the server into higher level structs for the user.
    fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output;

    /// Turn your task into [`Job`] by annotating it with a model name.
    fn with_model<'a>(&'a self, model: &'a str) -> MethodJob<'a, Self>
    where
        Self: Sized,
    {
        MethodJob { model, task: self }
    }
}

/// Enriches the `Task` to a `Job` by appending the model it should be executed with. Use this as
/// input for [`Client::output_of`].
pub struct MethodJob<'a, T> {
    /// Name of the Aleph Alpha Model. E.g. "luminous-base".
    pub model: &'a str,
    /// Task to be executed against the model.
    pub task: &'a T,
}

impl<T> Job for MethodJob<'_, T>
where
    T: Task,
{
    type Output = T::Output;

    type ResponseBody = T::ResponseBody;

    fn build_request(&self, client: &reqwest::Client, base: &str) -> RequestBuilder {
        self.task.build_request(client, base, self.model)
    }

    fn body_to_output(&self, response: T::ResponseBody) -> T::Output {
        self.task.body_to_output(response)
    }
}

/// Sends HTTP request to the Aleph Alpha API
pub struct HttpClient {
    base: String,
    http: reqwest::Client,
    api_token: Option<String>,
}

impl HttpClient {
    /// In production you typically would want set this to <https://inference-api.pharia.your-company.com>.
    /// Yet you may want to use a different instance for testing.
    pub fn new(host: String, api_token: Option<String>) -> Result<Self, Error> {
        let http = ClientBuilder::new().build()?;

        Ok(Self {
            base: host,
            http,
            api_token,
        })
    }

    /// Construct and execute a request building on top of a `RequestBuilder`
    async fn response(&self, builder: RequestBuilder, how: &How) -> Result<Response, Error> {
        let query = if how.be_nice {
            [("nice", "true")].as_slice()
        } else {
            // nice=false is default, so we just omit it.
            [].as_slice()
        };

        let api_token = how
            .api_token
            .as_ref()
            .or(self.api_token.as_ref())
            .expect("API token needs to be set on client construction or per request");
        let response = builder
            .query(query)
            .header(header::AUTHORIZATION, Self::header_from_token(api_token))
            .timeout(how.client_timeout)
            .send()
            .await
            .map_err(|reqwest_error| {
                if reqwest_error.is_timeout() {
                    Error::ClientTimeout(how.client_timeout)
                } else {
                    reqwest_error.into()
                }
            })?;
        translate_http_error(response).await
    }

    /// Execute a task with the aleph alpha API and fetch its result.
    ///
    /// ```no_run
    /// use aleph_alpha_client::{Client, How, TaskCompletion, Task, Error};
    ///
    /// async fn print_completion() -> Result<(), Error> {
    ///     // Authenticate against API. Fetches token.
    ///     let client = Client::from_env()?;
    ///
    ///     // Name of the model we we want to use. Large models give usually better answer, but are
    ///     // also slower and more costly.
    ///     let model = "luminous-base";
    ///
    ///     // The task we want to perform. Here we want to continue the sentence: "An apple a day
    ///     // ..."
    ///     let task = TaskCompletion::from_text("An apple a day");
    ///
    ///     // Retrieve answer from API
    ///     let response = client.output_of(&task.with_model(model), &How::default()).await?;
    ///
    ///     // Print entire sentence with completion
    ///     println!("An apple a day{}", response.completion);
    ///     Ok(())
    /// }
    /// ```
    pub async fn output_of<T: Job>(&self, task: &T, how: &How) -> Result<T::Output, Error> {
        let builder = task.build_request(&self.http, &self.base);
        let response = self.response(builder, how).await?;
        let response_body: T::ResponseBody = response.json().await?;
        let answer = task.body_to_output(response_body);
        Ok(answer)
    }

    pub async fn stream_output_of<T: StreamJob>(
        &self,
        task: &T,
        how: &How,
    ) -> Result<Pin<Box<dyn Stream<Item = Result<T::Output, Error>> + Send>>, Error>
    where
        T::Output: 'static,
    {
        let builder = task.build_request(&self.http, &self.base);
        let response = self.response(builder, how).await?;
        let mut stream = response.bytes_stream();

        Ok(Box::pin(stream! {
            while let Some(item) = stream.next().await {
                match item {
                    Ok(bytes) => {
                        let events = Self::parse_stream_event::<T::ResponseBody>(bytes.as_ref());
                        for event in events {
                            yield event.map(|b| T::body_to_output(b));
                        }
                    }
                    Err(e) => {
                        yield Err(e.into());
                    }
                }
            }
        }))
    }

    /// Take a byte slice (of a SSE) and parse it into a provided response body.
    /// Each SSE event is expected to contain one or multiple JSON bodies prefixed by `data: `.
    fn parse_stream_event<StreamBody>(bytes: &[u8]) -> Vec<Result<StreamBody, Error>>
    where
        StreamBody: for<'de> Deserialize<'de>,
    {
        String::from_utf8_lossy(bytes)
            .split("data: ")
            .skip(1)
            .map(|s| {
                serde_json::from_str(s).map_err(|e| Error::InvalidStream {
                    deserialization_error: e.to_string(),
                })
            })
            .collect()
    }

    fn header_from_token(api_token: &str) -> header::HeaderValue {
        let mut auth_value = header::HeaderValue::from_str(&format!("Bearer {api_token}")).unwrap();
        // Consider marking security-sensitive headers with `set_sensitive`.
        auth_value.set_sensitive(true);
        auth_value
    }

    pub async fn tokenizer_by_model(
        &self,
        model: &str,
        api_token: Option<String>,
    ) -> Result<Tokenizer, Error> {
        let api_token = api_token
            .as_ref()
            .or(self.api_token.as_ref())
            .expect("API token needs to be set on client construction or per request");
        let response = self
            .http
            .get(format!("{}/models/{model}/tokenizer", self.base))
            .header(header::AUTHORIZATION, Self::header_from_token(api_token))
            .send()
            .await?;
        let response = translate_http_error(response).await?;
        let bytes = response.bytes().await?;
        let tokenizer = Tokenizer::from_bytes(bytes).map_err(|e| Error::InvalidTokenizer {
            deserialization_error: e.to_string(),
        })?;
        Ok(tokenizer)
    }
}

async fn translate_http_error(response: reqwest::Response) -> Result<reqwest::Response, Error> {
    let status = response.status();
    if !status.is_success() {
        // Store body in a variable, so we can use it, even if it is not an Error emitted by
        // the API, but an intermediate Proxy like NGinx, so we can still forward the error
        // message.
        let body = response.text().await?;
        // If the response is an error emitted by the API, this deserialization should succeed.
        let api_error: Result<ApiError, _> = serde_json::from_str(&body);
        let translated_error = match status {
            StatusCode::TOO_MANY_REQUESTS => Error::TooManyRequests,
            StatusCode::SERVICE_UNAVAILABLE => {
                // Presence of `api_error` implies the error originated from the API itself (rather
                // than the intermediate proxy) and so we can decode it as such.
                if api_error.is_ok_and(|error| error.code == "QUEUE_FULL") {
                    Error::Busy
                } else {
                    Error::Unavailable
                }
            }
            _ => Error::Http {
                status: status.as_u16(),
                body,
            },
        };
        Err(translated_error)
    } else {
        Ok(response)
    }
}

/// We are only interested in the status codes of the API.
#[derive(Deserialize, Debug)]
struct ApiError<'a> {
    /// Unique string in capital letters emitted by the API to signal different kinds of errors in a
    /// finer granularity then the HTTP status codes alone would allow for.
    ///
    /// E.g. Differentiating between request rate limiting and parallel tasks limiting which both
    /// are 429 (the former is emitted by NGinx though).
    code: Cow<'a, str>,
}

/// Errors returned by the Aleph Alpha Client
#[derive(ThisError, Debug)]
pub enum Error {
    /// User exceeds his current Task Quota.
    #[error(
        "You are trying to send too many requests to the API in to short an interval. Slow down a \
        bit, otherwise these error will persist. Sorry for this, but we try to prevent DOS attacks."
    )]
    TooManyRequests,
    /// Model is busy. Most likely due to many other users requesting its services right now.
    #[error(
        "Sorry the request to the Aleph Alpha API has been rejected due to the requested model \
        being very busy at the moment. We found it unlikely that your request would finish in a \
        reasonable timeframe, so it was rejected right away, rather than make you wait. You are \
        welcome to retry your request any time."
    )]
    Busy,
    /// The API itself is unavailable, most likely due to restart.
    #[error(
        "The service is currently unavailable. This is likely due to restart. Please try again \
        later."
    )]
    Unavailable,
    #[error("No response received within given timeout: {0:?}")]
    ClientTimeout(Duration),
    /// An error on the Http Protocol level.
    #[error("HTTP request failed with status code {}. Body:\n{}", status, body)]
    Http { status: u16, body: String },
    #[error(
        "Tokenizer could not be correctly deserialized. Caused by:\n{}",
        deserialization_error
    )]
    InvalidTokenizer { deserialization_error: String },
    /// Deserialization error of the stream event.
    #[error(
        "Stream event could not be correctly deserialized. Caused by:\n{}.",
        deserialization_error
    )]
    InvalidStream { deserialization_error: String },
    /// Most likely either TLS errors creating the Client, or IO errors.
    #[error(transparent)]
    Other(#[from] reqwest::Error),
}

#[cfg(test)]
mod tests {
    use crate::{chat::ChatEvent, completion::CompletionEvent};

    use super::*;

    #[test]
    fn stream_chunk_event_is_parsed() {
        // Given some bytes
        let bytes = b"data: {\"type\":\"stream_chunk\",\"index\":0,\"completion\":\" The New York Times, May 15\"}\n\n";

        // When they are parsed
        let events = HttpClient::parse_stream_event::<CompletionEvent>(bytes);
        let event = events.first().unwrap().as_ref().unwrap();

        // Then the event is a stream chunk
        match event {
            CompletionEvent::StreamChunk(chunk) => assert_eq!(chunk.index, 0),
            _ => panic!("Expected a stream chunk"),
        }
    }

    #[test]
    fn completion_summary_event_is_parsed() {
        // Given some bytes with a stream summary and a completion summary
        let bytes = b"data: {\"type\":\"stream_summary\",\"index\":0,\"model_version\":\"2022-04\",\"finish_reason\":\"maximum_tokens\"}\n\ndata: {\"type\":\"completion_summary\",\"num_tokens_prompt_total\":1,\"num_tokens_generated\":7}\n\n";

        // When they are parsed
        let events = HttpClient::parse_stream_event::<CompletionEvent>(bytes);

        // Then the first event is a stream summary and the last event is a completion summary
        let first = events.first().unwrap().as_ref().unwrap();
        match first {
            CompletionEvent::StreamSummary(summary) => {
                assert_eq!(summary.finish_reason, "maximum_tokens")
            }
            _ => panic!("Expected a completion summary"),
        }
        let second = events.last().unwrap().as_ref().unwrap();
        match second {
            CompletionEvent::CompletionSummary(summary) => {
                assert_eq!(summary.num_tokens_generated, 7)
            }
            _ => panic!("Expected a completion summary"),
        }
    }

    #[test]
    fn chat_stream_chunk_event_is_parsed() {
        // Given some bytes
        let bytes = b"data: {\"id\":\"831e41b4-2382-4b08-990e-0a3859967f43\",\"choices\":[{\"finish_reason\":null,\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"\"},\"logprobs\":null}],\"created\":1729782822,\"model\":\"pharia-1-llm-7b-control\",\"system_fingerprint\":null,\"object\":\"chat.completion.chunk\",\"usage\":null}\n\n";

        // When they are parsed
        let events = HttpClient::parse_stream_event::<ChatEvent>(bytes);
        let event = events.first().unwrap().as_ref().unwrap();

        // Then the event is a chat stream chunk
        assert_eq!(event.choices[0].delta.role.as_ref().unwrap(), "assistant");
    }

    #[test]
    fn chat_stream_chunk_without_role_is_parsed() {
        // Given some bytes without a role
        let bytes = b"data: {\"id\":\"a3ceca7f-32b2-4a6c-89e7-bc8eb5327f76\",\"choices\":[{\"finish_reason\":null,\"index\":0,\"delta\":{\"content\":\"Hello! How can I help you today? If you have any questions or need assistance, feel free to ask.\"},\"logprobs\":null}],\"created\":1729784197,\"model\":\"pharia-1-llm-7b-control\",\"system_fingerprint\":null,\"object\":\"chat.completion.chunk\",\"usage\":null}\n\n";

        // When they are parsed
        let events = HttpClient::parse_stream_event::<ChatEvent>(bytes);
        let event = events.first().unwrap().as_ref().unwrap();

        // Then the event is a chat stream chunk
        assert_eq!(event.choices[0].delta.content, "Hello! How can I help you today? If you have any questions or need assistance, feel free to ask.");
    }
}