devsper_providers/
azure_openai.rs1use devsper_core::{LlmProvider, LlmRequest, LlmResponse, LlmRole, StopReason};
2use anyhow::{anyhow, Result};
3use async_trait::async_trait;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use tracing::debug;
7
8pub struct AzureOpenAiProvider {
14 client: Client,
15 api_key: String,
16 endpoint: String,
17 deployment: String,
18 api_version: String,
19}
20
21impl AzureOpenAiProvider {
22 pub fn new(
23 api_key: impl Into<String>,
24 endpoint: impl Into<String>,
25 deployment: impl Into<String>,
26 api_version: impl Into<String>,
27 ) -> Self {
28 Self {
29 client: Client::new(),
30 api_key: api_key.into(),
31 endpoint: endpoint.into(),
32 deployment: deployment.into(),
33 api_version: api_version.into(),
34 }
35 }
36}
37
38#[derive(Serialize)]
39struct OaiRequest<'a> {
40 model: &'a str,
41 messages: Vec<OaiMessage<'a>>,
42 #[serde(skip_serializing_if = "Option::is_none")]
43 max_tokens: Option<u32>,
44 #[serde(skip_serializing_if = "Option::is_none")]
45 temperature: Option<f32>,
46}
47
48#[derive(Serialize)]
49struct OaiMessage<'a> {
50 role: &'a str,
51 content: &'a str,
52}
53
54#[derive(Deserialize)]
55struct OaiResponse {
56 choices: Vec<OaiChoice>,
57 usage: OaiUsage,
58 model: String,
59}
60
61#[derive(Deserialize)]
62struct OaiChoice {
63 message: OaiChoiceMessage,
64 finish_reason: Option<String>,
65}
66
67#[derive(Deserialize)]
68struct OaiChoiceMessage {
69 content: Option<String>,
70}
71
72#[derive(Deserialize)]
73struct OaiUsage {
74 prompt_tokens: u32,
75 completion_tokens: u32,
76}
77
78fn role_str(role: &LlmRole) -> &'static str {
79 match role {
80 LlmRole::System => "system",
81 LlmRole::User | LlmRole::Tool => "user",
82 LlmRole::Assistant => "assistant",
83 }
84}
85
86#[async_trait]
87impl LlmProvider for AzureOpenAiProvider {
88 async fn generate(&self, req: LlmRequest) -> Result<LlmResponse> {
89 use tracing::Instrument;
90
91 let span = tracing::info_span!(
92 "gen_ai.chat",
93 "gen_ai.system" = self.name(),
94 "gen_ai.operation.name" = "chat",
95 "gen_ai.request.model" = req.model.as_str(),
96 "gen_ai.request.max_tokens" = req.max_tokens,
97 "gen_ai.response.model" = tracing::field::Empty,
98 "gen_ai.usage.input_tokens" = tracing::field::Empty,
99 "gen_ai.usage.output_tokens" = tracing::field::Empty,
100 );
101
102 let url = format!(
103 "{}/openai/deployments/{}/chat/completions?api-version={}",
104 self.endpoint, self.deployment, self.api_version
105 );
106
107 let messages: Vec<OaiMessage> = req
108 .messages
109 .iter()
110 .map(|m| OaiMessage {
111 role: role_str(&m.role),
112 content: &m.content,
113 })
114 .collect();
115
116 let body = OaiRequest {
118 model: &self.deployment,
119 messages,
120 max_tokens: req.max_tokens,
121 temperature: req.temperature,
122 };
123
124 debug!(deployment = %self.deployment, provider = "azure-openai", "Azure OpenAI request");
125
126 let result = async {
127 let resp = self
128 .client
129 .post(&url)
130 .header("api-key", &self.api_key)
131 .header("Content-Type", "application/json")
132 .json(&body)
133 .send()
134 .await?;
135
136 if !resp.status().is_success() {
137 let status = resp.status();
138 let text = resp.text().await.unwrap_or_default();
139 return Err(anyhow!("azure-openai API error {status}: {text}"));
140 }
141
142 let data: OaiResponse = resp.json().await?;
143 let choice = data
144 .choices
145 .into_iter()
146 .next()
147 .ok_or_else(|| anyhow!("No choices in response"))?;
148
149 let stop_reason = match choice.finish_reason.as_deref() {
150 Some("tool_calls") => StopReason::ToolUse,
151 Some("length") => StopReason::MaxTokens,
152 Some("stop") | None => StopReason::EndTurn,
153 _ => StopReason::EndTurn,
154 };
155
156 Ok(LlmResponse {
157 content: choice.message.content.unwrap_or_default(),
158 tool_calls: vec![],
159 input_tokens: data.usage.prompt_tokens,
160 output_tokens: data.usage.completion_tokens,
161 model: data.model,
162 stop_reason,
163 })
164 }
165 .instrument(span.clone())
166 .await;
167
168 if let Ok(ref resp) = result {
169 span.record("gen_ai.response.model", resp.model.as_str());
170 span.record("gen_ai.usage.input_tokens", resp.input_tokens);
171 span.record("gen_ai.usage.output_tokens", resp.output_tokens);
172 }
173 result
174 }
175
176 fn name(&self) -> &str {
177 "azure-openai"
178 }
179
180 fn supports_model(&self, model: &str) -> bool {
181 model.starts_with("azure:")
182 }
183}