Skip to main content

leann_core/embedding/
openai.rs

1use anyhow::{Context, Result};
2use ndarray::Array2;
3use serde::{Deserialize, Serialize};
4use tracing::info;
5
6use super::EmbeddingProvider;
7use crate::settings::resolve_openai_api_key;
8
9/// OpenAI embedding API client.
10pub struct OpenAiEmbedding {
11    model: String,
12    api_key: String,
13    base_url: String,
14    dimensions: usize,
15    client: reqwest::blocking::Client,
16}
17
18#[derive(Serialize)]
19struct EmbeddingRequest {
20    model: String,
21    input: Vec<String>,
22}
23
24#[derive(Deserialize)]
25struct EmbeddingResponse {
26    data: Vec<EmbeddingData>,
27}
28
29#[derive(Deserialize)]
30struct EmbeddingData {
31    embedding: Vec<f32>,
32}
33
34impl OpenAiEmbedding {
35    pub fn new(
36        model: &str,
37        api_key: Option<&str>,
38        base_url: Option<&str>,
39        dimensions: Option<usize>,
40    ) -> Result<Self> {
41        let api_key = resolve_openai_api_key(api_key)
42            .ok_or_else(|| anyhow::anyhow!("OpenAI API key required (set OPENAI_API_KEY)"))?;
43
44        let base_url = base_url
45            .unwrap_or("https://api.openai.com/v1")
46            .trim_end_matches('/')
47            .to_string();
48
49        let dimensions = dimensions.unwrap_or(1536);
50
51        Ok(Self {
52            model: model.to_string(),
53            api_key,
54            base_url,
55            dimensions,
56            client: reqwest::blocking::Client::new(),
57        })
58    }
59}
60
61impl EmbeddingProvider for OpenAiEmbedding {
62    fn compute_embeddings(
63        &self,
64        chunks: &[String],
65        progress: Option<&dyn crate::hnsw::IndexProgress>,
66    ) -> Result<Array2<f32>> {
67        if chunks.is_empty() {
68            return Ok(Array2::zeros((0, self.dimensions)));
69        }
70
71        // Batch to avoid overwhelming the API (matches Python's max_batch_size=800)
72        let max_batch_size = if self.base_url.contains("generativelanguage.googleapis.com") {
73            100 // Gemini OpenAI-compatible endpoint limits to 100
74        } else {
75            800
76        };
77
78        let mut all_embeddings: Vec<Vec<f32>> = Vec::with_capacity(chunks.len());
79        let num_batches = chunks.len().div_ceil(max_batch_size);
80
81        for (i, batch) in chunks.chunks(max_batch_size).enumerate() {
82            info!(
83                "OpenAI embedding batch {}/{} ({} chunks)",
84                i + 1,
85                num_batches,
86                batch.len()
87            );
88            let request = EmbeddingRequest {
89                model: self.model.clone(),
90                input: batch.to_vec(),
91            };
92
93            let response = self
94                .client
95                .post(format!("{}/embeddings", self.base_url))
96                .header("Authorization", format!("Bearer {}", self.api_key))
97                .header("Content-Type", "application/json")
98                .json(&request)
99                .send()
100                .context("sending embedding request to OpenAI")?;
101
102            let status = response.status();
103            if !status.is_success() {
104                let body = response.text().unwrap_or_default();
105                anyhow::bail!("OpenAI API error ({}): {}", status, body);
106            }
107
108            let resp: EmbeddingResponse = response
109                .json()
110                .context("parsing OpenAI embedding response")?;
111
112            for item in resp.data {
113                all_embeddings.push(item.embedding);
114            }
115
116            if let Some(p) = progress {
117                p.progress(all_embeddings.len());
118            }
119        }
120
121        if all_embeddings.is_empty() {
122            return Ok(Array2::zeros((0, self.dimensions)));
123        }
124
125        let n = all_embeddings.len();
126        let d = all_embeddings[0].len();
127        let flat: Vec<f32> = all_embeddings.into_iter().flatten().collect();
128
129        Array2::from_shape_vec((n, d), flat).context("reshaping OpenAI embeddings")
130    }
131
132    fn dimensions(&self) -> usize {
133        self.dimensions
134    }
135
136    fn name(&self) -> &str {
137        "openai"
138    }
139}