1#![cfg(any(feature = "client", feature = "blocking"))]
2
3use std::collections::HashMap;
4use std::time::Duration;
5
6use crate::errors::{APIError, Error, Result, TransportError, TransportErrorKind, ValidationError};
7#[cfg(feature = "streaming")]
8use crate::types::StreamEventKind;
9use crate::types::{
10 Model, ProxyMessage, ProxyRequest, ProxyResponse, ResponseFormat, StopReason, Usage,
11};
12
13#[cfg(feature = "blocking")]
14use crate::blocking::BlockingLLMClient;
15#[cfg(all(feature = "blocking", feature = "streaming"))]
16use crate::blocking::BlockingProxyHandle;
17#[cfg(feature = "client")]
18use crate::client::LLMClient;
19#[cfg(all(feature = "client", feature = "streaming"))]
20use crate::sse::StreamHandle;
21
22#[cfg(any(feature = "client", feature = "blocking"))]
23use crate::{ProxyOptions, RetryConfig};
24#[cfg(all(feature = "client", feature = "streaming"))]
25use futures_util::stream;
26use schemars::JsonSchema;
27#[cfg(all(feature = "client", feature = "streaming"))]
28use serde::de::DeserializeOwned;
29
30macro_rules! impl_chat_builder_common {
36 ($builder:ty) => {
37 impl $builder {
38 pub fn message(
40 mut self,
41 role: crate::types::MessageRole,
42 content: impl Into<String>,
43 ) -> Self {
44 self.messages.push(ProxyMessage {
45 role,
46 content: content.into(),
47 tool_calls: None,
48 tool_call_id: None,
49 });
50 self
51 }
52
53 pub fn system(self, content: impl Into<String>) -> Self {
55 self.message(crate::types::MessageRole::System, content)
56 }
57
58 pub fn user(self, content: impl Into<String>) -> Self {
60 self.message(crate::types::MessageRole::User, content)
61 }
62
63 pub fn assistant(self, content: impl Into<String>) -> Self {
65 self.message(crate::types::MessageRole::Assistant, content)
66 }
67
68 pub fn messages(mut self, messages: Vec<ProxyMessage>) -> Self {
70 self.messages = messages;
71 self
72 }
73
74 pub fn max_tokens(mut self, max_tokens: i64) -> Self {
76 self.max_tokens = Some(max_tokens);
77 self
78 }
79
80 pub fn temperature(mut self, temperature: f64) -> Self {
82 self.temperature = Some(temperature);
83 self
84 }
85
86 pub fn metadata(mut self, metadata: HashMap<String, String>) -> Self {
88 self.metadata = Some(metadata);
89 self
90 }
91
92 pub fn metadata_entry(
94 mut self,
95 key: impl Into<String>,
96 value: impl Into<String>,
97 ) -> Self {
98 let key = key.into();
99 let value = value.into();
100 if key.trim().is_empty() || value.trim().is_empty() {
101 return self;
102 }
103 let mut map = self.metadata.unwrap_or_default();
104 map.insert(key, value);
105 self.metadata = Some(map);
106 self
107 }
108
109 pub fn response_format(mut self, response_format: ResponseFormat) -> Self {
111 self.response_format = Some(response_format);
112 self
113 }
114
115 pub fn stop(mut self, stop: Vec<String>) -> Self {
117 self.stop = Some(stop);
118 self
119 }
120
121 pub fn tools(mut self, tools: Vec<crate::types::Tool>) -> Self {
123 self.tools = Some(tools);
124 self
125 }
126
127 pub fn tool_choice(mut self, tool_choice: crate::types::ToolChoice) -> Self {
129 self.tool_choice = Some(tool_choice);
130 self
131 }
132
133 pub fn request_id(mut self, request_id: impl Into<String>) -> Self {
135 self.request_id = Some(request_id.into());
136 self
137 }
138
139 pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
141 self.headers.push((key.into(), value.into()));
142 self
143 }
144
145 pub fn timeout(mut self, timeout: Duration) -> Self {
147 self.timeout = Some(timeout);
148 self
149 }
150
151 pub fn retry(mut self, retry: RetryConfig) -> Self {
153 self.retry = Some(retry);
154 self
155 }
156 }
157 };
158}
159
160#[derive(Clone, Debug, Default)]
162pub struct ChatRequestBuilder {
163 pub(crate) model: Option<Model>,
164 pub(crate) max_tokens: Option<i64>,
165 pub(crate) temperature: Option<f64>,
166 pub(crate) messages: Vec<ProxyMessage>,
167 pub(crate) metadata: Option<HashMap<String, String>>,
168 pub(crate) response_format: Option<ResponseFormat>,
169 pub(crate) stop: Option<Vec<String>>,
170 pub(crate) tools: Option<Vec<crate::types::Tool>>,
171 pub(crate) tool_choice: Option<crate::types::ToolChoice>,
172 pub(crate) request_id: Option<String>,
173 pub(crate) headers: Vec<(String, String)>,
174 pub(crate) timeout: Option<Duration>,
175 pub(crate) retry: Option<RetryConfig>,
176}
177
178impl_chat_builder_common!(ChatRequestBuilder);
180
181impl ChatRequestBuilder {
182 pub fn new(model: impl Into<Model>) -> Self {
184 Self {
185 model: Some(model.into()),
186 ..Default::default()
187 }
188 }
189
190 fn build_options(&self) -> ProxyOptions {
191 let mut opts = ProxyOptions::default();
192 if let Some(req_id) = &self.request_id {
193 opts = opts.with_request_id(req_id.clone());
194 }
195 for (k, v) in &self.headers {
196 opts = opts.with_header(k.clone(), v.clone());
197 }
198 if let Some(timeout) = self.timeout {
199 opts = opts.with_timeout(timeout);
200 }
201 if let Some(retry) = &self.retry {
202 opts = opts.with_retry(retry.clone());
203 }
204 opts
205 }
206
207 pub fn build_request(&self) -> Result<ProxyRequest> {
208 let model = self
209 .model
210 .clone()
211 .ok_or_else(|| Error::Validation("model is required".into()))?;
212
213 if self.messages.is_empty() {
214 return Err(Error::Validation(
215 ValidationError::new("at least one message is required").with_field("messages"),
216 ));
217 }
218 if !self
219 .messages
220 .iter()
221 .any(|msg| msg.role == crate::types::MessageRole::User)
222 {
223 return Err(Error::Validation(
224 ValidationError::new("at least one user message is required")
225 .with_field("messages"),
226 ));
227 }
228
229 let req = ProxyRequest {
230 model,
231 max_tokens: self.max_tokens,
232 temperature: self.temperature,
233 messages: self.messages.clone(),
234 metadata: self.metadata.clone(),
235 response_format: self.response_format.clone(),
236 stop: self.stop.clone(),
237 tools: self.tools.clone(),
238 tool_choice: self.tool_choice.clone(),
239 };
240 req.validate()?;
241 Ok(req)
242 }
243
244 #[cfg(feature = "client")]
246 pub async fn send(self, client: &LLMClient) -> Result<ProxyResponse> {
247 let req = self.build_request()?;
248 let opts = self.build_options();
249 client.proxy(req, opts).await
250 }
251
252 #[cfg(all(feature = "client", feature = "streaming"))]
254 pub async fn stream(self, client: &LLMClient) -> Result<StreamHandle> {
255 let req = self.build_request()?;
256 let opts = self.build_options();
257 client.proxy_stream(req, opts).await
258 }
259
260 #[cfg(all(feature = "client", feature = "streaming"))]
262 pub async fn stream_deltas(
263 self,
264 client: &LLMClient,
265 ) -> Result<std::pin::Pin<Box<dyn futures_core::Stream<Item = Result<String>> + Send>>> {
266 let req = self.build_request()?;
267 let opts = self.build_options();
268 client.proxy_stream_deltas(req, opts).await
269 }
270
271 #[cfg(feature = "client")]
298 pub fn structured<T>(self) -> crate::structured::StructuredChatBuilder<T>
299 where
300 T: JsonSchema + DeserializeOwned,
301 {
302 crate::structured::StructuredChatBuilder::new(self)
303 }
304
305 #[cfg(all(feature = "client", feature = "streaming"))]
310 pub async fn stream_json<T>(self, client: &LLMClient) -> Result<StructuredJSONStream<T>>
311 where
312 T: DeserializeOwned,
313 {
314 let req = self.build_request()?;
315 match &req.response_format {
316 Some(format) if format.is_structured() => {}
317 Some(_) => {
318 return Err(Error::Validation(
319 ValidationError::new("response_format must be structured (type=json_schema)")
320 .with_field("response_format.type"),
321 ));
322 }
323 None => {
324 return Err(Error::Validation(
325 ValidationError::new("response_format is required for structured streaming")
326 .with_field("response_format"),
327 ));
328 }
329 }
330 let opts = self.build_options();
331 let stream = client.proxy_stream(req, opts).await?;
332 Ok(StructuredJSONStream::new(stream))
333 }
334
335 #[cfg(feature = "blocking")]
337 pub fn send_blocking(self, client: &BlockingLLMClient) -> Result<ProxyResponse> {
338 let req = self.build_request()?;
339 let opts = self.build_options();
340 client.proxy(req, opts)
341 }
342
343 #[cfg(all(feature = "blocking", feature = "streaming"))]
345 pub fn stream_blocking(self, client: &BlockingLLMClient) -> Result<BlockingProxyHandle> {
346 let req = self.build_request()?;
347 let opts = self.build_options();
348 client.proxy_stream(req, opts)
349 }
350
351 #[cfg(all(feature = "blocking", feature = "streaming"))]
353 pub fn stream_deltas_blocking(
354 self,
355 client: &BlockingLLMClient,
356 ) -> Result<Box<dyn Iterator<Item = Result<String>>>> {
357 let req = self.build_request()?;
358 let opts = self.build_options();
359 client.proxy_stream_deltas(req, opts)
360 }
361
362 #[cfg(all(feature = "blocking", feature = "streaming"))]
367 pub fn stream_json_blocking<T>(
368 self,
369 client: &BlockingLLMClient,
370 ) -> Result<BlockingStructuredJSONStream<T>>
371 where
372 T: DeserializeOwned,
373 {
374 let req = self.build_request()?;
375 match &req.response_format {
376 Some(format) if format.is_structured() => {}
377 Some(_) => {
378 return Err(Error::Validation(
379 ValidationError::new("response_format must be structured (type=json_schema)")
380 .with_field("response_format.type"),
381 ));
382 }
383 None => {
384 return Err(Error::Validation(
385 ValidationError::new("response_format is required for structured streaming")
386 .with_field("response_format"),
387 ));
388 }
389 }
390 let opts = self.build_options();
391 let stream = client.proxy_stream(req, opts)?;
392 Ok(BlockingStructuredJSONStream::new(stream))
393 }
394}
395
396pub const CUSTOMER_ID_HEADER: &str = "X-ModelRelay-Customer-Id";
398
399#[derive(Clone, Debug, Default)]
405pub struct CustomerChatRequestBuilder {
406 pub(crate) customer_id: String,
407 pub(crate) max_tokens: Option<i64>,
408 pub(crate) temperature: Option<f64>,
409 pub(crate) messages: Vec<ProxyMessage>,
410 pub(crate) metadata: Option<HashMap<String, String>>,
411 pub(crate) response_format: Option<ResponseFormat>,
412 pub(crate) stop: Option<Vec<String>>,
413 pub(crate) tools: Option<Vec<crate::types::Tool>>,
414 pub(crate) tool_choice: Option<crate::types::ToolChoice>,
415 pub(crate) request_id: Option<String>,
416 pub(crate) headers: Vec<(String, String)>,
417 pub(crate) timeout: Option<Duration>,
418 pub(crate) retry: Option<RetryConfig>,
419}
420
421impl_chat_builder_common!(CustomerChatRequestBuilder);
423
424impl CustomerChatRequestBuilder {
425 pub fn new(customer_id: impl Into<String>) -> Self {
427 Self {
428 customer_id: customer_id.into(),
429 ..Default::default()
430 }
431 }
432
433 fn build_options(&self) -> ProxyOptions {
434 let mut opts = ProxyOptions::default();
435 if let Some(req_id) = &self.request_id {
436 opts = opts.with_request_id(req_id.clone());
437 }
438 for (k, v) in &self.headers {
440 opts = opts.with_header(k.clone(), v.clone());
441 }
442 if let Some(timeout) = self.timeout {
443 opts = opts.with_timeout(timeout);
444 }
445 if let Some(retry) = &self.retry {
446 opts = opts.with_retry(retry.clone());
447 }
448 opts
449 }
450
451 pub(crate) fn build_request_body(&self) -> Result<CustomerProxyRequestBody> {
453 if self.messages.is_empty() {
454 return Err(Error::Validation(
455 crate::errors::ValidationError::new("at least one message is required")
456 .with_field("messages"),
457 ));
458 }
459 if !self
460 .messages
461 .iter()
462 .any(|msg| msg.role == crate::types::MessageRole::User)
463 {
464 return Err(Error::Validation(
465 crate::errors::ValidationError::new("at least one user message is required")
466 .with_field("messages"),
467 ));
468 }
469 Ok(CustomerProxyRequestBody {
470 max_tokens: self.max_tokens,
471 temperature: self.temperature,
472 messages: self.messages.clone(),
473 metadata: self.metadata.clone(),
474 response_format: self.response_format.clone(),
475 stop: self.stop.clone(),
476 })
477 }
478
479 #[cfg(feature = "client")]
481 pub async fn send(self, client: &LLMClient) -> Result<ProxyResponse> {
482 let body = self.build_request_body()?;
483 let opts = self.build_options();
484 client.proxy_customer(&self.customer_id, body, opts).await
485 }
486
487 #[cfg(all(feature = "client", feature = "streaming"))]
489 pub async fn stream(self, client: &LLMClient) -> Result<StreamHandle> {
490 let body = self.build_request_body()?;
491 let opts = self.build_options();
492 client
493 .proxy_customer_stream(&self.customer_id, body, opts)
494 .await
495 }
496
497 #[cfg(feature = "blocking")]
499 pub fn send_blocking(self, client: &BlockingLLMClient) -> Result<ProxyResponse> {
500 let body = self.build_request_body()?;
501 let opts = self.build_options();
502 client.proxy_customer(&self.customer_id, body, opts)
503 }
504
505 #[cfg(all(feature = "blocking", feature = "streaming"))]
507 pub fn stream_blocking(self, client: &BlockingLLMClient) -> Result<BlockingProxyHandle> {
508 let body = self.build_request_body()?;
509 let opts = self.build_options();
510 client.proxy_customer_stream(&self.customer_id, body, opts)
511 }
512
513 #[cfg(all(feature = "client", feature = "streaming"))]
540 pub async fn stream_json<T>(self, client: &LLMClient) -> Result<StructuredJSONStream<T>>
541 where
542 T: DeserializeOwned,
543 {
544 let body = self.build_request_body()?;
545 match &body.response_format {
546 Some(format) if format.is_structured() => {}
547 Some(_) => {
548 return Err(Error::Validation(
549 ValidationError::new("response_format must be structured (type=json_schema)")
550 .with_field("response_format.type"),
551 ));
552 }
553 None => {
554 return Err(Error::Validation(
555 ValidationError::new("response_format is required for structured streaming")
556 .with_field("response_format"),
557 ));
558 }
559 }
560 let opts = self.build_options();
561 let stream = client
562 .proxy_customer_stream(&self.customer_id, body, opts)
563 .await?;
564 Ok(StructuredJSONStream::new(stream))
565 }
566
567 #[cfg(all(feature = "blocking", feature = "streaming"))]
593 pub fn stream_json_blocking<T>(
594 self,
595 client: &BlockingLLMClient,
596 ) -> Result<BlockingStructuredJSONStream<T>>
597 where
598 T: DeserializeOwned,
599 {
600 let body = self.build_request_body()?;
601 match &body.response_format {
602 Some(format) if format.is_structured() => {}
603 Some(_) => {
604 return Err(Error::Validation(
605 ValidationError::new("response_format must be structured (type=json_schema)")
606 .with_field("response_format.type"),
607 ));
608 }
609 None => {
610 return Err(Error::Validation(
611 ValidationError::new("response_format is required for structured streaming")
612 .with_field("response_format"),
613 ));
614 }
615 }
616 let opts = self.build_options();
617 let stream = client.proxy_customer_stream(&self.customer_id, body, opts)?;
618 Ok(BlockingStructuredJSONStream::new(stream))
619 }
620}
621
622#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq)]
624pub struct CustomerProxyRequestBody {
625 #[serde(skip_serializing_if = "Option::is_none")]
626 pub max_tokens: Option<i64>,
627 #[serde(skip_serializing_if = "Option::is_none")]
628 pub temperature: Option<f64>,
629 pub messages: Vec<ProxyMessage>,
630 #[serde(skip_serializing_if = "Option::is_none")]
631 pub metadata: Option<HashMap<String, String>>,
632 #[serde(skip_serializing_if = "Option::is_none")]
633 pub response_format: Option<ResponseFormat>,
634 #[serde(skip_serializing_if = "Option::is_none")]
635 pub stop: Option<Vec<String>>,
636}
637
638#[cfg(feature = "streaming")]
640#[derive(Debug)]
641pub struct ChatStreamAdapter<S> {
642 inner: S,
643 finished: bool,
644 final_usage: Option<Usage>,
645 final_stop_reason: Option<StopReason>,
646 final_request_id: Option<String>,
647}
648
649#[cfg(all(feature = "client", feature = "streaming"))]
650impl ChatStreamAdapter<StreamHandle> {
651 pub fn new(stream: StreamHandle) -> Self {
652 Self {
653 inner: stream,
654 finished: false,
655 final_usage: None,
656 final_stop_reason: None,
657 final_request_id: None,
658 }
659 }
660
661 pub async fn next_delta(&mut self) -> Result<Option<String>> {
663 use futures_util::StreamExt;
664
665 while let Some(item) = self.inner.next().await {
666 let evt = item?;
667 match evt.kind {
668 StreamEventKind::MessageDelta => {
669 if let Some(delta) = evt.text_delta {
670 return Ok(Some(delta));
671 }
672 }
673 StreamEventKind::MessageStop => {
674 self.finished = true;
675 self.final_usage = evt.usage;
676 self.final_stop_reason = evt.stop_reason;
677 self.final_request_id = evt
678 .request_id
679 .or_else(|| self.inner.request_id().map(|s| s.to_string()));
680 return Ok(None);
681 }
682 _ => {}
683 }
684 }
685 Ok(None)
686 }
687
688 pub fn final_usage(&self) -> Option<&Usage> {
690 self.final_usage.as_ref()
691 }
692
693 pub fn final_stop_reason(&self) -> Option<&StopReason> {
695 self.final_stop_reason.as_ref()
696 }
697
698 pub fn final_request_id(&self) -> Option<&str> {
700 self.final_request_id.as_deref()
701 }
702
703 pub fn into_stream(self) -> impl futures_core::Stream<Item = Result<String>> {
705 stream::unfold(self, |mut adapter| async move {
706 match adapter.next_delta().await {
707 Ok(Some(delta)) => Some((Ok(delta), adapter)),
708 Ok(None) => None,
709 Err(err) => Some((Err(err), adapter)),
710 }
711 })
712 }
713}
714
715#[cfg(feature = "streaming")]
717#[derive(Debug, Clone, Copy, PartialEq, Eq)]
718pub enum StructuredRecordKind {
719 Update,
720 Completion,
721}
722
723#[cfg(feature = "streaming")]
725#[derive(Debug, Clone)]
726pub struct StructuredJSONEvent<T> {
727 pub kind: StructuredRecordKind,
728 pub payload: T,
729 pub request_id: Option<String>,
730 pub complete_fields: std::collections::HashSet<String>,
734}
735
736#[cfg(all(feature = "client", feature = "streaming"))]
738pub struct StructuredJSONStream<T> {
739 inner: StreamHandle,
740 finished: bool,
741 saw_completion: bool,
742 _marker: std::marker::PhantomData<T>,
743}
744
745#[cfg(all(feature = "client", feature = "streaming"))]
746impl<T> StructuredJSONStream<T>
747where
748 T: DeserializeOwned,
749{
750 pub fn new(stream: StreamHandle) -> Self {
751 Self {
752 inner: stream,
753 finished: false,
754 saw_completion: false,
755 _marker: std::marker::PhantomData,
756 }
757 }
758
759 pub async fn next(&mut self) -> Result<Option<StructuredJSONEvent<T>>> {
761 use futures_util::StreamExt;
762
763 if self.finished {
764 return Ok(None);
765 }
766
767 while let Some(item) = self.inner.next().await {
768 let evt = item?;
769 let value = match evt.data {
770 Some(ref v) if v.is_object() => v,
771 _ => continue,
772 };
773 let record_type = value
774 .get("type")
775 .and_then(|v| v.as_str())
776 .map(|s| s.trim().to_lowercase())
777 .unwrap_or_default();
778
779 match record_type.as_str() {
780 "" | "start" => continue,
781 "update" | "completion" => {
782 let payload_value = value.get("payload").cloned().ok_or_else(|| {
783 Error::Transport(TransportError {
784 kind: TransportErrorKind::Request,
785 message: "structured stream record missing payload".to_string(),
786 source: None,
787 retries: None,
788 })
789 })?;
790 let payload: T =
791 serde_json::from_value(payload_value).map_err(Error::Serialization)?;
792 let kind = if record_type == "update" {
793 StructuredRecordKind::Update
794 } else {
795 self.saw_completion = true;
796 StructuredRecordKind::Completion
797 };
798 let request_id = evt
799 .request_id
800 .or_else(|| self.inner.request_id().map(|s| s.to_string()));
801 let complete_fields: std::collections::HashSet<String> = value
803 .get("complete_fields")
804 .and_then(|v| v.as_array())
805 .map(|arr| {
806 arr.iter()
807 .filter_map(|v| v.as_str().map(|s| s.to_string()))
808 .collect()
809 })
810 .unwrap_or_default();
811 return Ok(Some(StructuredJSONEvent {
812 kind,
813 payload,
814 request_id,
815 complete_fields,
816 }));
817 }
818 "error" => {
819 self.saw_completion = true;
820 let code = value
821 .get("code")
822 .and_then(|v| v.as_str())
823 .map(|s| s.to_string());
824 let message = value
825 .get("message")
826 .and_then(|v| v.as_str())
827 .unwrap_or("structured stream error")
828 .to_string();
829 let status = value
830 .get("status")
831 .and_then(|v| v.as_u64())
832 .map(|v| v as u16)
833 .unwrap_or(500);
834 let request_id = evt
835 .request_id
836 .or_else(|| self.inner.request_id().map(|s| s.to_string()));
837 return Err(APIError {
838 status,
839 code,
840 message,
841 request_id,
842 fields: Vec::new(),
843 retries: None,
844 raw_body: None,
845 }
846 .into());
847 }
848 _ => continue,
849 }
850 }
851
852 self.finished = true;
853 if !self.saw_completion {
854 return Err(Error::Transport(TransportError {
855 kind: TransportErrorKind::Request,
856 message: "structured stream ended without completion or error".to_string(),
857 source: None,
858 retries: None,
859 }));
860 }
861 Ok(None)
862 }
863
864 pub async fn collect(mut self) -> Result<T> {
866 let mut last: Option<T> = None;
867 while let Some(event) = self.next().await? {
868 if matches!(event.kind, StructuredRecordKind::Completion) {
869 return Ok(event.payload);
870 }
871 last = Some(event.payload);
872 }
873 match last {
874 Some(payload) => Ok(payload),
875 None => Err(Error::Transport(TransportError {
876 kind: TransportErrorKind::Request,
877 message: "structured stream ended without completion or error".to_string(),
878 source: None,
879 retries: None,
880 })),
881 }
882 }
883
884 pub fn request_id(&self) -> Option<&str> {
886 self.inner.request_id()
887 }
888}
889
890#[cfg(all(feature = "blocking", feature = "streaming"))]
892pub struct BlockingStructuredJSONStream<T> {
893 inner: BlockingProxyHandle,
894 finished: bool,
895 saw_completion: bool,
896 _marker: std::marker::PhantomData<T>,
897}
898
899#[cfg(all(feature = "blocking", feature = "streaming"))]
900impl<T> BlockingStructuredJSONStream<T>
901where
902 T: DeserializeOwned,
903{
904 pub fn new(stream: BlockingProxyHandle) -> Self {
905 Self {
906 inner: stream,
907 finished: false,
908 saw_completion: false,
909 _marker: std::marker::PhantomData,
910 }
911 }
912
913 #[allow(clippy::should_implement_trait)]
915 pub fn next(&mut self) -> Result<Option<StructuredJSONEvent<T>>> {
916 if self.finished {
917 return Ok(None);
918 }
919
920 while let Some(evt) = self.inner.next()? {
921 let value = match evt.data {
922 Some(ref v) if v.is_object() => v,
923 _ => continue,
924 };
925 let record_type = value
926 .get("type")
927 .and_then(|v| v.as_str())
928 .map(|s| s.trim().to_lowercase())
929 .unwrap_or_default();
930
931 match record_type.as_str() {
932 "" | "start" => continue,
933 "update" | "completion" => {
934 let payload_value = value.get("payload").cloned().ok_or_else(|| {
935 Error::Transport(TransportError {
936 kind: TransportErrorKind::Request,
937 message: "structured stream record missing payload".to_string(),
938 source: None,
939 retries: None,
940 })
941 })?;
942 let payload: T =
943 serde_json::from_value(payload_value).map_err(Error::Serialization)?;
944 let kind = if record_type == "update" {
945 StructuredRecordKind::Update
946 } else {
947 self.saw_completion = true;
948 StructuredRecordKind::Completion
949 };
950 let request_id = evt
951 .request_id
952 .or_else(|| self.inner.request_id().map(|s| s.to_string()));
953 let complete_fields: std::collections::HashSet<String> = value
954 .get("complete_fields")
955 .and_then(|v| v.as_array())
956 .map(|arr| {
957 arr.iter()
958 .filter_map(|v| v.as_str().map(|s| s.to_string()))
959 .collect()
960 })
961 .unwrap_or_default();
962 return Ok(Some(StructuredJSONEvent {
963 kind,
964 payload,
965 request_id,
966 complete_fields,
967 }));
968 }
969 "error" => {
970 self.saw_completion = true;
971 let code = value
972 .get("code")
973 .and_then(|v| v.as_str())
974 .map(|s| s.to_string());
975 let message = value
976 .get("message")
977 .and_then(|v| v.as_str())
978 .unwrap_or("structured stream error")
979 .to_string();
980 let status = value
981 .get("status")
982 .and_then(|v| v.as_u64())
983 .map(|v| v as u16)
984 .unwrap_or(500);
985 let request_id = evt
986 .request_id
987 .or_else(|| self.inner.request_id().map(|s| s.to_string()));
988 return Err(APIError {
989 status,
990 code,
991 message,
992 request_id,
993 fields: Vec::new(),
994 retries: None,
995 raw_body: None,
996 }
997 .into());
998 }
999 _ => continue,
1000 }
1001 }
1002
1003 self.finished = true;
1004 if !self.saw_completion {
1005 return Err(Error::Transport(TransportError {
1006 kind: TransportErrorKind::Request,
1007 message: "structured stream ended without completion or error".to_string(),
1008 source: None,
1009 retries: None,
1010 }));
1011 }
1012 Ok(None)
1013 }
1014
1015 pub fn collect(mut self) -> Result<T> {
1017 let mut last: Option<T> = None;
1018 while let Some(event) = self.next()? {
1019 if matches!(event.kind, StructuredRecordKind::Completion) {
1020 return Ok(event.payload);
1021 }
1022 last = Some(event.payload);
1023 }
1024 match last {
1025 Some(payload) => Ok(payload),
1026 None => Err(Error::Transport(TransportError {
1027 kind: TransportErrorKind::Request,
1028 message: "structured stream ended without completion or error".to_string(),
1029 source: None,
1030 retries: None,
1031 })),
1032 }
1033 }
1034
1035 pub fn request_id(&self) -> Option<&str> {
1037 self.inner.request_id()
1038 }
1039}
1040
1041#[cfg(all(feature = "blocking", feature = "streaming"))]
1043impl ChatStreamAdapter<BlockingProxyHandle> {
1044 pub fn new(stream: BlockingProxyHandle) -> Self {
1045 Self {
1046 inner: stream,
1047 finished: false,
1048 final_usage: None,
1049 final_stop_reason: None,
1050 final_request_id: None,
1051 }
1052 }
1053
1054 pub fn request_id(&self) -> Option<&str> {
1055 self.inner.request_id()
1056 }
1057
1058 pub fn next_delta(&mut self) -> Result<Option<String>> {
1059 while let Some(evt) = self.inner.next()? {
1060 match evt.kind {
1061 StreamEventKind::MessageDelta => {
1062 if let Some(delta) = evt.text_delta {
1063 return Ok(Some(delta));
1064 }
1065 }
1066 StreamEventKind::MessageStop => {
1067 self.finished = true;
1068 self.final_usage = evt.usage;
1069 self.final_stop_reason = evt.stop_reason;
1070 self.final_request_id = evt
1071 .request_id
1072 .or_else(|| self.inner.request_id().map(|s| s.to_string()));
1073 return Ok(None);
1074 }
1075 _ => {}
1076 }
1077 }
1078 Ok(None)
1079 }
1080
1081 pub fn final_usage(&self) -> Option<&Usage> {
1082 self.final_usage.as_ref()
1083 }
1084
1085 pub fn final_stop_reason(&self) -> Option<&StopReason> {
1086 self.final_stop_reason.as_ref()
1087 }
1088
1089 pub fn final_request_id(&self) -> Option<&str> {
1090 self.final_request_id.as_deref()
1091 }
1092
1093 #[allow(clippy::should_implement_trait)]
1095 pub fn into_iter(self) -> impl Iterator<Item = Result<String>> {
1096 let mut adapter = self;
1097 std::iter::from_fn(move || match adapter.next_delta() {
1098 Ok(Some(delta)) => Some(Ok(delta)),
1099 Ok(None) => None,
1100 Err(err) => Some(Err(err)),
1101 })
1102 }
1103}
1104
1105#[cfg(test)]
1106mod tests {
1107 use super::*;
1108 use crate::types::{Model, ResponseFormatKind, StreamEvent, StreamEventKind};
1109 use crate::ClientBuilder;
1110
1111 #[test]
1112 fn build_request_requires_user_message() {
1113 let builder = ChatRequestBuilder::new(Model::from("gpt-4o-mini")).system("just a system");
1114 let err = builder.build_request().unwrap_err();
1115 match err {
1116 Error::Validation(msg) => {
1117 assert!(
1118 msg.to_string().contains("user"),
1119 "unexpected validation: {msg}"
1120 );
1121 }
1122 other => panic!("expected validation error, got {other:?}"),
1123 }
1124 }
1125
1126 #[test]
1127 fn metadata_entry_ignores_empty_pairs() {
1128 let req = ChatRequestBuilder::new(Model::from("gpt-4o-mini"))
1129 .user("hello")
1130 .metadata_entry("trace_id", "abc123")
1131 .metadata_entry("", "should_skip")
1132 .metadata_entry("empty", "")
1133 .build_request()
1134 .unwrap();
1135 let meta = req.metadata.unwrap();
1136 assert_eq!(meta.len(), 1);
1137 assert_eq!(meta.get("trace_id"), Some(&"abc123".to_string()));
1138 }
1139
1140 #[test]
1141 fn role_helpers_append_expected_roles() {
1142 use crate::types::MessageRole;
1143 let req = ChatRequestBuilder::new("gpt-4o-mini")
1144 .system("sys")
1145 .user("u1")
1146 .assistant("a1")
1147 .build_request()
1148 .unwrap();
1149 let roles: Vec<_> = req.messages.iter().map(|m| m.role).collect();
1150 assert_eq!(
1151 roles,
1152 vec![
1153 MessageRole::System,
1154 MessageRole::User,
1155 MessageRole::Assistant
1156 ]
1157 );
1158 }
1159
1160 #[cfg(all(feature = "client", feature = "streaming"))]
1161 #[tokio::test]
1162 async fn stream_json_requires_structured_response_format() {
1163 let client = ClientBuilder::new()
1164 .api_key("mr_sk_test")
1165 .build()
1166 .expect("client build");
1167
1168 let builder = ChatRequestBuilder::new(Model::from("gpt-4o-mini")).user("hi");
1170 let result = builder
1171 .clone()
1172 .stream_json::<serde_json::Value>(&client.llm())
1173 .await;
1174 match result {
1175 Err(Error::Validation(v)) => {
1176 assert!(
1177 v.to_string().contains("response_format"),
1178 "unexpected validation error: {v}"
1179 );
1180 }
1181 Ok(_) => panic!("expected Validation error, got Ok"),
1182 Err(other) => panic!("expected Validation error, got {other:?}"),
1183 }
1184
1185 let format = ResponseFormat {
1187 kind: ResponseFormatKind::Text,
1188 json_schema: None,
1189 };
1190 let builder = ChatRequestBuilder::new(Model::from("gpt-4o-mini"))
1191 .user("hi")
1192 .response_format(format);
1193 let result = builder
1194 .stream_json::<serde_json::Value>(&client.llm())
1195 .await;
1196 match result {
1197 Err(Error::Validation(v)) => {
1198 assert!(
1199 v.to_string().contains("response_format must be structured"),
1200 "unexpected validation error: {v}"
1201 );
1202 }
1203 Ok(_) => panic!("expected Validation error, got Ok"),
1204 Err(other) => panic!("expected Validation error, got {other:?}"),
1205 }
1206 }
1207
1208 #[cfg(all(feature = "client", feature = "streaming"))]
1209 #[tokio::test]
1210 async fn structured_json_stream_yields_update_and_completion() {
1211 #[derive(Debug, serde::Deserialize, PartialEq)]
1212 struct Item {
1213 id: String,
1214 }
1215
1216 #[derive(Debug, serde::Deserialize, PartialEq)]
1217 struct ItemsPayload {
1218 items: Vec<Item>,
1219 }
1220
1221 let events = vec![
1222 StreamEvent {
1223 kind: StreamEventKind::Custom,
1224 event: "structured".into(),
1225 data: Some(serde_json::json!({"type":"start","request_id":"tiers-1"})),
1226 text_delta: None,
1227 tool_call_delta: None,
1228 tool_calls: None,
1229 response_id: None,
1230 model: None,
1231 stop_reason: None,
1232 usage: None,
1233 request_id: None,
1234 raw: String::new(),
1235 },
1236 StreamEvent {
1237 kind: StreamEventKind::Custom,
1238 event: "structured".into(),
1239 data: Some(serde_json::json!({"type":"update","payload":{"items":[{"id":"one"}]}})),
1240 text_delta: None,
1241 tool_call_delta: None,
1242 tool_calls: None,
1243 response_id: None,
1244 model: None,
1245 stop_reason: None,
1246 usage: None,
1247 request_id: None,
1248 raw: String::new(),
1249 },
1250 StreamEvent {
1251 kind: StreamEventKind::Custom,
1252 event: "structured".into(),
1253 data: Some(
1254 serde_json::json!({"type":"completion","payload":{"items":[{"id":"one"},{"id":"two"}]}}),
1255 ),
1256 text_delta: None,
1257 tool_call_delta: None,
1258 tool_calls: None,
1259 response_id: None,
1260 model: None,
1261 stop_reason: None,
1262 usage: None,
1263 request_id: None,
1264 raw: String::new(),
1265 },
1266 ];
1267
1268 let handle = StreamHandle::from_events_with_request_id(
1269 events.clone(),
1270 Some("req-structured".into()),
1271 );
1272 let mut stream = StructuredJSONStream::<ItemsPayload>::new(handle);
1273
1274 let first = stream.next().await.unwrap().unwrap();
1275 assert_eq!(first.kind, StructuredRecordKind::Update);
1276 assert_eq!(first.payload.items.len(), 1);
1277 assert_eq!(first.payload.items[0].id, "one");
1278
1279 let second = stream.next().await.unwrap().unwrap();
1280 assert_eq!(second.kind, StructuredRecordKind::Completion);
1281 assert_eq!(second.payload.items.len(), 2);
1282 assert_eq!(second.request_id.as_deref(), Some("req-structured"));
1283
1284 let handle2 =
1285 StreamHandle::from_events_with_request_id(events, Some("req-structured".into()));
1286 let stream2 = StructuredJSONStream::<ItemsPayload>::new(handle2);
1287 let collected = stream2.collect().await.unwrap();
1288 assert_eq!(collected.items.len(), 2);
1289 }
1290
1291 #[cfg(all(feature = "client", feature = "streaming"))]
1292 #[tokio::test]
1293 async fn structured_json_stream_maps_error_and_protocol_violation() {
1294 let error_events = vec![StreamEvent {
1296 kind: StreamEventKind::Custom,
1297 event: "structured".into(),
1298 data: Some(
1299 serde_json::json!({"type":"error","code":"SERVICE_UNAVAILABLE","message":"upstream timeout","status":502}),
1300 ),
1301 text_delta: None,
1302 tool_call_delta: None,
1303 tool_calls: None,
1304 response_id: None,
1305 model: None,
1306 stop_reason: None,
1307 usage: None,
1308 request_id: None,
1309 raw: String::new(),
1310 }];
1311 let handle_err =
1312 StreamHandle::from_events_with_request_id(error_events, Some("req-error".into()));
1313 let mut err_stream = StructuredJSONStream::<serde_json::Value>::new(handle_err);
1314 let err = err_stream.next().await.unwrap_err();
1315 match err {
1316 Error::Api(api) => {
1317 assert_eq!(api.status, 502);
1318 assert_eq!(api.code.as_deref(), Some("SERVICE_UNAVAILABLE"));
1319 assert_eq!(api.request_id.as_deref(), Some("req-error"));
1320 }
1321 other => panic!("expected API error, got {other:?}"),
1322 }
1323
1324 let update_only = vec![StreamEvent {
1326 kind: StreamEventKind::Custom,
1327 event: "structured".into(),
1328 data: Some(serde_json::json!({"type":"update","payload":{"items":[{"id":"one"}]}})),
1329 text_delta: None,
1330 tool_call_delta: None,
1331 tool_calls: None,
1332 response_id: None,
1333 model: None,
1334 stop_reason: None,
1335 usage: None,
1336 request_id: None,
1337 raw: String::new(),
1338 }];
1339 let handle_proto =
1340 StreamHandle::from_events_with_request_id(update_only, Some("req-incomplete".into()));
1341 let stream_proto = StructuredJSONStream::<serde_json::Value>::new(handle_proto);
1342 let err = stream_proto.collect().await.unwrap_err();
1343 match err {
1344 Error::Transport(te) => {
1345 assert!(
1346 te.message
1347 .contains("structured stream ended without completion or error"),
1348 "unexpected message: {}",
1349 te.message
1350 );
1351 }
1352 other => panic!("expected Transport error, got {other:?}"),
1353 }
1354 }
1355
1356 #[cfg(all(feature = "client", feature = "streaming"))]
1357 #[tokio::test]
1358 async fn customer_stream_json_requires_structured_response_format() {
1359 let client = ClientBuilder::new()
1360 .api_key("mr_sk_test")
1361 .build()
1362 .expect("client build");
1363
1364 let builder = CustomerChatRequestBuilder::new("customer-123").user("hi");
1366 let result = builder
1367 .clone()
1368 .stream_json::<serde_json::Value>(&client.llm())
1369 .await;
1370 match result {
1371 Err(Error::Validation(v)) => {
1372 assert!(
1373 v.to_string().contains("response_format"),
1374 "unexpected validation error: {v}"
1375 );
1376 }
1377 Ok(_) => panic!("expected Validation error, got Ok"),
1378 Err(other) => panic!("expected Validation error, got {other:?}"),
1379 }
1380
1381 let format = ResponseFormat {
1383 kind: ResponseFormatKind::Text,
1384 json_schema: None,
1385 };
1386 let builder = CustomerChatRequestBuilder::new("customer-123")
1387 .user("hi")
1388 .response_format(format);
1389 let result = builder
1390 .stream_json::<serde_json::Value>(&client.llm())
1391 .await;
1392 match result {
1393 Err(Error::Validation(v)) => {
1394 assert!(
1395 v.to_string().contains("response_format must be structured"),
1396 "unexpected validation error: {v}"
1397 );
1398 }
1399 Ok(_) => panic!("expected Validation error, got Ok"),
1400 Err(other) => panic!("expected Validation error, got {other:?}"),
1401 }
1402 }
1403
1404 #[test]
1405 fn customer_build_request_body_requires_user_message() {
1406 let builder = CustomerChatRequestBuilder::new("customer-123").system("just a system");
1407 let err = builder.build_request_body().unwrap_err();
1408 match err {
1409 Error::Validation(msg) => {
1410 assert!(
1411 msg.to_string().contains("user"),
1412 "unexpected validation: {msg}"
1413 );
1414 }
1415 other => panic!("expected validation error, got {other:?}"),
1416 }
1417 }
1418
1419 #[test]
1420 fn customer_metadata_entry_ignores_empty_pairs() {
1421 let body = CustomerChatRequestBuilder::new("customer-123")
1422 .user("hello")
1423 .metadata_entry("trace_id", "abc123")
1424 .metadata_entry("", "should_skip")
1425 .metadata_entry("empty", "")
1426 .build_request_body()
1427 .unwrap();
1428 let meta = body.metadata.unwrap();
1429 assert_eq!(meta.len(), 1);
1430 assert_eq!(meta.get("trace_id"), Some(&"abc123".to_string()));
1431 }
1432}