1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
use serde::{Deserialize, Serialize};
use std::fmt::Debug;

use crate::{http::Task, Job, Prompt};

/// Allows you to choose a semantic representation fitting for your usecase.
#[derive(Serialize, Debug)]
#[serde(rename_all = "snake_case")]
pub enum SemanticRepresentation {
    /// Useful for comparing prompts to each other, in use cases such as clustering, classification,
    /// similarity, etc. `Symmetric` embeddings are intended to be compared with other `Symmetric`
    /// embeddings.
    Symmetric,
    /// `Document` and `Query` are used together in use cases such as search where you want to
    /// compare shorter queries against larger documents. `Document` embeddings are optimized for
    /// larger pieces of text to compare queries against.
    Document,
    /// `Document` and `Query` are used together in use cases such as search where you want to
    /// compare shorter queries against larger documents. `Query` embeddings are optimized for
    /// shorter texts, such as questions or keywords.
    Query,
}

/// Create embeddings for prompts which can be used for downstream tasks. E.g. search, classifiers
#[derive(Serialize, Debug)]
pub struct TaskSemanticEmbedding<'a> {
    /// The prompt (usually text) to be embedded.
    pub prompt: Prompt<'a>,
    /// Semantic representation to embed the prompt with. This parameter is governed by the specific
    /// usecase in mind.
    pub representation: SemanticRepresentation,
    /// Default behaviour is to return the full embedding, but you can optionally request an
    /// embedding compressed to a smaller set of dimensions. A size of `128` is supported for every
    /// model.
    ///
    /// The 128 size is expected to have a small drop in accuracy performance (4-6%), with the
    /// benefit of being much smaller, which makes comparing these embeddings much faster for use
    /// cases where speed is critical.
    ///
    /// The 128 size can also perform better if you are embedding short texts or documents.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub compress_to_size: Option<u32>,
}

/// Appends model and hosting to the bare task
/// T stands for TaskSemanticEmbedding or TaskBatchSemanticEmbedding
#[derive(Serialize, Debug)]
struct RequestBody<'a, T: Serialize + Debug> {
    /// Currently semantic embedding still requires a model parameter, even though "luminous-base"
    /// is the only model to support it. This makes Semantic embedding both a Service and a Method.
    model: &'a str,
    #[serde(flatten)]
    semantic_embedding_task: &'a T,
}

/// Heap allocated embedding. Can hold full embeddings or compressed ones
#[derive(Deserialize)]
pub struct SemanticEmbeddingOutput {
    pub embedding: Vec<f32>,
}

impl Task for TaskSemanticEmbedding<'_> {
    type Output = SemanticEmbeddingOutput;
    type ResponseBody = SemanticEmbeddingOutput;

    fn build_request(
        &self,
        client: &reqwest::Client,
        base: &str,
        model: &str,
    ) -> reqwest::RequestBuilder {
        let body = RequestBody {
            model,
            semantic_embedding_task: self,
        };
        client.post(format!("{base}/semantic_embed")).json(&body)
    }

    fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output {
        response
    }
}

impl Job for TaskSemanticEmbedding<'_> {
    type Output = SemanticEmbeddingOutput;
    type ResponseBody = SemanticEmbeddingOutput;

    fn build_request(&self, client: &reqwest::Client, base: &str) -> reqwest::RequestBuilder {
        let model = "luminous-base";
        let body = RequestBody {
            model,
            semantic_embedding_task: self,
        };
        client.post(format!("{base}/semantic_embed")).json(&body)
    }

    fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output {
        response
    }
}

/// Create embeddings for multiple prompts
#[derive(Serialize, Debug)]
pub struct TaskBatchSemanticEmbedding<'a> {
    /// The prompt (usually text) to be embedded.
    pub prompts: Vec<Prompt<'a>>,
    /// Semantic representation to embed the prompt with. This parameter is governed by the specific
    /// usecase in mind.
    pub representation: SemanticRepresentation,
    /// Default behaviour is to return the full embedding, but you can optionally request an
    /// embedding compressed to a smaller set of dimensions. A size of `128` is supported for every
    /// model.
    ///
    /// The 128 size is expected to have a small drop in accuracy performance (4-6%), with the
    /// benefit of being much smaller, which makes comparing these embeddings much faster for use
    /// cases where speed is critical.
    ///
    /// The 128 size can also perform better if you are embedding short texts or documents.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub compress_to_size: Option<u32>,
}

/// Heap allocated vec of embeddings. Can hold full embeddings or compressed ones
#[derive(Deserialize)]
pub struct BatchSemanticEmbeddingOutput {
    pub embeddings: Vec<Vec<f32>>,
}

impl Job for TaskBatchSemanticEmbedding<'_> {
    type Output = BatchSemanticEmbeddingOutput;
    type ResponseBody = BatchSemanticEmbeddingOutput;

    fn build_request(&self, client: &reqwest::Client, base: &str) -> reqwest::RequestBuilder {
        let model = "luminous-base";
        let body = RequestBody {
            model,
            semantic_embedding_task: self,
        };
        client
            .post(format!("{base}/batch_semantic_embed"))
            .json(&body)
    }

    fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output {
        response
    }
}