llm/providers/openai_compatible/
generic.rs1use async_openai::{Client, config::OpenAIConfig};
2use schemars::Schema;
3
4use crate::provider::get_context_window;
5use crate::tool_schema::normalize_for_moonshot;
6use crate::{
7 Context, LlmError, LlmModel, LlmResponseStream, ProviderAuthMode, ProviderConnectionConfig, Result,
8 StreamingModelProvider,
9};
10
11use super::{AetherOpenAiConfig, build_chat_request, create_custom_stream_generic};
12
13pub struct ProviderConfig {
18 pub api_base: &'static str,
19 pub env_var: &'static str,
20 pub default_model: &'static str,
21 pub prefix: &'static str,
22 pub display_name: &'static str,
23 pub tool_schema_transform: Option<fn(&mut Schema)>,
24}
25
26pub const DEEPSEEK: ProviderConfig = ProviderConfig {
27 api_base: "https://api.deepseek.com",
28 env_var: "DEEPSEEK_API_KEY",
29 default_model: "deepseek-chat",
30 prefix: "deepseek",
31 display_name: "DeepSeek",
32 tool_schema_transform: None,
33};
34
35pub const MOONSHOT: ProviderConfig = ProviderConfig {
36 api_base: "https://api.moonshot.ai/v1",
37 env_var: "MOONSHOT_API_KEY",
38 default_model: "moonshot-v1-8k",
39 prefix: "moonshot",
40 display_name: "Moonshot",
41 tool_schema_transform: Some(normalize_for_moonshot),
42};
43
44pub const ZAI: ProviderConfig = ProviderConfig {
45 api_base: "https://api.z.ai/api/coding/paas/v4",
46 env_var: "ZAI_API_KEY",
47 default_model: "GLM-4.6",
48 prefix: "zai",
49 display_name: "Z.ai",
50 tool_schema_transform: None,
51};
52
53pub struct GenericOpenAiProvider {
55 client: Client<AetherOpenAiConfig>,
56 model: String,
57 config: &'static ProviderConfig,
58}
59
60impl GenericOpenAiProvider {
61 pub fn from_env(config: &'static ProviderConfig) -> Result<Self> {
62 Self::from_env_with_connection(config, ProviderConnectionConfig::default())
63 }
64
65 pub fn from_env_with_connection(
66 config: &'static ProviderConfig,
67 connection: ProviderConnectionConfig,
68 ) -> Result<Self> {
69 let api_key = match connection.auth_mode {
70 ProviderAuthMode::Default => {
71 std::env::var(config.env_var).map_err(|_| LlmError::MissingApiKey(config.env_var.to_string()))?
72 }
73 ProviderAuthMode::None => String::new(),
74 };
75 Ok(Self::new_with_connection(api_key, config, connection))
76 }
77
78 pub fn new(api_key: String, config: &'static ProviderConfig) -> Self {
79 Self::new_with_connection(api_key, config, ProviderConnectionConfig::default())
80 }
81
82 pub fn new_with_connection(
83 api_key: String,
84 config: &'static ProviderConfig,
85 connection: ProviderConnectionConfig,
86 ) -> Self {
87 let api_base = connection.base_url.unwrap_or_else(|| config.api_base.to_string());
88 let openai_config = OpenAIConfig::new().with_api_key(api_key).with_api_base(api_base);
89 let openai_config = AetherOpenAiConfig::new(openai_config, connection.auth_mode);
90
91 Self { client: Client::with_config(openai_config), model: config.default_model.to_string(), config }
92 }
93
94 pub fn with_model(mut self, model: &str) -> Self {
95 self.model = model.to_string();
96 self
97 }
98}
99
100impl StreamingModelProvider for GenericOpenAiProvider {
101 fn model(&self) -> Option<LlmModel> {
102 format!("{}:{}", self.config.prefix, self.model).parse().ok()
103 }
104
105 fn context_window(&self) -> Option<u32> {
106 get_context_window(self.config.prefix, &self.model)
107 }
108
109 fn stream_response(&self, context: &Context) -> LlmResponseStream {
110 let request = match build_chat_request(&self.model, context, self.config.tool_schema_transform) {
111 Ok(req) => req,
112 Err(e) => return Box::pin(async_stream::stream! { yield Err(e); }),
113 };
114 create_custom_stream_generic(&self.client, request)
115 }
116
117 fn display_name(&self) -> String {
118 format!("{} ({})", self.config.display_name, self.model)
119 }
120}