use super::config::AzureAIConfig;
use super::convert;
use crate::retry::{RetryConfig, execute_with_retry, is_retryable_model_error};
use adk_core::{
AdkError, ErrorCategory, ErrorComponent, Llm, LlmRequest, LlmResponse, LlmResponseStream, Part,
};
use async_stream::try_stream;
use async_trait::async_trait;
use futures::StreamExt;
use reqwest::Client;
use serde_json::Value;
pub struct AzureAIClient {
client: Client,
endpoint: String,
api_key: String,
model: String,
retry_config: RetryConfig,
}
impl AzureAIClient {
pub fn new(config: AzureAIConfig) -> Result<Self, AdkError> {
let client = Client::builder()
.build()
.map_err(|e| AdkError::model(format!("Failed to create HTTP client: {e}")))?;
Ok(Self {
client,
endpoint: config.endpoint,
api_key: config.api_key,
model: config.model,
retry_config: RetryConfig::default(),
})
}
#[must_use]
pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
self.retry_config = retry_config;
self
}
pub fn set_retry_config(&mut self, retry_config: RetryConfig) {
self.retry_config = retry_config;
}
pub fn retry_config(&self) -> &RetryConfig {
&self.retry_config
}
fn api_url(&self) -> String {
format!(
"{}/chat/completions?api-version=2024-05-01-preview",
self.endpoint.trim_end_matches('/')
)
}
}
#[async_trait]
impl Llm for AzureAIClient {
fn name(&self) -> &str {
&self.model
}
async fn generate_content(
&self,
request: LlmRequest,
stream: bool,
) -> Result<LlmResponseStream, AdkError> {
let usage_span = adk_telemetry::llm_generate_span("azure-ai", &self.model, stream);
let api_url = self.api_url();
let api_key = self.api_key.clone();
let model = self.model.clone();
let endpoint = self.endpoint.clone();
let client = self.client.clone();
let retry_config = self.retry_config.clone();
let body = convert::build_request_body(
&model,
&request.contents,
&request.tools,
request.config.as_ref(),
stream,
);
let response_stream = try_stream! {
let response = execute_with_retry(&retry_config, is_retryable_model_error, || {
let client = client.clone();
let api_url = api_url.clone();
let api_key = api_key.clone();
let body = body.clone();
let endpoint = endpoint.clone();
async move {
let resp = client
.post(&api_url)
.header("api-key", &api_key)
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| AdkError::new(
ErrorComponent::Model,
ErrorCategory::Unavailable,
"model.azure_ai.request",
format!("Azure AI error for endpoint={endpoint}: {e}"),
).with_provider("azure-ai"))?;
if !resp.status().is_success() {
let status = resp.status();
let status_code = status.as_u16();
let error_text = resp.text().await.unwrap_or_default();
let category = match status_code {
401 => ErrorCategory::Unauthorized,
403 => ErrorCategory::Forbidden,
404 => ErrorCategory::NotFound,
408 => ErrorCategory::Timeout,
429 => ErrorCategory::RateLimited,
503 | 529 => ErrorCategory::Unavailable,
_ if status_code >= 500 => ErrorCategory::Internal,
_ => ErrorCategory::InvalidInput,
};
return Err(AdkError::new(
ErrorComponent::Model,
category,
"model.azure_ai.api_error",
format!("Azure AI error for endpoint={endpoint}, status={status}: {error_text}"),
).with_upstream_status(status_code).with_provider("azure-ai"));
}
Ok(resp)
}
})
.await?;
if stream {
let mut byte_stream = response.bytes_stream();
let mut buffer = String::new();
let mut tool_call_accumulators: std::collections::HashMap<u32, (String, String, String)> =
std::collections::HashMap::new();
while let Some(chunk_result) = byte_stream.next().await {
let chunk = chunk_result
.map_err(|e| AdkError::model(format!("Azure AI stream error: {e}")))?;
buffer.push_str(&String::from_utf8_lossy(&chunk));
while let Some(line_end) = buffer.find('\n') {
let line = buffer[..line_end].trim().to_string();
buffer = buffer[line_end + 1..].to_string();
if line.is_empty() || line == "data: [DONE]" {
continue;
}
if let Some(data) = line.strip_prefix("data: ") {
match serde_json::from_str::<Value>(data) {
Ok(chunk_json) => {
accumulate_tool_calls(&chunk_json, &mut tool_call_accumulators);
let llm_resp = convert::parse_sse_chunk(&chunk_json);
if llm_resp.turn_complete {
if !tool_call_accumulators.is_empty() {
let mut sorted: Vec<_> =
tool_call_accumulators.drain().collect();
sorted.sort_by_key(|(idx, _)| *idx);
let parts: Vec<Part> = sorted
.into_iter()
.map(|(_, (id, name, args_str))| {
let args: Value =
serde_json::from_str(&args_str)
.unwrap_or(serde_json::json!({}));
Part::FunctionCall {
name,
args,
id: Some(id),
thought_signature: None,
}
})
.collect();
yield LlmResponse {
content: Some(adk_core::Content {
role: "model".to_string(),
parts,
}),
finish_reason: llm_resp.finish_reason,
partial: false,
turn_complete: true,
..Default::default()
};
continue;
}
yield llm_resp;
} else if llm_resp.content.is_some() {
yield llm_resp;
}
}
Err(e) => {
tracing::warn!("failed to parse Azure AI chunk: {e} - {data}");
}
}
}
}
}
} else {
let response_text = response.text().await
.map_err(|e| AdkError::model(format!(
"Azure AI response parse failed: {e}"
)))?;
let response_json: Value = serde_json::from_str(&response_text)
.map_err(|e| AdkError::model(format!(
"Azure AI response parse failed: {e}"
)))?;
yield convert::parse_response(&response_json);
}
};
Ok(crate::usage_tracking::with_usage_tracking(Box::pin(response_stream), usage_span))
}
}
fn accumulate_tool_calls(
chunk: &Value,
accumulators: &mut std::collections::HashMap<u32, (String, String, String)>,
) {
let Some(tool_calls) = chunk
.get("choices")
.and_then(|c| c.get(0))
.and_then(|choice| choice.get("delta"))
.and_then(|delta| delta.get("tool_calls"))
.and_then(|tc| tc.as_array())
else {
return;
};
for tc in tool_calls {
let index = tc.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as u32;
let entry = accumulators.entry(index).or_insert_with(|| {
let id = tc.get("id").and_then(|i| i.as_str()).unwrap_or("").to_string();
(id, String::new(), String::new())
});
if let Some(id) = tc.get("id").and_then(|i| i.as_str()) {
if !id.is_empty() {
entry.0 = id.to_string();
}
}
if let Some(func) = tc.get("function") {
if let Some(name) = func.get("name").and_then(|n| n.as_str()) {
if !name.is_empty() {
entry.1 = name.to_string();
}
}
if let Some(args) = func.get("arguments").and_then(|a| a.as_str()) {
entry.2.push_str(args);
}
}
}
}