Skip to main content

rusty_commit/providers/
azure.rs

1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use reqwest::{header, Client};
4use serde::{Deserialize, Serialize};
5
6use super::{split_prompt, AIProvider};
7use crate::config::Config;
8
9pub struct AzureProvider {
10    client: Client,
11    api_key: String,
12    endpoint: String,
13    deployment: String,
14}
15
16#[derive(Serialize)]
17struct AzureRequest {
18    messages: Vec<Message>,
19    max_tokens: u32,
20    temperature: f32,
21}
22
23#[derive(Serialize)]
24struct Message {
25    role: String,
26    content: String,
27}
28
29#[derive(Deserialize)]
30struct AzureResponse {
31    choices: Vec<Choice>,
32}
33
34#[derive(Deserialize)]
35struct Choice {
36    message: ResponseMessage,
37}
38
39#[derive(Deserialize)]
40struct ResponseMessage {
41    content: String,
42}
43
44impl AzureProvider {
45    pub fn new(config: &Config) -> Result<Self> {
46        let api_key = config
47            .api_key
48            .as_ref()
49            .context("Azure API key not configured. Run: rco config set RCO_API_KEY=<your_key>")?
50            .clone();
51
52        let endpoint = config
53            .api_url
54            .as_ref()
55            .context(
56                "Azure endpoint not configured. Run: rco config set RCO_API_URL=<your_endpoint>",
57            )?
58            .clone();
59
60        let deployment = config
61            .model
62            .as_deref()
63            .unwrap_or("gpt-35-turbo")
64            .to_string();
65
66        let client = Client::new();
67
68        Ok(Self {
69            client,
70            api_key,
71            endpoint,
72            deployment,
73        })
74    }
75
76    /// Create provider from account configuration
77    #[allow(dead_code)]
78    pub fn from_account(
79        account: &crate::config::accounts::AccountConfig,
80        api_key: &str,
81        config: &Config,
82    ) -> Result<Self> {
83        let endpoint = account
84            .api_url
85            .as_ref()
86            .context(
87                "Azure endpoint required. Set with: rco config set RCO_API_URL=<your_endpoint>",
88            )?
89            .clone();
90
91        let deployment = account
92            .model
93            .as_deref()
94            .or(config.model.as_deref())
95            .unwrap_or("gpt-35-turbo")
96            .to_string();
97
98        let client = Client::new();
99
100        Ok(Self {
101            client,
102            api_key: api_key.to_string(),
103            endpoint,
104            deployment,
105        })
106    }
107}
108
109#[async_trait]
110impl AIProvider for AzureProvider {
111    async fn generate_commit_message(
112        &self,
113        diff: &str,
114        context: Option<&str>,
115        full_gitmoji: bool,
116        config: &Config,
117    ) -> Result<String> {
118        let (system_prompt, user_prompt) = split_prompt(diff, context, config, full_gitmoji);
119
120        let request = AzureRequest {
121            messages: vec![
122                Message {
123                    role: "system".to_string(),
124                    content: system_prompt,
125                },
126                Message {
127                    role: "user".to_string(),
128                    content: user_prompt,
129                },
130            ],
131            max_tokens: config.tokens_max_output.unwrap_or(500),
132            temperature: 0.7,
133        };
134
135        let url = format!(
136            "{}/openai/deployments/{}/chat/completions?api-version=2024-02-01",
137            self.endpoint, self.deployment
138        );
139
140        let response = self
141            .client
142            .post(&url)
143            .header("api-key", &self.api_key)
144            .header(header::CONTENT_TYPE, "application/json")
145            .json(&request)
146            .send()
147            .await
148            .context("Failed to connect to Azure OpenAI")?;
149
150        if !response.status().is_success() {
151            let error_text = response.text().await?;
152            anyhow::bail!("Azure OpenAI API error: {}", error_text);
153        }
154
155        let azure_response: AzureResponse = response
156            .json()
157            .await
158            .context("Failed to parse Azure OpenAI response")?;
159
160        let message = azure_response
161            .choices
162            .first()
163            .map(|c| c.message.content.trim().to_string())
164            .context("No response from Azure OpenAI")?;
165
166        Ok(message)
167    }
168}
169
170/// ProviderBuilder for Azure
171pub struct AzureProviderBuilder;
172
173impl super::registry::ProviderBuilder for AzureProviderBuilder {
174    fn name(&self) -> &'static str {
175        "azure"
176    }
177
178    fn aliases(&self) -> Vec<&'static str> {
179        vec!["azure-openai"]
180    }
181
182    fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
183        Ok(Box::new(AzureProvider::new(config)?))
184    }
185
186    fn requires_api_key(&self) -> bool {
187        true
188    }
189
190    fn default_model(&self) -> Option<&'static str> {
191        Some("gpt-4o")
192    }
193}