1use super::http::{default_http_client, normalize_base_url, HttpClient};
4use super::structured;
5use super::types::*;
6use super::LlmClient;
7use crate::llm::types::{ToolResultContent, ToolResultContentField};
8use crate::retry::{AttemptOutcome, RetryConfig};
9use anyhow::{Context, Result};
10use async_trait::async_trait;
11use futures::StreamExt;
12use serde::Deserialize;
13use std::collections::HashMap;
14use std::sync::Arc;
15use std::time::Instant;
16use tokio::sync::mpsc;
17
18pub struct OpenAiClient {
20 pub(crate) provider_name: String,
21 pub(crate) api_key: SecretString,
22 pub(crate) model: String,
23 pub(crate) base_url: String,
24 pub(crate) chat_completions_path: String,
25 pub(crate) headers: HashMap<String, String>,
26 pub(crate) temperature: Option<f32>,
27 pub(crate) max_tokens: Option<usize>,
28 pub(crate) http: Arc<dyn HttpClient>,
29 pub(crate) retry_config: RetryConfig,
30}
31
32impl OpenAiClient {
33 pub(crate) fn parse_tool_arguments(tool_name: &str, arguments: &str) -> serde_json::Value {
34 if arguments.trim().is_empty() {
35 return serde_json::Value::Object(Default::default());
36 }
37
38 serde_json::from_str(arguments).unwrap_or_else(|e| {
39 tracing::warn!(
40 "Failed to parse tool arguments JSON for tool '{}': {}",
41 tool_name,
42 e
43 );
44 serde_json::json!({
45 "__parse_error": format!(
46 "Malformed tool arguments: {}. Raw input: {}",
47 e, arguments
48 )
49 })
50 })
51 }
52
53 fn merge_stream_text(text_content: &mut String, incoming: &str) -> Option<String> {
54 if incoming.is_empty() {
55 return None;
56 }
57 if text_content.is_empty() {
58 text_content.push_str(incoming);
59 return Some(incoming.to_string());
60 }
61 if incoming == text_content.as_str() || text_content.ends_with(incoming) {
62 return None;
63 }
64 if incoming.starts_with(text_content.as_str()) && incoming.len() > text_content.len() {
67 let suffix = &incoming[text_content.len()..];
68 if !suffix.is_empty() {
69 *text_content = incoming.to_string();
70 return Some(suffix.to_string());
71 }
72 return None;
73 }
74 if let Some(suffix) = incoming.strip_prefix(text_content.as_str()) {
75 if suffix.is_empty() {
76 return None;
77 }
78 text_content.push_str(suffix);
79 return Some(suffix.to_string());
80 }
81 text_content.push_str(incoming);
82 Some(incoming.to_string())
83 }
84
85 pub fn new(api_key: String, model: String) -> Self {
86 Self {
87 provider_name: "openai".to_string(),
88 api_key: SecretString::new(api_key),
89 model,
90 base_url: "https://api.openai.com".to_string(),
91 chat_completions_path: "/v1/chat/completions".to_string(),
92 headers: HashMap::new(),
93 temperature: None,
94 max_tokens: None,
95 http: default_http_client(),
96 retry_config: RetryConfig::default(),
97 }
98 }
99
100 pub fn with_base_url(mut self, base_url: String) -> Self {
101 self.base_url = normalize_base_url(&base_url);
102 self
103 }
104
105 pub fn with_provider_name(mut self, provider_name: impl Into<String>) -> Self {
106 self.provider_name = provider_name.into();
107 self
108 }
109
110 pub fn with_chat_completions_path(mut self, path: impl Into<String>) -> Self {
111 let path = path.into();
112 self.chat_completions_path = if path.starts_with('/') {
113 path
114 } else {
115 format!("/{}", path)
116 };
117 self
118 }
119
120 pub fn with_temperature(mut self, temperature: f32) -> Self {
121 self.temperature = Some(temperature);
122 self
123 }
124
125 pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
126 self.headers = headers;
127 self
128 }
129
130 pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
131 self.max_tokens = Some(max_tokens);
132 self
133 }
134
135 pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
136 self.retry_config = retry_config;
137 self
138 }
139
140 pub fn with_http_client(mut self, http: Arc<dyn HttpClient>) -> Self {
141 self.http = http;
142 self
143 }
144
145 pub(crate) fn request_headers(&self) -> Vec<(String, String)> {
146 let mut headers = Vec::with_capacity(self.headers.len() + 1);
147 let has_authorization = self
148 .headers
149 .keys()
150 .any(|key| key.eq_ignore_ascii_case("authorization"));
151 if !has_authorization {
152 headers.push((
153 "Authorization".to_string(),
154 format!("Bearer {}", self.api_key.expose()),
155 ));
156 }
157 headers.extend(
158 self.headers
159 .iter()
160 .map(|(key, value)| (key.clone(), value.clone())),
161 );
162 headers
163 }
164
165 pub(crate) fn convert_messages(&self, messages: &[Message]) -> Vec<serde_json::Value> {
166 messages
167 .iter()
168 .map(|msg| {
169 let content: serde_json::Value = if msg.content.len() == 1 {
170 match &msg.content[0] {
171 ContentBlock::Text { text } => serde_json::json!(text),
172 ContentBlock::ToolResult {
173 tool_use_id,
174 content,
175 ..
176 } => {
177 let content_str = match content {
178 ToolResultContentField::Text(s) => s.clone(),
179 ToolResultContentField::Blocks(blocks) => blocks
180 .iter()
181 .filter_map(|b| {
182 if let ToolResultContent::Text { text } = b {
183 Some(text.clone())
184 } else {
185 None
186 }
187 })
188 .collect::<Vec<_>>()
189 .join("\n"),
190 };
191 return serde_json::json!({
192 "role": "tool",
193 "tool_call_id": tool_use_id,
194 "content": content_str,
195 });
196 }
197 _ => serde_json::json!(""),
198 }
199 } else {
200 serde_json::json!(msg
201 .content
202 .iter()
203 .map(|block| {
204 match block {
205 ContentBlock::Text { text } => serde_json::json!({
206 "type": "text",
207 "text": text,
208 }),
209 ContentBlock::Image { source } => serde_json::json!({
210 "type": "image_url",
211 "image_url": {
212 "url": format!(
213 "data:{};base64,{}",
214 source.media_type, source.data
215 ),
216 }
217 }),
218 ContentBlock::ToolUse { id, name, input } => serde_json::json!({
219 "type": "function",
220 "id": id,
221 "function": {
222 "name": name,
223 "arguments": input.to_string(),
224 }
225 }),
226 _ => serde_json::json!({}),
227 }
228 })
229 .collect::<Vec<_>>())
230 };
231
232 if msg.role == "assistant" {
235 let rc = msg.reasoning_content.as_deref().unwrap_or("");
236 let tool_calls: Vec<_> = msg.tool_calls();
237 if !tool_calls.is_empty() {
238 return serde_json::json!({
239 "role": "assistant",
240 "content": msg.text(),
241 "reasoning_content": rc,
242 "tool_calls": tool_calls.iter().map(|tc| {
243 serde_json::json!({
244 "id": tc.id,
245 "type": "function",
246 "function": {
247 "name": tc.name,
248 "arguments": tc.args.to_string(),
249 }
250 })
251 }).collect::<Vec<_>>(),
252 });
253 }
254 return serde_json::json!({
255 "role": "assistant",
256 "content": content,
257 "reasoning_content": rc,
258 });
259 }
260
261 serde_json::json!({
262 "role": msg.role,
263 "content": content,
264 })
265 })
266 .collect()
267 }
268
269 pub(crate) fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<serde_json::Value> {
270 tools
271 .iter()
272 .map(|t| {
273 serde_json::json!({
274 "type": "function",
275 "function": {
276 "name": t.name,
277 "description": t.description,
278 "parameters": t.parameters,
279 }
280 })
281 })
282 .collect()
283 }
284}
285
286impl OpenAiClient {
287 fn apply_directive(
292 request: &mut serde_json::Value,
293 directive: &structured::StructuredDirective,
294 ) {
295 if let Some(tool) = &directive.force_tool {
296 request["tool_choice"] = serde_json::json!({
297 "type": "function",
298 "function": { "name": tool }
299 });
300 }
301 if let Some(rf) = &directive.response_format {
302 request["response_format"] = match rf {
303 structured::ResponseFormat::JsonObject => {
304 serde_json::json!({ "type": "json_object" })
305 }
306 structured::ResponseFormat::JsonSchema { name, schema } => serde_json::json!({
307 "type": "json_schema",
308 "json_schema": { "name": name, "schema": schema, "strict": true }
309 }),
310 };
311 }
312 }
313
314 fn build_chat_request(
316 &self,
317 messages: &[Message],
318 system: Option<&str>,
319 tools: &[ToolDefinition],
320 directive: Option<&structured::StructuredDirective>,
321 ) -> serde_json::Value {
322 let mut openai_messages = Vec::new();
323
324 if let Some(sys) = system {
325 openai_messages.push(serde_json::json!({
326 "role": "system",
327 "content": sys,
328 }));
329 }
330
331 openai_messages.extend(self.convert_messages(messages));
332
333 let mut request = serde_json::json!({
334 "model": self.model,
335 "messages": openai_messages,
336 });
337
338 if let Some(temp) = self.temperature {
339 request["temperature"] = serde_json::json!(temp);
340 }
341 if let Some(max) = self.max_tokens {
342 request["max_tokens"] = serde_json::json!(max);
343 }
344
345 if !tools.is_empty() {
346 request["tools"] = serde_json::json!(self.convert_tools(tools));
347 }
348
349 if let Some(directive) = directive {
350 Self::apply_directive(&mut request, directive);
351 }
352
353 request
354 }
355
356 async fn send_request(&self, request: serde_json::Value) -> Result<LlmResponse> {
358 {
359 let request_started_at = Instant::now();
360 let url = format!("{}{}", self.base_url, self.chat_completions_path);
361 let request_headers = self.request_headers();
362
363 let response = crate::retry::with_retry(&self.retry_config, |_attempt| {
364 let http = &self.http;
365 let url = &url;
366 let request_headers = request_headers.clone();
367 let request = &request;
368 async move {
369 let headers = request_headers
370 .iter()
371 .map(|(key, value)| (key.as_str(), value.as_str()))
372 .collect::<Vec<_>>();
373 let cancel_token = tokio_util::sync::CancellationToken::new();
375 match http.post(url, headers, request, cancel_token).await {
376 Ok(resp) => {
377 let status = reqwest::StatusCode::from_u16(resp.status)
378 .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
379 if status.is_success() {
380 AttemptOutcome::Success(resp.body)
381 } else if self.retry_config.is_retryable_status(status) {
382 AttemptOutcome::Retryable {
383 status,
384 body: resp.body,
385 retry_after: None,
386 }
387 } else {
388 AttemptOutcome::Fatal(anyhow::anyhow!(
389 "OpenAI API error at {} ({}): {}",
390 url,
391 status,
392 resp.body
393 ))
394 }
395 }
396 Err(e) => {
397 tracing::error!("HTTP error: {e:?}");
398 AttemptOutcome::Fatal(e)
399 }
400 }
401 }
402 })
403 .await?;
404
405 let parsed: OpenAiResponse =
406 serde_json::from_str(&response).context("Failed to parse OpenAI response")?;
407
408 let choice = parsed.choices.into_iter().next().context("No choices")?;
409
410 let mut content = vec![];
411
412 let reasoning_content = choice.message.reasoning_content;
413
414 let text_content = choice.message.content;
415
416 if let Some(text) = text_content {
417 if !text.is_empty() {
418 content.push(ContentBlock::Text { text });
419 }
420 }
421
422 if let Some(tool_calls) = choice.message.tool_calls {
423 for tc in tool_calls {
424 content.push(ContentBlock::ToolUse {
425 id: tc.id,
426 name: tc.function.name.clone(),
427 input: Self::parse_tool_arguments(
428 &tc.function.name,
429 &tc.function.arguments,
430 ),
431 });
432 }
433 }
434
435 let llm_response = LlmResponse {
436 message: Message {
437 role: "assistant".to_string(),
438 content,
439 reasoning_content,
440 },
441 usage: TokenUsage {
442 prompt_tokens: parsed.usage.prompt_tokens,
443 completion_tokens: parsed.usage.completion_tokens,
444 total_tokens: {
445 let t = parsed.usage.total_tokens;
446 if t == 0 {
448 parsed.usage.total_characters.unwrap_or(0)
449 } else {
450 t
451 }
452 },
453 cache_read_tokens: parsed
454 .usage
455 .prompt_tokens_details
456 .as_ref()
457 .and_then(|d| d.cached_tokens),
458 cache_write_tokens: None,
459 },
460 stop_reason: choice.finish_reason,
461 meta: Some(LlmResponseMeta {
462 provider: Some(self.provider_name.clone()),
463 request_model: Some(self.model.clone()),
464 request_url: Some(url.clone()),
465 response_id: parsed.id,
466 response_model: parsed.model,
467 response_object: parsed.object,
468 first_token_ms: None,
469 duration_ms: Some(request_started_at.elapsed().as_millis() as u64),
470 }),
471 };
472
473 crate::telemetry::record_llm_usage(
474 llm_response.usage.prompt_tokens,
475 llm_response.usage.completion_tokens,
476 llm_response.usage.total_tokens,
477 llm_response.stop_reason.as_deref(),
478 );
479
480 Ok(llm_response)
481 }
482 }
483}
484
485#[async_trait]
486impl LlmClient for OpenAiClient {
487 async fn complete(
488 &self,
489 messages: &[Message],
490 system: Option<&str>,
491 tools: &[ToolDefinition],
492 ) -> Result<LlmResponse> {
493 self.send_request(self.build_chat_request(messages, system, tools, None))
494 .await
495 }
496
497 async fn complete_structured(
498 &self,
499 messages: &[Message],
500 system: Option<&str>,
501 tools: &[ToolDefinition],
502 directive: &structured::StructuredDirective,
503 ) -> Result<LlmResponse> {
504 self.send_request(self.build_chat_request(messages, system, tools, Some(directive)))
505 .await
506 }
507
508 fn native_structured_support(&self) -> structured::NativeStructuredSupport {
509 structured::NativeStructuredSupport::JsonSchema
510 }
511
512 async fn complete_streaming(
513 &self,
514 messages: &[Message],
515 system: Option<&str>,
516 tools: &[ToolDefinition],
517 cancel_token: tokio_util::sync::CancellationToken,
518 ) -> Result<mpsc::Receiver<StreamEvent>> {
519 self.send_streaming(
520 self.build_chat_request(messages, system, tools, None),
521 cancel_token,
522 )
523 .await
524 }
525
526 async fn complete_streaming_structured(
527 &self,
528 messages: &[Message],
529 system: Option<&str>,
530 tools: &[ToolDefinition],
531 directive: &structured::StructuredDirective,
532 cancel_token: tokio_util::sync::CancellationToken,
533 ) -> Result<mpsc::Receiver<StreamEvent>> {
534 self.send_streaming(
535 self.build_chat_request(messages, system, tools, Some(directive)),
536 cancel_token,
537 )
538 .await
539 }
540}
541
542impl OpenAiClient {
543 async fn send_streaming(
545 &self,
546 mut request: serde_json::Value,
547 cancel_token: tokio_util::sync::CancellationToken,
548 ) -> Result<mpsc::Receiver<StreamEvent>> {
549 {
550 request["stream"] = serde_json::json!(true);
551 request["stream_options"] = serde_json::json!({ "include_usage": true });
552 let request_started_at = Instant::now();
553 let url = format!("{}{}", self.base_url, self.chat_completions_path);
554 let request_headers = self.request_headers();
555
556 let streaming_resp = crate::retry::with_retry(&self.retry_config, |_attempt| {
557 let http = &self.http;
558 let url = &url;
559 let request_headers = request_headers.clone();
560 let request = &request;
561 let cancel_token = cancel_token.clone();
562 async move {
563 let headers = request_headers
564 .iter()
565 .map(|(key, value)| (key.as_str(), value.as_str()))
566 .collect::<Vec<_>>();
567 let resp = tokio::select! {
569 _ = cancel_token.cancelled() => {
570 return AttemptOutcome::Fatal(anyhow::anyhow!("HTTP request cancelled"));
571 }
572 result = http.post_streaming(url, headers, request, cancel_token.clone()) => {
573 match result {
574 Ok(r) => r,
575 Err(e) => {
576 return AttemptOutcome::Fatal(anyhow::anyhow!("HTTP request failed: {}", e));
577 }
578 }
579 }
580 };
581 let status = reqwest::StatusCode::from_u16(resp.status)
582 .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
583 if status.is_success() {
584 AttemptOutcome::Success(resp)
585 } else {
586 let retry_after = resp
587 .retry_after
588 .as_deref()
589 .and_then(|v| RetryConfig::parse_retry_after(Some(v)));
590 if self.retry_config.is_retryable_status(status) {
591 AttemptOutcome::Retryable {
592 status,
593 body: resp.error_body,
594 retry_after,
595 }
596 } else {
597 AttemptOutcome::Fatal(anyhow::anyhow!(
598 "OpenAI API error at {} ({}): {}",
599 url,
600 status,
601 resp.error_body
602 ))
603 }
604 }
605 }
606 })
607 .await?;
608
609 let (tx, rx) = mpsc::channel(100);
610
611 let mut stream = streaming_resp.byte_stream;
612 let provider_name = self.provider_name.clone();
613 let request_model = self.model.clone();
614 let request_url = url.clone();
615 tokio::spawn(async move {
616 let mut buffer = String::new();
617 let mut content_blocks: Vec<ContentBlock> = Vec::new();
618 let mut text_content = String::new();
619 let mut reasoning_content_accum = String::new();
620 let mut tool_calls: std::collections::BTreeMap<usize, (String, String, String)> =
621 std::collections::BTreeMap::new();
622 let mut usage = TokenUsage::default();
623 let mut finish_reason = None;
624 let mut response_id = None;
625 let mut response_model = None;
626 let mut response_object = None;
627 let mut first_token_ms = None;
628 let mut saw_done = false;
629 let mut parsed_any_event = false;
630
631 while let Some(chunk_result) = stream.next().await {
632 let chunk = match chunk_result {
633 Ok(c) => c,
634 Err(e) => {
635 tracing::error!("Stream error: {}", e);
636 break;
637 }
638 };
639
640 buffer.push_str(&String::from_utf8_lossy(&chunk));
641
642 while let Some(event_end) = buffer.find("\n\n") {
643 let event_data: String = buffer.drain(..event_end).collect();
644 buffer.drain(..2);
645
646 for line in event_data.lines() {
647 if let Some(data) = line.strip_prefix("data: ") {
648 if data == "[DONE]" {
649 saw_done = true;
650 if !text_content.is_empty() {
651 content_blocks.push(ContentBlock::Text {
652 text: text_content.clone(),
653 });
654 }
655 for (_, (id, name, args)) in tool_calls.iter() {
656 content_blocks.push(ContentBlock::ToolUse {
657 id: id.clone(),
658 name: name.clone(),
659 input: Self::parse_tool_arguments(name, args),
660 });
661 }
662 tool_calls.clear();
663 crate::telemetry::record_llm_usage(
664 usage.prompt_tokens,
665 usage.completion_tokens,
666 usage.total_tokens,
667 finish_reason.as_deref(),
668 );
669 let response = LlmResponse {
670 message: Message {
671 role: "assistant".to_string(),
672 content: std::mem::take(&mut content_blocks),
673 reasoning_content: if reasoning_content_accum.is_empty()
674 {
675 None
676 } else {
677 Some(std::mem::take(&mut reasoning_content_accum))
678 },
679 },
680 usage: usage.clone(),
681 stop_reason: std::mem::take(&mut finish_reason),
682 meta: Some(LlmResponseMeta {
683 provider: Some(provider_name.clone()),
684 request_model: Some(request_model.clone()),
685 request_url: Some(request_url.clone()),
686 response_id: response_id.clone(),
687 response_model: response_model.clone(),
688 response_object: response_object.clone(),
689 first_token_ms,
690 duration_ms: Some(
691 request_started_at.elapsed().as_millis() as u64,
692 ),
693 }),
694 };
695 let _ = tx.send(StreamEvent::Done(response)).await;
696 continue;
697 }
698
699 if let Ok(event) = serde_json::from_str::<OpenAiStreamChunk>(data) {
700 parsed_any_event = true;
701 if response_id.is_none() {
702 response_id = event.id.clone();
703 }
704 if response_model.is_none() {
705 response_model = event.model.clone();
706 }
707 if response_object.is_none() {
708 response_object = event.object.clone();
709 }
710 if let Some(u) = event.usage {
711 usage.prompt_tokens = u.prompt_tokens;
712 usage.completion_tokens = u.completion_tokens;
713 usage.total_tokens = u.total_tokens;
714 if usage.total_tokens == 0 {
716 usage.total_tokens = u.total_characters.unwrap_or(0);
717 }
718 usage.cache_read_tokens = u
719 .prompt_tokens_details
720 .as_ref()
721 .and_then(|d| d.cached_tokens);
722 }
723
724 if let Some(choice) = event.choices.into_iter().next() {
725 if let Some(reason) = choice.finish_reason {
726 finish_reason = Some(reason);
727 }
728
729 if let Some(message) = choice.message {
730 let skip_content = !text_content.is_empty();
733 if let Some(reasoning) = message.reasoning_content {
734 if first_token_ms.is_none() {
744 first_token_ms = Some(
745 request_started_at.elapsed().as_millis()
746 as u64,
747 );
748 }
749 if let Some(delta) = Self::merge_stream_text(
750 &mut reasoning_content_accum,
751 &reasoning,
752 ) {
753 let _ = tx
754 .send(StreamEvent::ReasoningDelta(delta))
755 .await;
756 }
757 }
758 if !skip_content {
759 if let Some(content) = message
760 .content
761 .filter(|value| !value.is_empty())
762 {
763 if first_token_ms.is_none() {
764 first_token_ms = Some(
765 request_started_at.elapsed().as_millis()
766 as u64,
767 );
768 }
769 if let Some(delta) = Self::merge_stream_text(
770 &mut text_content,
771 &content,
772 ) {
773 let _ = tx
774 .send(StreamEvent::TextDelta(delta))
775 .await;
776 }
777 }
778 }
779 if let Some(tcs) = message.tool_calls {
780 for (index, tc) in tcs.into_iter().enumerate() {
781 tool_calls.insert(
782 index,
783 (
784 tc.id,
785 tc.function.name,
786 tc.function.arguments,
787 ),
788 );
789 }
790 }
791 } else if let Some(delta) = choice.delta {
792 if let Some(ref rc) = delta.reasoning_content {
793 if first_token_ms.is_none() {
796 first_token_ms = Some(
797 request_started_at.elapsed().as_millis()
798 as u64,
799 );
800 }
801 if let Some(delta) = Self::merge_stream_text(
802 &mut reasoning_content_accum,
803 rc,
804 ) {
805 let _ = tx
806 .send(StreamEvent::ReasoningDelta(delta))
807 .await;
808 }
809 }
810
811 if let Some(content) = delta.content {
812 if first_token_ms.is_none() {
813 first_token_ms = Some(
814 request_started_at.elapsed().as_millis()
815 as u64,
816 );
817 }
818 if let Some(delta) = Self::merge_stream_text(
819 &mut text_content,
820 &content,
821 ) {
822 let _ = tx
823 .send(StreamEvent::TextDelta(delta))
824 .await;
825 }
826 }
827
828 if let Some(tcs) = delta.tool_calls {
829 for tc in tcs {
830 let entry = tool_calls
831 .entry(tc.index)
832 .or_insert_with(|| {
833 (
834 String::new(),
835 String::new(),
836 String::new(),
837 )
838 });
839
840 if let Some(id) = tc.id {
841 entry.0 = id;
842 }
843 if let Some(func) = tc.function {
844 if let Some(name) = func.name {
845 if first_token_ms.is_none() {
846 first_token_ms = Some(
847 request_started_at
848 .elapsed()
849 .as_millis()
850 as u64,
851 );
852 }
853 entry.1 = name.clone();
854 let _ = tx
855 .send(StreamEvent::ToolUseStart {
856 id: entry.0.clone(),
857 name,
858 })
859 .await;
860 }
861 if let Some(args) = func.arguments {
862 entry.2.push_str(&args);
863 let _ = tx
864 .send(
865 StreamEvent::ToolUseInputDelta(
866 args,
867 ),
868 )
869 .await;
870 }
871 }
872 }
873 }
874 }
875 }
876 }
877 }
878 }
879 }
880 }
881
882 if saw_done {
883 return;
884 }
885
886 let trailing = buffer.trim();
887 if !trailing.is_empty() {
888 if let Ok(event) = serde_json::from_str::<OpenAiStreamChunk>(trailing) {
889 parsed_any_event = true;
890 if response_id.is_none() {
891 response_id = event.id.clone();
892 }
893 if response_model.is_none() {
894 response_model = event.model.clone();
895 }
896 if response_object.is_none() {
897 response_object = event.object.clone();
898 }
899 if let Some(u) = event.usage {
900 usage.prompt_tokens = u.prompt_tokens;
901 usage.completion_tokens = u.completion_tokens;
902 usage.total_tokens = u.total_tokens;
903 usage.cache_read_tokens = u
904 .prompt_tokens_details
905 .as_ref()
906 .and_then(|d| d.cached_tokens);
907 }
908 if let Some(choice) = event.choices.into_iter().next() {
909 if let Some(reason) = choice.finish_reason {
910 finish_reason = Some(reason);
911 }
912 let skip_content = !text_content.is_empty();
915 if let Some(message) = choice.message {
916 if let Some(reasoning) = message.reasoning_content {
917 if first_token_ms.is_none() {
920 first_token_ms =
921 Some(request_started_at.elapsed().as_millis() as u64);
922 }
923 if let Some(delta) = Self::merge_stream_text(
924 &mut reasoning_content_accum,
925 &reasoning,
926 ) {
927 let _ = tx.send(StreamEvent::ReasoningDelta(delta)).await;
928 }
929 }
930 if !skip_content {
931 if let Some(content) =
932 message.content.filter(|value| !value.is_empty())
933 {
934 if first_token_ms.is_none() {
935 first_token_ms = Some(
936 request_started_at.elapsed().as_millis() as u64,
937 );
938 }
939 if let Some(delta) =
940 Self::merge_stream_text(&mut text_content, &content)
941 {
942 let _ = tx.send(StreamEvent::TextDelta(delta)).await;
943 }
944 }
945 }
946 if let Some(tcs) = message.tool_calls {
947 for (index, tc) in tcs.into_iter().enumerate() {
948 tool_calls.insert(
949 index,
950 (tc.id, tc.function.name, tc.function.arguments),
951 );
952 }
953 }
954 } else if let Some(delta) = choice.delta {
955 if let Some(ref rc) = delta.reasoning_content {
956 if first_token_ms.is_none() {
959 first_token_ms =
960 Some(request_started_at.elapsed().as_millis() as u64);
961 }
962 if let Some(delta) =
963 Self::merge_stream_text(&mut reasoning_content_accum, rc)
964 {
965 let _ = tx.send(StreamEvent::ReasoningDelta(delta)).await;
966 }
967 }
968 if let Some(content) = delta.content {
969 if first_token_ms.is_none() {
970 first_token_ms =
971 Some(request_started_at.elapsed().as_millis() as u64);
972 }
973 if let Some(delta) =
974 Self::merge_stream_text(&mut text_content, &content)
975 {
976 let _ = tx.send(StreamEvent::TextDelta(delta)).await;
977 }
978 }
979 }
980 }
981 } else if let Ok(response) = serde_json::from_str::<OpenAiResponse>(trailing) {
982 parsed_any_event = true;
983 response_id = response.id.clone();
984 response_model = response.model.clone();
985 response_object = response.object.clone();
986 usage.prompt_tokens = response.usage.prompt_tokens;
987 usage.completion_tokens = response.usage.completion_tokens;
988 usage.total_tokens = response.usage.total_tokens;
989 if usage.total_tokens == 0 {
991 usage.total_tokens = response.usage.total_characters.unwrap_or(0);
992 }
993 usage.cache_read_tokens = response
994 .usage
995 .prompt_tokens_details
996 .as_ref()
997 .and_then(|d| d.cached_tokens);
998
999 if let Some(choice) = response.choices.into_iter().next() {
1000 finish_reason = choice.finish_reason;
1001 if let Some(text) =
1002 choice.message.content.filter(|text| !text.is_empty())
1003 {
1004 if first_token_ms.is_none() {
1005 first_token_ms =
1006 Some(request_started_at.elapsed().as_millis() as u64);
1007 }
1008 let _ = Self::merge_stream_text(&mut text_content, &text);
1009 }
1010 if let Some(reasoning) = choice.message.reasoning_content {
1011 reasoning_content_accum.push_str(&reasoning);
1012 }
1013 if let Some(final_tool_calls) = choice.message.tool_calls {
1014 for tc in final_tool_calls {
1015 tool_calls.insert(
1016 tool_calls.len(),
1017 (tc.id, tc.function.name, tc.function.arguments),
1018 );
1019 }
1020 }
1021 }
1022 }
1023 }
1024
1025 if parsed_any_event
1026 || !text_content.is_empty()
1027 || !tool_calls.is_empty()
1028 || !content_blocks.is_empty()
1029 {
1030 tracing::warn!(
1031 provider = %provider_name,
1032 model = %request_model,
1033 "OpenAI-compatible stream ended without [DONE]; finalizing buffered response"
1034 );
1035 if !text_content.is_empty() {
1036 content_blocks.push(ContentBlock::Text {
1037 text: text_content.clone(),
1038 });
1039 }
1040 for (_, (id, name, args)) in tool_calls.iter() {
1041 content_blocks.push(ContentBlock::ToolUse {
1042 id: id.clone(),
1043 name: name.clone(),
1044 input: Self::parse_tool_arguments(name, args),
1045 });
1046 }
1047 tool_calls.clear();
1048 crate::telemetry::record_llm_usage(
1049 usage.prompt_tokens,
1050 usage.completion_tokens,
1051 usage.total_tokens,
1052 finish_reason.as_deref(),
1053 );
1054 let response = LlmResponse {
1055 message: Message {
1056 role: "assistant".to_string(),
1057 content: std::mem::take(&mut content_blocks),
1058 reasoning_content: if reasoning_content_accum.is_empty() {
1059 None
1060 } else {
1061 Some(std::mem::take(&mut reasoning_content_accum))
1062 },
1063 },
1064 usage: usage.clone(),
1065 stop_reason: std::mem::take(&mut finish_reason),
1066 meta: Some(LlmResponseMeta {
1067 provider: Some(provider_name.clone()),
1068 request_model: Some(request_model.clone()),
1069 request_url: Some(request_url.clone()),
1070 response_id: response_id.clone(),
1071 response_model: response_model.clone(),
1072 response_object: response_object.clone(),
1073 first_token_ms,
1074 duration_ms: Some(request_started_at.elapsed().as_millis() as u64),
1075 }),
1076 };
1077 let _ = tx.send(StreamEvent::Done(response)).await;
1078 } else {
1079 tracing::warn!(
1080 provider = %provider_name,
1081 model = %request_model,
1082 trailing = %trailing.chars().take(400).collect::<String>(),
1083 "OpenAI-compatible stream ended without any parseable events"
1084 );
1085 }
1086 });
1087
1088 Ok(rx)
1089 }
1090 }
1091}
1092
1093#[derive(Debug, Deserialize)]
1095pub(crate) struct OpenAiResponse {
1096 #[serde(default)]
1097 pub(crate) id: Option<String>,
1098 #[serde(default)]
1099 pub(crate) object: Option<String>,
1100 #[serde(default)]
1101 pub(crate) model: Option<String>,
1102 pub(crate) choices: Vec<OpenAiChoice>,
1103 pub(crate) usage: OpenAiUsage,
1104}
1105
1106#[derive(Debug, Deserialize)]
1107pub(crate) struct OpenAiChoice {
1108 pub(crate) message: OpenAiMessage,
1109 pub(crate) finish_reason: Option<String>,
1110}
1111
1112#[derive(Debug, Deserialize)]
1113pub(crate) struct OpenAiMessage {
1114 #[serde(alias = "reasoning")]
1119 pub(crate) reasoning_content: Option<String>,
1120 pub(crate) content: Option<String>,
1121 pub(crate) tool_calls: Option<Vec<OpenAiToolCall>>,
1122}
1123
1124#[derive(Debug, Deserialize)]
1125pub(crate) struct OpenAiToolCall {
1126 pub(crate) id: String,
1127 pub(crate) function: OpenAiFunction,
1128}
1129
1130#[derive(Debug, Deserialize)]
1131pub(crate) struct OpenAiFunction {
1132 pub(crate) name: String,
1133 pub(crate) arguments: String,
1134}
1135
1136#[derive(Debug, Deserialize)]
1137pub(crate) struct OpenAiUsage {
1138 #[serde(default)]
1139 pub(crate) prompt_tokens: usize,
1140 #[serde(default)]
1141 pub(crate) completion_tokens: usize,
1142 #[serde(default)]
1143 pub(crate) total_tokens: usize,
1144 #[serde(default)]
1146 pub(crate) total_characters: Option<usize>,
1147 #[serde(default)]
1149 pub(crate) prompt_tokens_details: Option<OpenAiPromptTokensDetails>,
1150}
1151
1152#[derive(Debug, Deserialize)]
1153pub(crate) struct OpenAiPromptTokensDetails {
1154 #[serde(default)]
1155 pub(crate) cached_tokens: Option<usize>,
1156}
1157
1158#[derive(Debug, Deserialize)]
1160pub(crate) struct OpenAiStreamChunk {
1161 #[serde(default)]
1162 pub(crate) id: Option<String>,
1163 #[serde(default)]
1164 pub(crate) object: Option<String>,
1165 #[serde(default)]
1166 pub(crate) model: Option<String>,
1167 pub(crate) choices: Vec<OpenAiStreamChoice>,
1168 pub(crate) usage: Option<OpenAiUsage>,
1169}
1170
1171#[derive(Debug, Deserialize)]
1172pub(crate) struct OpenAiStreamChoice {
1173 pub(crate) message: Option<OpenAiMessage>,
1174 pub(crate) delta: Option<OpenAiDelta>,
1175 pub(crate) finish_reason: Option<String>,
1176}
1177
1178#[derive(Debug, Deserialize)]
1179pub(crate) struct OpenAiDelta {
1180 #[serde(alias = "reasoning")]
1185 pub(crate) reasoning_content: Option<String>,
1186 pub(crate) content: Option<String>,
1187 pub(crate) tool_calls: Option<Vec<OpenAiToolCallDelta>>,
1188}
1189
1190#[derive(Debug, Deserialize)]
1191pub(crate) struct OpenAiToolCallDelta {
1192 pub(crate) index: usize,
1193 pub(crate) id: Option<String>,
1194 pub(crate) function: Option<OpenAiFunctionDelta>,
1195}
1196
1197#[derive(Debug, Deserialize)]
1198pub(crate) struct OpenAiFunctionDelta {
1199 pub(crate) name: Option<String>,
1200 pub(crate) arguments: Option<String>,
1201}
1202
1203#[cfg(test)]
1208mod tests {
1209 use super::*;
1210 use crate::llm::types::{Message, ToolDefinition};
1211
1212 fn make_client() -> OpenAiClient {
1213 OpenAiClient::new("test-key".to_string(), "gpt-test".to_string())
1214 }
1215
1216 struct MockSseHttp {
1223 chunks: Vec<String>,
1224 }
1225
1226 #[async_trait::async_trait]
1227 impl crate::llm::http::HttpClient for MockSseHttp {
1228 async fn post(
1229 &self,
1230 _url: &str,
1231 _headers: Vec<(&str, &str)>,
1232 _body: &serde_json::Value,
1233 _cancel: tokio_util::sync::CancellationToken,
1234 ) -> anyhow::Result<crate::llm::http::HttpResponse> {
1235 anyhow::bail!("post is unused in the streaming test")
1236 }
1237
1238 async fn post_streaming(
1239 &self,
1240 _url: &str,
1241 _headers: Vec<(&str, &str)>,
1242 _body: &serde_json::Value,
1243 _cancel: tokio_util::sync::CancellationToken,
1244 ) -> anyhow::Result<crate::llm::http::StreamingHttpResponse> {
1245 let items: Vec<anyhow::Result<bytes::Bytes>> = self
1246 .chunks
1247 .iter()
1248 .map(|s| Ok(bytes::Bytes::from(s.clone())))
1249 .collect();
1250 Ok(crate::llm::http::StreamingHttpResponse {
1251 status: 200,
1252 retry_after: None,
1253 byte_stream: Box::pin(futures::stream::iter(items)),
1254 error_body: String::new(),
1255 })
1256 }
1257 }
1258
1259 fn glm_client(chunks: Vec<String>) -> OpenAiClient {
1260 OpenAiClient::new("k".to_string(), "glm-test".to_string())
1261 .with_http_client(std::sync::Arc::new(MockSseHttp { chunks }))
1262 }
1263
1264 async fn drain_to_done(client: &OpenAiClient) -> crate::llm::LlmResponse {
1265 use crate::llm::{LlmClient, StreamEvent};
1266 let mut rx = client
1267 .complete_streaming(
1268 &[Message::user("go")],
1269 None,
1270 &[],
1271 tokio_util::sync::CancellationToken::new(),
1272 )
1273 .await
1274 .expect("stream opened");
1275 let mut done = None;
1276 while let Some(ev) = rx.recv().await {
1277 if let StreamEvent::Done(resp) = ev {
1278 done = Some(resp);
1279 }
1280 }
1281 done.expect("a Done event")
1282 }
1283
1284 #[tokio::test]
1285 async fn streaming_reasoning_does_not_leak_into_content_and_keeps_tool_call() {
1286 let chunks = vec![
1287 "data: {\"choices\":[{\"delta\":{\"reasoning\":\"Let me plan the workers\"}}]}\n\n"
1288 .to_string(),
1289 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"name\":\"parallel_task\",\"arguments\":\"{}\"}}]}}]}\n\n"
1290 .to_string(),
1291 "data: [DONE]\n\n".to_string(),
1292 ];
1293 let resp = drain_to_done(&glm_client(chunks)).await;
1294 assert_eq!(resp.message.text(), "", "reasoning leaked into content");
1296 assert_eq!(
1297 resp.message.reasoning_content.as_deref(),
1298 Some("Let me plan the workers")
1299 );
1300 let calls = resp.message.tool_calls();
1302 assert_eq!(calls.len(), 1);
1303 assert_eq!(calls[0].name, "parallel_task");
1304 }
1305
1306 #[tokio::test]
1307 async fn streaming_reasoning_only_turn_yields_empty_text() {
1308 let chunks = vec![
1312 "data: {\"choices\":[{\"delta\":{\"reasoning\":\"still thinking, no answer yet\"}}]}\n\n"
1313 .to_string(),
1314 "data: [DONE]\n\n".to_string(),
1315 ];
1316 let resp = drain_to_done(&glm_client(chunks)).await;
1317 assert_eq!(resp.message.text(), "");
1318 assert_eq!(
1319 resp.message.reasoning_content.as_deref(),
1320 Some("still thinking, no answer yet")
1321 );
1322 assert!(resp.message.tool_calls().is_empty());
1323 }
1324
1325 #[test]
1326 fn test_apply_directive_forced_function_tool_choice() {
1327 let mut req = serde_json::json!({ "model": "m" });
1328 OpenAiClient::apply_directive(
1329 &mut req,
1330 &structured::StructuredDirective {
1331 force_tool: Some("emit_person".to_string()),
1332 response_format: None,
1333 },
1334 );
1335 assert_eq!(req["tool_choice"]["type"], "function");
1336 assert_eq!(req["tool_choice"]["function"]["name"], "emit_person");
1337 assert!(req.get("response_format").is_none());
1338 }
1339
1340 #[test]
1341 fn test_apply_directive_json_schema_strict() {
1342 let mut req = serde_json::json!({});
1343 OpenAiClient::apply_directive(
1344 &mut req,
1345 &structured::StructuredDirective {
1346 force_tool: None,
1347 response_format: Some(structured::ResponseFormat::JsonSchema {
1348 name: "person".to_string(),
1349 schema: serde_json::json!({ "type": "object" }),
1350 }),
1351 },
1352 );
1353 assert_eq!(req["response_format"]["type"], "json_schema");
1354 assert_eq!(req["response_format"]["json_schema"]["name"], "person");
1355 assert_eq!(req["response_format"]["json_schema"]["strict"], true);
1356 assert!(req.get("tool_choice").is_none());
1357 }
1358
1359 #[test]
1360 fn test_apply_directive_json_object() {
1361 let mut req = serde_json::json!({});
1362 OpenAiClient::apply_directive(
1363 &mut req,
1364 &structured::StructuredDirective {
1365 force_tool: None,
1366 response_format: Some(structured::ResponseFormat::JsonObject),
1367 },
1368 );
1369 assert_eq!(req["response_format"]["type"], "json_object");
1370 }
1371
1372 #[test]
1373 fn test_build_chat_request_applies_directive_and_system() {
1374 let req = make_client().build_chat_request(
1375 &[Message::user("hi")],
1376 Some("sys"),
1377 &[ToolDefinition {
1378 name: "emit_x".to_string(),
1379 description: "emit".to_string(),
1380 parameters: serde_json::json!({ "type": "object" }),
1381 }],
1382 Some(&structured::StructuredDirective {
1383 force_tool: Some("emit_x".to_string()),
1384 response_format: None,
1385 }),
1386 );
1387 assert_eq!(req["messages"][0]["role"], "system");
1388 assert_eq!(req["tool_choice"]["function"]["name"], "emit_x");
1389 assert_eq!(req["tools"][0]["function"]["name"], "emit_x");
1390 }
1391
1392 #[test]
1393 fn test_build_chat_request_without_directive_is_plain() {
1394 let req = make_client().build_chat_request(&[Message::user("hi")], None, &[], None);
1395 assert!(req.get("tool_choice").is_none());
1396 assert!(req.get("response_format").is_none());
1397 }
1398
1399 #[test]
1400 fn test_native_structured_support_is_json_schema() {
1401 assert_eq!(
1402 make_client().native_structured_support(),
1403 structured::NativeStructuredSupport::JsonSchema
1404 );
1405 }
1406}