rusty_commit/providers/
bedrock.rs1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use aws_config::Region;
4use aws_sdk_bedrockruntime as bedrock;
5use aws_sdk_bedrockruntime::types::{ContentBlock, SystemContentBlock};
6
7use super::prompt::split_prompt;
8use super::AIProvider;
9use crate::config::Config;
10
11pub struct BedrockProvider {
12 client: bedrock::Client,
13 model: String,
14 #[allow(dead_code)]
15 region: String,
16}
17
18impl BedrockProvider {
19 pub fn new(config: &Config) -> Result<Self> {
20 let rt = tokio::runtime::Runtime::new().context("Failed to create runtime")?;
21 rt.block_on(async { Self::new_async(config).await })
22 }
23
24 async fn new_async(config: &Config) -> Result<Self> {
25 let region = config
26 .api_url
27 .as_ref()
28 .and_then(|url| {
29 url.split("bedrock.")
30 .nth(1)
31 .and_then(|s| s.split('.').next())
32 .map(|s| s.to_string())
33 })
34 .unwrap_or_else(|| {
35 std::env::var("AWS_REGION").unwrap_or_else(|_| "us-east-1".to_string())
36 });
37
38 let region_provider = Region::new(region.clone());
39 let shared_config = aws_config::defaults(aws_config::BehaviorVersion::latest())
40 .region(region_provider)
41 .load()
42 .await;
43
44 let client = bedrock::Client::new(&shared_config);
45
46 let model = config
47 .model
48 .as_deref()
49 .unwrap_or("anthropic.claude-3-5-sonnet-20241022-v2:0")
50 .to_string();
51
52 Ok(Self {
53 client,
54 model,
55 region,
56 })
57 }
58
59 #[allow(dead_code)]
60 pub fn from_account(
61 account: &crate::config::accounts::AccountConfig,
62 _api_key: &str,
63 config: &Config,
64 ) -> Result<Self> {
65 let rt = tokio::runtime::Runtime::new().context("Failed to create runtime")?;
66 rt.block_on(async { Self::from_account_async(account, config).await })
67 }
68
69 async fn from_account_async(
70 account: &crate::config::accounts::AccountConfig,
71 config: &Config,
72 ) -> Result<Self> {
73 let region = account
74 .api_url
75 .as_ref()
76 .and_then(|url| {
77 url.split("bedrock.")
78 .nth(1)
79 .and_then(|s| s.split('.').next())
80 .map(|s| s.to_string())
81 })
82 .unwrap_or_else(|| "us-east-1".to_string());
83
84 let region_provider = Region::new(region.clone());
85 let shared_config = aws_config::defaults(aws_config::BehaviorVersion::latest())
86 .region(region_provider)
87 .load()
88 .await;
89
90 let client = bedrock::Client::new(&shared_config);
91
92 let model = account
93 .model
94 .as_deref()
95 .or(config.model.as_deref())
96 .unwrap_or("anthropic.claude-3-5-sonnet-20241022-v2:0")
97 .to_string();
98
99 Ok(Self {
100 client,
101 model,
102 region,
103 })
104 }
105}
106
107#[async_trait]
108impl AIProvider for BedrockProvider {
109 async fn generate_commit_message(
110 &self,
111 diff: &str,
112 context: Option<&str>,
113 full_gitmoji: bool,
114 config: &Config,
115 ) -> Result<String> {
116 let (system_prompt, user_prompt) = split_prompt(diff, context, config, full_gitmoji);
117
118 let system_block = SystemContentBlock::Text(system_prompt);
120
121 let user_content = ContentBlock::Text(user_prompt);
123 let user_message = bedrock::types::Message::builder()
124 .role(bedrock::types::ConversationRole::User)
125 .content(user_content)
126 .build()
127 .context("Failed to build user message")?;
128
129 let inference_config = bedrock::types::InferenceConfiguration::builder()
130 .max_tokens(config.tokens_max_output.unwrap_or(500) as i32)
131 .temperature(0.7)
132 .build();
133
134 let converse_output = self
135 .client
136 .converse()
137 .model_id(&self.model)
138 .messages(user_message)
139 .system(system_block)
140 .inference_config(inference_config)
141 .send()
142 .await
143 .context("Failed to communicate with Bedrock")?;
144
145 let message = converse_output
146 .output()
147 .and_then(|o| o.as_message().ok())
148 .context("No response from Bedrock")?;
149
150 let content = message
151 .content()
152 .first()
153 .and_then(|c| c.as_text().ok())
154 .context("Empty response from Bedrock")?;
155
156 Ok(content.trim().to_string())
157 }
158}
159
160pub struct BedrockProviderBuilder;
162
163impl super::registry::ProviderBuilder for BedrockProviderBuilder {
164 fn name(&self) -> &'static str {
165 "bedrock"
166 }
167
168 fn aliases(&self) -> Vec<&'static str> {
169 vec!["aws-bedrock", "amazon-bedrock"]
170 }
171
172 fn category(&self) -> super::registry::ProviderCategory {
173 super::registry::ProviderCategory::Cloud
174 }
175
176 fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
177 Ok(Box::new(BedrockProvider::new(config)?))
178 }
179
180 fn requires_api_key(&self) -> bool {
181 false
182 }
183
184 fn default_model(&self) -> Option<&'static str> {
185 Some("anthropic.claude-3-5-sonnet-20241022-v2:0")
186 }
187}