aleph_alpha_client/
semantic_embedding.rs

1use serde::{Deserialize, Serialize};
2use std::fmt::Debug;
3
4use crate::{http::Task, Job, Prompt};
5
6/// Allows you to choose a semantic representation fitting for your usecase.
7#[derive(Serialize, Debug)]
8#[serde(rename_all = "snake_case")]
9pub enum SemanticRepresentation {
10    /// Useful for comparing prompts to each other, in use cases such as clustering, classification,
11    /// similarity, etc. `Symmetric` embeddings are intended to be compared with other `Symmetric`
12    /// embeddings.
13    Symmetric,
14    /// `Document` and `Query` are used together in use cases such as search where you want to
15    /// compare shorter queries against larger documents. `Document` embeddings are optimized for
16    /// larger pieces of text to compare queries against.
17    Document,
18    /// `Document` and `Query` are used together in use cases such as search where you want to
19    /// compare shorter queries against larger documents. `Query` embeddings are optimized for
20    /// shorter texts, such as questions or keywords.
21    Query,
22}
23
24/// Create embeddings for prompts which can be used for downstream tasks. E.g. search, classifiers
25#[derive(Serialize, Debug)]
26pub struct TaskSemanticEmbedding<'a> {
27    /// The prompt (usually text) to be embedded.
28    pub prompt: Prompt<'a>,
29    /// Semantic representation to embed the prompt with. This parameter is governed by the specific
30    /// usecase in mind.
31    pub representation: SemanticRepresentation,
32    /// Default behaviour is to return the full embedding, but you can optionally request an
33    /// embedding compressed to a smaller set of dimensions. A size of `128` is supported for every
34    /// model.
35    ///
36    /// The 128 size is expected to have a small drop in accuracy performance (4-6%), with the
37    /// benefit of being much smaller, which makes comparing these embeddings much faster for use
38    /// cases where speed is critical.
39    ///
40    /// The 128 size can also perform better if you are embedding short texts or documents.
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub compress_to_size: Option<u32>,
43}
44
45/// Appends model and hosting to the bare task
46/// T stands for TaskSemanticEmbedding or TaskBatchSemanticEmbedding
47#[derive(Serialize, Debug)]
48struct RequestBody<'a, T: Serialize + Debug> {
49    /// Currently semantic embedding still requires a model parameter, even though "luminous-base"
50    /// is the only model to support it. This makes Semantic embedding both a Service and a Method.
51    model: &'a str,
52    #[serde(flatten)]
53    semantic_embedding_task: &'a T,
54}
55
56/// Heap allocated embedding. Can hold full embeddings or compressed ones
57#[derive(Deserialize)]
58pub struct SemanticEmbeddingOutput {
59    pub embedding: Vec<f32>,
60}
61
62impl Task for TaskSemanticEmbedding<'_> {
63    type Output = SemanticEmbeddingOutput;
64    type ResponseBody = SemanticEmbeddingOutput;
65
66    fn build_request(
67        &self,
68        client: &reqwest::Client,
69        base: &str,
70        model: &str,
71    ) -> reqwest::RequestBuilder {
72        let body = RequestBody {
73            model,
74            semantic_embedding_task: self,
75        };
76        client.post(format!("{base}/semantic_embed")).json(&body)
77    }
78
79    fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output {
80        response
81    }
82}
83
84impl Job for TaskSemanticEmbedding<'_> {
85    type Output = SemanticEmbeddingOutput;
86    type ResponseBody = SemanticEmbeddingOutput;
87
88    fn build_request(&self, client: &reqwest::Client, base: &str) -> reqwest::RequestBuilder {
89        let model = "luminous-base";
90        let body = RequestBody {
91            model,
92            semantic_embedding_task: self,
93        };
94        client.post(format!("{base}/semantic_embed")).json(&body)
95    }
96
97    fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output {
98        response
99    }
100}
101
102/// Create embeddings for multiple prompts
103#[derive(Serialize, Debug)]
104pub struct TaskBatchSemanticEmbedding<'a> {
105    /// The prompt (usually text) to be embedded.
106    pub prompts: Vec<Prompt<'a>>,
107    /// Semantic representation to embed the prompt with. This parameter is governed by the specific
108    /// usecase in mind.
109    pub representation: SemanticRepresentation,
110    /// Default behaviour is to return the full embedding, but you can optionally request an
111    /// embedding compressed to a smaller set of dimensions. A size of `128` is supported for every
112    /// model.
113    ///
114    /// The 128 size is expected to have a small drop in accuracy performance (4-6%), with the
115    /// benefit of being much smaller, which makes comparing these embeddings much faster for use
116    /// cases where speed is critical.
117    ///
118    /// The 128 size can also perform better if you are embedding short texts or documents.
119    #[serde(skip_serializing_if = "Option::is_none")]
120    pub compress_to_size: Option<u32>,
121}
122
123/// Heap allocated vec of embeddings. Can hold full embeddings or compressed ones
124#[derive(Deserialize)]
125pub struct BatchSemanticEmbeddingOutput {
126    pub embeddings: Vec<Vec<f32>>,
127}
128
129impl Job for TaskBatchSemanticEmbedding<'_> {
130    type Output = BatchSemanticEmbeddingOutput;
131    type ResponseBody = BatchSemanticEmbeddingOutput;
132
133    fn build_request(&self, client: &reqwest::Client, base: &str) -> reqwest::RequestBuilder {
134        let model = "luminous-base";
135        let body = RequestBody {
136            model,
137            semantic_embedding_task: self,
138        };
139        client
140            .post(format!("{base}/batch_semantic_embed"))
141            .json(&body)
142    }
143
144    fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output {
145        response
146    }
147}