aleph_alpha_client/lib.rs
1//! Usage sample
2//!
3//! ```no_run
4//! use aleph_alpha_client::{Client, TaskCompletion, How};
5//!
6//! #[tokio::main(flavor = "current_thread")]
7//! async fn main() {
8//! // Authenticate against API. Fetches token.
9//! let client = Client::from_env().unwrap();
10//!
11//! // Name of the model we want to use. Large models give usually better answer, but are also
12//! // more costly.
13//! let model = "luminous-base";
14//!
15//! // The task we want to perform. Here we want to continue the sentence: "An apple a day ..."
16//! let task = TaskCompletion::from_text("An apple a day");
17//!
18//! // Retrieve the answer from the API
19//! let response = client.completion(&task, model, &How::default()).await.unwrap();
20//!
21//! // Print entire sentence with completion
22//! println!("An apple a day{}", response.completion);
23//! }
24//! ```
25
26mod chat;
27mod completion;
28mod detokenization;
29mod explanation;
30mod http;
31mod image_preprocessing;
32mod logprobs;
33mod prompt;
34mod semantic_embedding;
35mod stream;
36mod tokenization;
37mod tracing;
38use dotenvy::dotenv;
39use futures_util::Stream;
40use http::HttpClient;
41use semantic_embedding::{BatchSemanticEmbeddingOutput, SemanticEmbeddingOutput};
42use std::env;
43use std::{pin::Pin, time::Duration};
44use tokenizers::Tokenizer;
45
46pub use self::{
47 chat::{ChatEvent, ChatOutput, ChatSampling, Distribution, Message, TaskChat, Usage},
48 completion::{CompletionEvent, CompletionOutput, Sampling, Stopping, TaskCompletion},
49 detokenization::{DetokenizationOutput, TaskDetokenization},
50 explanation::{
51 Explanation, ExplanationOutput, Granularity, ImageScore, ItemExplanation,
52 PromptGranularity, TaskExplanation, TextScore,
53 },
54 http::{Error, Job, Task},
55 logprobs::{Logprob, Logprobs},
56 prompt::{Modality, Prompt},
57 semantic_embedding::{
58 SemanticRepresentation, TaskBatchSemanticEmbedding, TaskSemanticEmbedding,
59 TaskSemanticEmbeddingWithInstruction,
60 },
61 stream::{StreamJob, StreamTask},
62 tokenization::{TaskTokenization, TokenizationOutput},
63 tracing::TraceContext,
64};
65
66/// Execute Jobs against the Aleph Alpha API
67pub struct Client {
68 /// This client does all the work of sending the requests and talking to the AA API. The only
69 /// additional knowledge added by this layer is that it knows about the individual jobs which
70 /// can be executed, which allows for an alternative non generic interface which might produce
71 /// easier to read code for the end user in many use cases.
72 http_client: HttpClient,
73}
74
75impl Client {
76 /// A new instance of an Aleph Alpha client helping you interact with the Aleph Alpha API.
77 ///
78 /// Setting the token to None allows specifying it on a per request basis.
79 /// You may want to only use request based authentication and skip default authentication. This
80 /// is useful if writing an application which invokes the client on behalf of many different
81 /// users. Having neither request, nor default authentication is considered a bug and will cause
82 /// a panic.
83 pub fn new(host: impl Into<String>, api_token: Option<String>) -> Result<Self, Error> {
84 let http_client = HttpClient::new(host.into(), api_token)?;
85 Ok(Self { http_client })
86 }
87
88 /// A client instance that always uses the same token for all requests.
89 pub fn with_auth(host: impl Into<String>, api_token: impl Into<String>) -> Result<Self, Error> {
90 Self::new(host, Some(api_token.into()))
91 }
92
93 pub fn from_env() -> Result<Self, Error> {
94 let _ = dotenv();
95 let api_token = env::var("PHARIA_AI_TOKEN").unwrap();
96 let inference_url = env::var("INFERENCE_URL").unwrap();
97 Self::with_auth(inference_url, api_token)
98 }
99
100 /// Execute a task with the aleph alpha API and fetch its result.
101 ///
102 /// ```no_run
103 /// use aleph_alpha_client::{Client, How, TaskCompletion, Error};
104 ///
105 /// async fn print_completion() -> Result<(), Error> {
106 /// // Authenticate against API. Fetches token.
107 /// let client = Client::from_env()?;
108 ///
109 /// // Name of the model we want to use. Large models give usually better answer, but are
110 /// // also slower and more costly.
111 /// let model = "luminous-base";
112 ///
113 /// // The task we want to perform. Here we want to continue the sentence: "An apple a day
114 /// // ..."
115 /// let task = TaskCompletion::from_text("An apple a day");
116 ///
117 /// // Retrieve answer from API
118 /// let response = client.execute(model, &task, &How::default()).await?;
119 ///
120 /// // Print entire sentence with completion
121 /// println!("An apple a day{}", response.completion);
122 /// Ok(())
123 /// }
124 /// ```
125 #[deprecated = "Please use output_of instead."]
126 pub async fn execute<T: Task>(
127 &self,
128 model: &str,
129 task: &T,
130 how: &How,
131 ) -> Result<T::Output, Error> {
132 self.output_of(&task.with_model(model), how).await
133 }
134
135 /// Execute any task with the aleph alpha API and fetch its result. This is most useful in
136 /// generic code then you want to execute arbitrary task types. Otherwise prefer methods taking
137 /// concrete tasks like [`Self::completion`] for improved readability.
138 pub async fn output_of<T: Job>(&self, task: &T, how: &How) -> Result<T::Output, Error> {
139 self.http_client.output_of(task, how).await
140 }
141
142 /// An embedding trying to capture the semantic meaning of a text. Cosine similarity can be used
143 /// find out how well two texts (or multimodal prompts) match. Useful for search use cases.
144 ///
145 /// See the example for [`cosine_similarity`].
146 pub async fn semantic_embedding(
147 &self,
148 task: &TaskSemanticEmbedding<'_>,
149 how: &How,
150 ) -> Result<SemanticEmbeddingOutput, Error> {
151 self.http_client.output_of(task, how).await
152 }
153
154 /// A batch of embeddings trying to capture the semantic meaning of a text.
155 pub async fn batch_semantic_embedding(
156 &self,
157 task: &TaskBatchSemanticEmbedding<'_>,
158 how: &How,
159 ) -> Result<BatchSemanticEmbeddingOutput, Error> {
160 self.http_client.output_of(task, how).await
161 }
162
163 /// An embedding trying to capture the semantic meaning of a text.
164 ///
165 /// By providing instructions, you can help the model better understand the nuances of your
166 /// specific data, leading to embeddings that are more useful for your use case.
167 pub async fn semantic_embedding_with_instruction(
168 &self,
169 task: &TaskSemanticEmbeddingWithInstruction<'_>,
170 how: &How,
171 ) -> Result<SemanticEmbeddingOutput, Error> {
172 self.http_client.output_of(task, how).await
173 }
174
175 /// Instruct a model served by the aleph alpha API to continue writing a piece of text (or
176 /// multimodal document).
177 ///
178 /// ```no_run
179 /// use aleph_alpha_client::{Client, How, TaskCompletion, Task, Error};
180 ///
181 /// async fn print_completion() -> Result<(), Error> {
182 /// // Authenticate against API. Fetches token.
183 /// let client = Client::from_env()?;
184 ///
185 /// // Name of the model we we want to use. Large models give usually better answer, but are
186 /// // also slower and more costly.
187 /// let model = "luminous-base";
188 ///
189 /// // The task we want to perform. Here we want to continue the sentence: "An apple a day
190 /// // ..."
191 /// let task = TaskCompletion::from_text("An apple a day");
192 ///
193 /// // Retrieve answer from API
194 /// let response = client.completion(&task, model, &How::default()).await?;
195 ///
196 /// // Print entire sentence with completion
197 /// println!("An apple a day{}", response.completion);
198 /// Ok(())
199 /// }
200 /// ```
201 pub async fn completion(
202 &self,
203 task: &TaskCompletion<'_>,
204 model: &str,
205 how: &How,
206 ) -> Result<CompletionOutput, Error> {
207 self.http_client
208 .output_of(&Task::with_model(task, model), how)
209 .await
210 }
211
212 /// Instruct a model served by the aleph alpha API to continue writing a piece of text.
213 /// Stream the response as a series of events.
214 ///
215 /// ```no_run
216 /// use aleph_alpha_client::{Client, How, TaskCompletion, Error, CompletionEvent};
217 /// use futures_util::StreamExt;
218 ///
219 /// async fn print_stream_completion() -> Result<(), Error> {
220 /// // Authenticate against API. Fetches token.
221 /// let client = Client::from_env()?;
222 ///
223 /// // Name of the model we we want to use. Large models give usually better answer, but are
224 /// // also slower and more costly.
225 /// let model = "luminous-base";
226 ///
227 /// // The task we want to perform. Here we want to continue the sentence: "An apple a day
228 /// // ..."
229 /// let task = TaskCompletion::from_text("An apple a day");
230 ///
231 /// // Retrieve stream from API
232 /// let mut stream = client.stream_completion(&task, model, &How::default()).await?;
233 /// while let Some(Ok(event)) = stream.next().await {
234 /// if let CompletionEvent::Delta { completion, logprobs: _ } = event {
235 /// println!("{}", completion);
236 /// }
237 /// }
238 /// Ok(())
239 /// }
240 /// ```
241 pub async fn stream_completion<'task>(
242 &self,
243 task: &'task TaskCompletion<'task>,
244 model: &'task str,
245 how: &How,
246 ) -> Result<Pin<Box<dyn Stream<Item = Result<CompletionEvent, Error>> + Send + 'task>>, Error>
247 {
248 self.http_client
249 .stream_output_of(StreamTask::with_model(task, model), how)
250 .await
251 }
252
253 /// Send a chat message to a model.
254 /// ```no_run
255 /// use aleph_alpha_client::{Client, How, TaskChat, Error, Message};
256 ///
257 /// async fn print_chat() -> Result<(), Error> {
258 /// // Authenticate against API. Fetches token.
259 /// let client = Client::from_env()?;
260 ///
261 /// // Name of a model that supports chat.
262 /// let model = "pharia-1-llm-7b-control";
263 ///
264 /// // Create a chat task with a user message.
265 /// let message = Message::user("Hello, how are you?");
266 /// let task = TaskChat::with_message(message);
267 ///
268 /// // Send the message to the model.
269 /// let response = client.chat(&task, model, &How::default()).await?;
270 ///
271 /// // Print the model response
272 /// println!("{}", response.message.content);
273 /// Ok(())
274 /// }
275 /// ```
276 pub async fn chat(
277 &self,
278 task: &TaskChat<'_>,
279 model: &str,
280 how: &How,
281 ) -> Result<ChatOutput, Error> {
282 self.http_client
283 .output_of(&Task::with_model(task, model), how)
284 .await
285 }
286
287 /// Send a chat message to a model. Stream the response as a series of events.
288 /// ```no_run
289 /// use aleph_alpha_client::{Client, How, TaskChat, Error, Message, ChatEvent};
290 /// use futures_util::StreamExt;
291 ///
292 /// async fn print_stream_chat() -> Result<(), Error> {
293 /// // Authenticate against API. Fetches token.
294 /// let client = Client::from_env()?;
295 ///
296 /// // Name of a model that supports chat.
297 /// let model = "pharia-1-llm-7b-control";
298 ///
299 /// // Create a chat task with a user message.
300 /// let message = Message::user("Hello, how are you?");
301 /// let task = TaskChat::with_message(message);
302 ///
303 /// // Send the message to the model.
304 /// let mut stream = client.stream_chat(&task, model, &How::default()).await?;
305 /// while let Some(Ok(event)) = stream.next().await {
306 /// if let ChatEvent::MessageDelta { content, logprobs: _ } = event {
307 /// println!("{}", content);
308 /// }
309 /// }
310 /// Ok(())
311 /// }
312 /// ```
313 pub async fn stream_chat<'task>(
314 &self,
315 task: &'task TaskChat<'_>,
316 model: &'task str,
317 how: &How,
318 ) -> Result<Pin<Box<dyn Stream<Item = Result<ChatEvent, Error>> + Send + 'task>>, Error> {
319 self.http_client
320 .stream_output_of(StreamTask::with_model(task, model), how)
321 .await
322 }
323
324 /// Returns an explanation given a prompt and a target (typically generated
325 /// by a previous completion request). The explanation describes how individual parts
326 /// of the prompt influenced the target.
327 ///
328 /// ```no_run
329 /// use aleph_alpha_client::{Client, How, TaskCompletion, Task, Error, Granularity,
330 /// TaskExplanation, Stopping, Prompt, Sampling, Logprobs};
331 ///
332 /// async fn print_explanation() -> Result<(), Error> {
333 /// let client = Client::from_env()?;
334 ///
335 /// // Name of the model we we want to use. Large models give usually better answer, but are
336 /// // also slower and more costly.
337 /// let model = "luminous-base";
338 ///
339 /// // input for the completion
340 /// let prompt = Prompt::from_text("An apple a day");
341 ///
342 /// let task = TaskCompletion {
343 /// prompt: prompt.clone(),
344 /// stopping: Stopping::from_maximum_tokens(10),
345 /// sampling: Sampling::MOST_LIKELY,
346 /// special_tokens: false,
347 /// logprobs: Logprobs::No,
348 /// echo: false,
349 /// };
350 /// let response = client.completion(&task, model, &How::default()).await?;
351 ///
352 /// let task = TaskExplanation {
353 /// prompt: prompt, // same input as for completion
354 /// target: &response.completion, // output of completion
355 /// granularity: Granularity::default(),
356 /// };
357 /// let response = client.explanation(&task, model, &How::default()).await?;
358 ///
359 /// dbg!(&response);
360 /// Ok(())
361 /// }
362 /// ```
363 pub async fn explanation(
364 &self,
365 task: &TaskExplanation<'_>,
366 model: &str,
367 how: &How,
368 ) -> Result<ExplanationOutput, Error> {
369 self.http_client
370 .output_of(&task.with_model(model), how)
371 .await
372 }
373
374 /// Tokenize a prompt for a specific model.
375 ///
376 /// ```no_run
377 /// use aleph_alpha_client::{Client, Error, How, TaskTokenization};
378 ///
379 /// async fn tokenize() -> Result<(), Error> {
380 /// let client = Client::from_env()?;
381 ///
382 /// // Name of the model for which we want to tokenize text.
383 /// let model = "luminous-base";
384 ///
385 /// // Text prompt to be tokenized.
386 /// let prompt = "An apple a day";
387 ///
388 /// let task = TaskTokenization {
389 /// prompt,
390 /// tokens: true, // return text-tokens
391 /// token_ids: true, // return numeric token-ids
392 /// };
393 /// let responses = client.tokenize(&task, model, &How::default()).await?;
394 ///
395 /// dbg!(&responses);
396 /// Ok(())
397 /// }
398 /// ```
399 pub async fn tokenize(
400 &self,
401 task: &TaskTokenization<'_>,
402 model: &str,
403 how: &How,
404 ) -> Result<TokenizationOutput, Error> {
405 self.http_client
406 .output_of(&task.with_model(model), how)
407 .await
408 }
409
410 /// Detokenize a list of token ids into a string.
411 ///
412 /// ```no_run
413 /// use aleph_alpha_client::{Client, Error, How, TaskDetokenization};
414 ///
415 /// async fn detokenize() -> Result<(), Error> {
416 /// let client = Client::from_env()?;
417 ///
418 /// // Specify the name of the model whose tokenizer was used to generate the input token ids.
419 /// let model = "luminous-base";
420 ///
421 /// // Token ids to convert into text.
422 /// let token_ids: Vec<u32> = vec![556, 48741, 247, 2983];
423 ///
424 /// let task = TaskDetokenization {
425 /// token_ids: &token_ids,
426 /// };
427 /// let responses = client.detokenize(&task, model, &How::default()).await?;
428 ///
429 /// dbg!(&responses);
430 /// Ok(())
431 /// }
432 /// ```
433 pub async fn detokenize(
434 &self,
435 task: &TaskDetokenization<'_>,
436 model: &str,
437 how: &How,
438 ) -> Result<DetokenizationOutput, Error> {
439 self.http_client
440 .output_of(&task.with_model(model), how)
441 .await
442 }
443
444 pub async fn tokenizer_by_model(
445 &self,
446 model: &str,
447 api_token: Option<String>,
448 context: Option<TraceContext>,
449 ) -> Result<Tokenizer, Error> {
450 self.http_client
451 .tokenizer_by_model(model, api_token, context)
452 .await
453 }
454}
455
456/// Controls of how to execute a task
457#[derive(Clone, PartialEq, Eq, Hash)]
458pub struct How {
459 /// The be-nice flag is used to reduce load for the models you intend to use.
460 /// This is commonly used if you are conducting experiments
461 /// or trying things out that create a large load on the aleph-alpha-api
462 /// and you do not want to increase queue time for other users too much.
463 ///
464 /// (!) This increases how often you get a `Busy` response.
465 pub be_nice: bool,
466
467 /// The maximum duration of a request before the client cancels the request. This is not passed on
468 /// to the server but only handled by the client locally, i.e. the client will not wait longer than
469 /// this duration for a response.
470 pub client_timeout: Duration,
471
472 /// API token used to authenticate the request, overwrites the default token provided on setup
473 /// Default token may not provide the tracking or permission that is wanted for the request
474 pub api_token: Option<String>,
475
476 /// Optionally pass a trace context to propagate tracing information through distributed systems.
477 pub trace_context: Option<TraceContext>,
478}
479
480impl Default for How {
481 fn default() -> Self {
482 // the aleph-alpha-api cancels request after 5 minute
483 let api_timeout = Duration::from_secs(300);
484 Self {
485 be_nice: Default::default(),
486 // on the client side a request can take longer in case of network errors
487 // therefore by default we wait slightly longer
488 client_timeout: api_timeout + Duration::from_secs(5),
489 api_token: None,
490 trace_context: None,
491 }
492 }
493}
494
495/// Intended to compare embeddings.
496///
497/// ```no_run
498/// use aleph_alpha_client::{
499/// Client, Prompt, TaskSemanticEmbedding, cosine_similarity, SemanticRepresentation, How
500/// };
501///
502/// async fn semantic_search_with_luminous_base(client: &Client) {
503/// // Given
504/// let robot_fact = Prompt::from_text(
505/// "A robot is a machine—especially one programmable by a computer—capable of carrying out a \
506/// complex series of actions automatically.",
507/// );
508/// let pizza_fact = Prompt::from_text(
509/// "Pizza (Italian: [ˈpittsa], Neapolitan: [ˈpittsə]) is a dish of Italian origin consisting \
510/// of a usually round, flat base of leavened wheat-based dough topped with tomatoes, cheese, \
511/// and often various other ingredients (such as various types of sausage, anchovies, \
512/// mushrooms, onions, olives, vegetables, meat, ham, etc.), which is then baked at a high \
513/// temperature, traditionally in a wood-fired oven.",
514/// );
515/// let query = Prompt::from_text("What is Pizza?");
516/// let how = How::default();
517///
518/// // When
519/// let robot_embedding_task = TaskSemanticEmbedding {
520/// prompt: robot_fact,
521/// representation: SemanticRepresentation::Document,
522/// compress_to_size: Some(128),
523/// };
524/// let robot_embedding = client.semantic_embedding(
525/// &robot_embedding_task,
526/// &how,
527/// ).await.unwrap().embedding;
528///
529/// let pizza_embedding_task = TaskSemanticEmbedding {
530/// prompt: pizza_fact,
531/// representation: SemanticRepresentation::Document,
532/// compress_to_size: Some(128),
533/// };
534/// let pizza_embedding = client.semantic_embedding(
535/// &pizza_embedding_task,
536/// &how,
537/// ).await.unwrap().embedding;
538///
539/// let query_embedding_task = TaskSemanticEmbedding {
540/// prompt: query,
541/// representation: SemanticRepresentation::Query,
542/// compress_to_size: Some(128),
543/// };
544/// let query_embedding = client.semantic_embedding(
545/// &query_embedding_task,
546/// &how,
547/// ).await.unwrap().embedding;
548/// let similarity_pizza = cosine_similarity(&query_embedding, &pizza_embedding);
549/// println!("similarity pizza: {similarity_pizza}");
550/// let similarity_robot = cosine_similarity(&query_embedding, &robot_embedding);
551/// println!("similarity robot: {similarity_robot}");
552///
553/// // Then
554///
555/// // The fact about pizza should be more relevant to the "What is Pizza?" question than a fact
556/// // about robots.
557/// assert!(similarity_pizza > similarity_robot);
558/// }
559/// ```
560pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
561 let ab: f32 = a.iter().zip(b).map(|(a, b)| a * b).sum();
562 let aa: f32 = a.iter().map(|a| a * a).sum();
563 let bb: f32 = b.iter().map(|b| b * b).sum();
564 let prod_len = (aa * bb).sqrt();
565 ab / prod_len
566}
567
568#[cfg(test)]
569mod tests {
570 use crate::Prompt;
571
572 #[test]
573 fn ability_to_generate_prompt_in_local_function() {
574 fn local_function() -> Prompt<'static> {
575 Prompt::from_text(String::from("My test prompt"))
576 }
577
578 assert_eq!(Prompt::from_text("My test prompt"), local_function())
579 }
580}