alith_interface/llms/api/openai/
mod.rs

1pub mod builder;
2pub mod completion;
3
4use super::{
5    client::ApiClient,
6    config::{ApiConfig, ApiConfigTrait},
7};
8use crate::requests::{
9    completion::{
10        error::CompletionError, request::CompletionRequest, response::CompletionResponse,
11    },
12    embeddings::{EmbeddingsError, EmbeddingsRequest, EmbeddingsResponse},
13};
14use alith_devices::logging::LoggingConfig;
15use alith_models::api_model::ApiLLMModel;
16use completion::OpenAICompletionRequest;
17use reqwest::header::{AUTHORIZATION, HeaderMap, HeaderValue};
18use secrecy::{ExposeSecret, SecretString};
19use serde_json::json;
20
21/// Default v1 API base url
22pub const OPENAI_API_HOST: &str = "api.openai.com/v1";
23/// Organization header
24pub const OPENAI_ORGANIZATION_HEADER: &str = "OpenAI-Organization";
25/// Project header
26pub const OPENAI_PROJECT_HEADER: &str = "OpenAI-Project";
27
28pub struct OpenAIBackend {
29    pub(crate) client: ApiClient<OpenAIConfig>,
30    pub model: ApiLLMModel,
31}
32
33impl OpenAIBackend {
34    pub fn new(mut config: OpenAIConfig, model: ApiLLMModel) -> crate::Result<Self> {
35        config.logging_config.load_logger()?;
36        config.api_config.api_key = Some(config.api_config.load_api_key()?);
37        Ok(Self {
38            client: ApiClient::new(config),
39            model,
40        })
41    }
42
43    pub(crate) async fn completion_request(
44        &self,
45        request: &CompletionRequest,
46    ) -> crate::Result<CompletionResponse, CompletionError> {
47        match self
48            .client
49            .post("/chat/completions", OpenAICompletionRequest::new(request)?)
50            .await
51        {
52            Err(e) => Err(CompletionError::ClientError(e)),
53            Ok(res) => Ok(CompletionResponse::new_from_openai(request, res)?),
54        }
55    }
56
57    pub(crate) async fn embeddings_request(
58        &self,
59        request: &EmbeddingsRequest,
60    ) -> crate::Result<EmbeddingsResponse, EmbeddingsError> {
61        match self
62            .client
63            .post(
64                "/embeddings",
65                json!({
66                    "input": request.input,
67                    "model": request.model,
68                }),
69            )
70            .await
71        {
72            Ok(res) => Ok(res),
73            Err(e) => Err(EmbeddingsError::ClientError(e)),
74        }
75    }
76}
77
78#[derive(Clone, Debug)]
79pub struct OpenAIConfig {
80    pub api_config: ApiConfig,
81    pub logging_config: LoggingConfig,
82    pub org_id: String,
83    pub project_id: String,
84}
85
86impl Default for OpenAIConfig {
87    fn default() -> Self {
88        Self {
89            api_config: ApiConfig {
90                host: OPENAI_API_HOST.to_string(),
91                port: None,
92                api_key: None,
93                api_key_env_var: "OPENAI_API_KEY".to_string(),
94            },
95            logging_config: LoggingConfig {
96                logger_name: "openai".to_string(),
97                ..Default::default()
98            },
99            org_id: Default::default(),
100            project_id: Default::default(),
101        }
102    }
103}
104
105impl OpenAIConfig {
106    pub fn new() -> Self {
107        Default::default()
108    }
109
110    /// To use a different organization id other than default
111    pub fn with_org_id<S: Into<String>>(mut self, org_id: S) -> Self {
112        self.org_id = org_id.into();
113        self
114    }
115
116    /// Non default project id
117    pub fn with_project_id<S: Into<String>>(mut self, project_id: S) -> Self {
118        self.project_id = project_id.into();
119        self
120    }
121}
122
123impl ApiConfigTrait for OpenAIConfig {
124    fn headers(&self) -> HeaderMap {
125        let mut headers = HeaderMap::new();
126
127        if !self.org_id.is_empty() {
128            if let Ok(header_value) = HeaderValue::from_str(self.org_id.as_str()) {
129                headers.insert(OPENAI_ORGANIZATION_HEADER, header_value);
130            } else {
131                crate::error!("Failed to create header value from org_id value");
132            }
133        }
134        if !self.project_id.is_empty() {
135            if let Ok(header_value) = HeaderValue::from_str(self.project_id.as_str()) {
136                headers.insert(OPENAI_PROJECT_HEADER, header_value);
137            } else {
138                crate::error!("Failed to create header value from project_id value");
139            }
140        }
141        if let Some(api_key) = self.api_key() {
142            if let Ok(header_value) =
143                HeaderValue::from_str(&format!("Bearer {}", api_key.expose_secret()))
144            {
145                headers.insert(AUTHORIZATION, header_value);
146            } else {
147                crate::error!("Failed to create header value from authorization value");
148            }
149        }
150
151        headers
152    }
153
154    fn url(&self, path: &str) -> String {
155        if self.api_config.host.starts_with("http") {
156            if let Some(port) = &self.api_config.port {
157                format!("{}:{}{}", self.api_config.host, port, path)
158            } else {
159                format!("{}{}", self.api_config.host, path)
160            }
161        } else {
162            format!("https://{}{}", self.api_config.host, path)
163        }
164    }
165
166    fn api_key(&self) -> &Option<SecretString> {
167        &self.api_config.api_key
168    }
169}