claude_code_switcher/templates/
kat_coder.rs

1//! KatCoder (WanQing) AI provider template implementation
2
3use crate::{
4    credentials::CredentialStore,
5    settings::{ClaudeSettings, Permissions},
6    simple_selector::get_endpoint_id_interactively,
7    snapshots::SnapshotScope,
8    templates::Template,
9};
10use anyhow::{Result, anyhow};
11use atty;
12use inquire::Select;
13use std::collections::HashMap;
14
15/// KatCoder AI provider variants
16#[derive(Debug, Clone)]
17pub enum KatCoderVariant {
18    Pro,
19    Air,
20}
21
22impl KatCoderVariant {
23    pub fn display_name(&self) -> &'static str {
24        match self {
25            KatCoderVariant::Pro => "KatCoder Pro (WanQing)",
26            KatCoderVariant::Air => "KatCoder Air (WanQing)",
27        }
28    }
29
30    pub fn description(&self) -> &'static str {
31        match self {
32            KatCoderVariant::Pro => {
33                "WanQing KAT-Coder Pro V1 - Professional coding AI with advanced capabilities"
34            }
35            KatCoderVariant::Air => {
36                "WanQing KAT-Coder Air V1 - Lightweight coding AI with fast response"
37            }
38        }
39    }
40
41    pub fn model_name(&self) -> &'static str {
42        match self {
43            KatCoderVariant::Pro => "KAT-Coder-Pro-V1",
44            KatCoderVariant::Air => "KAT-Coder-Air-V1",
45        }
46    }
47}
48
49/// KatCoder AI provider template
50#[derive(Debug, Clone)]
51pub struct KatCoderTemplate {
52    variant: KatCoderVariant,
53}
54
55impl KatCoderTemplate {
56    pub fn new(variant: KatCoderVariant) -> Self {
57        Self { variant }
58    }
59
60    pub fn pro() -> Self {
61        Self::new(KatCoderVariant::Pro)
62    }
63
64    pub fn air() -> Self {
65        Self::new(KatCoderVariant::Air)
66    }
67}
68
69impl Template for KatCoderTemplate {
70    fn template_type(&self) -> crate::templates::TemplateType {
71        crate::templates::TemplateType::KatCoder
72    }
73
74    fn env_var_name(&self) -> &'static str {
75        "KAT_CODER_API_KEY"
76    }
77
78    fn display_name(&self) -> &'static str {
79        self.variant.display_name()
80    }
81
82    fn description(&self) -> &'static str {
83        self.variant.description()
84    }
85
86    fn api_key_url(&self) -> Option<&'static str> {
87        Some("https://console.volcengine.com/ark/region:ark+cn-beijing/apikey")
88    }
89
90    fn has_variants(&self) -> bool {
91        true
92    }
93
94    fn get_variants() -> Result<Vec<Self>>
95    where
96        Self: Sized,
97    {
98        Ok(vec![Self::pro(), Self::air()])
99    }
100
101    fn create_interactively() -> Result<Self>
102    where
103        Self: Sized,
104    {
105        if !atty::is(atty::Stream::Stdin) {
106            return Err(anyhow!(
107                "KatCoder requires interactive mode to select variant. Use 'kat-coder-pro' or 'kat-coder-air' explicitly if not in interactive mode."
108            ));
109        }
110
111        let variants = [
112            (
113                "KatCoder Pro",
114                "Professional coding AI with advanced capabilities",
115            ),
116            ("KatCoder Air", "Lightweight coding AI with fast response"),
117        ];
118
119        let options: Vec<String> = variants.iter().map(|(name, _)| name.to_string()).collect();
120
121        let choice = Select::new("Select KatCoder variant:", options)
122            .prompt()
123            .map_err(|e| anyhow!("Failed to get variant selection: {}", e))?;
124
125        let template = match choice.as_str() {
126            "KatCoder Pro" => Self::pro(),
127            "KatCoder Air" => Self::air(),
128            _ => unreachable!(),
129        };
130
131        Ok(template)
132    }
133
134    fn requires_additional_config(&self) -> bool {
135        true
136    }
137
138    fn get_additional_config(&self) -> Result<HashMap<String, String>> {
139        let endpoint_id = get_kat_coder_endpoint_id()?;
140        let mut config = HashMap::new();
141        config.insert("endpoint_id".to_string(), endpoint_id);
142        Ok(config)
143    }
144
145    fn create_settings(&self, api_key: &str, scope: &SnapshotScope) -> ClaudeSettings {
146        let mut settings = ClaudeSettings::new();
147
148        // Get endpoint ID for KatCoder
149        let endpoint_id = get_kat_coder_endpoint_id().unwrap_or_else(|_| "default".to_string());
150        let base_url = format!(
151            "https://wanqing.streamlakeapi.com/api/gateway/v1/endpoints/{}/claude-code-proxy",
152            endpoint_id
153        );
154
155        if matches!(scope, SnapshotScope::Common | SnapshotScope::All) {
156            settings.model = Some(self.variant.model_name().to_string());
157
158            settings.permissions = Some(Permissions {
159                allow: Some(vec![
160                    "Bash".to_string(),
161                    "Read".to_string(),
162                    "Write".to_string(),
163                    "Edit".to_string(),
164                    "MultiEdit".to_string(),
165                    "Glob".to_string(),
166                    "Grep".to_string(),
167                    "WebFetch".to_string(),
168                ]),
169                ask: None,
170                deny: Some(vec!["WebSearch".to_string()]),
171                additional_directories: None,
172                default_mode: None,
173                disable_bypass_permissions_mode: None,
174            });
175        }
176
177        if matches!(scope, SnapshotScope::Env | SnapshotScope::All) {
178            let mut env = HashMap::new();
179            env.insert("ANTHROPIC_AUTH_TOKEN".to_string(), api_key.to_string());
180            env.insert("ANTHROPIC_BASE_URL".to_string(), base_url);
181            env.insert(
182                "ANTHROPIC_MODEL".to_string(),
183                self.variant.model_name().to_string(),
184            );
185            env.insert(
186                "ANTHROPIC_SMALL_FAST_MODEL".to_string(),
187                self.variant.model_name().to_string(),
188            );
189            env.insert("API_TIMEOUT_MS".to_string(), "600000".to_string());
190            env.insert(
191                "CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC".to_string(),
192                "1".to_string(),
193            );
194            settings.env = Some(env);
195        }
196
197        settings
198    }
199}
200
201/// Get KatCoder endpoint ID from environment or prompt user
202fn get_kat_coder_endpoint_id() -> Result<String> {
203    // Try to get from environment first
204    let env_var = "WANQING_ENDPOINT_ID";
205
206    if let Ok(id) = std::env::var(env_var) {
207        println!(
208            "  ✓ Using endpoint ID from environment variable {}",
209            env_var
210        );
211        return Ok(id);
212    }
213
214    // If not found and we're in non-interactive mode, error
215    if !atty::is(atty::Stream::Stdin) {
216        return Err(anyhow!(
217            "Endpoint ID required for kat-coder template. Set {} environment variable or use interactive mode.",
218            env_var
219        ));
220    }
221
222    // Use interactive endpoint ID selector
223    let endpoint_id = get_endpoint_id_interactively(&crate::templates::TemplateType::KatCoder)?;
224
225    // Auto-save the endpoint ID if it's new and we have credentials
226    if let Ok(credential_store) = CredentialStore::new()
227        && let Ok(credentials) = credential_store
228            .store
229            .find_by_template_type(&crate::templates::TemplateType::KatCoder)
230        && !credentials.is_empty()
231    {
232        // Save endpoint ID to the most recent credential
233        let most_recent = credentials.iter().max_by_key(|c| c.created_at());
234        if let Some(credential) = most_recent {
235            if credential_store
236                .has_endpoint_id(&endpoint_id, &crate::templates::TemplateType::KatCoder)
237            {
238                println!("  ✓ Endpoint ID already saved for KatCoder");
239            } else if credential_store
240                .save_endpoint_id(credential.id(), &endpoint_id)
241                .is_ok()
242            {
243                println!("  ✓ Endpoint ID saved automatically for future use");
244            }
245        }
246    }
247
248    Ok(endpoint_id)
249}
250
251/// Create KatCoder template settings (legacy compatibility function)
252pub fn create_kat_coder_template(api_key: &str, scope: &SnapshotScope) -> ClaudeSettings {
253    let template = KatCoderTemplate::pro(); // Default to Pro for backward compatibility
254    template.create_settings(api_key, scope)
255}