1use async_trait::async_trait;
8use serde_json::Value;
9use std::sync::LazyLock;
10
11use prompty::interfaces::{Executor, InvokerError};
12use prompty::model::Prompty;
13use prompty::types::Message;
14
15use prompty_openai::wire;
16
17static HTTP_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(reqwest::Client::new);
19
20const DEFAULT_API_VERSION: &str = "2025-04-01-preview";
22
23pub struct FoundryExecutor;
31
32#[async_trait]
33impl Executor for FoundryExecutor {
34 async fn execute(&self, agent: &Prompty, messages: &[Message]) -> Result<Value, InvokerError> {
35 let api_type = agent
36 .model
37 .api_type
38 .as_ref()
39 .map(|t| t.as_str())
40 .unwrap_or("chat");
41
42 let body = match api_type {
43 "chat" | "agent" => wire::build_chat_args(agent, messages),
44 "embedding" => wire::build_embedding_args(agent, messages),
45 "image" => wire::build_image_args(agent, messages),
46 other => {
47 return Err(InvokerError::Execute(
48 format!("Unsupported apiType: {other}").into(),
49 ));
50 }
51 };
52
53 let (url, auth_header) = build_azure_request(agent, api_type).await?;
54
55 let client = &*HTTP_CLIENT;
56 let response = client
57 .post(&url)
58 .header(auth_header.0, auth_header.1)
59 .header("Content-Type", "application/json")
60 .json(&body)
61 .send()
62 .await
63 .map_err(|e| InvokerError::Execute(format!("HTTP request failed: {e}").into()))?;
64
65 if !response.status().is_success() {
66 let status = response.status();
67 let body_text = response
68 .text()
69 .await
70 .unwrap_or_else(|_| "unable to read body".to_string());
71 return Err(InvokerError::Execute(
72 format!("Azure OpenAI API error (HTTP {status}): {body_text}").into(),
73 ));
74 }
75
76 let result: Value = response
77 .json()
78 .await
79 .map_err(|e| InvokerError::Execute(format!("Failed to parse response: {e}").into()))?;
80
81 Ok(result)
82 }
83
84 fn format_tool_messages(
85 &self,
86 _raw_response: &Value,
87 tool_calls: &[prompty::types::ToolCall],
88 tool_results: &[String],
89 _text_content: Option<&str>,
90 ) -> Vec<Message> {
91 wire::format_tool_messages(tool_calls, tool_results)
92 }
93
94 async fn execute_stream(
95 &self,
96 agent: &Prompty,
97 messages: &[Message],
98 ) -> Result<std::pin::Pin<Box<dyn futures::Stream<Item = Value> + Send>>, InvokerError> {
99 let api_type = agent
100 .model
101 .api_type
102 .as_ref()
103 .map(|t| t.as_str())
104 .unwrap_or("chat");
105 if api_type != "chat" && api_type != "agent" {
106 return Err(InvokerError::Execute(
107 format!("Foundry streaming only supports apiType 'chat', got: {api_type}").into(),
108 ));
109 }
110
111 let mut body = wire::build_chat_args(agent, messages);
112 if let Some(obj) = body.as_object_mut() {
114 obj.insert("stream".into(), Value::Bool(true));
115 }
116
117 let (url, auth_header) = build_azure_request(agent, api_type).await?;
118
119 let client = &*HTTP_CLIENT;
120 let response = client
121 .post(&url)
122 .header(auth_header.0, auth_header.1)
123 .header("Content-Type", "application/json")
124 .json(&body)
125 .send()
126 .await
127 .map_err(|e| InvokerError::Execute(format!("HTTP request failed: {e}").into()))?;
128
129 if !response.status().is_success() {
130 let status = response.status();
131 let body_text = response
132 .text()
133 .await
134 .unwrap_or_else(|_| "unable to read body".to_string());
135 return Err(InvokerError::Execute(
136 format!("Azure OpenAI API error (HTTP {status}): {body_text}").into(),
137 ));
138 }
139
140 let byte_stream = response.bytes_stream();
142 Ok(Box::pin(FoundrySseParser::new(byte_stream)))
143 }
144}
145
146fn resolve_connection(
153 agent: &Prompty,
154) -> Result<std::borrow::Cow<'_, serde_json::Value>, InvokerError> {
155 let conn = &agent.model.connection;
156 let kind = conn.get("kind").and_then(|k| k.as_str()).unwrap_or("");
157
158 if kind == "reference" {
159 let name = conn.get("name").and_then(|n| n.as_str()).ok_or_else(|| {
160 InvokerError::Execute(
161 "Reference connection missing 'name' field"
162 .to_string()
163 .into(),
164 )
165 })?;
166
167 let resolved =
168 prompty::connections::with_connection::<serde_json::Value, _>(name, |c| c.clone())
169 .map_err(|e| InvokerError::Execute(e.into()))?;
170
171 Ok(std::borrow::Cow::Owned(resolved))
172 } else {
173 Ok(std::borrow::Cow::Borrowed(conn))
174 }
175}
176
177async fn build_azure_request(
179 agent: &Prompty,
180 api_type: &str,
181) -> Result<(String, (&'static str, String)), InvokerError> {
182 let endpoint = get_endpoint(agent)?;
183 let deployment = get_deployment(agent)?;
184
185 let path = match api_type {
186 "chat" | "agent" => "chat/completions",
187 "embedding" => "embeddings",
188 "image" => "images/generations",
189 other => {
190 return Err(InvokerError::Execute(
191 format!("Unsupported apiType for Azure: {other}").into(),
192 ));
193 }
194 };
195
196 let conn = resolve_connection(agent)?;
197 let kind = conn.get("kind").and_then(|v| v.as_str()).unwrap_or("");
198 let url = if kind == "foundry" {
199 format!("{}/{}", endpoint.trim_end_matches('/'), path)
200 } else {
201 let api_version = get_api_version(agent);
202 format!(
203 "{}/openai/deployments/{}/{}?api-version={}",
204 endpoint.trim_end_matches('/'),
205 deployment,
206 path,
207 api_version,
208 )
209 };
210
211 let auth_header = get_auth_header(agent).await?;
212
213 Ok((url, auth_header))
214}
215
216fn get_endpoint(agent: &Prompty) -> Result<String, InvokerError> {
218 let conn = resolve_connection(agent)?;
219 let kind = conn.get("kind").and_then(|v| v.as_str()).unwrap_or("");
220
221 if let Some(ep) = conn.get("endpoint").and_then(|v| v.as_str()) {
223 if !ep.is_empty() {
224 return match kind {
225 "foundry" => Ok(strip_project_path(ep)),
226 _ => Ok(ep.to_string()),
227 };
228 }
229 }
230
231 if let Ok(ep) = std::env::var("AZURE_OPENAI_ENDPOINT") {
233 if !ep.is_empty() {
234 return Ok(ep);
235 }
236 }
237
238 Err(InvokerError::Execute(
239 "No Azure OpenAI endpoint found. Set AZURE_OPENAI_ENDPOINT or configure model.connection.endpoint"
240 .to_string()
241 .into(),
242 ))
243}
244
245fn strip_project_path(endpoint: &str) -> String {
250 let base = endpoint
251 .find("/api/projects")
252 .map(|idx| &endpoint[..idx])
253 .unwrap_or(endpoint)
254 .trim_end_matches('/');
255
256 let Some((scheme, rest)) = base.split_once("://") else {
257 return base.to_string();
258 };
259 let authority = rest.split_once('/').map(|(host, _)| host).unwrap_or(rest);
260 let (host, port) = match authority.rsplit_once(':') {
261 Some((host, port)) if port.chars().all(|c| c.is_ascii_digit()) => {
262 (host.to_string(), format!(":{port}"))
263 }
264 _ => (authority.to_string(), String::new()),
265 };
266
267 let host = host
268 .strip_suffix(".services.ai.azure.com")
269 .map(|resource| format!("{resource}.openai.azure.com"))
270 .unwrap_or(host);
271
272 format!("{scheme}://{host}{port}/openai/v1")
273}
274
275fn get_deployment(agent: &Prompty) -> Result<String, InvokerError> {
277 if !agent.model.id.is_empty() {
279 return Ok(agent.model.id.clone());
280 }
281
282 if let Ok(deployment) = std::env::var("AZURE_OPENAI_DEPLOYMENT") {
284 if !deployment.is_empty() {
285 return Ok(deployment);
286 }
287 }
288
289 Err(InvokerError::Execute(
290 "No deployment name found. Set model.id or AZURE_OPENAI_DEPLOYMENT"
291 .to_string()
292 .into(),
293 ))
294}
295
296fn get_api_version(agent: &Prompty) -> String {
298 if let Some(opts) = &agent.model.options {
300 if let Some(version) = opts
301 .additional_properties
302 .get("apiVersion")
303 .and_then(|v| v.as_str())
304 {
305 return version.to_string();
306 }
307 }
308
309 DEFAULT_API_VERSION.to_string()
310}
311
312async fn get_auth_header(agent: &Prompty) -> Result<(&'static str, String), InvokerError> {
318 let conn = resolve_connection(agent)?;
319 let kind = conn.get("kind").and_then(|k| k.as_str()).unwrap_or("");
320
321 if let Some(key) = conn
323 .get("apiKey")
324 .or(conn.get("api_key"))
325 .and_then(|k| k.as_str())
326 {
327 if !key.is_empty() {
328 return if kind == "foundry" {
329 Ok(("Authorization", format!("Bearer {key}")))
330 } else {
331 Ok(("api-key", key.to_string()))
332 };
333 }
334 }
335
336 if kind == "foundry" {
337 if let Ok(key) = std::env::var("AZURE_INFERENCE_CREDENTIAL") {
338 if !key.is_empty() {
339 return Ok(("Authorization", format!("Bearer {key}")));
340 }
341 }
342 return get_entra_token().await;
343 }
344
345 if let Ok(key) = std::env::var("AZURE_OPENAI_API_KEY") {
347 if !key.is_empty() {
348 return Ok(("api-key", key));
349 }
350 }
351
352 Err(InvokerError::Execute(
353 "No Azure API key found. Set AZURE_OPENAI_API_KEY or configure model.connection.apiKey"
354 .to_string()
355 .into(),
356 ))
357}
358
359#[cfg(feature = "entra_id")]
361const FOUNDRY_TOKEN_SCOPE: &str = "https://ai.azure.com/.default";
362
363#[cfg(feature = "entra_id")]
365async fn get_entra_token() -> Result<(&'static str, String), InvokerError> {
366 use azure_core::credentials::TokenCredential;
367 use azure_identity::DefaultAzureCredential;
368
369 let credential = DefaultAzureCredential::new().map_err(|e| {
370 InvokerError::Execute(format!("Failed to create DefaultAzureCredential: {e}").into())
371 })?;
372 let token = credential
373 .get_token(&[FOUNDRY_TOKEN_SCOPE])
374 .await
375 .map_err(|e| {
376 InvokerError::Execute(format!("Failed to acquire Entra ID token: {e}").into())
377 })?;
378 Ok(("Authorization", format!("Bearer {}", token.token.secret())))
379}
380
381#[cfg(not(feature = "entra_id"))]
383async fn get_entra_token() -> Result<(&'static str, String), InvokerError> {
384 Err(InvokerError::Execute(
385 "Foundry connection requires Entra ID auth. Enable the 'entra_id' feature on prompty-foundry, \
386 or provide an API key in model.connection.apiKey"
387 .to_string()
388 .into(),
389 ))
390}
391
392use std::collections::VecDeque;
397use std::pin::Pin;
398use std::task::{Context, Poll};
399
400use bytes::Bytes;
401use futures::Stream;
402
403struct FoundrySseParser {
405 inner: Pin<Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send>>,
406 buffer: String,
407 pending: VecDeque<Value>,
408 done: bool,
409}
410
411impl FoundrySseParser {
412 fn new(inner: impl Stream<Item = Result<Bytes, reqwest::Error>> + Send + 'static) -> Self {
413 Self {
414 inner: Box::pin(inner),
415 buffer: String::new(),
416 pending: VecDeque::new(),
417 done: false,
418 }
419 }
420
421 fn parse_buffer(&mut self) {
422 while let Some(pos) = self.buffer.find("\n\n") {
423 let event = self.buffer[..pos].to_string();
424 self.buffer = self.buffer[pos + 2..].to_string();
425
426 for line in event.lines() {
427 if let Some(data) = line
428 .strip_prefix("data: ")
429 .or_else(|| line.strip_prefix("data:"))
430 {
431 let data = data.trim();
432 if data == "[DONE]" {
433 self.done = true;
434 return;
435 }
436 match serde_json::from_str::<Value>(data) {
437 Ok(parsed) => self.pending.push_back(parsed),
438 Err(e) => {
439 self.pending.push_back(serde_json::json!({
440 "error": {
441 "type": "sse_parse_error",
442 "message": format!("Failed to parse SSE data: {e}"),
443 "raw": data,
444 }
445 }));
446 }
447 }
448 }
449 }
450 }
451 }
452}
453
454impl Stream for FoundrySseParser {
455 type Item = Value;
456
457 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
458 loop {
459 if let Some(item) = self.pending.pop_front() {
460 return Poll::Ready(Some(item));
461 }
462 if self.done {
463 return Poll::Ready(None);
464 }
465
466 match self.inner.as_mut().poll_next(cx) {
467 Poll::Ready(Some(Ok(bytes))) => {
468 match std::str::from_utf8(&bytes) {
469 Ok(text) => self.buffer.push_str(text),
470 Err(e) => {
471 self.pending.push_back(serde_json::json!({
472 "error": {
473 "type": "sse_decode_error",
474 "message": format!("Invalid UTF-8 in SSE stream: {e}"),
475 }
476 }));
477 }
478 }
479 self.parse_buffer();
480 }
481 Poll::Ready(Some(Err(e))) => {
482 self.pending.push_back(serde_json::json!({
483 "error": {
484 "type": "sse_transport_error",
485 "message": format!("SSE stream error: {e}"),
486 }
487 }));
488 self.done = true;
489 if let Some(item) = self.pending.pop_front() {
490 return Poll::Ready(Some(item));
491 }
492 return Poll::Ready(None);
493 }
494 Poll::Ready(None) => {
495 self.done = true;
496 return Poll::Ready(None);
497 }
498 Poll::Pending => {
499 return Poll::Pending;
500 }
501 }
502 }
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509 use prompty::model::context::LoadContext;
510 use serde_json::json;
511 use serial_test::serial;
512
513 fn make_agent(model_json: Value) -> Prompty {
514 let mut data = json!({
515 "name": "test",
516 "kind": "prompt",
517 "model": model_json,
518 });
519 data["instructions"] = json!("test");
520 Prompty::load_from_value(&data, &LoadContext::default())
521 }
522
523 #[tokio::test]
524 #[serial]
525 async fn test_build_url_api_key_connection() {
526 let agent = make_agent(json!({
527 "id": "gpt-4",
528 "connection": {
529 "kind": "key",
530 "endpoint": "https://myresource.openai.azure.com",
531 "apiKey": "test-key"
532 }
533 }));
534 let (url, _) = build_azure_request(&agent, "chat").await.unwrap();
535 assert!(url.starts_with(
536 "https://myresource.openai.azure.com/openai/deployments/gpt-4/chat/completions"
537 ));
538 assert!(url.contains("api-version="));
539 }
540
541 #[tokio::test]
542 #[serial]
543 async fn test_build_url_foundry_connection() {
544 unsafe { std::env::set_var("AZURE_INFERENCE_CREDENTIAL", "test-foundry-key") };
548 let agent = make_agent(json!({
549 "id": "gpt-4o",
550 "connection": {
551 "kind": "foundry",
552 "endpoint": "https://myresource.services.ai.azure.com/api/projects/my-project",
553 "name": "my-conn"
554 }
555 }));
556 let (url, _) = build_azure_request(&agent, "chat").await.unwrap();
557 assert!(url.starts_with("https://myresource.openai.azure.com/openai/v1/chat/completions"));
559 unsafe { std::env::remove_var("AZURE_INFERENCE_CREDENTIAL") };
560 }
561
562 #[tokio::test]
563 #[serial]
564 async fn test_build_url_embedding() {
565 let agent = make_agent(json!({
566 "id": "text-embedding-3-small",
567 "connection": {
568 "kind": "key",
569 "endpoint": "https://myresource.openai.azure.com",
570 "apiKey": "test-key"
571 }
572 }));
573 let (url, _) = build_azure_request(&agent, "embedding").await.unwrap();
574 assert!(url.contains("/embeddings?"));
575 }
576
577 #[tokio::test]
578 #[serial]
579 async fn test_build_url_image() {
580 let agent = make_agent(json!({
581 "id": "dall-e-3",
582 "connection": {
583 "kind": "key",
584 "endpoint": "https://myresource.openai.azure.com",
585 "apiKey": "test-key"
586 }
587 }));
588 let (url, _) = build_azure_request(&agent, "image").await.unwrap();
589 assert!(url.contains("/images/generations?"));
590 }
591
592 #[tokio::test]
593 #[serial]
594 async fn test_auth_header_api_key() {
595 let agent = make_agent(json!({
596 "id": "gpt-4",
597 "connection": {
598 "kind": "key",
599 "endpoint": "https://myresource.openai.azure.com",
600 "apiKey": "my-azure-key"
601 }
602 }));
603 let (name, value) = get_auth_header(&agent).await.unwrap();
604 assert_eq!(name, "api-key");
605 assert_eq!(value, "my-azure-key");
606 }
607
608 #[test]
609 #[serial]
610 fn test_strip_project_path() {
611 assert_eq!(
612 strip_project_path("https://myresource.services.ai.azure.com/api/projects/my-project"),
613 "https://myresource.openai.azure.com/openai/v1"
614 );
615 assert_eq!(
616 strip_project_path("https://myresource.openai.azure.com"),
617 "https://myresource.openai.azure.com/openai/v1"
618 );
619 assert_eq!(
620 strip_project_path("https://myresource.openai.azure.com/openai/v1"),
621 "https://myresource.openai.azure.com/openai/v1"
622 );
623 }
624
625 #[test]
626 #[serial]
627 fn test_deployment_from_model_id() {
628 let agent = make_agent(json!({
629 "id": "my-deployment-name",
630 "connection": {
631 "kind": "key",
632 "endpoint": "https://myresource.openai.azure.com",
633 "apiKey": "key"
634 }
635 }));
636 let deployment = get_deployment(&agent).unwrap();
637 assert_eq!(deployment, "my-deployment-name");
638 }
639
640 #[test]
641 #[serial]
642 fn test_api_version_default() {
643 let agent = make_agent(json!({
644 "id": "gpt-4",
645 "connection": {
646 "kind": "key",
647 "endpoint": "https://myresource.openai.azure.com",
648 "apiKey": "key"
649 }
650 }));
651 let version = get_api_version(&agent);
652 assert_eq!(version, DEFAULT_API_VERSION);
653 }
654
655 #[tokio::test]
656 #[serial]
657 async fn test_unsupported_api_type() {
658 let agent = make_agent(json!({
659 "id": "gpt-4",
660 "connection": {
661 "kind": "key",
662 "endpoint": "https://myresource.openai.azure.com",
663 "apiKey": "key"
664 }
665 }));
666 let result = build_azure_request(&agent, "unknown").await;
667 assert!(result.is_err());
668 }
669
670 #[test]
673 #[serial]
674 fn test_resolve_connection_passthrough() {
675 let agent = make_agent(json!({
676 "id": "gpt-4",
677 "connection": {
678 "kind": "key",
679 "endpoint": "https://myresource.openai.azure.com",
680 "apiKey": "test-key"
681 }
682 }));
683 let conn = resolve_connection(&agent).unwrap();
684 assert_eq!(conn.get("kind").unwrap().as_str().unwrap(), "key");
685 assert_eq!(conn.get("apiKey").unwrap().as_str().unwrap(), "test-key");
686 }
687
688 #[test]
689 #[serial]
690 fn test_resolve_connection_reference_missing_name() {
691 let agent = make_agent(json!({
692 "id": "gpt-4",
693 "connection": { "kind": "reference" }
694 }));
695 let result = resolve_connection(&agent);
696 assert!(result.is_err());
697 assert!(result.unwrap_err().to_string().contains("name"));
698 }
699
700 #[test]
701 #[serial]
702 fn test_resolve_connection_reference_success() {
703 prompty::connections::clear_connections();
704 prompty::connections::register_connection(
705 "azure-prod",
706 json!({
707 "kind": "key",
708 "endpoint": "https://prod.openai.azure.com",
709 "apiKey": "prod-key"
710 }),
711 );
712
713 let agent = make_agent(json!({
714 "id": "gpt-4",
715 "connection": { "kind": "reference", "name": "azure-prod" }
716 }));
717
718 let conn = resolve_connection(&agent).unwrap();
719 assert_eq!(
720 conn.get("endpoint").unwrap().as_str().unwrap(),
721 "https://prod.openai.azure.com"
722 );
723 assert_eq!(conn.get("apiKey").unwrap().as_str().unwrap(), "prod-key");
724
725 prompty::connections::clear_connections();
726 }
727
728 #[tokio::test]
729 #[serial]
730 async fn test_reference_connection_flows_to_auth_header() {
731 prompty::connections::clear_connections();
732 prompty::connections::register_connection(
733 "azure-resolved",
734 json!({
735 "kind": "key",
736 "endpoint": "https://resolved.openai.azure.com",
737 "apiKey": "resolved-key"
738 }),
739 );
740
741 let agent = make_agent(json!({
742 "id": "gpt-4",
743 "connection": { "kind": "reference", "name": "azure-resolved" }
744 }));
745
746 let (header_name, header_value) = get_auth_header(&agent).await.unwrap();
747 assert_eq!(header_name, "api-key");
748 assert_eq!(header_value, "resolved-key");
749
750 prompty::connections::clear_connections();
751 }
752
753 #[tokio::test]
756 #[serial]
757 async fn test_auth_header_foundry_no_key_no_entra() {
758 prompty::connections::clear_connections();
759 unsafe { std::env::remove_var("AZURE_OPENAI_API_KEY") };
762
763 let agent = make_agent(json!({
764 "id": "gpt-4",
765 "connection": {
766 "kind": "foundry",
767 "endpoint": "https://resource.services.ai.azure.com/api/projects/proj"
768 }
769 }));
770
771 let result = get_auth_header(&agent).await;
772 assert!(result.is_err());
775 }
776}