1use async_trait::async_trait;
4use bytes::Bytes;
5use futures::{Stream, StreamExt};
6use reqwest::Client;
7use serde::Deserialize;
8use serde_json::Value as JsonValue;
9use std::pin::Pin;
10
11use super::openai_responses_shared::parse_streaming_json;
12use super::shared_client;
13use crate::{
14 error::ProviderError, Api, AssistantMessage, ContentBlock, Context, Model, Provider,
15 ProviderEvent, StopReason, StreamOptions, TextContent, ThinkingContent, Usage,
16};
17
18fn is_zai(model: &Model) -> bool {
20 model.provider.eq_ignore_ascii_case("zai") || model.base_url.contains("api.z.ai")
21}
22
23#[derive(Clone)]
25pub struct OpenAiProvider {
26 client: &'static Client,
27 api_key: Option<String>,
28 base_url: Option<String>,
29}
30
31impl OpenAiProvider {
32 pub fn new() -> Self {
37 Self {
38 client: shared_client(),
39 api_key: None,
40 base_url: None,
41 }
42 }
43
44 pub fn with_api_key(api_key: impl Into<String>) -> Self {
46 Self {
47 client: shared_client(),
48 api_key: Some(api_key.into()),
49 base_url: None,
50 }
51 }
52
53 pub fn with_base_url(base_url: &str) -> Self {
57 Self {
58 client: shared_client(),
59 api_key: None,
60 base_url: Some(base_url.to_string()),
61 }
62 }
63
64 pub fn with_base_url_and_key(base_url: &str, api_key: Option<String>) -> Self {
68 Self {
69 client: shared_client(),
70 api_key,
71 base_url: Some(base_url.to_string()),
72 }
73 }
74}
75
76impl Default for OpenAiProvider {
77 fn default() -> Self {
78 Self::new()
79 }
80}
81
82#[async_trait]
83impl Provider for OpenAiProvider {
84 async fn stream(
85 &self,
86 model: &Model,
87 context: &Context,
88 options: Option<StreamOptions>,
89 ) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError> {
90 let options = options.unwrap_or_default();
91
92 let effective_base_url = self.base_url.as_deref().unwrap_or(&model.base_url);
94 let url = format!("{}/chat/completions", effective_base_url);
95
96 let api_key = options
98 .api_key
99 .as_ref()
100 .or(self.api_key.as_ref())
101 .ok_or_else(|| ProviderError::MissingApiKey)?;
102
103 let messages = build_messages(context)?;
105
106 let mut body = serde_json::json!({
108 "model": model.id,
109 "messages": messages,
110 "stream": true,
111 "stream_options": { "include_usage": true },
112 });
113
114 if let Some(temp) = options.temperature {
116 body["temperature"] = serde_json::json!(temp);
117 }
118
119 if let Some(max) = options.max_tokens {
120 body["max_tokens"] = serde_json::json!(max);
121 }
122
123 if !context.tools.is_empty() {
125 body["tools"] = build_tools(&context.tools)?;
126 }
127
128 if is_zai(model) {
132 if model.reasoning {
133 body["enable_thinking"] = serde_json::json!(true);
134 }
135 if !context.tools.is_empty() {
136 body["tool_stream"] = serde_json::json!(true);
137 }
138 }
139
140 tracing::info!(
141 "Sending request to {} model={} body_len={} enable_thinking={} tool_stream={}",
142 url,
143 model.id,
144 body.to_string().len(),
145 body.get("enable_thinking").is_some(),
146 body.get("tool_stream").is_some()
147 );
148 tracing::debug!("Request body: {}", body.to_string());
149
150 let mut headers = reqwest::header::HeaderMap::new();
152 headers.insert(
153 reqwest::header::AUTHORIZATION,
154 format!("Bearer {}", api_key)
155 .parse()
156 .expect("valid bearer header"),
157 );
158 headers.insert(
159 reqwest::header::CONTENT_TYPE,
160 "application/json".parse().expect("valid header value"),
161 );
162
163 for (k, v) in &options.headers {
164 if let (Ok(name), Ok(value)) = (
165 k.parse::<reqwest::header::HeaderName>(),
166 v.parse::<reqwest::header::HeaderValue>(),
167 ) {
168 headers.insert(name, value);
169 }
170 }
171
172 let response = self
174 .client
175 .post(&url)
176 .headers(headers)
177 .json(&body)
178 .send()
179 .await
180 .map_err(ProviderError::RequestFailed)?;
181
182 if !response.status().is_success() {
183 let status = response.status();
184 let body: String = response.text().await.unwrap_or_default();
185 return Err(ProviderError::HttpError(status.as_u16(), body));
186 }
187
188 let provider_name = model.provider.clone();
190 let model_id = model.id.clone();
191
192 let start_event = ProviderEvent::Start {
194 partial: AssistantMessage::new(Api::OpenAiCompletions, &provider_name, &model_id),
195 };
196
197 let stream = response
208 .bytes_stream()
209 .scan(
210 (
211 Vec::new(),
212 std::collections::HashMap::<usize, (String, String, String)>::new(),
213 std::collections::HashMap::<String, usize>::new(), false,
215 AssistantMessage::new(Api::OpenAiCompletions, &provider_name, &model_id),
216 ),
217 move |(
218 pending_bytes,
219 pending_tc,
220 tc_id_to_index,
221 thinking_started,
222 accumulated_output,
223 ),
224 chunk: Result<Bytes, reqwest::Error>| {
225 let events = match chunk {
226 Ok(bytes) => {
227 let mut combined =
229 Vec::with_capacity(pending_bytes.len() + bytes.len());
230 combined.extend_from_slice(pending_bytes);
231 combined.extend_from_slice(&bytes);
232
233 let (text, trailing) = split_complete_lines(&combined);
237 *pending_bytes = trailing;
238
239 tracing::debug!(
240 "parse_sse_events input: {} bytes, {} lines",
241 text.len(),
242 text.lines().count()
243 );
244 let raw_events = parse_sse_events(
245 &text,
246 &provider_name,
247 &model_id,
248 accumulated_output,
249 );
250 tracing::debug!("parse_sse_events output: {} events", raw_events.len());
251
252 let mut processed = Vec::new();
254 for event in raw_events {
255 match &event {
256 ProviderEvent::ThinkingDelta { content_index, .. } => {
257 if !*thinking_started {
259 *thinking_started = true;
260 processed.push(ProviderEvent::ThinkingStart {
261 content_index: *content_index,
262 partial: AssistantMessage::new(
263 Api::OpenAiCompletions,
264 &provider_name,
265 &model_id,
266 ),
267 });
268 }
269 processed.push(event);
270 }
271 ProviderEvent::ToolCallStart {
272 content_index,
273 tool_call_id,
274 tool_name,
275 ..
276 } => {
277 let entry =
278 pending_tc.entry(*content_index).or_insert_with(|| {
279 (String::new(), String::new(), String::new())
280 });
281 if let Some(ref id) = tool_call_id {
282 if !id.is_empty() {
283 entry.0 = id.clone();
284 tc_id_to_index.insert(id.clone(), *content_index);
285 }
286 }
287 if let Some(ref name) = tool_name {
288 if !name.is_empty() {
289 entry.1 = name.clone();
290 }
291 }
292 processed.push(event);
293 }
294 ProviderEvent::ToolCallDelta {
295 content_index,
296 delta,
297 ..
298 } => {
299 let idx = if pending_tc.contains_key(content_index) {
301 *content_index
302 } else {
303 tc_id_to_index
305 .values()
306 .copied()
307 .find(|i| *i == *content_index)
308 .unwrap_or(*content_index)
309 };
310 let entry = pending_tc.entry(idx).or_insert_with(|| {
311 (String::new(), String::new(), String::new())
312 });
313 tracing::debug!(
314 "[TC-DELTA] idx={}, delta_len={}, accumulated_len={}",
315 idx,
316 delta.len(),
317 entry.2.len() + delta.len()
318 );
319 entry.2.push_str(delta);
320 processed.push(event);
321 }
322 ProviderEvent::ToolCallEnd { .. } => {
323 processed.push(event);
325 }
326 ProviderEvent::Done { reason, .. } => {
327 if matches!(reason, StopReason::ToolUse) {
329 let mut indices: Vec<usize> =
330 pending_tc.keys().copied().collect();
331 indices.sort();
332 for idx in indices {
333 let (id, name, arguments) = &pending_tc[&idx];
334 tracing::debug!(
335 "[TC-END] idx={}, id={}, name={}, args_len={}",
336 idx,
337 id.len(),
338 name.len(),
339 arguments.len()
340 );
341 let args_value = parse_streaming_json(arguments);
342 processed.push(ProviderEvent::ToolCallEnd {
343 content_index: idx,
344 tool_call: crate::ToolCall {
345 content_type:
346 crate::messages::ToolCallType::ToolCall,
347 id: id.clone(),
348 name: name.clone(),
349 arguments: args_value,
350 thought_signature: None,
351 },
352 partial: AssistantMessage::new(
353 Api::OpenAiCompletions,
354 &provider_name,
355 &model_id,
356 ),
357 });
358 }
359 }
360 pending_tc.clear();
364 tc_id_to_index.clear();
365 processed.push(event);
366 }
367 _ => {
368 processed.push(event);
369 }
370 }
371 }
372 processed
373 }
374 Err(e) => {
375 vec![ProviderEvent::Error {
376 reason: StopReason::Error,
377 error: create_error_message(
378 &e.to_string(),
379 &provider_name,
380 &model_id,
381 ),
382 }]
383 }
384 };
385 async move { Some(futures::stream::iter(events)) }
387 },
388 )
389 .flatten();
390
391 let stream_with_start = futures::stream::once(async move { start_event }).chain(stream);
393 Ok(Box::pin(stream_with_start))
394 }
395
396 fn name(&self) -> &str {
397 "openai"
398 }
399}
400
401fn build_messages(context: &Context) -> Result<Vec<JsonValue>, ProviderError> {
403 let mut messages = Vec::new();
404
405 if let Some(ref prompt) = context.system_prompt {
407 messages.push(serde_json::json!({
408 "role": "system",
409 "content": prompt,
410 }));
411 }
412
413 for msg in &context.messages {
415 match msg {
416 crate::Message::User(u) => {
417 let content: String = match &u.content {
418 crate::MessageContent::Text(s) => s.clone(),
419 crate::MessageContent::Blocks(blocks) => blocks_to_content(blocks)?.to_string(),
420 };
421 messages.push(serde_json::json!({
422 "role": "user",
423 "content": content,
424 }));
425 }
426 crate::Message::Assistant(a) => {
427 let mut text_parts = Vec::new();
429 let mut tool_calls = Vec::new();
430 for block in &a.content {
431 match block {
432 ContentBlock::Text(t) => {
433 text_parts.push(t.text.clone());
434 }
435 ContentBlock::Thinking(_) => {
436 }
438 ContentBlock::ToolCall(tc) => {
439 tool_calls.push(serde_json::json!({
440 "id": tc.id,
441 "type": "function",
442 "function": {
443 "name": tc.name,
444 "arguments": tc.arguments.to_string(),
445 },
446 }));
447 }
448 ContentBlock::Image(_) | ContentBlock::Unknown(_) => {}
449 }
450 }
451 let mut msg = serde_json::json!({
452 "role": "assistant",
453 "content": text_parts.join(""),
454 });
455 if !tool_calls.is_empty() {
456 msg["tool_calls"] = serde_json::json!(tool_calls);
457 }
458 messages.push(msg);
459 }
460 crate::Message::ToolResult(t) => {
461 let result_text: String = t
462 .content
463 .iter()
464 .filter_map(|b| b.as_text())
465 .collect::<Vec<_>>()
466 .join("");
467 messages.push(serde_json::json!({
468 "role": "tool",
469 "tool_call_id": t.tool_call_id,
470 "content": result_text,
471 }));
472 }
473 }
474 }
475
476 Ok(messages)
477}
478
479fn blocks_to_content(blocks: &[ContentBlock]) -> Result<JsonValue, ProviderError> {
481 if blocks.len() == 1 {
482 if let Some(text) = blocks[0].as_text() {
483 return Ok(JsonValue::String(text.to_string()));
484 }
485 }
486
487 let items: Result<Vec<_>, _> = blocks
488 .iter()
489 .map(|block| match block {
490 ContentBlock::Text(t) => Ok(serde_json::json!({
491 "type": "text",
492 "text": t.text,
493 })),
494 ContentBlock::ToolCall(tc) => Ok(serde_json::json!({
495 "type": "function",
496 "id": tc.id,
497 "function": {
498 "name": tc.name,
499 "arguments": tc.arguments.to_string(),
500 },
501 })),
502 ContentBlock::Thinking(th) => Ok(serde_json::json!({
503 "type": "thinking",
504 "thinking": th.thinking,
505 })),
506 ContentBlock::Image(img) => Ok(serde_json::json!({
507 "type": "image_url",
508 "image_url": {
509 "url": format!("data:{};base64,{}", img.mime_type, img.data),
510 },
511 })),
512 ContentBlock::Unknown(_) => Err(ProviderError::InvalidResponse(
513 "Unknown content block type".into(),
514 )),
515 })
516 .collect();
517
518 Ok(serde_json::json!(items?))
519}
520
521fn build_tools(tools: &[crate::Tool]) -> Result<JsonValue, ProviderError> {
523 let items: Vec<_> = tools
524 .iter()
525 .map(|tool| {
526 serde_json::json!({
527 "type": "function",
528 "function": {
529 "name": tool.name,
530 "description": tool.description,
531 "parameters": tool.parameters,
532 },
533 })
534 })
535 .collect();
536
537 Ok(serde_json::json!(items))
538}
539
540fn find_valid_utf8_prefix(bytes: &[u8]) -> (String, Vec<u8>) {
546 match std::str::from_utf8(bytes) {
547 Ok(s) => (s.to_string(), Vec::new()),
548 Err(e) => {
549 let valid = &bytes[..e.valid_up_to()];
550 let trailing = bytes[e.valid_up_to()..].to_vec();
551 (String::from_utf8_lossy(valid).to_string(), trailing)
552 }
553 }
554}
555
556pub fn split_complete_lines(bytes: &[u8]) -> (String, Vec<u8>) {
560 match bytes.iter().rposition(|&b| b == b'\n') {
562 Some(last_nl) => {
563 let split_at = last_nl + 1;
564 let complete = match std::str::from_utf8(&bytes[..split_at]) {
565 Ok(s) => s.to_string(),
566 Err(_) => {
567 let (s, _) = find_valid_utf8_prefix(&bytes[..split_at]);
568 s
569 }
570 };
571 let trailing = bytes[split_at..].to_vec();
572 (complete, trailing)
573 }
574 None => {
575 (String::new(), bytes.to_vec())
578 }
579 }
580}
581
582fn parse_sse_events(
593 text: &str,
594 _provider: &str,
595 _model_id: &str,
596 output: &mut AssistantMessage,
597) -> Vec<ProviderEvent> {
598 let mut events = Vec::new();
599
600 let estimated_events = text.split('\n').filter(|l| l.starts_with("data: ")).count();
602 events.reserve(estimated_events);
603
604 let mut accumulated_usage = Usage::default();
605
606 for line in text.split('\n') {
607 let line = line.trim_end_matches('\r');
608 if line.is_empty() {
609 continue;
610 }
611
612 if !line.starts_with("data: ") {
614 continue;
615 }
616
617 let data = &line[6..]; if data == "[DONE]" {
621 break;
622 }
623
624 if data.is_empty() {
625 continue;
626 }
627
628 let chunk = match serde_json::from_str::<SSEChunk>(data) {
629 Ok(c) => c,
630 Err(_) => continue,
631 };
632
633 if let Some(chunk_usage) = &chunk.usage {
639 accumulated_usage.input = chunk_usage.prompt_tokens.max(accumulated_usage.input);
640 accumulated_usage.output = chunk_usage.completion_tokens.max(accumulated_usage.output);
641 accumulated_usage.cache_read = chunk_usage
642 .prompt_tokens_details
643 .as_ref()
644 .map(|d| d.cached_tokens)
645 .unwrap_or(0)
646 .max(accumulated_usage.cache_read);
647 accumulated_usage.total_tokens =
648 chunk_usage.total_tokens.max(accumulated_usage.total_tokens);
649 }
650
651 for choice in &chunk.choices {
652 if let Some(delta) = &choice.delta {
653 if let Some(content) = &delta.content {
654 let last_text_idx = output
656 .content
657 .iter()
658 .rposition(|b| matches!(b, ContentBlock::Text(_)));
659 if let Some(idx) = last_text_idx {
660 if let ContentBlock::Text(t) = &mut output.content[idx] {
661 t.text.push_str(content);
662 }
663 } else {
664 output
665 .content
666 .push(ContentBlock::Text(TextContent::new(content.clone())));
667 }
668 events.push(ProviderEvent::TextDelta {
669 content_index: choice.index,
670 delta: content.clone(),
671 partial: output.clone(),
672 });
673 }
674
675 if let Some(ref reasoning) = delta.reasoning_content {
677 if !reasoning.is_empty() {
678 let last_think_idx = output
680 .content
681 .iter()
682 .rposition(|b| matches!(b, ContentBlock::Thinking(_)));
683 if let Some(idx) = last_think_idx {
684 if let ContentBlock::Thinking(t) = &mut output.content[idx] {
685 t.thinking.push_str(reasoning);
686 }
687 } else {
688 output
689 .content
690 .push(ContentBlock::Thinking(ThinkingContent::new(
691 reasoning.clone(),
692 )));
693 }
694 events.push(ProviderEvent::ThinkingDelta {
695 content_index: choice.index,
696 delta: reasoning.clone(),
697 partial: output.clone(),
698 });
699 }
700 }
701
702 if let Some(tool_calls) = &delta.tool_calls {
703 for tc in tool_calls {
704 let tc_index = tc.index.unwrap_or(choice.index);
705
706 if tc.id.is_some()
708 || tc.function.as_ref().and_then(|f| f.name.as_ref()).is_some()
709 {
710 events.push(ProviderEvent::ToolCallStart {
711 content_index: tc_index,
712 tool_call_id: tc.id.clone(),
713 tool_name: tc.function.as_ref().and_then(|f| f.name.clone()),
714 partial: output.clone(),
715 });
716 }
717
718 if let Some(func) = &tc.function {
720 events.push(ProviderEvent::ToolCallDelta {
721 content_index: tc_index,
722 delta: func.arguments.clone().unwrap_or_default(),
723 partial: output.clone(),
724 });
725 }
726 }
727 }
728 }
729
730 if choice.finish_reason.is_some() {
731 let reason = match choice.finish_reason.as_deref() {
732 Some("stop") | Some("end") => StopReason::Stop,
733 Some("length") => StopReason::Length,
734 Some("tool_calls") | Some("function_call") => StopReason::ToolUse,
735 Some("content_filter") => StopReason::Error,
736 Some(unknown) => {
737 tracing::warn!("Unknown finish_reason: '{}', treating as Error", unknown);
738 StopReason::Error
739 }
740 None => StopReason::Stop,
741 };
742 tracing::info!("finish_reason={:?} → {:?}", choice.finish_reason, reason);
743
744 let mut done_msg = output.clone();
745 done_msg.stop_reason = reason;
746 done_msg.usage = accumulated_usage.clone();
747 events.push(ProviderEvent::Done {
748 reason,
749 message: done_msg,
750 });
751 }
752 }
753 }
754
755 events
756}
757
758fn create_error_message(msg: &str, provider: &str, model_id: &str) -> AssistantMessage {
760 let mut message = AssistantMessage::new(Api::OpenAiCompletions, provider, model_id);
761 message.stop_reason = StopReason::Error;
762 message.error_message = Some(msg.to_string());
763 message
764}
765
766#[derive(Debug, Deserialize)]
768struct SSEChunk {
770 _id: Option<String>,
771 #[serde(rename = "model")]
772 _model: Option<String>,
773 choices: Vec<Choice>,
774 usage: Option<UsageInfo>,
775}
776
777#[derive(Debug, Deserialize)]
778struct Choice {
780 index: usize,
781 delta: Option<Delta>,
782 finish_reason: Option<String>,
783}
784
785#[derive(Debug, Deserialize)]
786struct Delta {
787 content: Option<String>,
788 reasoning_content: Option<String>,
789 tool_calls: Option<Vec<ToolCallDelta>>,
790}
791
792#[derive(Debug, Deserialize)]
793struct ToolCallDelta {
795 index: Option<usize>,
796 id: Option<String>,
797 #[serde(rename = "type")]
798 _type_: Option<String>,
799 function: Option<FunctionDelta>,
800}
801
802#[derive(Debug, Deserialize)]
803struct FunctionDelta {
805 name: Option<String>,
806 arguments: Option<String>,
807}
808
809#[derive(Debug, Deserialize, Clone)]
810struct UsageInfo {
811 prompt_tokens: usize,
812 completion_tokens: usize,
813 total_tokens: usize,
814 #[serde(rename = "prompt_tokens_details")]
815 prompt_tokens_details: Option<PromptTokensDetails>,
816}
817
818#[derive(Debug, Deserialize, Clone)]
819struct PromptTokensDetails {
820 #[serde(rename = "cached_tokens")]
821 cached_tokens: usize,
822}
823
824#[cfg(test)]
825mod tests {
826 use super::*;
827
828 const PROVIDER: &str = "openai";
829 const MODEL: &str = "gpt-4o";
830
831 fn parse_sse(sse: &str) -> Vec<ProviderEvent> {
832 let mut output = AssistantMessage::new(Api::OpenAiCompletions, PROVIDER, MODEL);
833 parse_sse_events(sse, PROVIDER, MODEL, &mut output)
834 }
835
836 #[test]
839 fn parse_single_text_event() {
840 let sse = "data: {\"id\":\"chatcmpl-1\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"}}]}\n\n";
841 let events = parse_sse(sse);
842 assert_eq!(events.len(), 1);
843 match &events[0] {
844 ProviderEvent::TextDelta {
845 delta,
846 content_index,
847 ..
848 } => {
849 assert_eq!(delta, "Hello");
850 assert_eq!(*content_index, 0);
851 }
852 other => panic!("expected TextDelta, got {other:?}"),
853 }
854 }
855
856 #[test]
857 fn parse_multiple_text_events() {
858 let sse = concat!(
859 "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hel\"}}]}\n",
860 "\n",
861 "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"lo!\"}}]}\n",
862 "\n"
863 );
864 let events = parse_sse(sse);
865 assert_eq!(events.len(), 2);
866 let texts: Vec<&str> = events
867 .iter()
868 .filter_map(|e| match e {
869 ProviderEvent::TextDelta { delta, .. } => Some(delta.as_str()),
870 _ => None,
871 })
872 .collect();
873 assert_eq!(texts, vec!["Hel", "lo!"]);
874 }
875
876 #[test]
877 fn parse_done_terminator() {
878 let sse = concat!(
879 "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"X\"}}]}\n",
880 "\n",
881 "data: [DONE]\n",
882 "\n",
883 "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"NEVER\"}}]}\n"
884 );
885 let events = parse_sse(sse);
886 assert_eq!(events.len(), 1);
888 match &events[0] {
889 ProviderEvent::TextDelta { delta, .. } => assert_eq!(delta, "X"),
890 other => panic!("expected TextDelta, got {other:?}"),
891 }
892 }
893
894 #[test]
897 fn parse_finish_reason_stop() {
898 let sse = "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":null,\"finish_reason\":\"stop\"}]}\n\n";
899 let events = parse_sse(sse);
900 assert_eq!(events.len(), 1);
901 match &events[0] {
902 ProviderEvent::Done { reason, .. } => assert!(matches!(reason, StopReason::Stop)),
903 other => panic!("expected Done, got {other:?}"),
904 }
905 }
906
907 #[test]
908 fn parse_finish_reason_length() {
909 let sse = "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":null,\"finish_reason\":\"length\"}]}\n\n";
910 let events = parse_sse(sse);
911 match &events[0] {
912 ProviderEvent::Done { reason, .. } => assert!(matches!(reason, StopReason::Length)),
913 other => panic!("expected Done with Length, got {other:?}"),
914 }
915 }
916
917 #[test]
918 fn parse_finish_reason_tool_calls() {
919 let sse = "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":null,\"finish_reason\":\"tool_calls\"}]}\n\n";
920 let events = parse_sse(sse);
921 match &events[0] {
922 ProviderEvent::Done { reason, .. } => assert!(matches!(reason, StopReason::ToolUse)),
923 other => panic!("expected Done with ToolUse, got {other:?}"),
924 }
925 }
926
927 #[test]
930 fn parse_tool_call_deltas() {
931 let sse = concat!(
932 "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"arguments\":\"\"}}]}}]}\n",
933 "\n",
934 "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"{\\\"city\\\":\\\"SF\\\"}\"}}]}}]}\n",
935 "\n"
936 );
937 let events = parse_sse(sse);
938 assert_eq!(events.len(), 3);
941 let starts: Vec<&str> = events
942 .iter()
943 .filter_map(|e| match e {
944 ProviderEvent::ToolCallStart { tool_name, .. } => tool_name.as_deref(),
945 _ => None,
946 })
947 .collect();
948 assert_eq!(starts, vec!["get_weather"]);
949 let deltas: Vec<&str> = events
950 .iter()
951 .filter_map(|e| match e {
952 ProviderEvent::ToolCallDelta { delta, .. } => Some(delta.as_str()),
953 _ => None,
954 })
955 .collect();
956 assert_eq!(deltas, vec!["", "{\"city\":\"SF\"}"]);
957 }
958
959 #[test]
960 fn parse_tool_call_with_no_arguments_field() {
961 let sse = "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"name\":\"run\"}}]}}]}\n\n";
963 let events = parse_sse(sse);
964 assert_eq!(events.len(), 2);
965 match &events[0] {
966 ProviderEvent::ToolCallStart { tool_name, .. } => {
967 assert_eq!(tool_name.as_deref(), Some("run"));
968 }
969 other => panic!("expected ToolCallStart, got {other:?}"),
970 }
971 match &events[1] {
972 ProviderEvent::ToolCallDelta { delta, .. } => assert_eq!(delta, ""),
973 other => panic!("expected ToolCallDelta, got {other:?}"),
974 }
975 }
976
977 #[test]
980 fn parse_usage_in_chunk() {
981 let sse = concat!(
984 "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hi\"}}],\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":8,\"total_tokens\":18,\"prompt_tokens_details\":{\"cached_tokens\":3}}}\n",
985 "\n",
986 "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":null,\"finish_reason\":\"stop\"}]}\n"
987 );
988 let events = parse_sse(sse);
989 assert_eq!(events.len(), 2);
991 match &events[1] {
992 ProviderEvent::Done { message, .. } => {
993 assert_eq!(message.usage.input, 10);
994 assert_eq!(message.usage.output, 8);
995 assert_eq!(message.usage.total_tokens, 18);
996 assert_eq!(message.usage.cache_read, 3);
997 }
998 other => panic!("expected Done, got {other:?}"),
999 }
1000 }
1001
1002 #[test]
1003 fn parse_usage_without_cache_details() {
1004 let sse = concat!(
1006 "data: {\"id\":\"c\",\"choices\":[],\"usage\":{\"prompt_tokens\":5,\"completion_tokens\":2,\"total_tokens\":7}}\n",
1007 "\n",
1008 "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":null,\"finish_reason\":\"stop\"}]}\n"
1009 );
1010 let events = parse_sse(sse);
1011 match &events[0] {
1012 ProviderEvent::Done { message, .. } => {
1013 assert_eq!(message.usage.input, 5);
1014 assert_eq!(message.usage.output, 2);
1015 assert_eq!(message.usage.cache_read, 0);
1016 }
1017 other => panic!("expected Done, got {other:?}"),
1018 }
1019 }
1020
1021 #[test]
1024 fn parse_empty_input() {
1025 let events = parse_sse("");
1026 assert!(events.is_empty());
1027 }
1028
1029 #[test]
1030 fn parse_only_empty_lines() {
1031 let events = parse_sse("\n\n\n");
1032 assert!(events.is_empty());
1033 }
1034
1035 #[test]
1036 fn parse_malformed_json_after_data() {
1037 let sse = "data: {not json at all}\ndata: also bad\ndata: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"ok\"}}]}\n";
1038 let events = parse_sse(sse);
1039 assert_eq!(events.len(), 1);
1041 match &events[0] {
1042 ProviderEvent::TextDelta { delta, .. } => assert_eq!(delta, "ok"),
1043 other => panic!("expected TextDelta, got {other:?}"),
1044 }
1045 }
1046
1047 #[test]
1048 fn parse_empty_data_line() {
1049 let sse = "data: \ndata: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"X\"}}]}\n";
1050 let events = parse_sse(sse);
1051 assert_eq!(events.len(), 1);
1052 }
1053
1054 #[test]
1055 fn parse_non_data_lines_ignored() {
1056 let sse = "event: ping\nid: 42\nretry: 5000\ndata: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Y\"}}]}\n";
1057 let events = parse_sse(sse);
1058 assert_eq!(events.len(), 1);
1059 }
1060
1061 #[test]
1062 fn parse_carriage_return_line_endings() {
1063 let sse = "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"CR\"}}]}\r\n\r\n";
1064 let events = parse_sse(sse);
1065 assert_eq!(events.len(), 1);
1066 match &events[0] {
1067 ProviderEvent::TextDelta { delta, .. } => assert_eq!(delta, "CR"),
1068 other => panic!("expected TextDelta, got {other:?}"),
1069 }
1070 }
1071
1072 #[test]
1075 fn parse_full_stream_with_text_tool_and_done() {
1076 let sse = concat!(
1077 "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Let me\"}}]}\n",
1078 "\n",
1079 "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" check\"}}]}\n",
1080 "\n",
1081 "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"search\",\"arguments\":\"{\\\"q\\\":\\\"rust\\\"}\"}}]}}]}\n",
1082 "\n",
1083 "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":null,\"finish_reason\":\"tool_calls\"}]}\n",
1084 "\n",
1085 "data: [DONE]\n"
1086 );
1087 let events = parse_sse(sse);
1088 assert_eq!(events.len(), 5); let mut text_count = 0;
1091 let mut tc_start_count = 0;
1092 let mut tc_delta_count = 0;
1093 let mut done_count = 0;
1094 for e in &events {
1095 match e {
1096 ProviderEvent::TextDelta { .. } => text_count += 1,
1097 ProviderEvent::ToolCallStart { .. } => tc_start_count += 1,
1098 ProviderEvent::ToolCallDelta { .. } => tc_delta_count += 1,
1099 ProviderEvent::Done { reason, .. } => {
1100 done_count += 1;
1101 assert!(matches!(reason, StopReason::ToolUse));
1102 }
1103 other => panic!("unexpected event: {other:?}"),
1104 }
1105 }
1106 assert_eq!(text_count, 2);
1107 assert_eq!(tc_start_count, 1);
1108 assert_eq!(tc_delta_count, 1);
1109 assert_eq!(done_count, 1);
1110 }
1111}