pylate_rs/types.rs
1use serde::{Deserialize, Serialize};
2
3/// Input structure for similarity computation.
4///
5/// Contains two lists of strings: queries and documents, for which
6/// a similarity matrix will be computed.
7#[derive(Serialize, Deserialize, Debug, Clone)]
8pub struct SimilarityInput {
9 /// A list of query strings.
10 pub queries: Vec<String>,
11 /// A list of document strings.
12 pub documents: Vec<String>,
13}
14
15/// Input structure for the encoding process.
16///
17/// Contains a single list of sentences to be encoded into embeddings.
18#[derive(Serialize, Deserialize, Debug, Clone)]
19pub struct EncodeInput {
20 /// A list of sentences (queries or documents) to be encoded.
21 pub sentences: Vec<String>,
22 /// An optional batch size to override the model's default.
23 pub batch_size: Option<usize>,
24}
25
26/// Output structure for the encoding process.
27///
28/// Contains the resulting embeddings for a batch of sentences.
29#[derive(Serialize, Deserialize, Debug)]
30pub struct EncodeOutput {
31 /// A nested vector representing the embeddings.
32 /// The structure is `[batch_size, sequence_length, embedding_dimension]`.
33 pub embeddings: Vec<Vec<Vec<f32>>>,
34}
35
36/// Output structure for the similarity computation.
37///
38/// Contains a matrix of similarity scores between queries and documents.
39#[derive(Serialize, Deserialize, Debug)]
40pub struct Similarities {
41 /// A 2D vector where `data[i][j]` is the similarity score
42 /// between the i-th query and the j-th document.
43 pub data: Vec<Vec<f32>>,
44}
45
46/// Output structure for the raw similarity matrix computation.
47///
48/// This provides a detailed, un-reduced view of the similarity scores,
49/// along with the tokens for queries and documents for inspection.
50#[derive(Serialize, Deserialize, Debug)]
51pub struct RawSimilarityOutput {
52 /// The raw similarity matrix with dimensions
53 /// `[num_queries, num_documents, query_length, document_length]`.
54 pub similarity_matrix: Vec<Vec<Vec<Vec<f32>>>>,
55 /// The tokens corresponding to each query.
56 pub query_tokens: Vec<Vec<String>>,
57 /// The tokens corresponding to each document.
58 pub document_tokens: Vec<Vec<String>>,
59}