Skip to main content

daimon_provider_azure/
lib.rs

1//! Azure OpenAI model provider for the [Daimon](https://docs.rs/daimon) agent framework.
2//!
3//! The API wire format is identical to OpenAI but uses a different URL
4//! structure and supports both API key and Microsoft Entra ID (Azure AD)
5//! bearer token authentication.
6//!
7//! # Example
8//!
9//! ```ignore
10//! use daimon_provider_azure::AzureOpenAi;
11//! use daimon_core::Model;
12//!
13//! let model = AzureOpenAi::new(
14//!     "https://my-resource.openai.azure.com",
15//!     "gpt-4o",
16//! );
17//! ```
18
19use std::time::Duration;
20
21use reqwest::Client;
22use serde::{Deserialize, Serialize};
23
24mod embedding;
25
26#[cfg(feature = "servicebus")]
27pub mod servicebus;
28
29pub use embedding::AzureOpenAiEmbedding;
30
31#[cfg(feature = "servicebus")]
32pub use servicebus::ServiceBusBroker;
33
34use daimon_core::{
35    ChatRequest, ChatResponse, DaimonError, Message, Model, ResponseStream, Result, Role,
36    StopReason, StreamEvent, ToolCall, ToolSpec, Usage,
37};
38
39const DEFAULT_API_VERSION: &str = "2024-10-21";
40const DEFAULT_MAX_RETRIES: u32 = 3;
41
42fn build_client(timeout: Option<Duration>) -> Client {
43    let mut builder = Client::builder();
44    if let Some(t) = timeout {
45        builder = builder.timeout(t);
46    }
47    builder.build().expect("failed to build HTTP client")
48}
49
50/// Azure OpenAI model provider.
51///
52/// Connects to an Azure OpenAI deployment. Authentication is via API key
53/// (default, using the `api-key` header) or Microsoft Entra ID bearer token.
54#[derive(Debug)]
55pub struct AzureOpenAi {
56    client: Client,
57    api_key: String,
58    resource_url: String,
59    deployment_id: String,
60    api_version: String,
61    timeout: Option<Duration>,
62    max_retries: u32,
63    use_bearer_token: bool,
64}
65
66impl AzureOpenAi {
67    /// Create a new Azure OpenAI client, reading `AZURE_OPENAI_API_KEY` from the environment.
68    pub fn new(resource_url: impl Into<String>, deployment_id: impl Into<String>) -> Self {
69        let api_key = std::env::var("AZURE_OPENAI_API_KEY").unwrap_or_default();
70        Self::with_api_key(resource_url, deployment_id, api_key)
71    }
72
73    /// Create a new Azure OpenAI client with an explicit API key.
74    pub fn with_api_key(
75        resource_url: impl Into<String>,
76        deployment_id: impl Into<String>,
77        api_key: impl Into<String>,
78    ) -> Self {
79        Self {
80            client: build_client(None),
81            api_key: api_key.into(),
82            resource_url: resource_url.into().trim_end_matches('/').to_string(),
83            deployment_id: deployment_id.into(),
84            api_version: DEFAULT_API_VERSION.to_string(),
85            timeout: None,
86            max_retries: DEFAULT_MAX_RETRIES,
87            use_bearer_token: false,
88        }
89    }
90
91    /// Set the Azure OpenAI API version (default: `2024-10-21`).
92    pub fn with_api_version(mut self, version: impl Into<String>) -> Self {
93        self.api_version = version.into();
94        self
95    }
96
97    /// Set an HTTP timeout for requests.
98    pub fn with_timeout(mut self, timeout: Duration) -> Self {
99        self.timeout = Some(timeout);
100        self.client = build_client(Some(timeout));
101        self
102    }
103
104    /// Set the maximum number of retries for transient errors.
105    pub fn with_max_retries(mut self, retries: u32) -> Self {
106        self.max_retries = retries;
107        self
108    }
109
110    /// Use `Authorization: Bearer <token>` instead of `api-key` header.
111    ///
112    /// Required for Microsoft Entra ID (Azure AD) authentication.
113    pub fn with_bearer_token(mut self) -> Self {
114        self.use_bearer_token = true;
115        self
116    }
117
118    fn endpoint_url(&self) -> String {
119        format!(
120            "{}/openai/deployments/{}/chat/completions",
121            self.resource_url, self.deployment_id
122        )
123    }
124
125    fn apply_auth(&self, req: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
126        if self.use_bearer_token {
127            req.bearer_auth(&self.api_key)
128        } else {
129            req.header("api-key", &self.api_key)
130        }
131    }
132
133    fn build_request_body(&self, request: &ChatRequest, stream: bool) -> AzureRequest {
134        let messages: Vec<AzureMessage> = request.messages.iter().map(Into::into).collect();
135
136        let tools: Option<Vec<AzureTool>> = if request.tools.is_empty() {
137            None
138        } else {
139            Some(request.tools.iter().map(Into::into).collect())
140        };
141
142        AzureRequest {
143            messages,
144            tools,
145            temperature: request.temperature,
146            max_tokens: request.max_tokens,
147            stream,
148        }
149    }
150}
151
152impl Model for AzureOpenAi {
153    #[tracing::instrument(skip_all, fields(deployment = %self.deployment_id))]
154    async fn generate(&self, request: &ChatRequest) -> Result<ChatResponse> {
155        let body = self.build_request_body(request, false);
156        let url = self.endpoint_url();
157
158        for attempt in 0..=self.max_retries {
159            let req = self
160                .client
161                .post(&url)
162                .query(&[("api-version", &self.api_version)])
163                .json(&body);
164            let req = self.apply_auth(req);
165
166            tracing::debug!(attempt, "sending Azure OpenAI request");
167            let response = req
168                .send()
169                .await
170                .map_err(|e| DaimonError::Model(format!("Azure OpenAI HTTP error: {e}")))?;
171            let status = response.status();
172
173            if status.is_success() {
174                let api_resp: AzureResponse = response
175                    .json()
176                    .await
177                    .map_err(|e| {
178                        DaimonError::Model(format!("Azure OpenAI response parse error: {e}"))
179                    })?;
180                tracing::debug!("received successful Azure OpenAI response");
181                return parse_response(api_resp);
182            }
183
184            let text = response.text().await.unwrap_or_default();
185            let is_retryable = status.as_u16() == 429 || status.is_server_error();
186
187            if is_retryable && attempt < self.max_retries {
188                let delay_ms = 100 * 2u64.pow(attempt);
189                tracing::debug!(status = %status, attempt, delay_ms, "retryable error, backing off");
190                tokio::time::sleep(Duration::from_millis(delay_ms)).await;
191            } else {
192                return Err(DaimonError::Model(format!(
193                    "Azure OpenAI API error ({status}): {text}"
194                )));
195            }
196        }
197
198        unreachable!("loop always returns or retries")
199    }
200
201    #[tracing::instrument(skip_all, fields(deployment = %self.deployment_id))]
202    async fn generate_stream(&self, request: &ChatRequest) -> Result<ResponseStream> {
203        let body = self.build_request_body(request, true);
204        let url = self.endpoint_url();
205
206        let req = self
207            .client
208            .post(&url)
209            .query(&[("api-version", &self.api_version)])
210            .json(&body);
211        let req = self.apply_auth(req);
212
213        tracing::debug!("sending Azure OpenAI streaming request");
214        let response = req
215            .send()
216            .await
217            .map_err(|e| DaimonError::Model(format!("Azure OpenAI HTTP error: {e}")))?;
218
219        if !response.status().is_success() {
220            let status = response.status();
221            let text = response.text().await.unwrap_or_default();
222            return Err(DaimonError::Model(format!(
223                "Azure OpenAI API error ({status}): {text}"
224            )));
225        }
226
227        tracing::debug!("Azure OpenAI stream established");
228        let byte_stream = response.bytes_stream();
229
230        let stream = async_stream::try_stream! {
231            use futures::StreamExt;
232
233            let mut buffer = String::new();
234            let mut stream = Box::pin(byte_stream);
235
236            while let Some(chunk) = stream.next().await {
237                let chunk = chunk.map_err(|e| DaimonError::Model(format!("Azure OpenAI stream error: {e}")))?;
238                buffer.push_str(&String::from_utf8_lossy(&chunk));
239
240                while let Some(line_end) = buffer.find('\n') {
241                    let line = buffer[..line_end].trim().to_string();
242                    buffer = buffer[line_end + 1..].to_string();
243
244                    if line.is_empty() || line == "data: [DONE]" {
245                        if line == "data: [DONE]" {
246                            yield StreamEvent::Done;
247                        }
248                        continue;
249                    }
250
251                    if let Some(data) = line.strip_prefix("data: ") {
252                        if let Ok(chunk) = serde_json::from_str::<AzureStreamChunk>(data) {
253                            for choice in &chunk.choices {
254                                if let Some(ref content) = choice.delta.content {
255                                    if !content.is_empty() {
256                                        yield StreamEvent::TextDelta(content.clone());
257                                    }
258                                }
259                                if let Some(ref tool_calls) = choice.delta.tool_calls {
260                                    for tc in tool_calls {
261                                        if let Some(ref func) = tc.function {
262                                            if let Some(ref name) = func.name {
263                                                yield StreamEvent::ToolCallStart {
264                                                    id: tc.index.to_string(),
265                                                    name: name.clone(),
266                                                };
267                                            }
268                                            if let Some(ref args) = func.arguments {
269                                                if !args.is_empty() {
270                                                    yield StreamEvent::ToolCallDelta {
271                                                        id: tc.index.to_string(),
272                                                        arguments_delta: args.clone(),
273                                                    };
274                                                }
275                                            }
276                                        }
277                                    }
278                                }
279                            }
280                        }
281                    }
282                }
283            }
284        };
285
286        Ok(Box::pin(stream))
287    }
288}
289
290fn parse_response(response: AzureResponse) -> Result<ChatResponse> {
291    let choice = response
292        .choices
293        .into_iter()
294        .next()
295        .ok_or_else(|| DaimonError::Model("no choices in Azure OpenAI response".into()))?;
296
297    let tool_calls: Vec<ToolCall> = choice
298        .message
299        .tool_calls
300        .unwrap_or_default()
301        .into_iter()
302        .map(|tc| ToolCall {
303            id: tc.id,
304            name: tc.function.name,
305            arguments: serde_json::from_str(&tc.function.arguments).unwrap_or_default(),
306        })
307        .collect();
308
309    let stop_reason = match choice.finish_reason.as_deref() {
310        Some("tool_calls") => StopReason::ToolUse,
311        Some("length") => StopReason::MaxTokens,
312        _ => StopReason::EndTurn,
313    };
314
315    let message = Message {
316        role: Role::Assistant,
317        content: choice.message.content,
318        tool_calls,
319        tool_call_id: None,
320    };
321
322    Ok(ChatResponse {
323        message,
324        stop_reason,
325        usage: response.usage.map(|u| Usage {
326            input_tokens: u.prompt_tokens,
327            output_tokens: u.completion_tokens,
328            cached_tokens: u
329                .prompt_tokens_details
330                .map(|d| d.cached_tokens)
331                .unwrap_or(0),
332        }),
333    })
334}
335
336// --- Azure OpenAI API types ---
337
338#[derive(Serialize)]
339struct AzureRequest {
340    messages: Vec<AzureMessage>,
341    #[serde(skip_serializing_if = "Option::is_none")]
342    tools: Option<Vec<AzureTool>>,
343    #[serde(skip_serializing_if = "Option::is_none")]
344    temperature: Option<f32>,
345    #[serde(skip_serializing_if = "Option::is_none")]
346    max_tokens: Option<u32>,
347    stream: bool,
348}
349
350#[derive(Serialize, Deserialize)]
351struct AzureMessage {
352    role: String,
353    #[serde(skip_serializing_if = "Option::is_none")]
354    content: Option<String>,
355    #[serde(skip_serializing_if = "Option::is_none")]
356    tool_calls: Option<Vec<AzureToolCall>>,
357    #[serde(skip_serializing_if = "Option::is_none")]
358    tool_call_id: Option<String>,
359}
360
361impl From<&Message> for AzureMessage {
362    fn from(msg: &Message) -> Self {
363        let role = match msg.role {
364            Role::System => "system",
365            Role::User => "user",
366            Role::Assistant => "assistant",
367            Role::Tool => "tool",
368        };
369
370        let tool_calls = if msg.tool_calls.is_empty() {
371            None
372        } else {
373            Some(
374                msg.tool_calls
375                    .iter()
376                    .map(|tc| AzureToolCall {
377                        id: tc.id.clone(),
378                        r#type: "function".to_string(),
379                        function: AzureFunction {
380                            name: tc.name.clone(),
381                            arguments: serde_json::to_string(&tc.arguments).unwrap_or_default(),
382                        },
383                        index: 0,
384                    })
385                    .collect(),
386            )
387        };
388
389        Self {
390            role: role.to_string(),
391            content: msg.content.clone(),
392            tool_calls,
393            tool_call_id: msg.tool_call_id.clone(),
394        }
395    }
396}
397
398#[derive(Serialize)]
399struct AzureTool {
400    r#type: String,
401    function: AzureToolFunction,
402}
403
404impl From<&ToolSpec> for AzureTool {
405    fn from(spec: &ToolSpec) -> Self {
406        Self {
407            r#type: "function".to_string(),
408            function: AzureToolFunction {
409                name: spec.name.clone(),
410                description: spec.description.clone(),
411                parameters: spec.parameters.clone(),
412            },
413        }
414    }
415}
416
417#[derive(Serialize)]
418struct AzureToolFunction {
419    name: String,
420    description: String,
421    parameters: serde_json::Value,
422}
423
424#[derive(Deserialize)]
425struct AzureResponse {
426    choices: Vec<AzureChoice>,
427    usage: Option<AzureUsage>,
428}
429
430#[derive(Deserialize)]
431struct AzureChoice {
432    message: AzureChoiceMessage,
433    finish_reason: Option<String>,
434}
435
436#[derive(Deserialize)]
437struct AzureChoiceMessage {
438    content: Option<String>,
439    tool_calls: Option<Vec<AzureToolCall>>,
440}
441
442#[derive(Serialize, Deserialize)]
443struct AzureToolCall {
444    #[serde(default)]
445    id: String,
446    #[serde(default)]
447    r#type: String,
448    #[serde(default)]
449    function: AzureFunction,
450    #[serde(default)]
451    index: usize,
452}
453
454#[derive(Serialize, Deserialize, Default)]
455struct AzureFunction {
456    #[serde(default)]
457    name: String,
458    #[serde(default)]
459    arguments: String,
460}
461
462#[derive(Deserialize)]
463struct AzureUsage {
464    prompt_tokens: u32,
465    completion_tokens: u32,
466    prompt_tokens_details: Option<AzurePromptTokensDetails>,
467}
468
469#[derive(Deserialize)]
470struct AzurePromptTokensDetails {
471    #[serde(default)]
472    cached_tokens: u32,
473}
474
475#[derive(Deserialize)]
476struct AzureStreamChunk {
477    choices: Vec<AzureStreamChoice>,
478}
479
480#[derive(Deserialize)]
481struct AzureStreamChoice {
482    delta: AzureStreamDelta,
483}
484
485#[derive(Deserialize)]
486struct AzureStreamDelta {
487    content: Option<String>,
488    tool_calls: Option<Vec<AzureStreamToolCall>>,
489}
490
491#[derive(Deserialize)]
492struct AzureStreamToolCall {
493    index: usize,
494    function: Option<AzureStreamFunction>,
495}
496
497#[derive(Deserialize)]
498struct AzureStreamFunction {
499    name: Option<String>,
500    arguments: Option<String>,
501}
502
503#[cfg(test)]
504mod tests {
505    use super::*;
506
507    #[test]
508    fn test_azure_new_default() {
509        let model = AzureOpenAi::new("https://my-resource.openai.azure.com", "gpt-4o");
510        assert_eq!(model.deployment_id, "gpt-4o");
511        assert_eq!(
512            model.resource_url,
513            "https://my-resource.openai.azure.com"
514        );
515        assert_eq!(model.api_version, DEFAULT_API_VERSION);
516        assert_eq!(model.max_retries, DEFAULT_MAX_RETRIES);
517        assert!(!model.use_bearer_token);
518    }
519
520    #[test]
521    fn test_resource_url_trailing_slash_stripped() {
522        let model = AzureOpenAi::new("https://my-resource.openai.azure.com/", "gpt-4o");
523        assert_eq!(
524            model.resource_url,
525            "https://my-resource.openai.azure.com"
526        );
527    }
528
529    #[test]
530    fn test_endpoint_url() {
531        let model = AzureOpenAi::with_api_key(
532            "https://my-resource.openai.azure.com",
533            "gpt-4o",
534            "key",
535        );
536        assert_eq!(
537            model.endpoint_url(),
538            "https://my-resource.openai.azure.com/openai/deployments/gpt-4o/chat/completions"
539        );
540    }
541
542    #[test]
543    fn test_with_api_version() {
544        let model = AzureOpenAi::new("https://x.openai.azure.com", "gpt-4o")
545            .with_api_version("2025-01-01");
546        assert_eq!(model.api_version, "2025-01-01");
547    }
548
549    #[test]
550    fn test_with_timeout() {
551        let model = AzureOpenAi::new("https://x.openai.azure.com", "gpt-4o")
552            .with_timeout(Duration::from_secs(60));
553        assert_eq!(model.timeout, Some(Duration::from_secs(60)));
554    }
555
556    #[test]
557    fn test_with_max_retries() {
558        let model = AzureOpenAi::new("https://x.openai.azure.com", "gpt-4o")
559            .with_max_retries(10);
560        assert_eq!(model.max_retries, 10);
561    }
562
563    #[test]
564    fn test_with_bearer_token() {
565        let model =
566            AzureOpenAi::new("https://x.openai.azure.com", "gpt-4o").with_bearer_token();
567        assert!(model.use_bearer_token);
568    }
569
570    #[test]
571    fn test_message_conversion_user() {
572        let msg = Message::user("hello");
573        let azure: AzureMessage = (&msg).into();
574        assert_eq!(azure.role, "user");
575        assert_eq!(azure.content.as_deref(), Some("hello"));
576        assert!(azure.tool_calls.is_none());
577    }
578
579    #[test]
580    fn test_message_conversion_tool_result() {
581        let msg = Message::tool_result("tc_1", "42");
582        let azure: AzureMessage = (&msg).into();
583        assert_eq!(azure.role, "tool");
584        assert_eq!(azure.tool_call_id.as_deref(), Some("tc_1"));
585    }
586
587    #[test]
588    fn test_message_conversion_assistant_with_tools() {
589        let msg = Message::assistant_with_tool_calls(vec![ToolCall {
590            id: "tc_1".into(),
591            name: "calc".into(),
592            arguments: serde_json::json!({"x": 1}),
593        }]);
594        let azure: AzureMessage = (&msg).into();
595        assert_eq!(azure.role, "assistant");
596        assert!(azure.tool_calls.is_some());
597        assert_eq!(azure.tool_calls.unwrap().len(), 1);
598    }
599
600    #[test]
601    fn test_tool_spec_conversion() {
602        let spec = ToolSpec {
603            name: "search".into(),
604            description: "Web search".into(),
605            parameters: serde_json::json!({"type": "object"}),
606        };
607        let tool: AzureTool = (&spec).into();
608        assert_eq!(tool.r#type, "function");
609        assert_eq!(tool.function.name, "search");
610    }
611
612    #[test]
613    fn test_parse_response_text() {
614        let raw = AzureResponse {
615            choices: vec![AzureChoice {
616                message: AzureChoiceMessage {
617                    content: Some("hello".into()),
618                    tool_calls: None,
619                },
620                finish_reason: Some("stop".into()),
621            }],
622            usage: Some(AzureUsage {
623                prompt_tokens: 10,
624                completion_tokens: 5,
625                prompt_tokens_details: None,
626            }),
627        };
628        let resp = parse_response(raw).unwrap();
629        assert_eq!(resp.text(), "hello");
630        assert_eq!(resp.stop_reason, StopReason::EndTurn);
631        assert!(!resp.has_tool_calls());
632        assert_eq!(resp.usage.unwrap().input_tokens, 10);
633    }
634
635    #[test]
636    fn test_parse_response_tool_calls() {
637        let raw = AzureResponse {
638            choices: vec![AzureChoice {
639                message: AzureChoiceMessage {
640                    content: None,
641                    tool_calls: Some(vec![AzureToolCall {
642                        id: "tc_1".into(),
643                        r#type: "function".into(),
644                        function: AzureFunction {
645                            name: "calc".into(),
646                            arguments: r#"{"x":1}"#.into(),
647                        },
648                        index: 0,
649                    }]),
650                },
651                finish_reason: Some("tool_calls".into()),
652            }],
653            usage: None,
654        };
655        let resp = parse_response(raw).unwrap();
656        assert!(resp.has_tool_calls());
657        assert_eq!(resp.tool_calls()[0].name, "calc");
658        assert_eq!(resp.stop_reason, StopReason::ToolUse);
659    }
660
661    #[test]
662    fn test_parse_response_no_choices() {
663        let raw = AzureResponse {
664            choices: vec![],
665            usage: None,
666        };
667        assert!(parse_response(raw).is_err());
668    }
669
670    #[test]
671    fn test_builder_chain() {
672        let model = AzureOpenAi::with_api_key("https://x.openai.azure.com", "gpt-4o", "key")
673            .with_api_version("2025-01-01")
674            .with_timeout(Duration::from_secs(30))
675            .with_max_retries(5)
676            .with_bearer_token();
677
678        assert_eq!(model.deployment_id, "gpt-4o");
679        assert_eq!(model.api_version, "2025-01-01");
680        assert_eq!(model.timeout, Some(Duration::from_secs(30)));
681        assert_eq!(model.max_retries, 5);
682        assert!(model.use_bearer_token);
683    }
684}