Skip to main content

openai_core/
auth.rs

1//! 认证相关的通用抽象。
2
3use std::fmt;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use secrecy::{ExposeSecret, SecretString};
9
10use crate::error::{Error, Result};
11
12/// 表示一个可动态生成 API Key 的回调。
13pub type ApiKeyProvider = dyn Fn() -> Result<SecretString> + Send + Sync;
14
15/// 表示一个可异步生成 API Key 的回调。
16pub type AsyncApiKeyProvider =
17    dyn Fn() -> Pin<Box<dyn Future<Output = Result<SecretString>> + Send>> + Send + Sync;
18
19/// 表示客户端使用的 API Key 来源。
20#[derive(Clone)]
21pub enum ApiKeySource {
22    /// 使用固定字符串作为 API Key。
23    Static(SecretString),
24    /// 每次请求或重试时动态生成 API Key。
25    Dynamic(Arc<ApiKeyProvider>),
26    /// 每次请求或重试时异步生成 API Key。
27    AsyncDynamic(Arc<AsyncApiKeyProvider>),
28}
29
30impl ApiKeySource {
31    /// 创建一个静态 API Key 来源。
32    pub fn from_static<T>(value: T) -> Self
33    where
34        T: Into<String>,
35    {
36        Self::Static(SecretString::new(value.into().into()))
37    }
38
39    /// 创建一个动态 API Key 来源。
40    pub fn from_provider<F>(provider: F) -> Self
41    where
42        F: Fn() -> Result<SecretString> + Send + Sync + 'static,
43    {
44        Self::Dynamic(Arc::new(provider))
45    }
46
47    /// 创建一个异步 API Key 来源。
48    pub fn from_async_provider<F, Fut>(provider: F) -> Self
49    where
50        F: Fn() -> Fut + Send + Sync + 'static,
51        Fut: Future<Output = Result<SecretString>> + Send + 'static,
52    {
53        Self::AsyncDynamic(Arc::new(move || Box::pin(provider())))
54    }
55
56    /// 在当前时刻解析出可用的 API Key。
57    ///
58    /// # Errors
59    ///
60    /// 当动态回调返回错误时返回对应错误。
61    ///
62    /// 若来源是异步回调,请改用 [`Self::resolve_async`]。
63    pub fn resolve(&self) -> Result<SecretString> {
64        match self {
65            Self::Static(value) => Ok(value.clone()),
66            Self::Dynamic(provider) => provider(),
67            Self::AsyncDynamic(_) => Err(Error::InvalidConfig(
68                "当前 API Key 来源为异步回调,请使用 resolve_async".into(),
69            )),
70        }
71    }
72
73    /// 在当前时刻异步解析出可用的 API Key。
74    ///
75    /// # Errors
76    ///
77    /// 当动态回调返回错误时返回对应错误。
78    pub async fn resolve_async(&self) -> Result<SecretString> {
79        match self {
80            Self::Static(value) => Ok(value.clone()),
81            Self::Dynamic(provider) => provider(),
82            Self::AsyncDynamic(provider) => provider().await,
83        }
84    }
85
86    /// 返回一个可用于日志的脱敏字符串。
87    pub fn redacted(&self) -> String {
88        match self {
89            Self::Static(secret) => redact_secret(secret.expose_secret()),
90            Self::Dynamic(_) => "<dynamic-api-key-provider>".into(),
91            Self::AsyncDynamic(_) => "<async-api-key-provider>".into(),
92        }
93    }
94}
95
96impl fmt::Debug for ApiKeySource {
97    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98        f.debug_tuple("ApiKeySource")
99            .field(&self.redacted())
100            .finish()
101    }
102}
103
104fn redact_secret(secret: &str) -> String {
105    if secret.is_empty() {
106        return "<empty-secret>".into();
107    }
108
109    if secret.len() <= 8 {
110        return "********".into();
111    }
112
113    let prefix = &secret[..4];
114    let suffix = &secret[secret.len() - 4..];
115    format!("{prefix}****{suffix}")
116}
117
118impl From<SecretString> for ApiKeySource {
119    fn from(value: SecretString) -> Self {
120        Self::Static(value)
121    }
122}
123
124impl TryFrom<Option<ApiKeySource>> for ApiKeySource {
125    type Error = Error;
126
127    fn try_from(value: Option<ApiKeySource>) -> Result<Self> {
128        value.ok_or(Error::MissingCredentials)
129    }
130}