modelrelay/
chat.rs

1#![cfg(any(feature = "client", feature = "blocking"))]
2
3use std::collections::HashMap;
4use std::time::Duration;
5
6use crate::errors::{APIError, Error, Result, TransportError, TransportErrorKind, ValidationError};
7#[cfg(feature = "streaming")]
8use crate::types::StreamEventKind;
9use crate::types::{
10    Model, ProxyMessage, ProxyRequest, ProxyResponse, ResponseFormat, StopReason, Usage,
11};
12
13#[cfg(feature = "blocking")]
14use crate::blocking::BlockingLLMClient;
15#[cfg(all(feature = "blocking", feature = "streaming"))]
16use crate::blocking::BlockingProxyHandle;
17#[cfg(feature = "client")]
18use crate::client::LLMClient;
19#[cfg(all(feature = "client", feature = "streaming"))]
20use crate::sse::StreamHandle;
21
22#[cfg(any(feature = "client", feature = "blocking"))]
23use crate::{ProxyOptions, RetryConfig};
24#[cfg(all(feature = "client", feature = "streaming"))]
25use futures_util::stream;
26use schemars::JsonSchema;
27#[cfg(all(feature = "client", feature = "streaming"))]
28use serde::de::DeserializeOwned;
29
30/// Macro to implement shared builder methods for chat request builders.
31///
32/// Both `ChatRequestBuilder` and `CustomerChatRequestBuilder` share identical
33/// methods for setting messages, parameters, and options. This macro eliminates
34/// the duplication.
35macro_rules! impl_chat_builder_common {
36    ($builder:ty) => {
37        impl $builder {
38            /// Add a message with the given role and content.
39            pub fn message(
40                mut self,
41                role: crate::types::MessageRole,
42                content: impl Into<String>,
43            ) -> Self {
44                self.messages.push(ProxyMessage {
45                    role,
46                    content: content.into(),
47                    tool_calls: None,
48                    tool_call_id: None,
49                });
50                self
51            }
52
53            /// Add a system message.
54            pub fn system(self, content: impl Into<String>) -> Self {
55                self.message(crate::types::MessageRole::System, content)
56            }
57
58            /// Add a user message.
59            pub fn user(self, content: impl Into<String>) -> Self {
60                self.message(crate::types::MessageRole::User, content)
61            }
62
63            /// Add an assistant message.
64            pub fn assistant(self, content: impl Into<String>) -> Self {
65                self.message(crate::types::MessageRole::Assistant, content)
66            }
67
68            /// Set the full message list, replacing any existing messages.
69            pub fn messages(mut self, messages: Vec<ProxyMessage>) -> Self {
70                self.messages = messages;
71                self
72            }
73
74            /// Set the maximum number of tokens to generate.
75            pub fn max_tokens(mut self, max_tokens: i64) -> Self {
76                self.max_tokens = Some(max_tokens);
77                self
78            }
79
80            /// Set the sampling temperature.
81            pub fn temperature(mut self, temperature: f64) -> Self {
82                self.temperature = Some(temperature);
83                self
84            }
85
86            /// Set request metadata.
87            pub fn metadata(mut self, metadata: HashMap<String, String>) -> Self {
88                self.metadata = Some(metadata);
89                self
90            }
91
92            /// Add a single metadata entry. Empty keys or values are ignored.
93            pub fn metadata_entry(
94                mut self,
95                key: impl Into<String>,
96                value: impl Into<String>,
97            ) -> Self {
98                let key = key.into();
99                let value = value.into();
100                if key.trim().is_empty() || value.trim().is_empty() {
101                    return self;
102                }
103                let mut map = self.metadata.unwrap_or_default();
104                map.insert(key, value);
105                self.metadata = Some(map);
106                self
107            }
108
109            /// Set the response format (e.g., JSON schema for structured outputs).
110            pub fn response_format(mut self, response_format: ResponseFormat) -> Self {
111                self.response_format = Some(response_format);
112                self
113            }
114
115            /// Set stop sequences.
116            pub fn stop(mut self, stop: Vec<String>) -> Self {
117                self.stop = Some(stop);
118                self
119            }
120
121            /// Set tools available for the model to call.
122            pub fn tools(mut self, tools: Vec<crate::types::Tool>) -> Self {
123                self.tools = Some(tools);
124                self
125            }
126
127            /// Set the tool choice strategy.
128            pub fn tool_choice(mut self, tool_choice: crate::types::ToolChoice) -> Self {
129                self.tool_choice = Some(tool_choice);
130                self
131            }
132
133            /// Set a request ID for tracing.
134            pub fn request_id(mut self, request_id: impl Into<String>) -> Self {
135                self.request_id = Some(request_id.into());
136                self
137            }
138
139            /// Add a custom header to the request.
140            pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
141                self.headers.push((key.into(), value.into()));
142                self
143            }
144
145            /// Set the request timeout.
146            pub fn timeout(mut self, timeout: Duration) -> Self {
147                self.timeout = Some(timeout);
148                self
149            }
150
151            /// Set retry configuration.
152            pub fn retry(mut self, retry: RetryConfig) -> Self {
153                self.retry = Some(retry);
154                self
155            }
156        }
157    };
158}
159
160/// Builder for LLM proxy chat requests (async + streaming).
161#[derive(Clone, Debug, Default)]
162pub struct ChatRequestBuilder {
163    pub(crate) model: Option<Model>,
164    pub(crate) max_tokens: Option<i64>,
165    pub(crate) temperature: Option<f64>,
166    pub(crate) messages: Vec<ProxyMessage>,
167    pub(crate) metadata: Option<HashMap<String, String>>,
168    pub(crate) response_format: Option<ResponseFormat>,
169    pub(crate) stop: Option<Vec<String>>,
170    pub(crate) tools: Option<Vec<crate::types::Tool>>,
171    pub(crate) tool_choice: Option<crate::types::ToolChoice>,
172    pub(crate) request_id: Option<String>,
173    pub(crate) headers: Vec<(String, String)>,
174    pub(crate) timeout: Option<Duration>,
175    pub(crate) retry: Option<RetryConfig>,
176}
177
178// Generate shared builder methods for ChatRequestBuilder
179impl_chat_builder_common!(ChatRequestBuilder);
180
181impl ChatRequestBuilder {
182    /// Create a new chat request builder for the given model.
183    pub fn new(model: impl Into<Model>) -> Self {
184        Self {
185            model: Some(model.into()),
186            ..Default::default()
187        }
188    }
189
190    fn build_options(&self) -> ProxyOptions {
191        let mut opts = ProxyOptions::default();
192        if let Some(req_id) = &self.request_id {
193            opts = opts.with_request_id(req_id.clone());
194        }
195        for (k, v) in &self.headers {
196            opts = opts.with_header(k.clone(), v.clone());
197        }
198        if let Some(timeout) = self.timeout {
199            opts = opts.with_timeout(timeout);
200        }
201        if let Some(retry) = &self.retry {
202            opts = opts.with_retry(retry.clone());
203        }
204        opts
205    }
206
207    pub fn build_request(&self) -> Result<ProxyRequest> {
208        let model = self
209            .model
210            .clone()
211            .ok_or_else(|| Error::Validation("model is required".into()))?;
212
213        if self.messages.is_empty() {
214            return Err(Error::Validation(
215                ValidationError::new("at least one message is required").with_field("messages"),
216            ));
217        }
218        if !self
219            .messages
220            .iter()
221            .any(|msg| msg.role == crate::types::MessageRole::User)
222        {
223            return Err(Error::Validation(
224                ValidationError::new("at least one user message is required")
225                    .with_field("messages"),
226            ));
227        }
228
229        let req = ProxyRequest {
230            model,
231            max_tokens: self.max_tokens,
232            temperature: self.temperature,
233            messages: self.messages.clone(),
234            metadata: self.metadata.clone(),
235            response_format: self.response_format.clone(),
236            stop: self.stop.clone(),
237            tools: self.tools.clone(),
238            tool_choice: self.tool_choice.clone(),
239        };
240        req.validate()?;
241        Ok(req)
242    }
243
244    /// Execute the chat request (non-streaming, async).
245    #[cfg(feature = "client")]
246    pub async fn send(self, client: &LLMClient) -> Result<ProxyResponse> {
247        let req = self.build_request()?;
248        let opts = self.build_options();
249        client.proxy(req, opts).await
250    }
251
252    /// Execute the chat request and stream responses (async).
253    #[cfg(all(feature = "client", feature = "streaming"))]
254    pub async fn stream(self, client: &LLMClient) -> Result<StreamHandle> {
255        let req = self.build_request()?;
256        let opts = self.build_options();
257        client.proxy_stream(req, opts).await
258    }
259
260    /// Execute the chat request and stream text deltas (async).
261    #[cfg(all(feature = "client", feature = "streaming"))]
262    pub async fn stream_deltas(
263        self,
264        client: &LLMClient,
265    ) -> Result<std::pin::Pin<Box<dyn futures_core::Stream<Item = Result<String>> + Send>>> {
266        let req = self.build_request()?;
267        let opts = self.build_options();
268        client.proxy_stream_deltas(req, opts).await
269    }
270
271    /// Create a structured output builder with automatic schema generation.
272    ///
273    /// This method transitions the builder to a [`StructuredChatBuilder`] that
274    /// automatically generates a JSON schema from the type `T` and handles
275    /// validation retries.
276    ///
277    /// # Example
278    ///
279    /// ```ignore
280    /// use schemars::JsonSchema;
281    /// use serde::{Deserialize, Serialize};
282    ///
283    /// #[derive(Debug, Serialize, Deserialize, JsonSchema)]
284    /// struct Person {
285    ///     name: String,
286    ///     age: u32,
287    /// }
288    ///
289    /// let result = client.chat()
290    ///     .model("claude-sonnet-4-20250514")
291    ///     .user("Extract: John Doe, 30 years old")
292    ///     .structured::<Person>()
293    ///     .max_retries(2)
294    ///     .send()
295    ///     .await?;
296    /// ```
297    #[cfg(feature = "client")]
298    pub fn structured<T>(self) -> crate::structured::StructuredChatBuilder<T>
299    where
300        T: JsonSchema + DeserializeOwned,
301    {
302        crate::structured::StructuredChatBuilder::new(self)
303    }
304
305    /// Execute the chat request and stream structured JSON payloads (async).
306    ///
307    /// The request must include a structured response_format (type=json_schema),
308    /// and uses NDJSON framing per the /llm/proxy structured streaming contract.
309    #[cfg(all(feature = "client", feature = "streaming"))]
310    pub async fn stream_json<T>(self, client: &LLMClient) -> Result<StructuredJSONStream<T>>
311    where
312        T: DeserializeOwned,
313    {
314        let req = self.build_request()?;
315        match &req.response_format {
316            Some(format) if format.is_structured() => {}
317            Some(_) => {
318                return Err(Error::Validation(
319                    ValidationError::new("response_format must be structured (type=json_schema)")
320                        .with_field("response_format.type"),
321                ));
322            }
323            None => {
324                return Err(Error::Validation(
325                    ValidationError::new("response_format is required for structured streaming")
326                        .with_field("response_format"),
327                ));
328            }
329        }
330        let opts = self.build_options();
331        let stream = client.proxy_stream(req, opts).await?;
332        Ok(StructuredJSONStream::new(stream))
333    }
334
335    /// Execute the chat request (blocking).
336    #[cfg(feature = "blocking")]
337    pub fn send_blocking(self, client: &BlockingLLMClient) -> Result<ProxyResponse> {
338        let req = self.build_request()?;
339        let opts = self.build_options();
340        client.proxy(req, opts)
341    }
342
343    /// Execute the chat request and stream responses (blocking).
344    #[cfg(all(feature = "blocking", feature = "streaming"))]
345    pub fn stream_blocking(self, client: &BlockingLLMClient) -> Result<BlockingProxyHandle> {
346        let req = self.build_request()?;
347        let opts = self.build_options();
348        client.proxy_stream(req, opts)
349    }
350
351    /// Execute the chat request and stream text deltas (blocking).
352    #[cfg(all(feature = "blocking", feature = "streaming"))]
353    pub fn stream_deltas_blocking(
354        self,
355        client: &BlockingLLMClient,
356    ) -> Result<Box<dyn Iterator<Item = Result<String>>>> {
357        let req = self.build_request()?;
358        let opts = self.build_options();
359        client.proxy_stream_deltas(req, opts)
360    }
361
362    /// Execute the chat request and stream structured JSON payloads (blocking).
363    ///
364    /// The request must include a structured response_format (type=json_schema),
365    /// and uses NDJSON framing per the /llm/proxy structured streaming contract.
366    #[cfg(all(feature = "blocking", feature = "streaming"))]
367    pub fn stream_json_blocking<T>(
368        self,
369        client: &BlockingLLMClient,
370    ) -> Result<BlockingStructuredJSONStream<T>>
371    where
372        T: DeserializeOwned,
373    {
374        let req = self.build_request()?;
375        match &req.response_format {
376            Some(format) if format.is_structured() => {}
377            Some(_) => {
378                return Err(Error::Validation(
379                    ValidationError::new("response_format must be structured (type=json_schema)")
380                        .with_field("response_format.type"),
381                ));
382            }
383            None => {
384                return Err(Error::Validation(
385                    ValidationError::new("response_format is required for structured streaming")
386                        .with_field("response_format"),
387                ));
388            }
389        }
390        let opts = self.build_options();
391        let stream = client.proxy_stream(req, opts)?;
392        Ok(BlockingStructuredJSONStream::new(stream))
393    }
394}
395
396/// Header name for customer ID attribution.
397pub const CUSTOMER_ID_HEADER: &str = "X-ModelRelay-Customer-Id";
398
399/// Builder for customer-attributed LLM proxy chat requests.
400///
401/// Unlike [`ChatRequestBuilder`], this builder does not require a model since
402/// the customer's tier determines which model to use. Create via
403/// [`LLMClient::for_customer`].
404#[derive(Clone, Debug, Default)]
405pub struct CustomerChatRequestBuilder {
406    pub(crate) customer_id: String,
407    pub(crate) max_tokens: Option<i64>,
408    pub(crate) temperature: Option<f64>,
409    pub(crate) messages: Vec<ProxyMessage>,
410    pub(crate) metadata: Option<HashMap<String, String>>,
411    pub(crate) response_format: Option<ResponseFormat>,
412    pub(crate) stop: Option<Vec<String>>,
413    pub(crate) tools: Option<Vec<crate::types::Tool>>,
414    pub(crate) tool_choice: Option<crate::types::ToolChoice>,
415    pub(crate) request_id: Option<String>,
416    pub(crate) headers: Vec<(String, String)>,
417    pub(crate) timeout: Option<Duration>,
418    pub(crate) retry: Option<RetryConfig>,
419}
420
421// Generate shared builder methods for CustomerChatRequestBuilder
422impl_chat_builder_common!(CustomerChatRequestBuilder);
423
424impl CustomerChatRequestBuilder {
425    /// Create a new customer chat builder for the given customer ID.
426    pub fn new(customer_id: impl Into<String>) -> Self {
427        Self {
428            customer_id: customer_id.into(),
429            ..Default::default()
430        }
431    }
432
433    fn build_options(&self) -> ProxyOptions {
434        let mut opts = ProxyOptions::default();
435        if let Some(req_id) = &self.request_id {
436            opts = opts.with_request_id(req_id.clone());
437        }
438        // Customer ID is passed directly to proxy_customer/proxy_customer_stream
439        for (k, v) in &self.headers {
440            opts = opts.with_header(k.clone(), v.clone());
441        }
442        if let Some(timeout) = self.timeout {
443            opts = opts.with_timeout(timeout);
444        }
445        if let Some(retry) = &self.retry {
446            opts = opts.with_retry(retry.clone());
447        }
448        opts
449    }
450
451    /// Build the request body. Uses an empty model since the tier determines it.
452    pub(crate) fn build_request_body(&self) -> Result<CustomerProxyRequestBody> {
453        if self.messages.is_empty() {
454            return Err(Error::Validation(
455                crate::errors::ValidationError::new("at least one message is required")
456                    .with_field("messages"),
457            ));
458        }
459        if !self
460            .messages
461            .iter()
462            .any(|msg| msg.role == crate::types::MessageRole::User)
463        {
464            return Err(Error::Validation(
465                crate::errors::ValidationError::new("at least one user message is required")
466                    .with_field("messages"),
467            ));
468        }
469        Ok(CustomerProxyRequestBody {
470            max_tokens: self.max_tokens,
471            temperature: self.temperature,
472            messages: self.messages.clone(),
473            metadata: self.metadata.clone(),
474            response_format: self.response_format.clone(),
475            stop: self.stop.clone(),
476        })
477    }
478
479    /// Execute the chat request (non-streaming, async).
480    #[cfg(feature = "client")]
481    pub async fn send(self, client: &LLMClient) -> Result<ProxyResponse> {
482        let body = self.build_request_body()?;
483        let opts = self.build_options();
484        client.proxy_customer(&self.customer_id, body, opts).await
485    }
486
487    /// Execute the chat request and stream responses (async).
488    #[cfg(all(feature = "client", feature = "streaming"))]
489    pub async fn stream(self, client: &LLMClient) -> Result<StreamHandle> {
490        let body = self.build_request_body()?;
491        let opts = self.build_options();
492        client
493            .proxy_customer_stream(&self.customer_id, body, opts)
494            .await
495    }
496
497    /// Execute the chat request (blocking).
498    #[cfg(feature = "blocking")]
499    pub fn send_blocking(self, client: &BlockingLLMClient) -> Result<ProxyResponse> {
500        let body = self.build_request_body()?;
501        let opts = self.build_options();
502        client.proxy_customer(&self.customer_id, body, opts)
503    }
504
505    /// Execute the chat request and stream responses (blocking).
506    #[cfg(all(feature = "blocking", feature = "streaming"))]
507    pub fn stream_blocking(self, client: &BlockingLLMClient) -> Result<BlockingProxyHandle> {
508        let body = self.build_request_body()?;
509        let opts = self.build_options();
510        client.proxy_customer_stream(&self.customer_id, body, opts)
511    }
512
513    /// Execute the chat request and stream structured JSON payloads (async).
514    ///
515    /// The request must include a structured response_format (type=json_schema),
516    /// and uses NDJSON framing per the /llm/proxy structured streaming contract.
517    ///
518    /// # Example
519    ///
520    /// ```ignore
521    /// use serde::Deserialize;
522    /// use modelrelay::CustomerChatRequestBuilder;
523    ///
524    /// #[derive(Debug, Deserialize)]
525    /// struct CommitMessage {
526    ///     title: String,
527    ///     body: Option<String>,
528    /// }
529    ///
530    /// let stream = CustomerChatRequestBuilder::new("user-123")
531    ///     .user("Generate a commit message for: ...")
532    ///     .response_format(ResponseFormat::json_schema::<CommitMessage>("CommitMessage"))
533    ///     .stream_json::<CommitMessage>(&client.llm())
534    ///     .await?;
535    ///
536    /// let result = stream.collect().await?;
537    /// println!("Title: {}", result.title);
538    /// ```
539    #[cfg(all(feature = "client", feature = "streaming"))]
540    pub async fn stream_json<T>(self, client: &LLMClient) -> Result<StructuredJSONStream<T>>
541    where
542        T: DeserializeOwned,
543    {
544        let body = self.build_request_body()?;
545        match &body.response_format {
546            Some(format) if format.is_structured() => {}
547            Some(_) => {
548                return Err(Error::Validation(
549                    ValidationError::new("response_format must be structured (type=json_schema)")
550                        .with_field("response_format.type"),
551                ));
552            }
553            None => {
554                return Err(Error::Validation(
555                    ValidationError::new("response_format is required for structured streaming")
556                        .with_field("response_format"),
557                ));
558            }
559        }
560        let opts = self.build_options();
561        let stream = client
562            .proxy_customer_stream(&self.customer_id, body, opts)
563            .await?;
564        Ok(StructuredJSONStream::new(stream))
565    }
566
567    /// Execute the chat request and stream structured JSON payloads (blocking).
568    ///
569    /// The request must include a structured response_format (type=json_schema),
570    /// and uses NDJSON framing per the /llm/proxy structured streaming contract.
571    ///
572    /// # Example
573    ///
574    /// ```ignore
575    /// use modelrelay::{CustomerChatRequestBuilder, ResponseFormat};
576    /// use serde::Deserialize;
577    ///
578    /// #[derive(Debug, Deserialize)]
579    /// struct CommitMessage {
580    ///     title: String,
581    ///     body: Option<String>,
582    /// }
583    ///
584    /// let mut stream = CustomerChatRequestBuilder::new("user-123")
585    ///     .user("Generate a commit message for: ...")
586    ///     .response_format(ResponseFormat::json_schema::<CommitMessage>("CommitMessage"))
587    ///     .stream_json_blocking::<CommitMessage>(&client.llm())?;
588    ///
589    /// let result = stream.collect()?;
590    /// println!("Title: {}", result.title);
591    /// ```
592    #[cfg(all(feature = "blocking", feature = "streaming"))]
593    pub fn stream_json_blocking<T>(
594        self,
595        client: &BlockingLLMClient,
596    ) -> Result<BlockingStructuredJSONStream<T>>
597    where
598        T: DeserializeOwned,
599    {
600        let body = self.build_request_body()?;
601        match &body.response_format {
602            Some(format) if format.is_structured() => {}
603            Some(_) => {
604                return Err(Error::Validation(
605                    ValidationError::new("response_format must be structured (type=json_schema)")
606                        .with_field("response_format.type"),
607                ));
608            }
609            None => {
610                return Err(Error::Validation(
611                    ValidationError::new("response_format is required for structured streaming")
612                        .with_field("response_format"),
613                ));
614            }
615        }
616        let opts = self.build_options();
617        let stream = client.proxy_customer_stream(&self.customer_id, body, opts)?;
618        Ok(BlockingStructuredJSONStream::new(stream))
619    }
620}
621
622/// Request body for customer-attributed proxy requests (no model field).
623#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq)]
624pub struct CustomerProxyRequestBody {
625    #[serde(skip_serializing_if = "Option::is_none")]
626    pub max_tokens: Option<i64>,
627    #[serde(skip_serializing_if = "Option::is_none")]
628    pub temperature: Option<f64>,
629    pub messages: Vec<ProxyMessage>,
630    #[serde(skip_serializing_if = "Option::is_none")]
631    pub metadata: Option<HashMap<String, String>>,
632    #[serde(skip_serializing_if = "Option::is_none")]
633    pub response_format: Option<ResponseFormat>,
634    #[serde(skip_serializing_if = "Option::is_none")]
635    pub stop: Option<Vec<String>>,
636}
637
638/// Thin adapter over streaming events to yield text deltas and final metadata.
639#[cfg(feature = "streaming")]
640#[derive(Debug)]
641pub struct ChatStreamAdapter<S> {
642    inner: S,
643    finished: bool,
644    final_usage: Option<Usage>,
645    final_stop_reason: Option<StopReason>,
646    final_request_id: Option<String>,
647}
648
649#[cfg(all(feature = "client", feature = "streaming"))]
650impl ChatStreamAdapter<StreamHandle> {
651    pub fn new(stream: StreamHandle) -> Self {
652        Self {
653            inner: stream,
654            finished: false,
655            final_usage: None,
656            final_stop_reason: None,
657            final_request_id: None,
658        }
659    }
660
661    /// Pull the next text delta (if any) and track final usage/stop metadata.
662    pub async fn next_delta(&mut self) -> Result<Option<String>> {
663        use futures_util::StreamExt;
664
665        while let Some(item) = self.inner.next().await {
666            let evt = item?;
667            match evt.kind {
668                StreamEventKind::MessageDelta => {
669                    if let Some(delta) = evt.text_delta {
670                        return Ok(Some(delta));
671                    }
672                }
673                StreamEventKind::MessageStop => {
674                    self.finished = true;
675                    self.final_usage = evt.usage;
676                    self.final_stop_reason = evt.stop_reason;
677                    self.final_request_id = evt
678                        .request_id
679                        .or_else(|| self.inner.request_id().map(|s| s.to_string()));
680                    return Ok(None);
681                }
682                _ => {}
683            }
684        }
685        Ok(None)
686    }
687
688    /// Final usage info if the stream finished.
689    pub fn final_usage(&self) -> Option<&Usage> {
690        self.final_usage.as_ref()
691    }
692
693    /// Final stop reason if the stream finished.
694    pub fn final_stop_reason(&self) -> Option<&StopReason> {
695        self.final_stop_reason.as_ref()
696    }
697
698    /// Final request id if known.
699    pub fn final_request_id(&self) -> Option<&str> {
700        self.final_request_id.as_deref()
701    }
702
703    /// Convert to a stream of deltas, propagating errors and tracking final state.
704    pub fn into_stream(self) -> impl futures_core::Stream<Item = Result<String>> {
705        stream::unfold(self, |mut adapter| async move {
706            match adapter.next_delta().await {
707                Ok(Some(delta)) => Some((Ok(delta), adapter)),
708                Ok(None) => None,
709                Err(err) => Some((Err(err), adapter)),
710            }
711        })
712    }
713}
714
715/// Structured streaming record kinds surfaced by the helper.
716#[cfg(feature = "streaming")]
717#[derive(Debug, Clone, Copy, PartialEq, Eq)]
718pub enum StructuredRecordKind {
719    Update,
720    Completion,
721}
722
723/// Typed structured JSON event yielded from the NDJSON stream.
724#[cfg(feature = "streaming")]
725#[derive(Debug, Clone)]
726pub struct StructuredJSONEvent<T> {
727    pub kind: StructuredRecordKind,
728    pub payload: T,
729    pub request_id: Option<String>,
730    /// Set of field paths that are complete (have their closing delimiter).
731    /// Use dot notation for nested fields (e.g., "metadata.author").
732    /// Check with complete_fields.contains("fieldName").
733    pub complete_fields: std::collections::HashSet<String>,
734}
735
736/// Helper over NDJSON streaming events to yield structured JSON payloads.
737#[cfg(all(feature = "client", feature = "streaming"))]
738pub struct StructuredJSONStream<T> {
739    inner: StreamHandle,
740    finished: bool,
741    saw_completion: bool,
742    _marker: std::marker::PhantomData<T>,
743}
744
745#[cfg(all(feature = "client", feature = "streaming"))]
746impl<T> StructuredJSONStream<T>
747where
748    T: DeserializeOwned,
749{
750    pub fn new(stream: StreamHandle) -> Self {
751        Self {
752            inner: stream,
753            finished: false,
754            saw_completion: false,
755            _marker: std::marker::PhantomData,
756        }
757    }
758
759    /// Pull the next structured JSON event, skipping start/unknown records.
760    pub async fn next(&mut self) -> Result<Option<StructuredJSONEvent<T>>> {
761        use futures_util::StreamExt;
762
763        if self.finished {
764            return Ok(None);
765        }
766
767        while let Some(item) = self.inner.next().await {
768            let evt = item?;
769            let value = match evt.data {
770                Some(ref v) if v.is_object() => v,
771                _ => continue,
772            };
773            let record_type = value
774                .get("type")
775                .and_then(|v| v.as_str())
776                .map(|s| s.trim().to_lowercase())
777                .unwrap_or_default();
778
779            match record_type.as_str() {
780                "" | "start" => continue,
781                "update" | "completion" => {
782                    let payload_value = value.get("payload").cloned().ok_or_else(|| {
783                        Error::Transport(TransportError {
784                            kind: TransportErrorKind::Request,
785                            message: "structured stream record missing payload".to_string(),
786                            source: None,
787                            retries: None,
788                        })
789                    })?;
790                    let payload: T =
791                        serde_json::from_value(payload_value).map_err(Error::Serialization)?;
792                    let kind = if record_type == "update" {
793                        StructuredRecordKind::Update
794                    } else {
795                        self.saw_completion = true;
796                        StructuredRecordKind::Completion
797                    };
798                    let request_id = evt
799                        .request_id
800                        .or_else(|| self.inner.request_id().map(|s| s.to_string()));
801                    // Extract complete_fields array and convert to HashSet
802                    let complete_fields: std::collections::HashSet<String> = value
803                        .get("complete_fields")
804                        .and_then(|v| v.as_array())
805                        .map(|arr| {
806                            arr.iter()
807                                .filter_map(|v| v.as_str().map(|s| s.to_string()))
808                                .collect()
809                        })
810                        .unwrap_or_default();
811                    return Ok(Some(StructuredJSONEvent {
812                        kind,
813                        payload,
814                        request_id,
815                        complete_fields,
816                    }));
817                }
818                "error" => {
819                    self.saw_completion = true;
820                    let code = value
821                        .get("code")
822                        .and_then(|v| v.as_str())
823                        .map(|s| s.to_string());
824                    let message = value
825                        .get("message")
826                        .and_then(|v| v.as_str())
827                        .unwrap_or("structured stream error")
828                        .to_string();
829                    let status = value
830                        .get("status")
831                        .and_then(|v| v.as_u64())
832                        .map(|v| v as u16)
833                        .unwrap_or(500);
834                    let request_id = evt
835                        .request_id
836                        .or_else(|| self.inner.request_id().map(|s| s.to_string()));
837                    return Err(APIError {
838                        status,
839                        code,
840                        message,
841                        request_id,
842                        fields: Vec::new(),
843                        retries: None,
844                        raw_body: None,
845                    }
846                    .into());
847                }
848                _ => continue,
849            }
850        }
851
852        self.finished = true;
853        if !self.saw_completion {
854            return Err(Error::Transport(TransportError {
855                kind: TransportErrorKind::Request,
856                message: "structured stream ended without completion or error".to_string(),
857                source: None,
858                retries: None,
859            }));
860        }
861        Ok(None)
862    }
863
864    /// Drain the stream and return the final structured payload from the completion record.
865    pub async fn collect(mut self) -> Result<T> {
866        let mut last: Option<T> = None;
867        while let Some(event) = self.next().await? {
868            if matches!(event.kind, StructuredRecordKind::Completion) {
869                return Ok(event.payload);
870            }
871            last = Some(event.payload);
872        }
873        match last {
874            Some(payload) => Ok(payload),
875            None => Err(Error::Transport(TransportError {
876                kind: TransportErrorKind::Request,
877                message: "structured stream ended without completion or error".to_string(),
878                source: None,
879                retries: None,
880            })),
881        }
882    }
883
884    /// Request identifier returned by the server (if any).
885    pub fn request_id(&self) -> Option<&str> {
886        self.inner.request_id()
887    }
888}
889
890/// Blocking helper over NDJSON streaming events to yield structured JSON payloads.
891#[cfg(all(feature = "blocking", feature = "streaming"))]
892pub struct BlockingStructuredJSONStream<T> {
893    inner: BlockingProxyHandle,
894    finished: bool,
895    saw_completion: bool,
896    _marker: std::marker::PhantomData<T>,
897}
898
899#[cfg(all(feature = "blocking", feature = "streaming"))]
900impl<T> BlockingStructuredJSONStream<T>
901where
902    T: DeserializeOwned,
903{
904    pub fn new(stream: BlockingProxyHandle) -> Self {
905        Self {
906            inner: stream,
907            finished: false,
908            saw_completion: false,
909            _marker: std::marker::PhantomData,
910        }
911    }
912
913    /// Pull the next structured JSON event, skipping start/unknown records.
914    #[allow(clippy::should_implement_trait)]
915    pub fn next(&mut self) -> Result<Option<StructuredJSONEvent<T>>> {
916        if self.finished {
917            return Ok(None);
918        }
919
920        while let Some(evt) = self.inner.next()? {
921            let value = match evt.data {
922                Some(ref v) if v.is_object() => v,
923                _ => continue,
924            };
925            let record_type = value
926                .get("type")
927                .and_then(|v| v.as_str())
928                .map(|s| s.trim().to_lowercase())
929                .unwrap_or_default();
930
931            match record_type.as_str() {
932                "" | "start" => continue,
933                "update" | "completion" => {
934                    let payload_value = value.get("payload").cloned().ok_or_else(|| {
935                        Error::Transport(TransportError {
936                            kind: TransportErrorKind::Request,
937                            message: "structured stream record missing payload".to_string(),
938                            source: None,
939                            retries: None,
940                        })
941                    })?;
942                    let payload: T =
943                        serde_json::from_value(payload_value).map_err(Error::Serialization)?;
944                    let kind = if record_type == "update" {
945                        StructuredRecordKind::Update
946                    } else {
947                        self.saw_completion = true;
948                        StructuredRecordKind::Completion
949                    };
950                    let request_id = evt
951                        .request_id
952                        .or_else(|| self.inner.request_id().map(|s| s.to_string()));
953                    let complete_fields: std::collections::HashSet<String> = value
954                        .get("complete_fields")
955                        .and_then(|v| v.as_array())
956                        .map(|arr| {
957                            arr.iter()
958                                .filter_map(|v| v.as_str().map(|s| s.to_string()))
959                                .collect()
960                        })
961                        .unwrap_or_default();
962                    return Ok(Some(StructuredJSONEvent {
963                        kind,
964                        payload,
965                        request_id,
966                        complete_fields,
967                    }));
968                }
969                "error" => {
970                    self.saw_completion = true;
971                    let code = value
972                        .get("code")
973                        .and_then(|v| v.as_str())
974                        .map(|s| s.to_string());
975                    let message = value
976                        .get("message")
977                        .and_then(|v| v.as_str())
978                        .unwrap_or("structured stream error")
979                        .to_string();
980                    let status = value
981                        .get("status")
982                        .and_then(|v| v.as_u64())
983                        .map(|v| v as u16)
984                        .unwrap_or(500);
985                    let request_id = evt
986                        .request_id
987                        .or_else(|| self.inner.request_id().map(|s| s.to_string()));
988                    return Err(APIError {
989                        status,
990                        code,
991                        message,
992                        request_id,
993                        fields: Vec::new(),
994                        retries: None,
995                        raw_body: None,
996                    }
997                    .into());
998                }
999                _ => continue,
1000            }
1001        }
1002
1003        self.finished = true;
1004        if !self.saw_completion {
1005            return Err(Error::Transport(TransportError {
1006                kind: TransportErrorKind::Request,
1007                message: "structured stream ended without completion or error".to_string(),
1008                source: None,
1009                retries: None,
1010            }));
1011        }
1012        Ok(None)
1013    }
1014
1015    /// Drain the stream and return the final structured payload from the completion record.
1016    pub fn collect(mut self) -> Result<T> {
1017        let mut last: Option<T> = None;
1018        while let Some(event) = self.next()? {
1019            if matches!(event.kind, StructuredRecordKind::Completion) {
1020                return Ok(event.payload);
1021            }
1022            last = Some(event.payload);
1023        }
1024        match last {
1025            Some(payload) => Ok(payload),
1026            None => Err(Error::Transport(TransportError {
1027                kind: TransportErrorKind::Request,
1028                message: "structured stream ended without completion or error".to_string(),
1029                source: None,
1030                retries: None,
1031            })),
1032        }
1033    }
1034
1035    /// Request identifier returned by the server (if any).
1036    pub fn request_id(&self) -> Option<&str> {
1037        self.inner.request_id()
1038    }
1039}
1040
1041/// Blocking streaming adapter.
1042#[cfg(all(feature = "blocking", feature = "streaming"))]
1043impl ChatStreamAdapter<BlockingProxyHandle> {
1044    pub fn new(stream: BlockingProxyHandle) -> Self {
1045        Self {
1046            inner: stream,
1047            finished: false,
1048            final_usage: None,
1049            final_stop_reason: None,
1050            final_request_id: None,
1051        }
1052    }
1053
1054    pub fn request_id(&self) -> Option<&str> {
1055        self.inner.request_id()
1056    }
1057
1058    pub fn next_delta(&mut self) -> Result<Option<String>> {
1059        while let Some(evt) = self.inner.next()? {
1060            match evt.kind {
1061                StreamEventKind::MessageDelta => {
1062                    if let Some(delta) = evt.text_delta {
1063                        return Ok(Some(delta));
1064                    }
1065                }
1066                StreamEventKind::MessageStop => {
1067                    self.finished = true;
1068                    self.final_usage = evt.usage;
1069                    self.final_stop_reason = evt.stop_reason;
1070                    self.final_request_id = evt
1071                        .request_id
1072                        .or_else(|| self.inner.request_id().map(|s| s.to_string()));
1073                    return Ok(None);
1074                }
1075                _ => {}
1076            }
1077        }
1078        Ok(None)
1079    }
1080
1081    pub fn final_usage(&self) -> Option<&Usage> {
1082        self.final_usage.as_ref()
1083    }
1084
1085    pub fn final_stop_reason(&self) -> Option<&StopReason> {
1086        self.final_stop_reason.as_ref()
1087    }
1088
1089    pub fn final_request_id(&self) -> Option<&str> {
1090        self.final_request_id.as_deref()
1091    }
1092
1093    /// Iterate over text deltas until completion or error.
1094    #[allow(clippy::should_implement_trait)]
1095    pub fn into_iter(self) -> impl Iterator<Item = Result<String>> {
1096        let mut adapter = self;
1097        std::iter::from_fn(move || match adapter.next_delta() {
1098            Ok(Some(delta)) => Some(Ok(delta)),
1099            Ok(None) => None,
1100            Err(err) => Some(Err(err)),
1101        })
1102    }
1103}
1104
1105#[cfg(test)]
1106mod tests {
1107    use super::*;
1108    use crate::types::{Model, ResponseFormatKind, StreamEvent, StreamEventKind};
1109    use crate::ClientBuilder;
1110
1111    #[test]
1112    fn build_request_requires_user_message() {
1113        let builder = ChatRequestBuilder::new(Model::from("gpt-4o-mini")).system("just a system");
1114        let err = builder.build_request().unwrap_err();
1115        match err {
1116            Error::Validation(msg) => {
1117                assert!(
1118                    msg.to_string().contains("user"),
1119                    "unexpected validation: {msg}"
1120                );
1121            }
1122            other => panic!("expected validation error, got {other:?}"),
1123        }
1124    }
1125
1126    #[test]
1127    fn metadata_entry_ignores_empty_pairs() {
1128        let req = ChatRequestBuilder::new(Model::from("gpt-4o-mini"))
1129            .user("hello")
1130            .metadata_entry("trace_id", "abc123")
1131            .metadata_entry("", "should_skip")
1132            .metadata_entry("empty", "")
1133            .build_request()
1134            .unwrap();
1135        let meta = req.metadata.unwrap();
1136        assert_eq!(meta.len(), 1);
1137        assert_eq!(meta.get("trace_id"), Some(&"abc123".to_string()));
1138    }
1139
1140    #[test]
1141    fn role_helpers_append_expected_roles() {
1142        use crate::types::MessageRole;
1143        let req = ChatRequestBuilder::new("gpt-4o-mini")
1144            .system("sys")
1145            .user("u1")
1146            .assistant("a1")
1147            .build_request()
1148            .unwrap();
1149        let roles: Vec<_> = req.messages.iter().map(|m| m.role).collect();
1150        assert_eq!(
1151            roles,
1152            vec![
1153                MessageRole::System,
1154                MessageRole::User,
1155                MessageRole::Assistant
1156            ]
1157        );
1158    }
1159
1160    #[cfg(all(feature = "client", feature = "streaming"))]
1161    #[tokio::test]
1162    async fn stream_json_requires_structured_response_format() {
1163        let client = ClientBuilder::new()
1164            .api_key("mr_sk_test")
1165            .build()
1166            .expect("client build");
1167
1168        // Missing response_format
1169        let builder = ChatRequestBuilder::new(Model::from("gpt-4o-mini")).user("hi");
1170        let result = builder
1171            .clone()
1172            .stream_json::<serde_json::Value>(&client.llm())
1173            .await;
1174        match result {
1175            Err(Error::Validation(v)) => {
1176                assert!(
1177                    v.to_string().contains("response_format"),
1178                    "unexpected validation error: {v}"
1179                );
1180            }
1181            Ok(_) => panic!("expected Validation error, got Ok"),
1182            Err(other) => panic!("expected Validation error, got {other:?}"),
1183        }
1184
1185        // Non-structured response_format (Text)
1186        let format = ResponseFormat {
1187            kind: ResponseFormatKind::Text,
1188            json_schema: None,
1189        };
1190        let builder = ChatRequestBuilder::new(Model::from("gpt-4o-mini"))
1191            .user("hi")
1192            .response_format(format);
1193        let result = builder
1194            .stream_json::<serde_json::Value>(&client.llm())
1195            .await;
1196        match result {
1197            Err(Error::Validation(v)) => {
1198                assert!(
1199                    v.to_string().contains("response_format must be structured"),
1200                    "unexpected validation error: {v}"
1201                );
1202            }
1203            Ok(_) => panic!("expected Validation error, got Ok"),
1204            Err(other) => panic!("expected Validation error, got {other:?}"),
1205        }
1206    }
1207
1208    #[cfg(all(feature = "client", feature = "streaming"))]
1209    #[tokio::test]
1210    async fn structured_json_stream_yields_update_and_completion() {
1211        #[derive(Debug, serde::Deserialize, PartialEq)]
1212        struct Item {
1213            id: String,
1214        }
1215
1216        #[derive(Debug, serde::Deserialize, PartialEq)]
1217        struct ItemsPayload {
1218            items: Vec<Item>,
1219        }
1220
1221        let events = vec![
1222            StreamEvent {
1223                kind: StreamEventKind::Custom,
1224                event: "structured".into(),
1225                data: Some(serde_json::json!({"type":"start","request_id":"tiers-1"})),
1226                text_delta: None,
1227                tool_call_delta: None,
1228                tool_calls: None,
1229                response_id: None,
1230                model: None,
1231                stop_reason: None,
1232                usage: None,
1233                request_id: None,
1234                raw: String::new(),
1235            },
1236            StreamEvent {
1237                kind: StreamEventKind::Custom,
1238                event: "structured".into(),
1239                data: Some(serde_json::json!({"type":"update","payload":{"items":[{"id":"one"}]}})),
1240                text_delta: None,
1241                tool_call_delta: None,
1242                tool_calls: None,
1243                response_id: None,
1244                model: None,
1245                stop_reason: None,
1246                usage: None,
1247                request_id: None,
1248                raw: String::new(),
1249            },
1250            StreamEvent {
1251                kind: StreamEventKind::Custom,
1252                event: "structured".into(),
1253                data: Some(
1254                    serde_json::json!({"type":"completion","payload":{"items":[{"id":"one"},{"id":"two"}]}}),
1255                ),
1256                text_delta: None,
1257                tool_call_delta: None,
1258                tool_calls: None,
1259                response_id: None,
1260                model: None,
1261                stop_reason: None,
1262                usage: None,
1263                request_id: None,
1264                raw: String::new(),
1265            },
1266        ];
1267
1268        let handle = StreamHandle::from_events_with_request_id(
1269            events.clone(),
1270            Some("req-structured".into()),
1271        );
1272        let mut stream = StructuredJSONStream::<ItemsPayload>::new(handle);
1273
1274        let first = stream.next().await.unwrap().unwrap();
1275        assert_eq!(first.kind, StructuredRecordKind::Update);
1276        assert_eq!(first.payload.items.len(), 1);
1277        assert_eq!(first.payload.items[0].id, "one");
1278
1279        let second = stream.next().await.unwrap().unwrap();
1280        assert_eq!(second.kind, StructuredRecordKind::Completion);
1281        assert_eq!(second.payload.items.len(), 2);
1282        assert_eq!(second.request_id.as_deref(), Some("req-structured"));
1283
1284        let handle2 =
1285            StreamHandle::from_events_with_request_id(events, Some("req-structured".into()));
1286        let stream2 = StructuredJSONStream::<ItemsPayload>::new(handle2);
1287        let collected = stream2.collect().await.unwrap();
1288        assert_eq!(collected.items.len(), 2);
1289    }
1290
1291    #[cfg(all(feature = "client", feature = "streaming"))]
1292    #[tokio::test]
1293    async fn structured_json_stream_maps_error_and_protocol_violation() {
1294        // Error record surfaces as APIError.
1295        let error_events = vec![StreamEvent {
1296            kind: StreamEventKind::Custom,
1297            event: "structured".into(),
1298            data: Some(
1299                serde_json::json!({"type":"error","code":"SERVICE_UNAVAILABLE","message":"upstream timeout","status":502}),
1300            ),
1301            text_delta: None,
1302            tool_call_delta: None,
1303            tool_calls: None,
1304            response_id: None,
1305            model: None,
1306            stop_reason: None,
1307            usage: None,
1308            request_id: None,
1309            raw: String::new(),
1310        }];
1311        let handle_err =
1312            StreamHandle::from_events_with_request_id(error_events, Some("req-error".into()));
1313        let mut err_stream = StructuredJSONStream::<serde_json::Value>::new(handle_err);
1314        let err = err_stream.next().await.unwrap_err();
1315        match err {
1316            Error::Api(api) => {
1317                assert_eq!(api.status, 502);
1318                assert_eq!(api.code.as_deref(), Some("SERVICE_UNAVAILABLE"));
1319                assert_eq!(api.request_id.as_deref(), Some("req-error"));
1320            }
1321            other => panic!("expected API error, got {other:?}"),
1322        }
1323
1324        // Stream ending without completion/error becomes a transport error.
1325        let update_only = vec![StreamEvent {
1326            kind: StreamEventKind::Custom,
1327            event: "structured".into(),
1328            data: Some(serde_json::json!({"type":"update","payload":{"items":[{"id":"one"}]}})),
1329            text_delta: None,
1330            tool_call_delta: None,
1331            tool_calls: None,
1332            response_id: None,
1333            model: None,
1334            stop_reason: None,
1335            usage: None,
1336            request_id: None,
1337            raw: String::new(),
1338        }];
1339        let handle_proto =
1340            StreamHandle::from_events_with_request_id(update_only, Some("req-incomplete".into()));
1341        let stream_proto = StructuredJSONStream::<serde_json::Value>::new(handle_proto);
1342        let err = stream_proto.collect().await.unwrap_err();
1343        match err {
1344            Error::Transport(te) => {
1345                assert!(
1346                    te.message
1347                        .contains("structured stream ended without completion or error"),
1348                    "unexpected message: {}",
1349                    te.message
1350                );
1351            }
1352            other => panic!("expected Transport error, got {other:?}"),
1353        }
1354    }
1355
1356    #[cfg(all(feature = "client", feature = "streaming"))]
1357    #[tokio::test]
1358    async fn customer_stream_json_requires_structured_response_format() {
1359        let client = ClientBuilder::new()
1360            .api_key("mr_sk_test")
1361            .build()
1362            .expect("client build");
1363
1364        // Missing response_format
1365        let builder = CustomerChatRequestBuilder::new("customer-123").user("hi");
1366        let result = builder
1367            .clone()
1368            .stream_json::<serde_json::Value>(&client.llm())
1369            .await;
1370        match result {
1371            Err(Error::Validation(v)) => {
1372                assert!(
1373                    v.to_string().contains("response_format"),
1374                    "unexpected validation error: {v}"
1375                );
1376            }
1377            Ok(_) => panic!("expected Validation error, got Ok"),
1378            Err(other) => panic!("expected Validation error, got {other:?}"),
1379        }
1380
1381        // Non-structured response_format (Text)
1382        let format = ResponseFormat {
1383            kind: ResponseFormatKind::Text,
1384            json_schema: None,
1385        };
1386        let builder = CustomerChatRequestBuilder::new("customer-123")
1387            .user("hi")
1388            .response_format(format);
1389        let result = builder
1390            .stream_json::<serde_json::Value>(&client.llm())
1391            .await;
1392        match result {
1393            Err(Error::Validation(v)) => {
1394                assert!(
1395                    v.to_string().contains("response_format must be structured"),
1396                    "unexpected validation error: {v}"
1397                );
1398            }
1399            Ok(_) => panic!("expected Validation error, got Ok"),
1400            Err(other) => panic!("expected Validation error, got {other:?}"),
1401        }
1402    }
1403
1404    #[test]
1405    fn customer_build_request_body_requires_user_message() {
1406        let builder = CustomerChatRequestBuilder::new("customer-123").system("just a system");
1407        let err = builder.build_request_body().unwrap_err();
1408        match err {
1409            Error::Validation(msg) => {
1410                assert!(
1411                    msg.to_string().contains("user"),
1412                    "unexpected validation: {msg}"
1413                );
1414            }
1415            other => panic!("expected validation error, got {other:?}"),
1416        }
1417    }
1418
1419    #[test]
1420    fn customer_metadata_entry_ignores_empty_pairs() {
1421        let body = CustomerChatRequestBuilder::new("customer-123")
1422            .user("hello")
1423            .metadata_entry("trace_id", "abc123")
1424            .metadata_entry("", "should_skip")
1425            .metadata_entry("empty", "")
1426            .build_request_body()
1427            .unwrap();
1428        let meta = body.metadata.unwrap();
1429        assert_eq!(meta.len(), 1);
1430        assert_eq!(meta.get("trace_id"), Some(&"abc123".to_string()));
1431    }
1432}