1pub mod claude;
2pub mod ollama;
3pub mod openai;
4pub mod retry;
5
6use std::pin::Pin;
7
8use anyhow::Result;
9use async_trait::async_trait;
10use futures::Stream;
11
12use crate::config::provider::ProviderConfig;
13use crate::conversation::message::Message;
14use crate::stream::StreamEvent;
15use crate::tool::ToolDef;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum ReasoningPolicy {
38 Include,
43 Exclude,
46}
47
48pub const REASONING_PLACEHOLDER: &str = "(no reasoning recorded)";
60
61#[async_trait]
62pub trait LlmProvider: Send + Sync {
63 fn chat_stream(
64 &self,
65 messages: &[Message],
66 tools: Option<&[ToolDef]>,
67 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>>;
68
69 fn model_name(&self) -> &str;
70
71 fn availability_error(&self) -> Option<&str> {
72 None
73 }
74
75 fn reasoning_history_policy(&self) -> ReasoningPolicy {
80 ReasoningPolicy::Exclude
81 }
82}
83
84pub(super) fn build_http_client(ua_override: Option<&str>, skip_tls_verify: bool) -> reqwest::Client {
90 let ua = ua_override.unwrap_or(crate::ATOMCODE_USER_AGENT);
91 let mut builder = reqwest::Client::builder()
92 .connect_timeout(std::time::Duration::from_secs(30))
93 .timeout(std::time::Duration::from_secs(1800))
103 .user_agent(ua);
104
105 if skip_tls_verify {
106 builder = builder.danger_accept_invalid_certs(true);
107 }
108
109 builder.build().unwrap_or_else(|_| reqwest::Client::new())
110}
111
112pub(super) fn format_http_error(
140 status: reqwest::StatusCode,
141 url: &str,
142 msg: &str,
143) -> String {
144 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
145 format!("[429] {}", msg)
146 } else {
147 format!("API error ({}) at `{}`:\n{}", status, url, msg)
148 }
149}
150
151pub(super) fn extract_error_message(body: &str) -> String {
152 let trimmed = body.trim();
153 if let Ok(v) = serde_json::from_str::<serde_json::Value>(trimmed) {
154 if let Some(detail) = v.get("detail") {
155 if let Some(msg) = detail.get("message").and_then(|m| m.as_str()) {
156 return msg.to_string();
157 }
158 if let Some(s) = detail.as_str() {
159 return s.to_string();
160 }
161 }
162 if let Some(msg) = v
163 .get("error")
164 .and_then(|e| e.get("message"))
165 .and_then(|m| m.as_str())
166 {
167 return msg.to_string();
168 }
169 if let Some(msg) = v.get("message").and_then(|m| m.as_str()) {
170 return msg.to_string();
171 }
172 }
173 trimmed.to_string()
174}
175
176#[cfg(test)]
177mod extract_error_message_tests {
178 use super::extract_error_message;
179
180 #[test]
181 fn openai_envelope_codingplan_rate_limit() {
182 let body = r#"{"error":{"message":"codingplan rate limit exceeded for type='Pro'","type":"auth_error","param":"None","code":"429"}}"#;
189 assert_eq!(
190 extract_error_message(body),
191 "codingplan rate limit exceeded for type='Pro'"
192 );
193 }
194
195 #[test]
196 fn openai_envelope_no_deployments_available() {
197 let body = r#"{"error":{"message":"No deployments available for selected model. Try again in 30 seconds. Passed model=deepseek-v4-flash.","type":"None","param":"None","code":"429"}}"#;
202 let out = extract_error_message(body);
203 assert!(out.starts_with("No deployments available"));
204 assert!(out.contains("Try again in 30 seconds"));
205 assert!(!out.contains("\"code\""), "envelope keys must not leak");
206 }
207
208 #[test]
209 fn atomcode_detail_envelope() {
210 let body = r#"{"detail":{"code":"X","message":"detail message body"}}"#;
215 assert_eq!(extract_error_message(body), "detail message body");
216 }
217
218 #[test]
219 fn fastapi_string_detail() {
220 let body = r#"{"detail":"plain string detail"}"#;
221 assert_eq!(extract_error_message(body), "plain string detail");
222 }
223
224 #[test]
225 fn top_level_message() {
226 let body = r#"{"message":"top-level message"}"#;
227 assert_eq!(extract_error_message(body), "top-level message");
228 }
229
230 #[test]
231 fn non_json_body_passes_through_trimmed() {
232 assert_eq!(
236 extract_error_message(" upstream timeout "),
237 "upstream timeout"
238 );
239 }
240}
241
242#[cfg(test)]
243mod format_http_error_tests {
244 use super::format_http_error;
245 use reqwest::StatusCode;
246
247 #[test]
248 fn rate_limit_compresses_to_bracketed_form() {
249 assert_eq!(
250 format_http_error(
251 StatusCode::TOO_MANY_REQUESTS,
252 "https://llm-api.atomgit.com/v1/chat/completions",
253 "codingplan rate limit exceeded for type='Pro'",
254 ),
255 "[429] codingplan rate limit exceeded for type='Pro'"
256 );
257 }
258
259 #[test]
260 fn rate_limit_preserves_retry_matcher_keywords() {
261 let out = format_http_error(
267 StatusCode::TOO_MANY_REQUESTS,
268 "https://x",
269 "codingplan rate limit exceeded for type='Pro'",
270 );
271 assert!(out.contains("429"), "must contain literal `429`");
272 assert!(out.contains("rate"), "must contain `rate` for matcher");
273 }
274
275 #[test]
276 fn rate_limit_with_chinese_upstream_message_still_matches() {
277 let out = format_http_error(
281 StatusCode::TOO_MANY_REQUESTS,
282 "https://x",
283 "请求过于频繁,请稍后再试",
284 );
285 assert!(out.contains("429"));
286 assert!(out.contains("请求过于频繁"));
287 }
288
289 #[test]
290 fn non_rate_limit_keeps_verbose_form() {
291 let out = format_http_error(
295 StatusCode::INTERNAL_SERVER_ERROR,
296 "https://x/v1/chat/completions",
297 "upstream gateway timeout",
298 );
299 assert!(out.contains("500"));
300 assert!(out.contains("https://x/v1/chat/completions"));
301 assert!(out.contains("upstream gateway timeout"));
302 }
303
304 #[test]
305 fn bad_request_keeps_url_for_diagnostics() {
306 let out = format_http_error(
309 StatusCode::BAD_REQUEST,
310 "https://x/v1/chat/completions",
311 "Invalid model `xyz`",
312 );
313 assert!(out.contains("400"));
314 assert!(out.contains("https://x"));
315 assert!(out.contains("Invalid model"));
316 }
317}
318
319pub fn create_provider(config: &ProviderConfig) -> Result<Box<dyn LlmProvider>> {
323 let mut config = if config.api_key.is_none() && config.provider_type != "ollama" {
324 let mut c = config.clone();
325 c.api_key = Some(load_auth_token()?);
326 c
327 } else {
328 config.clone()
329 };
330 if let Some(key) = config.api_key.as_deref() {
338 let trimmed = key.trim();
339 if trimmed.is_empty() {
340 anyhow::bail!(
341 "API key for provider type '{}' is empty (or whitespace only) \
342 — check the value in your config.toml",
343 config.provider_type
344 );
345 }
346 if trimmed.chars().any(|c| c.is_control()) {
347 anyhow::bail!(
348 "API key for provider type '{}' contains control characters \
349 (newline/tab/etc.) — re-copy the key without surrounding \
350 whitespace",
351 config.provider_type
352 );
353 }
354 if trimmed.len() != key.len() {
355 config.api_key = Some(trimmed.to_string());
358 }
359 }
360 match config.provider_type.as_str() {
361 "claude" => Ok(Box::new(claude::ClaudeProvider::new(&config)?)),
362 "openai" => Ok(Box::new(openai::OpenAiProvider::new(&config)?)),
363 "ollama" => Ok(Box::new(ollama::OllamaProvider::new(&config)?)),
364 other => anyhow::bail!("Unknown provider type: {}", other),
365 }
366}
367
368pub fn unavailable_provider(reason: impl Into<String>) -> Box<dyn LlmProvider> {
369 Box::new(UnavailableProvider {
370 reason: reason.into(),
371 })
372}
373
374struct UnavailableProvider {
375 reason: String,
376}
377
378#[async_trait]
379impl LlmProvider for UnavailableProvider {
380 fn chat_stream(
381 &self,
382 _messages: &[Message],
383 _tools: Option<&[ToolDef]>,
384 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
385 anyhow::bail!("{}", self.reason);
386 }
387
388 fn model_name(&self) -> &str {
389 ""
390 }
391
392 fn availability_error(&self) -> Option<&str> {
393 Some(&self.reason)
394 }
395}
396
397#[derive(serde::Deserialize)]
403struct StoredAuth {
404 access_token: String,
405 #[serde(default)]
406 refresh_token: Option<String>,
407 #[serde(default)]
408 expires_in: Option<i64>,
409 #[serde(default)]
410 created_at: i64,
411}
412
413fn load_auth_token() -> Result<String> {
416 let auth_path = crate::auth::auth_file_path();
417 let content = std::fs::read_to_string(&auth_path)
418 .map_err(|_| anyhow::anyhow!("Not logged in — please use /login"))?;
419 let auth: StoredAuth = toml::from_str(&content)
420 .map_err(|_| anyhow::anyhow!("Invalid auth.toml — please use /login"))?;
421
422 if let Some(expires_in) = auth.expires_in {
424 let now = std::time::SystemTime::now()
425 .duration_since(std::time::UNIX_EPOCH)
426 .unwrap()
427 .as_secs() as i64;
428 if now >= auth.created_at + expires_in - 300 {
429 if let Some(ref rt) = auth.refresh_token {
431 return refresh_and_save(rt, &auth_path);
432 }
433 anyhow::bail!("Token expired — please use /login");
434 }
435 }
436
437 Ok(auth.access_token)
438}
439
440fn refresh_and_save(refresh_token: &str, auth_path: &std::path::Path) -> Result<String> {
442 let client = reqwest::blocking::Client::builder()
446 .connect_timeout(std::time::Duration::from_secs(5))
447 .timeout(std::time::Duration::from_secs(10))
448 .build()
449 .unwrap_or_else(|_| reqwest::blocking::Client::new());
450 let builder = client
451 .post(crate::auth::oauth::platform_refresh_url())
452 .json(&serde_json::json!({ "refresh_token": refresh_token, "provider": "atomgit" }));
453 let policy = crate::provider::retry::RetryPolicy::default_policy();
454 let resp = crate::provider::retry::send_with_retry_blocking(builder, &policy)
455 .map_err(|e| anyhow::anyhow!("Token refresh failed: {} — please /login", e))?;
456
457 if !resp.status().is_success() {
458 anyhow::bail!("Token refresh failed ({}) — please /login", resp.status());
459 }
460
461 #[derive(serde::Deserialize)]
462 struct RefreshedAuth {
463 access_token: String,
464 #[serde(default)]
465 token_type: Option<String>,
466 #[serde(default)]
467 refresh_token: Option<String>,
468 #[serde(default)]
469 expires_in: Option<i64>,
470 #[serde(default)]
471 user: Option<RefreshedUser>,
472 }
473
474 #[derive(serde::Deserialize)]
475 struct RefreshedUser {
476 id: String,
477 username: String,
478 #[serde(default)]
479 name: Option<String>,
480 #[serde(default)]
481 email: Option<String>,
482 #[serde(default)]
483 avatar_url: Option<String>,
484 }
485
486 let token: RefreshedAuth = resp
487 .json()
488 .map_err(|e| anyhow::anyhow!("Token refresh parse error: {} — please /login", e))?;
489
490 let token_type = token.token_type.as_deref().unwrap_or("Bearer");
492
493 let now = std::time::SystemTime::now()
495 .duration_since(std::time::UNIX_EPOCH)
496 .unwrap()
497 .as_secs();
498 let new_rt = token.refresh_token.as_deref().unwrap_or(refresh_token);
499 let mut content = format!(
500 "access_token = \"{}\"\ncreated_at = {}\nrefresh_token = \"{}\"\n",
501 token.access_token, now, new_rt,
502 );
503 if let Some(e) = token.expires_in {
504 content.push_str(&format!("expires_in = {}\n", e));
505 }
506 content.push_str(&format!("token_type = \"{}\"\n", token_type));
507 if let Some(user) = token.user {
508 content.push_str(&format!(
509 "\n[user]\nid = \"{}\"\nusername = \"{}\"\n",
510 user.id, user.username,
511 ));
512 if let Some(name) = user.name {
513 content.push_str(&format!("name = \"{}\"\n", name));
514 }
515 if let Some(email) = user.email {
516 content.push_str(&format!("email = \"{}\"\n", email));
517 }
518 if let Some(avatar_url) = user.avatar_url {
519 content.push_str(&format!("avatar_url = \"{}\"\n", avatar_url));
520 }
521 }
522 let _ = crate::auth::write_auth_file_secure(auth_path, &content);
523
524 Ok(token.access_token)
525}
526
527pub fn model_name_suggests_vision(name: &str) -> bool {
552 let n = name.to_lowercase();
553 n.contains("vision")
554 || n.contains("-vl")
555 || n.contains("vl-")
556 || n.contains("ocr")
557 || n.contains("-4v")
558 || n.contains("-4.1v")
559 || n.starts_with("gpt-4o")
560 || n.starts_with("claude-3")
564 || n.starts_with("claude-4")
565 || n.starts_with("claude-5")
566 || n.starts_with("claude-6")
567 || n.starts_with("claude-7")
568 || n.starts_with("claude-sonnet")
569 || n.starts_with("claude-opus")
570 || n.starts_with("claude-haiku")
571 || n.starts_with("gemini")
572 || n.starts_with("pixtral")
573 || n.contains("llava")
574 || n.contains("qvq")
575}
576
577#[cfg(test)]
578mod tests {
579 use super::{model_name_suggests_vision, unavailable_provider};
580
581 #[test]
585 fn test_auth_token_path_consistency() {
586 let auth_module_path = crate::auth::auth_file_path();
588 let expected_path = crate::tool::real_home_dir()
589 .unwrap_or_else(|| std::path::PathBuf::from("."))
590 .join(".atomcode")
591 .join("auth.toml");
592
593 assert_eq!(
594 auth_module_path, expected_path,
595 "auth_file_path() should always return ~/.atomcode/auth.toml"
596 );
597
598 assert!(
600 auth_module_path.ends_with(".atomcode/auth.toml")
601 || auth_module_path.ends_with(".atomcode\\auth.toml"), "Path should end with .atomcode/auth.toml, got: {}",
603 auth_module_path.display()
604 );
605 }
606
607 use crate::config::provider::ProviderConfig;
608
609 fn cfg(provider_type: &str, api_key: &str) -> ProviderConfig {
610 ProviderConfig {
611 provider_type: provider_type.to_string(),
612 api_key: Some(api_key.to_string()),
613 model: "m".to_string(),
614 base_url: Some("http://127.0.0.1:1/".to_string()),
615 system_prompt: None,
616 user_agent: None,
617 context_window: 8000,
618 max_tokens: None,
619 thinking_type: None,
620 thinking_keep: None,
621 reasoning_history: None,
622 thinking_enabled: None,
623 thinking_budget: None,
624 skip_tls_verify: false,
625 ephemeral: false,
626
627}
628 }
629
630 #[test]
631 fn unavailable_provider_reports_reason() {
632 let provider = unavailable_provider("未配置 provider");
633 assert_eq!(provider.model_name(), "");
634 assert_eq!(provider.availability_error(), Some("未配置 provider"));
635 }
636
637 #[test]
644 fn create_provider_rejects_api_key_with_internal_control_chars() {
645 let result = super::create_provider(&cfg("openai", "sk-ab\nc"));
646 let err = match result {
647 Err(e) => e,
648 Ok(_) => panic!("expected Err for api_key with internal \\n"),
649 };
650 let msg = err.to_string();
651 assert!(
652 msg.contains("control character"),
653 "expected control-char error, got: {}",
654 msg
655 );
656 }
657
658 #[test]
662 fn create_provider_silently_trims_trailing_newline() {
663 let result = super::create_provider(&cfg("openai", "sk-abc\n"));
664 assert!(
665 result.is_ok(),
666 "trailing \\n should be trimmed silently, got: {:?}",
667 result.err().map(|e| e.to_string())
668 );
669 }
670
671 #[test]
672 fn create_provider_rejects_empty_or_whitespace_api_key() {
673 let result = super::create_provider(&cfg("openai", " "));
674 let err = match result {
675 Err(e) => e,
676 Ok(_) => panic!("expected Err for whitespace-only api_key"),
677 };
678 let msg = err.to_string();
679 assert!(
680 msg.contains("empty") || msg.contains("whitespace"),
681 "expected empty/whitespace error, got: {}",
682 msg
683 );
684 }
685
686 #[test]
690 fn create_provider_silently_trims_surrounding_whitespace() {
691 let result = super::create_provider(&cfg("openai", " sk-abc "));
692 assert!(
693 result.is_ok(),
694 "trimmable key should be accepted, got: {:?}",
695 result.err().map(|e| e.to_string())
696 );
697 }
698
699 #[test]
702 fn vision_heuristic_recognises_known_vision_models() {
703 assert!(model_name_suggests_vision("claude-3-5-sonnet"));
705 assert!(model_name_suggests_vision("claude-4-opus"));
706 assert!(model_name_suggests_vision("claude-sonnet-4-6"));
707 assert!(model_name_suggests_vision("gpt-4o"));
709 assert!(model_name_suggests_vision("gpt-4o-mini"));
710 assert!(model_name_suggests_vision("gpt-4-vision-preview"));
711 assert!(model_name_suggests_vision("GLM-4V"));
713 assert!(model_name_suggests_vision("glm-4.1v-thinking"));
714 assert!(model_name_suggests_vision("Qwen2-VL-7B"));
716 assert!(model_name_suggests_vision("deepseek-vl"));
717 assert!(model_name_suggests_vision("gemini-2.0-flash"));
719 assert!(model_name_suggests_vision("pixtral-12b"));
720 assert!(model_name_suggests_vision("llava-1.6"));
721 assert!(model_name_suggests_vision("qvq-72b-preview"));
722 }
723
724 #[test]
730 fn vision_heuristic_rejects_text_only_models() {
731 assert!(!model_name_suggests_vision("GLM-5.1"));
732 assert!(!model_name_suggests_vision("glm-5.1"));
733 assert!(!model_name_suggests_vision("deepseek-v4-flash"));
734 assert!(!model_name_suggests_vision("Qwen/Qwen3.6-35B-A3B"));
735 assert!(!model_name_suggests_vision("gpt-4-turbo")); assert!(!model_name_suggests_vision("kimi-k2-thinking"));
737 assert!(!model_name_suggests_vision("o1-preview")); assert!(!model_name_suggests_vision(""));
739 }
740
741 #[test]
745 fn vision_heuristic_recognises_ocr_models() {
746 assert!(model_name_suggests_vision("PaddleOCR-VL-0.9B"));
748 assert!(model_name_suggests_vision("Qwen2-VL-OCR-7B"));
749 assert!(model_name_suggests_vision("GOT-OCR-2.0"));
751 assert!(model_name_suggests_vision("PaddleOCR-2.0"));
752 assert!(model_name_suggests_vision("MinerU-OCR"));
753 assert!(model_name_suggests_vision("MonkeyOCR-1.2B"));
754 assert!(model_name_suggests_vision("got-ocr-1.0")); }
756
757 #[test]
763 fn vision_heuristic_documented_false_positives() {
764 assert!(!model_name_suggests_vision("focar-text-7b"));
768 }
769}