Skip to main content

entelix_cloud/foundry/
transport.rs

1//! `FoundryTransport` — `entelix_core::transports::Transport` over
2//! Azure AI Foundry. Two auth modes:
3//! - [`FoundryAuth::ApiKey`]: static `api-key` header (the bridge
4//!   path most operators use day-1).
5//! - [`FoundryAuth::Entra`]: OAuth via `azure_identity` flowing
6//!   through [`crate::refresh::CachedTokenProvider`].
7
8use std::sync::Arc;
9
10use bytes::Bytes;
11use futures::StreamExt;
12use secrecy::{ExposeSecret, SecretString};
13
14use entelix_core::codecs::EncodedRequest;
15use entelix_core::context::ExecutionContext;
16use entelix_core::error::{Error, Result};
17use entelix_core::transports::{Transport, TransportResponse, TransportStream};
18
19use crate::CloudError;
20use crate::refresh::{CachedTokenProvider, TokenRefresher};
21
22/// Auth strategy for [`FoundryTransport`].
23#[derive(Clone)]
24#[non_exhaustive]
25pub enum FoundryAuth {
26    /// Static API key — sent as `api-key: {value}`.
27    ApiKey {
28        /// Pre-resolved key, redacted in `Debug`.
29        token: SecretString,
30    },
31    /// Entra ID (Azure AD) OAuth — token resolved through the
32    /// supplied refresher.
33    Entra {
34        /// Refresher driven by `azure_identity`.
35        refresher: Arc<dyn TokenRefresher<SecretString>>,
36    },
37}
38
39impl std::fmt::Debug for FoundryAuth {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        match self {
42            Self::ApiKey { .. } => f.write_str("FoundryAuth::ApiKey {{ <redacted> }}"),
43            Self::Entra { .. } => f.write_str("FoundryAuth::Entra {{ .. }}"),
44        }
45    }
46}
47
48#[derive(Clone)]
49enum ResolvedAuth {
50    ApiKey(SecretString),
51    Entra(Arc<CachedTokenProvider<SecretString>>),
52}
53
54/// Azure AI Foundry HTTP transport.
55#[derive(Clone)]
56pub struct FoundryTransport {
57    client: reqwest::Client,
58    base_url: String,
59    auth: ResolvedAuth,
60}
61
62impl FoundryTransport {
63    /// Start a fluent builder.
64    pub fn builder() -> FoundryTransportBuilder {
65        FoundryTransportBuilder {
66            base_url: None,
67            auth: None,
68        }
69    }
70
71    /// Borrow the resolved base URL.
72    pub fn base_url(&self) -> &str {
73        &self.base_url
74    }
75
76    async fn build_headers(
77        &self,
78        request_headers: &http::HeaderMap,
79    ) -> Result<Vec<(String, String)>> {
80        let mut pairs: Vec<(String, String)> = Vec::with_capacity(request_headers.len() + 1);
81        for (name, value) in request_headers {
82            if let Ok(v) = value.to_str() {
83                pairs.push((name.as_str().to_owned(), v.to_owned()));
84            }
85        }
86        match &self.auth {
87            ResolvedAuth::ApiKey(token) => {
88                pairs.push(("api-key".to_owned(), token.expose_secret().to_owned()));
89            }
90            ResolvedAuth::Entra(refreshable) => {
91                let token = refreshable.current().await.map_err(Error::from)?;
92                pairs.push((
93                    "authorization".to_owned(),
94                    format!("Bearer {}", token.expose_secret()),
95                ));
96            }
97        }
98        Ok(pairs)
99    }
100
101    fn maybe_invalidate_on_unauthorized(&self, status: u16) {
102        if status == 401
103            && let ResolvedAuth::Entra(token) = &self.auth
104        {
105            token.invalidate();
106        }
107    }
108
109    fn apply_pairs(
110        req: reqwest::RequestBuilder,
111        pairs: &[(String, String)],
112    ) -> reqwest::RequestBuilder {
113        let mut out = req;
114        for (name, value) in pairs {
115            out = out.header(name.as_str(), value.as_str());
116        }
117        out
118    }
119}
120
121#[async_trait::async_trait]
122impl Transport for FoundryTransport {
123    fn name(&self) -> &'static str {
124        "foundry"
125    }
126
127    async fn send(
128        &self,
129        request: EncodedRequest,
130        ctx: &ExecutionContext,
131    ) -> Result<TransportResponse> {
132        if ctx.is_cancelled() {
133            return Err(Error::Cancelled);
134        }
135        let url = format!("{}{}", self.base_url, request.path);
136        // Header build can stall on Entra token refresh — race the
137        // caller's cancellation token so a cancel surfaces within
138        // one HTTP round-trip instead of waiting for the full
139        // OAuth refresh to complete.
140        let pairs = tokio::select! {
141            biased;
142            () = ctx.cancellation().cancelled() => return Err(Error::Cancelled),
143            p = self.build_headers(&request.headers) => p?,
144        };
145        let body_bytes = Bytes::clone(&request.body);
146        let mut http_req = self.client.request(request.method.clone(), &url);
147        http_req = Self::apply_pairs(http_req, &pairs).body(body_bytes);
148        let response = tokio::select! {
149            biased;
150            () = ctx.cancellation().cancelled() => return Err(Error::Cancelled),
151            r = http_req.send() => r,
152        }
153        .map_err(Error::provider_network_from)?;
154        let status = response.status().as_u16();
155        let headers = response.headers().clone();
156        let body = response
157            .bytes()
158            .await
159            .map_err(|e| Error::provider_http(status, format!("response body read failed: {e}")))?;
160        self.maybe_invalidate_on_unauthorized(status);
161        Ok(TransportResponse {
162            status,
163            headers,
164            body,
165        })
166    }
167
168    #[allow(tail_expr_drop_order)]
169    async fn send_streaming(
170        &self,
171        request: EncodedRequest,
172        ctx: &ExecutionContext,
173    ) -> Result<TransportStream> {
174        if ctx.is_cancelled() {
175            return Err(Error::Cancelled);
176        }
177        let url = format!("{}{}", self.base_url, request.path);
178        // Header build can stall on Entra token refresh — race the
179        // caller's cancellation token so a cancel surfaces within
180        // one HTTP round-trip instead of waiting for the full
181        // OAuth refresh to complete.
182        let pairs = tokio::select! {
183            biased;
184            () = ctx.cancellation().cancelled() => return Err(Error::Cancelled),
185            p = self.build_headers(&request.headers) => p?,
186        };
187        let body_bytes = Bytes::clone(&request.body);
188        let mut http_req = self.client.request(request.method.clone(), &url);
189        http_req = Self::apply_pairs(http_req, &pairs).body(body_bytes);
190        let response = tokio::select! {
191            biased;
192            () = ctx.cancellation().cancelled() => return Err(Error::Cancelled),
193            r = http_req.send() => r,
194        }
195        .map_err(Error::provider_network_from)?;
196        let status = response.status().as_u16();
197        let headers = response.headers().clone();
198        self.maybe_invalidate_on_unauthorized(status);
199        if !(200..300).contains(&status) {
200            let body = response.bytes().await.unwrap_or_else(|_| Bytes::new()); // silent-fallback-ok: error-response body read already failed; empty body preserves status + headers for caller diagnostics
201            let body_stream = futures::stream::once(async move { Ok::<_, Error>(body) });
202            return Ok(TransportStream {
203                status,
204                headers,
205                body: Box::pin(body_stream),
206            });
207        }
208        let cancellation = ctx.cancellation().clone();
209        let raw_stream = response.bytes_stream();
210        let body = async_stream::stream! {
211            let mut s = raw_stream;
212            loop {
213                tokio::select! {
214                    biased;
215                    () = cancellation.cancelled() => {
216                        yield Err(Error::Cancelled);
217                        return;
218                    }
219                    item = s.next() => match item {
220                        Some(Ok(b)) => yield Ok(b),
221                        Some(Err(e)) => {
222                            yield Err(Error::provider_http(status, format!("stream chunk read failed: {e}")));
223                            return;
224                        }
225                        None => return,
226                    }
227                }
228            }
229        };
230        Ok(TransportStream {
231            status,
232            headers,
233            body: Box::pin(body),
234        })
235    }
236}
237
238/// Fluent builder for [`FoundryTransport`].
239#[must_use]
240pub struct FoundryTransportBuilder {
241    base_url: Option<String>,
242    auth: Option<FoundryAuth>,
243}
244
245impl FoundryTransportBuilder {
246    /// Foundry endpoint base URL — typically
247    /// `https://{resource}.services.ai.azure.com/anthropic` or
248    /// `https://{resource}.openai.azure.com`. Required.
249    pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
250        self.base_url = Some(url.into());
251        self
252    }
253
254    /// Pick the auth strategy.
255    pub fn with_auth(mut self, auth: FoundryAuth) -> Self {
256        self.auth = Some(auth);
257        self
258    }
259
260    /// Resolve and return the transport.
261    pub fn build(self) -> Result<FoundryTransport> {
262        let base_url = self
263            .base_url
264            .ok_or_else(|| Error::config("FoundryTransport: base_url is required"))?;
265        let auth = self
266            .auth
267            .ok_or_else(|| Error::config("FoundryTransport: auth is required"))?;
268        let resolved = match auth {
269            FoundryAuth::ApiKey { token } => ResolvedAuth::ApiKey(token),
270            FoundryAuth::Entra { refresher } => {
271                ResolvedAuth::Entra(Arc::new(CachedTokenProvider::new(refresher)))
272            }
273        };
274        let client = reqwest::Client::builder()
275            .build()
276            .map_err(|e| Error::config(format!("failed to build HTTP client: {e}")))?;
277        Ok(FoundryTransport {
278            client,
279            base_url,
280            auth: resolved,
281        })
282    }
283}
284
285const _: fn() = || {
286    let _ = std::marker::PhantomData::<CloudError>;
287};