Skip to main content

leann_core/embedding/
openai.rs

1use anyhow::{Context, Result};
2use ndarray::Array2;
3use serde::{Deserialize, Serialize};
4use tracing::info;
5
6use super::EmbeddingProvider;
7use crate::settings::resolve_openai_api_key;
8
9/// OpenAI embedding API client.
10pub struct OpenAiEmbedding {
11    model: String,
12    api_key: String,
13    base_url: String,
14    dimensions: usize,
15    client: reqwest::blocking::Client,
16}
17
18#[derive(Serialize)]
19struct EmbeddingRequest {
20    model: String,
21    input: Vec<String>,
22}
23
24#[derive(Deserialize)]
25struct EmbeddingResponse {
26    data: Vec<EmbeddingData>,
27}
28
29#[derive(Deserialize)]
30struct EmbeddingData {
31    embedding: Vec<f32>,
32}
33
34impl OpenAiEmbedding {
35    pub fn new(
36        model: &str,
37        api_key: Option<&str>,
38        base_url: Option<&str>,
39        dimensions: Option<usize>,
40    ) -> Result<Self> {
41        let api_key = resolve_openai_api_key(api_key)
42            .ok_or_else(|| anyhow::anyhow!("OpenAI API key required (set OPENAI_API_KEY)"))?;
43
44        let base_url = base_url
45            .unwrap_or("https://api.openai.com/v1")
46            .trim_end_matches('/')
47            .to_string();
48
49        let dimensions = dimensions.unwrap_or(1536);
50
51        Ok(Self {
52            model: model.to_string(),
53            api_key,
54            base_url,
55            dimensions,
56            client: reqwest::blocking::Client::new(),
57        })
58    }
59}
60
61impl EmbeddingProvider for OpenAiEmbedding {
62    fn compute_embeddings(&self, chunks: &[String]) -> Result<Array2<f32>> {
63        if chunks.is_empty() {
64            return Ok(Array2::zeros((0, self.dimensions)));
65        }
66
67        // Batch to avoid overwhelming the API (matches Python's max_batch_size=800)
68        let max_batch_size = if self.base_url.contains("generativelanguage.googleapis.com") {
69            100 // Gemini OpenAI-compatible endpoint limits to 100
70        } else {
71            800
72        };
73
74        let mut all_embeddings: Vec<Vec<f32>> = Vec::with_capacity(chunks.len());
75        let num_batches = chunks.len().div_ceil(max_batch_size);
76
77        for (i, batch) in chunks.chunks(max_batch_size).enumerate() {
78            info!(
79                "OpenAI embedding batch {}/{} ({} chunks)",
80                i + 1,
81                num_batches,
82                batch.len()
83            );
84            let request = EmbeddingRequest {
85                model: self.model.clone(),
86                input: batch.to_vec(),
87            };
88
89            let response = self
90                .client
91                .post(format!("{}/embeddings", self.base_url))
92                .header("Authorization", format!("Bearer {}", self.api_key))
93                .header("Content-Type", "application/json")
94                .json(&request)
95                .send()
96                .context("sending embedding request to OpenAI")?;
97
98            let status = response.status();
99            if !status.is_success() {
100                let body = response.text().unwrap_or_default();
101                anyhow::bail!("OpenAI API error ({}): {}", status, body);
102            }
103
104            let resp: EmbeddingResponse = response
105                .json()
106                .context("parsing OpenAI embedding response")?;
107
108            for item in resp.data {
109                all_embeddings.push(item.embedding);
110            }
111        }
112
113        if all_embeddings.is_empty() {
114            return Ok(Array2::zeros((0, self.dimensions)));
115        }
116
117        let n = all_embeddings.len();
118        let d = all_embeddings[0].len();
119        let flat: Vec<f32> = all_embeddings.into_iter().flatten().collect();
120
121        Array2::from_shape_vec((n, d), flat).context("reshaping OpenAI embeddings")
122    }
123
124    fn dimensions(&self) -> usize {
125        self.dimensions
126    }
127
128    fn name(&self) -> &str {
129        "openai"
130    }
131}