1use super::traits::*;
35use crate::types::*;
36use async_trait::async_trait;
37use futures::StreamExt;
38use reqwest_eventsource::EventSource;
39use serde::Deserialize;
40use tokio::sync::mpsc;
41use tracing::{debug, warn};
42
43pub struct AzureOpenAiProvider;
45
46#[async_trait]
47impl StreamProvider for AzureOpenAiProvider {
48 fn provider_id(&self) -> &str {
49 "azure"
50 }
51
52 async fn stream(
53 &self,
54 config: StreamConfig, tx: mpsc::UnboundedSender<StreamEvent>, cancel: tokio_util::sync::CancellationToken, ) -> Result<Message, ProviderError> {
58 let model_config = &config.model_config;
59 let api_key = model_config.resolve_api_key().await?;
61
62 let url = format!(
76 "{}/responses?api-version=2025-01-01-preview",
77 model_config.base_url
78 );
79
80 let body = build_azure_request_body(&config);
81 debug!(
82 "Azure OpenAI request: model={} url={}",
83 config.model_config.id, url
84 );
85
86 let client = reqwest::Client::new();
87 let mut request = client
88 .post(&url)
89 .header("content-type", "application/json")
90 .header("api-key", &api_key); for (k, v) in &model_config.headers {
93 request = request.header(k, v);
94 }
95
96 let request = request.json(&body);
97 let mut es =
98 EventSource::new(request).map_err(|e| ProviderError::Network(e.to_string()))?;
99
100 let mut content: Vec<Content> = Vec::new();
101 let mut usage = Usage::default();
102 let mut stop_reason = StopReason::Stop;
103 let mut tool_call_buffers: Vec<ToolCallBuffer> = Vec::new();
104
105 let _ = tx.send(StreamEvent::Start);
106
107 loop {
108 tokio::select! {
109 _ = cancel.cancelled() => {
110 es.close();
111 return Err(ProviderError::Cancelled);
112 }
113 event = es.next() => {
114 match event {
115 None => break,
116 Some(Ok(reqwest_eventsource::Event::Open)) => {}
117 Some(Ok(reqwest_eventsource::Event::Message(msg))) => {
118 match msg.event.as_str() {
119 "response.output_text.delta" => {
120 if let Ok(data) = serde_json::from_str::<DeltaEvent>(&msg.data) {
121 let idx = content.iter().position(|c| matches!(c, Content::Text { .. }));
122 let idx = match idx {
123 Some(i) => i,
124 None => {
125 content.push(Content::Text { text: String::new() });
126 content.len() - 1
127 }
128 };
129 if let Some(Content::Text { text }) = content.get_mut(idx) {
130 text.push_str(&data.delta);
131 }
132 let _ = tx.send(StreamEvent::TextDelta {
133 content_index: idx,
134 delta: data.delta,
135 });
136 }
137 }
138 "response.function_call_arguments.start" => {
139 if let Ok(data) = serde_json::from_str::<FnCallStartEvent>(&msg.data) {
140 tool_call_buffers.push(ToolCallBuffer {
141 id: data.call_id.unwrap_or_default(),
142 name: data.name.unwrap_or_default(),
143 arguments: String::new(),
144 });
145 let buf = tool_call_buffers.last().unwrap();
146 let _ = tx.send(StreamEvent::ToolCallStart {
147 content_index: content.len() + tool_call_buffers.len() - 1,
148 id: buf.id.clone(),
149 name: buf.name.clone(),
150 });
151 }
152 }
153 "response.function_call_arguments.delta" => {
154 if let Ok(data) = serde_json::from_str::<DeltaEvent>(&msg.data) {
155 if let Some(buf) = tool_call_buffers.last_mut() {
156 buf.arguments.push_str(&data.delta);
157 let _ = tx.send(StreamEvent::ToolCallDelta {
158 content_index: content.len() + tool_call_buffers.len() - 1,
159 delta: data.delta,
160 });
161 }
162 }
163 }
164 "response.completed" => {
165 if let Ok(data) = serde_json::from_str::<CompletedEvent>(&msg.data) {
166 if let Some(resp) = data.response {
167 if let Some(u) = resp.usage {
168 usage.input = u.input_tokens;
169 usage.output = u.output_tokens;
170 usage.total_tokens = u.total_tokens;
171 }
172 }
173 }
174 break;
175 }
176 "error" => {
177 warn!("Azure OpenAI error: {}", msg.data);
178 let err_msg = Message::Assistant {
179 content: vec![Content::Text { text: String::new() }],
180 stop_reason: StopReason::Error,
181 model: config.model_config.id.clone(),
182 provider: model_config.provider.clone(),
183 usage: usage.clone(),
184 timestamp: now_ms(),
185 error_message: Some(msg.data),
186 };
187 let _ = tx.send(StreamEvent::Error { message: err_msg.clone() });
188 return Ok(err_msg);
189 }
190 _ => {}
191 }
192 }
193 Some(Err(e)) => {
194 let err_str = e.to_string();
195 warn!("Azure SSE error: {}", err_str);
196 let err_msg = Message::Assistant {
197 content: vec![Content::Text { text: String::new() }],
198 stop_reason: StopReason::Error,
199 model: config.model_config.id.clone(),
200 provider: model_config.provider.clone(),
201 usage: usage.clone(),
202 timestamp: now_ms(),
203 error_message: Some(err_str),
204 };
205 let _ = tx.send(StreamEvent::Error { message: err_msg.clone() });
206 return Ok(err_msg);
207 }
208 }
209 }
210 }
211 }
212
213 for buf in &tool_call_buffers {
214 let args = serde_json::from_str(&buf.arguments)
215 .unwrap_or(serde_json::Value::Object(Default::default()));
216 content.push(Content::ToolCall {
217 id: buf.id.clone(),
218 name: buf.name.clone(),
219 arguments: args,
220 });
221 let _ = tx.send(StreamEvent::ToolCallEnd {
222 content_index: content.len() - 1,
223 });
224 }
225
226 if content
227 .iter()
228 .any(|c| matches!(c, Content::ToolCall { .. }))
229 {
230 stop_reason = StopReason::ToolUse;
231 }
232
233 let message = Message::Assistant {
234 content,
235 stop_reason,
236 model: config.model_config.id.clone(),
237 provider: model_config.provider.clone(),
238 usage,
239 timestamp: now_ms(),
240 error_message: None,
241 };
242
243 let _ = tx.send(StreamEvent::Done {
244 message: message.clone(),
245 });
246 Ok(message)
247 }
248}
249
250struct ToolCallBuffer {
251 id: String,
252 name: String,
253 arguments: String,
254}
255
256fn build_azure_request_body(config: &StreamConfig) -> serde_json::Value {
257 let mut input: Vec<serde_json::Value> = Vec::new();
259
260 for msg in &config.messages {
261 match msg {
262 Message::User { content, .. } => {
263 let user_content: Vec<serde_json::Value> = content
265 .iter()
266 .filter_map(|c| match c {
267 Content::Text { text } => Some(serde_json::json!({
268 "type": "input_text",
269 "text": text,
270 })),
271 Content::Image { data, mime_type } => Some(serde_json::json!({
272 "type": "input_image",
273 "image_url": format!("data:{};base64,{}", mime_type, data),
274 })),
275 _ => None,
276 })
277 .collect();
278
279 if user_content.len() == 1 && user_content[0]["type"] == "input_text" {
280 input.push(serde_json::json!({
282 "role": "user",
283 "content": user_content[0]["text"].as_str().unwrap_or(""),
284 }));
285 } else {
286 input.push(serde_json::json!({
288 "role": "user",
289 "content": user_content,
290 }));
291 }
292 }
293 Message::Assistant { content, .. } => {
294 for c in content {
295 match c {
296 Content::Text { text } => {
297 input.push(serde_json::json!({
298 "type": "message",
299 "role": "assistant",
300 "content": [{"type": "output_text", "text": text}],
301 }));
302 }
303 Content::ToolCall {
304 id,
305 name,
306 arguments,
307 } => {
308 input.push(serde_json::json!({
309 "type": "function_call",
310 "call_id": id,
311 "name": name,
312 "arguments": arguments.to_string(),
313 }));
314 }
315 _ => {}
316 }
317 }
318 }
319 Message::ToolResult {
320 tool_call_id,
321 content,
322 ..
323 } => {
324 let output_val = if content.iter().any(|c| matches!(c, Content::Image { .. })) {
325 let parts: Vec<serde_json::Value> = content
326 .iter()
327 .filter_map(|c| match c {
328 Content::Text { text } => Some(serde_json::json!({
329 "type": "input_text",
330 "text": text,
331 })),
332 Content::Image { data, mime_type } => Some(serde_json::json!({
333 "type": "input_image",
334 "image_url": format!("data:{};base64,{}", mime_type, data),
335 })),
336 _ => None,
337 })
338 .collect();
339 serde_json::json!(parts)
340 } else {
341 let text = content
342 .iter()
343 .find_map(|c| match c {
344 Content::Text { text } => Some(text.clone()),
345 _ => None,
346 })
347 .unwrap_or_default();
348 serde_json::json!(text)
349 };
350 input.push(serde_json::json!({
351 "type": "function_call_output",
352 "call_id": tool_call_id,
353 "output": output_val,
354 }));
355 }
356 }
357 }
358
359 let mut body = serde_json::json!({
360 "model": config.model_config.id,
361 "stream": true,
362 "input": input,
363 });
364
365 if !config.system_prompt.is_empty() {
366 body["instructions"] = serde_json::json!(config.system_prompt);
367 }
368
369 if let Some(max) = config.max_tokens {
370 body["max_output_tokens"] = serde_json::json!(max);
371 }
372
373 if !config.tools.is_empty() {
374 let tools: Vec<serde_json::Value> = config
375 .tools
376 .iter()
377 .map(|t| {
378 serde_json::json!({
379 "type": "function",
380 "name": t.name,
381 "description": t.description,
382 "parameters": t.parameters,
383 })
384 })
385 .collect();
386 body["tools"] = serde_json::json!(tools);
387 }
388
389 if let Some(temp) = config.temperature {
390 body["temperature"] = serde_json::json!(temp);
391 }
392
393 match &config.response_format {
396 ResponseFormat::Text => {} ResponseFormat::JsonObject => {
398 body["text"] = serde_json::json!({"format": {"type": "json_object"}});
399 }
400 ResponseFormat::JsonSchema {
401 schema,
402 name,
403 strict,
404 } => {
405 body["text"] = serde_json::json!({
406 "format": {
407 "type": "json_schema",
408 "name": name,
409 "schema": schema,
410 "strict": *strict,
411 },
412 });
413 }
414 }
415
416 body
417}
418
419#[derive(Deserialize)]
421struct DeltaEvent {
422 delta: String,
423}
424
425#[derive(Deserialize)]
426struct FnCallStartEvent {
427 #[serde(default)]
428 call_id: Option<String>,
429 #[serde(default)]
430 name: Option<String>,
431}
432
433#[derive(Deserialize)]
434struct CompletedEvent {
435 #[serde(default)]
436 response: Option<ResponseData>,
437}
438
439#[derive(Deserialize)]
440struct ResponseData {
441 #[serde(default)]
442 usage: Option<AzureUsage>,
443}
444
445#[derive(Deserialize)]
446struct AzureUsage {
447 #[serde(default)]
448 input_tokens: u64,
449 #[serde(default)]
450 output_tokens: u64,
451 #[serde(default)]
452 total_tokens: u64,
453}