1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
use serde::{Deserialize, Serialize};

use crate::{http::Task, Prompt};

/// Allows you to choose a semantic representation fitting for your usecase.
#[derive(Serialize, Debug)]
#[serde(rename_all = "snake_case")]
pub enum SemanticRepresentation {
    /// Useful for comparing prompts to each other, in use cases such as clustering, classification,
    /// similarity, etc. `Symmetric` embeddings are intended to be compared with other `Symmetric`
    /// embeddings.
    Symmetric,
    /// `Document` and `Query` are used together in use cases such as search where you want to
    /// compare shorter queries against larger documents. `Document` embeddings are optimized for
    /// larger pieces of text to compare queries against.
    Document,
    /// `Document` and `Query` are used together in use cases such as search where you want to
    /// compare shorter queries against larger documents. `Query` embeddings are optimized for
    /// shorter texts, such as questions or keywords.
    Query,
}

/// Create embeddings for prompts which can be used for downstream tasks. E.g. search, classifiers
#[derive(Serialize, Debug)]
pub struct TaskSemanticEmbedding<'a> {
    /// The prompt (usually text) to be embedded.
    pub prompt: Prompt<'a>,
    /// Semantic representation to embed the prompt with. This parameter is governed by the specific
    /// usecase in mind.
    pub representation: SemanticRepresentation,
    /// Default behaviour is to return the full embedding, but you can optionally request an
    /// embedding compressed to a smaller set of dimensions. A size of `128` is supported for every
    /// model.
    ///
    /// The 128 size is expected to have a small drop in accuracy performance (4-6%), with the
    /// benefit of being much smaller, which makes comparing these embeddings much faster for use
    /// cases where speed is critical.
    ///
    /// The 128 size can also perform better if you are embedding short texts or documents.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub compress_to_size: Option<u32>,
}

/// Appends model and hosting to the bare task
#[derive(Serialize, Debug)]
struct RequestBody<'a> {
    model: &'a str,
    #[serde(flatten)]
    semantic_embedding_task: &'a TaskSemanticEmbedding<'a>,
}

/// Heap allocated embedding. Can hold full embeddings or compressed ones
#[derive(Deserialize)]
pub struct SemanticEmbeddingOutput {
    pub embedding: Vec<f32>,
}

impl Task for TaskSemanticEmbedding<'_> {
    type Output = SemanticEmbeddingOutput;

    type ResponseBody = SemanticEmbeddingOutput;

    fn build_request(
        &self,
        client: &reqwest::Client,
        base: &str,
        model: &str,
    ) -> reqwest::RequestBuilder {
        let body = RequestBody {
            model,
            semantic_embedding_task: self,
        };
        client.post(format!("{base}/semantic_embed")).json(&body)
    }

    fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output {
        response
    }
}