1use std::pin::Pin;
2use std::sync::Arc;
3
4use serde::{Deserialize, Serialize};
5
6use crate::error::{Error, Result};
7
8use super::backend::EmbeddingBackend;
9use super::config::OpenAIConfig;
10use super::convert::to_f32_blob;
11
12const DEFAULT_BASE_URL: &str = "https://api.openai.com";
13
14struct Inner {
15 client: reqwest::Client,
16 api_key: String,
17 model: String,
18 dimensions: usize,
19 base_url: String,
20}
21
22pub struct OpenAIEmbedding(Arc<Inner>);
34
35impl Clone for OpenAIEmbedding {
36 fn clone(&self) -> Self {
37 Self(Arc::clone(&self.0))
38 }
39}
40
41impl OpenAIEmbedding {
42 pub fn new(client: reqwest::Client, config: &OpenAIConfig) -> Result<Self> {
48 config.validate()?;
49 let base_url = config
50 .base_url
51 .as_deref()
52 .unwrap_or(DEFAULT_BASE_URL)
53 .trim_end_matches('/')
54 .to_owned();
55 Ok(Self(Arc::new(Inner {
56 client,
57 api_key: config.api_key.clone(),
58 model: config.model.clone(),
59 dimensions: config.dimensions,
60 base_url,
61 })))
62 }
63}
64
65impl EmbeddingBackend for OpenAIEmbedding {
66 fn embed(&self, input: &str) -> Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send + '_>> {
67 let input = input.to_owned();
68 Box::pin(async move {
69 let url = format!("{}/v1/embeddings", self.0.base_url);
70 let body = Request {
71 input: &input,
72 model: &self.0.model,
73 dimensions: self.0.dimensions,
74 };
75
76 let resp = self
77 .0
78 .client
79 .post(&url)
80 .bearer_auth(&self.0.api_key)
81 .json(&body)
82 .send()
83 .await
84 .map_err(|e| Error::internal("openai embeddings request failed").chain(e))?;
85
86 if !resp.status().is_success() {
87 let status = resp.status();
88 let text = resp.text().await.unwrap_or_default();
89 return Err(Error::internal(format!(
90 "openai embedding error: {status}: {text}"
91 )));
92 }
93
94 let parsed: Response = resp.json().await.map_err(|e| {
95 Error::internal("failed to parse openai embedding response").chain(e)
96 })?;
97
98 let values = parsed
99 .data
100 .into_iter()
101 .next()
102 .ok_or_else(|| Error::internal("openai returned empty embedding data"))?
103 .embedding;
104
105 Ok(to_f32_blob(&values))
106 })
107 }
108
109 fn dimensions(&self) -> usize {
110 self.0.dimensions
111 }
112
113 fn model_name(&self) -> &str {
114 &self.0.model
115 }
116}
117
118#[derive(Serialize)]
119struct Request<'a> {
120 input: &'a str,
121 model: &'a str,
122 dimensions: usize,
123}
124
125#[derive(Deserialize)]
126struct Response {
127 data: Vec<EmbeddingData>,
128}
129
130#[derive(Deserialize)]
131struct EmbeddingData {
132 embedding: Vec<f32>,
133}