manx_cli/rag/providers/
openai.rs1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5
6use super::{EmbeddingProvider as ProviderTrait, ProviderInfo};
7
8pub struct OpenAiProvider {
10 client: Client,
11 api_key: String,
12 model: String,
13 dimension: Option<usize>, }
15
16#[derive(Serialize)]
17struct EmbeddingRequest {
18 input: String,
19 model: String,
20 encoding_format: String,
21}
22
23#[derive(Deserialize)]
24struct EmbeddingResponse {
25 data: Vec<EmbeddingData>,
26 model: String,
27 usage: Usage,
28}
29
30#[derive(Deserialize)]
31struct EmbeddingData {
32 embedding: Vec<f32>,
33 index: usize,
34}
35
36#[derive(Deserialize)]
37struct Usage {
38 prompt_tokens: u32,
39 total_tokens: u32,
40}
41
42impl OpenAiProvider {
43 pub fn new(api_key: String, model: String) -> Self {
45 let client = Client::builder()
46 .timeout(std::time::Duration::from_secs(30))
47 .build()
48 .unwrap();
49
50 Self {
51 client,
52 api_key,
53 model,
54 dimension: None,
55 }
56 }
57
58 #[allow(dead_code)]
60 pub async fn detect_dimension(&mut self) -> Result<usize> {
61 if let Some(dim) = self.dimension {
62 return Ok(dim);
63 }
64
65 log::info!(
66 "Detecting embedding dimension for OpenAI model: {}",
67 self.model
68 );
69
70 let test_embedding = self.call_api("test").await?;
71 let dimension = test_embedding.len();
72
73 self.dimension = Some(dimension);
74 log::info!("Detected dimension: {} for model {}", dimension, self.model);
75
76 Ok(dimension)
77 }
78
79 #[allow(dead_code)]
81 pub fn get_usage_stats(&self) -> Option<(u32, u32)> {
82 None
85 }
86
87 async fn call_api(&self, text: &str) -> Result<Vec<f32>> {
89 let request = EmbeddingRequest {
90 input: text.to_string(),
91 model: self.model.clone(),
92 encoding_format: "float".to_string(),
93 };
94
95 let response = self
96 .client
97 .post("https://api.openai.com/v1/embeddings")
98 .header("Authorization", format!("Bearer {}", self.api_key))
99 .header("Content-Type", "application/json")
100 .json(&request)
101 .send()
102 .await?;
103
104 let status = response.status();
105 if !status.is_success() {
106 let error_text = response.text().await.unwrap_or_default();
107 return Err(anyhow!(
108 "OpenAI API error: HTTP {} - {}",
109 status,
110 error_text
111 ));
112 }
113
114 let embedding_response: EmbeddingResponse = response.json().await?;
115
116 if embedding_response.data.is_empty() {
117 return Err(anyhow!("No embeddings returned from OpenAI API"));
118 }
119
120 log::debug!(
122 "OpenAI API usage: {} prompt tokens, {} total tokens",
123 embedding_response.usage.prompt_tokens,
124 embedding_response.usage.total_tokens
125 );
126
127 if embedding_response.data[0].index != 0 {
129 log::warn!(
130 "Unexpected embedding index: {}",
131 embedding_response.data[0].index
132 );
133 }
134
135 if embedding_response.model != self.model {
137 log::info!(
138 "API returned model: {} (requested: {})",
139 embedding_response.model,
140 self.model
141 );
142 }
143
144 Ok(embedding_response.data[0].embedding.clone())
145 }
146
147 pub fn get_model_info(model: &str) -> (usize, usize) {
149 match model {
150 "text-embedding-3-small" => (1536, 8191),
151 "text-embedding-3-large" => (3072, 8191),
152 "text-embedding-ada-002" => (1536, 8191),
153 _ => (1536, 8191), }
155 }
156}
157
158#[async_trait]
159impl ProviderTrait for OpenAiProvider {
160 async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
161 if text.trim().is_empty() {
162 return Err(anyhow!("Cannot embed empty text"));
163 }
164
165 let (_, max_chars) = Self::get_model_info(&self.model);
167 let truncated_text = if text.len() > max_chars {
168 &text[..max_chars]
169 } else {
170 text
171 };
172
173 self.call_api(truncated_text).await
174 }
175
176 async fn get_dimension(&self) -> Result<usize> {
177 if let Some(dim) = self.dimension {
178 Ok(dim)
179 } else {
180 let (dim, _) = Self::get_model_info(&self.model);
182 Ok(dim)
183 }
184 }
185
186 async fn health_check(&self) -> Result<()> {
187 self.call_api("test").await.map(|_| ())
188 }
189
190 fn get_info(&self) -> ProviderInfo {
191 let (_, max_length) = Self::get_model_info(&self.model);
192
193 ProviderInfo {
194 name: "OpenAI Embeddings".to_string(),
195 provider_type: "openai".to_string(),
196 model_name: Some(self.model.clone()),
197 description: format!("OpenAI embeddings model: {}", self.model),
198 max_input_length: Some(max_length),
199 }
200 }
201}