skill_runtime/embeddings/
openai.rs

1//! OpenAI embedding provider implementation
2//!
3//! Uses rig-core's OpenAI client for API-based embeddings.
4//! Requires OPENAI_API_KEY environment variable.
5
6use super::{EmbeddingProvider, OpenAIEmbeddingModel};
7use anyhow::{Context, Result};
8use async_trait::async_trait;
9use rig::embeddings::EmbeddingModel as RigEmbeddingModel;
10use rig::client::{EmbeddingsClient, ProviderClient};
11use rig::providers::openai::{self, Client as OpenAIClient};
12use std::sync::Arc;
13
14/// OpenAI embedding provider
15///
16/// Generates embeddings via OpenAI's API.
17/// Requires OPENAI_API_KEY environment variable to be set.
18pub struct OpenAIEmbedProvider {
19    client: Arc<OpenAIClient>,
20    model: OpenAIEmbeddingModel,
21    dims: usize,
22}
23
24impl OpenAIEmbedProvider {
25    /// Create a new OpenAI provider with the default model (Ada002)
26    ///
27    /// # Errors
28    /// Returns error if OPENAI_API_KEY is not set
29    pub fn new() -> Result<Self> {
30        Self::with_model(OpenAIEmbeddingModel::default())
31    }
32
33    /// Create a new OpenAI provider with a specific model
34    pub fn with_model(model: OpenAIEmbeddingModel) -> Result<Self> {
35        // Check for API key
36        std::env::var("OPENAI_API_KEY").context(
37            "OPENAI_API_KEY environment variable not set. Set it with: export OPENAI_API_KEY=your-key-here"
38        )?;
39
40        let client = Arc::new(OpenAIClient::from_env());
41        let dims = model.dimensions();
42
43        Ok(Self {
44            client,
45            model,
46            dims,
47        })
48    }
49
50    /// Create with a custom API key
51    pub fn with_api_key(api_key: &str, model: OpenAIEmbeddingModel) -> Result<Self> {
52        let client = Arc::new(OpenAIClient::new(api_key).context("Failed to create OpenAI client")?);
53        let dims = model.dimensions();
54
55        Ok(Self {
56            client,
57            model,
58            dims,
59        })
60    }
61
62    /// Create from a model name string
63    pub fn from_model_name(name: &str) -> Result<Self> {
64        let model: OpenAIEmbeddingModel = name.parse()?;
65        Self::with_model(model)
66    }
67
68    /// Get the API model name
69    fn api_model_name(&self) -> &'static str {
70        match self.model {
71            OpenAIEmbeddingModel::Ada002 => openai::TEXT_EMBEDDING_ADA_002,
72            OpenAIEmbeddingModel::TextEmbedding3Small => "text-embedding-3-small",
73            OpenAIEmbeddingModel::TextEmbedding3Large => "text-embedding-3-large",
74        }
75    }
76
77}
78
79#[async_trait]
80impl EmbeddingProvider for OpenAIEmbedProvider {
81    async fn embed_documents(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
82        if texts.is_empty() {
83            return Ok(Vec::new());
84        }
85
86        let embedding_model = self.client.embedding_model(self.api_model_name());
87
88        // Use rig's embed method
89        let embeddings = embedding_model
90            .embed_texts(texts)
91            .await
92            .context("OpenAI failed to generate embeddings")?;
93
94        // Convert from rig's Embedding type to Vec<f32>
95        let results: Vec<Vec<f32>> = embeddings
96            .into_iter()
97            .map(|emb| emb.vec.into_iter().map(|x| x as f32).collect())
98            .collect();
99
100        Ok(results)
101    }
102
103    fn dimensions(&self) -> usize {
104        self.dims
105    }
106
107    fn model_name(&self) -> &str {
108        self.api_model_name()
109    }
110
111    fn provider_name(&self) -> &str {
112        "openai"
113    }
114
115    fn max_batch_size(&self) -> usize {
116        // OpenAI API limit is 2048 texts per request
117        2048
118    }
119
120    async fn health_check(&self) -> Result<bool> {
121        // Try a minimal embedding to verify API key works
122        match self.embed_query("test").await {
123            Ok(emb) => Ok(emb.len() == self.dims),
124            Err(_) => Ok(false),
125        }
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132
133    #[test]
134    fn test_model_dimensions() {
135        // Just test the dimensions without actually creating a provider (needs API key)
136        assert_eq!(OpenAIEmbeddingModel::Ada002.dimensions(), 1536);
137        assert_eq!(OpenAIEmbeddingModel::TextEmbedding3Small.dimensions(), 1536);
138        assert_eq!(OpenAIEmbeddingModel::TextEmbedding3Large.dimensions(), 3072);
139    }
140
141    #[test]
142    fn test_api_model_names() {
143        assert_eq!(OpenAIEmbeddingModel::Ada002.api_name(), "text-embedding-ada-002");
144        assert_eq!(OpenAIEmbeddingModel::TextEmbedding3Small.api_name(), "text-embedding-3-small");
145        assert_eq!(OpenAIEmbeddingModel::TextEmbedding3Large.api_name(), "text-embedding-3-large");
146    }
147
148    // Integration test - requires API key
149    #[tokio::test]
150    #[ignore = "requires OPENAI_API_KEY"]
151    async fn test_embed_documents() {
152        let provider = OpenAIEmbedProvider::new().unwrap();
153        let texts = vec![
154            "Hello world".to_string(),
155            "How are you".to_string(),
156        ];
157
158        let embeddings = provider.embed_documents(texts).await.unwrap();
159        assert_eq!(embeddings.len(), 2);
160        assert_eq!(embeddings[0].len(), provider.dimensions());
161    }
162
163    #[test]
164    fn test_missing_api_key() {
165        // Temporarily unset the API key
166        let original = std::env::var("OPENAI_API_KEY").ok();
167        std::env::remove_var("OPENAI_API_KEY");
168
169        let result = OpenAIEmbedProvider::new();
170        assert!(result.is_err());
171
172        // Restore if it was set
173        if let Some(key) = original {
174            std::env::set_var("OPENAI_API_KEY", key);
175        }
176    }
177}