Skip to main content

devsper_providers/
azure_foundry.rs

1use devsper_core::{LlmProvider, LlmRequest, LlmResponse, LlmRole, StopReason};
2use anyhow::{anyhow, Result};
3use async_trait::async_trait;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use tracing::debug;
7
8const API_VERSION: &str = "2024-05-01-preview";
9
10/// Azure AI Foundry provider.
11///
12/// URL format: `{endpoint}/models/chat/completions?api-version=2024-05-01-preview`
13/// Auth header: `api-key: {api_key}`
14/// Supports model names prefixed with "foundry:" (e.g. "foundry:phi-4").
15pub struct AzureFoundryProvider {
16    client: Client,
17    api_key: String,
18    endpoint: String,
19    deployment: String,
20}
21
22impl AzureFoundryProvider {
23    pub fn new(
24        api_key: impl Into<String>,
25        endpoint: impl Into<String>,
26        deployment: impl Into<String>,
27    ) -> Self {
28        Self {
29            client: Client::new(),
30            api_key: api_key.into(),
31            endpoint: endpoint.into(),
32            deployment: deployment.into(),
33        }
34    }
35}
36
37#[derive(Serialize)]
38struct OaiRequest<'a> {
39    model: &'a str,
40    messages: Vec<OaiMessage<'a>>,
41    #[serde(skip_serializing_if = "Option::is_none")]
42    max_tokens: Option<u32>,
43    #[serde(skip_serializing_if = "Option::is_none")]
44    temperature: Option<f32>,
45}
46
47#[derive(Serialize)]
48struct OaiMessage<'a> {
49    role: &'a str,
50    content: &'a str,
51}
52
53#[derive(Deserialize)]
54struct OaiResponse {
55    choices: Vec<OaiChoice>,
56    usage: OaiUsage,
57    model: String,
58}
59
60#[derive(Deserialize)]
61struct OaiChoice {
62    message: OaiChoiceMessage,
63    finish_reason: Option<String>,
64}
65
66#[derive(Deserialize)]
67struct OaiChoiceMessage {
68    content: Option<String>,
69}
70
71#[derive(Deserialize)]
72struct OaiUsage {
73    prompt_tokens: u32,
74    completion_tokens: u32,
75}
76
77fn role_str(role: &LlmRole) -> &'static str {
78    match role {
79        LlmRole::System => "system",
80        LlmRole::User | LlmRole::Tool => "user",
81        LlmRole::Assistant => "assistant",
82    }
83}
84
85#[async_trait]
86impl LlmProvider for AzureFoundryProvider {
87    async fn generate(&self, req: LlmRequest) -> Result<LlmResponse> {
88        use tracing::Instrument;
89
90        let span = tracing::info_span!(
91            "gen_ai.chat",
92            "gen_ai.system" = self.name(),
93            "gen_ai.operation.name" = "chat",
94            "gen_ai.request.model" = req.model.as_str(),
95            "gen_ai.request.max_tokens" = req.max_tokens,
96            "gen_ai.response.model" = tracing::field::Empty,
97            "gen_ai.usage.input_tokens" = tracing::field::Empty,
98            "gen_ai.usage.output_tokens" = tracing::field::Empty,
99        );
100
101        let url = format!(
102            "{}/models/chat/completions?api-version={API_VERSION}",
103            self.endpoint
104        );
105
106        let messages: Vec<OaiMessage> = req
107            .messages
108            .iter()
109            .map(|m| OaiMessage {
110                role: role_str(&m.role),
111                content: &m.content,
112            })
113            .collect();
114
115        // Deployment name is sent as the model in the request body
116        let body = OaiRequest {
117            model: &self.deployment,
118            messages,
119            max_tokens: req.max_tokens,
120            temperature: req.temperature,
121        };
122
123        debug!(deployment = %self.deployment, provider = "azure-foundry", "Azure AI Foundry request");
124
125        let result = async {
126            let resp = self
127                .client
128                .post(&url)
129                .header("api-key", &self.api_key)
130                .header("Content-Type", "application/json")
131                .json(&body)
132                .send()
133                .await?;
134
135            if !resp.status().is_success() {
136                let status = resp.status();
137                let text = resp.text().await.unwrap_or_default();
138                return Err(anyhow!("azure-foundry API error {status}: {text}"));
139            }
140
141            let data: OaiResponse = resp.json().await?;
142            let choice = data
143                .choices
144                .into_iter()
145                .next()
146                .ok_or_else(|| anyhow!("No choices in response"))?;
147
148            let stop_reason = match choice.finish_reason.as_deref() {
149                Some("tool_calls") => StopReason::ToolUse,
150                Some("length") => StopReason::MaxTokens,
151                Some("stop") | None => StopReason::EndTurn,
152                _ => StopReason::EndTurn,
153            };
154
155            Ok(LlmResponse {
156                content: choice.message.content.unwrap_or_default(),
157                tool_calls: vec![],
158                input_tokens: data.usage.prompt_tokens,
159                output_tokens: data.usage.completion_tokens,
160                model: data.model,
161                stop_reason,
162            })
163        }
164        .instrument(span.clone())
165        .await;
166
167        if let Ok(ref resp) = result {
168            span.record("gen_ai.response.model", resp.model.as_str());
169            span.record("gen_ai.usage.input_tokens", resp.input_tokens);
170            span.record("gen_ai.usage.output_tokens", resp.output_tokens);
171        }
172        result
173    }
174
175    fn name(&self) -> &str {
176        "azure-foundry"
177    }
178
179    fn supports_model(&self, model: &str) -> bool {
180        model.starts_with("foundry:")
181    }
182}