rusty_commit/providers/
azure.rs1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use reqwest::{header, Client};
4use serde::{Deserialize, Serialize};
5
6use super::{split_prompt, AIProvider};
7use crate::config::Config;
8
9pub struct AzureProvider {
10 client: Client,
11 api_key: String,
12 endpoint: String,
13 deployment: String,
14}
15
16#[derive(Serialize)]
17struct AzureRequest {
18 messages: Vec<Message>,
19 max_tokens: u32,
20 temperature: f32,
21}
22
23#[derive(Serialize)]
24struct Message {
25 role: String,
26 content: String,
27}
28
29#[derive(Deserialize)]
30struct AzureResponse {
31 choices: Vec<Choice>,
32}
33
34#[derive(Deserialize)]
35struct Choice {
36 message: ResponseMessage,
37}
38
39#[derive(Deserialize)]
40struct ResponseMessage {
41 content: String,
42}
43
44impl AzureProvider {
45 pub fn new(config: &Config) -> Result<Self> {
46 let api_key = config
47 .api_key
48 .as_ref()
49 .context("Azure API key not configured. Run: rco config set RCO_API_KEY=<your_key>")?
50 .clone();
51
52 let endpoint = config
53 .api_url
54 .as_ref()
55 .context(
56 "Azure endpoint not configured. Run: rco config set RCO_API_URL=<your_endpoint>",
57 )?
58 .clone();
59
60 let deployment = config
61 .model
62 .as_deref()
63 .unwrap_or("gpt-35-turbo")
64 .to_string();
65
66 let client = Client::new();
67
68 Ok(Self {
69 client,
70 api_key,
71 endpoint,
72 deployment,
73 })
74 }
75
76 #[allow(dead_code)]
78 pub fn from_account(
79 account: &crate::config::accounts::AccountConfig,
80 api_key: &str,
81 config: &Config,
82 ) -> Result<Self> {
83 let endpoint = account
84 .api_url
85 .as_ref()
86 .context(
87 "Azure endpoint required. Set with: rco config set RCO_API_URL=<your_endpoint>",
88 )?
89 .clone();
90
91 let deployment = account
92 .model
93 .as_deref()
94 .or(config.model.as_deref())
95 .unwrap_or("gpt-35-turbo")
96 .to_string();
97
98 let client = Client::new();
99
100 Ok(Self {
101 client,
102 api_key: api_key.to_string(),
103 endpoint,
104 deployment,
105 })
106 }
107}
108
109#[async_trait]
110impl AIProvider for AzureProvider {
111 async fn generate_commit_message(
112 &self,
113 diff: &str,
114 context: Option<&str>,
115 full_gitmoji: bool,
116 config: &Config,
117 ) -> Result<String> {
118 let (system_prompt, user_prompt) = split_prompt(diff, context, config, full_gitmoji);
119
120 let request = AzureRequest {
121 messages: vec![
122 Message {
123 role: "system".to_string(),
124 content: system_prompt,
125 },
126 Message {
127 role: "user".to_string(),
128 content: user_prompt,
129 },
130 ],
131 max_tokens: config.tokens_max_output.unwrap_or(500),
132 temperature: 0.7,
133 };
134
135 let url = format!(
136 "{}/openai/deployments/{}/chat/completions?api-version=2024-02-01",
137 self.endpoint, self.deployment
138 );
139
140 let response = self
141 .client
142 .post(&url)
143 .header("api-key", &self.api_key)
144 .header(header::CONTENT_TYPE, "application/json")
145 .json(&request)
146 .send()
147 .await
148 .context("Failed to connect to Azure OpenAI")?;
149
150 if !response.status().is_success() {
151 let error_text = response.text().await?;
152 anyhow::bail!("Azure OpenAI API error: {}", error_text);
153 }
154
155 let azure_response: AzureResponse = response
156 .json()
157 .await
158 .context("Failed to parse Azure OpenAI response")?;
159
160 let message = azure_response
161 .choices
162 .first()
163 .map(|c| c.message.content.trim().to_string())
164 .context("No response from Azure OpenAI")?;
165
166 Ok(message)
167 }
168}
169
170pub struct AzureProviderBuilder;
172
173impl super::registry::ProviderBuilder for AzureProviderBuilder {
174 fn name(&self) -> &'static str {
175 "azure"
176 }
177
178 fn aliases(&self) -> Vec<&'static str> {
179 vec!["azure-openai"]
180 }
181
182 fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
183 Ok(Box::new(AzureProvider::new(config)?))
184 }
185
186 fn requires_api_key(&self) -> bool {
187 true
188 }
189
190 fn default_model(&self) -> Option<&'static str> {
191 Some("gpt-4o")
192 }
193}