1use async_trait::async_trait;
4use bytes::Bytes;
5use futures::{Stream, StreamExt};
6use reqwest::Client;
7use serde::Deserialize;
8use serde_json::Value as JsonValue;
9use std::pin::Pin;
10
11use super::openai_responses_shared::parse_streaming_json;
12use super::shared_client;
13use crate::{
14 error::ProviderError, Api, AssistantMessage, ContentBlock, Context, Model, Provider,
15 ProviderEvent, StopReason, StreamOptions, TextContent, ThinkingContent, Usage,
16};
17
18fn is_zai(model: &Model) -> bool {
20 model.provider.eq_ignore_ascii_case("zai") || model.base_url.contains("api.z.ai")
21}
22
23#[derive(Clone)]
25pub struct OpenAiProvider {
26 client: &'static Client,
27 api_key: Option<String>,
28 base_url: Option<String>,
29}
30
31impl OpenAiProvider {
32 pub fn new() -> Self {
37 Self {
38 client: shared_client(),
39 api_key: None,
40 base_url: None,
41 }
42 }
43
44 pub fn with_api_key(api_key: impl Into<String>) -> Self {
46 Self {
47 client: shared_client(),
48 api_key: Some(api_key.into()),
49 base_url: None,
50 }
51 }
52
53 pub fn with_base_url(base_url: &str) -> Self {
57 Self {
58 client: shared_client(),
59 api_key: None,
60 base_url: Some(base_url.to_string()),
61 }
62 }
63
64 pub fn with_base_url_and_key(base_url: &str, api_key: Option<String>) -> Self {
68 Self {
69 client: shared_client(),
70 api_key,
71 base_url: Some(base_url.to_string()),
72 }
73 }
74}
75
76impl Default for OpenAiProvider {
77 fn default() -> Self {
78 Self::new()
79 }
80}
81
82#[async_trait]
83impl Provider for OpenAiProvider {
84 async fn stream(
85 &self,
86 model: &Model,
87 context: &Context,
88 options: Option<StreamOptions>,
89 ) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError> {
90 let options = options.unwrap_or_default();
91
92 let effective_base_url = self.base_url.as_deref().unwrap_or(&model.base_url);
94 let url = format!("{}/chat/completions", effective_base_url);
95
96 let api_key = options
98 .api_key
99 .as_ref()
100 .or(self.api_key.as_ref())
101 .ok_or_else(|| ProviderError::MissingApiKey)?;
102
103 let messages = build_messages(context)?;
105
106 let mut body = serde_json::json!({
108 "model": model.id,
109 "messages": messages,
110 "stream": true,
111 "stream_options": { "include_usage": true },
112 });
113
114 if let Some(temp) = options.temperature {
116 body["temperature"] = serde_json::json!(temp);
117 }
118
119 if let Some(max) = options.max_tokens {
120 body["max_tokens"] = serde_json::json!(max);
121 }
122
123 if !context.tools.is_empty() {
125 body["tools"] = build_tools(&context.tools)?;
126 }
127
128 if is_zai(model) {
132 if model.reasoning {
133 body["enable_thinking"] = serde_json::json!(true);
134 }
135 if !context.tools.is_empty() {
136 body["tool_stream"] = serde_json::json!(true);
137 }
138 }
139
140 tracing::info!(
141 "Sending request to {} model={} body_len={} enable_thinking={} tool_stream={}",
142 url,
143 model.id,
144 body.to_string().len(),
145 body.get("enable_thinking").is_some(),
146 body.get("tool_stream").is_some()
147 );
148 tracing::debug!("Request body: {}", body.to_string());
149
150 let mut headers = reqwest::header::HeaderMap::new();
152 headers.insert(
153 reqwest::header::AUTHORIZATION,
154 format!("Bearer {}", api_key)
155 .parse()
156 .expect("valid bearer header"),
157 );
158 headers.insert(
159 reqwest::header::CONTENT_TYPE,
160 "application/json".parse().expect("valid header value"),
161 );
162
163 for (k, v) in &options.headers {
164 if let (Ok(name), Ok(value)) = (
165 k.parse::<reqwest::header::HeaderName>(),
166 v.parse::<reqwest::header::HeaderValue>(),
167 ) {
168 headers.insert(name, value);
169 }
170 }
171
172 let response = self
174 .client
175 .post(&url)
176 .headers(headers)
177 .json(&body)
178 .send()
179 .await
180 .map_err(ProviderError::RequestFailed)?;
181
182 if !response.status().is_success() {
183 let status = response.status();
184 let body: String = response.text().await.unwrap_or_default();
185 return Err(ProviderError::HttpError(status.as_u16(), body));
186 }
187
188 let provider_name = model.provider.clone();
190 let model_id = model.id.clone();
191
192 let start_event = ProviderEvent::Start {
194 partial: AssistantMessage::new(Api::OpenAiCompletions, &provider_name, &model_id),
195 };
196
197 let stream = response
208 .bytes_stream()
209 .scan(
210 (
211 Vec::new(),
212 std::collections::HashMap::<usize, (String, String, String)>::new(),
213 std::collections::HashMap::<String, usize>::new(), false,
215 AssistantMessage::new(Api::OpenAiCompletions, &provider_name, &model_id),
216 ),
217 move |(
218 pending_bytes,
219 pending_tc,
220 tc_id_to_index,
221 thinking_started,
222 accumulated_output,
223 ),
224 chunk: Result<Bytes, reqwest::Error>| {
225 let events = match chunk {
226 Ok(bytes) => {
227 let mut combined =
229 Vec::with_capacity(pending_bytes.len() + bytes.len());
230 combined.extend_from_slice(pending_bytes);
231 combined.extend_from_slice(&bytes);
232
233 let (text, trailing) = split_complete_lines(&combined);
237 *pending_bytes = trailing;
238
239 tracing::debug!(
240 "parse_sse_events input: {} bytes, {} lines",
241 text.len(),
242 text.lines().count()
243 );
244 let raw_events = parse_sse_events(
245 &text,
246 &provider_name,
247 &model_id,
248 accumulated_output,
249 );
250 tracing::debug!("parse_sse_events output: {} events", raw_events.len());
251
252 let mut processed = Vec::new();
254 for event in raw_events {
255 match &event {
256 ProviderEvent::ThinkingDelta { content_index, .. } => {
257 if !*thinking_started {
259 *thinking_started = true;
260 processed.push(ProviderEvent::ThinkingStart {
261 content_index: *content_index,
262 partial: AssistantMessage::new(
263 Api::OpenAiCompletions,
264 &provider_name,
265 &model_id,
266 ),
267 });
268 }
269 processed.push(event);
270 }
271 ProviderEvent::ToolCallStart {
272 content_index,
273 tool_call_id,
274 tool_name,
275 ..
276 } => {
277 let entry =
278 pending_tc.entry(*content_index).or_insert_with(|| {
279 (String::new(), String::new(), String::new())
280 });
281 if let Some(ref id) = tool_call_id {
282 if !id.is_empty() {
283 entry.0 = id.clone();
284 tc_id_to_index.insert(id.clone(), *content_index);
285 }
286 }
287 if let Some(ref name) = tool_name {
288 if !name.is_empty() {
289 entry.1 = name.clone();
290 }
291 }
292 processed.push(event);
293 }
294 ProviderEvent::ToolCallDelta {
295 content_index,
296 delta,
297 ..
298 } => {
299 let idx = if pending_tc.contains_key(content_index) {
301 *content_index
302 } else {
303 tc_id_to_index
305 .values()
306 .copied()
307 .find(|i| *i == *content_index)
308 .unwrap_or(*content_index)
309 };
310 let entry = pending_tc.entry(idx).or_insert_with(|| {
311 (String::new(), String::new(), String::new())
312 });
313 tracing::debug!(
314 "[TC-DELTA] idx={}, delta_len={}, accumulated_len={}",
315 idx,
316 delta.len(),
317 entry.2.len() + delta.len()
318 );
319 entry.2.push_str(delta);
320 processed.push(event);
321 }
322 ProviderEvent::ToolCallEnd { .. } => {
323 processed.push(event);
325 }
326 ProviderEvent::Done { reason, .. } => {
327 if matches!(reason, StopReason::ToolUse) {
329 let mut indices: Vec<usize> =
330 pending_tc.keys().copied().collect();
331 indices.sort();
332 for idx in indices {
333 let (id, name, arguments) = &pending_tc[&idx];
334 tracing::debug!(
335 "[TC-END] idx={}, id={}, name={}, args_len={}",
336 idx,
337 id.len(),
338 name.len(),
339 arguments.len()
340 );
341 let args_value = parse_streaming_json(arguments);
342 processed.push(ProviderEvent::ToolCallEnd {
343 content_index: idx,
344 tool_call: crate::ToolCall {
345 content_type:
346 crate::messages::ToolCallType::ToolCall,
347 id: id.clone(),
348 name: name.clone(),
349 arguments: args_value,
350 thought_signature: None,
351 },
352 partial: AssistantMessage::new(
353 Api::OpenAiCompletions,
354 &provider_name,
355 &model_id,
356 ),
357 });
358 }
359 }
360 pending_tc.clear();
364 tc_id_to_index.clear();
365 processed.push(event);
366 }
367 _ => {
368 processed.push(event);
369 }
370 }
371 }
372 processed
373 }
374 Err(e) => {
375 vec![ProviderEvent::Error {
376 reason: StopReason::Error,
377 error: create_error_message(
378 &e.to_string(),
379 &provider_name,
380 &model_id,
381 ),
382 }]
383 }
384 };
385 async move { Some(futures::stream::iter(events)) }
387 },
388 )
389 .flatten();
390
391 let stream_with_start = futures::stream::once(async move { start_event }).chain(stream);
393 Ok(Box::pin(stream_with_start))
394 }
395
396 fn name(&self) -> &str {
397 "openai"
398 }
399}
400
401fn build_messages(context: &Context) -> Result<Vec<JsonValue>, ProviderError> {
403 let mut messages = Vec::new();
404
405 if let Some(ref prompt) = context.system_prompt {
407 messages.push(serde_json::json!({
408 "role": "system",
409 "content": prompt,
410 }));
411 }
412
413 for msg in &context.messages {
415 match msg {
416 crate::Message::User(u) => {
417 let content: String = match &u.content {
418 crate::MessageContent::Text(s) => s.clone(),
419 crate::MessageContent::Blocks(blocks) => blocks_to_content(blocks)?.to_string(),
420 };
421 messages.push(serde_json::json!({
422 "role": "user",
423 "content": content,
424 }));
425 }
426 crate::Message::Assistant(a) => {
427 let mut text_parts = Vec::new();
429 let mut tool_calls = Vec::new();
430 for block in &a.content {
431 match block {
432 ContentBlock::Text(t) => {
433 text_parts.push(t.text.clone());
434 }
435 ContentBlock::Thinking(_) => {
436 }
438 ContentBlock::ToolCall(tc) => {
439 tool_calls.push(serde_json::json!({
440 "id": tc.id,
441 "type": "function",
442 "function": {
443 "name": tc.name,
444 "arguments": tc.arguments.to_string(),
445 },
446 }));
447 }
448 ContentBlock::Image(_) | ContentBlock::Unknown(_) => {}
449 }
450 }
451 let mut msg = serde_json::json!({
452 "role": "assistant",
453 "content": text_parts.join(""),
454 });
455 if !tool_calls.is_empty() {
456 msg["tool_calls"] = serde_json::json!(tool_calls);
457 }
458 messages.push(msg);
459 }
460 crate::Message::ToolResult(t) => {
461 let result_text: String = t
462 .content
463 .iter()
464 .filter_map(|b| b.as_text())
465 .collect::<Vec<_>>()
466 .join("");
467 messages.push(serde_json::json!({
468 "role": "tool",
469 "tool_call_id": t.tool_call_id,
470 "content": result_text,
471 }));
472 }
473 }
474 }
475
476 Ok(messages)
477}
478
479fn blocks_to_content(blocks: &[ContentBlock]) -> Result<JsonValue, ProviderError> {
481 if blocks.len() == 1 {
482 if let Some(text) = blocks[0].as_text() {
483 return Ok(JsonValue::String(text.to_string()));
484 }
485 }
486
487 let items: Result<Vec<_>, _> = blocks
488 .iter()
489 .map(|block| match block {
490 ContentBlock::Text(t) => Ok(serde_json::json!({
491 "type": "text",
492 "text": t.text,
493 })),
494 ContentBlock::ToolCall(tc) => Ok(serde_json::json!({
495 "type": "function",
496 "id": tc.id,
497 "function": {
498 "name": tc.name,
499 "arguments": tc.arguments.to_string(),
500 },
501 })),
502 ContentBlock::Thinking(th) => Ok(serde_json::json!({
503 "type": "thinking",
504 "thinking": th.thinking,
505 })),
506 ContentBlock::Image(img) => Ok(serde_json::json!({
507 "type": "image_url",
508 "image_url": {
509 "url": format!("data:{};base64,{}", img.mime_type, img.data),
510 },
511 })),
512 ContentBlock::Unknown(_) => Err(ProviderError::InvalidResponse(
513 "Unknown content block type".into(),
514 )),
515 })
516 .collect();
517
518 Ok(serde_json::json!(items?))
519}
520
521fn build_tools(tools: &[crate::Tool]) -> Result<JsonValue, ProviderError> {
523 let items: Vec<_> = tools
524 .iter()
525 .map(|tool| {
526 serde_json::json!({
527 "type": "function",
528 "function": {
529 "name": tool.name,
530 "description": tool.description,
531 "parameters": tool.parameters,
532 },
533 })
534 })
535 .collect();
536
537 Ok(serde_json::json!(items))
538}
539
540fn find_valid_utf8_prefix(bytes: &[u8]) -> (String, Vec<u8>) {
546 match std::str::from_utf8(bytes) {
547 Ok(s) => (s.to_string(), Vec::new()),
548 Err(e) => {
549 let valid = &bytes[..e.valid_up_to()];
550 let trailing = bytes[e.valid_up_to()..].to_vec();
551 (String::from_utf8_lossy(valid).to_string(), trailing)
552 }
553 }
554}
555
556pub fn split_complete_lines(bytes: &[u8]) -> (String, Vec<u8>) {
560 match bytes.iter().rposition(|&b| b == b'\n') {
562 Some(last_nl) => {
563 let split_at = last_nl + 1;
564 let complete = match std::str::from_utf8(&bytes[..split_at]) {
565 Ok(s) => s.to_string(),
566 Err(_) => {
567 let (s, _) = find_valid_utf8_prefix(&bytes[..split_at]);
568 s
569 }
570 };
571 let trailing = bytes[split_at..].to_vec();
572 (complete, trailing)
573 }
574 None => {
575 (String::new(), bytes.to_vec())
578 }
579 }
580}
581
582fn parse_sse_events(
593 text: &str,
594 _provider: &str,
595 _model_id: &str,
596 output: &mut AssistantMessage,
597) -> Vec<ProviderEvent> {
598 let mut events = Vec::new();
599
600 let estimated_events = text.split('\n').filter(|l| l.starts_with("data: ")).count();
602 events.reserve(estimated_events);
603
604 let mut accumulated_usage = Usage::default();
605
606 for line in text.split('\n') {
607 let line = line.trim_end_matches('\r');
608 if line.is_empty() {
609 continue;
610 }
611
612 if !line.starts_with("data: ") {
614 continue;
615 }
616
617 let data = &line[6..]; if data == "[DONE]" {
621 break;
622 }
623
624 if data.is_empty() {
625 continue;
626 }
627
628 let chunk = match serde_json::from_str::<SSEChunk>(data) {
629 Ok(c) => c,
630 Err(_) => continue,
631 };
632
633 for choice in &chunk.choices {
634 if let Some(delta) = &choice.delta {
635 if let Some(content) = &delta.content {
636 let last_text_idx = output
638 .content
639 .iter()
640 .rposition(|b| matches!(b, ContentBlock::Text(_)));
641 if let Some(idx) = last_text_idx {
642 if let ContentBlock::Text(t) = &mut output.content[idx] {
643 t.text.push_str(content);
644 }
645 } else {
646 output
647 .content
648 .push(ContentBlock::Text(TextContent::new(content.clone())));
649 }
650 events.push(ProviderEvent::TextDelta {
651 content_index: choice.index,
652 delta: content.clone(),
653 partial: output.clone(),
654 });
655 }
656
657 if let Some(ref reasoning) = delta.reasoning_content {
659 if !reasoning.is_empty() {
660 let last_think_idx = output
662 .content
663 .iter()
664 .rposition(|b| matches!(b, ContentBlock::Thinking(_)));
665 if let Some(idx) = last_think_idx {
666 if let ContentBlock::Thinking(t) = &mut output.content[idx] {
667 t.thinking.push_str(reasoning);
668 }
669 } else {
670 output
671 .content
672 .push(ContentBlock::Thinking(ThinkingContent::new(
673 reasoning.clone(),
674 )));
675 }
676 events.push(ProviderEvent::ThinkingDelta {
677 content_index: choice.index,
678 delta: reasoning.clone(),
679 partial: output.clone(),
680 });
681 }
682 }
683
684 if let Some(tool_calls) = &delta.tool_calls {
685 for tc in tool_calls {
686 let tc_index = tc.index.unwrap_or(choice.index);
687
688 if tc.id.is_some()
690 || tc.function.as_ref().and_then(|f| f.name.as_ref()).is_some()
691 {
692 events.push(ProviderEvent::ToolCallStart {
693 content_index: tc_index,
694 tool_call_id: tc.id.clone(),
695 tool_name: tc.function.as_ref().and_then(|f| f.name.clone()),
696 partial: output.clone(),
697 });
698 }
699
700 if let Some(func) = &tc.function {
702 events.push(ProviderEvent::ToolCallDelta {
703 content_index: tc_index,
704 delta: func.arguments.clone().unwrap_or_default(),
705 partial: output.clone(),
706 });
707 }
708 }
709 }
710 }
711
712 if choice.finish_reason.is_some() {
713 let reason = match choice.finish_reason.as_deref() {
714 Some("stop") | Some("end") => StopReason::Stop,
715 Some("length") => StopReason::Length,
716 Some("tool_calls") | Some("function_call") => StopReason::ToolUse,
717 Some("content_filter") => StopReason::Error,
718 Some(unknown) => {
719 tracing::warn!("Unknown finish_reason: '{}', treating as Error", unknown);
720 StopReason::Error
721 }
722 None => StopReason::Stop,
723 };
724 tracing::info!("finish_reason={:?} → {:?}", choice.finish_reason, reason);
725
726 let mut done_msg = output.clone();
727 done_msg.stop_reason = reason;
728 done_msg.usage = accumulated_usage.clone();
729 events.push(ProviderEvent::Done {
730 reason,
731 message: done_msg,
732 });
733 }
734 }
735
736 if let Some(chunk_usage) = chunk.usage {
738 accumulated_usage.input = chunk_usage.prompt_tokens;
739 accumulated_usage.output = chunk_usage.completion_tokens;
740 accumulated_usage.cache_read = chunk_usage
741 .prompt_tokens_details
742 .as_ref()
743 .map(|d| d.cached_tokens)
744 .unwrap_or(0);
745 accumulated_usage.total_tokens = chunk_usage.total_tokens;
746 }
747 }
748
749 events
750}
751
752fn create_error_message(msg: &str, provider: &str, model_id: &str) -> AssistantMessage {
754 let mut message = AssistantMessage::new(Api::OpenAiCompletions, provider, model_id);
755 message.stop_reason = StopReason::Error;
756 message.error_message = Some(msg.to_string());
757 message
758}
759
760#[derive(Debug, Deserialize)]
762struct SSEChunk {
764 _id: Option<String>,
765 #[serde(rename = "model")]
766 _model: Option<String>,
767 choices: Vec<Choice>,
768 usage: Option<UsageInfo>,
769}
770
771#[derive(Debug, Deserialize)]
772struct Choice {
774 index: usize,
775 delta: Option<Delta>,
776 finish_reason: Option<String>,
777}
778
779#[derive(Debug, Deserialize)]
780struct Delta {
781 content: Option<String>,
782 reasoning_content: Option<String>,
783 tool_calls: Option<Vec<ToolCallDelta>>,
784}
785
786#[derive(Debug, Deserialize)]
787struct ToolCallDelta {
789 index: Option<usize>,
790 id: Option<String>,
791 #[serde(rename = "type")]
792 _type_: Option<String>,
793 function: Option<FunctionDelta>,
794}
795
796#[derive(Debug, Deserialize)]
797struct FunctionDelta {
799 name: Option<String>,
800 arguments: Option<String>,
801}
802
803#[derive(Debug, Deserialize, Clone)]
804struct UsageInfo {
805 prompt_tokens: usize,
806 completion_tokens: usize,
807 total_tokens: usize,
808 #[serde(rename = "prompt_tokens_details")]
809 prompt_tokens_details: Option<PromptTokensDetails>,
810}
811
812#[derive(Debug, Deserialize, Clone)]
813struct PromptTokensDetails {
814 #[serde(rename = "cached_tokens")]
815 cached_tokens: usize,
816}
817
818#[cfg(test)]
819mod tests {
820 use super::*;
821
822 const PROVIDER: &str = "openai";
823 const MODEL: &str = "gpt-4o";
824
825 fn parse_sse(sse: &str) -> Vec<ProviderEvent> {
826 let mut output = AssistantMessage::new(Api::OpenAiCompletions, PROVIDER, MODEL);
827 parse_sse_events(sse, PROVIDER, MODEL, &mut output)
828 }
829
830 #[test]
833 fn parse_single_text_event() {
834 let sse = "data: {\"id\":\"chatcmpl-1\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"}}]}\n\n";
835 let events = parse_sse(sse);
836 assert_eq!(events.len(), 1);
837 match &events[0] {
838 ProviderEvent::TextDelta {
839 delta,
840 content_index,
841 ..
842 } => {
843 assert_eq!(delta, "Hello");
844 assert_eq!(*content_index, 0);
845 }
846 other => panic!("expected TextDelta, got {other:?}"),
847 }
848 }
849
850 #[test]
851 fn parse_multiple_text_events() {
852 let sse = concat!(
853 "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hel\"}}]}\n",
854 "\n",
855 "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"lo!\"}}]}\n",
856 "\n"
857 );
858 let events = parse_sse(sse);
859 assert_eq!(events.len(), 2);
860 let texts: Vec<&str> = events
861 .iter()
862 .filter_map(|e| match e {
863 ProviderEvent::TextDelta { delta, .. } => Some(delta.as_str()),
864 _ => None,
865 })
866 .collect();
867 assert_eq!(texts, vec!["Hel", "lo!"]);
868 }
869
870 #[test]
871 fn parse_done_terminator() {
872 let sse = concat!(
873 "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"X\"}}]}\n",
874 "\n",
875 "data: [DONE]\n",
876 "\n",
877 "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"NEVER\"}}]}\n"
878 );
879 let events = parse_sse(sse);
880 assert_eq!(events.len(), 1);
882 match &events[0] {
883 ProviderEvent::TextDelta { delta, .. } => assert_eq!(delta, "X"),
884 other => panic!("expected TextDelta, got {other:?}"),
885 }
886 }
887
888 #[test]
891 fn parse_finish_reason_stop() {
892 let sse = "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":null,\"finish_reason\":\"stop\"}]}\n\n";
893 let events = parse_sse(sse);
894 assert_eq!(events.len(), 1);
895 match &events[0] {
896 ProviderEvent::Done { reason, .. } => assert!(matches!(reason, StopReason::Stop)),
897 other => panic!("expected Done, got {other:?}"),
898 }
899 }
900
901 #[test]
902 fn parse_finish_reason_length() {
903 let sse = "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":null,\"finish_reason\":\"length\"}]}\n\n";
904 let events = parse_sse(sse);
905 match &events[0] {
906 ProviderEvent::Done { reason, .. } => assert!(matches!(reason, StopReason::Length)),
907 other => panic!("expected Done with Length, got {other:?}"),
908 }
909 }
910
911 #[test]
912 fn parse_finish_reason_tool_calls() {
913 let sse = "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":null,\"finish_reason\":\"tool_calls\"}]}\n\n";
914 let events = parse_sse(sse);
915 match &events[0] {
916 ProviderEvent::Done { reason, .. } => assert!(matches!(reason, StopReason::ToolUse)),
917 other => panic!("expected Done with ToolUse, got {other:?}"),
918 }
919 }
920
921 #[test]
924 fn parse_tool_call_deltas() {
925 let sse = concat!(
926 "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"arguments\":\"\"}}]}}]}\n",
927 "\n",
928 "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"{\\\"city\\\":\\\"SF\\\"}\"}}]}}]}\n",
929 "\n"
930 );
931 let events = parse_sse(sse);
932 assert_eq!(events.len(), 3);
935 let starts: Vec<&str> = events
936 .iter()
937 .filter_map(|e| match e {
938 ProviderEvent::ToolCallStart { tool_name, .. } => tool_name.as_deref(),
939 _ => None,
940 })
941 .collect();
942 assert_eq!(starts, vec!["get_weather"]);
943 let deltas: Vec<&str> = events
944 .iter()
945 .filter_map(|e| match e {
946 ProviderEvent::ToolCallDelta { delta, .. } => Some(delta.as_str()),
947 _ => None,
948 })
949 .collect();
950 assert_eq!(deltas, vec!["", "{\"city\":\"SF\"}"]);
951 }
952
953 #[test]
954 fn parse_tool_call_with_no_arguments_field() {
955 let sse = "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"name\":\"run\"}}]}}]}\n\n";
957 let events = parse_sse(sse);
958 assert_eq!(events.len(), 2);
959 match &events[0] {
960 ProviderEvent::ToolCallStart { tool_name, .. } => {
961 assert_eq!(tool_name.as_deref(), Some("run"));
962 }
963 other => panic!("expected ToolCallStart, got {other:?}"),
964 }
965 match &events[1] {
966 ProviderEvent::ToolCallDelta { delta, .. } => assert_eq!(delta, ""),
967 other => panic!("expected ToolCallDelta, got {other:?}"),
968 }
969 }
970
971 #[test]
974 fn parse_usage_in_chunk() {
975 let sse = concat!(
978 "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hi\"}}],\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":8,\"total_tokens\":18,\"prompt_tokens_details\":{\"cached_tokens\":3}}}\n",
979 "\n",
980 "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":null,\"finish_reason\":\"stop\"}]}\n"
981 );
982 let events = parse_sse(sse);
983 assert_eq!(events.len(), 2);
985 match &events[1] {
986 ProviderEvent::Done { message, .. } => {
987 assert_eq!(message.usage.input, 10);
988 assert_eq!(message.usage.output, 8);
989 assert_eq!(message.usage.total_tokens, 18);
990 assert_eq!(message.usage.cache_read, 3);
991 }
992 other => panic!("expected Done, got {other:?}"),
993 }
994 }
995
996 #[test]
997 fn parse_usage_without_cache_details() {
998 let sse = concat!(
1000 "data: {\"id\":\"c\",\"choices\":[],\"usage\":{\"prompt_tokens\":5,\"completion_tokens\":2,\"total_tokens\":7}}\n",
1001 "\n",
1002 "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":null,\"finish_reason\":\"stop\"}]}\n"
1003 );
1004 let events = parse_sse(sse);
1005 match &events[0] {
1006 ProviderEvent::Done { message, .. } => {
1007 assert_eq!(message.usage.input, 5);
1008 assert_eq!(message.usage.output, 2);
1009 assert_eq!(message.usage.cache_read, 0);
1010 }
1011 other => panic!("expected Done, got {other:?}"),
1012 }
1013 }
1014
1015 #[test]
1018 fn parse_empty_input() {
1019 let events = parse_sse("");
1020 assert!(events.is_empty());
1021 }
1022
1023 #[test]
1024 fn parse_only_empty_lines() {
1025 let events = parse_sse("\n\n\n");
1026 assert!(events.is_empty());
1027 }
1028
1029 #[test]
1030 fn parse_malformed_json_after_data() {
1031 let sse = "data: {not json at all}\ndata: also bad\ndata: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"ok\"}}]}\n";
1032 let events = parse_sse(sse);
1033 assert_eq!(events.len(), 1);
1035 match &events[0] {
1036 ProviderEvent::TextDelta { delta, .. } => assert_eq!(delta, "ok"),
1037 other => panic!("expected TextDelta, got {other:?}"),
1038 }
1039 }
1040
1041 #[test]
1042 fn parse_empty_data_line() {
1043 let sse = "data: \ndata: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"X\"}}]}\n";
1044 let events = parse_sse(sse);
1045 assert_eq!(events.len(), 1);
1046 }
1047
1048 #[test]
1049 fn parse_non_data_lines_ignored() {
1050 let sse = "event: ping\nid: 42\nretry: 5000\ndata: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Y\"}}]}\n";
1051 let events = parse_sse(sse);
1052 assert_eq!(events.len(), 1);
1053 }
1054
1055 #[test]
1056 fn parse_carriage_return_line_endings() {
1057 let sse = "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"CR\"}}]}\r\n\r\n";
1058 let events = parse_sse(sse);
1059 assert_eq!(events.len(), 1);
1060 match &events[0] {
1061 ProviderEvent::TextDelta { delta, .. } => assert_eq!(delta, "CR"),
1062 other => panic!("expected TextDelta, got {other:?}"),
1063 }
1064 }
1065
1066 #[test]
1069 fn parse_full_stream_with_text_tool_and_done() {
1070 let sse = concat!(
1071 "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Let me\"}}]}\n",
1072 "\n",
1073 "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" check\"}}]}\n",
1074 "\n",
1075 "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"search\",\"arguments\":\"{\\\"q\\\":\\\"rust\\\"}\"}}]}}]}\n",
1076 "\n",
1077 "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":null,\"finish_reason\":\"tool_calls\"}]}\n",
1078 "\n",
1079 "data: [DONE]\n"
1080 );
1081 let events = parse_sse(sse);
1082 assert_eq!(events.len(), 5); let mut text_count = 0;
1085 let mut tc_start_count = 0;
1086 let mut tc_delta_count = 0;
1087 let mut done_count = 0;
1088 for e in &events {
1089 match e {
1090 ProviderEvent::TextDelta { .. } => text_count += 1,
1091 ProviderEvent::ToolCallStart { .. } => tc_start_count += 1,
1092 ProviderEvent::ToolCallDelta { .. } => tc_delta_count += 1,
1093 ProviderEvent::Done { reason, .. } => {
1094 done_count += 1;
1095 assert!(matches!(reason, StopReason::ToolUse));
1096 }
1097 other => panic!("unexpected event: {other:?}"),
1098 }
1099 }
1100 assert_eq!(text_count, 2);
1101 assert_eq!(tc_start_count, 1);
1102 assert_eq!(tc_delta_count, 1);
1103 assert_eq!(done_count, 1);
1104 }
1105}