Skip to main content

rucora_embed/
openai.rs

1//! OpenAI Embedding Provider 实现。
2//!
3//! 约定:
4//! - API Key 从 `OPENAI_API_KEY` 环境变量读取
5//! - Base URL 默认 `https://api.openai.com/v1`,也可通过 `OPENAI_BASE_URL` 覆盖
6//! - 默认使用 `text-embedding-ada-002` 模型
7
8use std::env;
9
10use async_trait::async_trait;
11use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue};
12use rucora_core::{embed::EmbeddingProvider, error::ProviderError};
13use serde_json::{Value, json};
14
15/// OpenAI Embedding Provider。
16pub struct OpenAiEmbeddingProvider {
17    client: reqwest::Client,
18    base_url: String,
19    model: String,
20    embedding_dim: Option<usize>,
21}
22
23impl OpenAiEmbeddingProvider {
24    /// 从环境变量创建 Provider。
25    pub fn from_env() -> Result<Self, ProviderError> {
26        let api_key = env::var("OPENAI_API_KEY")
27            .map_err(|_| ProviderError::Message("缺少环境变量 OPENAI_API_KEY".to_string()))?;
28        let base_url =
29            env::var("OPENAI_BASE_URL").unwrap_or_else(|_| "https://api.openai.com/v1".to_string());
30
31        let embedding_model = env::var("EMBEDDING_MODEL")
32            .map_err(|_| ProviderError::Message("缺少环境变量 EMBEDDING_MODEL".to_string()))?;
33
34        Ok(Self::new(base_url, api_key, embedding_model))
35    }
36
37    /// 创建 Provider。
38    pub fn new(
39        base_url: impl Into<String>,
40        api_key: impl Into<String>,
41        model: impl Into<String>,
42    ) -> Self {
43        let api_key = api_key.into();
44        let model = model.into();
45        let base_url = base_url.into();
46
47        let mut headers = HeaderMap::new();
48        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
49        if let Ok(v) = HeaderValue::from_str(&format!("Bearer {api_key}")) {
50            headers.insert(AUTHORIZATION, v);
51        }
52
53        let client = reqwest::Client::builder()
54            .default_headers(headers)
55            .build()
56            .expect("reqwest client build failed");
57
58        // 根据模型确定维度
59        let embedding_dim = match model.as_str() {
60            "text-embedding-ada-002" | "text-embedding-3-small" => Some(1536),
61            "text-embedding-3-large" => Some(3072),
62            _ => None,
63        };
64
65        Self {
66            client,
67            base_url,
68            model,
69            embedding_dim,
70        }
71    }
72
73    /// 设置模型(用于切换不同嵌入模型)。
74    pub fn with_model(mut self, model: impl Into<String>) -> Self {
75        self.model = model.into();
76        // 重新计算维度
77        self.embedding_dim = match self.model.as_str() {
78            "text-embedding-ada-002" | "text-embedding-3-small" => Some(1536),
79            "text-embedding-3-large" => Some(3072),
80            _ => None,
81        };
82        self
83    }
84}
85
86#[async_trait]
87impl EmbeddingProvider for OpenAiEmbeddingProvider {
88    async fn embed(&self, text: &str) -> Result<Vec<f32>, ProviderError> {
89        let url = format!("{}/embeddings", self.base_url.trim_end_matches('/'));
90
91        let body = json!({
92            "model": self.model,
93            "input": text,
94        });
95
96        let resp = self
97            .client
98            .post(&url)
99            .json(&body)
100            .send()
101            .await
102            .map_err(|e| ProviderError::Message(e.to_string()))?;
103
104        let status = resp.status();
105        let data: Value = resp
106            .json()
107            .await
108            .map_err(|e| ProviderError::Message(e.to_string()))?;
109
110        if !status.is_success() {
111            return Err(ProviderError::Message(format!(
112                "OpenAI embedding 请求失败:status={status} body={data}"
113            )));
114        }
115
116        // 解析响应:data[0].embedding
117        let embedding = data
118            .get("data")
119            .and_then(|d| d.as_array())
120            .and_then(|arr| arr.first())
121            .and_then(|item| item.get("embedding"))
122            .and_then(|e| e.as_array())
123            .ok_or_else(|| ProviderError::Message("OpenAI 响应缺少 embedding 数据".to_string()))?
124            .iter()
125            .filter_map(|v| v.as_f64().map(|f| f as f32))
126            .collect();
127
128        Ok(embedding)
129    }
130
131    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, ProviderError> {
132        if texts.is_empty() {
133            return Ok(Vec::new());
134        }
135
136        let url = format!("{}/embeddings", self.base_url.trim_end_matches('/'));
137
138        let body = json!({
139            "model": self.model,
140            "input": texts,
141        });
142
143        let resp = self
144            .client
145            .post(&url)
146            .json(&body)
147            .send()
148            .await
149            .map_err(|e| ProviderError::Message(e.to_string()))?;
150
151        let status = resp.status();
152        let data: Value = resp
153            .json()
154            .await
155            .map_err(|e| ProviderError::Message(e.to_string()))?;
156
157        if !status.is_success() {
158            return Err(ProviderError::Message(format!(
159                "OpenAI embedding 批量请求失败:status={status} body={data}"
160            )));
161        }
162
163        // 解析响应:data[].embedding
164        let data_array = data
165            .get("data")
166            .and_then(|d| d.as_array())
167            .ok_or_else(|| ProviderError::Message("OpenAI 响应缺少 data 数组".to_string()))?;
168
169        let mut results = Vec::with_capacity(texts.len());
170        for item in data_array {
171            let embedding = item
172                .get("embedding")
173                .and_then(|e| e.as_array())
174                .ok_or_else(|| {
175                    ProviderError::Message("OpenAI 响应缺少 embedding 数据".to_string())
176                })?
177                .iter()
178                .filter_map(|v| v.as_f64().map(|f| f as f32))
179                .collect();
180            results.push(embedding);
181        }
182
183        Ok(results)
184    }
185
186    fn embedding_dim(&self) -> Option<usize> {
187        self.embedding_dim
188    }
189}