leann_core/embedding/
openai.rs1use anyhow::{Context, Result};
2use ndarray::Array2;
3use serde::{Deserialize, Serialize};
4use tracing::info;
5
6use super::EmbeddingProvider;
7use crate::settings::resolve_openai_api_key;
8
9pub struct OpenAiEmbedding {
11 model: String,
12 api_key: String,
13 base_url: String,
14 dimensions: usize,
15 client: reqwest::blocking::Client,
16}
17
18#[derive(Serialize)]
19struct EmbeddingRequest {
20 model: String,
21 input: Vec<String>,
22}
23
24#[derive(Deserialize)]
25struct EmbeddingResponse {
26 data: Vec<EmbeddingData>,
27}
28
29#[derive(Deserialize)]
30struct EmbeddingData {
31 embedding: Vec<f32>,
32}
33
34impl OpenAiEmbedding {
35 pub fn new(
36 model: &str,
37 api_key: Option<&str>,
38 base_url: Option<&str>,
39 dimensions: Option<usize>,
40 ) -> Result<Self> {
41 let api_key = resolve_openai_api_key(api_key)
42 .ok_or_else(|| anyhow::anyhow!("OpenAI API key required (set OPENAI_API_KEY)"))?;
43
44 let base_url = base_url
45 .unwrap_or("https://api.openai.com/v1")
46 .trim_end_matches('/')
47 .to_string();
48
49 let dimensions = dimensions.unwrap_or(1536);
50
51 Ok(Self {
52 model: model.to_string(),
53 api_key,
54 base_url,
55 dimensions,
56 client: reqwest::blocking::Client::new(),
57 })
58 }
59}
60
61impl EmbeddingProvider for OpenAiEmbedding {
62 fn compute_embeddings(
63 &self,
64 chunks: &[String],
65 progress: Option<&dyn crate::hnsw::IndexProgress>,
66 ) -> Result<Array2<f32>> {
67 if chunks.is_empty() {
68 return Ok(Array2::zeros((0, self.dimensions)));
69 }
70
71 let max_batch_size = if self.base_url.contains("generativelanguage.googleapis.com") {
73 100 } else {
75 800
76 };
77
78 let mut all_embeddings: Vec<Vec<f32>> = Vec::with_capacity(chunks.len());
79 let num_batches = chunks.len().div_ceil(max_batch_size);
80
81 for (i, batch) in chunks.chunks(max_batch_size).enumerate() {
82 info!(
83 "OpenAI embedding batch {}/{} ({} chunks)",
84 i + 1,
85 num_batches,
86 batch.len()
87 );
88 let request = EmbeddingRequest {
89 model: self.model.clone(),
90 input: batch.to_vec(),
91 };
92
93 let response = self
94 .client
95 .post(format!("{}/embeddings", self.base_url))
96 .header("Authorization", format!("Bearer {}", self.api_key))
97 .header("Content-Type", "application/json")
98 .json(&request)
99 .send()
100 .context("sending embedding request to OpenAI")?;
101
102 let status = response.status();
103 if !status.is_success() {
104 let body = response.text().unwrap_or_default();
105 anyhow::bail!("OpenAI API error ({}): {}", status, body);
106 }
107
108 let resp: EmbeddingResponse = response
109 .json()
110 .context("parsing OpenAI embedding response")?;
111
112 for item in resp.data {
113 all_embeddings.push(item.embedding);
114 }
115
116 if let Some(p) = progress {
117 p.progress(all_embeddings.len());
118 }
119 }
120
121 if all_embeddings.is_empty() {
122 return Ok(Array2::zeros((0, self.dimensions)));
123 }
124
125 let n = all_embeddings.len();
126 let d = all_embeddings[0].len();
127 let flat: Vec<f32> = all_embeddings.into_iter().flatten().collect();
128
129 Array2::from_shape_vec((n, d), flat).context("reshaping OpenAI embeddings")
130 }
131
132 fn dimensions(&self) -> usize {
133 self.dimensions
134 }
135
136 fn name(&self) -> &str {
137 "openai"
138 }
139}