1use std::borrow::Cow;
10
11use crate::error::{Error, Result};
12use crate::http::client::Client;
13use crate::model::{
14 AssistantMessage, ContentBlock, Message, StopReason, StreamEvent, TextContent, ThinkingContent,
15 ToolCall, Usage, UserContent,
16};
17use crate::models::CompatConfig;
18use crate::provider::{Context, Provider, StreamOptions, ToolDef};
19use crate::sse::SseStream;
20use async_trait::async_trait;
21use futures::StreamExt;
22use futures::stream::{self, Stream};
23use serde::{Deserialize, Serialize};
24use std::collections::VecDeque;
25use std::pin::Pin;
26
27const OPENAI_API_URL: &str = "https://api.openai.com/v1/chat/completions";
32const DEFAULT_MAX_TOKENS: u32 = 4096;
33const OPENROUTER_DEFAULT_HTTP_REFERER: &str = "https://github.com/Dicklesworthstone/pi_agent_rust";
34const OPENROUTER_DEFAULT_X_TITLE: &str = "Pi Agent Rust";
35
36fn to_cow_role(role: &str) -> Cow<'_, str> {
44 match role {
45 "system" => Cow::Borrowed("system"),
46 "developer" => Cow::Borrowed("developer"),
47 "user" => Cow::Borrowed("user"),
48 "assistant" => Cow::Borrowed("assistant"),
49 "tool" => Cow::Borrowed("tool"),
50 "function" => Cow::Borrowed("function"),
51 other => Cow::Owned(other.to_string()),
52 }
53}
54
55fn map_has_any_header(headers: &std::collections::HashMap<String, String>, names: &[&str]) -> bool {
56 headers
57 .keys()
58 .any(|key| names.iter().any(|name| key.eq_ignore_ascii_case(name)))
59}
60
61fn first_non_empty_env(keys: &[&str]) -> Option<String> {
62 keys.iter().find_map(|key| {
63 std::env::var(key)
64 .ok()
65 .map(|value| value.trim().to_string())
66 .filter(|value| !value.is_empty())
67 })
68}
69
70fn openrouter_default_http_referer() -> String {
71 first_non_empty_env(&["OPENROUTER_HTTP_REFERER", "PI_OPENROUTER_HTTP_REFERER"])
72 .unwrap_or_else(|| OPENROUTER_DEFAULT_HTTP_REFERER.to_string())
73}
74
75fn openrouter_default_x_title() -> String {
76 first_non_empty_env(&["OPENROUTER_X_TITLE", "PI_OPENROUTER_X_TITLE"])
77 .unwrap_or_else(|| OPENROUTER_DEFAULT_X_TITLE.to_string())
78}
79
80pub struct OpenAIProvider {
86 client: Client,
87 model: String,
88 base_url: String,
89 provider: String,
90 compat: Option<CompatConfig>,
91}
92
93impl OpenAIProvider {
94 pub fn new(model: impl Into<String>) -> Self {
96 Self {
97 client: Client::new(),
98 model: model.into(),
99 base_url: OPENAI_API_URL.to_string(),
100 provider: "openai".to_string(),
101 compat: None,
102 }
103 }
104
105 #[must_use]
110 pub fn with_provider_name(mut self, provider: impl Into<String>) -> Self {
111 self.provider = provider.into();
112 self
113 }
114
115 #[must_use]
117 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
118 self.base_url = base_url.into();
119 self
120 }
121
122 #[must_use]
124 pub fn with_client(mut self, client: Client) -> Self {
125 self.client = client;
126 self
127 }
128
129 #[must_use]
134 pub fn with_compat(mut self, compat: Option<CompatConfig>) -> Self {
135 self.compat = compat;
136 self
137 }
138
139 pub fn build_request<'a>(
141 &'a self,
142 context: &'a Context<'_>,
143 options: &StreamOptions,
144 ) -> OpenAIRequest<'a> {
145 let system_role = self
146 .compat
147 .as_ref()
148 .and_then(|c| c.system_role_name.as_deref())
149 .unwrap_or("system");
150 let messages = Self::build_messages_with_role(context, system_role);
151
152 let tools_supported = self
153 .compat
154 .as_ref()
155 .and_then(|c| c.supports_tools)
156 .unwrap_or(true);
157
158 let tools: Option<Vec<OpenAITool<'a>>> = if context.tools.is_empty() || !tools_supported {
159 None
160 } else {
161 Some(context.tools.iter().map(convert_tool_to_openai).collect())
162 };
163
164 let use_alt_field = self
166 .compat
167 .as_ref()
168 .and_then(|c| c.max_tokens_field.as_deref())
169 .is_some_and(|f| f == "max_completion_tokens");
170
171 let token_limit = options.max_tokens.or(Some(DEFAULT_MAX_TOKENS));
172 let (max_tokens, max_completion_tokens) = if use_alt_field {
173 (None, token_limit)
174 } else {
175 (token_limit, None)
176 };
177
178 let include_usage = self
179 .compat
180 .as_ref()
181 .and_then(|c| c.supports_usage_in_streaming)
182 .unwrap_or(true);
183
184 let stream_options = Some(OpenAIStreamOptions { include_usage });
185
186 OpenAIRequest {
187 model: &self.model,
188 messages,
189 max_tokens,
190 max_completion_tokens,
191 temperature: options.temperature,
192 tools,
193 stream: true,
194 stream_options,
195 }
196 }
197
198 fn build_request_json(
199 &self,
200 context: &Context<'_>,
201 options: &StreamOptions,
202 ) -> Result<serde_json::Value> {
203 let request = self.build_request(context, options);
204 let mut value = serde_json::to_value(request)
205 .map_err(|e| Error::api(format!("Failed to serialize OpenAI request: {e}")))?;
206 self.apply_openrouter_routing_overrides(&mut value)?;
207 Ok(value)
208 }
209
210 fn apply_openrouter_routing_overrides(&self, request: &mut serde_json::Value) -> Result<()> {
211 if !self.provider.eq_ignore_ascii_case("openrouter") {
212 return Ok(());
213 }
214
215 let Some(routing) = self
216 .compat
217 .as_ref()
218 .and_then(|compat| compat.open_router_routing.as_ref())
219 else {
220 return Ok(());
221 };
222
223 let Some(request_obj) = request.as_object_mut() else {
224 return Err(Error::api(
225 "OpenAI request body must serialize to a JSON object",
226 ));
227 };
228 let Some(routing_obj) = routing.as_object() else {
229 return Err(Error::config(
230 "openRouterRouting must be a JSON object when configured",
231 ));
232 };
233
234 for (key, value) in routing_obj {
235 request_obj.insert(key.clone(), value.clone());
236 }
237 Ok(())
238 }
239
240 fn build_messages_with_role<'a>(
242 context: &'a Context<'_>,
243 system_role: &'a str,
244 ) -> Vec<OpenAIMessage<'a>> {
245 let mut messages = Vec::with_capacity(context.messages.len() + 1);
246
247 if let Some(system) = &context.system_prompt {
249 messages.push(OpenAIMessage {
250 role: to_cow_role(system_role),
251 content: Some(OpenAIContent::Text(Cow::Borrowed(system))),
252 tool_calls: None,
253 tool_call_id: None,
254 });
255 }
256
257 for message in context.messages.iter() {
259 messages.extend(convert_message_to_openai(message));
260 }
261
262 messages
263 }
264}
265
266#[async_trait]
267impl Provider for OpenAIProvider {
268 fn name(&self) -> &str {
269 &self.provider
270 }
271
272 fn api(&self) -> &'static str {
273 "openai-completions"
274 }
275
276 fn model_id(&self) -> &str {
277 &self.model
278 }
279
280 #[allow(clippy::too_many_lines)]
281 async fn stream(
282 &self,
283 context: &Context<'_>,
284 options: &StreamOptions,
285 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
286 let has_authorization_header = options
287 .headers
288 .keys()
289 .any(|key| key.eq_ignore_ascii_case("authorization"));
290
291 let auth_value = if has_authorization_header {
292 None
293 } else {
294 Some(
295 options
296 .api_key
297 .clone()
298 .or_else(|| std::env::var("OPENAI_API_KEY").ok())
299 .ok_or_else(|| {
300 Error::provider(
301 &self.provider,
302 "Missing API key for OpenAI. Set OPENAI_API_KEY or configure in settings.",
303 )
304 })?,
305 )
306 };
307
308 let request_body = self.build_request_json(context, options)?;
309
310 let mut request = self
313 .client
314 .post(&self.base_url)
315 .header("Accept", "text/event-stream");
316
317 if let Some(auth_value) = auth_value {
318 request = request.header("Authorization", format!("Bearer {auth_value}"));
319 }
320
321 if self.provider.eq_ignore_ascii_case("openrouter") {
322 let compat_headers = self
323 .compat
324 .as_ref()
325 .and_then(|compat| compat.custom_headers.as_ref());
326 let has_referer = map_has_any_header(&options.headers, &["http-referer", "referer"])
327 || compat_headers.is_some_and(|headers| {
328 map_has_any_header(headers, &["http-referer", "referer"])
329 });
330 if !has_referer {
331 request = request.header("HTTP-Referer", openrouter_default_http_referer());
332 }
333
334 let has_title = map_has_any_header(&options.headers, &["x-title"])
335 || compat_headers.is_some_and(|headers| map_has_any_header(headers, &["x-title"]));
336 if !has_title {
337 request = request.header("X-Title", openrouter_default_x_title());
338 }
339 }
340
341 if let Some(compat) = &self.compat {
343 if let Some(custom_headers) = &compat.custom_headers {
344 for (key, value) in custom_headers {
345 request = request.header(key, value);
346 }
347 }
348 }
349
350 for (key, value) in &options.headers {
352 request = request.header(key, value);
353 }
354
355 let request = request.json(&request_body)?;
356
357 let response = Box::pin(request.send()).await?;
358 let status = response.status();
359 if !(200..300).contains(&status) {
360 let body = response
361 .text()
362 .await
363 .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
364 return Err(Error::provider(
365 &self.provider,
366 format!("OpenAI API error (HTTP {status}): {body}"),
367 ));
368 }
369
370 let content_type = response
371 .headers()
372 .iter()
373 .find(|(name, _)| name.eq_ignore_ascii_case("content-type"))
374 .map(|(_, value)| value.to_ascii_lowercase());
375 if !content_type
376 .as_deref()
377 .is_some_and(|value| value.contains("text/event-stream"))
378 {
379 let message = content_type.map_or_else(
380 || {
381 format!(
382 "OpenAI API protocol error (HTTP {status}): missing Content-Type (expected text/event-stream)"
383 )
384 },
385 |value| {
386 format!(
387 "OpenAI API protocol error (HTTP {status}): unexpected Content-Type {value} (expected text/event-stream)"
388 )
389 },
390 );
391 return Err(Error::api(message));
392 }
393
394 let event_source = SseStream::new(response.bytes_stream());
396
397 let model = self.model.clone();
399 let api = self.api().to_string();
400 let provider = self.name().to_string();
401
402 let stream = stream::unfold(
403 StreamState::new(event_source, model, api, provider),
404 |mut state| async move {
405 if state.done {
406 return None;
407 }
408 loop {
409 if let Some(event) = state.pending_events.pop_front() {
410 return Some((Ok(event), state));
411 }
412
413 match state.event_source.next().await {
414 Some(Ok(msg)) => {
415 if msg.data == "[DONE]" {
417 state.done = true;
418 let reason = state.partial.stop_reason;
419 let message = std::mem::take(&mut state.partial);
420 return Some((Ok(StreamEvent::Done { reason, message }), state));
421 }
422
423 if let Err(e) = state.process_event(&msg.data) {
424 state.done = true;
425 return Some((Err(e), state));
426 }
427 }
428 Some(Err(e)) => {
429 state.done = true;
430 let err = Error::api(format!("SSE error: {e}"));
431 return Some((Err(err), state));
432 }
433 None => {
438 state.done = true;
439 let reason = state.partial.stop_reason;
440 let message = std::mem::take(&mut state.partial);
441 return Some((Ok(StreamEvent::Done { reason, message }), state));
442 }
443 }
444 }
445 },
446 );
447
448 Ok(Box::pin(stream))
449 }
450}
451
452struct StreamState<S>
457where
458 S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
459{
460 event_source: SseStream<S>,
461 partial: AssistantMessage,
462 tool_calls: Vec<ToolCallState>,
463 pending_events: VecDeque<StreamEvent>,
464 started: bool,
465 done: bool,
466}
467
468struct ToolCallState {
469 index: usize,
470 content_index: usize,
471 id: String,
472 name: String,
473 arguments: String,
474}
475
476impl<S> StreamState<S>
477where
478 S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
479{
480 fn new(event_source: SseStream<S>, model: String, api: String, provider: String) -> Self {
481 Self {
482 event_source,
483 partial: AssistantMessage {
484 content: Vec::new(),
485 api,
486 provider,
487 model,
488 usage: Usage::default(),
489 stop_reason: StopReason::Stop,
490 error_message: None,
491 timestamp: chrono::Utc::now().timestamp_millis(),
492 },
493 tool_calls: Vec::new(),
494 pending_events: VecDeque::new(),
495 started: false,
496 done: false,
497 }
498 }
499
500 fn ensure_started(&mut self) {
501 if !self.started {
502 self.started = true;
503 self.pending_events.push_back(StreamEvent::Start {
504 partial: self.partial.clone(),
505 });
506 }
507 }
508
509 fn process_event(&mut self, data: &str) -> Result<()> {
510 let chunk: OpenAIStreamChunk =
511 serde_json::from_str(data).map_err(|e| Error::api(format!("JSON parse error: {e}")))?;
512
513 if let Some(usage) = chunk.usage {
515 self.partial.usage.input = usage.prompt_tokens;
516 self.partial.usage.output = usage.completion_tokens.unwrap_or(0);
517 self.partial.usage.total_tokens = usage.total_tokens;
518 }
519
520 if let Some(error) = chunk.error {
521 self.partial.stop_reason = StopReason::Error;
522 if let Some(message) = error.message {
523 let message = message.trim();
524 if !message.is_empty() {
525 self.partial.error_message = Some(message.to_string());
526 }
527 }
528 }
529
530 if let Some(choice) = chunk.choices.into_iter().next() {
532 if !self.started
533 && choice.finish_reason.is_none()
534 && choice.delta.content.is_none()
535 && choice.delta.tool_calls.is_none()
536 {
537 self.ensure_started();
538 return Ok(());
539 }
540
541 self.process_choice(choice);
542 }
543
544 Ok(())
545 }
546
547 fn finalize_tool_call_arguments(&mut self) {
548 for tc in &self.tool_calls {
549 let arguments: serde_json::Value = match serde_json::from_str(&tc.arguments) {
550 Ok(args) => args,
551 Err(e) => {
552 tracing::warn!(
553 error = %e,
554 raw = %tc.arguments,
555 "Failed to parse tool arguments as JSON"
556 );
557 serde_json::Value::Null
558 }
559 };
560
561 if let Some(ContentBlock::ToolCall(block)) =
562 self.partial.content.get_mut(tc.content_index)
563 {
564 block.arguments = arguments;
565 }
566 }
567 }
568
569 #[allow(clippy::too_many_lines)]
570 fn process_choice(&mut self, choice: OpenAIChoice) {
571 let delta = choice.delta;
572 if delta.content.is_some()
573 || delta.tool_calls.is_some()
574 || delta.reasoning_content.is_some()
575 {
576 self.ensure_started();
577 }
578
579 if choice.finish_reason.is_some() {
582 self.ensure_started();
583 }
584
585 if let Some(reasoning) = delta.reasoning_content {
587 let last_is_thinking =
589 matches!(self.partial.content.last(), Some(ContentBlock::Thinking(_)));
590
591 let content_index = if last_is_thinking {
592 self.partial.content.len() - 1
593 } else {
594 let idx = self.partial.content.len();
595 self.partial
596 .content
597 .push(ContentBlock::Thinking(ThinkingContent {
598 thinking: String::new(),
599 thinking_signature: None,
600 }));
601
602 self.pending_events
603 .push_back(StreamEvent::ThinkingStart { content_index: idx });
604
605 idx
606 };
607
608 if let Some(ContentBlock::Thinking(t)) = self.partial.content.get_mut(content_index) {
609 t.thinking.push_str(&reasoning);
610 }
611
612 self.pending_events.push_back(StreamEvent::ThinkingDelta {
613 content_index,
614 delta: reasoning,
615 });
616 }
617
618 if let Some(content) = delta.content {
621 let last_is_text = matches!(self.partial.content.last(), Some(ContentBlock::Text(_)));
624
625 let content_index = if last_is_text {
626 self.partial.content.len() - 1
627 } else {
628 let idx = self.partial.content.len();
629
630 self.partial
631 .content
632 .push(ContentBlock::Text(TextContent::new("")));
633
634 self.pending_events
635 .push_back(StreamEvent::TextStart { content_index: idx });
636
637 idx
638 };
639
640 if let Some(ContentBlock::Text(t)) = self.partial.content.get_mut(content_index) {
641 t.text.push_str(&content);
642 }
643
644 self.pending_events.push_back(StreamEvent::TextDelta {
645 content_index,
646
647 delta: content,
648 });
649 }
650
651 if let Some(tool_calls) = delta.tool_calls {
654 for tc_delta in tool_calls {
655 let index = tc_delta.index as usize;
656
657 let tool_state_idx = if let Some(existing_idx) =
662 self.tool_calls.iter().position(|tc| tc.index == index)
663 {
664 existing_idx
665 } else {
666 let content_index = self.partial.content.len();
667
668 self.tool_calls.push(ToolCallState {
669 index,
670
671 content_index,
672
673 id: String::new(),
674
675 name: String::new(),
676
677 arguments: String::new(),
678 });
679
680 self.partial.content.push(ContentBlock::ToolCall(ToolCall {
683 id: String::new(),
684
685 name: String::new(),
686
687 arguments: serde_json::Value::Null,
688
689 thought_signature: None,
690 }));
691
692 self.pending_events
693 .push_back(StreamEvent::ToolCallStart { content_index });
694
695 self.tool_calls.len() - 1
696 };
697
698 let tc = &mut self.tool_calls[tool_state_idx];
699
700 let content_index = tc.content_index;
701
702 if let Some(id) = tc_delta.id {
705 tc.id = id;
706
707 if let Some(ContentBlock::ToolCall(block)) =
708 self.partial.content.get_mut(content_index)
709 {
710 block.id.clone_from(&tc.id);
711 }
712 }
713
714 if let Some(function) = tc_delta.function {
717 if let Some(name) = function.name {
718 tc.name = name;
719
720 if let Some(ContentBlock::ToolCall(block)) =
721 self.partial.content.get_mut(content_index)
722 {
723 block.name.clone_from(&tc.name);
724 }
725 }
726
727 if let Some(args) = function.arguments {
728 tc.arguments.push_str(&args);
729
730 self.pending_events.push_back(StreamEvent::ToolCallDelta {
739 content_index,
740
741 delta: args,
742 });
743 }
744 }
745 }
746 }
747
748 if let Some(reason) = choice.finish_reason {
751 self.partial.stop_reason = match reason.as_str() {
752 "length" => StopReason::Length,
753
754 "tool_calls" => StopReason::ToolUse,
755
756 "content_filter" | "error" => StopReason::Error,
757
758 _ => StopReason::Stop,
759 };
760
761 for (content_index, block) in self.partial.content.iter().enumerate() {
765 if let ContentBlock::Text(t) = block {
766 self.pending_events.push_back(StreamEvent::TextEnd {
767 content_index,
768 content: t.text.clone(),
769 });
770 } else if let ContentBlock::Thinking(t) = block {
771 self.pending_events.push_back(StreamEvent::ThinkingEnd {
772 content_index,
773 content: t.thinking.clone(),
774 });
775 }
776 }
777
778 self.finalize_tool_call_arguments();
781
782 for tc in &self.tool_calls {
785 if let Some(ContentBlock::ToolCall(tool_call)) =
786 self.partial.content.get(tc.content_index)
787 {
788 self.pending_events.push_back(StreamEvent::ToolCallEnd {
789 content_index: tc.content_index,
790
791 tool_call: tool_call.clone(),
792 });
793 }
794 }
795 }
796 }
797}
798
799#[derive(Debug, Serialize)]
804pub struct OpenAIRequest<'a> {
805 model: &'a str,
806 messages: Vec<OpenAIMessage<'a>>,
807 #[serde(skip_serializing_if = "Option::is_none")]
808 max_tokens: Option<u32>,
809 #[serde(skip_serializing_if = "Option::is_none")]
811 max_completion_tokens: Option<u32>,
812 #[serde(skip_serializing_if = "Option::is_none")]
813 temperature: Option<f32>,
814 #[serde(skip_serializing_if = "Option::is_none")]
815 tools: Option<Vec<OpenAITool<'a>>>,
816 stream: bool,
817 #[serde(skip_serializing_if = "Option::is_none")]
818 stream_options: Option<OpenAIStreamOptions>,
819}
820
821#[derive(Debug, Serialize)]
822struct OpenAIStreamOptions {
823 include_usage: bool,
824}
825
826#[derive(Debug, Serialize)]
827struct OpenAIMessage<'a> {
828 role: Cow<'a, str>,
829 #[serde(skip_serializing_if = "Option::is_none")]
830 content: Option<OpenAIContent<'a>>,
831 #[serde(skip_serializing_if = "Option::is_none")]
832 tool_calls: Option<Vec<OpenAIToolCallRef<'a>>>,
833 #[serde(skip_serializing_if = "Option::is_none")]
834 tool_call_id: Option<&'a str>,
835}
836
837#[derive(Debug, Serialize)]
838#[serde(untagged)]
839enum OpenAIContent<'a> {
840 Text(Cow<'a, str>),
841 Parts(Vec<OpenAIContentPart<'a>>),
842}
843
844#[derive(Debug, Serialize)]
845#[serde(tag = "type", rename_all = "snake_case")]
846enum OpenAIContentPart<'a> {
847 Text { text: Cow<'a, str> },
848 ImageUrl { image_url: OpenAIImageUrl<'a> },
849}
850
851#[derive(Debug, Serialize)]
852struct OpenAIImageUrl<'a> {
853 url: String,
854 #[serde(skip)]
855 _phantom: std::marker::PhantomData<&'a ()>,
857}
858
859#[derive(Debug, Serialize)]
860struct OpenAIToolCallRef<'a> {
861 id: &'a str,
862 r#type: &'static str,
863 function: OpenAIFunctionRef<'a>,
864}
865
866#[derive(Debug, Serialize)]
867struct OpenAIFunctionRef<'a> {
868 name: &'a str,
869 arguments: String,
870}
871
872#[derive(Debug, Serialize)]
873struct OpenAITool<'a> {
874 r#type: &'static str,
875 function: OpenAIFunction<'a>,
876}
877
878#[derive(Debug, Serialize)]
879struct OpenAIFunction<'a> {
880 name: &'a str,
881 description: &'a str,
882 parameters: &'a serde_json::Value,
883}
884
885#[derive(Debug, Deserialize)]
890struct OpenAIStreamChunk {
891 #[serde(default)]
892 choices: Vec<OpenAIChoice>,
893 #[serde(default)]
894 usage: Option<OpenAIUsage>,
895 #[serde(default)]
896 error: Option<OpenAIChunkError>,
897}
898
899#[derive(Debug, Deserialize)]
900struct OpenAIChoice {
901 delta: OpenAIDelta,
902 #[serde(default)]
903 finish_reason: Option<String>,
904}
905
906#[derive(Debug, Deserialize)]
907struct OpenAIDelta {
908 #[serde(default)]
909 content: Option<String>,
910 #[serde(default)]
911 reasoning_content: Option<String>,
912 #[serde(default)]
913 tool_calls: Option<Vec<OpenAIToolCallDelta>>,
914}
915
916#[derive(Debug, Deserialize)]
917struct OpenAIToolCallDelta {
918 index: u32,
919 #[serde(default)]
920 id: Option<String>,
921 #[serde(default)]
922 function: Option<OpenAIFunctionDelta>,
923}
924
925#[derive(Debug, Deserialize)]
926struct OpenAIFunctionDelta {
927 #[serde(default)]
928 name: Option<String>,
929 #[serde(default)]
930 arguments: Option<String>,
931}
932
933#[derive(Debug, Deserialize)]
934#[allow(clippy::struct_field_names)]
935struct OpenAIUsage {
936 prompt_tokens: u64,
937 #[serde(default)]
938 completion_tokens: Option<u64>,
939 total_tokens: u64,
940}
941
942#[derive(Debug, Deserialize)]
943struct OpenAIChunkError {
944 #[serde(default)]
945 message: Option<String>,
946}
947
948fn convert_message_to_openai(message: &Message) -> Vec<OpenAIMessage<'_>> {
953 match message {
954 Message::User(user) => vec![OpenAIMessage {
955 role: Cow::Borrowed("user"),
956 content: Some(convert_user_content(&user.content)),
957 tool_calls: None,
958 tool_call_id: None,
959 }],
960 Message::Custom(custom) => vec![OpenAIMessage {
961 role: Cow::Borrowed("user"),
962 content: Some(OpenAIContent::Text(Cow::Borrowed(&custom.content))),
963 tool_calls: None,
964 tool_call_id: None,
965 }],
966 Message::Assistant(assistant) => {
967 let mut messages = Vec::new();
968
969 let text: String = assistant
971 .content
972 .iter()
973 .filter_map(|b| match b {
974 ContentBlock::Text(t) => Some(t.text.as_str()),
975 _ => None,
976 })
977 .collect::<String>();
978
979 let tool_calls: Vec<OpenAIToolCallRef<'_>> = assistant
981 .content
982 .iter()
983 .filter_map(|b| match b {
984 ContentBlock::ToolCall(tc) => Some(OpenAIToolCallRef {
985 id: &tc.id,
986 r#type: "function",
987 function: OpenAIFunctionRef {
988 name: &tc.name,
989 arguments: tc.arguments.to_string(),
990 },
991 }),
992 _ => None,
993 })
994 .collect();
995
996 let content = if text.is_empty() {
997 None
998 } else {
999 Some(OpenAIContent::Text(Cow::Owned(text)))
1000 };
1001
1002 let tool_calls = if tool_calls.is_empty() {
1003 None
1004 } else {
1005 Some(tool_calls)
1006 };
1007
1008 messages.push(OpenAIMessage {
1009 role: Cow::Borrowed("assistant"),
1010 content,
1011 tool_calls,
1012 tool_call_id: None,
1013 });
1014
1015 messages
1016 }
1017 Message::ToolResult(result) => {
1018 let parts: Vec<OpenAIContentPart<'_>> = result
1020 .content
1021 .iter()
1022 .filter_map(|block| match block {
1023 ContentBlock::Text(t) => Some(OpenAIContentPart::Text {
1024 text: Cow::Borrowed(&t.text),
1025 }),
1026 ContentBlock::Image(img) => {
1027 let url = format!("data:{};base64,{}", img.mime_type, img.data);
1028 Some(OpenAIContentPart::ImageUrl {
1029 image_url: OpenAIImageUrl {
1030 url,
1031 _phantom: std::marker::PhantomData,
1032 },
1033 })
1034 }
1035 _ => None,
1036 })
1037 .collect();
1038
1039 let content = if parts.is_empty() {
1040 None
1041 } else if parts.len() == 1 && matches!(parts[0], OpenAIContentPart::Text { .. }) {
1042 if let OpenAIContentPart::Text { text } = &parts[0] {
1044 Some(OpenAIContent::Text(text.clone()))
1045 } else {
1046 Some(OpenAIContent::Parts(parts))
1047 }
1048 } else {
1049 Some(OpenAIContent::Parts(parts))
1050 };
1051
1052 vec![OpenAIMessage {
1053 role: Cow::Borrowed("tool"),
1054 content,
1055 tool_calls: None,
1056 tool_call_id: Some(&result.tool_call_id),
1057 }]
1058 }
1059 }
1060}
1061
1062fn convert_user_content(content: &UserContent) -> OpenAIContent<'_> {
1063 match content {
1064 UserContent::Text(text) => OpenAIContent::Text(Cow::Borrowed(text)),
1065 UserContent::Blocks(blocks) => {
1066 let parts: Vec<OpenAIContentPart<'_>> = blocks
1067 .iter()
1068 .filter_map(|block| match block {
1069 ContentBlock::Text(t) => Some(OpenAIContentPart::Text {
1070 text: Cow::Borrowed(&t.text),
1071 }),
1072 ContentBlock::Image(img) => {
1073 let url = format!("data:{};base64,{}", img.mime_type, img.data);
1075 Some(OpenAIContentPart::ImageUrl {
1076 image_url: OpenAIImageUrl {
1077 url,
1078 _phantom: std::marker::PhantomData,
1079 },
1080 })
1081 }
1082 _ => None,
1083 })
1084 .collect();
1085 OpenAIContent::Parts(parts)
1086 }
1087 }
1088}
1089
1090fn convert_tool_to_openai(tool: &ToolDef) -> OpenAITool<'_> {
1091 OpenAITool {
1092 r#type: "function",
1093 function: OpenAIFunction {
1094 name: &tool.name,
1095 description: &tool.description,
1096 parameters: &tool.parameters,
1097 },
1098 }
1099}
1100
1101#[cfg(test)]
1106mod tests {
1107 use super::*;
1108 use asupersync::runtime::RuntimeBuilder;
1109 use futures::{StreamExt, stream};
1110 use serde::{Deserialize, Serialize};
1111 use serde_json::{Value, json};
1112 use std::collections::HashMap;
1113 use std::io::{Read, Write};
1114 use std::net::TcpListener;
1115 use std::path::PathBuf;
1116 use std::sync::mpsc;
1117 use std::time::Duration;
1118
1119 #[test]
1120 fn test_convert_user_text_message() {
1121 let message = Message::User(crate::model::UserMessage {
1122 content: UserContent::Text("Hello".to_string()),
1123 timestamp: 0,
1124 });
1125
1126 let converted = convert_message_to_openai(&message);
1127 assert_eq!(converted.len(), 1);
1128 assert_eq!(converted[0].role, "user");
1129 }
1130
1131 #[test]
1132 fn test_tool_conversion() {
1133 let tool = ToolDef {
1134 name: "test_tool".to_string(),
1135 description: "A test tool".to_string(),
1136 parameters: serde_json::json!({
1137 "type": "object",
1138 "properties": {
1139 "arg": {"type": "string"}
1140 }
1141 }),
1142 };
1143
1144 let converted = convert_tool_to_openai(&tool);
1145 assert_eq!(converted.r#type, "function");
1146 assert_eq!(converted.function.name, "test_tool");
1147 assert_eq!(converted.function.description, "A test tool");
1148 assert_eq!(
1149 converted.function.parameters,
1150 &serde_json::json!({
1151 "type": "object",
1152 "properties": {
1153 "arg": {"type": "string"}
1154 }
1155 })
1156 );
1157 }
1158
1159 #[test]
1160 fn test_provider_info() {
1161 let provider = OpenAIProvider::new("gpt-4o");
1162 assert_eq!(provider.name(), "openai");
1163 assert_eq!(provider.api(), "openai-completions");
1164 }
1165
1166 #[test]
1167 fn test_build_request_includes_system_tools_and_stream_options() {
1168 let provider = OpenAIProvider::new("gpt-4o");
1169 let context = Context {
1170 system_prompt: Some("You are concise.".to_string().into()),
1171 messages: vec![Message::User(crate::model::UserMessage {
1172 content: UserContent::Text("Ping".to_string()),
1173 timestamp: 0,
1174 })]
1175 .into(),
1176 tools: vec![ToolDef {
1177 name: "search".to_string(),
1178 description: "Search docs".to_string(),
1179 parameters: json!({
1180 "type": "object",
1181 "properties": {
1182 "q": { "type": "string" }
1183 },
1184 "required": ["q"]
1185 }),
1186 }]
1187 .into(),
1188 };
1189 let options = StreamOptions {
1190 temperature: Some(0.2),
1191 max_tokens: Some(123),
1192 ..Default::default()
1193 };
1194
1195 let request = provider.build_request(&context, &options);
1196 let value = serde_json::to_value(&request).expect("serialize request");
1197 assert_eq!(value["model"], "gpt-4o");
1198 assert_eq!(value["messages"][0]["role"], "system");
1199 assert_eq!(value["messages"][0]["content"], "You are concise.");
1200 assert_eq!(value["messages"][1]["role"], "user");
1201 assert_eq!(value["messages"][1]["content"], "Ping");
1202 let temperature = value["temperature"]
1203 .as_f64()
1204 .expect("temperature should serialize as number");
1205 assert!((temperature - 0.2).abs() < 1e-6);
1206 assert_eq!(value["max_tokens"], 123);
1207 assert_eq!(value["stream"], true);
1208 assert_eq!(value["stream_options"]["include_usage"], true);
1209 assert_eq!(value["tools"][0]["type"], "function");
1210 assert_eq!(value["tools"][0]["function"]["name"], "search");
1211 assert_eq!(value["tools"][0]["function"]["description"], "Search docs");
1212 assert_eq!(
1213 value["tools"][0]["function"]["parameters"],
1214 json!({
1215 "type": "object",
1216 "properties": {
1217 "q": { "type": "string" }
1218 },
1219 "required": ["q"]
1220 })
1221 );
1222 }
1223
1224 #[test]
1225 fn test_stream_accumulates_tool_call_argument_deltas() {
1226 let events = vec![
1227 json!({ "choices": [{ "delta": {} }] }),
1228 json!({
1229 "choices": [{
1230 "delta": {
1231 "tool_calls": [{
1232 "index": 0,
1233 "id": "call_1",
1234 "function": {
1235 "name": "search",
1236 "arguments": "{\"q\":\"ru"
1237 }
1238 }]
1239 }
1240 }]
1241 }),
1242 json!({
1243 "choices": [{
1244 "delta": {
1245 "tool_calls": [{
1246 "index": 0,
1247 "function": {
1248 "arguments": "st\"}"
1249 }
1250 }]
1251 }
1252 }]
1253 }),
1254 json!({ "choices": [{ "delta": {}, "finish_reason": "tool_calls" }] }),
1255 Value::String("[DONE]".to_string()),
1256 ];
1257
1258 let out = collect_events(&events);
1259 assert!(
1260 out.iter()
1261 .any(|e| matches!(e, StreamEvent::ToolCallStart { .. }))
1262 );
1263 assert!(out.iter().any(
1264 |e| matches!(e, StreamEvent::ToolCallDelta { delta, .. } if delta == "{\"q\":\"ru")
1265 ));
1266 assert!(
1267 out.iter()
1268 .any(|e| matches!(e, StreamEvent::ToolCallDelta { delta, .. } if delta == "st\"}"))
1269 );
1270 let done = out
1271 .iter()
1272 .find_map(|event| match event {
1273 StreamEvent::Done { message, .. } => Some(message),
1274 _ => None,
1275 })
1276 .expect("done event");
1277 let tool_call = done
1278 .content
1279 .iter()
1280 .find_map(|block| match block {
1281 ContentBlock::ToolCall(tc) => Some(tc),
1282 _ => None,
1283 })
1284 .expect("assembled tool call content");
1285 assert_eq!(tool_call.id, "call_1");
1286 assert_eq!(tool_call.name, "search");
1287 assert_eq!(tool_call.arguments, json!({ "q": "rust" }));
1288 assert!(out.iter().any(|e| matches!(
1289 e,
1290 StreamEvent::Done {
1291 reason: StopReason::ToolUse,
1292 ..
1293 }
1294 )));
1295 }
1296
1297 #[test]
1298 fn test_stream_handles_sparse_tool_call_index_without_panic() {
1299 let events = vec![
1300 json!({ "choices": [{ "delta": {} }] }),
1301 json!({
1302 "choices": [{
1303 "delta": {
1304 "tool_calls": [{
1305 "index": 2,
1306 "id": "call_sparse",
1307 "function": {
1308 "name": "lookup",
1309 "arguments": "{\"q\":\"sparse\"}"
1310 }
1311 }]
1312 }
1313 }]
1314 }),
1315 json!({ "choices": [{ "delta": {}, "finish_reason": "tool_calls" }] }),
1316 Value::String("[DONE]".to_string()),
1317 ];
1318
1319 let out = collect_events(&events);
1320 let done = out
1321 .iter()
1322 .find_map(|event| match event {
1323 StreamEvent::Done { message, .. } => Some(message),
1324 _ => None,
1325 })
1326 .expect("done event");
1327 let tool_calls: Vec<&ToolCall> = done
1328 .content
1329 .iter()
1330 .filter_map(|block| match block {
1331 ContentBlock::ToolCall(tc) => Some(tc),
1332 _ => None,
1333 })
1334 .collect();
1335 assert_eq!(tool_calls.len(), 1);
1336 assert_eq!(tool_calls[0].id, "call_sparse");
1337 assert_eq!(tool_calls[0].name, "lookup");
1338 assert_eq!(tool_calls[0].arguments, json!({ "q": "sparse" }));
1339 assert!(
1340 out.iter()
1341 .any(|event| matches!(event, StreamEvent::ToolCallStart { .. })),
1342 "expected tool call start event"
1343 );
1344 }
1345
1346 #[test]
1347 fn test_stream_maps_finish_reason_error_to_stop_reason_error() {
1348 let events = vec![
1349 json!({
1350 "choices": [{ "delta": {}, "finish_reason": "error" }],
1351 "error": { "message": "upstream provider timeout" }
1352 }),
1353 Value::String("[DONE]".to_string()),
1354 ];
1355
1356 let out = collect_events(&events);
1357 let done = out
1358 .iter()
1359 .find_map(|event| match event {
1360 StreamEvent::Done { reason, message } => Some((reason, message)),
1361 _ => None,
1362 })
1363 .expect("done event");
1364 assert_eq!(*done.0, StopReason::Error);
1365 assert_eq!(
1366 done.1.error_message.as_deref(),
1367 Some("upstream provider timeout")
1368 );
1369 }
1370
1371 #[test]
1372 fn test_finish_reason_without_prior_content_emits_start() {
1373 let events = vec![
1374 json!({ "choices": [{ "delta": {}, "finish_reason": "stop" }] }),
1375 Value::String("[DONE]".to_string()),
1376 ];
1377
1378 let out = collect_events(&events);
1379
1380 assert!(!out.is_empty(), "expected at least one event");
1383 assert!(
1384 matches!(out[0], StreamEvent::Start { .. }),
1385 "First event should be Start, got {:?}",
1386 out[0]
1387 );
1388 }
1389
1390 #[test]
1391 fn test_stream_emits_all_events_in_correct_order() {
1392 let events = vec![
1393 json!({ "choices": [{ "delta": { "content": "Hello" } }] }),
1394 json!({ "choices": [{ "delta": { "content": " world" } }] }),
1395 json!({ "choices": [{ "delta": {}, "finish_reason": "stop" }] }),
1396 Value::String("[DONE]".to_string()),
1397 ];
1398
1399 let out = collect_events(&events);
1400
1401 assert_eq!(out.len(), 6, "Expected 6 events, got {}", out.len());
1403
1404 assert!(
1405 matches!(out[0], StreamEvent::Start { .. }),
1406 "Event 0 should be Start, got {:?}",
1407 out[0]
1408 );
1409
1410 assert!(
1411 matches!(
1412 out[1],
1413 StreamEvent::TextStart {
1414 content_index: 0,
1415 ..
1416 }
1417 ),
1418 "Event 1 should be TextStart at index 0, got {:?}",
1419 out[1]
1420 );
1421
1422 assert!(
1423 matches!(&out[2], StreamEvent::TextDelta { content_index: 0, delta, .. } if delta == "Hello"),
1424 "Event 2 should be TextDelta 'Hello' at index 0, got {:?}",
1425 out[2]
1426 );
1427
1428 assert!(
1429 matches!(&out[3], StreamEvent::TextDelta { content_index: 0, delta, .. } if delta == " world"),
1430 "Event 3 should be TextDelta ' world' at index 0, got {:?}",
1431 out[3]
1432 );
1433
1434 assert!(
1435 matches!(&out[4], StreamEvent::TextEnd { content_index: 0, content, .. } if content == "Hello world"),
1436 "Event 4 should be TextEnd 'Hello world' at index 0, got {:?}",
1437 out[4]
1438 );
1439
1440 assert!(
1441 matches!(
1442 out[5],
1443 StreamEvent::Done {
1444 reason: StopReason::Stop,
1445 ..
1446 }
1447 ),
1448 "Event 5 should be Done with Stop reason, got {:?}",
1449 out[5]
1450 );
1451 }
1452
1453 #[test]
1454 fn test_build_request_applies_openrouter_routing_overrides() {
1455 let provider = OpenAIProvider::new("openai/gpt-4o-mini")
1456 .with_provider_name("openrouter")
1457 .with_compat(Some(CompatConfig {
1458 open_router_routing: Some(json!({
1459 "models": ["openai/gpt-4o-mini", "anthropic/claude-3.5-sonnet"],
1460 "provider": {
1461 "order": ["openai", "anthropic"],
1462 "allow_fallbacks": false
1463 },
1464 "route": "fallback"
1465 })),
1466 ..CompatConfig::default()
1467 }));
1468 let context = Context {
1469 system_prompt: None,
1470 messages: vec![Message::User(crate::model::UserMessage {
1471 content: UserContent::Text("Ping".to_string()),
1472 timestamp: 0,
1473 })]
1474 .into(),
1475 tools: Vec::new().into(),
1476 };
1477 let options = StreamOptions::default();
1478
1479 let request = provider
1480 .build_request_json(&context, &options)
1481 .expect("request json");
1482 assert_eq!(request["model"], "openai/gpt-4o-mini");
1483 assert_eq!(request["route"], "fallback");
1484 assert_eq!(request["provider"]["allow_fallbacks"], false);
1485 assert_eq!(request["models"][0], "openai/gpt-4o-mini");
1486 assert_eq!(request["models"][1], "anthropic/claude-3.5-sonnet");
1487 }
1488
1489 #[test]
1490 fn test_stream_sets_bearer_auth_header() {
1491 let captured = run_stream_and_capture_headers().expect("captured request");
1492 assert_eq!(
1493 captured.headers.get("authorization").map(String::as_str),
1494 Some("Bearer test-openai-key")
1495 );
1496 assert_eq!(
1497 captured.headers.get("accept").map(String::as_str),
1498 Some("text/event-stream")
1499 );
1500
1501 let body: Value = serde_json::from_str(&captured.body).expect("request body json");
1502 assert_eq!(body["stream"], true);
1503 assert_eq!(body["stream_options"]["include_usage"], true);
1504 }
1505
1506 #[test]
1507 fn test_stream_openrouter_injects_default_attribution_headers() {
1508 let options = StreamOptions {
1509 api_key: Some("test-openrouter-key".to_string()),
1510 ..Default::default()
1511 };
1512 let captured = run_stream_and_capture_headers_with(
1513 OpenAIProvider::new("openai/gpt-4o-mini").with_provider_name("openrouter"),
1514 &options,
1515 )
1516 .expect("captured request");
1517
1518 assert_eq!(
1519 captured.headers.get("http-referer").map(String::as_str),
1520 Some(OPENROUTER_DEFAULT_HTTP_REFERER)
1521 );
1522 assert_eq!(
1523 captured.headers.get("x-title").map(String::as_str),
1524 Some(OPENROUTER_DEFAULT_X_TITLE)
1525 );
1526 }
1527
1528 #[test]
1529 fn test_stream_openrouter_respects_explicit_attribution_headers() {
1530 let options = StreamOptions {
1531 api_key: Some("test-openrouter-key".to_string()),
1532 headers: HashMap::from([
1533 (
1534 "HTTP-Referer".to_string(),
1535 "https://example.test/app".to_string(),
1536 ),
1537 (
1538 "X-Title".to_string(),
1539 "Custom OpenRouter Client".to_string(),
1540 ),
1541 ]),
1542 ..Default::default()
1543 };
1544 let captured = run_stream_and_capture_headers_with(
1545 OpenAIProvider::new("openai/gpt-4o-mini").with_provider_name("openrouter"),
1546 &options,
1547 )
1548 .expect("captured request");
1549
1550 assert_eq!(
1551 captured.headers.get("http-referer").map(String::as_str),
1552 Some("https://example.test/app")
1553 );
1554 assert_eq!(
1555 captured.headers.get("x-title").map(String::as_str),
1556 Some("Custom OpenRouter Client")
1557 );
1558 }
1559
1560 #[derive(Debug, Deserialize)]
1561 struct ProviderFixture {
1562 cases: Vec<ProviderCase>,
1563 }
1564
1565 #[derive(Debug, Deserialize)]
1566 struct ProviderCase {
1567 name: String,
1568 events: Vec<Value>,
1569 expected: Vec<EventSummary>,
1570 }
1571
1572 #[derive(Debug, Deserialize, Serialize, PartialEq)]
1573 struct EventSummary {
1574 kind: String,
1575 #[serde(default)]
1576 content_index: Option<usize>,
1577 #[serde(default)]
1578 delta: Option<String>,
1579 #[serde(default)]
1580 content: Option<String>,
1581 #[serde(default)]
1582 reason: Option<String>,
1583 }
1584
1585 #[test]
1586 fn test_stream_fixtures() {
1587 let fixture = load_fixture("openai_stream.json");
1588 for case in fixture.cases {
1589 let events = collect_events(&case.events);
1590 let summaries: Vec<EventSummary> = events.iter().map(summarize_event).collect();
1591 assert_eq!(summaries, case.expected, "case {}", case.name);
1592 }
1593 }
1594
1595 fn load_fixture(file_name: &str) -> ProviderFixture {
1596 let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
1597 .join("tests/fixtures/provider_responses")
1598 .join(file_name);
1599 let raw = std::fs::read_to_string(path).expect("fixture read");
1600 serde_json::from_str(&raw).expect("fixture parse")
1601 }
1602
1603 #[derive(Debug)]
1604 struct CapturedRequest {
1605 headers: HashMap<String, String>,
1606 body: String,
1607 }
1608
1609 fn run_stream_and_capture_headers() -> Option<CapturedRequest> {
1610 let options = StreamOptions {
1611 api_key: Some("test-openai-key".to_string()),
1612 ..Default::default()
1613 };
1614 run_stream_and_capture_headers_with(OpenAIProvider::new("gpt-4o"), &options)
1615 }
1616
1617 fn run_stream_and_capture_headers_with(
1618 provider: OpenAIProvider,
1619 options: &StreamOptions,
1620 ) -> Option<CapturedRequest> {
1621 let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
1622 let provider = provider.with_base_url(base_url);
1623 let context = Context {
1624 system_prompt: None,
1625 messages: vec![Message::User(crate::model::UserMessage {
1626 content: UserContent::Text("ping".to_string()),
1627 timestamp: 0,
1628 })]
1629 .into(),
1630 tools: Vec::new().into(),
1631 };
1632
1633 let runtime = RuntimeBuilder::current_thread()
1634 .build()
1635 .expect("runtime build");
1636 runtime.block_on(async {
1637 let mut stream = provider.stream(&context, options).await.expect("stream");
1638 while let Some(event) = stream.next().await {
1639 if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
1640 break;
1641 }
1642 }
1643 });
1644
1645 rx.recv_timeout(Duration::from_secs(2)).ok()
1646 }
1647
1648 fn success_sse_body() -> String {
1649 [
1650 r#"data: {"choices":[{"delta":{}}]}"#,
1651 "",
1652 r#"data: {"choices":[{"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}"#,
1653 "",
1654 "data: [DONE]",
1655 "",
1656 ]
1657 .join("\n")
1658 }
1659
1660 fn spawn_test_server(
1661 status_code: u16,
1662 content_type: &str,
1663 body: &str,
1664 ) -> (String, mpsc::Receiver<CapturedRequest>) {
1665 let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
1666 let addr = listener.local_addr().expect("local addr");
1667 let (tx, rx) = mpsc::channel();
1668 let body = body.to_string();
1669 let content_type = content_type.to_string();
1670
1671 std::thread::spawn(move || {
1672 let (mut socket, _) = listener.accept().expect("accept");
1673 socket
1674 .set_read_timeout(Some(Duration::from_secs(2)))
1675 .expect("set read timeout");
1676
1677 let mut bytes = Vec::new();
1678 let mut chunk = [0_u8; 4096];
1679 loop {
1680 match socket.read(&mut chunk) {
1681 Ok(0) => break,
1682 Ok(n) => {
1683 bytes.extend_from_slice(&chunk[..n]);
1684 if bytes.windows(4).any(|window| window == b"\r\n\r\n") {
1685 break;
1686 }
1687 }
1688 Err(err)
1689 if err.kind() == std::io::ErrorKind::WouldBlock
1690 || err.kind() == std::io::ErrorKind::TimedOut =>
1691 {
1692 break;
1693 }
1694 Err(err) => panic!("read request failed: {err}"),
1695 }
1696 }
1697
1698 let header_end = bytes
1699 .windows(4)
1700 .position(|window| window == b"\r\n\r\n")
1701 .expect("request header boundary");
1702 let header_text = String::from_utf8_lossy(&bytes[..header_end]).to_string();
1703 let headers = parse_headers(&header_text);
1704 let mut request_body = bytes[header_end + 4..].to_vec();
1705
1706 let content_length = headers
1707 .get("content-length")
1708 .and_then(|value| value.parse::<usize>().ok())
1709 .unwrap_or(0);
1710 while request_body.len() < content_length {
1711 match socket.read(&mut chunk) {
1712 Ok(0) => break,
1713 Ok(n) => request_body.extend_from_slice(&chunk[..n]),
1714 Err(err)
1715 if err.kind() == std::io::ErrorKind::WouldBlock
1716 || err.kind() == std::io::ErrorKind::TimedOut =>
1717 {
1718 break;
1719 }
1720 Err(err) => panic!("read request body failed: {err}"),
1721 }
1722 }
1723
1724 let captured = CapturedRequest {
1725 headers,
1726 body: String::from_utf8_lossy(&request_body).to_string(),
1727 };
1728 tx.send(captured).expect("send captured request");
1729
1730 let reason = match status_code {
1731 401 => "Unauthorized",
1732 500 => "Internal Server Error",
1733 _ => "OK",
1734 };
1735 let response = format!(
1736 "HTTP/1.1 {status_code} {reason}\r\nContent-Type: {content_type}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
1737 body.len()
1738 );
1739 socket
1740 .write_all(response.as_bytes())
1741 .expect("write response");
1742 socket.flush().expect("flush response");
1743 });
1744
1745 (format!("http://{addr}/chat/completions"), rx)
1746 }
1747
1748 fn parse_headers(header_text: &str) -> HashMap<String, String> {
1749 let mut headers = HashMap::new();
1750 for line in header_text.lines().skip(1) {
1751 if let Some((name, value)) = line.split_once(':') {
1752 headers.insert(name.trim().to_ascii_lowercase(), value.trim().to_string());
1753 }
1754 }
1755 headers
1756 }
1757
1758 fn collect_events(events: &[Value]) -> Vec<StreamEvent> {
1759 let runtime = RuntimeBuilder::current_thread()
1760 .build()
1761 .expect("runtime build");
1762 runtime.block_on(async move {
1763 let byte_stream = stream::iter(
1764 events
1765 .iter()
1766 .map(|event| {
1767 let data = match event {
1768 Value::String(text) => text.clone(),
1769 _ => serde_json::to_string(event).expect("serialize event"),
1770 };
1771 format!("data: {data}\n\n").into_bytes()
1772 })
1773 .map(Ok),
1774 );
1775 let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
1776 let mut state = StreamState::new(
1777 event_source,
1778 "gpt-test".to_string(),
1779 "openai".to_string(),
1780 "openai".to_string(),
1781 );
1782 let mut out = Vec::new();
1783
1784 while let Some(item) = state.event_source.next().await {
1785 let msg = item.expect("SSE event");
1786 if msg.data == "[DONE]" {
1787 out.extend(state.pending_events.drain(..));
1788 let reason = state.partial.stop_reason;
1789 out.push(StreamEvent::Done {
1790 reason,
1791 message: std::mem::take(&mut state.partial),
1792 });
1793 break;
1794 }
1795 state.process_event(&msg.data).expect("process_event");
1796 out.extend(state.pending_events.drain(..));
1797 }
1798
1799 out
1800 })
1801 }
1802
1803 fn summarize_event(event: &StreamEvent) -> EventSummary {
1804 match event {
1805 StreamEvent::Start { .. } => EventSummary {
1806 kind: "start".to_string(),
1807 content_index: None,
1808 delta: None,
1809 content: None,
1810 reason: None,
1811 },
1812 StreamEvent::TextDelta {
1813 content_index,
1814 delta,
1815 ..
1816 } => EventSummary {
1817 kind: "text_delta".to_string(),
1818 content_index: Some(*content_index),
1819 delta: Some(delta.clone()),
1820 content: None,
1821 reason: None,
1822 },
1823 StreamEvent::Done { reason, .. } => EventSummary {
1824 kind: "done".to_string(),
1825 content_index: None,
1826 delta: None,
1827 content: None,
1828 reason: Some(reason_to_string(*reason)),
1829 },
1830 StreamEvent::Error { reason, .. } => EventSummary {
1831 kind: "error".to_string(),
1832 content_index: None,
1833 delta: None,
1834 content: None,
1835 reason: Some(reason_to_string(*reason)),
1836 },
1837 StreamEvent::TextStart { content_index, .. } => EventSummary {
1838 kind: "text_start".to_string(),
1839 content_index: Some(*content_index),
1840 delta: None,
1841 content: None,
1842 reason: None,
1843 },
1844 StreamEvent::TextEnd {
1845 content_index,
1846 content,
1847 ..
1848 } => EventSummary {
1849 kind: "text_end".to_string(),
1850 content_index: Some(*content_index),
1851 delta: None,
1852 content: Some(content.clone()),
1853 reason: None,
1854 },
1855 _ => EventSummary {
1856 kind: "other".to_string(),
1857 content_index: None,
1858 delta: None,
1859 content: None,
1860 reason: None,
1861 },
1862 }
1863 }
1864
1865 fn reason_to_string(reason: StopReason) -> String {
1866 match reason {
1867 StopReason::Stop => "stop",
1868 StopReason::Length => "length",
1869 StopReason::ToolUse => "tool_use",
1870 StopReason::Error => "error",
1871 StopReason::Aborted => "aborted",
1872 }
1873 .to_string()
1874 }
1875
1876 fn context_with_tools() -> Context<'static> {
1879 Context {
1880 system_prompt: Some("You are helpful.".to_string().into()),
1881 messages: vec![Message::User(crate::model::UserMessage {
1882 content: UserContent::Text("Hi".to_string()),
1883 timestamp: 0,
1884 })]
1885 .into(),
1886 tools: vec![ToolDef {
1887 name: "search".to_string(),
1888 description: "Search".to_string(),
1889 parameters: json!({"type": "object", "properties": {}}),
1890 }]
1891 .into(),
1892 }
1893 }
1894
1895 fn default_stream_options() -> StreamOptions {
1896 StreamOptions {
1897 max_tokens: Some(1024),
1898 ..Default::default()
1899 }
1900 }
1901
1902 #[test]
1903 fn compat_system_role_name_overrides_default() {
1904 let provider = OpenAIProvider::new("gpt-4o").with_compat(Some(CompatConfig {
1905 system_role_name: Some("developer".to_string()),
1906 ..Default::default()
1907 }));
1908 let context = context_with_tools();
1909 let options = default_stream_options();
1910 let req = provider.build_request(&context, &options);
1911 let value = serde_json::to_value(&req).expect("serialize");
1912 assert_eq!(
1913 value["messages"][0]["role"], "developer",
1914 "system message should use overridden role name"
1915 );
1916 }
1917
1918 #[test]
1919 fn compat_none_uses_default_system_role() {
1920 let provider = OpenAIProvider::new("gpt-4o");
1921 let context = context_with_tools();
1922 let options = default_stream_options();
1923 let req = provider.build_request(&context, &options);
1924 let value = serde_json::to_value(&req).expect("serialize");
1925 assert_eq!(
1926 value["messages"][0]["role"], "system",
1927 "default system role should be 'system'"
1928 );
1929 }
1930
1931 #[test]
1932 fn compat_supports_tools_false_omits_tools() {
1933 let provider = OpenAIProvider::new("gpt-4o").with_compat(Some(CompatConfig {
1934 supports_tools: Some(false),
1935 ..Default::default()
1936 }));
1937 let context = context_with_tools();
1938 let options = default_stream_options();
1939 let req = provider.build_request(&context, &options);
1940 let value = serde_json::to_value(&req).expect("serialize");
1941 assert!(
1942 value["tools"].is_null(),
1943 "tools should be omitted when supports_tools=false"
1944 );
1945 }
1946
1947 #[test]
1948 fn compat_supports_tools_true_includes_tools() {
1949 let provider = OpenAIProvider::new("gpt-4o").with_compat(Some(CompatConfig {
1950 supports_tools: Some(true),
1951 ..Default::default()
1952 }));
1953 let context = context_with_tools();
1954 let options = default_stream_options();
1955 let req = provider.build_request(&context, &options);
1956 let value = serde_json::to_value(&req).expect("serialize");
1957 assert!(
1958 value["tools"].is_array(),
1959 "tools should be included when supports_tools=true"
1960 );
1961 }
1962
1963 #[test]
1964 fn compat_max_tokens_field_routes_to_max_completion_tokens() {
1965 let provider = OpenAIProvider::new("o1").with_compat(Some(CompatConfig {
1966 max_tokens_field: Some("max_completion_tokens".to_string()),
1967 ..Default::default()
1968 }));
1969 let context = context_with_tools();
1970 let options = default_stream_options();
1971 let req = provider.build_request(&context, &options);
1972 let value = serde_json::to_value(&req).expect("serialize");
1973 assert!(
1974 value["max_tokens"].is_null(),
1975 "max_tokens should be absent when routed to max_completion_tokens"
1976 );
1977 assert_eq!(
1978 value["max_completion_tokens"], 1024,
1979 "max_completion_tokens should carry the token limit"
1980 );
1981 }
1982
1983 #[test]
1984 fn compat_default_routes_to_max_tokens() {
1985 let provider = OpenAIProvider::new("gpt-4o");
1986 let context = context_with_tools();
1987 let options = default_stream_options();
1988 let req = provider.build_request(&context, &options);
1989 let value = serde_json::to_value(&req).expect("serialize");
1990 assert_eq!(
1991 value["max_tokens"], 1024,
1992 "default should use max_tokens field"
1993 );
1994 assert!(
1995 value["max_completion_tokens"].is_null(),
1996 "max_completion_tokens should be absent by default"
1997 );
1998 }
1999
2000 #[test]
2001 fn compat_supports_usage_in_streaming_false() {
2002 let provider = OpenAIProvider::new("gpt-4o").with_compat(Some(CompatConfig {
2003 supports_usage_in_streaming: Some(false),
2004 ..Default::default()
2005 }));
2006 let context = context_with_tools();
2007 let options = default_stream_options();
2008 let req = provider.build_request(&context, &options);
2009 let value = serde_json::to_value(&req).expect("serialize");
2010 assert_eq!(
2011 value["stream_options"]["include_usage"], false,
2012 "include_usage should be false when supports_usage_in_streaming=false"
2013 );
2014 }
2015
2016 #[test]
2017 fn compat_combined_overrides() {
2018 let provider = OpenAIProvider::new("custom-model").with_compat(Some(CompatConfig {
2019 system_role_name: Some("developer".to_string()),
2020 max_tokens_field: Some("max_completion_tokens".to_string()),
2021 supports_tools: Some(false),
2022 supports_usage_in_streaming: Some(false),
2023 ..Default::default()
2024 }));
2025 let context = context_with_tools();
2026 let options = default_stream_options();
2027 let req = provider.build_request(&context, &options);
2028 let value = serde_json::to_value(&req).expect("serialize");
2029 assert_eq!(value["messages"][0]["role"], "developer");
2030 assert!(value["max_tokens"].is_null());
2031 assert_eq!(value["max_completion_tokens"], 1024);
2032 assert!(value["tools"].is_null());
2033 assert_eq!(value["stream_options"]["include_usage"], false);
2034 }
2035
2036 #[test]
2037 fn compat_custom_headers_injected_into_stream_request() {
2038 let mut custom = HashMap::new();
2039 custom.insert("X-Custom-Tag".to_string(), "test-123".to_string());
2040 custom.insert("X-Provider-Region".to_string(), "us-east-1".to_string());
2041 let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
2042 let provider = OpenAIProvider::new("gpt-4o")
2043 .with_base_url(base_url)
2044 .with_compat(Some(CompatConfig {
2045 custom_headers: Some(custom),
2046 ..Default::default()
2047 }));
2048
2049 let context = Context {
2050 system_prompt: None,
2051 messages: vec![Message::User(crate::model::UserMessage {
2052 content: UserContent::Text("ping".to_string()),
2053 timestamp: 0,
2054 })]
2055 .into(),
2056 tools: Vec::new().into(),
2057 };
2058 let options = StreamOptions {
2059 api_key: Some("test-key".to_string()),
2060 ..Default::default()
2061 };
2062
2063 let runtime = RuntimeBuilder::current_thread()
2064 .build()
2065 .expect("runtime build");
2066 runtime.block_on(async {
2067 let mut stream = provider.stream(&context, &options).await.expect("stream");
2068 while let Some(event) = stream.next().await {
2069 if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
2070 break;
2071 }
2072 }
2073 });
2074
2075 let captured = rx
2076 .recv_timeout(Duration::from_secs(2))
2077 .expect("captured request");
2078 assert_eq!(
2079 captured.headers.get("x-custom-tag").map(String::as_str),
2080 Some("test-123"),
2081 "custom header should be present in request"
2082 );
2083 assert_eq!(
2084 captured
2085 .headers
2086 .get("x-provider-region")
2087 .map(String::as_str),
2088 Some("us-east-1"),
2089 "custom header should be present in request"
2090 );
2091 }
2092
2093 mod proptest_process_event {
2098 use super::*;
2099 use proptest::prelude::*;
2100
2101 fn make_state()
2102 -> StreamState<impl Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin>
2103 {
2104 let empty = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
2105 let sse = crate::sse::SseStream::new(Box::pin(empty));
2106 StreamState::new(sse, "gpt-test".into(), "openai".into(), "openai".into())
2107 }
2108
2109 fn small_string() -> impl Strategy<Value = String> {
2110 prop_oneof![Just(String::new()), "[a-zA-Z0-9_]{1,16}", "[ -~]{0,32}",]
2111 }
2112
2113 fn optional_string() -> impl Strategy<Value = Option<String>> {
2114 prop_oneof![Just(None), small_string().prop_map(Some),]
2115 }
2116
2117 fn token_count() -> impl Strategy<Value = u64> {
2118 prop_oneof![
2119 5 => 0u64..10_000u64,
2120 2 => Just(0u64),
2121 1 => Just(u64::MAX),
2122 1 => (u64::MAX - 100)..=u64::MAX,
2123 ]
2124 }
2125
2126 fn finish_reason() -> impl Strategy<Value = Option<String>> {
2127 prop_oneof![
2128 3 => Just(None),
2129 1 => Just(Some("stop".to_string())),
2130 1 => Just(Some("length".to_string())),
2131 1 => Just(Some("tool_calls".to_string())),
2132 1 => Just(Some("content_filter".to_string())),
2133 1 => small_string().prop_map(Some),
2134 ]
2135 }
2136
2137 fn tool_call_index() -> impl Strategy<Value = u32> {
2138 prop_oneof![
2139 5 => 0u32..3u32,
2140 1 => Just(u32::MAX),
2141 1 => 100u32..200u32,
2142 ]
2143 }
2144
2145 fn openai_chunk_json() -> impl Strategy<Value = String> {
2147 prop_oneof![
2148 3 => (small_string(), finish_reason()).prop_map(|(text, fr)| {
2150 let mut choice = serde_json::json!({
2151 "delta": {"content": text}
2152 });
2153 if let Some(reason) = fr {
2154 choice["finish_reason"] = serde_json::Value::String(reason);
2155 }
2156 serde_json::json!({"choices": [choice]}).to_string()
2157 }),
2158 2 => Just(r#"{"choices":[{"delta":{}}]}"#.to_string()),
2160 2 => finish_reason().prop_filter("some reason", Option::is_some).prop_map(|fr| {
2162 serde_json::json!({
2163 "choices": [{"delta": {}, "finish_reason": fr.unwrap()}]
2164 })
2165 .to_string()
2166 }),
2167 3 => (tool_call_index(), optional_string(), optional_string(), optional_string())
2169 .prop_map(|(idx, id, name, args)| {
2170 let mut tc = serde_json::json!({"index": idx});
2171 if let Some(id) = id { tc["id"] = serde_json::Value::String(id); }
2172 let mut func = serde_json::Map::new();
2173 if let Some(n) = name { func.insert("name".into(), serde_json::Value::String(n)); }
2174 if let Some(a) = args { func.insert("arguments".into(), serde_json::Value::String(a)); }
2175 if !func.is_empty() { tc["function"] = serde_json::Value::Object(func); }
2176 serde_json::json!({
2177 "choices": [{"delta": {"tool_calls": [tc]}}]
2178 })
2179 .to_string()
2180 }),
2181 2 => (token_count(), token_count(), token_count()).prop_map(|(prompt, compl, total)| {
2183 serde_json::json!({
2184 "choices": [],
2185 "usage": {
2186 "prompt_tokens": prompt,
2187 "completion_tokens": compl,
2188 "total_tokens": total
2189 }
2190 })
2191 .to_string()
2192 }),
2193 1 => small_string().prop_map(|msg| {
2195 serde_json::json!({
2196 "choices": [],
2197 "error": {"message": msg}
2198 })
2199 .to_string()
2200 }),
2201 1 => Just(r#"{"choices":[]}"#.to_string()),
2203 ]
2204 }
2205
2206 fn chaos_json() -> impl Strategy<Value = String> {
2208 prop_oneof![
2209 Just(String::new()),
2210 Just("{}".to_string()),
2211 Just("[]".to_string()),
2212 Just("null".to_string()),
2213 Just("{".to_string()),
2214 Just(r#"{"choices":"not_array"}"#.to_string()),
2215 Just(r#"{"choices":[{"delta":null}]}"#.to_string()),
2216 "[a-z_]{1,20}".prop_map(|t| format!(r#"{{"type":"{t}"}}"#)),
2217 "[ -~]{0,64}",
2218 ]
2219 }
2220
2221 proptest! {
2222 #![proptest_config(ProptestConfig {
2223 cases: 256,
2224 max_shrink_iters: 100,
2225 .. ProptestConfig::default()
2226 })]
2227
2228 #[test]
2229 fn process_event_valid_never_panics(data in openai_chunk_json()) {
2230 let mut state = make_state();
2231 let _ = state.process_event(&data);
2232 }
2233
2234 #[test]
2235 fn process_event_chaos_never_panics(data in chaos_json()) {
2236 let mut state = make_state();
2237 let _ = state.process_event(&data);
2238 }
2239
2240 #[test]
2241 fn process_event_sequence_never_panics(
2242 events in prop::collection::vec(openai_chunk_json(), 1..8)
2243 ) {
2244 let mut state = make_state();
2245 for event in &events {
2246 let _ = state.process_event(event);
2247 }
2248 }
2249
2250 #[test]
2251 fn process_event_mixed_sequence_never_panics(
2252 events in prop::collection::vec(
2253 prop_oneof![openai_chunk_json(), chaos_json()],
2254 1..12
2255 )
2256 ) {
2257 let mut state = make_state();
2258 for event in &events {
2259 let _ = state.process_event(event);
2260 }
2261 }
2262 }
2263 }
2264}
2265
2266#[cfg(feature = "fuzzing")]
2271pub mod fuzz {
2272 use super::*;
2273 use futures::stream;
2274 use std::pin::Pin;
2275
2276 type FuzzStream =
2277 Pin<Box<futures::stream::Empty<std::result::Result<Vec<u8>, std::io::Error>>>>;
2278
2279 pub struct Processor(StreamState<FuzzStream>);
2281
2282 impl Default for Processor {
2283 fn default() -> Self {
2284 Self::new()
2285 }
2286 }
2287
2288 impl Processor {
2289 pub fn new() -> Self {
2291 let empty = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
2292 Self(StreamState::new(
2293 crate::sse::SseStream::new(Box::pin(empty)),
2294 "gpt-fuzz".into(),
2295 "openai".into(),
2296 "openai".into(),
2297 ))
2298 }
2299
2300 pub fn process_event(&mut self, data: &str) -> crate::error::Result<Vec<StreamEvent>> {
2302 self.0.process_event(data)?;
2303 Ok(self.0.pending_events.drain(..).collect())
2304 }
2305 }
2306}