aleph_alpha_client/semantic_embedding/
embedding.rs

1use crate::semantic_embedding::{RequestBody, DEFAULT_EMBEDDING_MODEL};
2use crate::{Job, Prompt, SemanticRepresentation, Task};
3use serde::{Deserialize, Serialize};
4
5const ENDPOINT: &str = "semantic_embed";
6
7/// Create embeddings for prompts which can be used for downstream tasks. E.g. search, classifiers
8#[derive(Serialize, Debug)]
9pub struct TaskSemanticEmbedding<'a> {
10    /// The prompt (usually text) to be embedded.
11    pub prompt: Prompt<'a>,
12    /// Semantic representation to embed the prompt with. This parameter is governed by the specific
13    /// use case in mind.
14    pub representation: SemanticRepresentation,
15    /// Default behaviour is to return the full embedding, but you can optionally request an
16    /// embedding compressed to a smaller set of dimensions. A size of `128` is supported for every
17    /// model.
18    ///
19    /// The 128 size is expected to have a small drop in accuracy performance (4-6%), with the
20    /// benefit of being much smaller, which makes comparing these embeddings much faster for use
21    /// cases where speed is critical.
22    ///
23    /// The 128 size can also perform better if you are embedding short texts or documents.
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub compress_to_size: Option<u32>,
26}
27
28/// Heap allocated embedding. Can hold full embeddings or compressed ones
29#[derive(Deserialize)]
30pub struct SemanticEmbeddingOutput {
31    pub embedding: Vec<f32>,
32}
33
34impl Task for TaskSemanticEmbedding<'_> {
35    type Output = SemanticEmbeddingOutput;
36    type ResponseBody = SemanticEmbeddingOutput;
37
38    fn build_request(
39        &self,
40        client: &reqwest::Client,
41        base: &str,
42        model: &str,
43    ) -> reqwest::RequestBuilder {
44        let body = RequestBody {
45            model,
46            semantic_embedding_task: self,
47        };
48        client.post(format!("{base}/{ENDPOINT}")).json(&body)
49    }
50
51    fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output {
52        response
53    }
54}
55
56impl Job for TaskSemanticEmbedding<'_> {
57    type Output = SemanticEmbeddingOutput;
58    type ResponseBody = SemanticEmbeddingOutput;
59
60    fn build_request(&self, client: &reqwest::Client, base: &str) -> reqwest::RequestBuilder {
61        let body = RequestBody {
62            model: DEFAULT_EMBEDDING_MODEL,
63            semantic_embedding_task: self,
64        };
65        client.post(format!("{base}/{ENDPOINT}")).json(&body)
66    }
67
68    fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output {
69        response
70    }
71}