1use std::sync::OnceLock;
2
3use async_openai::Client as OpenAiSdkClient;
4use async_openai::config::OpenAIConfig;
5use async_openai::types::{CreateEmbeddingRequestArgs, EmbeddingInput};
6use async_trait::async_trait;
7use tiktoken_rs::{CoreBPE, cl100k_base};
8
9use crate::config::LlmConfig;
10use crate::embeddings::{EmbeddingModelConfig, EmbeddingProvider};
11use crate::error::{LLMBrainError, Result};
12
13static BPE: OnceLock<CoreBPE> = OnceLock::new();
15
16const DEFAULT_EMBEDDING_MODEL: &str = "text-embedding-3-small";
17
18#[derive(Clone)]
21pub struct OpenAiClient {
22 client: OpenAiSdkClient<OpenAIConfig>,
24
25 embedding_model: String,
27
28 embedding_config: EmbeddingModelConfig,
30}
31
32impl OpenAiClient {
33 pub fn new(config: Option<&LlmConfig>) -> Result<Self> {
39 let api_key = config
40 .and_then(|c| c.openai_api_key.as_deref())
41 .unwrap_or("")
42 .to_owned();
43
44 let api_base = config
45 .and_then(|c| c.openai_api_base.as_deref())
46 .unwrap_or("")
47 .to_owned();
48
49 let embedding_model = config
50 .and_then(|c| c.embedding_model.as_deref())
51 .unwrap_or(DEFAULT_EMBEDDING_MODEL)
52 .to_owned();
53
54 let mut client_config = if !api_key.is_empty() {
55 OpenAIConfig::new().with_api_key(api_key)
56 } else {
57 OpenAIConfig::default()
58 };
59
60 if !api_base.is_empty() {
61 client_config = client_config.with_api_base(api_base);
62 }
63
64 let client = OpenAiSdkClient::with_config(client_config);
65
66 BPE.get_or_init(|| cl100k_base().expect("Failed to load cl100k_base tokenizer for OpenAI"));
68
69 let embedding_config = EmbeddingModelConfig {
71 model_name: embedding_model.clone(),
72 dimensions: 1536,
73 max_context_length: 8191,
74 };
75
76 Ok(Self {
77 client,
78 embedding_model,
79 embedding_config,
80 })
81 }
82
83 pub fn get_embedding_config(&self) -> &EmbeddingModelConfig {
85 &self.embedding_config
86 }
87}
88
89#[async_trait]
90impl EmbeddingProvider for OpenAiClient {
91 async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
93 let mut embeddings = self.generate_embeddings(vec![text.to_owned()]).await?;
94 if let Some(embedding) = embeddings.pop() {
95 Ok(embedding)
96 } else {
97 Err(LLMBrainError::ApiError(
98 "OpenAI API returned no embedding for single text input".to_owned(),
99 ))
100 }
101 }
102
103 async fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
113 if texts.is_empty() {
114 return Ok(Vec::new());
115 }
116
117 let request = CreateEmbeddingRequestArgs::default()
118 .model(&self.embedding_model)
119 .input(EmbeddingInput::StringArray(texts))
120 .build()
121 .map_err(|e| {
122 LLMBrainError::ApiError(format!("Failed to build OpenAI embedding request: {e}"))
123 })?;
124
125 let response = self
126 .client
127 .embeddings()
128 .create(request)
129 .await
130 .map_err(|e| {
131 LLMBrainError::ApiError(format!("OpenAI embedding API request failed: {e}"))
132 })?;
133
134 let embeddings = response
136 .data
137 .into_iter()
138 .map(|embedding_obj| embedding_obj.embedding)
139 .collect();
140
141 Ok(embeddings)
142 }
143
144 fn count_tokens(&self, text: &str) -> Result<usize> {
148 let bpe = BPE.get().expect("BPE Tokenizer not initialized");
149 Ok(bpe.encode_with_special_tokens(text).len())
150 }
151}
152
153