Skip to main content

llm/providers/openai_compatible/
generic.rs

1use async_openai::{Client, config::OpenAIConfig};
2use schemars::Schema;
3
4use crate::provider::get_context_window;
5use crate::tool_schema::normalize_for_moonshot;
6use crate::{
7    Context, LlmError, LlmModel, LlmResponseStream, ProviderAuthMode, ProviderConnectionConfig, Result,
8    StreamingModelProvider,
9};
10
11use super::{AetherOpenAiConfig, build_chat_request, create_custom_stream_generic};
12
13/// Configuration for an OpenAI-compatible provider.
14///
15/// Each provider that uses the standard `build_chat_request → create_custom_stream_generic`
16/// flow differs only in these constants.
17pub struct ProviderConfig {
18    pub api_base: &'static str,
19    pub env_var: &'static str,
20    pub default_model: &'static str,
21    pub prefix: &'static str,
22    pub display_name: &'static str,
23    pub tool_schema_transform: Option<fn(&mut Schema)>,
24}
25
26pub const DEEPSEEK: ProviderConfig = ProviderConfig {
27    api_base: "https://api.deepseek.com",
28    env_var: "DEEPSEEK_API_KEY",
29    default_model: "deepseek-chat",
30    prefix: "deepseek",
31    display_name: "DeepSeek",
32    tool_schema_transform: None,
33};
34
35pub const MOONSHOT: ProviderConfig = ProviderConfig {
36    api_base: "https://api.moonshot.ai/v1",
37    env_var: "MOONSHOT_API_KEY",
38    default_model: "moonshot-v1-8k",
39    prefix: "moonshot",
40    display_name: "Moonshot",
41    tool_schema_transform: Some(normalize_for_moonshot),
42};
43
44pub const ZAI: ProviderConfig = ProviderConfig {
45    api_base: "https://api.z.ai/api/coding/paas/v4",
46    env_var: "ZAI_API_KEY",
47    default_model: "GLM-4.6",
48    prefix: "zai",
49    display_name: "Z.ai",
50    tool_schema_transform: None,
51};
52
53/// A generic provider for APIs that are fully OpenAI-compatible.
54pub struct GenericOpenAiProvider {
55    client: Client<AetherOpenAiConfig>,
56    model: String,
57    config: &'static ProviderConfig,
58}
59
60impl GenericOpenAiProvider {
61    pub fn from_env(config: &'static ProviderConfig) -> Result<Self> {
62        Self::from_env_with_connection(config, ProviderConnectionConfig::default())
63    }
64
65    pub fn from_env_with_connection(
66        config: &'static ProviderConfig,
67        connection: ProviderConnectionConfig,
68    ) -> Result<Self> {
69        let api_key = match connection.auth_mode {
70            ProviderAuthMode::Default => {
71                std::env::var(config.env_var).map_err(|_| LlmError::MissingApiKey(config.env_var.to_string()))?
72            }
73            ProviderAuthMode::None => String::new(),
74        };
75        Ok(Self::new_with_connection(api_key, config, connection))
76    }
77
78    pub fn new(api_key: String, config: &'static ProviderConfig) -> Self {
79        Self::new_with_connection(api_key, config, ProviderConnectionConfig::default())
80    }
81
82    pub fn new_with_connection(
83        api_key: String,
84        config: &'static ProviderConfig,
85        connection: ProviderConnectionConfig,
86    ) -> Self {
87        let api_base = connection.base_url.unwrap_or_else(|| config.api_base.to_string());
88        let openai_config = OpenAIConfig::new().with_api_key(api_key).with_api_base(api_base);
89        let openai_config = AetherOpenAiConfig::new(openai_config, connection.auth_mode);
90
91        Self { client: Client::with_config(openai_config), model: config.default_model.to_string(), config }
92    }
93
94    pub fn with_model(mut self, model: &str) -> Self {
95        self.model = model.to_string();
96        self
97    }
98}
99
100impl StreamingModelProvider for GenericOpenAiProvider {
101    fn model(&self) -> Option<LlmModel> {
102        format!("{}:{}", self.config.prefix, self.model).parse().ok()
103    }
104
105    fn context_window(&self) -> Option<u32> {
106        get_context_window(self.config.prefix, &self.model)
107    }
108
109    fn stream_response(&self, context: &Context) -> LlmResponseStream {
110        let request = match build_chat_request(&self.model, context, self.config.tool_schema_transform) {
111            Ok(req) => req,
112            Err(e) => return Box::pin(async_stream::stream! { yield Err(e); }),
113        };
114        create_custom_stream_generic(&self.client, request)
115    }
116
117    fn display_name(&self) -> String {
118        format!("{} ({})", self.config.display_name, self.model)
119    }
120}