Skip to main content

quantum_sdk/
chat.rs

1use std::collections::HashMap;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use futures_util::Stream;
6use pin_project_lite::pin_project;
7use serde::{Deserialize, Serialize};
8
9use crate::client::Client;
10use crate::error::Result;
11
12/// Request body for text generation.
13#[derive(Debug, Clone, Serialize, Default)]
14pub struct ChatRequest {
15    /// Model ID that determines provider routing (e.g. "claude-sonnet-4-6", "grok-4-1-fast-non-reasoning").
16    pub model: String,
17
18    /// Conversation history.
19    pub messages: Vec<ChatMessage>,
20
21    /// Functions the model can call.
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub tools: Option<Vec<ChatTool>>,
24
25    /// Enables server-sent event streaming. Set automatically by `chat_stream`.
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub stream: Option<bool>,
28
29    /// Controls randomness (0.0-2.0).
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub temperature: Option<f64>,
32
33    /// Limits the response length.
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub max_tokens: Option<i32>,
36
37    /// Provider-specific settings (e.g. Anthropic thinking, xAI search).
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub provider_options: Option<HashMap<String, serde_json::Value>>,
40}
41
42/// A single message in a conversation.
43#[derive(Debug, Clone, Serialize, Deserialize, Default)]
44pub struct ChatMessage {
45    /// One of "system", "user", "assistant", or "tool".
46    pub role: String,
47
48    /// Text content of the message.
49    #[serde(skip_serializing_if = "Option::is_none")]
50    pub content: Option<String>,
51
52    /// Structured content for assistant messages with tool calls.
53    /// When present, takes precedence over `content`.
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub content_blocks: Option<Vec<ContentBlock>>,
56
57    /// Required when role is "tool" — references the tool_use ID.
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub tool_call_id: Option<String>,
60
61    /// Whether a tool result is an error.
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub is_error: Option<bool>,
64}
65
66impl ChatMessage {
67    /// Creates a user message.
68    pub fn user(content: impl Into<String>) -> Self {
69        Self {
70            role: "user".to_string(),
71            content: Some(content.into()),
72            ..Default::default()
73        }
74    }
75
76    /// Creates an assistant message.
77    pub fn assistant(content: impl Into<String>) -> Self {
78        Self {
79            role: "assistant".to_string(),
80            content: Some(content.into()),
81            ..Default::default()
82        }
83    }
84
85    /// Creates a system message.
86    pub fn system(content: impl Into<String>) -> Self {
87        Self {
88            role: "system".to_string(),
89            content: Some(content.into()),
90            ..Default::default()
91        }
92    }
93
94    /// Creates a tool result message.
95    pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
96        Self {
97            role: "tool".to_string(),
98            content: Some(content.into()),
99            tool_call_id: Some(tool_call_id.into()),
100            ..Default::default()
101        }
102    }
103
104    /// Creates a tool error result message.
105    pub fn tool_error(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
106        Self {
107            role: "tool".to_string(),
108            content: Some(content.into()),
109            tool_call_id: Some(tool_call_id.into()),
110            is_error: Some(true),
111            ..Default::default()
112        }
113    }
114}
115
116/// A single block in the response content array.
117#[derive(Debug, Clone, Serialize, Deserialize, Default)]
118pub struct ContentBlock {
119    /// One of "text", "thinking", or "tool_use".
120    #[serde(rename = "type")]
121    pub block_type: String,
122
123    /// Content for "text" and "thinking" blocks.
124    #[serde(skip_serializing_if = "Option::is_none")]
125    pub text: Option<String>,
126
127    /// Tool call identifier for "tool_use" blocks.
128    #[serde(skip_serializing_if = "Option::is_none")]
129    pub id: Option<String>,
130
131    /// Function name for "tool_use" blocks.
132    #[serde(skip_serializing_if = "Option::is_none")]
133    pub name: Option<String>,
134
135    /// Function arguments for "tool_use" blocks.
136    #[serde(skip_serializing_if = "Option::is_none")]
137    pub input: Option<HashMap<String, serde_json::Value>>,
138}
139
140/// Defines a function the model can call.
141#[derive(Debug, Clone, Serialize, Default)]
142pub struct ChatTool {
143    /// Function name.
144    pub name: String,
145
146    /// Explains what the function does.
147    pub description: String,
148
149    /// JSON Schema for the function's arguments.
150    #[serde(skip_serializing_if = "Option::is_none")]
151    pub parameters: Option<serde_json::Value>,
152}
153
154/// Response from a non-streaming chat request.
155#[derive(Debug, Clone, Deserialize)]
156pub struct ChatResponse {
157    /// Unique request identifier.
158    pub id: String,
159
160    /// Model that generated the response.
161    pub model: String,
162
163    /// List of content blocks (text, thinking, tool_use).
164    #[serde(default)]
165    pub content: Vec<ContentBlock>,
166
167    /// Token counts and cost.
168    pub usage: Option<ChatUsage>,
169
170    /// Why generation stopped ("end_turn", "tool_use", "max_tokens").
171    #[serde(default)]
172    pub stop_reason: String,
173
174    /// Total cost from the X-QAI-Cost-Ticks header.
175    #[serde(skip)]
176    pub cost_ticks: i64,
177
178    /// From the X-QAI-Request-Id header.
179    #[serde(skip)]
180    pub request_id: String,
181}
182
183impl ChatResponse {
184    /// Returns the concatenated text content, ignoring thinking and tool_use blocks.
185    pub fn text(&self) -> String {
186        self.content
187            .iter()
188            .filter(|b| b.block_type == "text")
189            .filter_map(|b| b.text.as_deref())
190            .collect::<Vec<_>>()
191            .join("")
192    }
193
194    /// Returns the concatenated thinking content.
195    pub fn thinking(&self) -> String {
196        self.content
197            .iter()
198            .filter(|b| b.block_type == "thinking")
199            .filter_map(|b| b.text.as_deref())
200            .collect::<Vec<_>>()
201            .join("")
202    }
203
204    /// Returns all tool_use blocks from the response.
205    pub fn tool_calls(&self) -> Vec<&ContentBlock> {
206        self.content
207            .iter()
208            .filter(|b| b.block_type == "tool_use")
209            .collect()
210    }
211}
212
213/// Token counts and cost for a chat response.
214#[derive(Debug, Clone, Deserialize)]
215pub struct ChatUsage {
216    pub input_tokens: i32,
217    pub output_tokens: i32,
218    pub cost_ticks: i64,
219}
220
221/// A single event from an SSE chat stream.
222#[derive(Debug, Clone)]
223pub struct StreamEvent {
224    /// Event type: "content_delta", "thinking_delta", "tool_use", "usage", "heartbeat", "error", "done".
225    pub event_type: String,
226
227    /// Incremental text for content_delta and thinking_delta events.
228    pub delta: Option<StreamDelta>,
229
230    /// Populated for tool_use events.
231    pub tool_use: Option<StreamToolUse>,
232
233    /// Populated for usage events.
234    pub usage: Option<ChatUsage>,
235
236    /// Populated for error events.
237    pub error: Option<String>,
238
239    /// True when the stream is complete.
240    pub done: bool,
241}
242
243/// Incremental text in a streaming event.
244#[derive(Debug, Clone, Deserialize)]
245pub struct StreamDelta {
246    pub text: String,
247}
248
249/// A tool call from a streaming event.
250#[derive(Debug, Clone, Deserialize)]
251pub struct StreamToolUse {
252    pub id: String,
253    pub name: String,
254    pub input: HashMap<String, serde_json::Value>,
255}
256
257/// Raw JSON from the SSE stream before parsing into typed fields.
258#[derive(Deserialize)]
259struct RawStreamEvent {
260    #[serde(rename = "type")]
261    event_type: String,
262    #[serde(default)]
263    delta: Option<StreamDelta>,
264    #[serde(default)]
265    id: Option<String>,
266    #[serde(default)]
267    name: Option<String>,
268    #[serde(default)]
269    input: Option<HashMap<String, serde_json::Value>>,
270    #[serde(default)]
271    input_tokens: Option<i32>,
272    #[serde(default)]
273    output_tokens: Option<i32>,
274    #[serde(default)]
275    cost_ticks: Option<i64>,
276    #[serde(default)]
277    message: Option<String>,
278}
279
280pin_project! {
281    /// An async stream of [`StreamEvent`]s from an SSE chat response.
282    pub struct ChatStream {
283        #[pin]
284        inner: Pin<Box<dyn Stream<Item = StreamEvent> + Send>>,
285    }
286}
287
288impl Stream for ChatStream {
289    type Item = StreamEvent;
290
291    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
292        self.project().inner.poll_next(cx)
293    }
294}
295
296impl Client {
297    /// Sends a non-streaming text generation request.
298    pub async fn chat(&self, req: &ChatRequest) -> Result<ChatResponse> {
299        let mut req = req.clone();
300        req.stream = Some(false);
301
302        let (mut resp, meta) = self.post_json::<ChatRequest, ChatResponse>("/qai/v1/chat", &req).await?;
303        resp.cost_ticks = meta.cost_ticks;
304        resp.request_id = meta.request_id;
305        if resp.model.is_empty() {
306            resp.model = meta.model;
307        }
308        Ok(resp)
309    }
310
311    /// Sends a streaming text generation request and returns an async stream of events.
312    ///
313    /// # Example
314    ///
315    /// ```no_run
316    /// use futures_util::StreamExt;
317    ///
318    /// # async fn example() -> quantum_sdk::Result<()> {
319    /// let client = quantum_sdk::Client::new("key");
320    /// let req = quantum_sdk::ChatRequest {
321    ///     model: "claude-sonnet-4-6".into(),
322    ///     messages: vec![quantum_sdk::ChatMessage::user("Hello!")],
323    ///     ..Default::default()
324    /// };
325    /// let mut stream = client.chat_stream(&req).await?;
326    /// while let Some(ev) = stream.next().await {
327    ///     if let Some(delta) = &ev.delta {
328    ///         print!("{}", delta.text);
329    ///     }
330    /// }
331    /// # Ok(())
332    /// # }
333    /// ```
334    pub async fn chat_stream(&self, req: &ChatRequest) -> Result<ChatStream> {
335        let mut req = req.clone();
336        req.stream = Some(true);
337
338        let (resp, _meta) = self.post_stream_raw("/qai/v1/chat", &req).await?;
339
340        let byte_stream = resp.bytes_stream();
341        let event_stream = sse_to_events(byte_stream);
342
343        Ok(ChatStream {
344            inner: Box::pin(event_stream),
345        })
346    }
347}
348
349/// Converts a byte stream into a stream of parsed [`StreamEvent`]s.
350fn sse_to_events<S>(byte_stream: S) -> impl Stream<Item = StreamEvent> + Send
351where
352    S: Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>> + Send + 'static,
353{
354    // Pin the byte stream so we can poll it inside unfold.
355    let pinned_stream = Box::pin(byte_stream);
356
357    // We accumulate bytes into lines, then parse SSE "data: " lines.
358    let line_stream = futures_util::stream::unfold(
359        (pinned_stream, String::new()),
360        |(mut stream, mut buffer)| async move {
361            use futures_util::StreamExt;
362            loop {
363                // Check if we have a complete line in the buffer.
364                if let Some(newline_pos) = buffer.find('\n') {
365                    let line = buffer[..newline_pos].trim_end_matches('\r').to_string();
366                    buffer = buffer[newline_pos + 1..].to_string();
367                    return Some((line, (stream, buffer)));
368                }
369
370                // Read more data.
371                match stream.next().await {
372                    Some(Ok(chunk)) => {
373                        buffer.push_str(&String::from_utf8_lossy(&chunk));
374                    }
375                    Some(Err(_)) | None => {
376                        // Stream ended. Emit remaining buffer if non-empty.
377                        if !buffer.is_empty() {
378                            let remaining = std::mem::take(&mut buffer);
379                            return Some((remaining, (stream, buffer)));
380                        }
381                        return None;
382                    }
383                }
384            }
385        },
386    );
387
388    let pinned_lines = Box::pin(line_stream);
389    futures_util::stream::unfold(pinned_lines, |mut lines| async move {
390        use futures_util::StreamExt;
391        loop {
392            let line = lines.next().await?;
393
394            if !line.starts_with("data: ") {
395                continue;
396            }
397            let payload = &line["data: ".len()..];
398
399            if payload == "[DONE]" {
400                let ev = StreamEvent {
401                    event_type: "done".to_string(),
402                    delta: None,
403                    tool_use: None,
404                    usage: None,
405                    error: None,
406                    done: true,
407                };
408                return Some((ev, lines));
409            }
410
411            let raw: RawStreamEvent = match serde_json::from_str(payload) {
412                Ok(r) => r,
413                Err(e) => {
414                    let ev = StreamEvent {
415                        event_type: "error".to_string(),
416                        delta: None,
417                        tool_use: None,
418                        usage: None,
419                        error: Some(format!("parse SSE: {e}")),
420                        done: false,
421                    };
422                    return Some((ev, lines));
423                }
424            };
425
426            let mut ev = StreamEvent {
427                event_type: raw.event_type.clone(),
428                delta: None,
429                tool_use: None,
430                usage: None,
431                error: None,
432                done: false,
433            };
434
435            match raw.event_type.as_str() {
436                "content_delta" | "thinking_delta" => {
437                    ev.delta = raw.delta;
438                }
439                "tool_use" => {
440                    ev.tool_use = Some(StreamToolUse {
441                        id: raw.id.unwrap_or_default(),
442                        name: raw.name.unwrap_or_default(),
443                        input: raw.input.unwrap_or_default(),
444                    });
445                }
446                "usage" => {
447                    ev.usage = Some(ChatUsage {
448                        input_tokens: raw.input_tokens.unwrap_or(0),
449                        output_tokens: raw.output_tokens.unwrap_or(0),
450                        cost_ticks: raw.cost_ticks.unwrap_or(0),
451                    });
452                }
453                "error" => {
454                    ev.error = raw.message;
455                }
456                "heartbeat" => {}
457                _ => {}
458            }
459
460            return Some((ev, lines));
461        }
462    })
463}