1use super::http::{default_http_client, normalize_base_url, HttpClient};
4use super::structured;
5use super::types::*;
6use super::LlmClient;
7use crate::llm::types::{ToolResultContent, ToolResultContentField};
8use crate::retry::{AttemptOutcome, RetryConfig};
9use anyhow::{Context, Result};
10use async_trait::async_trait;
11use futures::StreamExt;
12use serde::Deserialize;
13use std::collections::HashMap;
14use std::sync::Arc;
15use std::time::Instant;
16use tokio::sync::mpsc;
17
18pub struct OpenAiClient {
20 pub(crate) provider_name: String,
21 pub(crate) api_key: SecretString,
22 pub(crate) model: String,
23 pub(crate) base_url: String,
24 pub(crate) chat_completions_path: String,
25 pub(crate) headers: HashMap<String, String>,
26 pub(crate) temperature: Option<f32>,
27 pub(crate) max_tokens: Option<usize>,
28 pub(crate) http: Arc<dyn HttpClient>,
29 pub(crate) retry_config: RetryConfig,
30}
31
32impl OpenAiClient {
33 pub(crate) fn parse_tool_arguments(tool_name: &str, arguments: &str) -> serde_json::Value {
34 if arguments.trim().is_empty() {
35 return serde_json::Value::Object(Default::default());
36 }
37
38 serde_json::from_str(arguments).unwrap_or_else(|e| {
39 tracing::warn!(
40 "Failed to parse tool arguments JSON for tool '{}': {}",
41 tool_name,
42 e
43 );
44 serde_json::json!({
45 "__parse_error": format!(
46 "Malformed tool arguments: {}. Raw input: {}",
47 e, arguments
48 )
49 })
50 })
51 }
52
53 fn merge_stream_text(text_content: &mut String, incoming: &str) -> Option<String> {
54 if incoming.is_empty() {
55 return None;
56 }
57 if text_content.is_empty() {
58 text_content.push_str(incoming);
59 return Some(incoming.to_string());
60 }
61 if incoming == text_content.as_str() || text_content.ends_with(incoming) {
62 return None;
63 }
64 if incoming.starts_with(text_content.as_str()) && incoming.len() > text_content.len() {
67 let suffix = &incoming[text_content.len()..];
68 if !suffix.is_empty() {
69 *text_content = incoming.to_string();
70 return Some(suffix.to_string());
71 }
72 return None;
73 }
74 if let Some(suffix) = incoming.strip_prefix(text_content.as_str()) {
75 if suffix.is_empty() {
76 return None;
77 }
78 text_content.push_str(suffix);
79 return Some(suffix.to_string());
80 }
81 text_content.push_str(incoming);
82 Some(incoming.to_string())
83 }
84
85 pub fn new(api_key: String, model: String) -> Self {
86 Self {
87 provider_name: "openai".to_string(),
88 api_key: SecretString::new(api_key),
89 model,
90 base_url: "https://api.openai.com".to_string(),
91 chat_completions_path: "/v1/chat/completions".to_string(),
92 headers: HashMap::new(),
93 temperature: None,
94 max_tokens: None,
95 http: default_http_client(),
96 retry_config: RetryConfig::default(),
97 }
98 }
99
100 pub fn with_base_url(mut self, base_url: String) -> Self {
101 self.base_url = normalize_base_url(&base_url);
102 self
103 }
104
105 pub fn with_provider_name(mut self, provider_name: impl Into<String>) -> Self {
106 self.provider_name = provider_name.into();
107 self
108 }
109
110 pub fn with_chat_completions_path(mut self, path: impl Into<String>) -> Self {
111 let path = path.into();
112 self.chat_completions_path = if path.starts_with('/') {
113 path
114 } else {
115 format!("/{}", path)
116 };
117 self
118 }
119
120 pub fn with_temperature(mut self, temperature: f32) -> Self {
121 self.temperature = Some(temperature);
122 self
123 }
124
125 pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
126 self.headers = headers;
127 self
128 }
129
130 pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
131 self.max_tokens = Some(max_tokens);
132 self
133 }
134
135 pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
136 self.retry_config = retry_config;
137 self
138 }
139
140 pub fn with_http_client(mut self, http: Arc<dyn HttpClient>) -> Self {
141 self.http = http;
142 self
143 }
144
145 pub(crate) fn request_headers(&self) -> Vec<(String, String)> {
146 let mut headers = Vec::with_capacity(self.headers.len() + 1);
147 let has_authorization = self
148 .headers
149 .keys()
150 .any(|key| key.eq_ignore_ascii_case("authorization"));
151 if !has_authorization {
152 headers.push((
153 "Authorization".to_string(),
154 format!("Bearer {}", self.api_key.expose()),
155 ));
156 }
157 headers.extend(
158 self.headers
159 .iter()
160 .map(|(key, value)| (key.clone(), value.clone())),
161 );
162 headers
163 }
164
165 pub(crate) fn convert_messages(&self, messages: &[Message]) -> Vec<serde_json::Value> {
166 messages
167 .iter()
168 .map(|msg| {
169 let content: serde_json::Value = if msg.content.len() == 1 {
170 match &msg.content[0] {
171 ContentBlock::Text { text } => serde_json::json!(text),
172 ContentBlock::ToolResult {
173 tool_use_id,
174 content,
175 ..
176 } => {
177 let content_str = match content {
178 ToolResultContentField::Text(s) => s.clone(),
179 ToolResultContentField::Blocks(blocks) => blocks
180 .iter()
181 .filter_map(|b| {
182 if let ToolResultContent::Text { text } = b {
183 Some(text.clone())
184 } else {
185 None
186 }
187 })
188 .collect::<Vec<_>>()
189 .join("\n"),
190 };
191 return serde_json::json!({
192 "role": "tool",
193 "tool_call_id": tool_use_id,
194 "content": content_str,
195 });
196 }
197 _ => serde_json::json!(""),
198 }
199 } else {
200 serde_json::json!(msg
201 .content
202 .iter()
203 .map(|block| {
204 match block {
205 ContentBlock::Text { text } => serde_json::json!({
206 "type": "text",
207 "text": text,
208 }),
209 ContentBlock::Image { source } => serde_json::json!({
210 "type": "image_url",
211 "image_url": {
212 "url": format!(
213 "data:{};base64,{}",
214 source.media_type, source.data
215 ),
216 }
217 }),
218 ContentBlock::ToolUse { id, name, input } => serde_json::json!({
219 "type": "function",
220 "id": id,
221 "function": {
222 "name": name,
223 "arguments": input.to_string(),
224 }
225 }),
226 _ => serde_json::json!({}),
227 }
228 })
229 .collect::<Vec<_>>())
230 };
231
232 if msg.role == "assistant" {
235 let rc = msg.reasoning_content.as_deref().unwrap_or("");
236 let tool_calls: Vec<_> = msg.tool_calls();
237 if !tool_calls.is_empty() {
238 return serde_json::json!({
239 "role": "assistant",
240 "content": msg.text(),
241 "reasoning_content": rc,
242 "tool_calls": tool_calls.iter().map(|tc| {
243 serde_json::json!({
244 "id": tc.id,
245 "type": "function",
246 "function": {
247 "name": tc.name,
248 "arguments": tc.args.to_string(),
249 }
250 })
251 }).collect::<Vec<_>>(),
252 });
253 }
254 return serde_json::json!({
255 "role": "assistant",
256 "content": content,
257 "reasoning_content": rc,
258 });
259 }
260
261 serde_json::json!({
262 "role": msg.role,
263 "content": content,
264 })
265 })
266 .collect()
267 }
268
269 pub(crate) fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<serde_json::Value> {
270 tools
271 .iter()
272 .map(|t| {
273 serde_json::json!({
274 "type": "function",
275 "function": {
276 "name": t.name,
277 "description": t.description,
278 "parameters": t.parameters,
279 }
280 })
281 })
282 .collect()
283 }
284}
285
286impl OpenAiClient {
287 fn apply_directive(
292 request: &mut serde_json::Value,
293 directive: &structured::StructuredDirective,
294 ) {
295 if let Some(tool) = &directive.force_tool {
296 request["tool_choice"] = serde_json::json!({
297 "type": "function",
298 "function": { "name": tool }
299 });
300 }
301 if let Some(rf) = &directive.response_format {
302 request["response_format"] = match rf {
303 structured::ResponseFormat::JsonObject => {
304 serde_json::json!({ "type": "json_object" })
305 }
306 structured::ResponseFormat::JsonSchema { name, schema } => serde_json::json!({
307 "type": "json_schema",
308 "json_schema": { "name": name, "schema": schema, "strict": true }
309 }),
310 };
311 }
312 }
313
314 fn build_chat_request(
316 &self,
317 messages: &[Message],
318 system: Option<&str>,
319 tools: &[ToolDefinition],
320 directive: Option<&structured::StructuredDirective>,
321 ) -> serde_json::Value {
322 let mut openai_messages = Vec::new();
323
324 if let Some(sys) = system {
325 openai_messages.push(serde_json::json!({
326 "role": "system",
327 "content": sys,
328 }));
329 }
330
331 openai_messages.extend(self.convert_messages(messages));
332
333 let mut request = serde_json::json!({
334 "model": self.model,
335 "messages": openai_messages,
336 });
337
338 if let Some(temp) = self.temperature {
339 request["temperature"] = serde_json::json!(temp);
340 }
341 if let Some(max) = self.max_tokens {
342 request["max_tokens"] = serde_json::json!(max);
343 }
344
345 if !tools.is_empty() {
346 request["tools"] = serde_json::json!(self.convert_tools(tools));
347 }
348
349 if let Some(directive) = directive {
350 Self::apply_directive(&mut request, directive);
351 }
352
353 request
354 }
355
356 async fn send_request(&self, request: serde_json::Value) -> Result<LlmResponse> {
358 {
359 let request_started_at = Instant::now();
360 let url = format!("{}{}", self.base_url, self.chat_completions_path);
361 let request_headers = self.request_headers();
362
363 let response = crate::retry::with_retry(&self.retry_config, |_attempt| {
364 let http = &self.http;
365 let url = &url;
366 let request_headers = request_headers.clone();
367 let request = &request;
368 async move {
369 let headers = request_headers
370 .iter()
371 .map(|(key, value)| (key.as_str(), value.as_str()))
372 .collect::<Vec<_>>();
373 let cancel_token = tokio_util::sync::CancellationToken::new();
375 match http.post(url, headers, request, cancel_token).await {
376 Ok(resp) => {
377 let status = reqwest::StatusCode::from_u16(resp.status)
378 .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
379 if status.is_success() {
380 AttemptOutcome::Success(resp.body)
381 } else if self.retry_config.is_retryable_status(status) {
382 AttemptOutcome::Retryable {
383 status,
384 body: resp.body,
385 retry_after: None,
386 }
387 } else {
388 AttemptOutcome::Fatal(anyhow::anyhow!(
389 "OpenAI API error at {} ({}): {}",
390 url,
391 status,
392 resp.body
393 ))
394 }
395 }
396 Err(e) => {
397 tracing::error!("HTTP error: {e:?}");
398 AttemptOutcome::Fatal(e)
399 }
400 }
401 }
402 })
403 .await?;
404
405 let parsed: OpenAiResponse =
406 serde_json::from_str(&response).context("Failed to parse OpenAI response")?;
407
408 let choice = parsed.choices.into_iter().next().context("No choices")?;
409
410 let mut content = vec![];
411
412 let reasoning_content = choice.message.reasoning_content;
413
414 let text_content = choice.message.content;
415
416 if let Some(text) = text_content {
417 if !text.is_empty() {
418 content.push(ContentBlock::Text { text });
419 }
420 }
421
422 if let Some(tool_calls) = choice.message.tool_calls {
423 for tc in tool_calls {
424 content.push(ContentBlock::ToolUse {
425 id: tc.id,
426 name: tc.function.name.clone(),
427 input: Self::parse_tool_arguments(
428 &tc.function.name,
429 &tc.function.arguments,
430 ),
431 });
432 }
433 }
434
435 let llm_response = LlmResponse {
436 message: Message {
437 role: "assistant".to_string(),
438 content,
439 reasoning_content,
440 },
441 usage: TokenUsage {
442 prompt_tokens: parsed.usage.prompt_tokens,
443 completion_tokens: parsed.usage.completion_tokens,
444 total_tokens: {
445 let t = parsed.usage.total_tokens;
446 if t == 0 {
448 parsed.usage.total_characters.unwrap_or(0)
449 } else {
450 t
451 }
452 },
453 cache_read_tokens: parsed
454 .usage
455 .prompt_tokens_details
456 .as_ref()
457 .and_then(|d| d.cached_tokens),
458 cache_write_tokens: None,
459 },
460 stop_reason: choice.finish_reason,
461 meta: Some(LlmResponseMeta {
462 provider: Some(self.provider_name.clone()),
463 request_model: Some(self.model.clone()),
464 request_url: Some(url.clone()),
465 response_id: parsed.id,
466 response_model: parsed.model,
467 response_object: parsed.object,
468 first_token_ms: None,
469 duration_ms: Some(request_started_at.elapsed().as_millis() as u64),
470 }),
471 };
472
473 crate::telemetry::record_llm_usage(
474 llm_response.usage.prompt_tokens,
475 llm_response.usage.completion_tokens,
476 llm_response.usage.total_tokens,
477 llm_response.stop_reason.as_deref(),
478 );
479
480 Ok(llm_response)
481 }
482 }
483}
484
485#[async_trait]
486impl LlmClient for OpenAiClient {
487 async fn complete(
488 &self,
489 messages: &[Message],
490 system: Option<&str>,
491 tools: &[ToolDefinition],
492 ) -> Result<LlmResponse> {
493 self.send_request(self.build_chat_request(messages, system, tools, None))
494 .await
495 }
496
497 async fn complete_structured(
498 &self,
499 messages: &[Message],
500 system: Option<&str>,
501 tools: &[ToolDefinition],
502 directive: &structured::StructuredDirective,
503 ) -> Result<LlmResponse> {
504 self.send_request(self.build_chat_request(messages, system, tools, Some(directive)))
505 .await
506 }
507
508 fn native_structured_support(&self) -> structured::NativeStructuredSupport {
509 structured::NativeStructuredSupport::JsonSchema
510 }
511
512 async fn complete_streaming(
513 &self,
514 messages: &[Message],
515 system: Option<&str>,
516 tools: &[ToolDefinition],
517 cancel_token: tokio_util::sync::CancellationToken,
518 ) -> Result<mpsc::Receiver<StreamEvent>> {
519 self.send_streaming(
520 self.build_chat_request(messages, system, tools, None),
521 cancel_token,
522 )
523 .await
524 }
525
526 async fn complete_streaming_structured(
527 &self,
528 messages: &[Message],
529 system: Option<&str>,
530 tools: &[ToolDefinition],
531 directive: &structured::StructuredDirective,
532 cancel_token: tokio_util::sync::CancellationToken,
533 ) -> Result<mpsc::Receiver<StreamEvent>> {
534 self.send_streaming(
535 self.build_chat_request(messages, system, tools, Some(directive)),
536 cancel_token,
537 )
538 .await
539 }
540}
541
542impl OpenAiClient {
543 async fn send_streaming(
545 &self,
546 mut request: serde_json::Value,
547 cancel_token: tokio_util::sync::CancellationToken,
548 ) -> Result<mpsc::Receiver<StreamEvent>> {
549 {
550 request["stream"] = serde_json::json!(true);
551 request["stream_options"] = serde_json::json!({ "include_usage": true });
552 let request_started_at = Instant::now();
553 let url = format!("{}{}", self.base_url, self.chat_completions_path);
554 let request_headers = self.request_headers();
555
556 let streaming_resp = crate::retry::with_retry(&self.retry_config, |_attempt| {
557 let http = &self.http;
558 let url = &url;
559 let request_headers = request_headers.clone();
560 let request = &request;
561 let cancel_token = cancel_token.clone();
562 async move {
563 let headers = request_headers
564 .iter()
565 .map(|(key, value)| (key.as_str(), value.as_str()))
566 .collect::<Vec<_>>();
567 let resp = tokio::select! {
569 _ = cancel_token.cancelled() => {
570 return AttemptOutcome::Fatal(anyhow::anyhow!("HTTP request cancelled"));
571 }
572 result = http.post_streaming(url, headers, request, cancel_token.clone()) => {
573 match result {
574 Ok(r) => r,
575 Err(e) => {
576 return if crate::retry::is_transient_error(&e) {
582 AttemptOutcome::Retryable {
583 status: reqwest::StatusCode::SERVICE_UNAVAILABLE,
584 body: format!("network error: {e}"),
585 retry_after: None,
586 }
587 } else {
588 AttemptOutcome::Fatal(anyhow::anyhow!(
589 "HTTP request failed: {}",
590 e
591 ))
592 };
593 }
594 }
595 }
596 };
597 let status = reqwest::StatusCode::from_u16(resp.status)
598 .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
599 if status.is_success() {
600 AttemptOutcome::Success(resp)
601 } else {
602 let retry_after = resp
603 .retry_after
604 .as_deref()
605 .and_then(|v| RetryConfig::parse_retry_after(Some(v)));
606 if self.retry_config.is_retryable_status(status) {
607 AttemptOutcome::Retryable {
608 status,
609 body: resp.error_body,
610 retry_after,
611 }
612 } else {
613 AttemptOutcome::Fatal(anyhow::anyhow!(
614 "OpenAI API error at {} ({}): {}",
615 url,
616 status,
617 resp.error_body
618 ))
619 }
620 }
621 }
622 })
623 .await?;
624
625 let (tx, rx) = mpsc::channel(100);
626
627 let mut stream = streaming_resp.byte_stream;
628 let provider_name = self.provider_name.clone();
629 let request_model = self.model.clone();
630 let request_url = url.clone();
631 tokio::spawn(async move {
632 let mut buffer = String::new();
633 let mut content_blocks: Vec<ContentBlock> = Vec::new();
634 let mut text_content = String::new();
635 let mut reasoning_content_accum = String::new();
636 let mut tool_calls: std::collections::BTreeMap<usize, (String, String, String)> =
637 std::collections::BTreeMap::new();
638 let mut usage = TokenUsage::default();
639 let mut finish_reason = None;
640 let mut response_id = None;
641 let mut response_model = None;
642 let mut response_object = None;
643 let mut first_token_ms = None;
644 let mut saw_done = false;
645 let mut parsed_any_event = false;
646
647 while let Some(chunk_result) = stream.next().await {
648 let chunk = match chunk_result {
649 Ok(c) => c,
650 Err(e) => {
651 tracing::error!("Stream error: {}", e);
652 break;
653 }
654 };
655
656 buffer.push_str(&String::from_utf8_lossy(&chunk));
657
658 while let Some(event_end) = buffer.find("\n\n") {
659 let event_data: String = buffer.drain(..event_end).collect();
660 buffer.drain(..2);
661
662 for line in event_data.lines() {
663 if let Some(data) = line.strip_prefix("data: ") {
664 if data == "[DONE]" {
665 saw_done = true;
666 if !text_content.is_empty() {
667 content_blocks.push(ContentBlock::Text {
668 text: text_content.clone(),
669 });
670 }
671 for (_, (id, name, args)) in tool_calls.iter() {
672 content_blocks.push(ContentBlock::ToolUse {
673 id: id.clone(),
674 name: name.clone(),
675 input: Self::parse_tool_arguments(name, args),
676 });
677 }
678 tool_calls.clear();
679 crate::telemetry::record_llm_usage(
680 usage.prompt_tokens,
681 usage.completion_tokens,
682 usage.total_tokens,
683 finish_reason.as_deref(),
684 );
685 let response = LlmResponse {
686 message: Message {
687 role: "assistant".to_string(),
688 content: std::mem::take(&mut content_blocks),
689 reasoning_content: if reasoning_content_accum.is_empty()
690 {
691 None
692 } else {
693 Some(std::mem::take(&mut reasoning_content_accum))
694 },
695 },
696 usage: usage.clone(),
697 stop_reason: std::mem::take(&mut finish_reason),
698 meta: Some(LlmResponseMeta {
699 provider: Some(provider_name.clone()),
700 request_model: Some(request_model.clone()),
701 request_url: Some(request_url.clone()),
702 response_id: response_id.clone(),
703 response_model: response_model.clone(),
704 response_object: response_object.clone(),
705 first_token_ms,
706 duration_ms: Some(
707 request_started_at.elapsed().as_millis() as u64,
708 ),
709 }),
710 };
711 let _ = tx.send(StreamEvent::Done(response)).await;
712 continue;
713 }
714
715 if let Ok(event) = serde_json::from_str::<OpenAiStreamChunk>(data) {
716 parsed_any_event = true;
717 if response_id.is_none() {
718 response_id = event.id.clone();
719 }
720 if response_model.is_none() {
721 response_model = event.model.clone();
722 }
723 if response_object.is_none() {
724 response_object = event.object.clone();
725 }
726 if let Some(u) = event.usage {
727 usage.prompt_tokens = u.prompt_tokens;
728 usage.completion_tokens = u.completion_tokens;
729 usage.total_tokens = u.total_tokens;
730 if usage.total_tokens == 0 {
732 usage.total_tokens = u.total_characters.unwrap_or(0);
733 }
734 usage.cache_read_tokens = u
735 .prompt_tokens_details
736 .as_ref()
737 .and_then(|d| d.cached_tokens);
738 }
739
740 if let Some(choice) = event.choices.into_iter().next() {
741 if let Some(reason) = choice.finish_reason {
742 finish_reason = Some(reason);
743 }
744
745 if let Some(message) = choice.message {
746 let skip_content = !text_content.is_empty();
749 if let Some(reasoning) = message.reasoning_content {
750 if first_token_ms.is_none() {
760 first_token_ms = Some(
761 request_started_at.elapsed().as_millis()
762 as u64,
763 );
764 }
765 if let Some(delta) = Self::merge_stream_text(
766 &mut reasoning_content_accum,
767 &reasoning,
768 ) {
769 let _ = tx
770 .send(StreamEvent::ReasoningDelta(delta))
771 .await;
772 }
773 }
774 if !skip_content {
775 if let Some(content) = message
776 .content
777 .filter(|value| !value.is_empty())
778 {
779 if first_token_ms.is_none() {
780 first_token_ms = Some(
781 request_started_at.elapsed().as_millis()
782 as u64,
783 );
784 }
785 if let Some(delta) = Self::merge_stream_text(
786 &mut text_content,
787 &content,
788 ) {
789 let _ = tx
790 .send(StreamEvent::TextDelta(delta))
791 .await;
792 }
793 }
794 }
795 if let Some(tcs) = message.tool_calls {
796 for (index, tc) in tcs.into_iter().enumerate() {
797 tool_calls.insert(
798 index,
799 (
800 tc.id,
801 tc.function.name,
802 tc.function.arguments,
803 ),
804 );
805 }
806 }
807 } else if let Some(delta) = choice.delta {
808 if let Some(ref rc) = delta.reasoning_content {
809 if first_token_ms.is_none() {
812 first_token_ms = Some(
813 request_started_at.elapsed().as_millis()
814 as u64,
815 );
816 }
817 if let Some(delta) = Self::merge_stream_text(
818 &mut reasoning_content_accum,
819 rc,
820 ) {
821 let _ = tx
822 .send(StreamEvent::ReasoningDelta(delta))
823 .await;
824 }
825 }
826
827 if let Some(content) = delta.content {
828 if first_token_ms.is_none() {
829 first_token_ms = Some(
830 request_started_at.elapsed().as_millis()
831 as u64,
832 );
833 }
834 if let Some(delta) = Self::merge_stream_text(
835 &mut text_content,
836 &content,
837 ) {
838 let _ = tx
839 .send(StreamEvent::TextDelta(delta))
840 .await;
841 }
842 }
843
844 if let Some(tcs) = delta.tool_calls {
845 for tc in tcs {
846 let entry = tool_calls
847 .entry(tc.index)
848 .or_insert_with(|| {
849 (
850 String::new(),
851 String::new(),
852 String::new(),
853 )
854 });
855
856 if let Some(id) = tc.id {
857 entry.0 = id;
858 }
859 if let Some(func) = tc.function {
860 if let Some(name) = func.name {
861 if first_token_ms.is_none() {
862 first_token_ms = Some(
863 request_started_at
864 .elapsed()
865 .as_millis()
866 as u64,
867 );
868 }
869 entry.1 = name.clone();
870 let _ = tx
871 .send(StreamEvent::ToolUseStart {
872 id: entry.0.clone(),
873 name,
874 })
875 .await;
876 }
877 if let Some(args) = func.arguments {
878 entry.2.push_str(&args);
879 let _ = tx
880 .send(
881 StreamEvent::ToolUseInputDelta(
882 args,
883 ),
884 )
885 .await;
886 }
887 }
888 }
889 }
890 }
891 }
892 }
893 }
894 }
895 }
896 }
897
898 if saw_done {
899 return;
900 }
901
902 let trailing = buffer.trim();
903 if !trailing.is_empty() {
904 if let Ok(event) = serde_json::from_str::<OpenAiStreamChunk>(trailing) {
905 parsed_any_event = true;
906 if response_id.is_none() {
907 response_id = event.id.clone();
908 }
909 if response_model.is_none() {
910 response_model = event.model.clone();
911 }
912 if response_object.is_none() {
913 response_object = event.object.clone();
914 }
915 if let Some(u) = event.usage {
916 usage.prompt_tokens = u.prompt_tokens;
917 usage.completion_tokens = u.completion_tokens;
918 usage.total_tokens = u.total_tokens;
919 usage.cache_read_tokens = u
920 .prompt_tokens_details
921 .as_ref()
922 .and_then(|d| d.cached_tokens);
923 }
924 if let Some(choice) = event.choices.into_iter().next() {
925 if let Some(reason) = choice.finish_reason {
926 finish_reason = Some(reason);
927 }
928 let skip_content = !text_content.is_empty();
931 if let Some(message) = choice.message {
932 if let Some(reasoning) = message.reasoning_content {
933 if first_token_ms.is_none() {
936 first_token_ms =
937 Some(request_started_at.elapsed().as_millis() as u64);
938 }
939 if let Some(delta) = Self::merge_stream_text(
940 &mut reasoning_content_accum,
941 &reasoning,
942 ) {
943 let _ = tx.send(StreamEvent::ReasoningDelta(delta)).await;
944 }
945 }
946 if !skip_content {
947 if let Some(content) =
948 message.content.filter(|value| !value.is_empty())
949 {
950 if first_token_ms.is_none() {
951 first_token_ms = Some(
952 request_started_at.elapsed().as_millis() as u64,
953 );
954 }
955 if let Some(delta) =
956 Self::merge_stream_text(&mut text_content, &content)
957 {
958 let _ = tx.send(StreamEvent::TextDelta(delta)).await;
959 }
960 }
961 }
962 if let Some(tcs) = message.tool_calls {
963 for (index, tc) in tcs.into_iter().enumerate() {
964 tool_calls.insert(
965 index,
966 (tc.id, tc.function.name, tc.function.arguments),
967 );
968 }
969 }
970 } else if let Some(delta) = choice.delta {
971 if let Some(ref rc) = delta.reasoning_content {
972 if first_token_ms.is_none() {
975 first_token_ms =
976 Some(request_started_at.elapsed().as_millis() as u64);
977 }
978 if let Some(delta) =
979 Self::merge_stream_text(&mut reasoning_content_accum, rc)
980 {
981 let _ = tx.send(StreamEvent::ReasoningDelta(delta)).await;
982 }
983 }
984 if let Some(content) = delta.content {
985 if first_token_ms.is_none() {
986 first_token_ms =
987 Some(request_started_at.elapsed().as_millis() as u64);
988 }
989 if let Some(delta) =
990 Self::merge_stream_text(&mut text_content, &content)
991 {
992 let _ = tx.send(StreamEvent::TextDelta(delta)).await;
993 }
994 }
995 }
996 }
997 } else if let Ok(response) = serde_json::from_str::<OpenAiResponse>(trailing) {
998 parsed_any_event = true;
999 response_id = response.id.clone();
1000 response_model = response.model.clone();
1001 response_object = response.object.clone();
1002 usage.prompt_tokens = response.usage.prompt_tokens;
1003 usage.completion_tokens = response.usage.completion_tokens;
1004 usage.total_tokens = response.usage.total_tokens;
1005 if usage.total_tokens == 0 {
1007 usage.total_tokens = response.usage.total_characters.unwrap_or(0);
1008 }
1009 usage.cache_read_tokens = response
1010 .usage
1011 .prompt_tokens_details
1012 .as_ref()
1013 .and_then(|d| d.cached_tokens);
1014
1015 if let Some(choice) = response.choices.into_iter().next() {
1016 finish_reason = choice.finish_reason;
1017 if let Some(text) =
1018 choice.message.content.filter(|text| !text.is_empty())
1019 {
1020 if first_token_ms.is_none() {
1021 first_token_ms =
1022 Some(request_started_at.elapsed().as_millis() as u64);
1023 }
1024 let _ = Self::merge_stream_text(&mut text_content, &text);
1025 }
1026 if let Some(reasoning) = choice.message.reasoning_content {
1027 reasoning_content_accum.push_str(&reasoning);
1028 }
1029 if let Some(final_tool_calls) = choice.message.tool_calls {
1030 for tc in final_tool_calls {
1031 tool_calls.insert(
1032 tool_calls.len(),
1033 (tc.id, tc.function.name, tc.function.arguments),
1034 );
1035 }
1036 }
1037 }
1038 }
1039 }
1040
1041 if parsed_any_event
1042 || !text_content.is_empty()
1043 || !tool_calls.is_empty()
1044 || !content_blocks.is_empty()
1045 {
1046 tracing::warn!(
1047 provider = %provider_name,
1048 model = %request_model,
1049 "OpenAI-compatible stream ended without [DONE]; finalizing buffered response"
1050 );
1051 if !text_content.is_empty() {
1052 content_blocks.push(ContentBlock::Text {
1053 text: text_content.clone(),
1054 });
1055 }
1056 for (_, (id, name, args)) in tool_calls.iter() {
1057 content_blocks.push(ContentBlock::ToolUse {
1058 id: id.clone(),
1059 name: name.clone(),
1060 input: Self::parse_tool_arguments(name, args),
1061 });
1062 }
1063 tool_calls.clear();
1064 crate::telemetry::record_llm_usage(
1065 usage.prompt_tokens,
1066 usage.completion_tokens,
1067 usage.total_tokens,
1068 finish_reason.as_deref(),
1069 );
1070 let response = LlmResponse {
1071 message: Message {
1072 role: "assistant".to_string(),
1073 content: std::mem::take(&mut content_blocks),
1074 reasoning_content: if reasoning_content_accum.is_empty() {
1075 None
1076 } else {
1077 Some(std::mem::take(&mut reasoning_content_accum))
1078 },
1079 },
1080 usage: usage.clone(),
1081 stop_reason: std::mem::take(&mut finish_reason),
1082 meta: Some(LlmResponseMeta {
1083 provider: Some(provider_name.clone()),
1084 request_model: Some(request_model.clone()),
1085 request_url: Some(request_url.clone()),
1086 response_id: response_id.clone(),
1087 response_model: response_model.clone(),
1088 response_object: response_object.clone(),
1089 first_token_ms,
1090 duration_ms: Some(request_started_at.elapsed().as_millis() as u64),
1091 }),
1092 };
1093 let _ = tx.send(StreamEvent::Done(response)).await;
1094 } else {
1095 tracing::warn!(
1096 provider = %provider_name,
1097 model = %request_model,
1098 trailing = %trailing.chars().take(400).collect::<String>(),
1099 "OpenAI-compatible stream ended without any parseable events"
1100 );
1101 }
1102 });
1103
1104 Ok(rx)
1105 }
1106 }
1107}
1108
1109#[derive(Debug, Deserialize)]
1111pub(crate) struct OpenAiResponse {
1112 #[serde(default)]
1113 pub(crate) id: Option<String>,
1114 #[serde(default)]
1115 pub(crate) object: Option<String>,
1116 #[serde(default)]
1117 pub(crate) model: Option<String>,
1118 pub(crate) choices: Vec<OpenAiChoice>,
1119 pub(crate) usage: OpenAiUsage,
1120}
1121
1122#[derive(Debug, Deserialize)]
1123pub(crate) struct OpenAiChoice {
1124 pub(crate) message: OpenAiMessage,
1125 pub(crate) finish_reason: Option<String>,
1126}
1127
1128#[derive(Debug, Deserialize)]
1129pub(crate) struct OpenAiMessage {
1130 #[serde(alias = "reasoning")]
1135 pub(crate) reasoning_content: Option<String>,
1136 pub(crate) content: Option<String>,
1137 pub(crate) tool_calls: Option<Vec<OpenAiToolCall>>,
1138}
1139
1140#[derive(Debug, Deserialize)]
1141pub(crate) struct OpenAiToolCall {
1142 pub(crate) id: String,
1143 pub(crate) function: OpenAiFunction,
1144}
1145
1146#[derive(Debug, Deserialize)]
1147pub(crate) struct OpenAiFunction {
1148 pub(crate) name: String,
1149 pub(crate) arguments: String,
1150}
1151
1152#[derive(Debug, Deserialize)]
1153pub(crate) struct OpenAiUsage {
1154 #[serde(default)]
1155 pub(crate) prompt_tokens: usize,
1156 #[serde(default)]
1157 pub(crate) completion_tokens: usize,
1158 #[serde(default)]
1159 pub(crate) total_tokens: usize,
1160 #[serde(default)]
1162 pub(crate) total_characters: Option<usize>,
1163 #[serde(default)]
1165 pub(crate) prompt_tokens_details: Option<OpenAiPromptTokensDetails>,
1166}
1167
1168#[derive(Debug, Deserialize)]
1169pub(crate) struct OpenAiPromptTokensDetails {
1170 #[serde(default)]
1171 pub(crate) cached_tokens: Option<usize>,
1172}
1173
1174#[derive(Debug, Deserialize)]
1176pub(crate) struct OpenAiStreamChunk {
1177 #[serde(default)]
1178 pub(crate) id: Option<String>,
1179 #[serde(default)]
1180 pub(crate) object: Option<String>,
1181 #[serde(default)]
1182 pub(crate) model: Option<String>,
1183 pub(crate) choices: Vec<OpenAiStreamChoice>,
1184 pub(crate) usage: Option<OpenAiUsage>,
1185}
1186
1187#[derive(Debug, Deserialize)]
1188pub(crate) struct OpenAiStreamChoice {
1189 pub(crate) message: Option<OpenAiMessage>,
1190 pub(crate) delta: Option<OpenAiDelta>,
1191 pub(crate) finish_reason: Option<String>,
1192}
1193
1194#[derive(Debug, Deserialize)]
1195pub(crate) struct OpenAiDelta {
1196 #[serde(alias = "reasoning")]
1201 pub(crate) reasoning_content: Option<String>,
1202 pub(crate) content: Option<String>,
1203 pub(crate) tool_calls: Option<Vec<OpenAiToolCallDelta>>,
1204}
1205
1206#[derive(Debug, Deserialize)]
1207pub(crate) struct OpenAiToolCallDelta {
1208 pub(crate) index: usize,
1209 pub(crate) id: Option<String>,
1210 pub(crate) function: Option<OpenAiFunctionDelta>,
1211}
1212
1213#[derive(Debug, Deserialize)]
1214pub(crate) struct OpenAiFunctionDelta {
1215 pub(crate) name: Option<String>,
1216 pub(crate) arguments: Option<String>,
1217}
1218
1219#[cfg(test)]
1224mod tests {
1225 use super::*;
1226 use crate::llm::types::{Message, ToolDefinition};
1227
1228 fn make_client() -> OpenAiClient {
1229 OpenAiClient::new("test-key".to_string(), "gpt-test".to_string())
1230 }
1231
1232 struct MockSseHttp {
1239 chunks: Vec<String>,
1240 }
1241
1242 #[async_trait::async_trait]
1243 impl crate::llm::http::HttpClient for MockSseHttp {
1244 async fn post(
1245 &self,
1246 _url: &str,
1247 _headers: Vec<(&str, &str)>,
1248 _body: &serde_json::Value,
1249 _cancel: tokio_util::sync::CancellationToken,
1250 ) -> anyhow::Result<crate::llm::http::HttpResponse> {
1251 anyhow::bail!("post is unused in the streaming test")
1252 }
1253
1254 async fn post_streaming(
1255 &self,
1256 _url: &str,
1257 _headers: Vec<(&str, &str)>,
1258 _body: &serde_json::Value,
1259 _cancel: tokio_util::sync::CancellationToken,
1260 ) -> anyhow::Result<crate::llm::http::StreamingHttpResponse> {
1261 let items: Vec<anyhow::Result<bytes::Bytes>> = self
1262 .chunks
1263 .iter()
1264 .map(|s| Ok(bytes::Bytes::from(s.clone())))
1265 .collect();
1266 Ok(crate::llm::http::StreamingHttpResponse {
1267 status: 200,
1268 retry_after: None,
1269 byte_stream: Box::pin(futures::stream::iter(items)),
1270 error_body: String::new(),
1271 })
1272 }
1273 }
1274
1275 fn glm_client(chunks: Vec<String>) -> OpenAiClient {
1276 OpenAiClient::new("k".to_string(), "glm-test".to_string())
1277 .with_http_client(std::sync::Arc::new(MockSseHttp { chunks }))
1278 }
1279
1280 async fn drain_to_done(client: &OpenAiClient) -> crate::llm::LlmResponse {
1281 use crate::llm::{LlmClient, StreamEvent};
1282 let mut rx = client
1283 .complete_streaming(
1284 &[Message::user("go")],
1285 None,
1286 &[],
1287 tokio_util::sync::CancellationToken::new(),
1288 )
1289 .await
1290 .expect("stream opened");
1291 let mut done = None;
1292 while let Some(ev) = rx.recv().await {
1293 if let StreamEvent::Done(resp) = ev {
1294 done = Some(resp);
1295 }
1296 }
1297 done.expect("a Done event")
1298 }
1299
1300 #[tokio::test]
1301 async fn streaming_reasoning_does_not_leak_into_content_and_keeps_tool_call() {
1302 let chunks = vec![
1303 "data: {\"choices\":[{\"delta\":{\"reasoning\":\"Let me plan the workers\"}}]}\n\n"
1304 .to_string(),
1305 "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"name\":\"parallel_task\",\"arguments\":\"{}\"}}]}}]}\n\n"
1306 .to_string(),
1307 "data: [DONE]\n\n".to_string(),
1308 ];
1309 let resp = drain_to_done(&glm_client(chunks)).await;
1310 assert_eq!(resp.message.text(), "", "reasoning leaked into content");
1312 assert_eq!(
1313 resp.message.reasoning_content.as_deref(),
1314 Some("Let me plan the workers")
1315 );
1316 let calls = resp.message.tool_calls();
1318 assert_eq!(calls.len(), 1);
1319 assert_eq!(calls[0].name, "parallel_task");
1320 }
1321
1322 #[tokio::test]
1323 async fn streaming_reasoning_only_turn_yields_empty_text() {
1324 let chunks = vec![
1328 "data: {\"choices\":[{\"delta\":{\"reasoning\":\"still thinking, no answer yet\"}}]}\n\n"
1329 .to_string(),
1330 "data: [DONE]\n\n".to_string(),
1331 ];
1332 let resp = drain_to_done(&glm_client(chunks)).await;
1333 assert_eq!(resp.message.text(), "");
1334 assert_eq!(
1335 resp.message.reasoning_content.as_deref(),
1336 Some("still thinking, no answer yet")
1337 );
1338 assert!(resp.message.tool_calls().is_empty());
1339 }
1340
1341 #[test]
1342 fn test_apply_directive_forced_function_tool_choice() {
1343 let mut req = serde_json::json!({ "model": "m" });
1344 OpenAiClient::apply_directive(
1345 &mut req,
1346 &structured::StructuredDirective {
1347 force_tool: Some("emit_person".to_string()),
1348 response_format: None,
1349 },
1350 );
1351 assert_eq!(req["tool_choice"]["type"], "function");
1352 assert_eq!(req["tool_choice"]["function"]["name"], "emit_person");
1353 assert!(req.get("response_format").is_none());
1354 }
1355
1356 #[test]
1357 fn test_apply_directive_json_schema_strict() {
1358 let mut req = serde_json::json!({});
1359 OpenAiClient::apply_directive(
1360 &mut req,
1361 &structured::StructuredDirective {
1362 force_tool: None,
1363 response_format: Some(structured::ResponseFormat::JsonSchema {
1364 name: "person".to_string(),
1365 schema: serde_json::json!({ "type": "object" }),
1366 }),
1367 },
1368 );
1369 assert_eq!(req["response_format"]["type"], "json_schema");
1370 assert_eq!(req["response_format"]["json_schema"]["name"], "person");
1371 assert_eq!(req["response_format"]["json_schema"]["strict"], true);
1372 assert!(req.get("tool_choice").is_none());
1373 }
1374
1375 #[test]
1376 fn test_apply_directive_json_object() {
1377 let mut req = serde_json::json!({});
1378 OpenAiClient::apply_directive(
1379 &mut req,
1380 &structured::StructuredDirective {
1381 force_tool: None,
1382 response_format: Some(structured::ResponseFormat::JsonObject),
1383 },
1384 );
1385 assert_eq!(req["response_format"]["type"], "json_object");
1386 }
1387
1388 #[test]
1389 fn test_build_chat_request_applies_directive_and_system() {
1390 let req = make_client().build_chat_request(
1391 &[Message::user("hi")],
1392 Some("sys"),
1393 &[ToolDefinition {
1394 name: "emit_x".to_string(),
1395 description: "emit".to_string(),
1396 parameters: serde_json::json!({ "type": "object" }),
1397 }],
1398 Some(&structured::StructuredDirective {
1399 force_tool: Some("emit_x".to_string()),
1400 response_format: None,
1401 }),
1402 );
1403 assert_eq!(req["messages"][0]["role"], "system");
1404 assert_eq!(req["tool_choice"]["function"]["name"], "emit_x");
1405 assert_eq!(req["tools"][0]["function"]["name"], "emit_x");
1406 }
1407
1408 #[test]
1409 fn test_build_chat_request_without_directive_is_plain() {
1410 let req = make_client().build_chat_request(&[Message::user("hi")], None, &[], None);
1411 assert!(req.get("tool_choice").is_none());
1412 assert!(req.get("response_format").is_none());
1413 }
1414
1415 #[test]
1416 fn test_native_structured_support_is_json_schema() {
1417 assert_eq!(
1418 make_client().native_structured_support(),
1419 structured::NativeStructuredSupport::JsonSchema
1420 );
1421 }
1422}