llm_brain/
llm.rs

1use std::sync::OnceLock;
2
3use async_openai::Client as OpenAiSdkClient;
4use async_openai::config::OpenAIConfig;
5use async_openai::types::{CreateEmbeddingRequestArgs, EmbeddingInput};
6use async_trait::async_trait;
7use tiktoken_rs::{CoreBPE, cl100k_base};
8
9use crate::config::LlmConfig;
10use crate::embeddings::{EmbeddingModelConfig, EmbeddingProvider};
11use crate::error::{LLMBrainError, Result};
12
13/// Cached tokenizer instance for performance
14static BPE: OnceLock<CoreBPE> = OnceLock::new();
15
16const DEFAULT_EMBEDDING_MODEL: &str = "text-embedding-3-small";
17
18/// Client for interactions with `OpenAI`'s API, particularly for generating
19/// embeddings
20#[derive(Clone)]
21pub struct OpenAiClient {
22    /// The underlying `OpenAI` SDK client
23    client: OpenAiSdkClient<OpenAIConfig>,
24
25    /// Model used for embedding generation
26    embedding_model: String,
27
28    /// Embedding model configuration
29    embedding_config: EmbeddingModelConfig,
30}
31
32impl OpenAiClient {
33    /// Creates a new `OpenAI` client.
34    ///
35    /// Uses the provided configuration or falls back to environment variables.
36    /// Will attempt to read `OPENAI_API_KEY` from environment if not specified
37    /// in config.
38    pub fn new(config: Option<&LlmConfig>) -> Result<Self> {
39        let api_key = config
40            .and_then(|c| c.openai_api_key.as_deref())
41            .unwrap_or("")
42            .to_owned();
43
44        let api_base = config
45            .and_then(|c| c.openai_api_base.as_deref())
46            .unwrap_or("")
47            .to_owned();
48
49        let embedding_model = config
50            .and_then(|c| c.embedding_model.as_deref())
51            .unwrap_or(DEFAULT_EMBEDDING_MODEL)
52            .to_owned();
53
54        let mut client_config = if !api_key.is_empty() {
55            OpenAIConfig::new().with_api_key(api_key)
56        } else {
57            OpenAIConfig::default()
58        };
59
60        if !api_base.is_empty() {
61            client_config = client_config.with_api_base(api_base);
62        }
63
64        let client = OpenAiSdkClient::with_config(client_config);
65
66        // Initialize BPE tokenizer on first use
67        BPE.get_or_init(|| cl100k_base().expect("Failed to load cl100k_base tokenizer for OpenAI"));
68
69        // Create default embedding model configuration
70        let embedding_config = EmbeddingModelConfig {
71            model_name: embedding_model.clone(),
72            dimensions: 1536,
73            max_context_length: 8191,
74        };
75
76        Ok(Self {
77            client,
78            embedding_model,
79            embedding_config,
80        })
81    }
82
83    /// Get current embedding model configuration
84    pub fn get_embedding_config(&self) -> &EmbeddingModelConfig {
85        &self.embedding_config
86    }
87}
88
89#[async_trait]
90impl EmbeddingProvider for OpenAiClient {
91    /// Generates an embedding for a single text.
92    async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
93        let mut embeddings = self.generate_embeddings(vec![text.to_owned()]).await?;
94        if let Some(embedding) = embeddings.pop() {
95            Ok(embedding)
96        } else {
97            Err(LLMBrainError::ApiError(
98                "OpenAI API returned no embedding for single text input".to_owned(),
99            ))
100        }
101    }
102
103    /// Generates embeddings for multiple texts in a single API call.
104    ///
105    /// # Arguments
106    ///
107    /// * `texts` - Vector of strings to generate embeddings for
108    ///
109    /// # Returns
110    ///
111    /// Vector of embedding vectors, in the same order as the input texts
112    async fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
113        if texts.is_empty() {
114            return Ok(Vec::new());
115        }
116
117        let request = CreateEmbeddingRequestArgs::default()
118            .model(&self.embedding_model)
119            .input(EmbeddingInput::StringArray(texts))
120            .build()
121            .map_err(|e| {
122                LLMBrainError::ApiError(format!("Failed to build OpenAI embedding request: {e}"))
123            })?;
124
125        let response = self
126            .client
127            .embeddings()
128            .create(request)
129            .await
130            .map_err(|e| {
131                LLMBrainError::ApiError(format!("OpenAI embedding API request failed: {e}"))
132            })?;
133
134        // Collect the embeddings from the response
135        let embeddings = response
136            .data
137            .into_iter()
138            .map(|embedding_obj| embedding_obj.embedding)
139            .collect();
140
141        Ok(embeddings)
142    }
143
144    /// Counts the number of tokens in a text using the cl100k_base tokenizer.
145    ///
146    /// Uses the same tokenizer as the OpenAI API to get accurate token counts.
147    fn count_tokens(&self, text: &str) -> Result<usize> {
148        let bpe = BPE.get().expect("BPE Tokenizer not initialized");
149        Ok(bpe.encode_with_special_tokens(text).len())
150    }
151}
152
153// Optional: Define a generic LlmClient trait if supporting multiple providers
154// #[async_trait]
155// pub trait LlmClient: Send + Sync {
156//     async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>>;
157//     async fn generate_embeddings(&self, texts: Vec<String>) ->
158// Result<Vec<Vec<f32>>>;     fn count_tokens(&self, text: &str) ->
159// Result<usize>; }