Skip to main content

quantum_sdk/
chat.rs

1use std::collections::HashMap;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use futures_util::Stream;
6use pin_project_lite::pin_project;
7use serde::{Deserialize, Serialize};
8
9use crate::client::Client;
10use crate::error::Result;
11
12/// Request body for text generation.
13#[derive(Debug, Clone, Serialize, Default)]
14pub struct ChatRequest {
15    /// Model ID that determines provider routing (e.g. "claude-sonnet-4-6", "grok-4-1-fast-non-reasoning").
16    pub model: String,
17
18    /// Conversation history.
19    pub messages: Vec<ChatMessage>,
20
21    /// Functions the model can call.
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub tools: Option<Vec<ChatTool>>,
24
25    /// Enables server-sent event streaming. Set automatically by `chat_stream`.
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub stream: Option<bool>,
28
29    /// Controls randomness (0.0-2.0).
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub temperature: Option<f64>,
32
33    /// Limits the response length.
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub max_tokens: Option<i32>,
36
37    /// Provider-specific settings (e.g. Anthropic thinking, xAI search).
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub provider_options: Option<HashMap<String, serde_json::Value>>,
40}
41
42/// A single message in a conversation.
43#[derive(Debug, Clone, Serialize, Deserialize, Default)]
44pub struct ChatMessage {
45    /// One of "system", "user", "assistant", or "tool".
46    pub role: String,
47
48    /// Text content of the message.
49    #[serde(skip_serializing_if = "Option::is_none")]
50    pub content: Option<String>,
51
52    /// Structured content for assistant messages with tool calls.
53    /// When present, takes precedence over `content`.
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub content_blocks: Option<Vec<ContentBlock>>,
56
57    /// Required when role is "tool" — references the tool_use ID.
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub tool_call_id: Option<String>,
60
61    /// Whether a tool result is an error.
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub is_error: Option<bool>,
64}
65
66impl ChatMessage {
67    /// Creates a user message.
68    pub fn user(content: impl Into<String>) -> Self {
69        Self {
70            role: "user".to_string(),
71            content: Some(content.into()),
72            ..Default::default()
73        }
74    }
75
76    /// Creates an assistant message.
77    pub fn assistant(content: impl Into<String>) -> Self {
78        Self {
79            role: "assistant".to_string(),
80            content: Some(content.into()),
81            ..Default::default()
82        }
83    }
84
85    /// Creates a system message.
86    pub fn system(content: impl Into<String>) -> Self {
87        Self {
88            role: "system".to_string(),
89            content: Some(content.into()),
90            ..Default::default()
91        }
92    }
93
94    /// Creates a tool result message.
95    pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
96        Self {
97            role: "tool".to_string(),
98            content: Some(content.into()),
99            tool_call_id: Some(tool_call_id.into()),
100            ..Default::default()
101        }
102    }
103
104    /// Creates a tool error result message.
105    pub fn tool_error(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
106        Self {
107            role: "tool".to_string(),
108            content: Some(content.into()),
109            tool_call_id: Some(tool_call_id.into()),
110            is_error: Some(true),
111            ..Default::default()
112        }
113    }
114}
115
116/// A single block in the response content array.
117#[derive(Debug, Clone, Serialize, Deserialize, Default)]
118pub struct ContentBlock {
119    /// One of "text", "thinking", or "tool_use".
120    #[serde(rename = "type")]
121    pub block_type: String,
122
123    /// Content for "text" and "thinking" blocks.
124    #[serde(skip_serializing_if = "Option::is_none")]
125    pub text: Option<String>,
126
127    /// Tool call identifier for "tool_use" blocks.
128    #[serde(skip_serializing_if = "Option::is_none")]
129    pub id: Option<String>,
130
131    /// Function name for "tool_use" blocks.
132    #[serde(skip_serializing_if = "Option::is_none")]
133    pub name: Option<String>,
134
135    /// Function arguments for "tool_use" blocks.
136    #[serde(skip_serializing_if = "Option::is_none")]
137    pub input: Option<HashMap<String, serde_json::Value>>,
138
139    /// Gemini thought signature — must be echoed back with tool results.
140    /// This is an opaque base64-encoded blob from Gemini 3.x models.
141    #[serde(skip_serializing_if = "Option::is_none")]
142    pub thought_signature: Option<String>,
143}
144
145/// Defines a function the model can call.
146#[derive(Debug, Clone, Serialize, Default)]
147pub struct ChatTool {
148    /// Function name.
149    pub name: String,
150
151    /// Explains what the function does.
152    pub description: String,
153
154    /// JSON Schema for the function's arguments.
155    #[serde(skip_serializing_if = "Option::is_none")]
156    pub parameters: Option<serde_json::Value>,
157}
158
159/// Response from a non-streaming chat request.
160#[derive(Debug, Clone, Deserialize)]
161pub struct ChatResponse {
162    /// Unique request identifier.
163    pub id: String,
164
165    /// Model that generated the response.
166    pub model: String,
167
168    /// List of content blocks (text, thinking, tool_use).
169    #[serde(default)]
170    pub content: Vec<ContentBlock>,
171
172    /// Token counts and cost.
173    pub usage: Option<ChatUsage>,
174
175    /// Why generation stopped ("end_turn", "tool_use", "max_tokens").
176    #[serde(default)]
177    pub stop_reason: String,
178
179    /// Citations from web search (when search is enabled via provider_options).
180    #[serde(default)]
181    pub citations: Vec<Citation>,
182
183    /// Total cost from the X-QAI-Cost-Ticks header.
184    #[serde(skip)]
185    pub cost_ticks: i64,
186
187    /// From the X-QAI-Request-Id header.
188    #[serde(skip)]
189    pub request_id: String,
190}
191
192impl ChatResponse {
193    /// Returns the concatenated text content, ignoring thinking and tool_use blocks.
194    pub fn text(&self) -> String {
195        self.content
196            .iter()
197            .filter(|b| b.block_type == "text")
198            .filter_map(|b| b.text.as_deref())
199            .collect::<Vec<_>>()
200            .join("")
201    }
202
203    /// Returns the concatenated thinking content.
204    pub fn thinking(&self) -> String {
205        self.content
206            .iter()
207            .filter(|b| b.block_type == "thinking")
208            .filter_map(|b| b.text.as_deref())
209            .collect::<Vec<_>>()
210            .join("")
211    }
212
213    /// Returns all tool_use blocks from the response.
214    pub fn tool_calls(&self) -> Vec<&ContentBlock> {
215        self.content
216            .iter()
217            .filter(|b| b.block_type == "tool_use")
218            .collect()
219    }
220}
221
222/// A source reference from web search grounding.
223#[derive(Debug, Clone, Deserialize, Serialize)]
224pub struct Citation {
225    /// Title of the cited source.
226    #[serde(default)]
227    pub title: String,
228
229    /// URL of the cited source.
230    #[serde(default)]
231    pub url: String,
232
233    /// Relevant text snippet from the source.
234    #[serde(default)]
235    pub text: String,
236
237    /// Position in the response.
238    #[serde(default)]
239    pub index: i32,
240}
241
242/// Token counts and cost for a chat response.
243#[derive(Debug, Clone, Deserialize)]
244pub struct ChatUsage {
245    pub input_tokens: i32,
246    pub output_tokens: i32,
247    pub cost_ticks: i64,
248}
249
250/// A single event from an SSE chat stream.
251#[derive(Debug, Clone)]
252pub struct StreamEvent {
253    /// Event type: "content_delta", "thinking_delta", "tool_use", "usage", "heartbeat", "error", "done".
254    pub event_type: String,
255
256    /// Incremental text for content_delta and thinking_delta events.
257    pub delta: Option<StreamDelta>,
258
259    /// Populated for tool_use events.
260    pub tool_use: Option<StreamToolUse>,
261
262    /// Populated for usage events.
263    pub usage: Option<ChatUsage>,
264
265    /// Populated for error events.
266    pub error: Option<String>,
267
268    /// True when the stream is complete.
269    pub done: bool,
270}
271
272/// Incremental text in a streaming event.
273#[derive(Debug, Clone, Deserialize)]
274pub struct StreamDelta {
275    pub text: String,
276}
277
278/// A tool call from a streaming event.
279#[derive(Debug, Clone, Deserialize)]
280pub struct StreamToolUse {
281    pub id: String,
282    pub name: String,
283    pub input: HashMap<String, serde_json::Value>,
284}
285
286/// Raw JSON from the SSE stream before parsing into typed fields.
287#[derive(Deserialize)]
288struct RawStreamEvent {
289    #[serde(rename = "type")]
290    event_type: String,
291    #[serde(default)]
292    delta: Option<StreamDelta>,
293    #[serde(default)]
294    id: Option<String>,
295    #[serde(default)]
296    name: Option<String>,
297    #[serde(default)]
298    input: Option<HashMap<String, serde_json::Value>>,
299    #[serde(default)]
300    input_tokens: Option<i32>,
301    #[serde(default)]
302    output_tokens: Option<i32>,
303    #[serde(default)]
304    cost_ticks: Option<i64>,
305    #[serde(default)]
306    message: Option<String>,
307}
308
309pin_project! {
310    /// An async stream of [`StreamEvent`]s from an SSE chat response.
311    pub struct ChatStream {
312        #[pin]
313        inner: Pin<Box<dyn Stream<Item = StreamEvent> + Send>>,
314    }
315}
316
317impl Stream for ChatStream {
318    type Item = StreamEvent;
319
320    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
321        self.project().inner.poll_next(cx)
322    }
323}
324
325impl Client {
326    /// Sends a non-streaming text generation request.
327    pub async fn chat(&self, req: &ChatRequest) -> Result<ChatResponse> {
328        let mut req = req.clone();
329        req.stream = Some(false);
330
331        let (mut resp, meta) = self.post_json::<ChatRequest, ChatResponse>("/qai/v1/chat", &req).await?;
332        resp.cost_ticks = meta.cost_ticks;
333        resp.request_id = meta.request_id;
334        if resp.model.is_empty() {
335            resp.model = meta.model;
336        }
337        Ok(resp)
338    }
339
340    /// Sends a streaming text generation request and returns an async stream of events.
341    ///
342    /// # Example
343    ///
344    /// ```no_run
345    /// use futures_util::StreamExt;
346    ///
347    /// # async fn example() -> quantum_sdk::Result<()> {
348    /// let client = quantum_sdk::Client::new("key");
349    /// let req = quantum_sdk::ChatRequest {
350    ///     model: "claude-sonnet-4-6".into(),
351    ///     messages: vec![quantum_sdk::ChatMessage::user("Hello!")],
352    ///     ..Default::default()
353    /// };
354    /// let mut stream = client.chat_stream(&req).await?;
355    /// while let Some(ev) = stream.next().await {
356    ///     if let Some(delta) = &ev.delta {
357    ///         print!("{}", delta.text);
358    ///     }
359    /// }
360    /// # Ok(())
361    /// # }
362    /// ```
363    pub async fn chat_stream(&self, req: &ChatRequest) -> Result<ChatStream> {
364        let mut req = req.clone();
365        req.stream = Some(true);
366
367        let (resp, _meta) = self.post_stream_raw("/qai/v1/chat", &req).await?;
368
369        let byte_stream = resp.bytes_stream();
370        let event_stream = sse_to_events(byte_stream);
371
372        Ok(ChatStream {
373            inner: Box::pin(event_stream),
374        })
375    }
376}
377
378/// Converts a byte stream into a stream of parsed [`StreamEvent`]s.
379fn sse_to_events<S>(byte_stream: S) -> impl Stream<Item = StreamEvent> + Send
380where
381    S: Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>> + Send + 'static,
382{
383    // Pin the byte stream so we can poll it inside unfold.
384    let pinned_stream = Box::pin(byte_stream);
385
386    // We accumulate bytes into lines, then parse SSE "data: " lines.
387    let line_stream = futures_util::stream::unfold(
388        (pinned_stream, String::new()),
389        |(mut stream, mut buffer)| async move {
390            use futures_util::StreamExt;
391            loop {
392                // Check if we have a complete line in the buffer.
393                if let Some(newline_pos) = buffer.find('\n') {
394                    let line = buffer[..newline_pos].trim_end_matches('\r').to_string();
395                    buffer = buffer[newline_pos + 1..].to_string();
396                    return Some((line, (stream, buffer)));
397                }
398
399                // Read more data.
400                match stream.next().await {
401                    Some(Ok(chunk)) => {
402                        buffer.push_str(&String::from_utf8_lossy(&chunk));
403                    }
404                    Some(Err(_)) | None => {
405                        // Stream ended. Emit remaining buffer if non-empty.
406                        if !buffer.is_empty() {
407                            let remaining = std::mem::take(&mut buffer);
408                            return Some((remaining, (stream, buffer)));
409                        }
410                        return None;
411                    }
412                }
413            }
414        },
415    );
416
417    let pinned_lines = Box::pin(line_stream);
418    futures_util::stream::unfold(pinned_lines, |mut lines| async move {
419        use futures_util::StreamExt;
420        loop {
421            let line = lines.next().await?;
422
423            if !line.starts_with("data: ") {
424                continue;
425            }
426            let payload = &line["data: ".len()..];
427
428            if payload == "[DONE]" {
429                let ev = StreamEvent {
430                    event_type: "done".to_string(),
431                    delta: None,
432                    tool_use: None,
433                    usage: None,
434                    error: None,
435                    done: true,
436                };
437                return Some((ev, lines));
438            }
439
440            let raw: RawStreamEvent = match serde_json::from_str(payload) {
441                Ok(r) => r,
442                Err(e) => {
443                    let ev = StreamEvent {
444                        event_type: "error".to_string(),
445                        delta: None,
446                        tool_use: None,
447                        usage: None,
448                        error: Some(format!("parse SSE: {e}")),
449                        done: false,
450                    };
451                    return Some((ev, lines));
452                }
453            };
454
455            let mut ev = StreamEvent {
456                event_type: raw.event_type.clone(),
457                delta: None,
458                tool_use: None,
459                usage: None,
460                error: None,
461                done: false,
462            };
463
464            match raw.event_type.as_str() {
465                "content_delta" | "thinking_delta" => {
466                    ev.delta = raw.delta;
467                }
468                "tool_use" => {
469                    ev.tool_use = Some(StreamToolUse {
470                        id: raw.id.unwrap_or_default(),
471                        name: raw.name.unwrap_or_default(),
472                        input: raw.input.unwrap_or_default(),
473                    });
474                }
475                "usage" => {
476                    ev.usage = Some(ChatUsage {
477                        input_tokens: raw.input_tokens.unwrap_or(0),
478                        output_tokens: raw.output_tokens.unwrap_or(0),
479                        cost_ticks: raw.cost_ticks.unwrap_or(0),
480                    });
481                }
482                "error" => {
483                    ev.error = raw.message;
484                }
485                "heartbeat" => {}
486                _ => {}
487            }
488
489            return Some((ev, lines));
490        }
491    })
492}