devsper_providers/
azure_foundry.rs1use devsper_core::{LlmProvider, LlmRequest, LlmResponse, LlmRole, StopReason};
2use anyhow::{anyhow, Result};
3use async_trait::async_trait;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use tracing::debug;
7
8const API_VERSION: &str = "2024-05-01-preview";
9
10pub struct AzureFoundryProvider {
16 client: Client,
17 api_key: String,
18 endpoint: String,
19 deployment: String,
20}
21
22impl AzureFoundryProvider {
23 pub fn new(
24 api_key: impl Into<String>,
25 endpoint: impl Into<String>,
26 deployment: impl Into<String>,
27 ) -> Self {
28 Self {
29 client: Client::new(),
30 api_key: api_key.into(),
31 endpoint: endpoint.into(),
32 deployment: deployment.into(),
33 }
34 }
35}
36
37#[derive(Serialize)]
38struct OaiRequest<'a> {
39 model: &'a str,
40 messages: Vec<OaiMessage<'a>>,
41 #[serde(skip_serializing_if = "Option::is_none")]
42 max_tokens: Option<u32>,
43 #[serde(skip_serializing_if = "Option::is_none")]
44 temperature: Option<f32>,
45}
46
47#[derive(Serialize)]
48struct OaiMessage<'a> {
49 role: &'a str,
50 content: &'a str,
51}
52
53#[derive(Deserialize)]
54struct OaiResponse {
55 choices: Vec<OaiChoice>,
56 usage: OaiUsage,
57 model: String,
58}
59
60#[derive(Deserialize)]
61struct OaiChoice {
62 message: OaiChoiceMessage,
63 finish_reason: Option<String>,
64}
65
66#[derive(Deserialize)]
67struct OaiChoiceMessage {
68 content: Option<String>,
69}
70
71#[derive(Deserialize)]
72struct OaiUsage {
73 prompt_tokens: u32,
74 completion_tokens: u32,
75}
76
77fn role_str(role: &LlmRole) -> &'static str {
78 match role {
79 LlmRole::System => "system",
80 LlmRole::User | LlmRole::Tool => "user",
81 LlmRole::Assistant => "assistant",
82 }
83}
84
85#[async_trait]
86impl LlmProvider for AzureFoundryProvider {
87 async fn generate(&self, req: LlmRequest) -> Result<LlmResponse> {
88 use tracing::Instrument;
89
90 let span = tracing::info_span!(
91 "gen_ai.chat",
92 "gen_ai.system" = self.name(),
93 "gen_ai.operation.name" = "chat",
94 "gen_ai.request.model" = req.model.as_str(),
95 "gen_ai.request.max_tokens" = req.max_tokens,
96 "gen_ai.response.model" = tracing::field::Empty,
97 "gen_ai.usage.input_tokens" = tracing::field::Empty,
98 "gen_ai.usage.output_tokens" = tracing::field::Empty,
99 );
100
101 let url = format!(
102 "{}/models/chat/completions?api-version={API_VERSION}",
103 self.endpoint
104 );
105
106 let messages: Vec<OaiMessage> = req
107 .messages
108 .iter()
109 .map(|m| OaiMessage {
110 role: role_str(&m.role),
111 content: &m.content,
112 })
113 .collect();
114
115 let body = OaiRequest {
117 model: &self.deployment,
118 messages,
119 max_tokens: req.max_tokens,
120 temperature: req.temperature,
121 };
122
123 debug!(deployment = %self.deployment, provider = "azure-foundry", "Azure AI Foundry request");
124
125 let result = async {
126 let resp = self
127 .client
128 .post(&url)
129 .header("api-key", &self.api_key)
130 .header("Content-Type", "application/json")
131 .json(&body)
132 .send()
133 .await?;
134
135 if !resp.status().is_success() {
136 let status = resp.status();
137 let text = resp.text().await.unwrap_or_default();
138 return Err(anyhow!("azure-foundry API error {status}: {text}"));
139 }
140
141 let data: OaiResponse = resp.json().await?;
142 let choice = data
143 .choices
144 .into_iter()
145 .next()
146 .ok_or_else(|| anyhow!("No choices in response"))?;
147
148 let stop_reason = match choice.finish_reason.as_deref() {
149 Some("tool_calls") => StopReason::ToolUse,
150 Some("length") => StopReason::MaxTokens,
151 Some("stop") | None => StopReason::EndTurn,
152 _ => StopReason::EndTurn,
153 };
154
155 Ok(LlmResponse {
156 content: choice.message.content.unwrap_or_default(),
157 tool_calls: vec![],
158 input_tokens: data.usage.prompt_tokens,
159 output_tokens: data.usage.completion_tokens,
160 model: data.model,
161 stop_reason,
162 })
163 }
164 .instrument(span.clone())
165 .await;
166
167 if let Ok(ref resp) = result {
168 span.record("gen_ai.response.model", resp.model.as_str());
169 span.record("gen_ai.usage.input_tokens", resp.input_tokens);
170 span.record("gen_ai.usage.output_tokens", resp.output_tokens);
171 }
172 result
173 }
174
175 fn name(&self) -> &str {
176 "azure-foundry"
177 }
178
179 fn supports_model(&self, model: &str) -> bool {
180 model.starts_with("foundry:")
181 }
182}