omni_dev/claude/
client.rs1use crate::claude::{ai_client::AiClient, error::ClaudeError, prompts};
4use crate::claude::{bedrock_ai_client::BedrockAiClient, claude_ai_client::ClaudeAiClient};
5use crate::data::{
6 amendments::AmendmentFile, context::CommitContext, RepositoryView, RepositoryViewForAI,
7};
8use anyhow::{Context, Result};
9use tracing::debug;
10
11pub struct ClaudeClient {
13 ai_client: Box<dyn AiClient>,
15}
16
17impl ClaudeClient {
18 pub fn new(ai_client: Box<dyn AiClient>) -> Self {
20 Self { ai_client }
21 }
22
23 pub fn from_env(model: String) -> Result<Self> {
25 let api_key = std::env::var("CLAUDE_API_KEY")
27 .or_else(|_| std::env::var("ANTHROPIC_API_KEY"))
28 .map_err(|_| ClaudeError::ApiKeyNotFound)?;
29
30 let ai_client = ClaudeAiClient::new(model, api_key);
31 Ok(Self::new(Box::new(ai_client)))
32 }
33
34 pub async fn generate_amendments(&self, repo_view: &RepositoryView) -> Result<AmendmentFile> {
36 let ai_repo_view = RepositoryViewForAI::from_repository_view(repo_view.clone())
38 .context("Failed to enhance repository view with diff content")?;
39
40 let repo_yaml = crate::data::to_yaml(&ai_repo_view)
42 .context("Failed to serialize repository view to YAML")?;
43
44 let user_prompt = prompts::generate_user_prompt(&repo_yaml);
46
47 let content = self
49 .ai_client
50 .send_request(prompts::SYSTEM_PROMPT, &user_prompt)
51 .await?;
52
53 self.parse_amendment_response(&content)
55 }
56
57 pub async fn generate_contextual_amendments(
59 &self,
60 repo_view: &RepositoryView,
61 context: &CommitContext,
62 ) -> Result<AmendmentFile> {
63 let ai_repo_view = RepositoryViewForAI::from_repository_view(repo_view.clone())
65 .context("Failed to enhance repository view with diff content")?;
66
67 let repo_yaml = crate::data::to_yaml(&ai_repo_view)
69 .context("Failed to serialize repository view to YAML")?;
70
71 let system_prompt = prompts::generate_contextual_system_prompt(context);
73 let user_prompt = prompts::generate_contextual_user_prompt(&repo_yaml, context);
74
75 match &context.project.commit_guidelines {
77 Some(guidelines) => {
78 debug!(length = guidelines.len(), "Project commit guidelines found");
79 debug!(guidelines = %guidelines, "Commit guidelines content");
80 }
81 None => {
82 debug!("No project commit guidelines found");
83 }
84 }
85
86 let content = self
88 .ai_client
89 .send_request(&system_prompt, &user_prompt)
90 .await?;
91
92 self.parse_amendment_response(&content)
94 }
95
96 fn parse_amendment_response(&self, content: &str) -> Result<AmendmentFile> {
98 let yaml_content = if content.contains("```yaml") {
100 content
101 .split("```yaml")
102 .nth(1)
103 .and_then(|s| s.split("```").next())
104 .unwrap_or(content)
105 .trim()
106 } else if content.contains("```") {
107 content
109 .split("```")
110 .nth(1)
111 .and_then(|s| s.split("```").next())
112 .unwrap_or(content)
113 .trim()
114 } else {
115 content.trim()
116 };
117
118 let amendment_file: AmendmentFile = serde_yaml::from_str(yaml_content).map_err(|e| {
120 debug!(
121 error = %e,
122 content_length = content.len(),
123 yaml_length = yaml_content.len(),
124 "YAML parsing failed"
125 );
126 debug!(content = %content, "Raw Claude response");
127 debug!(yaml = %yaml_content, "Extracted YAML content");
128
129 if yaml_content.lines().any(|line| line.contains('\t')) {
131 ClaudeError::AmendmentParsingFailed("YAML parsing error: Found tab characters. YAML requires spaces for indentation.".to_string())
132 } else if yaml_content.lines().any(|line| line.trim().starts_with('-') && !line.trim().starts_with("- ")) {
133 ClaudeError::AmendmentParsingFailed("YAML parsing error: List items must have a space after the dash (- item).".to_string())
134 } else {
135 ClaudeError::AmendmentParsingFailed(format!("YAML parsing error: {}", e))
136 }
137 })?;
138
139 amendment_file
141 .validate()
142 .map_err(|e| ClaudeError::AmendmentParsingFailed(format!("Validation error: {}", e)))?;
143
144 Ok(amendment_file)
145 }
146}
147
148pub fn create_default_claude_client(model: Option<String>) -> Result<ClaudeClient> {
150 use crate::utils::settings::{get_env_var, get_env_vars};
151
152 let use_bedrock = get_env_var("CLAUDE_CODE_USE_BEDROCK")
154 .map(|val| val == "true")
155 .unwrap_or(false);
156
157 let model = model
159 .or_else(|| get_env_var("ANTHROPIC_MODEL").ok())
160 .unwrap_or_else(|| "claude-3-haiku-20240307".to_string());
161
162 if use_bedrock {
163 let skip_bedrock_auth = get_env_var("CLAUDE_CODE_SKIP_BEDROCK_AUTH")
165 .map(|val| val == "true")
166 .unwrap_or(false);
167
168 if skip_bedrock_auth {
169 let auth_token =
171 get_env_var("ANTHROPIC_AUTH_TOKEN").map_err(|_| ClaudeError::ApiKeyNotFound)?;
172
173 let base_url = get_env_var("ANTHROPIC_BEDROCK_BASE_URL")
174 .map_err(|_| ClaudeError::ApiKeyNotFound)?;
175
176 let ai_client = BedrockAiClient::new(model, auth_token, base_url);
177 return Ok(ClaudeClient::new(Box::new(ai_client)));
178 }
179 }
180
181 let api_key = get_env_vars(&[
183 "CLAUDE_API_KEY",
184 "ANTHROPIC_API_KEY",
185 "ANTHROPIC_AUTH_TOKEN",
186 ])
187 .map_err(|_| ClaudeError::ApiKeyNotFound)?;
188
189 let ai_client = ClaudeAiClient::new(model, api_key);
190 Ok(ClaudeClient::new(Box::new(ai_client)))
191}