dynamo_async_openai/
config.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
5// Original Copyright (c) 2022 Himanshu Neema
6// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
7//
8// Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
9// Licensed under Apache 2.0
10
11//! Client configurations: [OpenAIConfig] for OpenAI, [AzureConfig] for Azure OpenAI Service.
12use reqwest::header::{AUTHORIZATION, HeaderMap};
13use secrecy::{ExposeSecret, SecretString};
14use serde::Deserialize;
15
16/// Default v1 API base url
17pub const OPENAI_API_BASE: &str = "https://api.openai.com/v1";
18/// Organization header
19pub const OPENAI_ORGANIZATION_HEADER: &str = "OpenAI-Organization";
20/// Project header
21pub const OPENAI_PROJECT_HEADER: &str = "OpenAI-Project";
22
23/// Calls to the Assistants API require that you pass a Beta header
24pub const OPENAI_BETA_HEADER: &str = "OpenAI-Beta";
25
26/// [crate::Client] relies on this for every API call on OpenAI
27/// or Azure OpenAI service
28pub trait Config: Send + Sync {
29    fn headers(&self) -> HeaderMap;
30    fn url(&self, path: &str) -> String;
31    fn query(&self) -> Vec<(&str, &str)>;
32
33    fn api_base(&self) -> &str;
34
35    fn api_key(&self) -> &SecretString;
36}
37
38/// Macro to implement Config trait for pointer types with dyn objects
39macro_rules! impl_config_for_ptr {
40    ($t:ty) => {
41        impl Config for $t {
42            fn headers(&self) -> HeaderMap {
43                self.as_ref().headers()
44            }
45            fn url(&self, path: &str) -> String {
46                self.as_ref().url(path)
47            }
48            fn query(&self) -> Vec<(&str, &str)> {
49                self.as_ref().query()
50            }
51            fn api_base(&self) -> &str {
52                self.as_ref().api_base()
53            }
54            fn api_key(&self) -> &SecretString {
55                self.as_ref().api_key()
56            }
57        }
58    };
59}
60
61impl_config_for_ptr!(Box<dyn Config>);
62impl_config_for_ptr!(std::sync::Arc<dyn Config>);
63
64/// Configuration for OpenAI API
65#[derive(Clone, Debug, Deserialize)]
66#[serde(default)]
67pub struct OpenAIConfig {
68    api_base: String,
69    api_key: SecretString,
70    org_id: String,
71    project_id: String,
72}
73
74impl Default for OpenAIConfig {
75    fn default() -> Self {
76        Self {
77            api_base: OPENAI_API_BASE.to_string(),
78            api_key: std::env::var("OPENAI_API_KEY")
79                .unwrap_or_else(|_| "".to_string())
80                .into(),
81            org_id: Default::default(),
82            project_id: Default::default(),
83        }
84    }
85}
86
87impl OpenAIConfig {
88    /// Create client with default [OPENAI_API_BASE] url and default API key from OPENAI_API_KEY env var
89    pub fn new() -> Self {
90        Default::default()
91    }
92
93    /// To use a different organization id other than default
94    pub fn with_org_id<S: Into<String>>(mut self, org_id: S) -> Self {
95        self.org_id = org_id.into();
96        self
97    }
98
99    /// Non default project id
100    pub fn with_project_id<S: Into<String>>(mut self, project_id: S) -> Self {
101        self.project_id = project_id.into();
102        self
103    }
104
105    /// To use a different API key different from default OPENAI_API_KEY env var
106    pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
107        self.api_key = SecretString::from(api_key.into());
108        self
109    }
110
111    /// To use a API base url different from default [OPENAI_API_BASE]
112    pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
113        self.api_base = api_base.into();
114        self
115    }
116
117    pub fn org_id(&self) -> &str {
118        &self.org_id
119    }
120}
121
122impl Config for OpenAIConfig {
123    fn headers(&self) -> HeaderMap {
124        let mut headers = HeaderMap::new();
125        if !self.org_id.is_empty() {
126            headers.insert(
127                OPENAI_ORGANIZATION_HEADER,
128                self.org_id.as_str().parse().unwrap(),
129            );
130        }
131
132        if !self.project_id.is_empty() {
133            headers.insert(
134                OPENAI_PROJECT_HEADER,
135                self.project_id.as_str().parse().unwrap(),
136            );
137        }
138
139        headers.insert(
140            AUTHORIZATION,
141            format!("Bearer {}", self.api_key.expose_secret())
142                .as_str()
143                .parse()
144                .unwrap(),
145        );
146
147        // hack for Assistants APIs
148        // Calls to the Assistants API require that you pass a Beta header
149        headers.insert(OPENAI_BETA_HEADER, "assistants=v2".parse().unwrap());
150
151        headers
152    }
153
154    fn url(&self, path: &str) -> String {
155        format!("{}{}", self.api_base, path)
156    }
157
158    fn api_base(&self) -> &str {
159        &self.api_base
160    }
161
162    fn api_key(&self) -> &SecretString {
163        &self.api_key
164    }
165
166    fn query(&self) -> Vec<(&str, &str)> {
167        vec![]
168    }
169}
170
171/// Configuration for Azure OpenAI Service
172#[derive(Clone, Debug, Deserialize)]
173#[serde(default)]
174pub struct AzureConfig {
175    api_version: String,
176    deployment_id: String,
177    api_base: String,
178    api_key: SecretString,
179}
180
181impl Default for AzureConfig {
182    fn default() -> Self {
183        Self {
184            api_base: Default::default(),
185            api_key: std::env::var("OPENAI_API_KEY")
186                .unwrap_or_else(|_| "".to_string())
187                .into(),
188            deployment_id: Default::default(),
189            api_version: Default::default(),
190        }
191    }
192}
193
194impl AzureConfig {
195    pub fn new() -> Self {
196        Default::default()
197    }
198
199    pub fn with_api_version<S: Into<String>>(mut self, api_version: S) -> Self {
200        self.api_version = api_version.into();
201        self
202    }
203
204    pub fn with_deployment_id<S: Into<String>>(mut self, deployment_id: S) -> Self {
205        self.deployment_id = deployment_id.into();
206        self
207    }
208
209    /// To use a different API key different from default OPENAI_API_KEY env var
210    pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
211        self.api_key = SecretString::from(api_key.into());
212        self
213    }
214
215    /// API base url in form of <https://your-resource-name.openai.azure.com>
216    pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
217        self.api_base = api_base.into();
218        self
219    }
220}
221
222impl Config for AzureConfig {
223    fn headers(&self) -> HeaderMap {
224        let mut headers = HeaderMap::new();
225
226        headers.insert("api-key", self.api_key.expose_secret().parse().unwrap());
227
228        headers
229    }
230
231    fn url(&self, path: &str) -> String {
232        format!(
233            "{}/openai/deployments/{}{}",
234            self.api_base, self.deployment_id, path
235        )
236    }
237
238    fn api_base(&self) -> &str {
239        &self.api_base
240    }
241
242    fn api_key(&self) -> &SecretString {
243        &self.api_key
244    }
245
246    fn query(&self) -> Vec<(&str, &str)> {
247        vec![("api-version", &self.api_version)]
248    }
249}
250
251#[cfg(test)]
252mod test {
253    use super::*;
254    use crate::Client;
255    use crate::types::{
256        ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, CreateChatCompletionRequest,
257    };
258    use std::sync::Arc;
259    #[test]
260    fn test_client_creation() {
261        unsafe { std::env::set_var("OPENAI_API_KEY", "test") }
262        let openai_config = OpenAIConfig::default();
263        let config = Box::new(openai_config.clone()) as Box<dyn Config>;
264        let client = Client::with_config(config);
265        assert!(client.config().url("").ends_with("/v1"));
266
267        let config = Arc::new(openai_config) as Arc<dyn Config>;
268        let client = Client::with_config(config);
269        assert!(client.config().url("").ends_with("/v1"));
270        let cloned_client = client.clone();
271        assert!(cloned_client.config().url("").ends_with("/v1"));
272    }
273
274    async fn dynamic_dispatch_compiles(client: &Client<Box<dyn Config>>) {
275        let _ = client.chat().create(CreateChatCompletionRequest {
276            model: "gpt-4o".to_string(),
277            messages: vec![ChatCompletionRequestMessage::User(
278                ChatCompletionRequestUserMessage {
279                    content: "Hello, world!".into(),
280                    ..Default::default()
281                },
282            )],
283            ..Default::default()
284        });
285    }
286
287    #[tokio::test]
288    async fn test_dynamic_dispatch() {
289        let openai_config = OpenAIConfig::default();
290        let azure_config = AzureConfig::default();
291
292        let azure_client = Client::with_config(Box::new(azure_config.clone()) as Box<dyn Config>);
293        let oai_client = Client::with_config(Box::new(openai_config.clone()) as Box<dyn Config>);
294
295        let _ = dynamic_dispatch_compiles(&azure_client).await;
296        let _ = dynamic_dispatch_compiles(&oai_client).await;
297
298        let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&azure_client).await });
299        let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&oai_client).await });
300    }
301}