leann_core/embedding/
openai.rs1use anyhow::{Context, Result};
2use ndarray::Array2;
3use serde::{Deserialize, Serialize};
4use tracing::info;
5
6use super::EmbeddingProvider;
7use crate::settings::resolve_openai_api_key;
8
9pub struct OpenAiEmbedding {
11 model: String,
12 api_key: String,
13 base_url: String,
14 dimensions: usize,
15 client: reqwest::blocking::Client,
16}
17
18#[derive(Serialize)]
19struct EmbeddingRequest {
20 model: String,
21 input: Vec<String>,
22}
23
24#[derive(Deserialize)]
25struct EmbeddingResponse {
26 data: Vec<EmbeddingData>,
27}
28
29#[derive(Deserialize)]
30struct EmbeddingData {
31 embedding: Vec<f32>,
32}
33
34impl OpenAiEmbedding {
35 pub fn new(
36 model: &str,
37 api_key: Option<&str>,
38 base_url: Option<&str>,
39 dimensions: Option<usize>,
40 ) -> Result<Self> {
41 let api_key = resolve_openai_api_key(api_key)
42 .ok_or_else(|| anyhow::anyhow!("OpenAI API key required (set OPENAI_API_KEY)"))?;
43
44 let base_url = base_url
45 .unwrap_or("https://api.openai.com/v1")
46 .trim_end_matches('/')
47 .to_string();
48
49 let dimensions = dimensions.unwrap_or(1536);
50
51 Ok(Self {
52 model: model.to_string(),
53 api_key,
54 base_url,
55 dimensions,
56 client: reqwest::blocking::Client::new(),
57 })
58 }
59}
60
61impl EmbeddingProvider for OpenAiEmbedding {
62 fn compute_embeddings(&self, chunks: &[String]) -> Result<Array2<f32>> {
63 if chunks.is_empty() {
64 return Ok(Array2::zeros((0, self.dimensions)));
65 }
66
67 let max_batch_size = if self.base_url.contains("generativelanguage.googleapis.com") {
69 100 } else {
71 800
72 };
73
74 let mut all_embeddings: Vec<Vec<f32>> = Vec::with_capacity(chunks.len());
75 let num_batches = chunks.len().div_ceil(max_batch_size);
76
77 for (i, batch) in chunks.chunks(max_batch_size).enumerate() {
78 info!(
79 "OpenAI embedding batch {}/{} ({} chunks)",
80 i + 1,
81 num_batches,
82 batch.len()
83 );
84 let request = EmbeddingRequest {
85 model: self.model.clone(),
86 input: batch.to_vec(),
87 };
88
89 let response = self
90 .client
91 .post(format!("{}/embeddings", self.base_url))
92 .header("Authorization", format!("Bearer {}", self.api_key))
93 .header("Content-Type", "application/json")
94 .json(&request)
95 .send()
96 .context("sending embedding request to OpenAI")?;
97
98 let status = response.status();
99 if !status.is_success() {
100 let body = response.text().unwrap_or_default();
101 anyhow::bail!("OpenAI API error ({}): {}", status, body);
102 }
103
104 let resp: EmbeddingResponse = response
105 .json()
106 .context("parsing OpenAI embedding response")?;
107
108 for item in resp.data {
109 all_embeddings.push(item.embedding);
110 }
111 }
112
113 if all_embeddings.is_empty() {
114 return Ok(Array2::zeros((0, self.dimensions)));
115 }
116
117 let n = all_embeddings.len();
118 let d = all_embeddings[0].len();
119 let flat: Vec<f32> = all_embeddings.into_iter().flatten().collect();
120
121 Array2::from_shape_vec((n, d), flat).context("reshaping OpenAI embeddings")
122 }
123
124 fn dimensions(&self) -> usize {
125 self.dimensions
126 }
127
128 fn name(&self) -> &str {
129 "openai"
130 }
131}