Skip to main content

fastskill_core/core/
embedding.rs

1//! Embedding service for generating vector representations of text
2
3use crate::core::service::ServiceError;
4use async_trait::async_trait;
5use reqwest::Client;
6use serde::{Deserialize, Serialize};
7
8/// Response from OpenAI embeddings API
9#[derive(Debug, Deserialize)]
10struct OpenAIEmbeddingResponse {
11    data: Vec<OpenAIEmbeddingData>,
12}
13
14#[derive(Debug, Deserialize)]
15struct OpenAIEmbeddingData {
16    embedding: Vec<f32>,
17}
18
19/// Embedding service trait
20#[async_trait]
21pub trait EmbeddingService: Send + Sync {
22    /// Generate embeddings for text content
23    async fn embed_text(&self, text: &str) -> Result<Vec<f32>, ServiceError>;
24
25    /// Generate embeddings for a search query
26    async fn embed_query(&self, query: &str) -> Result<Vec<f32>, ServiceError>;
27}
28
29/// OpenAI embedding service implementation
30pub struct OpenAIEmbeddingService {
31    client: Client,
32    base_url: String,
33    model: String,
34    api_key: String,
35}
36
37impl OpenAIEmbeddingService {
38    /// Create a new OpenAI embedding service
39    pub fn new(base_url: String, model: String, api_key: String) -> Self {
40        Self {
41            client: Client::new(),
42            base_url,
43            model,
44            api_key,
45        }
46    }
47
48    /// Create from embedding config and API key
49    pub fn from_config(config: &crate::core::service::EmbeddingConfig, api_key: String) -> Self {
50        Self::new(
51            config.openai_base_url.clone(),
52            config.embedding_model.clone(),
53            api_key,
54        )
55    }
56
57    /// Make the actual API call to OpenAI
58    async fn call_openai_api(&self, text: &str) -> Result<Vec<f32>, ServiceError> {
59        #[derive(Serialize)]
60        struct OpenAIRequest {
61            input: String,
62            model: String,
63        }
64
65        let request = OpenAIRequest {
66            input: text.to_string(),
67            model: self.model.clone(),
68        };
69
70        let url = format!("{}/embeddings", self.base_url.trim_end_matches('/'));
71
72        let response = self
73            .client
74            .post(&url)
75            .header("Authorization", format!("Bearer {}", self.api_key))
76            .header("Content-Type", "application/json")
77            .json(&request)
78            .send()
79            .await
80            .map_err(|e| ServiceError::Custom(format!("OpenAI API request failed: {}", e)))?;
81
82        if !response.status().is_success() {
83            let status = response.status();
84            let body = response.text().await.unwrap_or_default();
85            return Err(ServiceError::Custom(format!(
86                "OpenAI API error {}: {}",
87                status, body
88            )));
89        }
90
91        let embedding_response: OpenAIEmbeddingResponse = response
92            .json()
93            .await
94            .map_err(|e| ServiceError::Custom(format!("Failed to parse OpenAI response: {}", e)))?;
95
96        if embedding_response.data.is_empty() {
97            return Err(ServiceError::Custom(
98                "No embeddings returned from OpenAI".to_string(),
99            ));
100        }
101
102        Ok(embedding_response.data[0].embedding.clone())
103    }
104}
105
106#[async_trait]
107impl EmbeddingService for OpenAIEmbeddingService {
108    async fn embed_text(&self, text: &str) -> Result<Vec<f32>, ServiceError> {
109        // For text content, we might want to limit length or preprocess
110        // For now, just call the API directly
111        self.call_openai_api(text).await
112    }
113
114    async fn embed_query(&self, query: &str) -> Result<Vec<f32>, ServiceError> {
115        // Queries are typically shorter, so we can pass them through directly
116        self.call_openai_api(query).await
117    }
118}