Skip to main content

quantum_sdk/
chat.rs

1use std::collections::HashMap;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use futures_util::Stream;
6use pin_project_lite::pin_project;
7use serde::{Deserialize, Serialize};
8
9use crate::client::Client;
10use crate::error::Result;
11
12/// Deserialize null as empty Vec (Gemini sometimes returns null for array fields).
13fn null_as_empty_vec<'de, D, T>(deserializer: D) -> std::result::Result<Vec<T>, D::Error>
14where
15    D: serde::Deserializer<'de>,
16    T: Deserialize<'de>,
17{
18    Option::<Vec<T>>::deserialize(deserializer).map(|v| v.unwrap_or_default())
19}
20
21/// Deserialize null as None for Option<Vec<T>> fields.
22fn deserialize_opt_vec<'de, D, T>(deserializer: D) -> std::result::Result<Option<Vec<T>>, D::Error>
23where
24    D: serde::Deserializer<'de>,
25    T: Deserialize<'de>,
26{
27    // null → None, [] → Some([]), [...] → Some([...])
28    Ok(Option::<Vec<T>>::deserialize(deserializer).unwrap_or(None))
29}
30
31/// Request body for text generation.
32#[derive(Debug, Clone, Serialize, Default)]
33pub struct ChatRequest {
34    /// Model ID that determines provider routing (e.g. "claude-sonnet-4-6", "grok-4-1-fast-non-reasoning").
35    pub model: String,
36
37    /// Conversation history.
38    pub messages: Vec<ChatMessage>,
39
40    /// Functions the model can call.
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub tools: Option<Vec<ChatTool>>,
43
44    /// Constrains tool use: "auto" (default), "any" (force tool use), "none", or a specific tool name.
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub tool_choice: Option<String>,
47
48    /// JSON Schema for structured output constraints.
49    #[serde(skip_serializing_if = "Option::is_none")]
50    pub output_schema: Option<serde_json::Value>,
51
52    /// Enables server-sent event streaming. Set automatically by `chat_stream`.
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub stream: Option<bool>,
55
56    /// Controls randomness (0.0-2.0).
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub temperature: Option<f64>,
59
60    /// Limits the response length.
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub max_tokens: Option<i32>,
63
64    /// Provider-specific settings (e.g. Anthropic thinking, xAI search).
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub provider_options: Option<HashMap<String, serde_json::Value>>,
67}
68
69/// A single message in a conversation.
70#[derive(Debug, Clone, Serialize, Deserialize, Default)]
71pub struct ChatMessage {
72    /// One of "system", "user", "assistant", or "tool".
73    pub role: String,
74
75    /// Text content of the message.
76    #[serde(skip_serializing_if = "Option::is_none")]
77    pub content: Option<String>,
78
79    /// Structured content for assistant messages with tool calls.
80    /// When present, takes precedence over `content`.
81    #[serde(skip_serializing_if = "Option::is_none", deserialize_with = "deserialize_opt_vec", default)]
82    pub content_blocks: Option<Vec<ContentBlock>>,
83
84    /// Required when role is "tool" — references the tool_use ID.
85    #[serde(skip_serializing_if = "Option::is_none")]
86    pub tool_call_id: Option<String>,
87
88    /// Whether a tool result is an error.
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub is_error: Option<bool>,
91}
92
93impl ChatMessage {
94    /// Creates a user message.
95    pub fn user(content: impl Into<String>) -> Self {
96        Self {
97            role: "user".to_string(),
98            content: Some(content.into()),
99            ..Default::default()
100        }
101    }
102
103    /// Creates an assistant message.
104    pub fn assistant(content: impl Into<String>) -> Self {
105        Self {
106            role: "assistant".to_string(),
107            content: Some(content.into()),
108            ..Default::default()
109        }
110    }
111
112    /// Creates a system message.
113    pub fn system(content: impl Into<String>) -> Self {
114        Self {
115            role: "system".to_string(),
116            content: Some(content.into()),
117            ..Default::default()
118        }
119    }
120
121    /// Creates a tool result message.
122    pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
123        Self {
124            role: "tool".to_string(),
125            content: Some(content.into()),
126            tool_call_id: Some(tool_call_id.into()),
127            ..Default::default()
128        }
129    }
130
131    /// Creates a tool error result message.
132    pub fn tool_error(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
133        Self {
134            role: "tool".to_string(),
135            content: Some(content.into()),
136            tool_call_id: Some(tool_call_id.into()),
137            is_error: Some(true),
138            ..Default::default()
139        }
140    }
141}
142
143/// A single block in the response content array.
144#[derive(Debug, Clone, Serialize, Deserialize, Default)]
145pub struct ContentBlock {
146    /// One of "text", "thinking", or "tool_use".
147    #[serde(rename = "type")]
148    pub block_type: String,
149
150    /// Content for "text" and "thinking" blocks.
151    #[serde(skip_serializing_if = "Option::is_none")]
152    pub text: Option<String>,
153
154    /// Tool call identifier for "tool_use" blocks.
155    #[serde(skip_serializing_if = "Option::is_none")]
156    pub id: Option<String>,
157
158    /// Function name for "tool_use" blocks.
159    #[serde(skip_serializing_if = "Option::is_none")]
160    pub name: Option<String>,
161
162    /// Function arguments for "tool_use" blocks.
163    #[serde(skip_serializing_if = "Option::is_none")]
164    pub input: Option<HashMap<String, serde_json::Value>>,
165
166    /// Gemini thought signature — must be echoed back with tool results.
167    #[serde(skip_serializing_if = "Option::is_none")]
168    pub thought_signature: Option<String>,
169
170    /// Base64-encoded data for file/image content blocks.
171    #[serde(skip_serializing_if = "Option::is_none")]
172    pub data: Option<String>,
173
174    /// Filename for file content blocks.
175    #[serde(skip_serializing_if = "Option::is_none")]
176    pub file_name: Option<String>,
177
178    /// MIME type for file/image content blocks.
179    #[serde(skip_serializing_if = "Option::is_none")]
180    pub mime_type: Option<String>,
181}
182
183/// Defines a function the model can call.
184#[derive(Debug, Clone, Serialize, Default)]
185pub struct ChatTool {
186    /// Function name.
187    pub name: String,
188
189    /// Explains what the function does.
190    pub description: String,
191
192    /// JSON Schema for the function's arguments.
193    #[serde(skip_serializing_if = "Option::is_none")]
194    pub parameters: Option<serde_json::Value>,
195
196    /// Enable guaranteed schema validation on tool inputs (Anthropic, OpenAI).
197    #[serde(skip_serializing_if = "Option::is_none")]
198    pub strict: Option<bool>,
199}
200
201/// Response from a non-streaming chat request.
202#[derive(Debug, Clone, Deserialize)]
203pub struct ChatResponse {
204    /// Unique request identifier.
205    pub id: String,
206
207    /// Model that generated the response.
208    pub model: String,
209
210    /// List of content blocks (text, thinking, tool_use).
211    #[serde(default, deserialize_with = "null_as_empty_vec")]
212    pub content: Vec<ContentBlock>,
213
214    /// Token counts and cost.
215    pub usage: Option<ChatUsage>,
216
217    /// Why generation stopped ("end_turn", "tool_use", "max_tokens").
218    #[serde(default)]
219    pub stop_reason: String,
220
221    /// Citations from web search (when search is enabled via provider_options).
222    #[serde(default, deserialize_with = "null_as_empty_vec")]
223    pub citations: Vec<Citation>,
224
225    /// Total cost from the X-QAI-Cost-Ticks header.
226    #[serde(skip)]
227    pub cost_ticks: i64,
228
229    /// From the X-QAI-Request-Id header.
230    #[serde(skip)]
231    pub request_id: String,
232}
233
234impl ChatResponse {
235    /// Returns the concatenated text content, ignoring thinking and tool_use blocks.
236    pub fn text(&self) -> String {
237        self.content
238            .iter()
239            .filter(|b| b.block_type == "text")
240            .filter_map(|b| b.text.as_deref())
241            .collect::<Vec<_>>()
242            .join("")
243    }
244
245    /// Returns the concatenated thinking content.
246    pub fn thinking(&self) -> String {
247        self.content
248            .iter()
249            .filter(|b| b.block_type == "thinking")
250            .filter_map(|b| b.text.as_deref())
251            .collect::<Vec<_>>()
252            .join("")
253    }
254
255    /// Returns all tool_use blocks from the response.
256    pub fn tool_calls(&self) -> Vec<&ContentBlock> {
257        self.content
258            .iter()
259            .filter(|b| b.block_type == "tool_use")
260            .collect()
261    }
262}
263
264/// A source reference from web search grounding.
265#[derive(Debug, Clone, Deserialize, Serialize)]
266pub struct Citation {
267    /// Title of the cited source.
268    #[serde(default)]
269    pub title: String,
270
271    /// URL of the cited source.
272    #[serde(default)]
273    pub url: String,
274
275    /// Relevant text snippet from the source.
276    #[serde(default)]
277    pub text: String,
278
279    /// Position in the response.
280    #[serde(default)]
281    pub index: i32,
282}
283
284/// Token counts and cost for a chat response.
285#[derive(Debug, Clone, Deserialize)]
286pub struct ChatUsage {
287    pub input_tokens: i32,
288    pub output_tokens: i32,
289    pub cost_ticks: i64,
290}
291
292/// A single event from an SSE chat stream.
293///
294/// Tool-use streaming uses a triplet of events since v0.7:
295/// `tool_use_start` carries `tool_use_start`, `tool_use_input_delta`
296/// carries `tool_use_input_delta`, and `tool_use_complete` carries
297/// `tool_use_complete`. The legacy atomic `tool_use` event is still
298/// emitted by backends that haven't shipped the triplet yet — for new
299/// code, prefer the triplet fields.
300#[derive(Debug, Clone)]
301pub struct StreamEvent {
302    /// Event type: "content_delta", "thinking_delta",
303    /// "tool_use_start", "tool_use_input_delta", "tool_use_complete",
304    /// "tool_use" (legacy), "usage", "heartbeat", "error", "done".
305    pub event_type: String,
306
307    /// Incremental text for content_delta and thinking_delta events.
308    pub delta: Option<StreamDelta>,
309
310    /// Populated for legacy atomic tool_use events.
311    pub tool_use: Option<StreamToolUse>,
312
313    /// Populated for tool_use_start events.
314    pub tool_use_start: Option<StreamToolUseStart>,
315
316    /// Populated for tool_use_input_delta events.
317    pub tool_use_input_delta: Option<StreamToolUseInputDelta>,
318
319    /// Populated for tool_use_complete events.
320    pub tool_use_complete: Option<StreamToolUseComplete>,
321
322    /// Populated for usage events.
323    pub usage: Option<ChatUsage>,
324
325    /// Populated for error events.
326    pub error: Option<String>,
327
328    /// True when the stream is complete.
329    pub done: bool,
330}
331
332/// Incremental text in a streaming event.
333#[derive(Debug, Clone, Deserialize)]
334pub struct StreamDelta {
335    pub text: String,
336}
337
338/// A tool call from a legacy (atomic) streaming event.
339#[derive(Debug, Clone, Deserialize)]
340pub struct StreamToolUse {
341    pub id: String,
342    pub name: String,
343    pub input: HashMap<String, serde_json::Value>,
344}
345
346/// Tool-call start event — fires once before any input deltas.
347#[derive(Debug, Clone, Deserialize)]
348pub struct StreamToolUseStart {
349    pub id: String,
350    pub name: String,
351}
352
353/// Tool-call input delta — fires zero or more times with raw JSON fragments.
354#[derive(Debug, Clone, Deserialize)]
355pub struct StreamToolUseInputDelta {
356    pub id: String,
357    /// Raw JSON fragment. May not parse on its own; accumulate until
358    /// the corresponding `tool_use_complete` event arrives with the
359    /// authoritative `input`.
360    pub partial_json: String,
361}
362
363/// Tool-call completion event — fires exactly once per call with the
364/// server-accumulated, fully-parsed arguments.
365#[derive(Debug, Clone, Deserialize)]
366pub struct StreamToolUseComplete {
367    pub id: String,
368    pub name: String,
369    pub input: HashMap<String, serde_json::Value>,
370}
371
372/// Raw JSON from the SSE stream before parsing into typed fields.
373#[derive(Deserialize)]
374struct RawStreamEvent {
375    #[serde(rename = "type")]
376    event_type: String,
377    #[serde(default)]
378    delta: Option<StreamDelta>,
379    #[serde(default)]
380    id: Option<String>,
381    #[serde(default)]
382    name: Option<String>,
383    #[serde(default)]
384    input: Option<HashMap<String, serde_json::Value>>,
385    /// Carried by `tool_use_input_delta` events — a raw JSON fragment.
386    #[serde(default)]
387    partial_json: Option<String>,
388    #[serde(default)]
389    input_tokens: Option<i32>,
390    #[serde(default)]
391    output_tokens: Option<i32>,
392    #[serde(default)]
393    cost_ticks: Option<i64>,
394    #[serde(default)]
395    message: Option<String>,
396}
397
398pin_project! {
399    /// An async stream of [`StreamEvent`]s from an SSE chat response.
400    pub struct ChatStream {
401        #[pin]
402        inner: Pin<Box<dyn Stream<Item = StreamEvent> + Send>>,
403    }
404}
405
406impl Stream for ChatStream {
407    type Item = StreamEvent;
408
409    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
410        self.project().inner.poll_next(cx)
411    }
412}
413
414impl Client {
415    /// Sends a non-streaming text generation request.
416    pub async fn chat(&self, req: &ChatRequest) -> Result<ChatResponse> {
417        let mut req = req.clone();
418        req.stream = Some(false);
419
420        let (mut resp, meta) = self.post_json::<ChatRequest, ChatResponse>("/qai/v1/chat", &req).await?;
421        resp.cost_ticks = meta.cost_ticks;
422        resp.request_id = meta.request_id;
423        if resp.model.is_empty() {
424            resp.model = meta.model;
425        }
426        Ok(resp)
427    }
428
429    /// Sends a streaming text generation request and returns an async stream of events.
430    ///
431    /// # Example
432    ///
433    /// ```no_run
434    /// use futures_util::StreamExt;
435    ///
436    /// # async fn example() -> quantum_sdk::Result<()> {
437    /// let client = quantum_sdk::Client::new("key");
438    /// let req = quantum_sdk::ChatRequest {
439    ///     model: "claude-sonnet-4-6".into(),
440    ///     messages: vec![quantum_sdk::ChatMessage::user("Hello!")],
441    ///     ..Default::default()
442    /// };
443    /// let mut stream = client.chat_stream(&req).await?;
444    /// while let Some(ev) = stream.next().await {
445    ///     if let Some(delta) = &ev.delta {
446    ///         print!("{}", delta.text);
447    ///     }
448    /// }
449    /// # Ok(())
450    /// # }
451    /// ```
452    pub async fn chat_stream(&self, req: &ChatRequest) -> Result<ChatStream> {
453        let mut req = req.clone();
454        req.stream = Some(true);
455
456        let (resp, _meta) = self.post_stream_raw("/qai/v1/chat", &req).await?;
457
458        let byte_stream = resp.bytes_stream();
459        let event_stream = sse_to_events(byte_stream);
460
461        Ok(ChatStream {
462            inner: Box::pin(event_stream),
463        })
464    }
465}
466
467/// Converts a byte stream into a stream of parsed [`StreamEvent`]s.
468fn sse_to_events<S>(byte_stream: S) -> impl Stream<Item = StreamEvent> + Send
469where
470    S: Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>> + Send + 'static,
471{
472    // Pin the byte stream so we can poll it inside unfold.
473    let pinned_stream = Box::pin(byte_stream);
474
475    // Accumulate raw bytes into lines to avoid splitting multi-byte UTF-8 characters.
476    // Only convert to String when we have a complete newline-terminated line.
477    let line_stream = futures_util::stream::unfold(
478        (pinned_stream, Vec::<u8>::new()),
479        |(mut stream, mut buffer)| async move {
480            use futures_util::StreamExt;
481            loop {
482                // Check if we have a complete line in the buffer.
483                if let Some(newline_pos) = buffer.iter().position(|&b| b == b'\n') {
484                    let mut line_bytes = buffer[..newline_pos].to_vec();
485                    buffer = buffer[newline_pos + 1..].to_vec();
486                    // Trim trailing \r
487                    if line_bytes.last() == Some(&b'\r') {
488                        line_bytes.pop();
489                    }
490                    let line = String::from_utf8_lossy(&line_bytes).into_owned();
491                    return Some((line, (stream, buffer)));
492                }
493
494                // Read more data.
495                match stream.next().await {
496                    Some(Ok(chunk)) => {
497                        buffer.extend_from_slice(&chunk);
498                    }
499                    Some(Err(_)) | None => {
500                        // Stream ended. Emit remaining buffer if non-empty.
501                        if !buffer.is_empty() {
502                            let remaining = String::from_utf8_lossy(&buffer).into_owned();
503                            buffer.clear();
504                            return Some((remaining, (stream, buffer)));
505                        }
506                        return None;
507                    }
508                }
509            }
510        },
511    );
512
513    let pinned_lines = Box::pin(line_stream);
514    futures_util::stream::unfold(pinned_lines, |mut lines| async move {
515        use futures_util::StreamExt;
516        loop {
517            let line = lines.next().await?;
518
519            if !line.starts_with("data: ") {
520                continue;
521            }
522            let payload = &line["data: ".len()..];
523
524            if payload == "[DONE]" {
525                let ev = StreamEvent {
526                    event_type: "done".to_string(),
527                    delta: None,
528                    tool_use: None,
529                    tool_use_start: None,
530                    tool_use_input_delta: None,
531                    tool_use_complete: None,
532                    usage: None,
533                    error: None,
534                    done: true,
535                };
536                return Some((ev, lines));
537            }
538
539            let raw: RawStreamEvent = match serde_json::from_str(payload) {
540                Ok(r) => r,
541                Err(e) => {
542                    let ev = StreamEvent {
543                        event_type: "error".to_string(),
544                        delta: None,
545                        tool_use: None,
546                        tool_use_start: None,
547                        tool_use_input_delta: None,
548                        tool_use_complete: None,
549                        usage: None,
550                        error: Some(format!("parse SSE: {e}")),
551                        done: false,
552                    };
553                    return Some((ev, lines));
554                }
555            };
556
557            let mut ev = StreamEvent {
558                event_type: raw.event_type.clone(),
559                delta: None,
560                tool_use: None,
561                tool_use_start: None,
562                tool_use_input_delta: None,
563                tool_use_complete: None,
564                usage: None,
565                error: None,
566                done: false,
567            };
568
569            match raw.event_type.as_str() {
570                "content_delta" | "thinking_delta" => {
571                    ev.delta = raw.delta;
572                }
573                "tool_use" => {
574                    // Legacy atomic event — kept for back-compat with
575                    // backends that haven't shipped the triplet (v0.7+).
576                    ev.tool_use = Some(StreamToolUse {
577                        id: raw.id.unwrap_or_default(),
578                        name: raw.name.unwrap_or_default(),
579                        input: raw.input.unwrap_or_default(),
580                    });
581                }
582                "tool_use_start" => {
583                    ev.tool_use_start = Some(StreamToolUseStart {
584                        id: raw.id.unwrap_or_default(),
585                        name: raw.name.unwrap_or_default(),
586                    });
587                }
588                "tool_use_input_delta" => {
589                    ev.tool_use_input_delta = Some(StreamToolUseInputDelta {
590                        id: raw.id.unwrap_or_default(),
591                        partial_json: raw.partial_json.unwrap_or_default(),
592                    });
593                }
594                "tool_use_complete" => {
595                    ev.tool_use_complete = Some(StreamToolUseComplete {
596                        id: raw.id.unwrap_or_default(),
597                        name: raw.name.unwrap_or_default(),
598                        input: raw.input.unwrap_or_default(),
599                    });
600                }
601                "usage" => {
602                    ev.usage = Some(ChatUsage {
603                        input_tokens: raw.input_tokens.unwrap_or(0),
604                        output_tokens: raw.output_tokens.unwrap_or(0),
605                        cost_ticks: raw.cost_ticks.unwrap_or(0),
606                    });
607                }
608                "error" => {
609                    ev.error = raw.message;
610                }
611                "heartbeat" => {}
612                _ => {}
613            }
614
615            return Some((ev, lines));
616        }
617    })
618}