1use std::time::Duration;
20
21use reqwest::Client;
22use serde::{Deserialize, Serialize};
23
24mod embedding;
25
26#[cfg(feature = "servicebus")]
27pub mod servicebus;
28
29pub use embedding::AzureOpenAiEmbedding;
30
31#[cfg(feature = "servicebus")]
32pub use servicebus::ServiceBusBroker;
33
34use daimon_core::{
35 ChatRequest, ChatResponse, DaimonError, Message, Model, ResponseStream, Result, Role,
36 StopReason, StreamEvent, ToolCall, ToolSpec, Usage,
37};
38
39const DEFAULT_API_VERSION: &str = "2024-10-21";
40const DEFAULT_MAX_RETRIES: u32 = 3;
41
42fn build_client(timeout: Option<Duration>) -> Client {
43 let mut builder = Client::builder();
44 if let Some(t) = timeout {
45 builder = builder.timeout(t);
46 }
47 builder.build().expect("failed to build HTTP client")
48}
49
50#[derive(Debug)]
55pub struct AzureOpenAi {
56 client: Client,
57 api_key: String,
58 resource_url: String,
59 deployment_id: String,
60 api_version: String,
61 timeout: Option<Duration>,
62 max_retries: u32,
63 use_bearer_token: bool,
64}
65
66impl AzureOpenAi {
67 pub fn new(resource_url: impl Into<String>, deployment_id: impl Into<String>) -> Self {
69 let api_key = std::env::var("AZURE_OPENAI_API_KEY").unwrap_or_default();
70 Self::with_api_key(resource_url, deployment_id, api_key)
71 }
72
73 pub fn with_api_key(
75 resource_url: impl Into<String>,
76 deployment_id: impl Into<String>,
77 api_key: impl Into<String>,
78 ) -> Self {
79 Self {
80 client: build_client(None),
81 api_key: api_key.into(),
82 resource_url: resource_url.into().trim_end_matches('/').to_string(),
83 deployment_id: deployment_id.into(),
84 api_version: DEFAULT_API_VERSION.to_string(),
85 timeout: None,
86 max_retries: DEFAULT_MAX_RETRIES,
87 use_bearer_token: false,
88 }
89 }
90
91 pub fn with_api_version(mut self, version: impl Into<String>) -> Self {
93 self.api_version = version.into();
94 self
95 }
96
97 pub fn with_timeout(mut self, timeout: Duration) -> Self {
99 self.timeout = Some(timeout);
100 self.client = build_client(Some(timeout));
101 self
102 }
103
104 pub fn with_max_retries(mut self, retries: u32) -> Self {
106 self.max_retries = retries;
107 self
108 }
109
110 pub fn with_bearer_token(mut self) -> Self {
114 self.use_bearer_token = true;
115 self
116 }
117
118 fn endpoint_url(&self) -> String {
119 format!(
120 "{}/openai/deployments/{}/chat/completions",
121 self.resource_url, self.deployment_id
122 )
123 }
124
125 fn apply_auth(&self, req: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
126 if self.use_bearer_token {
127 req.bearer_auth(&self.api_key)
128 } else {
129 req.header("api-key", &self.api_key)
130 }
131 }
132
133 fn build_request_body(&self, request: &ChatRequest, stream: bool) -> AzureRequest {
134 let messages: Vec<AzureMessage> = request.messages.iter().map(Into::into).collect();
135
136 let tools: Option<Vec<AzureTool>> = if request.tools.is_empty() {
137 None
138 } else {
139 Some(request.tools.iter().map(Into::into).collect())
140 };
141
142 AzureRequest {
143 messages,
144 tools,
145 temperature: request.temperature,
146 max_tokens: request.max_tokens,
147 stream,
148 }
149 }
150}
151
152impl Model for AzureOpenAi {
153 #[tracing::instrument(skip_all, fields(deployment = %self.deployment_id))]
154 async fn generate(&self, request: &ChatRequest) -> Result<ChatResponse> {
155 let body = self.build_request_body(request, false);
156 let url = self.endpoint_url();
157
158 for attempt in 0..=self.max_retries {
159 let req = self
160 .client
161 .post(&url)
162 .query(&[("api-version", &self.api_version)])
163 .json(&body);
164 let req = self.apply_auth(req);
165
166 tracing::debug!(attempt, "sending Azure OpenAI request");
167 let response = req
168 .send()
169 .await
170 .map_err(|e| DaimonError::Model(format!("Azure OpenAI HTTP error: {e}")))?;
171 let status = response.status();
172
173 if status.is_success() {
174 let api_resp: AzureResponse = response
175 .json()
176 .await
177 .map_err(|e| {
178 DaimonError::Model(format!("Azure OpenAI response parse error: {e}"))
179 })?;
180 tracing::debug!("received successful Azure OpenAI response");
181 return parse_response(api_resp);
182 }
183
184 let text = response.text().await.unwrap_or_default();
185 let is_retryable = status.as_u16() == 429 || status.is_server_error();
186
187 if is_retryable && attempt < self.max_retries {
188 let delay_ms = 100 * 2u64.pow(attempt);
189 tracing::debug!(status = %status, attempt, delay_ms, "retryable error, backing off");
190 tokio::time::sleep(Duration::from_millis(delay_ms)).await;
191 } else {
192 return Err(DaimonError::Model(format!(
193 "Azure OpenAI API error ({status}): {text}"
194 )));
195 }
196 }
197
198 unreachable!("loop always returns or retries")
199 }
200
201 #[tracing::instrument(skip_all, fields(deployment = %self.deployment_id))]
202 async fn generate_stream(&self, request: &ChatRequest) -> Result<ResponseStream> {
203 let body = self.build_request_body(request, true);
204 let url = self.endpoint_url();
205
206 let req = self
207 .client
208 .post(&url)
209 .query(&[("api-version", &self.api_version)])
210 .json(&body);
211 let req = self.apply_auth(req);
212
213 tracing::debug!("sending Azure OpenAI streaming request");
214 let response = req
215 .send()
216 .await
217 .map_err(|e| DaimonError::Model(format!("Azure OpenAI HTTP error: {e}")))?;
218
219 if !response.status().is_success() {
220 let status = response.status();
221 let text = response.text().await.unwrap_or_default();
222 return Err(DaimonError::Model(format!(
223 "Azure OpenAI API error ({status}): {text}"
224 )));
225 }
226
227 tracing::debug!("Azure OpenAI stream established");
228 let byte_stream = response.bytes_stream();
229
230 let stream = async_stream::try_stream! {
231 use futures::StreamExt;
232
233 let mut buffer = String::new();
234 let mut stream = Box::pin(byte_stream);
235
236 while let Some(chunk) = stream.next().await {
237 let chunk = chunk.map_err(|e| DaimonError::Model(format!("Azure OpenAI stream error: {e}")))?;
238 buffer.push_str(&String::from_utf8_lossy(&chunk));
239
240 while let Some(line_end) = buffer.find('\n') {
241 let line = buffer[..line_end].trim().to_string();
242 buffer = buffer[line_end + 1..].to_string();
243
244 if line.is_empty() || line == "data: [DONE]" {
245 if line == "data: [DONE]" {
246 yield StreamEvent::Done;
247 }
248 continue;
249 }
250
251 if let Some(data) = line.strip_prefix("data: ") {
252 if let Ok(chunk) = serde_json::from_str::<AzureStreamChunk>(data) {
253 for choice in &chunk.choices {
254 if let Some(ref content) = choice.delta.content {
255 if !content.is_empty() {
256 yield StreamEvent::TextDelta(content.clone());
257 }
258 }
259 if let Some(ref tool_calls) = choice.delta.tool_calls {
260 for tc in tool_calls {
261 if let Some(ref func) = tc.function {
262 if let Some(ref name) = func.name {
263 yield StreamEvent::ToolCallStart {
264 id: tc.index.to_string(),
265 name: name.clone(),
266 };
267 }
268 if let Some(ref args) = func.arguments {
269 if !args.is_empty() {
270 yield StreamEvent::ToolCallDelta {
271 id: tc.index.to_string(),
272 arguments_delta: args.clone(),
273 };
274 }
275 }
276 }
277 }
278 }
279 }
280 }
281 }
282 }
283 }
284 };
285
286 Ok(Box::pin(stream))
287 }
288}
289
290fn parse_response(response: AzureResponse) -> Result<ChatResponse> {
291 let choice = response
292 .choices
293 .into_iter()
294 .next()
295 .ok_or_else(|| DaimonError::Model("no choices in Azure OpenAI response".into()))?;
296
297 let tool_calls: Vec<ToolCall> = choice
298 .message
299 .tool_calls
300 .unwrap_or_default()
301 .into_iter()
302 .map(|tc| ToolCall {
303 id: tc.id,
304 name: tc.function.name,
305 arguments: serde_json::from_str(&tc.function.arguments).unwrap_or_default(),
306 })
307 .collect();
308
309 let stop_reason = match choice.finish_reason.as_deref() {
310 Some("tool_calls") => StopReason::ToolUse,
311 Some("length") => StopReason::MaxTokens,
312 _ => StopReason::EndTurn,
313 };
314
315 let message = Message {
316 role: Role::Assistant,
317 content: choice.message.content,
318 tool_calls,
319 tool_call_id: None,
320 };
321
322 Ok(ChatResponse {
323 message,
324 stop_reason,
325 usage: response.usage.map(|u| Usage {
326 input_tokens: u.prompt_tokens,
327 output_tokens: u.completion_tokens,
328 cached_tokens: u
329 .prompt_tokens_details
330 .map(|d| d.cached_tokens)
331 .unwrap_or(0),
332 }),
333 })
334}
335
336#[derive(Serialize)]
339struct AzureRequest {
340 messages: Vec<AzureMessage>,
341 #[serde(skip_serializing_if = "Option::is_none")]
342 tools: Option<Vec<AzureTool>>,
343 #[serde(skip_serializing_if = "Option::is_none")]
344 temperature: Option<f32>,
345 #[serde(skip_serializing_if = "Option::is_none")]
346 max_tokens: Option<u32>,
347 stream: bool,
348}
349
350#[derive(Serialize, Deserialize)]
351struct AzureMessage {
352 role: String,
353 #[serde(skip_serializing_if = "Option::is_none")]
354 content: Option<String>,
355 #[serde(skip_serializing_if = "Option::is_none")]
356 tool_calls: Option<Vec<AzureToolCall>>,
357 #[serde(skip_serializing_if = "Option::is_none")]
358 tool_call_id: Option<String>,
359}
360
361impl From<&Message> for AzureMessage {
362 fn from(msg: &Message) -> Self {
363 let role = match msg.role {
364 Role::System => "system",
365 Role::User => "user",
366 Role::Assistant => "assistant",
367 Role::Tool => "tool",
368 };
369
370 let tool_calls = if msg.tool_calls.is_empty() {
371 None
372 } else {
373 Some(
374 msg.tool_calls
375 .iter()
376 .map(|tc| AzureToolCall {
377 id: tc.id.clone(),
378 r#type: "function".to_string(),
379 function: AzureFunction {
380 name: tc.name.clone(),
381 arguments: serde_json::to_string(&tc.arguments).unwrap_or_default(),
382 },
383 index: 0,
384 })
385 .collect(),
386 )
387 };
388
389 Self {
390 role: role.to_string(),
391 content: msg.content.clone(),
392 tool_calls,
393 tool_call_id: msg.tool_call_id.clone(),
394 }
395 }
396}
397
398#[derive(Serialize)]
399struct AzureTool {
400 r#type: String,
401 function: AzureToolFunction,
402}
403
404impl From<&ToolSpec> for AzureTool {
405 fn from(spec: &ToolSpec) -> Self {
406 Self {
407 r#type: "function".to_string(),
408 function: AzureToolFunction {
409 name: spec.name.clone(),
410 description: spec.description.clone(),
411 parameters: spec.parameters.clone(),
412 },
413 }
414 }
415}
416
417#[derive(Serialize)]
418struct AzureToolFunction {
419 name: String,
420 description: String,
421 parameters: serde_json::Value,
422}
423
424#[derive(Deserialize)]
425struct AzureResponse {
426 choices: Vec<AzureChoice>,
427 usage: Option<AzureUsage>,
428}
429
430#[derive(Deserialize)]
431struct AzureChoice {
432 message: AzureChoiceMessage,
433 finish_reason: Option<String>,
434}
435
436#[derive(Deserialize)]
437struct AzureChoiceMessage {
438 content: Option<String>,
439 tool_calls: Option<Vec<AzureToolCall>>,
440}
441
442#[derive(Serialize, Deserialize)]
443struct AzureToolCall {
444 #[serde(default)]
445 id: String,
446 #[serde(default)]
447 r#type: String,
448 #[serde(default)]
449 function: AzureFunction,
450 #[serde(default)]
451 index: usize,
452}
453
454#[derive(Serialize, Deserialize, Default)]
455struct AzureFunction {
456 #[serde(default)]
457 name: String,
458 #[serde(default)]
459 arguments: String,
460}
461
462#[derive(Deserialize)]
463struct AzureUsage {
464 prompt_tokens: u32,
465 completion_tokens: u32,
466 prompt_tokens_details: Option<AzurePromptTokensDetails>,
467}
468
469#[derive(Deserialize)]
470struct AzurePromptTokensDetails {
471 #[serde(default)]
472 cached_tokens: u32,
473}
474
475#[derive(Deserialize)]
476struct AzureStreamChunk {
477 choices: Vec<AzureStreamChoice>,
478}
479
480#[derive(Deserialize)]
481struct AzureStreamChoice {
482 delta: AzureStreamDelta,
483}
484
485#[derive(Deserialize)]
486struct AzureStreamDelta {
487 content: Option<String>,
488 tool_calls: Option<Vec<AzureStreamToolCall>>,
489}
490
491#[derive(Deserialize)]
492struct AzureStreamToolCall {
493 index: usize,
494 function: Option<AzureStreamFunction>,
495}
496
497#[derive(Deserialize)]
498struct AzureStreamFunction {
499 name: Option<String>,
500 arguments: Option<String>,
501}
502
503#[cfg(test)]
504mod tests {
505 use super::*;
506
507 #[test]
508 fn test_azure_new_default() {
509 let model = AzureOpenAi::new("https://my-resource.openai.azure.com", "gpt-4o");
510 assert_eq!(model.deployment_id, "gpt-4o");
511 assert_eq!(
512 model.resource_url,
513 "https://my-resource.openai.azure.com"
514 );
515 assert_eq!(model.api_version, DEFAULT_API_VERSION);
516 assert_eq!(model.max_retries, DEFAULT_MAX_RETRIES);
517 assert!(!model.use_bearer_token);
518 }
519
520 #[test]
521 fn test_resource_url_trailing_slash_stripped() {
522 let model = AzureOpenAi::new("https://my-resource.openai.azure.com/", "gpt-4o");
523 assert_eq!(
524 model.resource_url,
525 "https://my-resource.openai.azure.com"
526 );
527 }
528
529 #[test]
530 fn test_endpoint_url() {
531 let model = AzureOpenAi::with_api_key(
532 "https://my-resource.openai.azure.com",
533 "gpt-4o",
534 "key",
535 );
536 assert_eq!(
537 model.endpoint_url(),
538 "https://my-resource.openai.azure.com/openai/deployments/gpt-4o/chat/completions"
539 );
540 }
541
542 #[test]
543 fn test_with_api_version() {
544 let model = AzureOpenAi::new("https://x.openai.azure.com", "gpt-4o")
545 .with_api_version("2025-01-01");
546 assert_eq!(model.api_version, "2025-01-01");
547 }
548
549 #[test]
550 fn test_with_timeout() {
551 let model = AzureOpenAi::new("https://x.openai.azure.com", "gpt-4o")
552 .with_timeout(Duration::from_secs(60));
553 assert_eq!(model.timeout, Some(Duration::from_secs(60)));
554 }
555
556 #[test]
557 fn test_with_max_retries() {
558 let model = AzureOpenAi::new("https://x.openai.azure.com", "gpt-4o")
559 .with_max_retries(10);
560 assert_eq!(model.max_retries, 10);
561 }
562
563 #[test]
564 fn test_with_bearer_token() {
565 let model =
566 AzureOpenAi::new("https://x.openai.azure.com", "gpt-4o").with_bearer_token();
567 assert!(model.use_bearer_token);
568 }
569
570 #[test]
571 fn test_message_conversion_user() {
572 let msg = Message::user("hello");
573 let azure: AzureMessage = (&msg).into();
574 assert_eq!(azure.role, "user");
575 assert_eq!(azure.content.as_deref(), Some("hello"));
576 assert!(azure.tool_calls.is_none());
577 }
578
579 #[test]
580 fn test_message_conversion_tool_result() {
581 let msg = Message::tool_result("tc_1", "42");
582 let azure: AzureMessage = (&msg).into();
583 assert_eq!(azure.role, "tool");
584 assert_eq!(azure.tool_call_id.as_deref(), Some("tc_1"));
585 }
586
587 #[test]
588 fn test_message_conversion_assistant_with_tools() {
589 let msg = Message::assistant_with_tool_calls(vec![ToolCall {
590 id: "tc_1".into(),
591 name: "calc".into(),
592 arguments: serde_json::json!({"x": 1}),
593 }]);
594 let azure: AzureMessage = (&msg).into();
595 assert_eq!(azure.role, "assistant");
596 assert!(azure.tool_calls.is_some());
597 assert_eq!(azure.tool_calls.unwrap().len(), 1);
598 }
599
600 #[test]
601 fn test_tool_spec_conversion() {
602 let spec = ToolSpec {
603 name: "search".into(),
604 description: "Web search".into(),
605 parameters: serde_json::json!({"type": "object"}),
606 };
607 let tool: AzureTool = (&spec).into();
608 assert_eq!(tool.r#type, "function");
609 assert_eq!(tool.function.name, "search");
610 }
611
612 #[test]
613 fn test_parse_response_text() {
614 let raw = AzureResponse {
615 choices: vec![AzureChoice {
616 message: AzureChoiceMessage {
617 content: Some("hello".into()),
618 tool_calls: None,
619 },
620 finish_reason: Some("stop".into()),
621 }],
622 usage: Some(AzureUsage {
623 prompt_tokens: 10,
624 completion_tokens: 5,
625 prompt_tokens_details: None,
626 }),
627 };
628 let resp = parse_response(raw).unwrap();
629 assert_eq!(resp.text(), "hello");
630 assert_eq!(resp.stop_reason, StopReason::EndTurn);
631 assert!(!resp.has_tool_calls());
632 assert_eq!(resp.usage.unwrap().input_tokens, 10);
633 }
634
635 #[test]
636 fn test_parse_response_tool_calls() {
637 let raw = AzureResponse {
638 choices: vec![AzureChoice {
639 message: AzureChoiceMessage {
640 content: None,
641 tool_calls: Some(vec![AzureToolCall {
642 id: "tc_1".into(),
643 r#type: "function".into(),
644 function: AzureFunction {
645 name: "calc".into(),
646 arguments: r#"{"x":1}"#.into(),
647 },
648 index: 0,
649 }]),
650 },
651 finish_reason: Some("tool_calls".into()),
652 }],
653 usage: None,
654 };
655 let resp = parse_response(raw).unwrap();
656 assert!(resp.has_tool_calls());
657 assert_eq!(resp.tool_calls()[0].name, "calc");
658 assert_eq!(resp.stop_reason, StopReason::ToolUse);
659 }
660
661 #[test]
662 fn test_parse_response_no_choices() {
663 let raw = AzureResponse {
664 choices: vec![],
665 usage: None,
666 };
667 assert!(parse_response(raw).is_err());
668 }
669
670 #[test]
671 fn test_builder_chain() {
672 let model = AzureOpenAi::with_api_key("https://x.openai.azure.com", "gpt-4o", "key")
673 .with_api_version("2025-01-01")
674 .with_timeout(Duration::from_secs(30))
675 .with_max_retries(5)
676 .with_bearer_token();
677
678 assert_eq!(model.deployment_id, "gpt-4o");
679 assert_eq!(model.api_version, "2025-01-01");
680 assert_eq!(model.timeout, Some(Duration::from_secs(30)));
681 assert_eq!(model.max_retries, 5);
682 assert!(model.use_bearer_token);
683 }
684}