1use std::borrow::Cow;
10
11use crate::error::{Error, Result};
12use crate::http::client::Client;
13use crate::model::{
14 AssistantMessage, ContentBlock, Message, StopReason, StreamEvent, TextContent, ThinkingContent,
15 ToolCall, Usage, UserContent,
16};
17use crate::models::CompatConfig;
18use crate::provider::{Context, Provider, StreamOptions, ToolDef};
19use crate::sse::SseStream;
20use async_trait::async_trait;
21use futures::StreamExt;
22use futures::stream::{self, Stream};
23use serde::{Deserialize, Serialize};
24use std::collections::VecDeque;
25use std::pin::Pin;
26
27const OPENAI_API_URL: &str = "https://api.openai.com/v1/chat/completions";
32const DEFAULT_MAX_TOKENS: u32 = 4096;
33const OPENROUTER_DEFAULT_HTTP_REFERER: &str = "https://github.com/Dicklesworthstone/pi_agent_rust";
34const OPENROUTER_DEFAULT_X_TITLE: &str = "Pi Agent Rust";
35
36fn to_cow_role(role: &str) -> Cow<'_, str> {
43 match role {
44 "system" => Cow::Borrowed("system"),
45 "developer" => Cow::Borrowed("developer"),
46 "user" => Cow::Borrowed("user"),
47 "assistant" => Cow::Borrowed("assistant"),
48 "tool" => Cow::Borrowed("tool"),
49 "function" => Cow::Borrowed("function"),
50 other => Cow::Owned(other.to_string()),
51 }
52}
53
54fn map_has_any_header(headers: &std::collections::HashMap<String, String>, names: &[&str]) -> bool {
55 headers
56 .keys()
57 .any(|key| names.iter().any(|name| key.eq_ignore_ascii_case(name)))
58}
59
60fn authorization_override(
61 options: &StreamOptions,
62 compat: Option<&CompatConfig>,
63) -> Option<String> {
64 super::first_non_empty_header_value_case_insensitive(&options.headers, &["authorization"])
65 .or_else(|| {
66 compat
67 .and_then(|compat| compat.custom_headers.as_ref())
68 .and_then(|headers| {
69 super::first_non_empty_header_value_case_insensitive(
70 headers,
71 &["authorization"],
72 )
73 })
74 })
75}
76
77fn first_non_empty_env(keys: &[&str]) -> Option<String> {
78 keys.iter().find_map(|key| {
79 std::env::var(key)
80 .ok()
81 .map(|value| value.trim().to_string())
82 .filter(|value| !value.is_empty())
83 })
84}
85
86fn openrouter_default_http_referer() -> String {
87 first_non_empty_env(&["OPENROUTER_HTTP_REFERER", "PI_OPENROUTER_HTTP_REFERER"])
88 .unwrap_or_else(|| OPENROUTER_DEFAULT_HTTP_REFERER.to_string())
89}
90
91fn openrouter_default_x_title() -> String {
92 first_non_empty_env(&["OPENROUTER_X_TITLE", "PI_OPENROUTER_X_TITLE"])
93 .unwrap_or_else(|| OPENROUTER_DEFAULT_X_TITLE.to_string())
94}
95
96pub struct OpenAIProvider {
102 client: Client,
103 model: String,
104 base_url: String,
105 provider: String,
106 compat: Option<CompatConfig>,
107}
108
109impl OpenAIProvider {
110 pub fn new(model: impl Into<String>) -> Self {
112 Self {
113 client: Client::new(),
114 model: model.into(),
115 base_url: OPENAI_API_URL.to_string(),
116 provider: "openai".to_string(),
117 compat: None,
118 }
119 }
120
121 #[must_use]
126 pub fn with_provider_name(mut self, provider: impl Into<String>) -> Self {
127 self.provider = provider.into();
128 self
129 }
130
131 #[must_use]
133 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
134 self.base_url = base_url.into();
135 self
136 }
137
138 #[must_use]
140 pub fn with_client(mut self, client: Client) -> Self {
141 self.client = client;
142 self
143 }
144
145 #[must_use]
150 pub fn with_compat(mut self, compat: Option<CompatConfig>) -> Self {
151 self.compat = compat;
152 self
153 }
154
155 pub fn build_request<'a>(
157 &'a self,
158 context: &'a Context<'_>,
159 options: &StreamOptions,
160 ) -> OpenAIRequest<'a> {
161 let system_role = self
162 .compat
163 .as_ref()
164 .and_then(|c| c.system_role_name.as_deref())
165 .unwrap_or("system");
166 let messages = Self::build_messages_with_role(context, system_role);
167
168 let tools_supported = self
169 .compat
170 .as_ref()
171 .and_then(|c| c.supports_tools)
172 .unwrap_or(true);
173
174 let tools: Option<Vec<OpenAITool<'a>>> = if context.tools.is_empty() || !tools_supported {
175 None
176 } else {
177 Some(context.tools.iter().map(convert_tool_to_openai).collect())
178 };
179
180 let use_alt_field = self
182 .compat
183 .as_ref()
184 .and_then(|c| c.max_tokens_field.as_deref())
185 .is_some_and(|f| f == "max_completion_tokens");
186
187 let token_limit = options.max_tokens.or(Some(DEFAULT_MAX_TOKENS));
188 let (max_tokens, max_completion_tokens) = if use_alt_field {
189 (None, token_limit)
190 } else {
191 (token_limit, None)
192 };
193
194 let include_usage = self
195 .compat
196 .as_ref()
197 .and_then(|c| c.supports_usage_in_streaming)
198 .unwrap_or(true);
199
200 let stream_options = Some(OpenAIStreamOptions { include_usage });
201
202 OpenAIRequest {
203 model: &self.model,
204 messages,
205 max_tokens,
206 max_completion_tokens,
207 temperature: options.temperature,
208 tools,
209 stream: true,
210 stream_options,
211 }
212 }
213
214 fn build_request_json(
215 &self,
216 context: &Context<'_>,
217 options: &StreamOptions,
218 ) -> Result<serde_json::Value> {
219 let request = self.build_request(context, options);
220 let mut value = serde_json::to_value(request)
221 .map_err(|e| Error::api(format!("Failed to serialize OpenAI request: {e}")))?;
222 self.apply_openrouter_routing_overrides(&mut value)?;
223 Ok(value)
224 }
225
226 fn apply_openrouter_routing_overrides(&self, request: &mut serde_json::Value) -> Result<()> {
227 if !self.provider.eq_ignore_ascii_case("openrouter") {
228 return Ok(());
229 }
230
231 let Some(routing) = self
232 .compat
233 .as_ref()
234 .and_then(|compat| compat.open_router_routing.as_ref())
235 else {
236 return Ok(());
237 };
238
239 let Some(request_obj) = request.as_object_mut() else {
240 return Err(Error::api(
241 "OpenAI request body must serialize to a JSON object",
242 ));
243 };
244 let Some(routing_obj) = routing.as_object() else {
245 return Err(Error::config(
246 "openRouterRouting must be a JSON object when configured",
247 ));
248 };
249
250 for (key, value) in routing_obj {
251 request_obj.insert(key.clone(), value.clone());
252 }
253 Ok(())
254 }
255
256 fn build_messages_with_role<'a>(
258 context: &'a Context<'_>,
259 system_role: &'a str,
260 ) -> Vec<OpenAIMessage<'a>> {
261 let mut messages = Vec::with_capacity(context.messages.len() + 1);
262
263 if let Some(system) = &context.system_prompt {
265 messages.push(OpenAIMessage {
266 role: to_cow_role(system_role),
267 content: Some(OpenAIContent::Text(Cow::Borrowed(system))),
268 tool_calls: None,
269 tool_call_id: None,
270 });
271 }
272
273 for message in context.messages.iter() {
275 messages.extend(convert_message_to_openai(message));
276 }
277
278 messages
279 }
280}
281
282#[async_trait]
283impl Provider for OpenAIProvider {
284 fn name(&self) -> &str {
285 &self.provider
286 }
287
288 fn api(&self) -> &'static str {
289 "openai-completions"
290 }
291
292 fn model_id(&self) -> &str {
293 &self.model
294 }
295
296 #[allow(clippy::too_many_lines)]
297 async fn stream(
298 &self,
299 context: &Context<'_>,
300 options: &StreamOptions,
301 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
302 let authorization_override = authorization_override(options, self.compat.as_ref());
303
304 let auth_value = if authorization_override.is_some() {
305 None
306 } else {
307 Some(
308 options
309 .api_key
310 .clone()
311 .or_else(|| std::env::var("OPENAI_API_KEY").ok())
312 .ok_or_else(|| {
313 Error::provider(
314 self.name(),
315 "Missing API key for provider. Configure credentials with /login <provider> or set the provider's API key env var.",
316 )
317 })?,
318 )
319 };
320
321 let request_body = self.build_request_json(context, options)?;
322
323 let mut request = self
326 .client
327 .post(&self.base_url)
328 .header("Accept", "text/event-stream");
329
330 if let Some(auth_value) = auth_value {
331 request = request.header("Authorization", format!("Bearer {auth_value}"));
332 }
333
334 if self.provider.eq_ignore_ascii_case("openrouter") {
335 let compat_headers = self
336 .compat
337 .as_ref()
338 .and_then(|compat| compat.custom_headers.as_ref());
339 let has_referer = map_has_any_header(&options.headers, &["http-referer", "referer"])
340 || compat_headers.is_some_and(|headers| {
341 map_has_any_header(headers, &["http-referer", "referer"])
342 });
343 if !has_referer {
344 request = request.header("HTTP-Referer", openrouter_default_http_referer());
345 }
346
347 let has_title = map_has_any_header(&options.headers, &["x-title"])
348 || compat_headers.is_some_and(|headers| map_has_any_header(headers, &["x-title"]));
349 if !has_title {
350 request = request.header("X-Title", openrouter_default_x_title());
351 }
352 }
353
354 if let Some(compat) = &self.compat {
356 if let Some(custom_headers) = &compat.custom_headers {
357 request = super::apply_headers_ignoring_blank_auth_overrides(
358 request,
359 custom_headers,
360 &["authorization"],
361 );
362 }
363 }
364
365 request = super::apply_headers_ignoring_blank_auth_overrides(
367 request,
368 &options.headers,
369 &["authorization"],
370 );
371
372 let request = request.json(&request_body)?;
373
374 let response = Box::pin(request.send()).await?;
375 let status = response.status();
376 if !(200..300).contains(&status) {
377 let body = response
378 .text()
379 .await
380 .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
381 return Err(Error::provider(
382 &self.provider,
383 format!("OpenAI API error (HTTP {status}): {body}"),
384 ));
385 }
386
387 let content_type = response
388 .headers()
389 .iter()
390 .find(|(name, _)| name.eq_ignore_ascii_case("content-type"))
391 .map(|(_, value)| value.to_ascii_lowercase());
392 if !content_type
393 .as_deref()
394 .is_some_and(|value| value.contains("text/event-stream"))
395 {
396 let message = content_type.map_or_else(
397 || {
398 format!(
399 "OpenAI API protocol error (HTTP {status}): missing Content-Type (expected text/event-stream)"
400 )
401 },
402 |value| {
403 format!(
404 "OpenAI API protocol error (HTTP {status}): unexpected Content-Type {value} (expected text/event-stream)"
405 )
406 },
407 );
408 return Err(Error::api(message));
409 }
410
411 let event_source = SseStream::new(response.bytes_stream());
413
414 let model = self.model.clone();
416 let api = self.api().to_string();
417 let provider = self.name().to_string();
418
419 let stream = stream::unfold(
420 StreamState::new(event_source, model, api, provider),
421 |mut state| async move {
422 if state.done {
423 return None;
424 }
425 loop {
426 if let Some(event) = state.pending_events.pop_front() {
427 return Some((Ok(event), state));
428 }
429
430 match state.event_source.next().await {
431 Some(Ok(msg)) => {
432 state.transient_error_count = 0;
434 if msg.data == "[DONE]" {
436 state.done = true;
437 let reason = state.partial.stop_reason;
438 let message = std::mem::take(&mut state.partial);
439 return Some((Ok(StreamEvent::Done { reason, message }), state));
440 }
441
442 if let Err(e) = state.process_event(&msg.data) {
443 state.done = true;
444 return Some((Err(e), state));
445 }
446 }
447 Some(Err(e)) => {
448 const MAX_CONSECUTIVE_TRANSIENT_ERRORS: usize = 5;
452 if e.kind() == std::io::ErrorKind::WriteZero
453 || e.kind() == std::io::ErrorKind::WouldBlock
454 || e.kind() == std::io::ErrorKind::TimedOut
455 {
456 state.transient_error_count += 1;
457 if state.transient_error_count <= MAX_CONSECUTIVE_TRANSIENT_ERRORS {
458 tracing::warn!(
459 kind = ?e.kind(),
460 count = state.transient_error_count,
461 "Transient error in SSE stream, continuing"
462 );
463 continue;
464 }
465 tracing::warn!(
466 kind = ?e.kind(),
467 "Error persisted after {MAX_CONSECUTIVE_TRANSIENT_ERRORS} \
468 consecutive attempts, treating as fatal"
469 );
470 }
471 state.done = true;
472 let err = Error::api(format!("SSE error: {e}"));
473 return Some((Err(err), state));
474 }
475 None => {
480 state.done = true;
481 let reason = state.partial.stop_reason;
482 let message = std::mem::take(&mut state.partial);
483 return Some((Ok(StreamEvent::Done { reason, message }), state));
484 }
485 }
486 }
487 },
488 );
489
490 Ok(Box::pin(stream))
491 }
492}
493
494struct StreamState<S>
499where
500 S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
501{
502 event_source: SseStream<S>,
503 partial: AssistantMessage,
504 tool_calls: Vec<ToolCallState>,
505 pending_events: VecDeque<StreamEvent>,
506 started: bool,
507 done: bool,
508 transient_error_count: usize,
510}
511
512struct ToolCallState {
513 index: usize,
514 content_index: usize,
515 id: String,
516 name: String,
517 arguments: String,
518}
519
520impl<S> StreamState<S>
521where
522 S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
523{
524 fn new(event_source: SseStream<S>, model: String, api: String, provider: String) -> Self {
525 Self {
526 event_source,
527 partial: AssistantMessage {
528 content: Vec::new(),
529 api,
530 provider,
531 model,
532 usage: Usage::default(),
533 stop_reason: StopReason::Stop,
534 error_message: None,
535 timestamp: chrono::Utc::now().timestamp_millis(),
536 },
537 tool_calls: Vec::new(),
538 pending_events: VecDeque::new(),
539 started: false,
540 done: false,
541 transient_error_count: 0,
542 }
543 }
544
545 fn ensure_started(&mut self) {
546 if !self.started {
547 self.started = true;
548 self.pending_events.push_back(StreamEvent::Start {
549 partial: self.partial.clone(),
550 });
551 }
552 }
553
554 fn process_event(&mut self, data: &str) -> Result<()> {
555 let chunk: OpenAIStreamChunk = serde_json::from_str(data)
556 .map_err(|e| Error::api(format!("JSON parse error: {e}\nData: {data}")))?;
557
558 if let Some(usage) = chunk.usage {
560 self.partial.usage.input = usage.prompt_tokens;
561 self.partial.usage.output = usage.completion_tokens.unwrap_or(0);
562 self.partial.usage.total_tokens = usage.total_tokens;
563 if let Some(details) = usage.prompt_tokens_details {
564 self.partial.usage.cache_read = details.cached_tokens.unwrap_or(0);
565 }
566 }
567
568 if let Some(error) = chunk.error {
569 self.partial.stop_reason = StopReason::Error;
570 if let Some(message) = error.message {
571 let message = message.trim();
572 if !message.is_empty() {
573 self.partial.error_message = Some(message.to_string());
574 }
575 }
576 }
577
578 if let Some(choice) = chunk.choices.into_iter().next() {
580 if !self.started
581 && choice.finish_reason.is_none()
582 && choice.delta.content.is_none()
583 && choice.delta.reasoning_content.is_none()
584 && choice.delta.tool_calls.is_none()
585 {
586 self.ensure_started();
587 return Ok(());
588 }
589
590 self.process_choice(choice);
591 }
592
593 Ok(())
594 }
595
596 fn finalize_tool_call_arguments(&mut self) {
597 for tc in &self.tool_calls {
598 let arguments: serde_json::Value = match serde_json::from_str(&tc.arguments) {
599 Ok(args) => args,
600 Err(e) => {
601 tracing::warn!(
602 error = %e,
603 raw = %tc.arguments,
604 "Failed to parse tool arguments as JSON"
605 );
606 serde_json::Value::Null
607 }
608 };
609
610 if let Some(ContentBlock::ToolCall(block)) =
611 self.partial.content.get_mut(tc.content_index)
612 {
613 block.arguments = arguments;
614 }
615 }
616 }
617
618 #[allow(clippy::too_many_lines)]
619 fn process_choice(&mut self, choice: OpenAIChoice) {
620 let delta = choice.delta;
621 if delta.content.is_some()
622 || delta.tool_calls.is_some()
623 || delta.reasoning_content.is_some()
624 {
625 self.ensure_started();
626 }
627
628 if choice.finish_reason.is_some() {
631 self.ensure_started();
632 }
633
634 if let Some(reasoning) = delta.reasoning_content {
636 let last_is_thinking =
638 matches!(self.partial.content.last(), Some(ContentBlock::Thinking(_)));
639
640 let content_index = if last_is_thinking {
641 self.partial.content.len() - 1
642 } else {
643 let idx = self.partial.content.len();
644 self.partial
645 .content
646 .push(ContentBlock::Thinking(ThinkingContent {
647 thinking: String::new(),
648 thinking_signature: None,
649 }));
650
651 self.pending_events
652 .push_back(StreamEvent::ThinkingStart { content_index: idx });
653
654 idx
655 };
656
657 if let Some(ContentBlock::Thinking(t)) = self.partial.content.get_mut(content_index) {
658 t.thinking.push_str(&reasoning);
659 }
660
661 self.pending_events.push_back(StreamEvent::ThinkingDelta {
662 content_index,
663 delta: reasoning,
664 });
665 }
666
667 if let Some(content) = delta.content {
670 let last_is_text = matches!(self.partial.content.last(), Some(ContentBlock::Text(_)));
673
674 let content_index = if last_is_text {
675 self.partial.content.len() - 1
676 } else {
677 let idx = self.partial.content.len();
678
679 self.partial
680 .content
681 .push(ContentBlock::Text(TextContent::new("")));
682
683 self.pending_events
684 .push_back(StreamEvent::TextStart { content_index: idx });
685
686 idx
687 };
688
689 if let Some(ContentBlock::Text(t)) = self.partial.content.get_mut(content_index) {
690 t.text.push_str(&content);
691 }
692
693 self.pending_events.push_back(StreamEvent::TextDelta {
694 content_index,
695
696 delta: content,
697 });
698 }
699
700 if let Some(tool_calls) = delta.tool_calls {
703 for tc_delta in tool_calls {
704 let index = tc_delta.index as usize;
705
706 let tool_state_idx = if let Some(existing_idx) =
711 self.tool_calls.iter().position(|tc| tc.index == index)
712 {
713 existing_idx
714 } else {
715 let content_index = self.partial.content.len();
716
717 self.tool_calls.push(ToolCallState {
718 index,
719
720 content_index,
721
722 id: String::new(),
723
724 name: String::new(),
725
726 arguments: String::new(),
727 });
728
729 self.partial.content.push(ContentBlock::ToolCall(ToolCall {
732 id: String::new(),
733
734 name: String::new(),
735
736 arguments: serde_json::Value::Null,
737
738 thought_signature: None,
739 }));
740
741 self.pending_events
742 .push_back(StreamEvent::ToolCallStart { content_index });
743
744 self.tool_calls.len() - 1
745 };
746
747 let tc = &mut self.tool_calls[tool_state_idx];
748
749 let content_index = tc.content_index;
750
751 if let Some(id) = tc_delta.id {
754 tc.id.push_str(&id);
755
756 if let Some(ContentBlock::ToolCall(block)) =
757 self.partial.content.get_mut(content_index)
758 {
759 block.id.clone_from(&tc.id);
760 }
761 }
762
763 if let Some(function) = tc_delta.function {
766 if let Some(name) = function.name {
767 tc.name.push_str(&name);
768
769 if let Some(ContentBlock::ToolCall(block)) =
770 self.partial.content.get_mut(content_index)
771 {
772 block.name.clone_from(&tc.name);
773 }
774 }
775
776 if let Some(args) = function.arguments {
777 tc.arguments.push_str(&args);
778
779 self.pending_events.push_back(StreamEvent::ToolCallDelta {
788 content_index,
789
790 delta: args,
791 });
792 }
793 }
794 }
795 }
796
797 if let Some(reason) = choice.finish_reason {
800 self.partial.stop_reason = match reason.as_str() {
801 "length" => StopReason::Length,
802
803 "tool_calls" => StopReason::ToolUse,
804
805 "content_filter" | "error" => StopReason::Error,
806
807 _ => StopReason::Stop,
808 };
809
810 for (content_index, block) in self.partial.content.iter().enumerate() {
814 if let ContentBlock::Text(t) = block {
815 self.pending_events.push_back(StreamEvent::TextEnd {
816 content_index,
817 content: t.text.clone(),
818 });
819 } else if let ContentBlock::Thinking(t) = block {
820 self.pending_events.push_back(StreamEvent::ThinkingEnd {
821 content_index,
822 content: t.thinking.clone(),
823 });
824 }
825 }
826
827 self.finalize_tool_call_arguments();
830
831 for tc in &self.tool_calls {
834 if let Some(ContentBlock::ToolCall(tool_call)) =
835 self.partial.content.get(tc.content_index)
836 {
837 self.pending_events.push_back(StreamEvent::ToolCallEnd {
838 content_index: tc.content_index,
839
840 tool_call: tool_call.clone(),
841 });
842 }
843 }
844 }
845 }
846}
847
848#[derive(Debug, Serialize)]
853pub struct OpenAIRequest<'a> {
854 model: &'a str,
855 messages: Vec<OpenAIMessage<'a>>,
856 #[serde(skip_serializing_if = "Option::is_none")]
857 max_tokens: Option<u32>,
858 #[serde(skip_serializing_if = "Option::is_none")]
860 max_completion_tokens: Option<u32>,
861 #[serde(skip_serializing_if = "Option::is_none")]
862 temperature: Option<f32>,
863 #[serde(skip_serializing_if = "Option::is_none")]
864 tools: Option<Vec<OpenAITool<'a>>>,
865 stream: bool,
866 #[serde(skip_serializing_if = "Option::is_none")]
867 stream_options: Option<OpenAIStreamOptions>,
868}
869
870#[derive(Debug, Serialize)]
871struct OpenAIStreamOptions {
872 include_usage: bool,
873}
874
875#[derive(Debug, Serialize)]
876struct OpenAIMessage<'a> {
877 role: Cow<'a, str>,
878 #[serde(skip_serializing_if = "Option::is_none")]
879 content: Option<OpenAIContent<'a>>,
880 #[serde(skip_serializing_if = "Option::is_none")]
881 tool_calls: Option<Vec<OpenAIToolCallRef<'a>>>,
882 #[serde(skip_serializing_if = "Option::is_none")]
883 tool_call_id: Option<&'a str>,
884}
885
886#[derive(Debug, Serialize)]
887#[serde(untagged)]
888enum OpenAIContent<'a> {
889 Text(Cow<'a, str>),
890 Parts(Vec<OpenAIContentPart<'a>>),
891}
892
893#[derive(Debug, Serialize)]
894#[serde(tag = "type", rename_all = "snake_case")]
895enum OpenAIContentPart<'a> {
896 Text { text: Cow<'a, str> },
897 ImageUrl { image_url: OpenAIImageUrl<'a> },
898}
899
900#[derive(Debug, Serialize)]
901struct OpenAIImageUrl<'a> {
902 url: String,
903 #[serde(skip)]
904 _phantom: std::marker::PhantomData<&'a ()>,
906}
907
908#[derive(Debug, Serialize)]
909struct OpenAIToolCallRef<'a> {
910 id: &'a str,
911 r#type: &'static str,
912 function: OpenAIFunctionRef<'a>,
913}
914
915#[derive(Debug, Serialize)]
916struct OpenAIFunctionRef<'a> {
917 name: &'a str,
918 arguments: String,
919}
920
921#[derive(Debug, Serialize)]
922struct OpenAITool<'a> {
923 r#type: &'static str,
924 function: OpenAIFunction<'a>,
925}
926
927#[derive(Debug, Serialize)]
928struct OpenAIFunction<'a> {
929 name: &'a str,
930 description: &'a str,
931 parameters: &'a serde_json::Value,
932}
933
934#[derive(Debug, Deserialize)]
939struct OpenAIStreamChunk {
940 #[serde(default)]
941 choices: Vec<OpenAIChoice>,
942 #[serde(default)]
943 usage: Option<OpenAIUsage>,
944 #[serde(default)]
945 error: Option<OpenAIChunkError>,
946}
947
948#[derive(Debug, Deserialize)]
949struct OpenAIChoice {
950 delta: OpenAIDelta,
951 #[serde(default)]
952 finish_reason: Option<String>,
953}
954
955#[derive(Debug, Deserialize)]
956struct OpenAIDelta {
957 #[serde(default)]
958 content: Option<String>,
959 #[serde(default)]
960 reasoning_content: Option<String>,
961 #[serde(default)]
962 tool_calls: Option<Vec<OpenAIToolCallDelta>>,
963}
964
965#[derive(Debug, Deserialize)]
966struct OpenAIToolCallDelta {
967 index: u32,
968 #[serde(default)]
969 id: Option<String>,
970 #[serde(default)]
971 function: Option<OpenAIFunctionDelta>,
972}
973
974#[derive(Debug, Deserialize)]
975struct OpenAIFunctionDelta {
976 #[serde(default)]
977 name: Option<String>,
978 #[serde(default)]
979 arguments: Option<String>,
980}
981
982#[derive(Debug, Deserialize)]
983#[allow(clippy::struct_field_names)]
984struct OpenAIUsage {
985 prompt_tokens: u64,
986 #[serde(default)]
987 completion_tokens: Option<u64>,
988 total_tokens: u64,
989 #[serde(default)]
990 prompt_tokens_details: Option<OpenAIPromptTokensDetails>,
991}
992
993#[derive(Debug, Deserialize)]
994struct OpenAIPromptTokensDetails {
995 #[serde(default)]
996 cached_tokens: Option<u64>,
997}
998
999#[derive(Debug, Deserialize)]
1000struct OpenAIChunkError {
1001 #[serde(default)]
1002 message: Option<String>,
1003}
1004
1005#[allow(clippy::too_many_lines)]
1010fn convert_message_to_openai(message: &Message) -> Vec<OpenAIMessage<'_>> {
1011 match message {
1012 Message::User(user) => vec![OpenAIMessage {
1013 role: Cow::Borrowed("user"),
1014 content: Some(convert_user_content(&user.content)),
1015 tool_calls: None,
1016 tool_call_id: None,
1017 }],
1018 Message::Custom(custom) => vec![OpenAIMessage {
1019 role: Cow::Borrowed("user"),
1020 content: Some(OpenAIContent::Text(Cow::Borrowed(&custom.content))),
1021 tool_calls: None,
1022 tool_call_id: None,
1023 }],
1024 Message::Assistant(assistant) => {
1025 let mut messages = Vec::new();
1026
1027 let text: String = assistant
1029 .content
1030 .iter()
1031 .filter_map(|b| match b {
1032 ContentBlock::Text(t) => Some(t.text.as_str()),
1033 _ => None,
1034 })
1035 .collect::<Vec<_>>()
1036 .join("\n\n");
1037
1038 let tool_calls: Vec<OpenAIToolCallRef<'_>> = assistant
1040 .content
1041 .iter()
1042 .filter_map(|b| match b {
1043 ContentBlock::ToolCall(tc) => Some(OpenAIToolCallRef {
1044 id: &tc.id,
1045 r#type: "function",
1046 function: OpenAIFunctionRef {
1047 name: &tc.name,
1048 arguments: tc.arguments.to_string(),
1049 },
1050 }),
1051 _ => None,
1052 })
1053 .collect();
1054
1055 let content = if text.is_empty() {
1056 Some(OpenAIContent::Text(Cow::Borrowed("")))
1062 } else {
1063 Some(OpenAIContent::Text(Cow::Owned(text)))
1064 };
1065
1066 let tool_calls = if tool_calls.is_empty() {
1067 None
1068 } else {
1069 Some(tool_calls)
1070 };
1071
1072 messages.push(OpenAIMessage {
1073 role: Cow::Borrowed("assistant"),
1074 content,
1075 tool_calls,
1076 tool_call_id: None,
1077 });
1078
1079 messages
1080 }
1081 Message::ToolResult(result) => {
1082 let mut text_parts = Vec::new();
1083 let mut image_parts = Vec::new();
1084
1085 for block in &result.content {
1086 match block {
1087 ContentBlock::Text(t) => text_parts.push(t.text.as_str()),
1088 ContentBlock::Image(img) => {
1089 let url = format!("data:{};base64,{}", img.mime_type, img.data);
1090 image_parts.push(OpenAIContentPart::ImageUrl {
1091 image_url: OpenAIImageUrl {
1092 url,
1093 _phantom: std::marker::PhantomData,
1094 },
1095 });
1096 }
1097 _ => {}
1098 }
1099 }
1100
1101 let text_content = if text_parts.is_empty() {
1102 if image_parts.is_empty() {
1103 Some(OpenAIContent::Text(Cow::Borrowed("")))
1104 } else {
1105 Some(OpenAIContent::Text(Cow::Borrowed("(see attached image)")))
1106 }
1107 } else {
1108 Some(OpenAIContent::Text(Cow::Owned(text_parts.join("\n"))))
1109 };
1110
1111 let mut messages = vec![OpenAIMessage {
1112 role: Cow::Borrowed("tool"),
1113 content: text_content,
1114 tool_calls: None,
1115 tool_call_id: Some(&result.tool_call_id),
1116 }];
1117
1118 if !image_parts.is_empty() {
1119 let mut parts = vec![OpenAIContentPart::Text {
1120 text: Cow::Borrowed("Attached image(s) from tool result:"),
1121 }];
1122 parts.extend(image_parts);
1123 messages.push(OpenAIMessage {
1124 role: Cow::Borrowed("user"),
1125 content: Some(OpenAIContent::Parts(parts)),
1126 tool_calls: None,
1127 tool_call_id: None,
1128 });
1129 }
1130
1131 messages
1132 }
1133 }
1134}
1135
1136fn convert_user_content(content: &UserContent) -> OpenAIContent<'_> {
1137 match content {
1138 UserContent::Text(text) => OpenAIContent::Text(Cow::Borrowed(text)),
1139 UserContent::Blocks(blocks) => {
1140 let parts: Vec<OpenAIContentPart<'_>> = blocks
1141 .iter()
1142 .filter_map(|block| match block {
1143 ContentBlock::Text(t) => Some(OpenAIContentPart::Text {
1144 text: Cow::Borrowed(&t.text),
1145 }),
1146 ContentBlock::Image(img) => {
1147 let url = format!("data:{};base64,{}", img.mime_type, img.data);
1149 Some(OpenAIContentPart::ImageUrl {
1150 image_url: OpenAIImageUrl {
1151 url,
1152 _phantom: std::marker::PhantomData,
1153 },
1154 })
1155 }
1156 _ => None,
1157 })
1158 .collect();
1159 OpenAIContent::Parts(parts)
1160 }
1161 }
1162}
1163
1164fn convert_tool_to_openai(tool: &ToolDef) -> OpenAITool<'_> {
1165 OpenAITool {
1166 r#type: "function",
1167 function: OpenAIFunction {
1168 name: &tool.name,
1169 description: &tool.description,
1170 parameters: &tool.parameters,
1171 },
1172 }
1173}
1174
1175#[cfg(test)]
1180mod tests {
1181 use super::*;
1182 use asupersync::runtime::RuntimeBuilder;
1183 use futures::{StreamExt, stream};
1184 use serde::{Deserialize, Serialize};
1185 use serde_json::{Value, json};
1186 use std::collections::HashMap;
1187 use std::io::{Read, Write};
1188 use std::net::TcpListener;
1189 use std::path::PathBuf;
1190 use std::sync::mpsc;
1191 use std::time::Duration;
1192
1193 #[test]
1194 fn test_convert_user_text_message() {
1195 let message = Message::User(crate::model::UserMessage {
1196 content: UserContent::Text("Hello".to_string()),
1197 timestamp: 0,
1198 });
1199
1200 let converted = convert_message_to_openai(&message);
1201 assert_eq!(converted.len(), 1);
1202 assert_eq!(converted[0].role, "user");
1203 }
1204
1205 #[test]
1206 fn test_tool_conversion() {
1207 let tool = ToolDef {
1208 name: "test_tool".to_string(),
1209 description: "A test tool".to_string(),
1210 parameters: serde_json::json!({
1211 "type": "object",
1212 "properties": {
1213 "arg": {"type": "string"}
1214 }
1215 }),
1216 };
1217
1218 let converted = convert_tool_to_openai(&tool);
1219 assert_eq!(converted.r#type, "function");
1220 assert_eq!(converted.function.name, "test_tool");
1221 assert_eq!(converted.function.description, "A test tool");
1222 assert_eq!(
1223 converted.function.parameters,
1224 &serde_json::json!({
1225 "type": "object",
1226 "properties": {
1227 "arg": {"type": "string"}
1228 }
1229 })
1230 );
1231 }
1232
1233 #[test]
1234 fn test_provider_info() {
1235 let provider = OpenAIProvider::new("gpt-4o");
1236 assert_eq!(provider.name(), "openai");
1237 assert_eq!(provider.api(), "openai-completions");
1238 }
1239
1240 #[test]
1241 fn test_build_request_includes_system_tools_and_stream_options() {
1242 let provider = OpenAIProvider::new("gpt-4o");
1243 let context = Context {
1244 system_prompt: Some("You are concise.".to_string().into()),
1245 messages: vec![Message::User(crate::model::UserMessage {
1246 content: UserContent::Text("Ping".to_string()),
1247 timestamp: 0,
1248 })]
1249 .into(),
1250 tools: vec![ToolDef {
1251 name: "search".to_string(),
1252 description: "Search docs".to_string(),
1253 parameters: json!({
1254 "type": "object",
1255 "properties": {
1256 "q": { "type": "string" }
1257 },
1258 "required": ["q"]
1259 }),
1260 }]
1261 .into(),
1262 };
1263 let options = StreamOptions {
1264 temperature: Some(0.2),
1265 max_tokens: Some(123),
1266 ..Default::default()
1267 };
1268
1269 let request = provider.build_request(&context, &options);
1270 let value = serde_json::to_value(&request).expect("serialize request");
1271 assert_eq!(value["model"], "gpt-4o");
1272 assert_eq!(value["messages"][0]["role"], "system");
1273 assert_eq!(value["messages"][0]["content"], "You are concise.");
1274 assert_eq!(value["messages"][1]["role"], "user");
1275 assert_eq!(value["messages"][1]["content"], "Ping");
1276 let temperature = value["temperature"]
1277 .as_f64()
1278 .expect("temperature should serialize as number");
1279 assert!((temperature - 0.2).abs() < 1e-6);
1280 assert_eq!(value["max_tokens"], 123);
1281 assert_eq!(value["stream"], true);
1282 assert_eq!(value["stream_options"]["include_usage"], true);
1283 assert_eq!(value["tools"][0]["type"], "function");
1284 assert_eq!(value["tools"][0]["function"]["name"], "search");
1285 assert_eq!(value["tools"][0]["function"]["description"], "Search docs");
1286 assert_eq!(
1287 value["tools"][0]["function"]["parameters"],
1288 json!({
1289 "type": "object",
1290 "properties": {
1291 "q": { "type": "string" }
1292 },
1293 "required": ["q"]
1294 })
1295 );
1296 }
1297
1298 #[test]
1299 fn test_stream_accumulates_tool_call_argument_deltas() {
1300 let events = vec![
1301 json!({ "choices": [{ "delta": {} }] }),
1302 json!({
1303 "choices": [{
1304 "delta": {
1305 "tool_calls": [{
1306 "index": 0,
1307 "id": "call_1",
1308 "function": {
1309 "name": "search",
1310 "arguments": "{\"q\":\"ru"
1311 }
1312 }]
1313 }
1314 }]
1315 }),
1316 json!({
1317 "choices": [{
1318 "delta": {
1319 "tool_calls": [{
1320 "index": 0,
1321 "function": {
1322 "arguments": "st\"}"
1323 }
1324 }]
1325 }
1326 }]
1327 }),
1328 json!({ "choices": [{ "delta": {}, "finish_reason": "tool_calls" }] }),
1329 Value::String("[DONE]".to_string()),
1330 ];
1331
1332 let out = collect_events(&events);
1333 assert!(
1334 out.iter()
1335 .any(|e| matches!(e, StreamEvent::ToolCallStart { .. }))
1336 );
1337 assert!(out.iter().any(
1338 |e| matches!(e, StreamEvent::ToolCallDelta { delta, .. } if delta == "{\"q\":\"ru")
1339 ));
1340 assert!(
1341 out.iter()
1342 .any(|e| matches!(e, StreamEvent::ToolCallDelta { delta, .. } if delta == "st\"}"))
1343 );
1344 let done = out
1345 .iter()
1346 .find_map(|event| match event {
1347 StreamEvent::Done { message, .. } => Some(message),
1348 _ => None,
1349 })
1350 .expect("done event");
1351 let tool_call = done
1352 .content
1353 .iter()
1354 .find_map(|block| match block {
1355 ContentBlock::ToolCall(tc) => Some(tc),
1356 _ => None,
1357 })
1358 .expect("assembled tool call content");
1359 assert_eq!(tool_call.id, "call_1");
1360 assert_eq!(tool_call.name, "search");
1361 assert_eq!(tool_call.arguments, json!({ "q": "rust" }));
1362 assert!(out.iter().any(|e| matches!(
1363 e,
1364 StreamEvent::Done {
1365 reason: StopReason::ToolUse,
1366 ..
1367 }
1368 )));
1369 }
1370
1371 #[test]
1372 fn test_stream_handles_sparse_tool_call_index_without_panic() {
1373 let events = vec![
1374 json!({ "choices": [{ "delta": {} }] }),
1375 json!({
1376 "choices": [{
1377 "delta": {
1378 "tool_calls": [{
1379 "index": 2,
1380 "id": "call_sparse",
1381 "function": {
1382 "name": "lookup",
1383 "arguments": "{\"q\":\"sparse\"}"
1384 }
1385 }]
1386 }
1387 }]
1388 }),
1389 json!({ "choices": [{ "delta": {}, "finish_reason": "tool_calls" }] }),
1390 Value::String("[DONE]".to_string()),
1391 ];
1392
1393 let out = collect_events(&events);
1394 let done = out
1395 .iter()
1396 .find_map(|event| match event {
1397 StreamEvent::Done { message, .. } => Some(message),
1398 _ => None,
1399 })
1400 .expect("done event");
1401 let tool_calls: Vec<&ToolCall> = done
1402 .content
1403 .iter()
1404 .filter_map(|block| match block {
1405 ContentBlock::ToolCall(tc) => Some(tc),
1406 _ => None,
1407 })
1408 .collect();
1409 assert_eq!(tool_calls.len(), 1);
1410 assert_eq!(tool_calls[0].id, "call_sparse");
1411 assert_eq!(tool_calls[0].name, "lookup");
1412 assert_eq!(tool_calls[0].arguments, json!({ "q": "sparse" }));
1413 assert!(
1414 out.iter()
1415 .any(|event| matches!(event, StreamEvent::ToolCallStart { .. })),
1416 "expected tool call start event"
1417 );
1418 }
1419
1420 #[test]
1421 fn test_stream_maps_finish_reason_error_to_stop_reason_error() {
1422 let events = vec![
1423 json!({
1424 "choices": [{ "delta": {}, "finish_reason": "error" }],
1425 "error": { "message": "upstream provider timeout" }
1426 }),
1427 Value::String("[DONE]".to_string()),
1428 ];
1429
1430 let out = collect_events(&events);
1431 let done = out
1432 .iter()
1433 .find_map(|event| match event {
1434 StreamEvent::Done { reason, message } => Some((reason, message)),
1435 _ => None,
1436 })
1437 .expect("done event");
1438 assert_eq!(*done.0, StopReason::Error);
1439 assert_eq!(
1440 done.1.error_message.as_deref(),
1441 Some("upstream provider timeout")
1442 );
1443 }
1444
1445 #[test]
1446 fn test_finish_reason_without_prior_content_emits_start() {
1447 let events = vec![
1448 json!({ "choices": [{ "delta": {}, "finish_reason": "stop" }] }),
1449 Value::String("[DONE]".to_string()),
1450 ];
1451
1452 let out = collect_events(&events);
1453
1454 assert!(!out.is_empty(), "expected at least one event");
1457 assert!(
1458 matches!(out[0], StreamEvent::Start { .. }),
1459 "First event should be Start, got {:?}",
1460 out[0]
1461 );
1462 }
1463
1464 #[test]
1465 fn test_stream_emits_all_events_in_correct_order() {
1466 let events = vec![
1467 json!({ "choices": [{ "delta": { "content": "Hello" } }] }),
1468 json!({ "choices": [{ "delta": { "content": " world" } }] }),
1469 json!({ "choices": [{ "delta": {}, "finish_reason": "stop" }] }),
1470 Value::String("[DONE]".to_string()),
1471 ];
1472
1473 let out = collect_events(&events);
1474
1475 assert_eq!(out.len(), 6, "Expected 6 events, got {}", out.len());
1477
1478 assert!(
1479 matches!(out[0], StreamEvent::Start { .. }),
1480 "Event 0 should be Start, got {:?}",
1481 out[0]
1482 );
1483
1484 assert!(
1485 matches!(
1486 out[1],
1487 StreamEvent::TextStart {
1488 content_index: 0,
1489 ..
1490 }
1491 ),
1492 "Event 1 should be TextStart at index 0, got {:?}",
1493 out[1]
1494 );
1495
1496 assert!(
1497 matches!(&out[2], StreamEvent::TextDelta { content_index: 0, delta, .. } if delta == "Hello"),
1498 "Event 2 should be TextDelta 'Hello' at index 0, got {:?}",
1499 out[2]
1500 );
1501
1502 assert!(
1503 matches!(&out[3], StreamEvent::TextDelta { content_index: 0, delta, .. } if delta == " world"),
1504 "Event 3 should be TextDelta ' world' at index 0, got {:?}",
1505 out[3]
1506 );
1507
1508 assert!(
1509 matches!(&out[4], StreamEvent::TextEnd { content_index: 0, content, .. } if content == "Hello world"),
1510 "Event 4 should be TextEnd 'Hello world' at index 0, got {:?}",
1511 out[4]
1512 );
1513
1514 assert!(
1515 matches!(
1516 out[5],
1517 StreamEvent::Done {
1518 reason: StopReason::Stop,
1519 ..
1520 }
1521 ),
1522 "Event 5 should be Done with Stop reason, got {:?}",
1523 out[5]
1524 );
1525 }
1526
1527 #[test]
1528 fn test_build_request_applies_openrouter_routing_overrides() {
1529 let provider = OpenAIProvider::new("openai/gpt-4o-mini")
1530 .with_provider_name("openrouter")
1531 .with_compat(Some(CompatConfig {
1532 open_router_routing: Some(json!({
1533 "models": ["openai/gpt-4o-mini", "anthropic/claude-3.5-sonnet"],
1534 "provider": {
1535 "order": ["openai", "anthropic"],
1536 "allow_fallbacks": false
1537 },
1538 "route": "fallback"
1539 })),
1540 ..CompatConfig::default()
1541 }));
1542 let context = Context {
1543 system_prompt: None,
1544 messages: vec![Message::User(crate::model::UserMessage {
1545 content: UserContent::Text("Ping".to_string()),
1546 timestamp: 0,
1547 })]
1548 .into(),
1549 tools: Vec::new().into(),
1550 };
1551 let options = StreamOptions::default();
1552
1553 let request = provider
1554 .build_request_json(&context, &options)
1555 .expect("request json");
1556 assert_eq!(request["model"], "openai/gpt-4o-mini");
1557 assert_eq!(request["route"], "fallback");
1558 assert_eq!(request["provider"]["allow_fallbacks"], false);
1559 assert_eq!(request["models"][0], "openai/gpt-4o-mini");
1560 assert_eq!(request["models"][1], "anthropic/claude-3.5-sonnet");
1561 }
1562
1563 #[test]
1564 fn test_stream_sets_bearer_auth_header() {
1565 let captured = run_stream_and_capture_headers().expect("captured request");
1566 assert_eq!(
1567 captured.headers.get("authorization").map(String::as_str),
1568 Some("Bearer test-openai-key")
1569 );
1570 assert_eq!(
1571 captured.headers.get("accept").map(String::as_str),
1572 Some("text/event-stream")
1573 );
1574
1575 let body: Value = serde_json::from_str(&captured.body).expect("request body json");
1576 assert_eq!(body["stream"], true);
1577 assert_eq!(body["stream_options"]["include_usage"], true);
1578 }
1579
1580 #[test]
1581 fn test_stream_openrouter_injects_default_attribution_headers() {
1582 let options = StreamOptions {
1583 api_key: Some("test-openrouter-key".to_string()),
1584 ..Default::default()
1585 };
1586 let captured = run_stream_and_capture_headers_with(
1587 OpenAIProvider::new("openai/gpt-4o-mini").with_provider_name("openrouter"),
1588 &options,
1589 )
1590 .expect("captured request");
1591
1592 assert_eq!(
1593 captured.headers.get("http-referer").map(String::as_str),
1594 Some(OPENROUTER_DEFAULT_HTTP_REFERER)
1595 );
1596 assert_eq!(
1597 captured.headers.get("x-title").map(String::as_str),
1598 Some(OPENROUTER_DEFAULT_X_TITLE)
1599 );
1600 }
1601
1602 #[test]
1603 fn test_stream_openrouter_respects_explicit_attribution_headers() {
1604 let options = StreamOptions {
1605 api_key: Some("test-openrouter-key".to_string()),
1606 headers: HashMap::from([
1607 (
1608 "HTTP-Referer".to_string(),
1609 "https://example.test/app".to_string(),
1610 ),
1611 (
1612 "X-Title".to_string(),
1613 "Custom OpenRouter Client".to_string(),
1614 ),
1615 ]),
1616 ..Default::default()
1617 };
1618 let captured = run_stream_and_capture_headers_with(
1619 OpenAIProvider::new("openai/gpt-4o-mini").with_provider_name("openrouter"),
1620 &options,
1621 )
1622 .expect("captured request");
1623
1624 assert_eq!(
1625 captured.headers.get("http-referer").map(String::as_str),
1626 Some("https://example.test/app")
1627 );
1628 assert_eq!(
1629 captured.headers.get("x-title").map(String::as_str),
1630 Some("Custom OpenRouter Client")
1631 );
1632 }
1633
1634 #[derive(Debug, Deserialize)]
1635 struct ProviderFixture {
1636 cases: Vec<ProviderCase>,
1637 }
1638
1639 #[derive(Debug, Deserialize)]
1640 struct ProviderCase {
1641 name: String,
1642 events: Vec<Value>,
1643 expected: Vec<EventSummary>,
1644 }
1645
1646 #[derive(Debug, Deserialize, Serialize, PartialEq)]
1647 struct EventSummary {
1648 kind: String,
1649 #[serde(default)]
1650 content_index: Option<usize>,
1651 #[serde(default)]
1652 delta: Option<String>,
1653 #[serde(default)]
1654 content: Option<String>,
1655 #[serde(default)]
1656 reason: Option<String>,
1657 }
1658
1659 #[test]
1660 fn test_stream_fixtures() {
1661 let fixture = load_fixture("openai_stream.json");
1662 for case in fixture.cases {
1663 let events = collect_events(&case.events);
1664 let summaries: Vec<EventSummary> = events.iter().map(summarize_event).collect();
1665 assert_eq!(summaries, case.expected, "case {}", case.name);
1666 }
1667 }
1668
1669 fn load_fixture(file_name: &str) -> ProviderFixture {
1670 let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
1671 .join("tests/fixtures/provider_responses")
1672 .join(file_name);
1673 let raw = std::fs::read_to_string(path).expect("fixture read");
1674 serde_json::from_str(&raw).expect("fixture parse")
1675 }
1676
1677 #[derive(Debug)]
1678 struct CapturedRequest {
1679 headers: HashMap<String, String>,
1680 body: String,
1681 }
1682
1683 fn run_stream_and_capture_headers() -> Option<CapturedRequest> {
1684 let options = StreamOptions {
1685 api_key: Some("test-openai-key".to_string()),
1686 ..Default::default()
1687 };
1688 run_stream_and_capture_headers_with(OpenAIProvider::new("gpt-4o"), &options)
1689 }
1690
1691 fn run_stream_and_capture_headers_with(
1692 provider: OpenAIProvider,
1693 options: &StreamOptions,
1694 ) -> Option<CapturedRequest> {
1695 let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
1696 let provider = provider.with_base_url(base_url);
1697 let context = Context {
1698 system_prompt: None,
1699 messages: vec![Message::User(crate::model::UserMessage {
1700 content: UserContent::Text("ping".to_string()),
1701 timestamp: 0,
1702 })]
1703 .into(),
1704 tools: Vec::new().into(),
1705 };
1706
1707 let runtime = RuntimeBuilder::current_thread()
1708 .build()
1709 .expect("runtime build");
1710 runtime.block_on(async {
1711 let mut stream = provider.stream(&context, options).await.expect("stream");
1712 while let Some(event) = stream.next().await {
1713 if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
1714 break;
1715 }
1716 }
1717 });
1718
1719 rx.recv_timeout(Duration::from_secs(2)).ok()
1720 }
1721
1722 fn success_sse_body() -> String {
1723 [
1724 r#"data: {"choices":[{"delta":{}}]}"#,
1725 "",
1726 r#"data: {"choices":[{"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}"#,
1727 "",
1728 "data: [DONE]",
1729 "",
1730 ]
1731 .join("\n")
1732 }
1733
1734 fn spawn_test_server(
1735 status_code: u16,
1736 content_type: &str,
1737 body: &str,
1738 ) -> (String, mpsc::Receiver<CapturedRequest>) {
1739 let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
1740 let addr = listener.local_addr().expect("local addr");
1741 let (tx, rx) = mpsc::channel();
1742 let body = body.to_string();
1743 let content_type = content_type.to_string();
1744
1745 std::thread::spawn(move || {
1746 let (mut socket, _) = listener.accept().expect("accept");
1747 socket
1748 .set_read_timeout(Some(Duration::from_secs(2)))
1749 .expect("set read timeout");
1750
1751 let mut bytes = Vec::new();
1752 let mut chunk = [0_u8; 4096];
1753 loop {
1754 match socket.read(&mut chunk) {
1755 Ok(0) => break,
1756 Ok(n) => {
1757 bytes.extend_from_slice(&chunk[..n]);
1758 if bytes.windows(4).any(|window| window == b"\r\n\r\n") {
1759 break;
1760 }
1761 }
1762 Err(err)
1763 if err.kind() == std::io::ErrorKind::WouldBlock
1764 || err.kind() == std::io::ErrorKind::TimedOut =>
1765 {
1766 break;
1767 }
1768 Err(err) => panic!("read error: {err}"),
1769 }
1770 }
1771
1772 let header_end = bytes
1773 .windows(4)
1774 .position(|window| window == b"\r\n\r\n")
1775 .expect("request header boundary");
1776 let header_text = String::from_utf8_lossy(&bytes[..header_end]).to_string();
1777 let headers = parse_headers(&header_text);
1778 let mut request_body = bytes[header_end + 4..].to_vec();
1779
1780 let content_length = headers
1781 .get("content-length")
1782 .and_then(|value| value.parse::<usize>().ok())
1783 .unwrap_or(0);
1784 while request_body.len() < content_length {
1785 match socket.read(&mut chunk) {
1786 Ok(0) => break,
1787 Ok(n) => request_body.extend_from_slice(&chunk[..n]),
1788 Err(err)
1789 if err.kind() == std::io::ErrorKind::WouldBlock
1790 || err.kind() == std::io::ErrorKind::TimedOut =>
1791 {
1792 break;
1793 }
1794 Err(err) => panic!("read error: {err}"),
1795 }
1796 }
1797
1798 let captured = CapturedRequest {
1799 headers,
1800 body: String::from_utf8_lossy(&request_body).to_string(),
1801 };
1802 tx.send(captured).expect("send captured request");
1803
1804 let reason = match status_code {
1805 401 => "Unauthorized",
1806 500 => "Internal Server Error",
1807 _ => "OK",
1808 };
1809 let response = format!(
1810 "HTTP/1.1 {status_code} {reason}\r\nContent-Type: {content_type}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
1811 body.len()
1812 );
1813 socket
1814 .write_all(response.as_bytes())
1815 .expect("write response");
1816 socket.flush().expect("flush response");
1817 });
1818
1819 (format!("http://{addr}/chat/completions"), rx)
1820 }
1821
1822 fn parse_headers(header_text: &str) -> HashMap<String, String> {
1823 let mut headers = HashMap::new();
1824 for line in header_text.lines().skip(1) {
1825 if let Some((name, value)) = line.split_once(':') {
1826 headers.insert(name.trim().to_ascii_lowercase(), value.trim().to_string());
1827 }
1828 }
1829 headers
1830 }
1831
1832 fn collect_events(events: &[Value]) -> Vec<StreamEvent> {
1833 let runtime = RuntimeBuilder::current_thread()
1834 .build()
1835 .expect("runtime build");
1836 runtime.block_on(async move {
1837 let byte_stream = stream::iter(
1838 events
1839 .iter()
1840 .map(|event| {
1841 let data = match event {
1842 Value::String(text) => text.clone(),
1843 _ => serde_json::to_string(event).expect("serialize event"),
1844 };
1845 format!("data: {data}\n\n").into_bytes()
1846 })
1847 .map(Ok),
1848 );
1849 let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
1850 let mut state = StreamState::new(
1851 event_source,
1852 "gpt-test".to_string(),
1853 "openai".to_string(),
1854 "openai".to_string(),
1855 );
1856 let mut out = Vec::new();
1857
1858 while let Some(item) = state.event_source.next().await {
1859 let msg = item.expect("SSE event");
1860 if msg.data == "[DONE]" {
1861 out.extend(state.pending_events.drain(..));
1862 let reason = state.partial.stop_reason;
1863 out.push(StreamEvent::Done {
1864 reason,
1865 message: std::mem::take(&mut state.partial),
1866 });
1867 break;
1868 }
1869 state.process_event(&msg.data).expect("process_event");
1870 out.extend(state.pending_events.drain(..));
1871 }
1872
1873 out
1874 })
1875 }
1876
1877 fn collect_thinking_text(events: &[StreamEvent]) -> String {
1878 let mut out = String::new();
1879 for event in events {
1880 if let StreamEvent::ThinkingDelta { delta, .. } = event {
1881 out.push_str(delta);
1882 }
1883 }
1884 out
1885 }
1886
1887 fn summarize_event(event: &StreamEvent) -> EventSummary {
1888 match event {
1889 StreamEvent::Start { .. } => EventSummary {
1890 kind: "start".to_string(),
1891 content_index: None,
1892 delta: None,
1893 content: None,
1894 reason: None,
1895 },
1896 StreamEvent::TextDelta {
1897 content_index,
1898 delta,
1899 ..
1900 } => EventSummary {
1901 kind: "text_delta".to_string(),
1902 content_index: Some(*content_index),
1903 delta: Some(delta.clone()),
1904 content: None,
1905 reason: None,
1906 },
1907 StreamEvent::Done { reason, .. } => EventSummary {
1908 kind: "done".to_string(),
1909 content_index: None,
1910 delta: None,
1911 content: None,
1912 reason: Some(reason_to_string(*reason)),
1913 },
1914 StreamEvent::Error { reason, .. } => EventSummary {
1915 kind: "error".to_string(),
1916 content_index: None,
1917 delta: None,
1918 content: None,
1919 reason: Some(reason_to_string(*reason)),
1920 },
1921 StreamEvent::TextStart { content_index, .. } => EventSummary {
1922 kind: "text_start".to_string(),
1923 content_index: Some(*content_index),
1924 delta: None,
1925 content: None,
1926 reason: None,
1927 },
1928 StreamEvent::TextEnd {
1929 content_index,
1930 content,
1931 ..
1932 } => EventSummary {
1933 kind: "text_end".to_string(),
1934 content_index: Some(*content_index),
1935 delta: None,
1936 content: Some(content.clone()),
1937 reason: None,
1938 },
1939 _ => EventSummary {
1940 kind: "other".to_string(),
1941 content_index: None,
1942 delta: None,
1943 content: None,
1944 reason: None,
1945 },
1946 }
1947 }
1948
1949 fn reason_to_string(reason: StopReason) -> String {
1950 match reason {
1951 StopReason::Stop => "stop",
1952 StopReason::Length => "length",
1953 StopReason::ToolUse => "tool_use",
1954 StopReason::Error => "error",
1955 StopReason::Aborted => "aborted",
1956 }
1957 .to_string()
1958 }
1959
1960 fn context_with_tools() -> Context<'static> {
1963 Context {
1964 system_prompt: Some("You are helpful.".to_string().into()),
1965 messages: vec![Message::User(crate::model::UserMessage {
1966 content: UserContent::Text("Hi".to_string()),
1967 timestamp: 0,
1968 })]
1969 .into(),
1970 tools: vec![ToolDef {
1971 name: "search".to_string(),
1972 description: "Search".to_string(),
1973 parameters: json!({"type": "object", "properties": {}}),
1974 }]
1975 .into(),
1976 }
1977 }
1978
1979 fn default_stream_options() -> StreamOptions {
1980 StreamOptions {
1981 max_tokens: Some(1024),
1982 ..Default::default()
1983 }
1984 }
1985
1986 #[test]
1987 fn compat_system_role_name_overrides_default() {
1988 let provider = OpenAIProvider::new("gpt-4o").with_compat(Some(CompatConfig {
1989 system_role_name: Some("developer".to_string()),
1990 ..Default::default()
1991 }));
1992 let context = context_with_tools();
1993 let options = default_stream_options();
1994 let req = provider.build_request(&context, &options);
1995 let value = serde_json::to_value(&req).expect("serialize");
1996 assert_eq!(
1997 value["messages"][0]["role"], "developer",
1998 "system message should use overridden role name"
1999 );
2000 }
2001
2002 #[test]
2003 fn compat_none_uses_default_system_role() {
2004 let provider = OpenAIProvider::new("gpt-4o");
2005 let context = context_with_tools();
2006 let options = default_stream_options();
2007 let req = provider.build_request(&context, &options);
2008 let value = serde_json::to_value(&req).expect("serialize");
2009 assert_eq!(
2010 value["messages"][0]["role"], "system",
2011 "default system role should be 'system'"
2012 );
2013 }
2014
2015 #[test]
2016 fn compat_supports_tools_false_omits_tools() {
2017 let provider = OpenAIProvider::new("gpt-4o").with_compat(Some(CompatConfig {
2018 supports_tools: Some(false),
2019 ..Default::default()
2020 }));
2021 let context = context_with_tools();
2022 let options = default_stream_options();
2023 let req = provider.build_request(&context, &options);
2024 let value = serde_json::to_value(&req).expect("serialize");
2025 assert!(
2026 value["tools"].is_null(),
2027 "tools should be omitted when supports_tools=false"
2028 );
2029 }
2030
2031 #[test]
2032 fn compat_supports_tools_true_includes_tools() {
2033 let provider = OpenAIProvider::new("gpt-4o").with_compat(Some(CompatConfig {
2034 supports_tools: Some(true),
2035 ..Default::default()
2036 }));
2037 let context = context_with_tools();
2038 let options = default_stream_options();
2039 let req = provider.build_request(&context, &options);
2040 let value = serde_json::to_value(&req).expect("serialize");
2041 assert!(
2042 value["tools"].is_array(),
2043 "tools should be included when supports_tools=true"
2044 );
2045 }
2046
2047 #[test]
2048 fn compat_max_tokens_field_routes_to_max_completion_tokens() {
2049 let provider = OpenAIProvider::new("o1").with_compat(Some(CompatConfig {
2050 max_tokens_field: Some("max_completion_tokens".to_string()),
2051 ..Default::default()
2052 }));
2053 let context = context_with_tools();
2054 let options = default_stream_options();
2055 let req = provider.build_request(&context, &options);
2056 let value = serde_json::to_value(&req).expect("serialize");
2057 assert!(
2058 value["max_tokens"].is_null(),
2059 "max_tokens should be absent when routed to max_completion_tokens"
2060 );
2061 assert_eq!(
2062 value["max_completion_tokens"], 1024,
2063 "max_completion_tokens should carry the token limit"
2064 );
2065 }
2066
2067 #[test]
2068 fn compat_default_routes_to_max_tokens() {
2069 let provider = OpenAIProvider::new("gpt-4o");
2070 let context = context_with_tools();
2071 let options = default_stream_options();
2072 let req = provider.build_request(&context, &options);
2073 let value = serde_json::to_value(&req).expect("serialize");
2074 assert_eq!(
2075 value["max_tokens"], 1024,
2076 "default should use max_tokens field"
2077 );
2078 assert!(
2079 value["max_completion_tokens"].is_null(),
2080 "max_completion_tokens should be absent by default"
2081 );
2082 }
2083
2084 #[test]
2085 fn compat_supports_usage_in_streaming_false() {
2086 let provider = OpenAIProvider::new("gpt-4o").with_compat(Some(CompatConfig {
2087 supports_usage_in_streaming: Some(false),
2088 ..Default::default()
2089 }));
2090 let context = context_with_tools();
2091 let options = default_stream_options();
2092 let req = provider.build_request(&context, &options);
2093 let value = serde_json::to_value(&req).expect("serialize");
2094 assert_eq!(
2095 value["stream_options"]["include_usage"], false,
2096 "include_usage should be false when supports_usage_in_streaming=false"
2097 );
2098 }
2099
2100 #[test]
2101 fn compat_combined_overrides() {
2102 let provider = OpenAIProvider::new("custom-model").with_compat(Some(CompatConfig {
2103 system_role_name: Some("developer".to_string()),
2104 max_tokens_field: Some("max_completion_tokens".to_string()),
2105 supports_tools: Some(false),
2106 supports_usage_in_streaming: Some(false),
2107 ..Default::default()
2108 }));
2109 let context = context_with_tools();
2110 let options = default_stream_options();
2111 let req = provider.build_request(&context, &options);
2112 let value = serde_json::to_value(&req).expect("serialize");
2113 assert_eq!(value["messages"][0]["role"], "developer");
2114 assert!(value["max_tokens"].is_null());
2115 assert_eq!(value["max_completion_tokens"], 1024);
2116 assert!(value["tools"].is_null());
2117 assert_eq!(value["stream_options"]["include_usage"], false);
2118 }
2119
2120 #[test]
2121 fn compat_custom_headers_injected_into_stream_request() {
2122 let mut custom = HashMap::new();
2123 custom.insert("X-Custom-Tag".to_string(), "test-123".to_string());
2124 custom.insert("X-Provider-Region".to_string(), "us-east-1".to_string());
2125 let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
2126 let provider = OpenAIProvider::new("gpt-4o")
2127 .with_base_url(base_url)
2128 .with_compat(Some(CompatConfig {
2129 custom_headers: Some(custom),
2130 ..Default::default()
2131 }));
2132
2133 let context = Context {
2134 system_prompt: None,
2135 messages: vec![Message::User(crate::model::UserMessage {
2136 content: UserContent::Text("ping".to_string()),
2137 timestamp: 0,
2138 })]
2139 .into(),
2140 tools: Vec::new().into(),
2141 };
2142 let options = StreamOptions {
2143 api_key: Some("test-key".to_string()),
2144 ..Default::default()
2145 };
2146
2147 let runtime = RuntimeBuilder::current_thread()
2148 .build()
2149 .expect("runtime build");
2150 runtime.block_on(async {
2151 let mut stream = provider.stream(&context, &options).await.expect("stream");
2152 while let Some(event) = stream.next().await {
2153 if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
2154 break;
2155 }
2156 }
2157 });
2158
2159 let captured = rx
2160 .recv_timeout(Duration::from_secs(2))
2161 .expect("captured request");
2162 assert_eq!(
2163 captured.headers.get("x-custom-tag").map(String::as_str),
2164 Some("test-123"),
2165 "custom header should be present in request"
2166 );
2167 assert_eq!(
2168 captured
2169 .headers
2170 .get("x-provider-region")
2171 .map(String::as_str),
2172 Some("us-east-1"),
2173 "custom header should be present in request"
2174 );
2175 }
2176
2177 #[test]
2178 fn compat_authorization_header_is_used_without_api_key() {
2179 let mut custom = HashMap::new();
2180 custom.insert(
2181 "Authorization".to_string(),
2182 "Bearer compat-openai-token".to_string(),
2183 );
2184 let provider = OpenAIProvider::new("gpt-4o").with_compat(Some(CompatConfig {
2185 custom_headers: Some(custom),
2186 ..Default::default()
2187 }));
2188 let options = StreamOptions::default();
2189
2190 let captured =
2191 run_stream_and_capture_headers_with(provider, &options).expect("captured request");
2192 assert_eq!(
2193 captured.headers.get("authorization").map(String::as_str),
2194 Some("Bearer compat-openai-token")
2195 );
2196 }
2197
2198 #[test]
2199 fn blank_compat_authorization_header_does_not_override_builtin_api_key() {
2200 let mut custom = HashMap::new();
2201 custom.insert("Authorization".to_string(), " ".to_string());
2202 let provider = OpenAIProvider::new("gpt-4o").with_compat(Some(CompatConfig {
2203 custom_headers: Some(custom),
2204 ..Default::default()
2205 }));
2206 let options = StreamOptions {
2207 api_key: Some("test-openai-key".to_string()),
2208 ..Default::default()
2209 };
2210
2211 let captured =
2212 run_stream_and_capture_headers_with(provider, &options).expect("captured request");
2213 assert_eq!(
2214 captured.headers.get("authorization").map(String::as_str),
2215 Some("Bearer test-openai-key")
2216 );
2217 }
2218
2219 #[test]
2220 fn reasoning_only_delta_emits_thinking_events() {
2221 let events = vec![
2222 json!({
2223 "choices": [{
2224 "delta": {"reasoning_content": "plan"},
2225 "finish_reason": null
2226 }]
2227 }),
2228 json!({
2229 "choices": [{
2230 "delta": {},
2231 "finish_reason": "stop"
2232 }]
2233 }),
2234 Value::String("[DONE]".to_string()),
2235 ];
2236
2237 let out = collect_events(&events);
2238 assert!(
2239 out.iter()
2240 .any(|event| matches!(event, StreamEvent::ThinkingStart { .. })),
2241 "expected ThinkingStart for reasoning-only delta"
2242 );
2243 assert!(
2244 out.iter()
2245 .any(|event| matches!(event, StreamEvent::ThinkingDelta { .. })),
2246 "expected ThinkingDelta for reasoning-only delta"
2247 );
2248 assert!(
2249 out.iter()
2250 .any(|event| matches!(event, StreamEvent::ThinkingEnd { .. })),
2251 "expected ThinkingEnd after finish_reason"
2252 );
2253 assert_eq!(collect_thinking_text(&out), "plan");
2254 }
2255
2256 #[test]
2257 fn reasoning_delta_segmentation_is_lossless() {
2258 let single = vec![
2259 json!({
2260 "choices": [{
2261 "delta": {"reasoning_content": "abc"},
2262 "finish_reason": null
2263 }]
2264 }),
2265 json!({
2266 "choices": [{
2267 "delta": {},
2268 "finish_reason": "stop"
2269 }]
2270 }),
2271 Value::String("[DONE]".to_string()),
2272 ];
2273
2274 let split = vec![
2275 json!({
2276 "choices": [{
2277 "delta": {"reasoning_content": "a"},
2278 "finish_reason": null
2279 }]
2280 }),
2281 json!({
2282 "choices": [{
2283 "delta": {"reasoning_content": "bc"},
2284 "finish_reason": null
2285 }]
2286 }),
2287 json!({
2288 "choices": [{
2289 "delta": {},
2290 "finish_reason": "stop"
2291 }]
2292 }),
2293 Value::String("[DONE]".to_string()),
2294 ];
2295
2296 let single_out = collect_events(&single);
2297 let split_out = collect_events(&split);
2298
2299 assert_eq!(
2300 collect_thinking_text(&single_out),
2301 collect_thinking_text(&split_out),
2302 "segmenting reasoning deltas should not change final thinking text"
2303 );
2304 }
2305
2306 mod proptest_process_event {
2311 use super::*;
2312 use proptest::prelude::*;
2313
2314 fn make_state()
2315 -> StreamState<impl Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin>
2316 {
2317 let empty = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
2318 let sse = crate::sse::SseStream::new(Box::pin(empty));
2319 StreamState::new(sse, "gpt-test".into(), "openai".into(), "openai".into())
2320 }
2321
2322 fn small_string() -> impl Strategy<Value = String> {
2323 prop_oneof![Just(String::new()), "[a-zA-Z0-9_]{1,16}", "[ -~]{0,32}",]
2324 }
2325
2326 fn optional_string() -> impl Strategy<Value = Option<String>> {
2327 prop_oneof![Just(None), small_string().prop_map(Some),]
2328 }
2329
2330 fn token_count() -> impl Strategy<Value = u64> {
2331 prop_oneof![
2332 5 => 0u64..10_000u64,
2333 2 => Just(0u64),
2334 1 => Just(u64::MAX),
2335 1 => (u64::MAX - 100)..=u64::MAX,
2336 ]
2337 }
2338
2339 fn finish_reason() -> impl Strategy<Value = Option<String>> {
2340 prop_oneof![
2341 3 => Just(None),
2342 1 => Just(Some("stop".to_string())),
2343 1 => Just(Some("length".to_string())),
2344 1 => Just(Some("tool_calls".to_string())),
2345 1 => Just(Some("content_filter".to_string())),
2346 1 => small_string().prop_map(Some),
2347 ]
2348 }
2349
2350 fn tool_call_index() -> impl Strategy<Value = u32> {
2351 prop_oneof![
2352 5 => 0u32..3u32,
2353 1 => Just(u32::MAX),
2354 1 => 100u32..200u32,
2355 ]
2356 }
2357
2358 fn openai_chunk_json() -> impl Strategy<Value = String> {
2360 prop_oneof![
2361 3 => (small_string(), finish_reason()).prop_map(|(text, fr)| {
2363 let mut choice = serde_json::json!({
2364 "delta": {"content": text}
2365 });
2366 if let Some(reason) = fr {
2367 choice["finish_reason"] = serde_json::Value::String(reason);
2368 }
2369 serde_json::json!({"choices": [choice]}).to_string()
2370 }),
2371 2 => Just(r#"{"choices":[{"delta":{}}]}"#.to_string()),
2373 2 => finish_reason()
2375 .prop_filter_map("some reason", |fr| fr)
2376 .prop_map(|reason| {
2377 serde_json::json!({
2378 "choices": [{"delta": {}, "finish_reason": reason}]
2379 })
2380 .to_string()
2381 }),
2382 3 => (tool_call_index(), optional_string(), optional_string(), optional_string())
2384 .prop_map(|(idx, id, name, args)| {
2385 let mut tc = serde_json::json!({"index": idx});
2386 if let Some(id) = id { tc["id"] = serde_json::Value::String(id); }
2387 let mut func = serde_json::Map::new();
2388 if let Some(n) = name { func.insert("name".into(), serde_json::Value::String(n)); }
2389 if let Some(a) = args { func.insert("arguments".into(), serde_json::Value::String(a)); }
2390 if !func.is_empty() { tc["function"] = serde_json::Value::Object(func); }
2391 serde_json::json!({
2392 "choices": [{"delta": {"tool_calls": [tc]}}]
2393 })
2394 .to_string()
2395 }),
2396 2 => (token_count(), token_count(), token_count()).prop_map(|(prompt, compl, total)| {
2398 serde_json::json!({
2399 "choices": [],
2400 "usage": {
2401 "prompt_tokens": prompt,
2402 "completion_tokens": compl,
2403 "total_tokens": total
2404 }
2405 })
2406 .to_string()
2407 }),
2408 1 => small_string().prop_map(|msg| {
2410 serde_json::json!({
2411 "choices": [],
2412 "error": {"message": msg}
2413 })
2414 .to_string()
2415 }),
2416 1 => Just(r#"{"choices":[]}"#.to_string()),
2418 ]
2419 }
2420
2421 fn chaos_json() -> impl Strategy<Value = String> {
2423 prop_oneof![
2424 Just(String::new()),
2425 Just("{}".to_string()),
2426 Just("[]".to_string()),
2427 Just("null".to_string()),
2428 Just("{".to_string()),
2429 Just(r#"{"choices":"not_array"}"#.to_string()),
2430 Just(r#"{"choices":[{"delta":null}]}"#.to_string()),
2431 "[a-z_]{1,20}".prop_map(|t| format!(r#"{{"type":"{t}"}}"#)),
2432 "[ -~]{0,64}",
2433 ]
2434 }
2435
2436 proptest! {
2437 #![proptest_config(ProptestConfig {
2438 cases: 256,
2439 max_shrink_iters: 100,
2440 .. ProptestConfig::default()
2441 })]
2442
2443 #[test]
2444 fn process_event_valid_never_panics(data in openai_chunk_json()) {
2445 let mut state = make_state();
2446 let _ = state.process_event(&data);
2447 }
2448
2449 #[test]
2450 fn process_event_chaos_never_panics(data in chaos_json()) {
2451 let mut state = make_state();
2452 let _ = state.process_event(&data);
2453 }
2454
2455 #[test]
2456 fn process_event_sequence_never_panics(
2457 events in prop::collection::vec(openai_chunk_json(), 1..8)
2458 ) {
2459 let mut state = make_state();
2460 for event in &events {
2461 let _ = state.process_event(event);
2462 }
2463 }
2464
2465 #[test]
2466 fn process_event_mixed_sequence_never_panics(
2467 events in prop::collection::vec(
2468 prop_oneof![openai_chunk_json(), chaos_json()],
2469 1..12
2470 )
2471 ) {
2472 let mut state = make_state();
2473 for event in &events {
2474 let _ = state.process_event(event);
2475 }
2476 }
2477 }
2478 }
2479}
2480
2481#[cfg(feature = "fuzzing")]
2486pub mod fuzz {
2487 use super::*;
2488 use futures::stream;
2489 use std::pin::Pin;
2490
2491 type FuzzStream =
2492 Pin<Box<futures::stream::Empty<std::result::Result<Vec<u8>, std::io::Error>>>>;
2493
2494 pub struct Processor(StreamState<FuzzStream>);
2496
2497 impl Default for Processor {
2498 fn default() -> Self {
2499 Self::new()
2500 }
2501 }
2502
2503 impl Processor {
2504 pub fn new() -> Self {
2506 let empty = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
2507 Self(StreamState::new(
2508 crate::sse::SseStream::new(Box::pin(empty)),
2509 "gpt-fuzz".into(),
2510 "openai".into(),
2511 "openai".into(),
2512 ))
2513 }
2514
2515 pub fn process_event(&mut self, data: &str) -> crate::error::Result<Vec<StreamEvent>> {
2517 self.0.process_event(data)?;
2518 Ok(self.0.pending_events.drain(..).collect())
2519 }
2520 }
2521}