aleph_alpha_client/lib.rs
1//! Usage sample
2//!
3//! ```no_run
4//! use aleph_alpha_client::{Client, TaskCompletion, How};
5//!
6//! #[tokio::main(flavor = "current_thread")]
7//! async fn main() {
8//! // Authenticate against API. Fetches token.
9//! let client = Client::from_env().unwrap();
10//!
11//! // Name of the model we we want to use. Large models give usually better answer, but are also
12//! // more costly.
13//! let model = "luminous-base";
14//!
15//! // The task we want to perform. Here we want to continue the sentence: "An apple a day ..."
16//! let task = TaskCompletion::from_text("An apple a day");
17//!
18//! // Retrieve the answer from the API
19//! let response = client.completion(&task, model, &How::default()).await.unwrap();
20//!
21//! // Print entire sentence with completion
22//! println!("An apple a day{}", response.completion);
23//! }
24//! ```
25
26mod chat;
27mod completion;
28mod detokenization;
29mod explanation;
30mod http;
31mod image_preprocessing;
32mod logprobs;
33mod prompt;
34mod semantic_embedding;
35mod stream;
36mod tokenization;
37use dotenvy::dotenv;
38use futures_util::Stream;
39use http::HttpClient;
40use semantic_embedding::{BatchSemanticEmbeddingOutput, SemanticEmbeddingOutput};
41use std::env;
42use std::{pin::Pin, time::Duration};
43use tokenizers::Tokenizer;
44
45pub use self::{
46 chat::{ChatEvent, ChatOutput, ChatSampling, Distribution, Message, TaskChat, Usage},
47 completion::{CompletionEvent, CompletionOutput, Sampling, Stopping, TaskCompletion},
48 detokenization::{DetokenizationOutput, TaskDetokenization},
49 explanation::{
50 Explanation, ExplanationOutput, Granularity, ImageScore, ItemExplanation,
51 PromptGranularity, TaskExplanation, TextScore,
52 },
53 http::{Error, Job, Task},
54 logprobs::{Logprob, Logprobs},
55 prompt::{Modality, Prompt},
56 semantic_embedding::{
57 SemanticRepresentation, TaskBatchSemanticEmbedding, TaskSemanticEmbedding,
58 },
59 stream::{StreamJob, StreamTask},
60 tokenization::{TaskTokenization, TokenizationOutput},
61};
62
63/// Execute Jobs against the Aleph Alpha API
64pub struct Client {
65 /// This client does all the work of sending the requests and talking to the AA API. The only
66 /// additional knowledge added by this layer is that it knows about the individual jobs which
67 /// can be executed, which allows for an alternative non generic interface which might produce
68 /// easier to read code for the end user in many use cases.
69 http_client: HttpClient,
70}
71
72impl Client {
73 /// A new instance of an Aleph Alpha client helping you interact with the Aleph Alpha API.
74 ///
75 /// Setting the token to None allows specifying it on a per request basis.
76 /// You may want to only use request based authentication and skip default authentication. This
77 /// is useful if writing an application which invokes the client on behalf of many different
78 /// users. Having neither request, nor default authentication is considered a bug and will cause
79 /// a panic.
80 pub fn new(host: impl Into<String>, api_token: Option<String>) -> Result<Self, Error> {
81 let http_client = HttpClient::new(host.into(), api_token)?;
82 Ok(Self { http_client })
83 }
84
85 /// A client instance that always uses the same token for all requests.
86 pub fn with_auth(host: impl Into<String>, api_token: impl Into<String>) -> Result<Self, Error> {
87 Self::new(host, Some(api_token.into()))
88 }
89
90 pub fn from_env() -> Result<Self, Error> {
91 let _ = dotenv();
92 let api_token = env::var("PHARIA_AI_TOKEN").unwrap();
93 let inference_url = env::var("INFERENCE_URL").unwrap();
94 Self::with_auth(inference_url, api_token)
95 }
96
97 /// Execute a task with the aleph alpha API and fetch its result.
98 ///
99 /// ```no_run
100 /// use aleph_alpha_client::{Client, How, TaskCompletion, Error};
101 ///
102 /// async fn print_completion() -> Result<(), Error> {
103 /// // Authenticate against API. Fetches token.
104 /// let client = Client::from_env()?;
105 ///
106 /// // Name of the model we we want to use. Large models give usually better answer, but are
107 /// // also slower and more costly.
108 /// let model = "luminous-base";
109 ///
110 /// // The task we want to perform. Here we want to continue the sentence: "An apple a day
111 /// // ..."
112 /// let task = TaskCompletion::from_text("An apple a day");
113 ///
114 /// // Retrieve answer from API
115 /// let response = client.execute(model, &task, &How::default()).await?;
116 ///
117 /// // Print entire sentence with completion
118 /// println!("An apple a day{}", response.completion);
119 /// Ok(())
120 /// }
121 /// ```
122 #[deprecated = "Please use output_of instead."]
123 pub async fn execute<T: Task>(
124 &self,
125 model: &str,
126 task: &T,
127 how: &How,
128 ) -> Result<T::Output, Error> {
129 self.output_of(&task.with_model(model), how).await
130 }
131
132 /// Execute any task with the aleph alpha API and fetch its result. This is most useful in
133 /// generic code then you want to execute arbitrary task types. Otherwise prefer methods taking
134 /// concrete tasks like [`Self::completion`] for improved readability.
135 pub async fn output_of<T: Job>(&self, task: &T, how: &How) -> Result<T::Output, Error> {
136 self.http_client.output_of(task, how).await
137 }
138
139 /// An embedding trying to capture the semantic meaning of a text. Cosine similarity can be used
140 /// find out how well two texts (or multimodal prompts) match. Useful for search usecases.
141 ///
142 /// See the example for [`cosine_similarity`].
143 pub async fn semantic_embedding(
144 &self,
145 task: &TaskSemanticEmbedding<'_>,
146 how: &How,
147 ) -> Result<SemanticEmbeddingOutput, Error> {
148 self.http_client.output_of(task, how).await
149 }
150
151 /// An batch of embeddings trying to capture the semantic meaning of a text.
152 pub async fn batch_semantic_embedding(
153 &self,
154 task: &TaskBatchSemanticEmbedding<'_>,
155 how: &How,
156 ) -> Result<BatchSemanticEmbeddingOutput, Error> {
157 self.http_client.output_of(task, how).await
158 }
159
160 /// Instruct a model served by the aleph alpha API to continue writing a piece of text (or
161 /// multimodal document).
162 ///
163 /// ```no_run
164 /// use aleph_alpha_client::{Client, How, TaskCompletion, Task, Error};
165 ///
166 /// async fn print_completion() -> Result<(), Error> {
167 /// // Authenticate against API. Fetches token.
168 /// let client = Client::from_env()?;
169 ///
170 /// // Name of the model we we want to use. Large models give usually better answer, but are
171 /// // also slower and more costly.
172 /// let model = "luminous-base";
173 ///
174 /// // The task we want to perform. Here we want to continue the sentence: "An apple a day
175 /// // ..."
176 /// let task = TaskCompletion::from_text("An apple a day");
177 ///
178 /// // Retrieve answer from API
179 /// let response = client.completion(&task, model, &How::default()).await?;
180 ///
181 /// // Print entire sentence with completion
182 /// println!("An apple a day{}", response.completion);
183 /// Ok(())
184 /// }
185 /// ```
186 pub async fn completion(
187 &self,
188 task: &TaskCompletion<'_>,
189 model: &str,
190 how: &How,
191 ) -> Result<CompletionOutput, Error> {
192 self.http_client
193 .output_of(&Task::with_model(task, model), how)
194 .await
195 }
196
197 /// Instruct a model served by the aleph alpha API to continue writing a piece of text.
198 /// Stream the response as a series of events.
199 ///
200 /// ```no_run
201 /// use aleph_alpha_client::{Client, How, TaskCompletion, Error, CompletionEvent};
202 /// use futures_util::StreamExt;
203 ///
204 /// async fn print_stream_completion() -> Result<(), Error> {
205 /// // Authenticate against API. Fetches token.
206 /// let client = Client::from_env()?;
207 ///
208 /// // Name of the model we we want to use. Large models give usually better answer, but are
209 /// // also slower and more costly.
210 /// let model = "luminous-base";
211 ///
212 /// // The task we want to perform. Here we want to continue the sentence: "An apple a day
213 /// // ..."
214 /// let task = TaskCompletion::from_text("An apple a day");
215 ///
216 /// // Retrieve stream from API
217 /// let mut stream = client.stream_completion(&task, model, &How::default()).await?;
218 /// while let Some(Ok(event)) = stream.next().await {
219 /// if let CompletionEvent::Delta { completion, logprobs: _ } = event {
220 /// println!("{}", completion);
221 /// }
222 /// }
223 /// Ok(())
224 /// }
225 /// ```
226 pub async fn stream_completion<'task>(
227 &self,
228 task: &'task TaskCompletion<'task>,
229 model: &'task str,
230 how: &How,
231 ) -> Result<Pin<Box<dyn Stream<Item = Result<CompletionEvent, Error>> + Send + 'task>>, Error>
232 {
233 self.http_client
234 .stream_output_of(StreamTask::with_model(task, model), how)
235 .await
236 }
237
238 /// Send a chat message to a model.
239 /// ```no_run
240 /// use aleph_alpha_client::{Client, How, TaskChat, Error, Message};
241 ///
242 /// async fn print_chat() -> Result<(), Error> {
243 /// // Authenticate against API. Fetches token.
244 /// let client = Client::from_env()?;
245 ///
246 /// // Name of a model that supports chat.
247 /// let model = "pharia-1-llm-7b-control";
248 ///
249 /// // Create a chat task with a user message.
250 /// let message = Message::user("Hello, how are you?");
251 /// let task = TaskChat::with_message(message);
252 ///
253 /// // Send the message to the model.
254 /// let response = client.chat(&task, model, &How::default()).await?;
255 ///
256 /// // Print the model response
257 /// println!("{}", response.message.content);
258 /// Ok(())
259 /// }
260 /// ```
261 pub async fn chat(
262 &self,
263 task: &TaskChat<'_>,
264 model: &str,
265 how: &How,
266 ) -> Result<ChatOutput, Error> {
267 self.http_client
268 .output_of(&Task::with_model(task, model), how)
269 .await
270 }
271
272 /// Send a chat message to a model. Stream the response as a series of events.
273 /// ```no_run
274 /// use aleph_alpha_client::{Client, How, TaskChat, Error, Message, ChatEvent};
275 /// use futures_util::StreamExt;
276 ///
277 /// async fn print_stream_chat() -> Result<(), Error> {
278 /// // Authenticate against API. Fetches token.
279 /// let client = Client::from_env()?;
280 ///
281 /// // Name of a model that supports chat.
282 /// let model = "pharia-1-llm-7b-control";
283 ///
284 /// // Create a chat task with a user message.
285 /// let message = Message::user("Hello, how are you?");
286 /// let task = TaskChat::with_message(message);
287 ///
288 /// // Send the message to the model.
289 /// let mut stream = client.stream_chat(&task, model, &How::default()).await?;
290 /// while let Some(Ok(event)) = stream.next().await {
291 /// if let ChatEvent::Delta { content, logprobs: _ } = event {
292 /// println!("{}", content);
293 /// }
294 /// }
295 /// Ok(())
296 /// }
297 /// ```
298 pub async fn stream_chat<'task>(
299 &self,
300 task: &'task TaskChat<'_>,
301 model: &'task str,
302 how: &How,
303 ) -> Result<Pin<Box<dyn Stream<Item = Result<ChatEvent, Error>> + Send + 'task>>, Error> {
304 self.http_client
305 .stream_output_of(StreamTask::with_model(task, model), how)
306 .await
307 }
308
309 /// Returns an explanation given a prompt and a target (typically generated
310 /// by a previous completion request). The explanation describes how individual parts
311 /// of the prompt influenced the target.
312 ///
313 /// ```no_run
314 /// use aleph_alpha_client::{Client, How, TaskCompletion, Task, Error, Granularity,
315 /// TaskExplanation, Stopping, Prompt, Sampling, Logprobs};
316 ///
317 /// async fn print_explanation() -> Result<(), Error> {
318 /// let client = Client::from_env()?;
319 ///
320 /// // Name of the model we we want to use. Large models give usually better answer, but are
321 /// // also slower and more costly.
322 /// let model = "luminous-base";
323 ///
324 /// // input for the completion
325 /// let prompt = Prompt::from_text("An apple a day");
326 ///
327 /// let task = TaskCompletion {
328 /// prompt: prompt.clone(),
329 /// stopping: Stopping::from_maximum_tokens(10),
330 /// sampling: Sampling::MOST_LIKELY,
331 /// special_tokens: false,
332 /// logprobs: Logprobs::No,
333 /// };
334 /// let response = client.completion(&task, model, &How::default()).await?;
335 ///
336 /// let task = TaskExplanation {
337 /// prompt: prompt, // same input as for completion
338 /// target: &response.completion, // output of completion
339 /// granularity: Granularity::default(),
340 /// };
341 /// let response = client.explanation(&task, model, &How::default()).await?;
342 ///
343 /// dbg!(&response);
344 /// Ok(())
345 /// }
346 /// ```
347 pub async fn explanation(
348 &self,
349 task: &TaskExplanation<'_>,
350 model: &str,
351 how: &How,
352 ) -> Result<ExplanationOutput, Error> {
353 self.http_client
354 .output_of(&task.with_model(model), how)
355 .await
356 }
357
358 /// Tokenize a prompt for a specific model.
359 ///
360 /// ```no_run
361 /// use aleph_alpha_client::{Client, Error, How, TaskTokenization};
362 ///
363 /// async fn tokenize() -> Result<(), Error> {
364 /// let client = Client::from_env()?;
365 ///
366 /// // Name of the model for which we want to tokenize text.
367 /// let model = "luminous-base";
368 ///
369 /// // Text prompt to be tokenized.
370 /// let prompt = "An apple a day";
371 ///
372 /// let task = TaskTokenization {
373 /// prompt,
374 /// tokens: true, // return text-tokens
375 /// token_ids: true, // return numeric token-ids
376 /// };
377 /// let responses = client.tokenize(&task, model, &How::default()).await?;
378 ///
379 /// dbg!(&responses);
380 /// Ok(())
381 /// }
382 /// ```
383 pub async fn tokenize(
384 &self,
385 task: &TaskTokenization<'_>,
386 model: &str,
387 how: &How,
388 ) -> Result<TokenizationOutput, Error> {
389 self.http_client
390 .output_of(&task.with_model(model), how)
391 .await
392 }
393
394 /// Detokenize a list of token ids into a string.
395 ///
396 /// ```no_run
397 /// use aleph_alpha_client::{Client, Error, How, TaskDetokenization};
398 ///
399 /// async fn detokenize() -> Result<(), Error> {
400 /// let client = Client::from_env()?;
401 ///
402 /// // Specify the name of the model whose tokenizer was used to generate the input token ids.
403 /// let model = "luminous-base";
404 ///
405 /// // Token ids to convert into text.
406 /// let token_ids: Vec<u32> = vec![556, 48741, 247, 2983];
407 ///
408 /// let task = TaskDetokenization {
409 /// token_ids: &token_ids,
410 /// };
411 /// let responses = client.detokenize(&task, model, &How::default()).await?;
412 ///
413 /// dbg!(&responses);
414 /// Ok(())
415 /// }
416 /// ```
417 pub async fn detokenize(
418 &self,
419 task: &TaskDetokenization<'_>,
420 model: &str,
421 how: &How,
422 ) -> Result<DetokenizationOutput, Error> {
423 self.http_client
424 .output_of(&task.with_model(model), how)
425 .await
426 }
427
428 pub async fn tokenizer_by_model(
429 &self,
430 model: &str,
431 api_token: Option<String>,
432 ) -> Result<Tokenizer, Error> {
433 self.http_client.tokenizer_by_model(model, api_token).await
434 }
435}
436
437/// Controls of how to execute a task
438#[derive(Clone, PartialEq, Eq, Hash)]
439pub struct How {
440 /// The be-nice flag is used to reduce load for the models you intend to use.
441 /// This is commonly used if you are conducting experiments
442 /// or trying things out that create a large load on the aleph-alpha-api
443 /// and you do not want to increase queue time for other users too much.
444 ///
445 /// (!) This increases how often you get a `Busy` response.
446 pub be_nice: bool,
447
448 /// The maximum duration of a request before the client cancels the request. This is not passed on
449 /// to the server but only handled by the client locally, i.e. the client will not wait longer than
450 /// this duration for a response.
451 pub client_timeout: Duration,
452
453 /// API token used to authenticate the request, overwrites the default token provided on setup
454 /// Default token may not provide the tracking or permission that is wanted for the request
455 pub api_token: Option<String>,
456}
457
458impl Default for How {
459 fn default() -> Self {
460 // the aleph-alpha-api cancels request after 5 minute
461 let api_timeout = Duration::from_secs(300);
462 Self {
463 be_nice: Default::default(),
464 // on the client side a request can take longer in case of network errors
465 // therefore by default we wait slightly longer
466 client_timeout: api_timeout + Duration::from_secs(5),
467 api_token: None,
468 }
469 }
470}
471
472/// Intended to compare embeddings.
473///
474/// ```no_run
475/// use aleph_alpha_client::{
476/// Client, Prompt, TaskSemanticEmbedding, cosine_similarity, SemanticRepresentation, How
477/// };
478///
479/// async fn semantic_search_with_luminous_base(client: &Client) {
480/// // Given
481/// let robot_fact = Prompt::from_text(
482/// "A robot is a machine—especially one programmable by a computer—capable of carrying out a \
483/// complex series of actions automatically.",
484/// );
485/// let pizza_fact = Prompt::from_text(
486/// "Pizza (Italian: [ˈpittsa], Neapolitan: [ˈpittsə]) is a dish of Italian origin consisting \
487/// of a usually round, flat base of leavened wheat-based dough topped with tomatoes, cheese, \
488/// and often various other ingredients (such as various types of sausage, anchovies, \
489/// mushrooms, onions, olives, vegetables, meat, ham, etc.), which is then baked at a high \
490/// temperature, traditionally in a wood-fired oven.",
491/// );
492/// let query = Prompt::from_text("What is Pizza?");
493/// let how = How::default();
494///
495/// // When
496/// let robot_embedding_task = TaskSemanticEmbedding {
497/// prompt: robot_fact,
498/// representation: SemanticRepresentation::Document,
499/// compress_to_size: Some(128),
500/// };
501/// let robot_embedding = client.semantic_embedding(
502/// &robot_embedding_task,
503/// &how,
504/// ).await.unwrap().embedding;
505///
506/// let pizza_embedding_task = TaskSemanticEmbedding {
507/// prompt: pizza_fact,
508/// representation: SemanticRepresentation::Document,
509/// compress_to_size: Some(128),
510/// };
511/// let pizza_embedding = client.semantic_embedding(
512/// &pizza_embedding_task,
513/// &how,
514/// ).await.unwrap().embedding;
515///
516/// let query_embedding_task = TaskSemanticEmbedding {
517/// prompt: query,
518/// representation: SemanticRepresentation::Query,
519/// compress_to_size: Some(128),
520/// };
521/// let query_embedding = client.semantic_embedding(
522/// &query_embedding_task,
523/// &how,
524/// ).await.unwrap().embedding;
525/// let similarity_pizza = cosine_similarity(&query_embedding, &pizza_embedding);
526/// println!("similarity pizza: {similarity_pizza}");
527/// let similarity_robot = cosine_similarity(&query_embedding, &robot_embedding);
528/// println!("similarity robot: {similarity_robot}");
529///
530/// // Then
531///
532/// // The fact about pizza should be more relevant to the "What is Pizza?" question than a fact
533/// // about robots.
534/// assert!(similarity_pizza > similarity_robot);
535/// }
536/// ```
537pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
538 let ab: f32 = a.iter().zip(b).map(|(a, b)| a * b).sum();
539 let aa: f32 = a.iter().map(|a| a * a).sum();
540 let bb: f32 = b.iter().map(|b| b * b).sum();
541 let prod_len = (aa * bb).sqrt();
542 ab / prod_len
543}
544
545#[cfg(test)]
546mod tests {
547 use crate::Prompt;
548
549 #[test]
550 fn ability_to_generate_prompt_in_local_function() {
551 fn local_function() -> Prompt<'static> {
552 Prompt::from_text(String::from("My test prompt"))
553 }
554
555 assert_eq!(Prompt::from_text("My test prompt"), local_function())
556 }
557}