use std::collections::BTreeMap;
use std::fmt;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use secrecy::SecretString;
use tokio_util::sync::CancellationToken;
use crate::providers::{CompatibilityMode, Provider};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default)]
pub enum LogLevel {
Off,
Error,
#[default]
Warn,
Info,
Debug,
}
impl LogLevel {
pub fn allows(self, level: Self) -> bool {
self != Self::Off && level <= self
}
pub fn as_str(self) -> &'static str {
match self {
Self::Off => "off",
Self::Error => "error",
Self::Warn => "warn",
Self::Info => "info",
Self::Debug => "debug",
}
}
}
impl FromStr for LogLevel {
type Err = String;
fn from_str(value: &str) -> Result<Self, Self::Err> {
match value.trim().to_ascii_lowercase().as_str() {
"off" | "none" => Ok(Self::Off),
"error" => Ok(Self::Error),
"warn" | "warning" => Ok(Self::Warn),
"info" => Ok(Self::Info),
"debug" => Ok(Self::Debug),
other => Err(format!("不支持的日志级别: {other}")),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct LogRecord {
pub level: LogLevel,
pub target: &'static str,
pub message: String,
pub fields: BTreeMap<String, String>,
}
pub trait Logger: Send + Sync {
fn log(&self, record: &LogRecord);
}
impl<F> Logger for F
where
F: Fn(&LogRecord) + Send + Sync,
{
fn log(&self, record: &LogRecord) {
(self)(record);
}
}
#[derive(Clone)]
pub struct LoggerHandle {
inner: Arc<dyn Logger>,
}
impl LoggerHandle {
pub fn new<L>(logger: L) -> Self
where
L: Logger + 'static,
{
Self {
inner: Arc::new(logger),
}
}
pub fn log(&self, record: &LogRecord) {
self.inner.log(record);
}
}
impl fmt::Debug for LoggerHandle {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("LoggerHandle(..)")
}
}
#[derive(Debug, Clone)]
pub struct ClientOptions {
pub provider: Provider,
pub base_url: Option<String>,
pub disable_proxy_for_local_base_url: bool,
pub timeout: Duration,
pub max_retries: u32,
pub default_headers: BTreeMap<String, String>,
pub default_query: BTreeMap<String, String>,
pub webhook_secret: Option<SecretString>,
pub log_level: LogLevel,
pub logger: Option<LoggerHandle>,
pub compatibility_mode: CompatibilityMode,
}
impl Default for ClientOptions {
fn default() -> Self {
Self {
provider: Provider::openai(),
base_url: None,
disable_proxy_for_local_base_url: false,
timeout: Duration::from_secs(600),
max_retries: 2,
default_headers: BTreeMap::new(),
default_query: BTreeMap::new(),
webhook_secret: None,
log_level: LogLevel::Warn,
logger: None,
compatibility_mode: CompatibilityMode::Passthrough,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct RequestOptions {
pub extra_headers: BTreeMap<String, Option<String>>,
pub extra_query: BTreeMap<String, Option<String>>,
pub timeout: Option<Duration>,
pub max_retries: Option<u32>,
pub cancellation_token: Option<CancellationToken>,
}
impl RequestOptions {
pub fn insert_header<K, V>(&mut self, key: K, value: V)
where
K: Into<String>,
V: Into<String>,
{
self.extra_headers.insert(key.into(), Some(value.into()));
}
pub fn remove_header<K>(&mut self, key: K)
where
K: Into<String>,
{
self.extra_headers.insert(key.into(), None);
}
pub fn insert_query<K, V>(&mut self, key: K, value: V)
where
K: Into<String>,
V: Into<String>,
{
self.extra_query.insert(key.into(), Some(value.into()));
}
pub fn remove_query<K>(&mut self, key: K)
where
K: Into<String>,
{
self.extra_query.insert(key.into(), None);
}
pub fn merged_headers(&self, defaults: &BTreeMap<String, String>) -> BTreeMap<String, String> {
merge_kv_maps(defaults, &self.extra_headers)
}
pub fn merged_query(&self, defaults: &BTreeMap<String, String>) -> BTreeMap<String, String> {
merge_kv_maps(defaults, &self.extra_query)
}
}
pub fn merge_kv_maps(
defaults: &BTreeMap<String, String>,
overrides: &BTreeMap<String, Option<String>>,
) -> BTreeMap<String, String> {
let mut merged = defaults.clone();
for (key, value) in overrides {
match value {
Some(value) => {
merged.insert(key.clone(), value.clone());
}
None => {
merged.remove(key);
}
}
}
merged
}