1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
use serde::{Deserialize, Serialize};

use crate::{error::OpenAIResponseExt, AuthTokenProvider, OpenAI, OpenAIResult};

use super::API_BASE_URL;

#[derive(Serialize, Debug, Clone, Copy)]
pub enum EmbeddingsModel {
    #[serde(rename = "text-embedding-3-large")]
    TextEmbedding3Large,
}

#[derive(Serialize, Debug, Clone)]
pub struct CreateEmbeddingsRequest<'a> {
    model: EmbeddingsModel,
    input: &'a str,
}

impl<'a> CreateEmbeddingsRequest<'a> {
    pub fn new(model: EmbeddingsModel, input: &'a str) -> Self {
        Self { model, input }
    }
}

#[derive(Deserialize)]
pub struct CreateEmbeddingsResponse {
    data: Vec<EmbeddingsData>,
}

impl CreateEmbeddingsResponse {
    pub fn embedding(&self) -> &[f32] {
        &self.data[0].embedding
    }
}

#[derive(Deserialize)]
struct EmbeddingsData {
    embedding: Vec<f32>,
}

pub async fn create_embeddings<Auth>(
    openai: &OpenAI<Auth>,
    req: &CreateEmbeddingsRequest<'_>,
) -> OpenAIResult<CreateEmbeddingsResponse>
where
    Auth: AuthTokenProvider,
{
    openai
        .post(format!("{API_BASE_URL}/embeddings"), req)
        .await?
        .openai_response_json()
        .await
}