Skip to main content

devsper_providers/
azure_openai.rs

1use devsper_core::{LlmProvider, LlmRequest, LlmResponse, LlmRole, StopReason};
2use anyhow::{anyhow, Result};
3use async_trait::async_trait;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use tracing::debug;
7
8/// Azure OpenAI provider.
9///
10/// URL format: `{endpoint}/openai/deployments/{deployment}/chat/completions?api-version={api_version}`
11/// Auth header: `api-key: {api_key}` (NOT Bearer)
12/// Supports model names prefixed with "azure:" (e.g. "azure:gpt-4o").
13pub struct AzureOpenAiProvider {
14    client: Client,
15    api_key: String,
16    endpoint: String,
17    deployment: String,
18    api_version: String,
19}
20
21impl AzureOpenAiProvider {
22    pub fn new(
23        api_key: impl Into<String>,
24        endpoint: impl Into<String>,
25        deployment: impl Into<String>,
26        api_version: impl Into<String>,
27    ) -> Self {
28        Self {
29            client: Client::new(),
30            api_key: api_key.into(),
31            endpoint: endpoint.into(),
32            deployment: deployment.into(),
33            api_version: api_version.into(),
34        }
35    }
36}
37
38#[derive(Serialize)]
39struct OaiRequest<'a> {
40    model: &'a str,
41    messages: Vec<OaiMessage<'a>>,
42    #[serde(skip_serializing_if = "Option::is_none")]
43    max_tokens: Option<u32>,
44    #[serde(skip_serializing_if = "Option::is_none")]
45    temperature: Option<f32>,
46}
47
48#[derive(Serialize)]
49struct OaiMessage<'a> {
50    role: &'a str,
51    content: &'a str,
52}
53
54#[derive(Deserialize)]
55struct OaiResponse {
56    choices: Vec<OaiChoice>,
57    usage: OaiUsage,
58    model: String,
59}
60
61#[derive(Deserialize)]
62struct OaiChoice {
63    message: OaiChoiceMessage,
64    finish_reason: Option<String>,
65}
66
67#[derive(Deserialize)]
68struct OaiChoiceMessage {
69    content: Option<String>,
70}
71
72#[derive(Deserialize)]
73struct OaiUsage {
74    prompt_tokens: u32,
75    completion_tokens: u32,
76}
77
78fn role_str(role: &LlmRole) -> &'static str {
79    match role {
80        LlmRole::System => "system",
81        LlmRole::User | LlmRole::Tool => "user",
82        LlmRole::Assistant => "assistant",
83    }
84}
85
86#[async_trait]
87impl LlmProvider for AzureOpenAiProvider {
88    async fn generate(&self, req: LlmRequest) -> Result<LlmResponse> {
89        use tracing::Instrument;
90
91        let span = tracing::info_span!(
92            "gen_ai.chat",
93            "gen_ai.system" = self.name(),
94            "gen_ai.operation.name" = "chat",
95            "gen_ai.request.model" = req.model.as_str(),
96            "gen_ai.request.max_tokens" = req.max_tokens,
97            "gen_ai.response.model" = tracing::field::Empty,
98            "gen_ai.usage.input_tokens" = tracing::field::Empty,
99            "gen_ai.usage.output_tokens" = tracing::field::Empty,
100        );
101
102        let url = format!(
103            "{}/openai/deployments/{}/chat/completions?api-version={}",
104            self.endpoint, self.deployment, self.api_version
105        );
106
107        let messages: Vec<OaiMessage> = req
108            .messages
109            .iter()
110            .map(|m| OaiMessage {
111                role: role_str(&m.role),
112                content: &m.content,
113            })
114            .collect();
115
116        // The deployment is fixed — use it as the model name in the request body
117        let body = OaiRequest {
118            model: &self.deployment,
119            messages,
120            max_tokens: req.max_tokens,
121            temperature: req.temperature,
122        };
123
124        debug!(deployment = %self.deployment, provider = "azure-openai", "Azure OpenAI request");
125
126        let result = async {
127            let resp = self
128                .client
129                .post(&url)
130                .header("api-key", &self.api_key)
131                .header("Content-Type", "application/json")
132                .json(&body)
133                .send()
134                .await?;
135
136            if !resp.status().is_success() {
137                let status = resp.status();
138                let text = resp.text().await.unwrap_or_default();
139                return Err(anyhow!("azure-openai API error {status}: {text}"));
140            }
141
142            let data: OaiResponse = resp.json().await?;
143            let choice = data
144                .choices
145                .into_iter()
146                .next()
147                .ok_or_else(|| anyhow!("No choices in response"))?;
148
149            let stop_reason = match choice.finish_reason.as_deref() {
150                Some("tool_calls") => StopReason::ToolUse,
151                Some("length") => StopReason::MaxTokens,
152                Some("stop") | None => StopReason::EndTurn,
153                _ => StopReason::EndTurn,
154            };
155
156            Ok(LlmResponse {
157                content: choice.message.content.unwrap_or_default(),
158                tool_calls: vec![],
159                input_tokens: data.usage.prompt_tokens,
160                output_tokens: data.usage.completion_tokens,
161                model: data.model,
162                stop_reason,
163            })
164        }
165        .instrument(span.clone())
166        .await;
167
168        if let Ok(ref resp) = result {
169            span.record("gen_ai.response.model", resp.model.as_str());
170            span.record("gen_ai.usage.input_tokens", resp.input_tokens);
171            span.record("gen_ai.usage.output_tokens", resp.output_tokens);
172        }
173        result
174    }
175
176    fn name(&self) -> &str {
177        "azure-openai"
178    }
179
180    fn supports_model(&self, model: &str) -> bool {
181        model.starts_with("azure:")
182    }
183}