contrag_core/embedders/
openai.rs

1use serde::{Deserialize, Serialize};
2use crate::embedders::{Embedder, http_client::HttpClient};
3use crate::error::{ContragError, Result};
4use crate::types::ConnectionTestResult;
5
6/// OpenAI embedder using HTTP outcalls
7pub struct OpenAIEmbedder {
8    api_key: String,
9    model: String,
10    dimensions: usize,
11    api_endpoint: String,
12    http_client: HttpClient,
13}
14
15impl OpenAIEmbedder {
16    /// Create a new OpenAI embedder
17    pub fn new(api_key: String, model: String) -> Self {
18        let dimensions = match model.as_str() {
19            "text-embedding-3-small" => 1536,
20            "text-embedding-3-large" => 3072,
21            "text-embedding-ada-002" => 1536,
22            _ => 1536, // default
23        };
24
25        Self {
26            api_key,
27            model,
28            dimensions,
29            api_endpoint: "https://api.openai.com/v1/embeddings".to_string(),
30            http_client: HttpClient::new(),
31        }
32    }
33
34    /// Create with custom API endpoint
35    pub fn with_endpoint(mut self, endpoint: String) -> Self {
36        self.api_endpoint = endpoint;
37        self
38    }
39}
40
41#[async_trait::async_trait]
42impl Embedder for OpenAIEmbedder {
43    fn name(&self) -> &str {
44        "openai"
45    }
46
47    async fn embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
48        if texts.is_empty() {
49            return Ok(vec![]);
50        }
51
52        let request = OpenAIEmbeddingRequest {
53            model: self.model.clone(),
54            input: texts,
55        };
56
57        let body = serde_json::to_vec(&request)
58            .map_err(|e| ContragError::SerializationError(e.to_string()))?;
59
60        let headers = vec![
61            ("Content-Type".to_string(), "application/json".to_string()),
62            ("Authorization".to_string(), format!("Bearer {}", self.api_key)),
63        ];
64
65        let response = self
66            .http_client
67            .post(self.api_endpoint.clone(), headers, body)
68            .await?;
69
70        if response.status != 200 {
71            let error_text = response.text().unwrap_or_else(|_| "Unknown error".to_string());
72            return Err(ContragError::EmbedderError(format!(
73                "OpenAI API returned status {}: {}",
74                response.status, error_text
75            )));
76        }
77
78        let embedding_response: OpenAIEmbeddingResponse = response.json()?;
79
80        Ok(embedding_response
81            .data
82            .into_iter()
83            .map(|item| item.embedding)
84            .collect())
85    }
86
87    fn dimensions(&self) -> usize {
88        self.dimensions
89    }
90
91    async fn test_connection(&self) -> Result<ConnectionTestResult> {
92        let start = ic_cdk::api::time();
93
94        match self.embed(vec!["test connection".to_string()]).await {
95            Ok(_) => {
96                let latency = (ic_cdk::api::time() - start) / 1_000_000; // Convert to ms
97                Ok(ConnectionTestResult {
98                    plugin: self.name().to_string(),
99                    connected: true,
100                    latency: Some(latency),
101                    error: None,
102                    details: Some(format!(
103                        "model: {}, dimensions: {}",
104                        self.model, self.dimensions
105                    )),
106                })
107            }
108            Err(e) => Ok(ConnectionTestResult {
109                plugin: self.name().to_string(),
110                connected: false,
111                latency: None,
112                error: Some(e.to_string()),
113                details: None,
114            }),
115        }
116    }
117
118    async fn generate_with_prompt(
119        &self,
120        text: String,
121        system_prompt: String,
122    ) -> Result<String> {
123        let request = OpenAIChatRequest {
124            model: "gpt-3.5-turbo".to_string(),
125            messages: vec![
126                ChatMessage {
127                    role: "system".to_string(),
128                    content: system_prompt,
129                },
130                ChatMessage {
131                    role: "user".to_string(),
132                    content: text,
133                },
134            ],
135            max_tokens: 1000,
136            temperature: 0.7,
137        };
138
139        let body = serde_json::to_vec(&request)
140            .map_err(|e| ContragError::SerializationError(e.to_string()))?;
141
142        let headers = vec![
143            ("Content-Type".to_string(), "application/json".to_string()),
144            ("Authorization".to_string(), format!("Bearer {}", self.api_key)),
145        ];
146
147        let response = self
148            .http_client
149            .post(
150                "https://api.openai.com/v1/chat/completions".to_string(),
151                headers,
152                body,
153            )
154            .await?;
155
156        if response.status != 200 {
157            let error_text = response.text().unwrap_or_else(|_| "Unknown error".to_string());
158            return Err(ContragError::EmbedderError(format!(
159                "OpenAI API returned status {}: {}",
160                response.status, error_text
161            )));
162        }
163
164        let chat_response: OpenAIChatResponse = response.json()?;
165
166        Ok(chat_response
167            .choices
168            .get(0)
169            .and_then(|c| Some(c.message.content.clone()))
170            .unwrap_or_default())
171    }
172}
173
174// Request/Response types for OpenAI API
175
176#[derive(Serialize)]
177struct OpenAIEmbeddingRequest {
178    model: String,
179    input: Vec<String>,
180}
181
182#[derive(Deserialize)]
183struct OpenAIEmbeddingResponse {
184    data: Vec<EmbeddingData>,
185}
186
187#[derive(Deserialize)]
188struct EmbeddingData {
189    embedding: Vec<f32>,
190}
191
192#[derive(Serialize)]
193struct OpenAIChatRequest {
194    model: String,
195    messages: Vec<ChatMessage>,
196    max_tokens: u32,
197    temperature: f32,
198}
199
200#[derive(Serialize, Deserialize)]
201struct ChatMessage {
202    role: String,
203    content: String,
204}
205
206#[derive(Deserialize)]
207struct OpenAIChatResponse {
208    choices: Vec<ChatChoice>,
209}
210
211#[derive(Deserialize)]
212struct ChatChoice {
213    message: ChatMessage,
214}