Skip to main content

opendev_http/
adapted_client.rs

1//! Adapted HTTP client that wraps HttpClient + ProviderAdapter.
2//!
3//! Transparently converts requests/responses through the provider adapter
4//! so the rest of the codebase can use a uniform Chat Completions format.
5
6use crate::adapters::base::ProviderAdapter;
7use crate::adapters::detect_provider_from_key;
8use crate::client::HttpClient;
9use crate::models::{HttpError, HttpResult};
10use crate::streaming::{StreamCallback, StreamEvent};
11use tokio_util::sync::CancellationToken;
12use tracing::{debug, warn};
13
14/// HTTP client with provider-specific request/response adaptation.
15///
16/// Wraps `HttpClient` and an optional `ProviderAdapter`. When an adapter
17/// is present, `post_json` will:
18/// 1. Convert the payload via `adapter.convert_request()`
19/// 2. Send via `HttpClient::post_json()`
20/// 3. Convert the response body via `adapter.convert_response()`
21pub struct AdaptedClient {
22    client: HttpClient,
23    adapter: Option<Box<dyn ProviderAdapter>>,
24}
25
26impl AdaptedClient {
27    /// Create an adapted client without any adapter (passthrough).
28    pub fn new(client: HttpClient) -> Self {
29        Self {
30            client,
31            adapter: None,
32        }
33    }
34
35    /// Create an adapted client with a provider adapter.
36    pub fn with_adapter(client: HttpClient, adapter: Box<dyn ProviderAdapter>) -> Self {
37        Self {
38            client,
39            adapter: Some(adapter),
40        }
41    }
42
43    /// Create an adapter for a specific provider name.
44    ///
45    /// Recognized providers:
46    /// - `"anthropic"` → [`AnthropicAdapter`](crate::adapters::anthropic::AnthropicAdapter)
47    /// - `"openai"` → [`OpenAiAdapter`](crate::adapters::openai::OpenAiAdapter)
48    /// - `"gemini"` | `"google"` → [`GeminiAdapter`](crate::adapters::gemini::GeminiAdapter)
49    ///
50    /// Returns `None` for providers that use the Chat Completions format natively
51    /// (groq, fireworks, mistral, etc.).
52    pub fn adapter_for_provider(provider: &str) -> Option<Box<dyn ProviderAdapter>> {
53        match provider {
54            "anthropic" => Some(Box::new(crate::adapters::anthropic::AnthropicAdapter::new())),
55            "openai" => Some(Box::new(crate::adapters::openai::OpenAiAdapter::new())),
56            "gemini" | "google" => {
57                Some(Box::new(crate::adapters::gemini::GeminiAdapter::default()))
58            }
59            _ => None,
60        }
61    }
62
63    /// Resolve the provider name, falling back to auto-detection from the API key.
64    ///
65    /// If `provider` is non-empty, returns it as-is. Otherwise, inspects the
66    /// API key prefix via [`detect_provider_from_key`] and returns the detected
67    /// provider or `"openai"` as the final fallback.
68    pub fn resolve_provider(provider: &str, api_key: &str) -> String {
69        if !provider.is_empty() {
70            return provider.to_string();
71        }
72        detect_provider_from_key(api_key)
73            .unwrap_or("openai")
74            .to_string()
75    }
76
77    /// POST JSON with optional request/response conversion.
78    pub async fn post_json(
79        &self,
80        payload: &serde_json::Value,
81        cancel: Option<&CancellationToken>,
82    ) -> Result<HttpResult, HttpError> {
83        // Only clone the payload when an adapter needs to transform it.
84        // For the passthrough (None) case, use the original reference directly.
85        let converted;
86        let effective_payload = match &self.adapter {
87            Some(adapter) => {
88                converted = adapter.convert_request(payload.clone());
89                &converted
90            }
91            None => {
92                // Strip internal `_reasoning_effort` field for passthrough providers
93                // that don't have an adapter to consume it.
94                if payload.get("_reasoning_effort").is_some() {
95                    let mut cleaned = payload.clone();
96                    cleaned.as_object_mut().unwrap().remove("_reasoning_effort");
97                    converted = cleaned;
98                    &converted
99                } else {
100                    payload
101                }
102            }
103        };
104
105        let mut result = self.client.post_json(effective_payload, cancel).await?;
106
107        // Convert response body back to Chat Completions format
108        if let (Some(adapter), Some(body)) = (&self.adapter, &result.body)
109            && result.success
110        {
111            result.body = Some(adapter.convert_response(body.clone()));
112        }
113
114        Ok(result)
115    }
116
117    /// Whether streaming is supported for this client's adapter.
118    pub fn supports_streaming(&self) -> bool {
119        self.adapter
120            .as_ref()
121            .map(|a| a.supports_streaming())
122            .unwrap_or(false)
123    }
124
125    /// POST JSON with SSE streaming, calling the callback for each event.
126    ///
127    /// Falls back to `post_json` if the adapter doesn't support streaming.
128    /// Returns the final accumulated response as an `HttpResult`.
129    pub async fn post_json_streaming(
130        &self,
131        payload: &serde_json::Value,
132        cancel: Option<&CancellationToken>,
133        callback: &dyn StreamCallback,
134    ) -> Result<HttpResult, HttpError> {
135        let adapter = match &self.adapter {
136            Some(a) if a.supports_streaming() => a,
137            _ => {
138                return self.post_json(payload, cancel).await;
139            }
140        };
141
142        // Convert request and add streaming flag
143        let mut converted = adapter.convert_request(payload.clone());
144        adapter.enable_streaming(&mut converted);
145
146        // Use streaming URL if the adapter provides one, otherwise fall back to client URL
147        let base_url = self.client.api_url();
148        let streaming_url_owned = adapter.streaming_url(base_url);
149        let url = streaming_url_owned.as_deref().unwrap_or(base_url);
150
151        // Send request and get raw response for streaming.
152        // On failure (after internal retries are exhausted), soft-fail to an
153        // HttpResult so the react loop can retry on the next iteration, matching
154        // the non-streaming post_json behavior.
155        debug!(url = %url, "Sending streaming request");
156        let response = match self
157            .client
158            .send_streaming_request(url, &converted, cancel)
159            .await
160        {
161            Ok(resp) => resp,
162            Err(HttpError::Interrupted) => return Ok(HttpResult::interrupted()),
163            Err(e) => {
164                warn!(error = %e, "Streaming request failed after retries, soft-failing");
165                return Ok(HttpResult::fail(e.to_string(), true));
166            }
167        };
168
169        let content_type = response
170            .headers()
171            .get("content-type")
172            .and_then(|v| v.to_str().ok())
173            .unwrap_or("")
174            .to_string();
175        debug!(content_type = %content_type, status = %response.status(), "Streaming response headers received");
176        // If the response isn't SSE, fall back to reading as JSON
177        if !content_type.contains("text/event-stream") {
178            warn!(content_type = %content_type, "Streaming fallback: response is not SSE, reading as JSON");
179            let body = response
180                .json::<serde_json::Value>()
181                .await
182                .map_err(|e| HttpError::Other(format!("Failed to parse response: {e}")))?;
183
184            // Check for API error
185            if let Some(error_obj) = body.get("error") {
186                let msg = error_obj
187                    .get("message")
188                    .and_then(|m| m.as_str())
189                    .unwrap_or("Unknown API error");
190                return Err(HttpError::Other(format!("API error: {msg}")));
191            }
192
193            let converted_body = adapter.convert_response(body);
194            return Ok(HttpResult::ok(200, converted_body));
195        }
196
197        // Read SSE events from the response body
198        let mut final_body: Option<serde_json::Value> = None;
199        let mut accumulated_text = String::new();
200        let mut accumulated_reasoning = String::new();
201        let mut usage_data: Option<serde_json::Value> = None;
202        let mut tool_calls: Vec<serde_json::Value> = Vec::new();
203        let mut current_tool_args: std::collections::HashMap<usize, String> =
204            std::collections::HashMap::new();
205        // OpenAI Responses API: map output_index → tool_call vec index
206        let mut tool_call_index: std::collections::HashMap<usize, usize> =
207            std::collections::HashMap::new();
208        let mut stop_reason: Option<String> = None;
209        let mut line_buf = String::new();
210        let mut event_type: Option<String> = None;
211
212        use futures::StreamExt;
213        let mut byte_stream = response.bytes_stream();
214
215        // Buffer for incomplete UTF-8 or line fragments
216        let mut buf = Vec::new();
217
218        let mut stream_done = false;
219        let mut stream_end_reason: Option<&str> = None;
220        let stream_start = std::time::Instant::now();
221        // Maximum total stream duration (5 minutes). Prevents indefinite hangs
222        // when the API sends heartbeat events but never completes.
223        const MAX_STREAM_DURATION: std::time::Duration = std::time::Duration::from_secs(300);
224
225        loop {
226            // Check total stream duration
227            if stream_start.elapsed() > MAX_STREAM_DURATION {
228                warn!(
229                    elapsed_secs = stream_start.elapsed().as_secs(),
230                    "SSE stream total duration exceeded 300s, forcing termination"
231                );
232                stream_end_reason = Some("stream duration exceeded 5 minutes");
233                break;
234            }
235
236            let chunk_result =
237                match tokio::time::timeout(std::time::Duration::from_secs(120), byte_stream.next())
238                    .await
239                {
240                    Ok(Some(result)) => result,
241                    Ok(None) => {
242                        stream_end_reason = Some("connection closed by server");
243                        break;
244                    }
245                    Err(_elapsed) => {
246                        warn!("SSE stream idle timeout (120s with no data)");
247                        stream_end_reason = Some("idle timeout (120s with no data)");
248                        break;
249                    }
250                };
251
252            // Check cancellation
253            if let Some(token) = cancel
254                && token.is_cancelled()
255            {
256                return Ok(HttpResult::interrupted());
257            }
258
259            let chunk = match chunk_result {
260                Ok(c) => c,
261                Err(e) => {
262                    warn!(error = %e, "SSE stream error");
263                    callback.on_event(&StreamEvent::Error(e.to_string()));
264                    stream_end_reason = Some("network error during stream");
265                    break;
266                }
267            };
268
269            buf.extend_from_slice(&chunk);
270
271            // Process complete lines from the buffer
272            while let Some(newline_pos) = buf.iter().position(|&b| b == b'\n') {
273                let line_bytes = buf.drain(..=newline_pos).collect::<Vec<u8>>();
274                let line = String::from_utf8_lossy(&line_bytes).trim().to_string();
275
276                if line.is_empty() {
277                    // Empty line = end of SSE event block
278                    if !line_buf.is_empty() && line_buf.trim() == "data: [DONE]" {
279                        stream_done = true;
280                        line_buf.clear();
281                        event_type = None;
282                        continue;
283                    }
284                    if !line_buf.is_empty()
285                        && let Some(data_json) = crate::streaming::parse_sse_data(&line_buf)
286                    {
287                        // Get event type from SSE `event:` line or from JSON `type` field.
288                        // OpenAI Responses API sends only `data:` lines with a `type` field
289                        // in the JSON payload (no `event:` lines).
290                        let et = event_type.as_deref().unwrap_or_else(|| {
291                            data_json.get("type").and_then(|t| t.as_str()).unwrap_or("")
292                        });
293                        if let Some(stream_event) = adapter.parse_stream_event(et, &data_json) {
294                            debug!(event_type = %et, "Stream event received");
295                            match &stream_event {
296                                StreamEvent::Done(body) => {
297                                    final_body = Some(body.clone());
298                                    stream_done = true;
299                                }
300                                StreamEvent::TextDelta(text) => {
301                                    accumulated_text.push_str(text);
302                                }
303                                StreamEvent::ReasoningBlockStart => {
304                                    if !accumulated_reasoning.is_empty() {
305                                        accumulated_reasoning.push_str("\n\n");
306                                    }
307                                }
308                                StreamEvent::ReasoningDelta(text) => {
309                                    accumulated_reasoning.push_str(text);
310                                }
311                                StreamEvent::FunctionCallStart {
312                                    index,
313                                    call_id,
314                                    name,
315                                } => {
316                                    let tc_idx = tool_calls.len();
317                                    tool_calls.push(serde_json::json!({
318                                        "id": call_id,
319                                        "type": "function",
320                                        "function": {
321                                            "name": name,
322                                            "arguments": "",
323                                        }
324                                    }));
325                                    tool_call_index.insert(*index, tc_idx);
326                                    current_tool_args.insert(tc_idx, String::new());
327                                }
328                                StreamEvent::FunctionCallDelta { index, delta } => {
329                                    if let Some(&tc_idx) = tool_call_index.get(index) {
330                                        current_tool_args
331                                            .entry(tc_idx)
332                                            .or_default()
333                                            .push_str(delta);
334                                    }
335                                }
336                                StreamEvent::FunctionCallDone { index, arguments } => {
337                                    if let Some(&tc_idx) = tool_call_index.get(index) {
338                                        current_tool_args.insert(tc_idx, arguments.clone());
339                                    }
340                                }
341                                StreamEvent::UsageUpdate {
342                                    usage,
343                                    stop_reason: sr,
344                                } => {
345                                    if let Some(u) = usage {
346                                        usage_data = Some(u.clone());
347                                    }
348                                    if let Some(r) = sr {
349                                        stop_reason = Some(r.clone());
350                                    }
351                                }
352                                StreamEvent::Error(_) => {}
353                            }
354                            callback.on_event(&stream_event);
355                        } else {
356                            debug!(event_type = %et, "Unhandled stream event type");
357                        }
358                    }
359                    line_buf.clear();
360                    event_type = None;
361                    continue;
362                }
363
364                if let Some(et) = line.strip_prefix("event: ") {
365                    event_type = Some(et.to_string());
366                } else if line.starts_with("data: ") {
367                    // Process any previous pending data line before starting a new one
368                    if !line_buf.is_empty() {
369                        if line_buf.trim() == "data: [DONE]" {
370                            stream_done = true;
371                        } else if let Some(data_json) = crate::streaming::parse_sse_data(&line_buf)
372                        {
373                            let et = event_type.as_deref().unwrap_or_else(|| {
374                                data_json.get("type").and_then(|t| t.as_str()).unwrap_or("")
375                            });
376                            if let Some(stream_event) = adapter.parse_stream_event(et, &data_json) {
377                                if let StreamEvent::Done(ref body) = stream_event {
378                                    final_body = Some(body.clone());
379                                    stream_done = true;
380                                }
381                                callback.on_event(&stream_event);
382                            }
383                        }
384                        event_type = None;
385                    }
386                    line_buf = line;
387                }
388                // Ignore other SSE fields (id:, retry:, comments)
389            }
390
391            // Eagerly process pending line_buf for stream-terminating events
392            // that arrive without a trailing blank line (e.g. last chunk).
393            if !stream_done && !line_buf.is_empty() {
394                if line_buf.trim() == "data: [DONE]" {
395                    stream_done = true;
396                } else if let Some(data_json) = crate::streaming::parse_sse_data(&line_buf) {
397                    let et = event_type.as_deref().unwrap_or_else(|| {
398                        data_json.get("type").and_then(|t| t.as_str()).unwrap_or("")
399                    });
400                    if let Some(stream_event) = adapter.parse_stream_event(et, &data_json) {
401                        if let StreamEvent::Done(ref body) = stream_event {
402                            final_body = Some(body.clone());
403                            stream_done = true;
404                        }
405                        callback.on_event(&stream_event);
406                    }
407                }
408                if stream_done {
409                    line_buf.clear();
410                    event_type = None;
411                }
412            }
413
414            if stream_done {
415                break;
416            }
417        }
418
419        // Process any remaining data in buffer
420        if !line_buf.is_empty()
421            && let Some(data_json) = crate::streaming::parse_sse_data(&line_buf)
422        {
423            let et = event_type
424                .as_deref()
425                .unwrap_or_else(|| data_json.get("type").and_then(|t| t.as_str()).unwrap_or(""));
426            if let Some(stream_event) = adapter.parse_stream_event(et, &data_json) {
427                if let StreamEvent::Done(ref body) = stream_event {
428                    final_body = Some(body.clone());
429                }
430                callback.on_event(&stream_event);
431            }
432        }
433
434        // Convert the final accumulated response through the adapter
435        match final_body {
436            Some(body) => {
437                let converted = adapter.convert_response(body);
438                debug!("Streaming complete, final response converted");
439                Ok(HttpResult::ok(200, converted))
440            }
441            None if !accumulated_text.is_empty()
442                || !accumulated_reasoning.is_empty()
443                || !tool_calls.is_empty() =>
444            {
445                // Build synthetic Chat Completions response from accumulated deltas.
446                // This handles providers like Anthropic that don't send a single
447                // "done" event with the full response.
448                let mut message = serde_json::json!({
449                    "role": "assistant",
450                    "content": if accumulated_text.is_empty() {
451                        serde_json::Value::Null
452                    } else {
453                        serde_json::Value::String(accumulated_text)
454                    },
455                });
456                if !accumulated_reasoning.is_empty() {
457                    message["reasoning_content"] = serde_json::Value::String(accumulated_reasoning);
458                }
459                // Finalize tool call arguments
460                if !tool_calls.is_empty() {
461                    let mut finalized = tool_calls;
462                    for (idx, args) in &current_tool_args {
463                        if let Some(tc) = finalized.get_mut(*idx)
464                            && let Some(func) = tc.get_mut("function")
465                        {
466                            func["arguments"] = serde_json::Value::String(args.clone());
467                        }
468                    }
469                    message["tool_calls"] = serde_json::Value::Array(finalized);
470                }
471                // Normalize provider-specific stop reasons to Chat Completions values
472                let finish = match stop_reason.as_deref() {
473                    Some("end_turn") => "stop",
474                    Some("max_tokens") => "length",
475                    Some("tool_use") => "tool_calls",
476                    Some(other) => other,
477                    None => {
478                        if message.get("tool_calls").is_some() {
479                            "tool_calls"
480                        } else {
481                            "stop"
482                        }
483                    }
484                };
485                let response = serde_json::json!({
486                    "id": "stream-accumulated",
487                    "object": "chat.completion",
488                    "model": "",
489                    "choices": [{"index": 0, "message": message, "finish_reason": finish}],
490                    "usage": usage_data.unwrap_or(serde_json::json!({})),
491                });
492                debug!("Streaming complete, built response from accumulated deltas");
493                Ok(HttpResult::ok(200, response))
494            }
495            None => {
496                let reason = stream_end_reason.unwrap_or("unknown");
497                warn!(reason = %reason, "Stream ended with no content");
498                Ok(HttpResult::fail(
499                    format!("No response received from stream ({reason})"),
500                    true,
501                ))
502            }
503        }
504    }
505
506    /// Get the configured API URL.
507    pub fn api_url(&self) -> &str {
508        self.client.api_url()
509    }
510}
511
512impl std::fmt::Debug for AdaptedClient {
513    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
514        f.debug_struct("AdaptedClient")
515            .field("api_url", &self.client.api_url())
516            .field(
517                "adapter",
518                &self
519                    .adapter
520                    .as_ref()
521                    .map(|a| a.provider_name())
522                    .unwrap_or("none"),
523            )
524            .finish()
525    }
526}
527
528#[cfg(test)]
529#[path = "adapted_client_tests.rs"]
530mod tests;