Skip to main content

openai_core/
config.rs

1//! 客户端与请求级配置。
2
3use std::collections::BTreeMap;
4use std::fmt;
5use std::str::FromStr;
6use std::sync::Arc;
7use std::time::Duration;
8
9use secrecy::SecretString;
10use tokio_util::sync::CancellationToken;
11
12use crate::providers::{CompatibilityMode, Provider};
13
14/// SDK 日志级别。
15#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default)]
16pub enum LogLevel {
17    /// 关闭 SDK 内部日志。
18    Off,
19    /// 仅输出错误日志。
20    Error,
21    /// 输出警告和错误日志。
22    #[default]
23    Warn,
24    /// 输出信息、警告和错误日志。
25    Info,
26    /// 输出全部调试日志。
27    Debug,
28}
29
30impl LogLevel {
31    /// 判断当前配置是否允许输出指定级别的日志。
32    pub fn allows(self, level: Self) -> bool {
33        self != Self::Off && level <= self
34    }
35
36    /// 返回日志级别的稳定字符串表示。
37    pub fn as_str(self) -> &'static str {
38        match self {
39            Self::Off => "off",
40            Self::Error => "error",
41            Self::Warn => "warn",
42            Self::Info => "info",
43            Self::Debug => "debug",
44        }
45    }
46}
47
48impl FromStr for LogLevel {
49    type Err = String;
50
51    fn from_str(value: &str) -> Result<Self, Self::Err> {
52        match value.trim().to_ascii_lowercase().as_str() {
53            "off" | "none" => Ok(Self::Off),
54            "error" => Ok(Self::Error),
55            "warn" | "warning" => Ok(Self::Warn),
56            "info" => Ok(Self::Info),
57            "debug" => Ok(Self::Debug),
58            other => Err(format!("不支持的日志级别: {other}")),
59        }
60    }
61}
62
63/// 一条 SDK 日志记录。
64#[derive(Debug, Clone, PartialEq, Eq)]
65pub struct LogRecord {
66    /// 日志级别。
67    pub level: LogLevel,
68    /// 日志目标,一般对应子系统。
69    pub target: &'static str,
70    /// 人类可读消息。
71    pub message: String,
72    /// 附加字段。
73    pub fields: BTreeMap<String, String>,
74}
75
76/// 用户自定义日志接收器。
77pub trait Logger: Send + Sync {
78    /// 处理一条日志记录。
79    fn log(&self, record: &LogRecord);
80}
81
82impl<F> Logger for F
83where
84    F: Fn(&LogRecord) + Send + Sync,
85{
86    fn log(&self, record: &LogRecord) {
87        (self)(record);
88    }
89}
90
91/// 可克隆的日志器句柄。
92#[derive(Clone)]
93pub struct LoggerHandle {
94    inner: Arc<dyn Logger>,
95}
96
97impl LoggerHandle {
98    /// 创建新的日志器句柄。
99    pub fn new<L>(logger: L) -> Self
100    where
101        L: Logger + 'static,
102    {
103        Self {
104            inner: Arc::new(logger),
105        }
106    }
107
108    /// 输出一条日志。
109    pub fn log(&self, record: &LogRecord) {
110        self.inner.log(record);
111    }
112}
113
114impl fmt::Debug for LoggerHandle {
115    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116        f.write_str("LoggerHandle(..)")
117    }
118}
119
120/// 表示客户端级别的默认配置。
121#[derive(Debug, Clone)]
122pub struct ClientOptions {
123    /// 当前客户端使用的 Provider。
124    pub provider: Provider,
125    /// 覆盖默认的基础地址。
126    pub base_url: Option<String>,
127    /// 当基础地址指向本机地址时,是否显式关闭系统代理,默认关闭。
128    pub disable_proxy_for_local_base_url: bool,
129    /// 每次请求默认超时时间。
130    pub timeout: Duration,
131    /// 默认最大重试次数。
132    pub max_retries: u32,
133    /// 发送前追加到所有请求中的默认请求头。
134    pub default_headers: BTreeMap<String, String>,
135    /// 发送前追加到所有请求中的默认查询参数。
136    pub default_query: BTreeMap<String, String>,
137    /// 可选的 Webhook 密钥。
138    pub webhook_secret: Option<SecretString>,
139    /// SDK 内部日志级别。
140    pub log_level: LogLevel,
141    /// 可选的用户日志器。
142    pub logger: Option<LoggerHandle>,
143    /// Provider 兼容校验模式。
144    pub compatibility_mode: CompatibilityMode,
145}
146
147impl Default for ClientOptions {
148    fn default() -> Self {
149        Self {
150            provider: Provider::openai(),
151            base_url: None,
152            disable_proxy_for_local_base_url: false,
153            timeout: Duration::from_secs(600),
154            max_retries: 2,
155            default_headers: BTreeMap::new(),
156            default_query: BTreeMap::new(),
157            webhook_secret: None,
158            log_level: LogLevel::Warn,
159            logger: None,
160            compatibility_mode: CompatibilityMode::Passthrough,
161        }
162    }
163}
164
165/// 表示单次请求可覆盖的配置。
166#[derive(Debug, Clone, Default)]
167pub struct RequestOptions {
168    /// 额外请求头。若值为 `None`,则会移除同名默认请求头。
169    pub extra_headers: BTreeMap<String, Option<String>>,
170    /// 额外查询参数。若值为 `None`,则会移除同名默认查询参数。
171    pub extra_query: BTreeMap<String, Option<String>>,
172    /// 覆盖客户端默认超时时间。
173    pub timeout: Option<Duration>,
174    /// 覆盖客户端默认最大重试次数。
175    pub max_retries: Option<u32>,
176    /// 可选的取消令牌。
177    pub cancellation_token: Option<CancellationToken>,
178}
179
180impl RequestOptions {
181    /// 追加或覆盖一个请求头。
182    pub fn insert_header<K, V>(&mut self, key: K, value: V)
183    where
184        K: Into<String>,
185        V: Into<String>,
186    {
187        self.extra_headers.insert(key.into(), Some(value.into()));
188    }
189
190    /// 移除一个请求头。
191    pub fn remove_header<K>(&mut self, key: K)
192    where
193        K: Into<String>,
194    {
195        self.extra_headers.insert(key.into(), None);
196    }
197
198    /// 追加或覆盖一个查询参数。
199    pub fn insert_query<K, V>(&mut self, key: K, value: V)
200    where
201        K: Into<String>,
202        V: Into<String>,
203    {
204        self.extra_query.insert(key.into(), Some(value.into()));
205    }
206
207    /// 移除一个查询参数。
208    pub fn remove_query<K>(&mut self, key: K)
209    where
210        K: Into<String>,
211    {
212        self.extra_query.insert(key.into(), None);
213    }
214
215    /// 合并客户端默认请求头与请求级请求头。
216    pub fn merged_headers(&self, defaults: &BTreeMap<String, String>) -> BTreeMap<String, String> {
217        merge_kv_maps(defaults, &self.extra_headers)
218    }
219
220    /// 合并客户端默认查询参数与请求级查询参数。
221    pub fn merged_query(&self, defaults: &BTreeMap<String, String>) -> BTreeMap<String, String> {
222        merge_kv_maps(defaults, &self.extra_query)
223    }
224}
225
226/// 合并默认键值对与请求级覆盖项。
227pub fn merge_kv_maps(
228    defaults: &BTreeMap<String, String>,
229    overrides: &BTreeMap<String, Option<String>>,
230) -> BTreeMap<String, String> {
231    let mut merged = defaults.clone();
232
233    for (key, value) in overrides {
234        match value {
235            Some(value) => {
236                merged.insert(key.clone(), value.clone());
237            }
238            None => {
239                merged.remove(key);
240            }
241        }
242    }
243
244    merged
245}