Skip to main content

rusty_commit/providers/
bedrock.rs

1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use aws_config::Region;
4use aws_sdk_bedrockruntime as bedrock;
5use aws_sdk_bedrockruntime::types::{ContentBlock, SystemContentBlock};
6
7use super::prompt::split_prompt;
8use super::AIProvider;
9use crate::config::Config;
10
11pub struct BedrockProvider {
12    client: bedrock::Client,
13    model: String,
14    #[allow(dead_code)]
15    region: String,
16}
17
18impl BedrockProvider {
19    pub fn new(config: &Config) -> Result<Self> {
20        let rt = tokio::runtime::Runtime::new().context("Failed to create runtime")?;
21        rt.block_on(async { Self::new_async(config).await })
22    }
23
24    async fn new_async(config: &Config) -> Result<Self> {
25        let region = config
26            .api_url
27            .as_ref()
28            .and_then(|url| {
29                url.split("bedrock.")
30                    .nth(1)
31                    .and_then(|s| s.split('.').next())
32                    .map(|s| s.to_string())
33            })
34            .unwrap_or_else(|| {
35                std::env::var("AWS_REGION").unwrap_or_else(|_| "us-east-1".to_string())
36            });
37
38        let region_provider = Region::new(region.clone());
39        let shared_config = aws_config::defaults(aws_config::BehaviorVersion::latest())
40            .region(region_provider)
41            .load()
42            .await;
43
44        let client = bedrock::Client::new(&shared_config);
45
46        let model = config
47            .model
48            .as_deref()
49            .unwrap_or("anthropic.claude-3-5-sonnet-20241022-v2:0")
50            .to_string();
51
52        Ok(Self {
53            client,
54            model,
55            region,
56        })
57    }
58
59    #[allow(dead_code)]
60    pub fn from_account(
61        account: &crate::config::accounts::AccountConfig,
62        _api_key: &str,
63        config: &Config,
64    ) -> Result<Self> {
65        let rt = tokio::runtime::Runtime::new().context("Failed to create runtime")?;
66        rt.block_on(async { Self::from_account_async(account, config).await })
67    }
68
69    async fn from_account_async(
70        account: &crate::config::accounts::AccountConfig,
71        config: &Config,
72    ) -> Result<Self> {
73        let region = account
74            .api_url
75            .as_ref()
76            .and_then(|url| {
77                url.split("bedrock.")
78                    .nth(1)
79                    .and_then(|s| s.split('.').next())
80                    .map(|s| s.to_string())
81            })
82            .unwrap_or_else(|| "us-east-1".to_string());
83
84        let region_provider = Region::new(region.clone());
85        let shared_config = aws_config::defaults(aws_config::BehaviorVersion::latest())
86            .region(region_provider)
87            .load()
88            .await;
89
90        let client = bedrock::Client::new(&shared_config);
91
92        let model = account
93            .model
94            .as_deref()
95            .or(config.model.as_deref())
96            .unwrap_or("anthropic.claude-3-5-sonnet-20241022-v2:0")
97            .to_string();
98
99        Ok(Self {
100            client,
101            model,
102            region,
103        })
104    }
105}
106
107#[async_trait]
108impl AIProvider for BedrockProvider {
109    async fn generate_commit_message(
110        &self,
111        diff: &str,
112        context: Option<&str>,
113        full_gitmoji: bool,
114        config: &Config,
115    ) -> Result<String> {
116        let (system_prompt, user_prompt) = split_prompt(diff, context, config, full_gitmoji);
117
118        // Build system message
119        let system_block = SystemContentBlock::Text(system_prompt);
120
121        // Build user message
122        let user_content = ContentBlock::Text(user_prompt);
123        let user_message = bedrock::types::Message::builder()
124            .role(bedrock::types::ConversationRole::User)
125            .content(user_content)
126            .build()
127            .context("Failed to build user message")?;
128
129        let inference_config = bedrock::types::InferenceConfiguration::builder()
130            .max_tokens(config.tokens_max_output.unwrap_or(500) as i32)
131            .temperature(0.7)
132            .build();
133
134        let converse_output = self
135            .client
136            .converse()
137            .model_id(&self.model)
138            .messages(user_message)
139            .system(system_block)
140            .inference_config(inference_config)
141            .send()
142            .await
143            .context("Failed to communicate with Bedrock")?;
144
145        let message = converse_output
146            .output()
147            .and_then(|o| o.as_message().ok())
148            .context("No response from Bedrock")?;
149
150        let content = message
151            .content()
152            .first()
153            .and_then(|c| c.as_text().ok())
154            .context("Empty response from Bedrock")?;
155
156        Ok(content.trim().to_string())
157    }
158}
159
160/// ProviderBuilder for Bedrock
161pub struct BedrockProviderBuilder;
162
163impl super::registry::ProviderBuilder for BedrockProviderBuilder {
164    fn name(&self) -> &'static str {
165        "bedrock"
166    }
167
168    fn aliases(&self) -> Vec<&'static str> {
169        vec!["aws-bedrock", "amazon-bedrock"]
170    }
171
172    fn category(&self) -> super::registry::ProviderCategory {
173        super::registry::ProviderCategory::Cloud
174    }
175
176    fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
177        Ok(Box::new(BedrockProvider::new(config)?))
178    }
179
180    fn requires_api_key(&self) -> bool {
181        false
182    }
183
184    fn default_model(&self) -> Option<&'static str> {
185        Some("anthropic.claude-3-5-sonnet-20241022-v2:0")
186    }
187}