fastskill_core/core/
embedding.rs1use crate::core::service::ServiceError;
4use async_trait::async_trait;
5use reqwest::Client;
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Deserialize)]
10struct OpenAIEmbeddingResponse {
11 data: Vec<OpenAIEmbeddingData>,
12}
13
14#[derive(Debug, Deserialize)]
15struct OpenAIEmbeddingData {
16 embedding: Vec<f32>,
17}
18
19#[async_trait]
21pub trait EmbeddingService: Send + Sync {
22 async fn embed_text(&self, text: &str) -> Result<Vec<f32>, ServiceError>;
24
25 async fn embed_query(&self, query: &str) -> Result<Vec<f32>, ServiceError>;
27}
28
29pub struct OpenAIEmbeddingService {
31 client: Client,
32 base_url: String,
33 model: String,
34 api_key: String,
35}
36
37impl OpenAIEmbeddingService {
38 pub fn new(base_url: String, model: String, api_key: String) -> Self {
40 Self {
41 client: Client::new(),
42 base_url,
43 model,
44 api_key,
45 }
46 }
47
48 pub fn from_config(config: &crate::core::service::EmbeddingConfig, api_key: String) -> Self {
50 Self::new(
51 config.openai_base_url.clone(),
52 config.embedding_model.clone(),
53 api_key,
54 )
55 }
56
57 async fn call_openai_api(&self, text: &str) -> Result<Vec<f32>, ServiceError> {
59 #[derive(Serialize)]
60 struct OpenAIRequest {
61 input: String,
62 model: String,
63 }
64
65 let request = OpenAIRequest {
66 input: text.to_string(),
67 model: self.model.clone(),
68 };
69
70 let url = format!("{}/embeddings", self.base_url.trim_end_matches('/'));
71
72 let response = self
73 .client
74 .post(&url)
75 .header("Authorization", format!("Bearer {}", self.api_key))
76 .header("Content-Type", "application/json")
77 .json(&request)
78 .send()
79 .await
80 .map_err(|e| ServiceError::Custom(format!("OpenAI API request failed: {}", e)))?;
81
82 if !response.status().is_success() {
83 let status = response.status();
84 let body = response.text().await.unwrap_or_default();
85 return Err(ServiceError::Custom(format!(
86 "OpenAI API error {}: {}",
87 status, body
88 )));
89 }
90
91 let embedding_response: OpenAIEmbeddingResponse = response
92 .json()
93 .await
94 .map_err(|e| ServiceError::Custom(format!("Failed to parse OpenAI response: {}", e)))?;
95
96 if embedding_response.data.is_empty() {
97 return Err(ServiceError::Custom(
98 "No embeddings returned from OpenAI".to_string(),
99 ));
100 }
101
102 Ok(embedding_response.data[0].embedding.clone())
103 }
104}
105
106#[async_trait]
107impl EmbeddingService for OpenAIEmbeddingService {
108 async fn embed_text(&self, text: &str) -> Result<Vec<f32>, ServiceError> {
109 self.call_openai_api(text).await
112 }
113
114 async fn embed_query(&self, query: &str) -> Result<Vec<f32>, ServiceError> {
115 self.call_openai_api(query).await
117 }
118}