1use std::env;
9
10use async_trait::async_trait;
11use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue};
12use rucora_core::{embed::EmbeddingProvider, error::ProviderError};
13use serde_json::{Value, json};
14
15pub struct OpenAiEmbeddingProvider {
17 client: reqwest::Client,
18 base_url: String,
19 model: String,
20 embedding_dim: Option<usize>,
21}
22
23impl OpenAiEmbeddingProvider {
24 pub fn from_env() -> Result<Self, ProviderError> {
26 let api_key = env::var("OPENAI_API_KEY")
27 .map_err(|_| ProviderError::Message("缺少环境变量 OPENAI_API_KEY".to_string()))?;
28 let base_url =
29 env::var("OPENAI_BASE_URL").unwrap_or_else(|_| "https://api.openai.com/v1".to_string());
30
31 let embedding_model = env::var("EMBEDDING_MODEL")
32 .map_err(|_| ProviderError::Message("缺少环境变量 EMBEDDING_MODEL".to_string()))?;
33
34 Ok(Self::new(base_url, api_key, embedding_model))
35 }
36
37 pub fn new(
39 base_url: impl Into<String>,
40 api_key: impl Into<String>,
41 model: impl Into<String>,
42 ) -> Self {
43 let api_key = api_key.into();
44 let model = model.into();
45 let base_url = base_url.into();
46
47 let mut headers = HeaderMap::new();
48 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
49 if let Ok(v) = HeaderValue::from_str(&format!("Bearer {api_key}")) {
50 headers.insert(AUTHORIZATION, v);
51 }
52
53 let client = reqwest::Client::builder()
54 .default_headers(headers)
55 .build()
56 .expect("reqwest client build failed");
57
58 let embedding_dim = match model.as_str() {
60 "text-embedding-ada-002" | "text-embedding-3-small" => Some(1536),
61 "text-embedding-3-large" => Some(3072),
62 _ => None,
63 };
64
65 Self {
66 client,
67 base_url,
68 model,
69 embedding_dim,
70 }
71 }
72
73 pub fn with_model(mut self, model: impl Into<String>) -> Self {
75 self.model = model.into();
76 self.embedding_dim = match self.model.as_str() {
78 "text-embedding-ada-002" | "text-embedding-3-small" => Some(1536),
79 "text-embedding-3-large" => Some(3072),
80 _ => None,
81 };
82 self
83 }
84}
85
86#[async_trait]
87impl EmbeddingProvider for OpenAiEmbeddingProvider {
88 async fn embed(&self, text: &str) -> Result<Vec<f32>, ProviderError> {
89 let url = format!("{}/embeddings", self.base_url.trim_end_matches('/'));
90
91 let body = json!({
92 "model": self.model,
93 "input": text,
94 });
95
96 let resp = self
97 .client
98 .post(&url)
99 .json(&body)
100 .send()
101 .await
102 .map_err(|e| ProviderError::Message(e.to_string()))?;
103
104 let status = resp.status();
105 let data: Value = resp
106 .json()
107 .await
108 .map_err(|e| ProviderError::Message(e.to_string()))?;
109
110 if !status.is_success() {
111 return Err(ProviderError::Message(format!(
112 "OpenAI embedding 请求失败:status={status} body={data}"
113 )));
114 }
115
116 let embedding = data
118 .get("data")
119 .and_then(|d| d.as_array())
120 .and_then(|arr| arr.first())
121 .and_then(|item| item.get("embedding"))
122 .and_then(|e| e.as_array())
123 .ok_or_else(|| ProviderError::Message("OpenAI 响应缺少 embedding 数据".to_string()))?
124 .iter()
125 .filter_map(|v| v.as_f64().map(|f| f as f32))
126 .collect();
127
128 Ok(embedding)
129 }
130
131 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, ProviderError> {
132 if texts.is_empty() {
133 return Ok(Vec::new());
134 }
135
136 let url = format!("{}/embeddings", self.base_url.trim_end_matches('/'));
137
138 let body = json!({
139 "model": self.model,
140 "input": texts,
141 });
142
143 let resp = self
144 .client
145 .post(&url)
146 .json(&body)
147 .send()
148 .await
149 .map_err(|e| ProviderError::Message(e.to_string()))?;
150
151 let status = resp.status();
152 let data: Value = resp
153 .json()
154 .await
155 .map_err(|e| ProviderError::Message(e.to_string()))?;
156
157 if !status.is_success() {
158 return Err(ProviderError::Message(format!(
159 "OpenAI embedding 批量请求失败:status={status} body={data}"
160 )));
161 }
162
163 let data_array = data
165 .get("data")
166 .and_then(|d| d.as_array())
167 .ok_or_else(|| ProviderError::Message("OpenAI 响应缺少 data 数组".to_string()))?;
168
169 let mut results = Vec::with_capacity(texts.len());
170 for item in data_array {
171 let embedding = item
172 .get("embedding")
173 .and_then(|e| e.as_array())
174 .ok_or_else(|| {
175 ProviderError::Message("OpenAI 响应缺少 embedding 数据".to_string())
176 })?
177 .iter()
178 .filter_map(|v| v.as_f64().map(|f| f as f32))
179 .collect();
180 results.push(embedding);
181 }
182
183 Ok(results)
184 }
185
186 fn embedding_dim(&self) -> Option<usize> {
187 self.embedding_dim
188 }
189}