1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
use crate::{http::Task, Job, Prompt};
/// Allows you to choose a semantic representation fitting for your usecase.
#[derive(Serialize, Debug)]
#[serde(rename_all = "snake_case")]
pub enum SemanticRepresentation {
/// Useful for comparing prompts to each other, in use cases such as clustering, classification,
/// similarity, etc. `Symmetric` embeddings are intended to be compared with other `Symmetric`
/// embeddings.
Symmetric,
/// `Document` and `Query` are used together in use cases such as search where you want to
/// compare shorter queries against larger documents. `Document` embeddings are optimized for
/// larger pieces of text to compare queries against.
Document,
/// `Document` and `Query` are used together in use cases such as search where you want to
/// compare shorter queries against larger documents. `Query` embeddings are optimized for
/// shorter texts, such as questions or keywords.
Query,
}
/// Create embeddings for prompts which can be used for downstream tasks. E.g. search, classifiers
#[derive(Serialize, Debug)]
pub struct TaskSemanticEmbedding<'a> {
/// The prompt (usually text) to be embedded.
pub prompt: Prompt<'a>,
/// Semantic representation to embed the prompt with. This parameter is governed by the specific
/// usecase in mind.
pub representation: SemanticRepresentation,
/// Default behaviour is to return the full embedding, but you can optionally request an
/// embedding compressed to a smaller set of dimensions. A size of `128` is supported for every
/// model.
///
/// The 128 size is expected to have a small drop in accuracy performance (4-6%), with the
/// benefit of being much smaller, which makes comparing these embeddings much faster for use
/// cases where speed is critical.
///
/// The 128 size can also perform better if you are embedding short texts or documents.
#[serde(skip_serializing_if = "Option::is_none")]
pub compress_to_size: Option<u32>,
}
/// Appends model and hosting to the bare task
/// T stands for TaskSemanticEmbedding or TaskBatchSemanticEmbedding
#[derive(Serialize, Debug)]
struct RequestBody<'a, T: Serialize + Debug> {
/// Currently semantic embedding still requires a model parameter, even though "luminous-base"
/// is the only model to support it. This makes Semantic embedding both a Service and a Method.
model: &'a str,
#[serde(flatten)]
semantic_embedding_task: &'a T,
}
/// Heap allocated embedding. Can hold full embeddings or compressed ones
#[derive(Deserialize)]
pub struct SemanticEmbeddingOutput {
pub embedding: Vec<f32>,
}
impl Task for TaskSemanticEmbedding<'_> {
type Output = SemanticEmbeddingOutput;
type ResponseBody = SemanticEmbeddingOutput;
fn build_request(
&self,
client: &reqwest::Client,
base: &str,
model: &str,
) -> reqwest::RequestBuilder {
let body = RequestBody {
model,
semantic_embedding_task: self,
};
client.post(format!("{base}/semantic_embed")).json(&body)
}
fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output {
response
}
}
impl Job for TaskSemanticEmbedding<'_> {
type Output = SemanticEmbeddingOutput;
type ResponseBody = SemanticEmbeddingOutput;
fn build_request(&self, client: &reqwest::Client, base: &str) -> reqwest::RequestBuilder {
let model = "luminous-base";
let body = RequestBody {
model,
semantic_embedding_task: self,
};
client.post(format!("{base}/semantic_embed")).json(&body)
}
fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output {
response
}
}
/// Create embeddings for multiple prompts
#[derive(Serialize, Debug)]
pub struct TaskBatchSemanticEmbedding<'a> {
/// The prompt (usually text) to be embedded.
pub prompts: Vec<Prompt<'a>>,
/// Semantic representation to embed the prompt with. This parameter is governed by the specific
/// usecase in mind.
pub representation: SemanticRepresentation,
/// Default behaviour is to return the full embedding, but you can optionally request an
/// embedding compressed to a smaller set of dimensions. A size of `128` is supported for every
/// model.
///
/// The 128 size is expected to have a small drop in accuracy performance (4-6%), with the
/// benefit of being much smaller, which makes comparing these embeddings much faster for use
/// cases where speed is critical.
///
/// The 128 size can also perform better if you are embedding short texts or documents.
#[serde(skip_serializing_if = "Option::is_none")]
pub compress_to_size: Option<u32>,
}
/// Heap allocated vec of embeddings. Can hold full embeddings or compressed ones
#[derive(Deserialize)]
pub struct BatchSemanticEmbeddingOutput {
pub embeddings: Vec<Vec<f32>>,
}
impl Job for TaskBatchSemanticEmbedding<'_> {
type Output = BatchSemanticEmbeddingOutput;
type ResponseBody = BatchSemanticEmbeddingOutput;
fn build_request(&self, client: &reqwest::Client, base: &str) -> reqwest::RequestBuilder {
let model = "luminous-base";
let body = RequestBody {
model,
semantic_embedding_task: self,
};
client
.post(format!("{base}/batch_semantic_embed"))
.json(&body)
}
fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output {
response
}
}