aleph_alpha_client/semantic_embedding.rs
1use serde::{Deserialize, Serialize};
2use std::fmt::Debug;
3
4use crate::{http::Task, Job, Prompt};
5
6/// Allows you to choose a semantic representation fitting for your usecase.
7#[derive(Serialize, Debug)]
8#[serde(rename_all = "snake_case")]
9pub enum SemanticRepresentation {
10 /// Useful for comparing prompts to each other, in use cases such as clustering, classification,
11 /// similarity, etc. `Symmetric` embeddings are intended to be compared with other `Symmetric`
12 /// embeddings.
13 Symmetric,
14 /// `Document` and `Query` are used together in use cases such as search where you want to
15 /// compare shorter queries against larger documents. `Document` embeddings are optimized for
16 /// larger pieces of text to compare queries against.
17 Document,
18 /// `Document` and `Query` are used together in use cases such as search where you want to
19 /// compare shorter queries against larger documents. `Query` embeddings are optimized for
20 /// shorter texts, such as questions or keywords.
21 Query,
22}
23
24/// Create embeddings for prompts which can be used for downstream tasks. E.g. search, classifiers
25#[derive(Serialize, Debug)]
26pub struct TaskSemanticEmbedding<'a> {
27 /// The prompt (usually text) to be embedded.
28 pub prompt: Prompt<'a>,
29 /// Semantic representation to embed the prompt with. This parameter is governed by the specific
30 /// usecase in mind.
31 pub representation: SemanticRepresentation,
32 /// Default behaviour is to return the full embedding, but you can optionally request an
33 /// embedding compressed to a smaller set of dimensions. A size of `128` is supported for every
34 /// model.
35 ///
36 /// The 128 size is expected to have a small drop in accuracy performance (4-6%), with the
37 /// benefit of being much smaller, which makes comparing these embeddings much faster for use
38 /// cases where speed is critical.
39 ///
40 /// The 128 size can also perform better if you are embedding short texts or documents.
41 #[serde(skip_serializing_if = "Option::is_none")]
42 pub compress_to_size: Option<u32>,
43}
44
45/// Appends model and hosting to the bare task
46/// T stands for TaskSemanticEmbedding or TaskBatchSemanticEmbedding
47#[derive(Serialize, Debug)]
48struct RequestBody<'a, T: Serialize + Debug> {
49 /// Currently semantic embedding still requires a model parameter, even though "luminous-base"
50 /// is the only model to support it. This makes Semantic embedding both a Service and a Method.
51 model: &'a str,
52 #[serde(flatten)]
53 semantic_embedding_task: &'a T,
54}
55
56/// Heap allocated embedding. Can hold full embeddings or compressed ones
57#[derive(Deserialize)]
58pub struct SemanticEmbeddingOutput {
59 pub embedding: Vec<f32>,
60}
61
62impl Task for TaskSemanticEmbedding<'_> {
63 type Output = SemanticEmbeddingOutput;
64 type ResponseBody = SemanticEmbeddingOutput;
65
66 fn build_request(
67 &self,
68 client: &reqwest::Client,
69 base: &str,
70 model: &str,
71 ) -> reqwest::RequestBuilder {
72 let body = RequestBody {
73 model,
74 semantic_embedding_task: self,
75 };
76 client.post(format!("{base}/semantic_embed")).json(&body)
77 }
78
79 fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output {
80 response
81 }
82}
83
84impl Job for TaskSemanticEmbedding<'_> {
85 type Output = SemanticEmbeddingOutput;
86 type ResponseBody = SemanticEmbeddingOutput;
87
88 fn build_request(&self, client: &reqwest::Client, base: &str) -> reqwest::RequestBuilder {
89 let model = "luminous-base";
90 let body = RequestBody {
91 model,
92 semantic_embedding_task: self,
93 };
94 client.post(format!("{base}/semantic_embed")).json(&body)
95 }
96
97 fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output {
98 response
99 }
100}
101
102/// Create embeddings for multiple prompts
103#[derive(Serialize, Debug)]
104pub struct TaskBatchSemanticEmbedding<'a> {
105 /// The prompt (usually text) to be embedded.
106 pub prompts: Vec<Prompt<'a>>,
107 /// Semantic representation to embed the prompt with. This parameter is governed by the specific
108 /// usecase in mind.
109 pub representation: SemanticRepresentation,
110 /// Default behaviour is to return the full embedding, but you can optionally request an
111 /// embedding compressed to a smaller set of dimensions. A size of `128` is supported for every
112 /// model.
113 ///
114 /// The 128 size is expected to have a small drop in accuracy performance (4-6%), with the
115 /// benefit of being much smaller, which makes comparing these embeddings much faster for use
116 /// cases where speed is critical.
117 ///
118 /// The 128 size can also perform better if you are embedding short texts or documents.
119 #[serde(skip_serializing_if = "Option::is_none")]
120 pub compress_to_size: Option<u32>,
121}
122
123/// Heap allocated vec of embeddings. Can hold full embeddings or compressed ones
124#[derive(Deserialize)]
125pub struct BatchSemanticEmbeddingOutput {
126 pub embeddings: Vec<Vec<f32>>,
127}
128
129impl Job for TaskBatchSemanticEmbedding<'_> {
130 type Output = BatchSemanticEmbeddingOutput;
131 type ResponseBody = BatchSemanticEmbeddingOutput;
132
133 fn build_request(&self, client: &reqwest::Client, base: &str) -> reqwest::RequestBuilder {
134 let model = "luminous-base";
135 let body = RequestBody {
136 model,
137 semantic_embedding_task: self,
138 };
139 client
140 .post(format!("{base}/batch_semantic_embed"))
141 .json(&body)
142 }
143
144 fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output {
145 response
146 }
147}