pylate_rs/
types.rs

1use serde::{Deserialize, Serialize};
2
3/// Input structure for similarity computation.
4///
5/// Contains two lists of strings: queries and documents, for which
6/// a similarity matrix will be computed.
7#[derive(Serialize, Deserialize, Debug, Clone)]
8pub struct SimilarityInput {
9    /// A list of query strings.
10    pub queries: Vec<String>,
11    /// A list of document strings.
12    pub documents: Vec<String>,
13}
14
15/// Input structure for the encoding process.
16///
17/// Contains a single list of sentences to be encoded into embeddings.
18#[derive(Serialize, Deserialize, Debug, Clone)]
19pub struct EncodeInput {
20    /// A list of sentences (queries or documents) to be encoded.
21    pub sentences: Vec<String>,
22    /// An optional batch size to override the model's default.
23    pub batch_size: Option<usize>,
24}
25
26/// Output structure for the encoding process.
27///
28/// Contains the resulting embeddings for a batch of sentences.
29#[derive(Serialize, Deserialize, Debug)]
30pub struct EncodeOutput {
31    /// A nested vector representing the embeddings.
32    /// The structure is `[batch_size, sequence_length, embedding_dimension]`.
33    pub embeddings: Vec<Vec<Vec<f32>>>,
34}
35
36/// Output structure for the similarity computation.
37///
38/// Contains a matrix of similarity scores between queries and documents.
39#[derive(Serialize, Deserialize, Debug)]
40pub struct Similarities {
41    /// A 2D vector where `data[i][j]` is the similarity score
42    /// between the i-th query and the j-th document.
43    pub data: Vec<Vec<f32>>,
44}
45
46/// Output structure for the raw similarity matrix computation.
47///
48/// This provides a detailed, un-reduced view of the similarity scores,
49/// along with the tokens for queries and documents for inspection.
50#[derive(Serialize, Deserialize, Debug)]
51pub struct RawSimilarityOutput {
52    /// The raw similarity matrix with dimensions
53    /// `[num_queries, num_documents, query_length, document_length]`.
54    pub similarity_matrix: Vec<Vec<Vec<Vec<f32>>>>,
55    /// The tokens corresponding to each query.
56    pub query_tokens: Vec<Vec<String>>,
57    /// The tokens corresponding to each document.
58    pub document_tokens: Vec<Vec<String>>,
59}