contrag_core/embedders/
openai.rs1use serde::{Deserialize, Serialize};
2use crate::embedders::{Embedder, http_client::HttpClient};
3use crate::error::{ContragError, Result};
4use crate::types::ConnectionTestResult;
5
6pub struct OpenAIEmbedder {
8 api_key: String,
9 model: String,
10 dimensions: usize,
11 api_endpoint: String,
12 http_client: HttpClient,
13}
14
15impl OpenAIEmbedder {
16 pub fn new(api_key: String, model: String) -> Self {
18 let dimensions = match model.as_str() {
19 "text-embedding-3-small" => 1536,
20 "text-embedding-3-large" => 3072,
21 "text-embedding-ada-002" => 1536,
22 _ => 1536, };
24
25 Self {
26 api_key,
27 model,
28 dimensions,
29 api_endpoint: "https://api.openai.com/v1/embeddings".to_string(),
30 http_client: HttpClient::new(),
31 }
32 }
33
34 pub fn with_endpoint(mut self, endpoint: String) -> Self {
36 self.api_endpoint = endpoint;
37 self
38 }
39}
40
41#[async_trait::async_trait]
42impl Embedder for OpenAIEmbedder {
43 fn name(&self) -> &str {
44 "openai"
45 }
46
47 async fn embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
48 if texts.is_empty() {
49 return Ok(vec![]);
50 }
51
52 let request = OpenAIEmbeddingRequest {
53 model: self.model.clone(),
54 input: texts,
55 };
56
57 let body = serde_json::to_vec(&request)
58 .map_err(|e| ContragError::SerializationError(e.to_string()))?;
59
60 let headers = vec![
61 ("Content-Type".to_string(), "application/json".to_string()),
62 ("Authorization".to_string(), format!("Bearer {}", self.api_key)),
63 ];
64
65 let response = self
66 .http_client
67 .post(self.api_endpoint.clone(), headers, body)
68 .await?;
69
70 if response.status != 200 {
71 let error_text = response.text().unwrap_or_else(|_| "Unknown error".to_string());
72 return Err(ContragError::EmbedderError(format!(
73 "OpenAI API returned status {}: {}",
74 response.status, error_text
75 )));
76 }
77
78 let embedding_response: OpenAIEmbeddingResponse = response.json()?;
79
80 Ok(embedding_response
81 .data
82 .into_iter()
83 .map(|item| item.embedding)
84 .collect())
85 }
86
87 fn dimensions(&self) -> usize {
88 self.dimensions
89 }
90
91 async fn test_connection(&self) -> Result<ConnectionTestResult> {
92 let start = ic_cdk::api::time();
93
94 match self.embed(vec!["test connection".to_string()]).await {
95 Ok(_) => {
96 let latency = (ic_cdk::api::time() - start) / 1_000_000; Ok(ConnectionTestResult {
98 plugin: self.name().to_string(),
99 connected: true,
100 latency: Some(latency),
101 error: None,
102 details: Some(format!(
103 "model: {}, dimensions: {}",
104 self.model, self.dimensions
105 )),
106 })
107 }
108 Err(e) => Ok(ConnectionTestResult {
109 plugin: self.name().to_string(),
110 connected: false,
111 latency: None,
112 error: Some(e.to_string()),
113 details: None,
114 }),
115 }
116 }
117
118 async fn generate_with_prompt(
119 &self,
120 text: String,
121 system_prompt: String,
122 ) -> Result<String> {
123 let request = OpenAIChatRequest {
124 model: "gpt-3.5-turbo".to_string(),
125 messages: vec![
126 ChatMessage {
127 role: "system".to_string(),
128 content: system_prompt,
129 },
130 ChatMessage {
131 role: "user".to_string(),
132 content: text,
133 },
134 ],
135 max_tokens: 1000,
136 temperature: 0.7,
137 };
138
139 let body = serde_json::to_vec(&request)
140 .map_err(|e| ContragError::SerializationError(e.to_string()))?;
141
142 let headers = vec![
143 ("Content-Type".to_string(), "application/json".to_string()),
144 ("Authorization".to_string(), format!("Bearer {}", self.api_key)),
145 ];
146
147 let response = self
148 .http_client
149 .post(
150 "https://api.openai.com/v1/chat/completions".to_string(),
151 headers,
152 body,
153 )
154 .await?;
155
156 if response.status != 200 {
157 let error_text = response.text().unwrap_or_else(|_| "Unknown error".to_string());
158 return Err(ContragError::EmbedderError(format!(
159 "OpenAI API returned status {}: {}",
160 response.status, error_text
161 )));
162 }
163
164 let chat_response: OpenAIChatResponse = response.json()?;
165
166 Ok(chat_response
167 .choices
168 .get(0)
169 .and_then(|c| Some(c.message.content.clone()))
170 .unwrap_or_default())
171 }
172}
173
174#[derive(Serialize)]
177struct OpenAIEmbeddingRequest {
178 model: String,
179 input: Vec<String>,
180}
181
182#[derive(Deserialize)]
183struct OpenAIEmbeddingResponse {
184 data: Vec<EmbeddingData>,
185}
186
187#[derive(Deserialize)]
188struct EmbeddingData {
189 embedding: Vec<f32>,
190}
191
192#[derive(Serialize)]
193struct OpenAIChatRequest {
194 model: String,
195 messages: Vec<ChatMessage>,
196 max_tokens: u32,
197 temperature: f32,
198}
199
200#[derive(Serialize, Deserialize)]
201struct ChatMessage {
202 role: String,
203 content: String,
204}
205
206#[derive(Deserialize)]
207struct OpenAIChatResponse {
208 choices: Vec<ChatChoice>,
209}
210
211#[derive(Deserialize)]
212struct ChatChoice {
213 message: ChatMessage,
214}