1use crate::adapters::base::ProviderAdapter;
7use crate::adapters::detect_provider_from_key;
8use crate::client::HttpClient;
9use crate::models::{HttpError, HttpResult};
10use crate::streaming::{StreamCallback, StreamEvent};
11use tokio_util::sync::CancellationToken;
12use tracing::{debug, warn};
13
14pub struct AdaptedClient {
22 client: HttpClient,
23 adapter: Option<Box<dyn ProviderAdapter>>,
24}
25
26impl AdaptedClient {
27 pub fn new(client: HttpClient) -> Self {
29 Self {
30 client,
31 adapter: None,
32 }
33 }
34
35 pub fn with_adapter(client: HttpClient, adapter: Box<dyn ProviderAdapter>) -> Self {
37 Self {
38 client,
39 adapter: Some(adapter),
40 }
41 }
42
43 pub fn adapter_for_provider(provider: &str) -> Option<Box<dyn ProviderAdapter>> {
53 match provider {
54 "anthropic" => Some(Box::new(crate::adapters::anthropic::AnthropicAdapter::new())),
55 "openai" => Some(Box::new(crate::adapters::openai::OpenAiAdapter::new())),
56 "gemini" | "google" => {
57 Some(Box::new(crate::adapters::gemini::GeminiAdapter::default()))
58 }
59 _ => None,
60 }
61 }
62
63 pub fn resolve_provider(provider: &str, api_key: &str) -> String {
69 if !provider.is_empty() {
70 return provider.to_string();
71 }
72 detect_provider_from_key(api_key)
73 .unwrap_or("openai")
74 .to_string()
75 }
76
77 pub async fn post_json(
79 &self,
80 payload: &serde_json::Value,
81 cancel: Option<&CancellationToken>,
82 ) -> Result<HttpResult, HttpError> {
83 let converted;
86 let effective_payload = match &self.adapter {
87 Some(adapter) => {
88 converted = adapter.convert_request(payload.clone());
89 &converted
90 }
91 None => {
92 if payload.get("_reasoning_effort").is_some() {
95 let mut cleaned = payload.clone();
96 cleaned.as_object_mut().unwrap().remove("_reasoning_effort");
97 converted = cleaned;
98 &converted
99 } else {
100 payload
101 }
102 }
103 };
104
105 let mut result = self.client.post_json(effective_payload, cancel).await?;
106
107 if let (Some(adapter), Some(body)) = (&self.adapter, &result.body)
109 && result.success
110 {
111 result.body = Some(adapter.convert_response(body.clone()));
112 }
113
114 Ok(result)
115 }
116
117 pub fn supports_streaming(&self) -> bool {
119 self.adapter
120 .as_ref()
121 .map(|a| a.supports_streaming())
122 .unwrap_or(false)
123 }
124
125 pub async fn post_json_streaming(
130 &self,
131 payload: &serde_json::Value,
132 cancel: Option<&CancellationToken>,
133 callback: &dyn StreamCallback,
134 ) -> Result<HttpResult, HttpError> {
135 let adapter = match &self.adapter {
136 Some(a) if a.supports_streaming() => a,
137 _ => {
138 return self.post_json(payload, cancel).await;
139 }
140 };
141
142 let mut converted = adapter.convert_request(payload.clone());
144 adapter.enable_streaming(&mut converted);
145
146 let base_url = self.client.api_url();
148 let streaming_url_owned = adapter.streaming_url(base_url);
149 let url = streaming_url_owned.as_deref().unwrap_or(base_url);
150
151 debug!(url = %url, "Sending streaming request");
156 let response = match self
157 .client
158 .send_streaming_request(url, &converted, cancel)
159 .await
160 {
161 Ok(resp) => resp,
162 Err(HttpError::Interrupted) => return Ok(HttpResult::interrupted()),
163 Err(e) => {
164 warn!(error = %e, "Streaming request failed after retries, soft-failing");
165 return Ok(HttpResult::fail(e.to_string(), true));
166 }
167 };
168
169 let content_type = response
170 .headers()
171 .get("content-type")
172 .and_then(|v| v.to_str().ok())
173 .unwrap_or("")
174 .to_string();
175 debug!(content_type = %content_type, status = %response.status(), "Streaming response headers received");
176 if !content_type.contains("text/event-stream") {
178 warn!(content_type = %content_type, "Streaming fallback: response is not SSE, reading as JSON");
179 let body = response
180 .json::<serde_json::Value>()
181 .await
182 .map_err(|e| HttpError::Other(format!("Failed to parse response: {e}")))?;
183
184 if let Some(error_obj) = body.get("error") {
186 let msg = error_obj
187 .get("message")
188 .and_then(|m| m.as_str())
189 .unwrap_or("Unknown API error");
190 return Err(HttpError::Other(format!("API error: {msg}")));
191 }
192
193 let converted_body = adapter.convert_response(body);
194 return Ok(HttpResult::ok(200, converted_body));
195 }
196
197 let mut final_body: Option<serde_json::Value> = None;
199 let mut accumulated_text = String::new();
200 let mut accumulated_reasoning = String::new();
201 let mut usage_data: Option<serde_json::Value> = None;
202 let mut tool_calls: Vec<serde_json::Value> = Vec::new();
203 let mut current_tool_args: std::collections::HashMap<usize, String> =
204 std::collections::HashMap::new();
205 let mut tool_call_index: std::collections::HashMap<usize, usize> =
207 std::collections::HashMap::new();
208 let mut stop_reason: Option<String> = None;
209 let mut line_buf = String::new();
210 let mut event_type: Option<String> = None;
211
212 use futures::StreamExt;
213 let mut byte_stream = response.bytes_stream();
214
215 let mut buf = Vec::new();
217
218 let mut stream_done = false;
219 let mut stream_end_reason: Option<&str> = None;
220 let stream_start = std::time::Instant::now();
221 const MAX_STREAM_DURATION: std::time::Duration = std::time::Duration::from_secs(300);
224
225 loop {
226 if stream_start.elapsed() > MAX_STREAM_DURATION {
228 warn!(
229 elapsed_secs = stream_start.elapsed().as_secs(),
230 "SSE stream total duration exceeded 300s, forcing termination"
231 );
232 stream_end_reason = Some("stream duration exceeded 5 minutes");
233 break;
234 }
235
236 let chunk_result =
237 match tokio::time::timeout(std::time::Duration::from_secs(120), byte_stream.next())
238 .await
239 {
240 Ok(Some(result)) => result,
241 Ok(None) => {
242 stream_end_reason = Some("connection closed by server");
243 break;
244 }
245 Err(_elapsed) => {
246 warn!("SSE stream idle timeout (120s with no data)");
247 stream_end_reason = Some("idle timeout (120s with no data)");
248 break;
249 }
250 };
251
252 if let Some(token) = cancel
254 && token.is_cancelled()
255 {
256 return Ok(HttpResult::interrupted());
257 }
258
259 let chunk = match chunk_result {
260 Ok(c) => c,
261 Err(e) => {
262 warn!(error = %e, "SSE stream error");
263 callback.on_event(&StreamEvent::Error(e.to_string()));
264 stream_end_reason = Some("network error during stream");
265 break;
266 }
267 };
268
269 buf.extend_from_slice(&chunk);
270
271 while let Some(newline_pos) = buf.iter().position(|&b| b == b'\n') {
273 let line_bytes = buf.drain(..=newline_pos).collect::<Vec<u8>>();
274 let line = String::from_utf8_lossy(&line_bytes).trim().to_string();
275
276 if line.is_empty() {
277 if !line_buf.is_empty() && line_buf.trim() == "data: [DONE]" {
279 stream_done = true;
280 line_buf.clear();
281 event_type = None;
282 continue;
283 }
284 if !line_buf.is_empty()
285 && let Some(data_json) = crate::streaming::parse_sse_data(&line_buf)
286 {
287 let et = event_type.as_deref().unwrap_or_else(|| {
291 data_json.get("type").and_then(|t| t.as_str()).unwrap_or("")
292 });
293 if let Some(stream_event) = adapter.parse_stream_event(et, &data_json) {
294 debug!(event_type = %et, "Stream event received");
295 match &stream_event {
296 StreamEvent::Done(body) => {
297 final_body = Some(body.clone());
298 stream_done = true;
299 }
300 StreamEvent::TextDelta(text) => {
301 accumulated_text.push_str(text);
302 }
303 StreamEvent::ReasoningBlockStart => {
304 if !accumulated_reasoning.is_empty() {
305 accumulated_reasoning.push_str("\n\n");
306 }
307 }
308 StreamEvent::ReasoningDelta(text) => {
309 accumulated_reasoning.push_str(text);
310 }
311 StreamEvent::FunctionCallStart {
312 index,
313 call_id,
314 name,
315 } => {
316 let tc_idx = tool_calls.len();
317 tool_calls.push(serde_json::json!({
318 "id": call_id,
319 "type": "function",
320 "function": {
321 "name": name,
322 "arguments": "",
323 }
324 }));
325 tool_call_index.insert(*index, tc_idx);
326 current_tool_args.insert(tc_idx, String::new());
327 }
328 StreamEvent::FunctionCallDelta { index, delta } => {
329 if let Some(&tc_idx) = tool_call_index.get(index) {
330 current_tool_args
331 .entry(tc_idx)
332 .or_default()
333 .push_str(delta);
334 }
335 }
336 StreamEvent::FunctionCallDone { index, arguments } => {
337 if let Some(&tc_idx) = tool_call_index.get(index) {
338 current_tool_args.insert(tc_idx, arguments.clone());
339 }
340 }
341 StreamEvent::UsageUpdate {
342 usage,
343 stop_reason: sr,
344 } => {
345 if let Some(u) = usage {
346 usage_data = Some(u.clone());
347 }
348 if let Some(r) = sr {
349 stop_reason = Some(r.clone());
350 }
351 }
352 StreamEvent::Error(_) => {}
353 }
354 callback.on_event(&stream_event);
355 } else {
356 debug!(event_type = %et, "Unhandled stream event type");
357 }
358 }
359 line_buf.clear();
360 event_type = None;
361 continue;
362 }
363
364 if let Some(et) = line.strip_prefix("event: ") {
365 event_type = Some(et.to_string());
366 } else if line.starts_with("data: ") {
367 if !line_buf.is_empty() {
369 if line_buf.trim() == "data: [DONE]" {
370 stream_done = true;
371 } else if let Some(data_json) = crate::streaming::parse_sse_data(&line_buf)
372 {
373 let et = event_type.as_deref().unwrap_or_else(|| {
374 data_json.get("type").and_then(|t| t.as_str()).unwrap_or("")
375 });
376 if let Some(stream_event) = adapter.parse_stream_event(et, &data_json) {
377 if let StreamEvent::Done(ref body) = stream_event {
378 final_body = Some(body.clone());
379 stream_done = true;
380 }
381 callback.on_event(&stream_event);
382 }
383 }
384 event_type = None;
385 }
386 line_buf = line;
387 }
388 }
390
391 if !stream_done && !line_buf.is_empty() {
394 if line_buf.trim() == "data: [DONE]" {
395 stream_done = true;
396 } else if let Some(data_json) = crate::streaming::parse_sse_data(&line_buf) {
397 let et = event_type.as_deref().unwrap_or_else(|| {
398 data_json.get("type").and_then(|t| t.as_str()).unwrap_or("")
399 });
400 if let Some(stream_event) = adapter.parse_stream_event(et, &data_json) {
401 if let StreamEvent::Done(ref body) = stream_event {
402 final_body = Some(body.clone());
403 stream_done = true;
404 }
405 callback.on_event(&stream_event);
406 }
407 }
408 if stream_done {
409 line_buf.clear();
410 event_type = None;
411 }
412 }
413
414 if stream_done {
415 break;
416 }
417 }
418
419 if !line_buf.is_empty()
421 && let Some(data_json) = crate::streaming::parse_sse_data(&line_buf)
422 {
423 let et = event_type
424 .as_deref()
425 .unwrap_or_else(|| data_json.get("type").and_then(|t| t.as_str()).unwrap_or(""));
426 if let Some(stream_event) = adapter.parse_stream_event(et, &data_json) {
427 if let StreamEvent::Done(ref body) = stream_event {
428 final_body = Some(body.clone());
429 }
430 callback.on_event(&stream_event);
431 }
432 }
433
434 match final_body {
436 Some(body) => {
437 let converted = adapter.convert_response(body);
438 debug!("Streaming complete, final response converted");
439 Ok(HttpResult::ok(200, converted))
440 }
441 None if !accumulated_text.is_empty()
442 || !accumulated_reasoning.is_empty()
443 || !tool_calls.is_empty() =>
444 {
445 let mut message = serde_json::json!({
449 "role": "assistant",
450 "content": if accumulated_text.is_empty() {
451 serde_json::Value::Null
452 } else {
453 serde_json::Value::String(accumulated_text)
454 },
455 });
456 if !accumulated_reasoning.is_empty() {
457 message["reasoning_content"] = serde_json::Value::String(accumulated_reasoning);
458 }
459 if !tool_calls.is_empty() {
461 let mut finalized = tool_calls;
462 for (idx, args) in ¤t_tool_args {
463 if let Some(tc) = finalized.get_mut(*idx)
464 && let Some(func) = tc.get_mut("function")
465 {
466 func["arguments"] = serde_json::Value::String(args.clone());
467 }
468 }
469 message["tool_calls"] = serde_json::Value::Array(finalized);
470 }
471 let finish = match stop_reason.as_deref() {
473 Some("end_turn") => "stop",
474 Some("max_tokens") => "length",
475 Some("tool_use") => "tool_calls",
476 Some(other) => other,
477 None => {
478 if message.get("tool_calls").is_some() {
479 "tool_calls"
480 } else {
481 "stop"
482 }
483 }
484 };
485 let response = serde_json::json!({
486 "id": "stream-accumulated",
487 "object": "chat.completion",
488 "model": "",
489 "choices": [{"index": 0, "message": message, "finish_reason": finish}],
490 "usage": usage_data.unwrap_or(serde_json::json!({})),
491 });
492 debug!("Streaming complete, built response from accumulated deltas");
493 Ok(HttpResult::ok(200, response))
494 }
495 None => {
496 let reason = stream_end_reason.unwrap_or("unknown");
497 warn!(reason = %reason, "Stream ended with no content");
498 Ok(HttpResult::fail(
499 format!("No response received from stream ({reason})"),
500 true,
501 ))
502 }
503 }
504 }
505
506 pub fn api_url(&self) -> &str {
508 self.client.api_url()
509 }
510}
511
512impl std::fmt::Debug for AdaptedClient {
513 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
514 f.debug_struct("AdaptedClient")
515 .field("api_url", &self.client.api_url())
516 .field(
517 "adapter",
518 &self
519 .adapter
520 .as_ref()
521 .map(|a| a.provider_name())
522 .unwrap_or("none"),
523 )
524 .finish()
525 }
526}
527
528#[cfg(test)]
529#[path = "adapted_client_tests.rs"]
530mod tests;