pub mod claude;
pub mod ollama;
pub mod openai;
pub mod retry;
use std::pin::Pin;
use anyhow::Result;
use async_trait::async_trait;
use futures::Stream;
use crate::config::provider::ProviderConfig;
use crate::conversation::message::Message;
use crate::stream::StreamEvent;
use crate::tool::ToolDef;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReasoningPolicy {
Include,
Exclude,
}
pub const REASONING_PLACEHOLDER: &str = "(no reasoning recorded)";
#[async_trait]
pub trait LlmProvider: Send + Sync {
fn chat_stream(
&self,
messages: &[Message],
tools: Option<&[ToolDef]>,
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>>;
fn model_name(&self) -> &str;
fn availability_error(&self) -> Option<&str> {
None
}
fn reasoning_history_policy(&self) -> ReasoningPolicy {
ReasoningPolicy::Exclude
}
}
pub(super) fn build_http_client(ua_override: Option<&str>, skip_tls_verify: bool) -> reqwest::Client {
let ua = ua_override.unwrap_or(crate::ATOMCODE_USER_AGENT);
let mut builder = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(30))
.timeout(std::time::Duration::from_secs(1800))
.user_agent(ua);
if skip_tls_verify {
builder = builder.danger_accept_invalid_certs(true);
}
builder.build().unwrap_or_else(|_| reqwest::Client::new())
}
pub(super) fn format_http_error(
status: reqwest::StatusCode,
url: &str,
msg: &str,
) -> String {
if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
format!("[429] {}", msg)
} else {
format!("API error ({}) at `{}`:\n{}", status, url, msg)
}
}
pub(super) fn extract_error_message(body: &str) -> String {
let trimmed = body.trim();
if let Ok(v) = serde_json::from_str::<serde_json::Value>(trimmed) {
if let Some(detail) = v.get("detail") {
if let Some(msg) = detail.get("message").and_then(|m| m.as_str()) {
return msg.to_string();
}
if let Some(s) = detail.as_str() {
return s.to_string();
}
}
if let Some(msg) = v
.get("error")
.and_then(|e| e.get("message"))
.and_then(|m| m.as_str())
{
return msg.to_string();
}
if let Some(msg) = v.get("message").and_then(|m| m.as_str()) {
return msg.to_string();
}
}
trimmed.to_string()
}
#[cfg(test)]
mod extract_error_message_tests {
use super::extract_error_message;
#[test]
fn openai_envelope_codingplan_rate_limit() {
let body = r#"{"error":{"message":"codingplan rate limit exceeded for type='Pro'","type":"auth_error","param":"None","code":"429"}}"#;
assert_eq!(
extract_error_message(body),
"codingplan rate limit exceeded for type='Pro'"
);
}
#[test]
fn openai_envelope_no_deployments_available() {
let body = r#"{"error":{"message":"No deployments available for selected model. Try again in 30 seconds. Passed model=deepseek-v4-flash.","type":"None","param":"None","code":"429"}}"#;
let out = extract_error_message(body);
assert!(out.starts_with("No deployments available"));
assert!(out.contains("Try again in 30 seconds"));
assert!(!out.contains("\"code\""), "envelope keys must not leak");
}
#[test]
fn atomcode_detail_envelope() {
let body = r#"{"detail":{"code":"X","message":"detail message body"}}"#;
assert_eq!(extract_error_message(body), "detail message body");
}
#[test]
fn fastapi_string_detail() {
let body = r#"{"detail":"plain string detail"}"#;
assert_eq!(extract_error_message(body), "plain string detail");
}
#[test]
fn top_level_message() {
let body = r#"{"message":"top-level message"}"#;
assert_eq!(extract_error_message(body), "top-level message");
}
#[test]
fn non_json_body_passes_through_trimmed() {
assert_eq!(
extract_error_message(" upstream timeout "),
"upstream timeout"
);
}
}
#[cfg(test)]
mod format_http_error_tests {
use super::format_http_error;
use reqwest::StatusCode;
#[test]
fn rate_limit_compresses_to_bracketed_form() {
assert_eq!(
format_http_error(
StatusCode::TOO_MANY_REQUESTS,
"https://llm-api.atomgit.com/v1/chat/completions",
"codingplan rate limit exceeded for type='Pro'",
),
"[429] codingplan rate limit exceeded for type='Pro'"
);
}
#[test]
fn rate_limit_preserves_retry_matcher_keywords() {
let out = format_http_error(
StatusCode::TOO_MANY_REQUESTS,
"https://x",
"codingplan rate limit exceeded for type='Pro'",
);
assert!(out.contains("429"), "must contain literal `429`");
assert!(out.contains("rate"), "must contain `rate` for matcher");
}
#[test]
fn rate_limit_with_chinese_upstream_message_still_matches() {
let out = format_http_error(
StatusCode::TOO_MANY_REQUESTS,
"https://x",
"请求过于频繁,请稍后再试",
);
assert!(out.contains("429"));
assert!(out.contains("请求过于频繁"));
}
#[test]
fn non_rate_limit_keeps_verbose_form() {
let out = format_http_error(
StatusCode::INTERNAL_SERVER_ERROR,
"https://x/v1/chat/completions",
"upstream gateway timeout",
);
assert!(out.contains("500"));
assert!(out.contains("https://x/v1/chat/completions"));
assert!(out.contains("upstream gateway timeout"));
}
#[test]
fn bad_request_keeps_url_for_diagnostics() {
let out = format_http_error(
StatusCode::BAD_REQUEST,
"https://x/v1/chat/completions",
"Invalid model `xyz`",
);
assert!(out.contains("400"));
assert!(out.contains("https://x"));
assert!(out.contains("Invalid model"));
}
}
pub fn create_provider(config: &ProviderConfig) -> Result<Box<dyn LlmProvider>> {
let mut config = if config.api_key.is_none() && config.provider_type != "ollama" {
let mut c = config.clone();
c.api_key = Some(load_auth_token()?);
c
} else {
config.clone()
};
if let Some(key) = config.api_key.as_deref() {
let trimmed = key.trim();
if trimmed.is_empty() {
anyhow::bail!(
"API key for provider type '{}' is empty (or whitespace only) \
— check the value in your config.toml",
config.provider_type
);
}
if trimmed.chars().any(|c| c.is_control()) {
anyhow::bail!(
"API key for provider type '{}' contains control characters \
(newline/tab/etc.) — re-copy the key without surrounding \
whitespace",
config.provider_type
);
}
if trimmed.len() != key.len() {
config.api_key = Some(trimmed.to_string());
}
}
match config.provider_type.as_str() {
"claude" => Ok(Box::new(claude::ClaudeProvider::new(&config)?)),
"openai" => Ok(Box::new(openai::OpenAiProvider::new(&config)?)),
"ollama" => Ok(Box::new(ollama::OllamaProvider::new(&config)?)),
other => anyhow::bail!("Unknown provider type: {}", other),
}
}
pub fn unavailable_provider(reason: impl Into<String>) -> Box<dyn LlmProvider> {
Box::new(UnavailableProvider {
reason: reason.into(),
})
}
struct UnavailableProvider {
reason: String,
}
#[async_trait]
impl LlmProvider for UnavailableProvider {
fn chat_stream(
&self,
_messages: &[Message],
_tools: Option<&[ToolDef]>,
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
anyhow::bail!("{}", self.reason);
}
fn model_name(&self) -> &str {
""
}
fn availability_error(&self) -> Option<&str> {
Some(&self.reason)
}
}
#[derive(serde::Deserialize)]
struct StoredAuth {
access_token: String,
#[serde(default)]
refresh_token: Option<String>,
#[serde(default)]
expires_in: Option<i64>,
#[serde(default)]
created_at: i64,
}
fn load_auth_token() -> Result<String> {
let auth_path = crate::auth::auth_file_path();
let content = std::fs::read_to_string(&auth_path)
.map_err(|_| anyhow::anyhow!("Not logged in — please use /login"))?;
let auth: StoredAuth = toml::from_str(&content)
.map_err(|_| anyhow::anyhow!("Invalid auth.toml — please use /login"))?;
if let Some(expires_in) = auth.expires_in {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
if now >= auth.created_at + expires_in - 300 {
if let Some(ref rt) = auth.refresh_token {
return refresh_and_save(rt, &auth_path);
}
anyhow::bail!("Token expired — please use /login");
}
}
Ok(auth.access_token)
}
fn refresh_and_save(refresh_token: &str, auth_path: &std::path::Path) -> Result<String> {
let client = reqwest::blocking::Client::builder()
.connect_timeout(std::time::Duration::from_secs(5))
.timeout(std::time::Duration::from_secs(10))
.build()
.unwrap_or_else(|_| reqwest::blocking::Client::new());
let builder = client
.post(crate::auth::oauth::platform_refresh_url())
.json(&serde_json::json!({ "refresh_token": refresh_token, "provider": "atomgit" }));
let policy = crate::provider::retry::RetryPolicy::default_policy();
let resp = crate::provider::retry::send_with_retry_blocking(builder, &policy)
.map_err(|e| anyhow::anyhow!("Token refresh failed: {} — please /login", e))?;
if !resp.status().is_success() {
anyhow::bail!("Token refresh failed ({}) — please /login", resp.status());
}
#[derive(serde::Deserialize)]
struct RefreshedAuth {
access_token: String,
#[serde(default)]
token_type: Option<String>,
#[serde(default)]
refresh_token: Option<String>,
#[serde(default)]
expires_in: Option<i64>,
#[serde(default)]
user: Option<RefreshedUser>,
}
#[derive(serde::Deserialize)]
struct RefreshedUser {
id: String,
username: String,
#[serde(default)]
name: Option<String>,
#[serde(default)]
email: Option<String>,
#[serde(default)]
avatar_url: Option<String>,
}
let token: RefreshedAuth = resp
.json()
.map_err(|e| anyhow::anyhow!("Token refresh parse error: {} — please /login", e))?;
let token_type = token.token_type.as_deref().unwrap_or("Bearer");
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
let new_rt = token.refresh_token.as_deref().unwrap_or(refresh_token);
let mut content = format!(
"access_token = \"{}\"\ncreated_at = {}\nrefresh_token = \"{}\"\n",
token.access_token, now, new_rt,
);
if let Some(e) = token.expires_in {
content.push_str(&format!("expires_in = {}\n", e));
}
content.push_str(&format!("token_type = \"{}\"\n", token_type));
if let Some(user) = token.user {
content.push_str(&format!(
"\n[user]\nid = \"{}\"\nusername = \"{}\"\n",
user.id, user.username,
));
if let Some(name) = user.name {
content.push_str(&format!("name = \"{}\"\n", name));
}
if let Some(email) = user.email {
content.push_str(&format!("email = \"{}\"\n", email));
}
if let Some(avatar_url) = user.avatar_url {
content.push_str(&format!("avatar_url = \"{}\"\n", avatar_url));
}
}
let _ = crate::auth::write_auth_file_secure(auth_path, &content);
Ok(token.access_token)
}
pub fn model_name_suggests_vision(name: &str) -> bool {
let n = name.to_lowercase();
n.contains("vision")
|| n.contains("-vl")
|| n.contains("vl-")
|| n.contains("ocr")
|| n.contains("-4v")
|| n.contains("-4.1v")
|| n.starts_with("gpt-4o")
|| n.starts_with("claude-3")
|| n.starts_with("claude-4")
|| n.starts_with("claude-5")
|| n.starts_with("claude-6")
|| n.starts_with("claude-7")
|| n.starts_with("claude-sonnet")
|| n.starts_with("claude-opus")
|| n.starts_with("claude-haiku")
|| n.starts_with("gemini")
|| n.starts_with("pixtral")
|| n.contains("llava")
|| n.contains("qvq")
}
#[cfg(test)]
mod tests {
use super::{model_name_suggests_vision, unavailable_provider};
#[test]
fn test_auth_token_path_consistency() {
let auth_module_path = crate::auth::auth_file_path();
let expected_path = crate::tool::real_home_dir()
.unwrap_or_else(|| std::path::PathBuf::from("."))
.join(".atomcode")
.join("auth.toml");
assert_eq!(
auth_module_path, expected_path,
"auth_file_path() should always return ~/.atomcode/auth.toml"
);
assert!(
auth_module_path.ends_with(".atomcode/auth.toml")
|| auth_module_path.ends_with(".atomcode\\auth.toml"), "Path should end with .atomcode/auth.toml, got: {}",
auth_module_path.display()
);
}
use crate::config::provider::ProviderConfig;
fn cfg(provider_type: &str, api_key: &str) -> ProviderConfig {
ProviderConfig {
provider_type: provider_type.to_string(),
api_key: Some(api_key.to_string()),
model: "m".to_string(),
base_url: Some("http://127.0.0.1:1/".to_string()),
system_prompt: None,
user_agent: None,
context_window: 8000,
max_tokens: None,
thinking_type: None,
thinking_keep: None,
reasoning_history: None,
thinking_enabled: None,
thinking_budget: None,
skip_tls_verify: false,
ephemeral: false,
}
}
#[test]
fn unavailable_provider_reports_reason() {
let provider = unavailable_provider("未配置 provider");
assert_eq!(provider.model_name(), "");
assert_eq!(provider.availability_error(), Some("未配置 provider"));
}
#[test]
fn create_provider_rejects_api_key_with_internal_control_chars() {
let result = super::create_provider(&cfg("openai", "sk-ab\nc"));
let err = match result {
Err(e) => e,
Ok(_) => panic!("expected Err for api_key with internal \\n"),
};
let msg = err.to_string();
assert!(
msg.contains("control character"),
"expected control-char error, got: {}",
msg
);
}
#[test]
fn create_provider_silently_trims_trailing_newline() {
let result = super::create_provider(&cfg("openai", "sk-abc\n"));
assert!(
result.is_ok(),
"trailing \\n should be trimmed silently, got: {:?}",
result.err().map(|e| e.to_string())
);
}
#[test]
fn create_provider_rejects_empty_or_whitespace_api_key() {
let result = super::create_provider(&cfg("openai", " "));
let err = match result {
Err(e) => e,
Ok(_) => panic!("expected Err for whitespace-only api_key"),
};
let msg = err.to_string();
assert!(
msg.contains("empty") || msg.contains("whitespace"),
"expected empty/whitespace error, got: {}",
msg
);
}
#[test]
fn create_provider_silently_trims_surrounding_whitespace() {
let result = super::create_provider(&cfg("openai", " sk-abc "));
assert!(
result.is_ok(),
"trimmable key should be accepted, got: {:?}",
result.err().map(|e| e.to_string())
);
}
#[test]
fn vision_heuristic_recognises_known_vision_models() {
assert!(model_name_suggests_vision("claude-3-5-sonnet"));
assert!(model_name_suggests_vision("claude-4-opus"));
assert!(model_name_suggests_vision("claude-sonnet-4-6"));
assert!(model_name_suggests_vision("gpt-4o"));
assert!(model_name_suggests_vision("gpt-4o-mini"));
assert!(model_name_suggests_vision("gpt-4-vision-preview"));
assert!(model_name_suggests_vision("GLM-4V"));
assert!(model_name_suggests_vision("glm-4.1v-thinking"));
assert!(model_name_suggests_vision("Qwen2-VL-7B"));
assert!(model_name_suggests_vision("deepseek-vl"));
assert!(model_name_suggests_vision("gemini-2.0-flash"));
assert!(model_name_suggests_vision("pixtral-12b"));
assert!(model_name_suggests_vision("llava-1.6"));
assert!(model_name_suggests_vision("qvq-72b-preview"));
}
#[test]
fn vision_heuristic_rejects_text_only_models() {
assert!(!model_name_suggests_vision("GLM-5.1"));
assert!(!model_name_suggests_vision("glm-5.1"));
assert!(!model_name_suggests_vision("deepseek-v4-flash"));
assert!(!model_name_suggests_vision("Qwen/Qwen3.6-35B-A3B"));
assert!(!model_name_suggests_vision("gpt-4-turbo")); assert!(!model_name_suggests_vision("kimi-k2-thinking"));
assert!(!model_name_suggests_vision("o1-preview")); assert!(!model_name_suggests_vision(""));
}
#[test]
fn vision_heuristic_recognises_ocr_models() {
assert!(model_name_suggests_vision("PaddleOCR-VL-0.9B"));
assert!(model_name_suggests_vision("Qwen2-VL-OCR-7B"));
assert!(model_name_suggests_vision("GOT-OCR-2.0"));
assert!(model_name_suggests_vision("PaddleOCR-2.0"));
assert!(model_name_suggests_vision("MinerU-OCR"));
assert!(model_name_suggests_vision("MonkeyOCR-1.2B"));
assert!(model_name_suggests_vision("got-ocr-1.0")); }
#[test]
fn vision_heuristic_documented_false_positives() {
assert!(!model_name_suggests_vision("focar-text-7b"));
}
}